/*
 * Created on 23-Apr-2006
 * Martin Robinson
 * This software is released and may be distributed under the GNU GPL license
 * See: http://www.gnu.org/licenses/gpl.txt
 */
package uk.co.miajo.NN;
import java.io.*;

public class Network {
    private double  learnrate, 
                    actfuncoff;
    private Layer   layers[];
    private String  lastpath = "Untitled.txt";
    
    public Network(int numNodes[], double _learnrate, double _actfuncoff)
    {
        try {    
            checkNumNodes(numNodes.length);
            
            learnrate   = _learnrate;
            actfuncoff  = _actfuncoff;
            
            layers = new Layer[numNodes.length-1];
            
            for(int i = 0; i < layers.length; i++)
            {
                int numinputs = numNodes[i];
                int numoutputs = numNodes[i+1];
                
                if(numinputs < 1) numinputs = 1;
                if(numoutputs < 1) numoutputs = 1;
                
                layers[i] = new Layer(numoutputs, numinputs);
            }
        }
        catch(NNException e) {
            
        }
        
    }
    
    public Network(int numNodes[])
    {
        this(numNodes, Utility.defaultLearnRate, Utility.defaultActFuncOffset);
    }
    
    public int numLayers()
    {
        return layers.length;
    }
    
    public String getLastPath()
    {
        return lastpath;
    }
    
    public void write(String _filepath, final Listener callback)
    {
        this.write(_filepath, true, callback);
    }
    
    public void write(String _filepath, final boolean doComments, final Listener callback)
    {
        final String filepath;
        if(_filepath == null) {
            filepath = lastpath;
        } else {
            filepath = _filepath;
        }
        
        Thread t = new Thread() 
        {
            public void run() 
            {
                try {
                    FileWriter fileWriter = new FileWriter(filepath);
                    BufferedWriter buffer = new BufferedWriter(fileWriter);
                    
                    buffer.write(layers[0].numInputs()+" ");
                    
                    for(int i = 0; i < layers.length; i++) {
                        buffer.write(layers[i].numOutputs()+" ");
                    }
                    
                    if(doComments) buffer.write("\t\t\t# Network structure");
                    buffer.write("\n");
                    
                    buffer.write(learnrate+" "+actfuncoff);
                    if(doComments) buffer.write("\t\t# Network parameters");
                    buffer.write("\n");
                    
                    for(int i = 0; i < layers.length; i++) {
                        String comment = null;
                        if(doComments) {
                            comment = "Layer "+i;
                        }
                        layers[i].write(comment, buffer, callback);
                    }
                    
                    buffer.close();
                    lastpath = filepath;
                    
                    if(callback != null) callback.nnDoneWriteNet(filepath);
                }
                catch(Exception e)
                {
                    if(callback != null) {
                        callback.nnError("network write error: "+filepath);
                        callback.nnDoneWriteNet(null);
                    }
                    return;
                }
                
                
            }
        };
        t.start();
    }
    
    public void writeagain(final Listener callback)
    {
        this.write(lastpath, callback);
    }
    
    public void read(final String filepath, final Listener callback)
    {
        final Network net = this;
        
        Thread t = new Thread() 
        {
            public void run() 
            {
                try {
                    FileReader fileReader = new FileReader(filepath);
                    BufferedReader buffer = new BufferedReader(fileReader);
 
                    if(buffer.ready())
                    {
                        String line;
                        int comm;
                        
                        line = buffer.readLine();
                        comm = line.indexOf("#");
                        if(comm>=0) line = line.substring(0, comm);
                        int netArgsI[] = Utility.parseIntArray(line);
                        
                        if(netArgsI.length != net.numLayers()+1)
                        {
                            buffer.close();
                            if(callback != null) {
                                callback.nnError("network read: numLayers doesn't match");
                                callback.nnDoneReadNet(null);
                            }
                            return;
                        }
                        
                        if(netArgsI[0] != layers[0].numInputs())
                        {
                            buffer.close();
                            if(callback != null) {
                                callback.nnError("network read: input layer doesn't match");
                                callback.nnDoneReadNet(null);
                            }
                            return;
                        }
                        
                        for(int i = 1; i < netArgsI.length; i++)
                        {
                            if(netArgsI[i] != layers[i-1].numOutputs())
                            {
                                buffer.close();
                                if(callback != null) {
                                    callback.nnError("network read: layer "+i+" doesn't match");
                                    callback.nnDoneReadNet(null);
                                }
                                return;
                            }
                        }
                        
                        line = buffer.readLine();
                        comm = line.indexOf("#");
                        if(comm>=0) line = line.substring(0, comm);
                        double netArgsD[] = Utility.parseDoubleArray(line);
                        
                        if(netArgsD.length != 2)
                        {
                            buffer.close();
                            if(callback != null) {
                                callback.nnError("network read: network paramters must be two floats");
                                callback.nnDoneReadNet(null);
                            }
                            return;
                        }
                        
                        learnrate = netArgsD[0];
                        actfuncoff = netArgsD[1];
                        
                        for(int i = 0; i < layers.length; i++)
                        {
                            layers[i].read(buffer, callback);
                        }
                        
                    } else {
                        buffer.close();
                        if(callback != null) {
                            callback.nnError("network read error");
                            callback.nnDoneReadNet(null);
                        }
                        return;
                    }
                    
                    buffer.close();
                    lastpath = filepath;
                    
                    if(callback != null) callback.nnDoneReadNet(filepath);
                }
                catch(Exception e)
                {
                    if(callback != null) {
                        callback.nnError("network read error");
                        callback.nnDoneReadNet(null);
                    }
                    return;
                }
            }
        };
        t.start();
    }
    
    public void readagain(final Listener callback)
    {
        this.read(lastpath, callback);
    }
    
    private void checkNumNodes(int n) throws NNException
    {
        if(n < 2) throw new NNException("minimum number of layers is 2");
    }
    
    public int numInputs()
    {
        return layers[0].numInputs();
    }
    
    public int numOutputs()
    {
        return layers[layers.length-1].numOutputs();
    }
    
    public void init()
    {
        this.init(Utility.defaultWeight);
    }
    
    public void init(double w)
    {
        for(int i = 0; i < layers.length; i++)
        {
            layers[i].init(w);
        }
    }
    
    public double[] propogate(double inputVector[])
    {
        double vector[] = inputVector;
        
        // traverse the layers forwards to propogate the input
        for(int i = 0; i < layers.length; i++)
        {
            vector = layers[i].propogate(vector);
        }
        
        return vector;
    }
    
    public double[] backProp(double inputVector[], double targetVector[])
    {   
        double outputVector[] = propogate(inputVector);
        double errorVector[] = Utility.subtractVectors(targetVector, outputVector);
        
        // traverse the layers backwards to propogate the error
        for(int i = layers.length-1; i >= 0; i--)
        {
            errorVector = layers[i].backProp(errorVector, actfuncoff, learnrate);
        }
        
        return errorVector;
    }
    
    public void train(final TrainingPatternCollection patterns, final int n, final Listener callback) throws NNException
    {
        if(patterns.numInputs() != this.numInputs()) throw new NNException("input size doesn't match");
        if(patterns.numOutputs() != this.numOutputs()) throw new NNException("output size doesn't match");
        
        final Network network = this;
        
        Thread t = new Thread() 
        {
            public void run() 
            {
                patterns.train(network, n, callback);
            }
        };
        t.start();
    }
    
}
