/*
 * Created on 23-Apr-2006
 * Martin Robinson
 * This software is released and may be distributed under the GNU GPL license
 * See: http://www.gnu.org/licenses/gpl.txt
 */
package uk.co.miajo.NN;
import java.io.*;
import java.util.*;

public class TrainingPatternCollection {
    private int numinputs, numoutputs;
    private List patterns;
    private String lastpath = "Untitled.txt";
    
    public TrainingPatternCollection(int _numinputs, int _numoutputs)
    {
        numinputs = _numinputs;
        numoutputs = _numoutputs;
        
        this.clear();
    }
    
    public String getLastPath()
    {
        return lastpath;
    }
    
    public void write(String _filepath, final Listener callback)
    {
        this.write(_filepath, true, callback);
    }
    
    public void write(String _filepath, final boolean doComments, final Listener callback)
    {
        final String filepath;
        if(_filepath == null) {
            filepath = lastpath;
        } else {
            filepath = _filepath;
        }
        
        Thread t = new Thread() 
        {
            public void run() 
            {
                try {
                    FileWriter fileWriter = new FileWriter(filepath);
                    BufferedWriter buffer = new BufferedWriter(fileWriter);
                    
                    buffer.write(numinputs+" "+numoutputs);
                    if(doComments) buffer.write("\t\t\t\t# Network inputs/outputs");
                    buffer.write("\n");
                    
                    for(int i = 0; i < patterns.size(); i++) {
                        TrainingPattern pat = (TrainingPattern)patterns.get(i);
                        buffer.write("pattern");
                        if(doComments) buffer.write("\t\t\t\t# header (pattern "+i+")");
                        buffer.write("\n");
                        pat.write(buffer, doComments, callback);
                    }
                    
                    buffer.close();
                    lastpath = filepath;
                    
                    if(callback != null) callback.nnDoneWritePat(filepath);
                }
                catch(Exception e)
                {
                    if(callback != null) {
                        callback.nnError("pattern write error: "+filepath);
                        callback.nnDoneWritePat(null);
                    }
                    return;
                }
                
                
            }
        };
        t.start();
    }
    
    public void writeagain(final Listener callback)
    {
        this.write(lastpath, callback);
    }
    
    public void read(final String filepath, final Listener callback)
    {
        final TrainingPatternCollection pat = this;
        
        Thread t = new Thread() 
        {
            public void run() 
            {
                try {
                    FileReader fileReader = new FileReader(filepath);
                    BufferedReader buffer = new BufferedReader(fileReader);
 
                    if(buffer.ready())
                    {
                        String line;
                        int comm;
                        
                        line = buffer.readLine();
                        comm = line.indexOf("#");
                        if(comm>=0) line = line.substring(0, comm);
                        
                        int patArgs[] = Utility.parseIntArray(line);
                        
                        if(patArgs.length != 2)
                        {
                            buffer.close();
                            if(callback != null) {
                                callback.nnError("pattern read: network inputs/outputs must be two ints: "+filepath);
                                callback.nnDoneReadPat(null);
                            }
                            return;
                        }
                        
                        if(patArgs[0] != pat.numInputs())
                        {
                            buffer.close();
                            if(callback != null) {
                                callback.nnError("pattern read: input layer doesn't match: "+filepath);
                                callback.nnDoneReadPat(null);
                            }
                            return;
                        }
                        
                        if(patArgs[1] != pat.numOutputs())
                        {
                            buffer.close();
                            if(callback != null) {
                                callback.nnError("pattern read: output layer doesn't match: "+filepath);
                                callback.nnDoneReadPat(null);
                            }
                            return;
                        }
                        
                        while(buffer.ready()) {
                            line = buffer.readLine();
                            comm = line.indexOf("#");
                            if(comm>=0) line = line.substring(0, comm);
                            
                            if(line.indexOf("pattern") >= 0)
                            {
                                pat.add(new TrainingPattern(buffer, numinputs, numoutputs, callback));
                            }
                        }
                        
                    } else {
                        buffer.close();
                        if(callback != null) {
                            callback.nnError("pattern read error: "+filepath);
                            callback.nnDoneReadPat(null);
                        }
                        return;
                    }
                    
                    buffer.close();
                    lastpath = filepath;
                    
                    if(callback != null) callback.nnDoneReadPat(filepath);
                }
                catch(Exception e)
                {
                    if(callback != null) {
                        callback.nnError("pattern read error: "+filepath);
                        callback.nnDoneReadPat(null);
                    }
                    return;
                }
            }
        };
        t.start();
    }
    
    public void readagain(final Listener callback)
    {
        this.read(lastpath, callback);
    }
    
    public int numPatterns()
    {
        return patterns.size();
    }
    
    public int numInputs()
    {
        return numinputs;
    }
    
    public int numOutputs()
    {
        return numoutputs;
    }
    
    public void clear()
    {
        patterns = new ArrayList();
    }
    
    public void add(TrainingPattern p) throws NNException
    {
        if(numinputs != p.numInputs()) throw new NNException("input size doesn't match");
        if(numoutputs != p.numOutputs()) throw new NNException("output size doesn't match");
        patterns.add(p);
    }
    
    public void add(double inputs[], double outputs[]) throws NNException
    {
        if(numinputs != inputs.length) throw new NNException("input size doesn't match");
        if(numoutputs != outputs.length) throw new NNException("output size doesn't match");
        patterns.add(new TrainingPattern(inputs, outputs));
    }
    
    public void train(Network network, int n)
    {
        this.train(network, n, null);
    }
    
    public void train(Network network, int n, Listener callback)
    {
        TrainingPattern patternArray[] = this.getArray();
        
        if(n < 1) n = 1;
        
        for(int e = 0; e < n; e++)
        {
            for(int i = 0; i < patternArray.length; i++)
            {
                patternArray[i].train(network);
            }
        }
        
        if(callback != null) callback.nnDoneTrain();
    }
    
    private TrainingPattern[] getArray()
    {
        TrainingPattern patternArray[] = new TrainingPattern[patterns.size()];
        
        for(int i = 0; i < patterns.size(); i++)
        {
            patternArray[i] = (TrainingPattern)patterns.get(i);
        }
        
        return patternArray;
    }
    
}
