/*
 * Decompiled with CFR 0.152.
 */
package pty.learn;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import ma.SequenceType;
import nuts.math.RateMtxUtils;
import nuts.tui.Table;
import nuts.util.CollUtils;
import nuts.util.Indexer;
import nuts.util.MathUtils;
import pty.io.Dataset;
import pty.io.WalsDataset;
import pty.io.WalsProcessingScript;
import pty.smc.models.CTMC;
import pty.smc.models.CTMCUtils;

public class CTMCLoader {
    @Option
    public LoadingMethod loadingMethod = LoadingMethod.HEURISTIC_ESTIMATE;
    @Option
    public String file = "init.CTMC";
    @Option
    public SequenceType builtInSequenceType = SequenceType.BINARY;
    @Option
    public double rate = 1.0;
    @Option
    public boolean siteSpecific = false;
    @Option
    public boolean forceUniform = false;
    public Dataset data;

    public CTMC load() {
        return this.loadingMethod.load(this);
    }

    public void setData(Dataset data) {
        this.data = data;
    }

    public static boolean linked(Indexer<WalsDataset.BioCharacter> characterIndexer, int c1, int c2) {
        int _c1Index = characterIndexer.i2o((int)c1).index;
        int _c2Index = characterIndexer.i2o((int)c2).index;
        int c1Index = _c1Index < _c2Index ? _c1Index : _c2Index;
        int c2Index = _c1Index < _c2Index ? _c2Index : _c1Index;
        for (int c3 = 0; c3 < characterIndexer.size(); ++c3) {
            int c3Index = characterIndexer.i2o((int)c3).index;
            if (c1Index >= c3Index || c3Index >= c2Index) continue;
            return false;
        }
        return true;
    }

    private CTMC getHeuristicEstimate() {
        if (this.siteSpecific) {
            ArrayList<double[][]> rateMtrices = new ArrayList<double[][]>();
            double[][] sds = this.statDistn(this.data.observations());
            for (int s = 0; s < sds.length; ++s) {
                rateMtrices.add(RateMtxUtils.reversibleRateMtx(this.rate, sds[s]));
            }
            return new CTMC.GeneralCTMC(rateMtrices);
        }
        double[] sd = this.globalStatDistn(this.data.observations());
        double[][] rateMtx = RateMtxUtils.reversibleRateMtx(this.rate, sd);
        LogInfo.logs("Estimated stat dist:" + Arrays.toString(sd));
        LogInfo.logs("Estimated rate matrix from stat dist:\n" + Table.toString(rateMtx));
        return new CTMC.SimpleCTMC(rateMtx, this.data.nSites());
    }

    private double[][] statDistn(Map<Taxon, double[][]> observations) {
        int nSites = this.data.nSites();
        double[][] result = new double[this.data.nSites()][];
        if (this.forceUniform) {
            double[][] anObsArray = observations.values().iterator().next();
            for (int s = 0; s < nSites; ++s) {
                int nChars = anObsArray[s].length;
                result[s] = new double[nChars];
                for (int i = 0; i < result[s].length; ++i) {
                    result[s][i] = 1.0 / (double)nChars;
                }
            }
            return result;
        }
        for (int s = 0; s < nSites; ++s) {
            result[s] = new double[this.data.nCharacter(s)];
            for (Taxon lang : observations.keySet()) {
                this.process(observations.get(lang)[s], result[s]);
            }
            int c = 0;
            while (c < this.data.nCharacter(s)) {
                double[] dArray = result[s];
                int n = c++;
                dArray[n] = dArray[n] + 1.0;
            }
            NumUtils.normalize(result[s]);
        }
        return result;
    }

    private void process(double[] currentObs, double[] result) {
        double sum = MathUtils.sum(currentObs);
        if (sum == (double)currentObs.length) {
            return;
        }
        if (MathUtils.close(1.0, sum)) {
            for (int i = 0; i < currentObs.length; ++i) {
                int n = i;
                result[n] = result[n] + currentObs[i];
            }
        } else {
            throw new RuntimeException();
        }
    }

    private double[] globalStatDistn(Map<Taxon, double[][]> observations) {
        if (this.forceUniform) {
            throw new RuntimeException();
        }
        int nChars = this.data.nCharacter(0);
        double[] result = new double[nChars];
        for (Taxon lang : observations.keySet()) {
            for (int s = 0; s < this.data.nSites(); ++s) {
                this.process(observations.get(lang)[s], result);
            }
        }
        NumUtils.normalize(result);
        return result;
    }

    public static enum LoadingMethod {
        BUILT_IN{

            @Override
            public CTMC load(CTMCLoader loader) {
                if (loader.siteSpecific) {
                    throw new RuntimeException();
                }
                CTMC.SimpleCTMC result = CTMC.SimpleCTMC.fromSequenceType(loader.data.nSites(), loader.builtInSequenceType, loader.rate);
                LogInfo.logs("Simple built-in CTMC:\n" + result);
                return result;
            }
        }
        ,
        SCRIPT{

            @Override
            public CTMC load(CTMCLoader loader) {
                if (loader.siteSpecific) {
                    throw new RuntimeException();
                }
                WalsDataset dataset = (WalsDataset)loader.data;
                WalsProcessingScript script = WalsDataset.getScript();
                ArrayList<double[][]> Qs = CollUtils.list();
                for (int s = 0; s < dataset.nSites(); ++s) {
                    int c2;
                    int c1;
                    WalsDataset.Site site = dataset.siteIndexer().i2o(s);
                    int nChars = dataset.nCharacter(s);
                    double[][] current = new double[nChars][nChars];
                    Indexer<WalsDataset.BioCharacter> characterIndexer = dataset.charIndexers().get(site);
                    if (script.orderFeature.contains(site.toString())) {
                        for (c1 = 0; c1 < nChars; ++c1) {
                            for (c2 = 0; c2 < nChars; ++c2) {
                                if (c1 == c2 || !CTMCLoader.linked(characterIndexer, c1, c2)) continue;
                                current[c1][c2] = 1.0;
                            }
                        }
                    } else {
                        for (c1 = 0; c1 < nChars; ++c1) {
                            for (c2 = 0; c2 < nChars; ++c2) {
                                if (c1 == c2) continue;
                                current[c1][c2] = 1.0;
                            }
                        }
                    }
                    RateMtxUtils.fillRateMatrixDiagonalEntries(current);
                    Qs.add(current);
                }
                return new CTMC.GeneralCTMC(Qs);
            }
        }
        ,
        FILE{

            @Override
            public CTMC load(CTMCLoader loader) {
                return CTMCUtils.unSerialize(new File(loader.file));
            }
        }
        ,
        HEURISTIC_ESTIMATE{

            @Override
            public CTMC load(CTMCLoader loader) {
                return loader.getHeuristicEstimate();
            }
        };


        public abstract CTMC load(CTMCLoader var1);
    }
}

