/*
 * Decompiled with CFR 0.152.
 */
package ma;

import fig.basic.IOUtils;
import fig.basic.Option;
import fig.exec.Execution;
import goblin.CognateId;
import goblin.DataPrepUtils;
import goblin.DerivationTree;
import goblin.HLParams;
import goblin.Taxon;
import goblin.TreeSamplers;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.regex.Pattern;
import ma.AffineGapAlignmentSampler;
import ma.ArbreInitializer;
import ma.BalibaseCorpus;
import ma.Baselines;
import ma.BioCorpus;
import ma.GeneratedCorpus;
import ma.MultiAlignment;
import nuts.io.IO;
import nuts.lang.StringUtils;
import nuts.math.MeanVar;
import nuts.math.MeasureZeroException;
import nuts.math.RejectionSampler;
import nuts.util.All2OneMap;
import nuts.util.Arbre;
import nuts.util.Tree;

public class YAM
implements Runnable {
    private BioCorpus corpus;
    private Set<CognateId> ignored = new HashSet<CognateId>();
    private List<CognateId> train;
    private List<CognateId> test;
    private AffineGapAlignmentSampler.TKFParams baseParams;
    private RejectionSampler<Arbre<DerivationTree.DerivationNode>> rejectionSampler = new RejectionSampler();
    private Map<CognateId, Arbre<DerivationTree.DerivationNode>> initializations = new HashMap<CognateId, Arbre<DerivationTree.DerivationNode>>();
    @Option(gloss="times total number of characters in the init")
    public int maxTimePerCharacter = 20;
    @Option
    public LearnType learnType = LearnType.HL;
    @Option
    public boolean useCheatInit = false;
    @Option
    public boolean initUsingHandel = true;
    @Option
    public boolean initUsingAln = true;
    @Option
    public String alnPath = ".";
    @Option
    public boolean generateCorpus = false;
    @Option
    public boolean sampleUsingHandel = true;
    @Option
    public boolean sampleUsingGoblin = true;
    @Option
    public boolean etrange = false;
    @Option
    public double branchDiscretizationLength = 0.2;
    @Option
    public int nEMIters = 10;
    @Option
    public double heldoutProp = 0.5;
    @Option
    public boolean useGoldToStartAnneal = true;
    @Option
    public int goldAlignRetries = 10;
    @Option
    public Random permutationRandom = new Random(1L);
    @Option
    public Random initRandom = new Random(1L);
    public TreeSamplers.AncestryMCMCKernelOptions sOptions = new TreeSamplers.AncestryMCMCKernelOptions();
    public GoblinAlignerOptions gaOptions = new GoblinAlignerOptions();
    public HandelAligner.HandelAlignerOptions hOptions = new HandelAligner.HandelAlignerOptions();
    public GeneratedCorpus.GeneratedCorpusOptions gOptions = new GeneratedCorpus.GeneratedCorpusOptions();
    public BalibaseCorpus.BalibaseCorpusOptions bOptions = new BalibaseCorpus.BalibaseCorpusOptions();
    public static final double BRANCH_LENGTH_USED_FOR_INIT = 1.0;
    private AffineGapAlignmentSampler.AlignmentSufficientStatistics currentSuffStat = null;
    private Map<CognateId, HandelAligner> handelCache = new HashMap<CognateId, HandelAligner>();
    public static final Pattern HANDEL_LL_REGEX = Pattern.compile(".*LgP[(]alignment[|]tree[)][=] ([.0123456789-]+) bits.*");

    @Override
    public void run() {
        System.out.println();
        try {
            this.init();
            this.evaluate(true);
            if (this.learnType != LearnType.NO) {
                for (int i = 0; i < this.nEMIters; ++i) {
                    this.train();
                    this.evaluate(false);
                }
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void init() {
        if (this.useCheatInit && this.initUsingAln) {
            throw new RuntimeException();
        }
        if (this.initUsingHandel && !this.sampleUsingHandel) {
            throw new RuntimeException();
        }
        this.corpus = this.generateCorpus ? new GeneratedCorpus(this.gOptions) : new BalibaseCorpus(this.bOptions);
        this.splitCorpus();
        this.baseParams = this.corpus.getType().defaultTKFParams();
        for (CognateId id : this.corpus.intersectedIds()) {
            ArbreInitializer initializer = null;
            try {
                new ArbreInitializer(this.baseParams.compile(1.0), this.initRandom).init(DataPrepUtils.tree2arbre2(this.corpus.getTopology(id), this.corpus.getMultiAlignment(id).getSequences()));
                initializer = this.useCheatInit ? new ArbreInitializer(this.corpus.getMultiAlignment(id), this.initRandom) : (this.initUsingAln ? this.tryReadAlnAndCreateInitializer(id) : new ArbreInitializer(this.baseParams.compile(1.0), this.initRandom));
                this.initializations.put(id, initializer.init(DataPrepUtils.tree2arbre2(this.corpus.getTopology(id), this.corpus.getMultiAlignment(id).getSequences())));
            }
            catch (RuntimeException re) {
                System.err.println("Problem with: " + id + ", entry ignored. Details: " + re);
                this.ignored.add(id);
            }
        }
    }

    private ArbreInitializer tryReadAlnAndCreateInitializer(CognateId id) {
        File alnFile = new File(this.alnPath, id.toString() + ".aln");
        if (alnFile.exists()) {
            try {
                return new ArbreInitializer(MultiAlignment.parseALNToMultiAlignment(alnFile.getPath(), ""), this.initRandom);
            }
            catch (IOException e) {
                throw new RuntimeException();
            }
        }
        System.err.println("Warning: could not use aln for " + id);
        return new ArbreInitializer(this.baseParams.compile(1.0), this.initRandom);
    }

    private void splitCorpus() {
        if (this.learnType != LearnType.NO) {
            ArrayList<CognateId> all = new ArrayList<CognateId>(this.corpus.intersectedIds());
            int testSize = (int)(this.heldoutProp * (double)all.size());
            Collections.shuffle(all, this.permutationRandom);
            this.test = new ArrayList<CognateId>(all.subList(0, testSize));
            this.train = new ArrayList<CognateId>(all.subList(testSize, all.size()));
        } else {
            this.test = new ArrayList<CognateId>(this.corpus.intersectedIds());
        }
    }

    private void train() {
        System.out.println("Training");
        if (this.learnType == LearnType.HL) {
            this.hlTrain();
            return;
        }
        AffineGapAlignmentSampler.AlignmentSufficientStatistics nextSuffStat = new AffineGapAlignmentSampler.AlignmentSufficientStatistics();
        for (CognateId id : this.train) {
            if (this.ignored.contains(id)) continue;
            Arbre<DerivationTree.DerivationNode> init = this.useGoldToStartAnneal ? new ArbreInitializer(this.corpus.getMultiAlignment(id), this.initRandom).init(this.initializations.get(id)) : this.initializations.get(id);
            long maxTime = YAM.maxTime(init, this.maxTimePerCharacter);
            System.out.println("Entry: " + id + ",\tmaxTime=" + maxTime);
            GoblinAligner goblinAligner = this.goblinAlign(id, maxTime, true);
            goblinAligner.align();
            nextSuffStat.add(goblinAligner.getSuffStat());
        }
        this.currentSuffStat = nextSuffStat;
    }

    private void hlTrain() {
    }

    private void evaluate(boolean firstRound) {
        System.out.println("Evaluating");
        MeanVar handelCS = new MeanVar();
        MeanVar goblinCS = new MeanVar();
        MeanVar handelSP = new MeanVar();
        MeanVar goblinSP = new MeanVar();
        for (CognateId id : this.test) {
            if (this.ignored.contains(id)) continue;
            Arbre<DerivationTree.DerivationNode> init = this.initializations.get(id);
            long maxTime = YAM.maxTime(init, this.maxTimePerCharacter);
            System.out.println("Entry: " + id + ",\tmaxTime=" + maxTime);
            HandelAligner handelAligner = null;
            System.out.println("Gold:\n\t" + this.corpus.getMultiAlignment(id).toString().replaceAll("\n", "\n\t"));
            if (firstRound && this.sampleUsingHandel) {
                handelAligner = this.handelAlign(id, maxTime);
                handelAligner.align();
                handelCS.addPoint(handelAligner.getArgmaxColumnScore());
                handelSP.addPoint(handelAligner.getArgmaxSumOfPairs());
            }
            if (!this.sampleUsingGoblin) continue;
            GoblinAligner goblinAligner = this.goblinAlign(id, maxTime, false);
            goblinAligner.align();
            goblinCS.addPoint(goblinAligner.getArgmaxColumnScore());
            goblinSP.addPoint(goblinAligner.getArgmaxSumOfPairs());
        }
        System.out.println();
        System.out.println("System\t\tCS\t\tSP");
        if (this.sampleUsingGoblin) {
            System.out.println("Goblin\t\t" + goblinCS.getMean() + "\t\t" + goblinSP.getMean());
        }
        if (this.sampleUsingHandel) {
            System.out.println("Handel\t\t" + handelCS.getMean() + "\t\t" + handelSP.getMean());
        }
        System.out.println();
    }

    private GoblinAligner goblinAlign(CognateId id, long maxTime, boolean annealToGold) {
        HandelAligner handelAligner;
        Arbre<DerivationTree.DerivationNode> init;
        Map<Taxon, Double> lengths = this.corpus.getBranchLengths(id);
        Arbre<DerivationTree.DerivationNode> arbre = init = annealToGold && this.useGoldToStartAnneal ? new ArbreInitializer(this.corpus.getMultiAlignment(id), this.initRandom).init(this.initializations.get(id)) : this.initializations.get(id);
        if (this.initUsingHandel && (handelAligner = this.handelCache.get(id)) != null) {
            init = this.handelInit(handelAligner, init);
        }
        if (this.learnType != LearnType.NO) {
            init = YAM.hobGoblinTransform(init, this.branchDiscretizationLength, lengths);
        }
        GoblinAligner aligner = this.learnType == LearnType.NO ? new GoblinAligner(lengths, this.sOptions, this.gaOptions, init, this.compileParams(lengths)) : (this.learnType == LearnType.HOB ? new GoblinAligner(lengths, this.sOptions, this.gaOptions, init, this.hobCompileParams()) : new GoblinAligner(lengths, this.sOptions, this.gaOptions, init, this.hlParams()));
        aligner.setGold(this.corpus.getMultiAlignment(id));
        if (annealToGold) {
            aligner.annealToGold();
        }
        aligner.setMaxTime(maxTime);
        return aligner;
    }

    private Map<Taxon, AffineGapAlignmentSampler.GapAlignmentParams> hlParams() {
        return null;
    }

    private Map<Taxon, AffineGapAlignmentSampler.GapAlignmentParams> hobCompileParams() {
        AffineGapAlignmentSampler.DiscreteAffineGapParameters gp = AffineGapAlignmentSampler.DiscreteAffineGapParameters.createFromTKF(this.baseParams.compile(this.branchDiscretizationLength));
        if (this.currentSuffStat != null) {
            gp = gp.reestimate(this.currentSuffStat);
        }
        System.out.println("Current params: " + gp);
        return new All2OneMap<Taxon, AffineGapAlignmentSampler.GapAlignmentParams>(gp);
    }

    private Map<Taxon, AffineGapAlignmentSampler.GapAlignmentParams> compileParams(Map<Taxon, Double> lengths) {
        HashMap<Taxon, AffineGapAlignmentSampler.GapAlignmentParams> result = new HashMap<Taxon, AffineGapAlignmentSampler.GapAlignmentParams>();
        for (Taxon cLang : lengths.keySet()) {
            result.put(cLang, this.baseParams.compile(lengths.get(cLang)));
        }
        return result;
    }

    private Arbre<DerivationTree.DerivationNode> handelInit(HandelAligner handelAligner, Arbre<DerivationTree.DerivationNode> oldInit) {
        ArbreInitializer initializer = new ArbreInitializer(handelAligner.getArgmax(), this.initRandom);
        return initializer.init(oldInit);
    }

    private HandelAligner handelAlign(CognateId id, long maxTime) {
        HandelAligner aligner = new HandelAligner(this.corpus.getTopology(id), this.corpus.getBranchLengths(id), this.hOptions, MultiAlignment.inducedMultiAlignment(this.initializations.get(id)));
        aligner.setGold(this.corpus.getMultiAlignment(id));
        aligner.setMaxTime(maxTime);
        this.handelCache.put(id, aligner);
        return aligner;
    }

    public static void main(String[] args) throws MeasureZeroException {
        System.out.println("Invo. Args.: " + Arrays.toString(args));
        System.out.println();
        YAM yam = new YAM();
        Execution.run(args, yam, "rs", yam.rejectionSampler, "gao", yam.gaOptions, "so", yam.sOptions, "ho", yam.hOptions, "go", yam.gOptions, "bo", yam.bOptions);
    }

    public static long maxTime(Arbre<DerivationTree.DerivationNode> initialization, long timeAllowedPerInitialCharacter) {
        if (timeAllowedPerInitialCharacter == Integer.MAX_VALUE) {
            return Integer.MAX_VALUE;
        }
        int nChars = 0;
        for (Arbre<DerivationTree.DerivationNode> node : initialization.nodes()) {
            nChars += node.getContents().getWord().length();
        }
        return (long)nChars * timeAllowedPerInitialCharacter;
    }

    public static Arbre<DerivationTree.DerivationNode> hobGoblinTransform(Arbre<DerivationTree.DerivationNode> currentSample, double discreteBranchLength, Map<Taxon, Double> lengths) {
        DerivationTree.DerivationNode contents = currentSample.getContents();
        Arbre<DerivationTree.DerivationNode> result = Arbre.arbre(new DerivationTree.DerivationNode(contents.getLanguage(), contents.getWord(), YAM.id(contents.getWord())));
        for (Arbre<DerivationTree.DerivationNode> children : currentSample.getChildren()) {
            result.addLeaves(YAM.hobGoblinTransform(children, discreteBranchLength, lengths));
        }
        if (!currentSample.isRoot()) {
            int n = (int)(lengths.get(contents.getLanguage()) / discreteBranchLength);
            for (int i = 0; i < n; ++i) {
                Arbre<DerivationTree.DerivationNode> temp = Arbre.arbre(new DerivationTree.DerivationNode(new Taxon(contents.getLanguage() + "_" + i), contents.getWord(), YAM.id(contents.getWord())));
                temp.addLeaves(result);
                result = temp;
            }
            result.setContents(new DerivationTree.DerivationNode(result.getContents().getLanguage(), result.getContents().getWord(), contents.getDerivation()));
        }
        return result;
    }

    public static DerivationTree.Derivation id(String word) {
        if (word == null) {
            return null;
        }
        int[] ancestors = new int[word.length()];
        for (int i = 0; i < ancestors.length; ++i) {
            ancestors[i] = i;
        }
        return new DerivationTree.Derivation(ancestors, word, word);
    }

    public static void saveToWei(File file, Tree<String> topology, Map<Taxon, Double> branchLengths) {
        CharSequence wei = null;
        try {
            PrintWriter out = IOUtils.openOut(file);
            out.append(wei);
            out.close();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static double getHandelLL(String string) {
        String logLLStr;
        try {
            logLLStr = StringUtils.selectRegex(HANDEL_LL_REGEX, string).get(0);
        }
        catch (Exception e) {
            throw new RuntimeException("Bad handel output:\n" + string);
        }
        return Double.parseDouble(logLLStr);
    }

    public static String statLine(MultiAlignment guess, MultiAlignment gold, long firstStatDump, double max) {
        return "time=" + (System.currentTimeMillis() - firstStatDump) + (gold == null ? "" : ",\tcolumnScore=" + gold.columnScore(guess) + ",\tsumOfPairs=" + gold.sumOfPairsScore(guess)) + ",\tlikelihood=" + max;
    }

    static /* synthetic */ RejectionSampler access$000(YAM x0) {
        return x0.rejectionSampler;
    }

    public static class HandelAligner
    extends IterativeTreeBasedAligner {
        private final MultiAlignment initialization;
        private MultiAlignment current;
        private int nSame = 0;
        public static final int MAX_N_SAME = 10;
        private final Tree<String> topology;
        private final HandelAlignerOptions options;
        private MultiAlignment argmax = null;

        @Override
        public MultiAlignment getArgmax() {
            return this.argmax;
        }

        public HandelAligner(Tree<String> topology, Map<Taxon, Double> branchLengths, HandelAlignerOptions options, MultiAlignment initialization) {
            super(branchLengths);
            this.topology = topology;
            try {
                this.options = (HandelAlignerOptions)options.clone();
            }
            catch (CloneNotSupportedException cnse) {
                throw new RuntimeException(cnse);
            }
            this.initialization = initialization;
            this.current = initialization;
        }

        @Override
        public void align() {
            System.out.println("Handel");
            long start = System.currentTimeMillis();
            double maxLL = Double.NEGATIVE_INFINITY;
            try {
                File wei = File.createTempFile("wei" + System.currentTimeMillis(), null);
                YAM.saveToWei(wei, this.topology, this.branchLengths);
                File initPath = File.createTempFile("fasta" + System.currentTimeMillis(), null);
                this.initialization.saveToMSF(initPath);
                for (int i = 0; i < this.options.nSamplingRounds; ++i) {
                    String result = IO.call(Baselines.tkfalign + " -s " + this.options.samplesPerRound + " " + wei.getPath() + " " + initPath.getPath());
                    double ll = YAM.getHandelLL(result);
                    MultiAlignment ma = MultiAlignment.parseALNStringToMultiAlignment(Baselines.stock2ALN(result), "");
                    if (ll > maxLL) {
                        this.argmax = ma;
                        maxLL = ll;
                    }
                    System.out.println(YAM.statLine(this.argmax, this.gold, start, maxLL));
                    ma.saveToMSF(initPath);
                    if (System.currentTimeMillis() - start > this.maxTime) break;
                    if (this.current.equals(ma)) {
                        ++this.nSame;
                        if (this.nSame > 10) {
                            break;
                        }
                    } else {
                        this.nSame = 0;
                    }
                    this.current = ma;
                }
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        public static class HandelAlignerOptions
        implements Cloneable {
            @Option
            public int nSamplingRounds = Integer.MAX_VALUE;
            @Option
            public int samplesPerRound = 20;

            public Object clone() throws CloneNotSupportedException {
                return super.clone();
            }
        }
    }

    public class GoblinAligner
    extends IterativeTreeBasedAligner {
        private final TreeSamplers.AncestryMCMCKernelOptions options;
        private final GoblinAlignerOptions alignerOptions;
        private final Arbre<DerivationTree.DerivationNode> initialization;
        private final Map<Taxon, AffineGapAlignmentSampler.GapAlignmentParams> parameters;
        private final HLParams hlParams;
        private MaxSampleProcessor msp;
        private CollectSuffStatProcessor cssp;
        private boolean annealToGold;

        public GoblinAligner(Map<Taxon, Double> branchLengths, TreeSamplers.AncestryMCMCKernelOptions options, GoblinAlignerOptions alignerOptions, Arbre<DerivationTree.DerivationNode> initialization, Map<Taxon, AffineGapAlignmentSampler.GapAlignmentParams> parameters) {
            super(branchLengths);
            this.msp = null;
            this.cssp = null;
            this.annealToGold = false;
            this.options = options;
            this.alignerOptions = alignerOptions;
            this.initialization = initialization;
            this.parameters = parameters;
            this.hlParams = null;
        }

        public GoblinAligner(Map<Taxon, Double> branchLengths, TreeSamplers.AncestryMCMCKernelOptions options, GoblinAlignerOptions alignerOptions, Arbre<DerivationTree.DerivationNode> initialization, HLParams hlParams) {
            super(branchLengths);
            this.msp = null;
            this.cssp = null;
            this.annealToGold = false;
            this.options = options;
            this.alignerOptions = alignerOptions;
            this.initialization = initialization;
            this.parameters = null;
            this.hlParams = hlParams;
        }

        public boolean isHL() {
            return this.parameters == null;
        }

        private TreeSamplers.SampleProcessor getProcessor() {
            TreeSamplers.ForkedSampleProcessor result = new TreeSamplers.ForkedSampleProcessor();
            this.msp = new MaxSampleProcessor();
            this.msp.setGold(this.gold);
            result.processors.add(this.msp);
            this.cssp = new CollectSuffStatProcessor();
            result.processors.add(this.cssp);
            return result;
        }

        public void annealToGold() {
            this.annealToGold = true;
        }

        @Override
        public void align() {
            System.out.println("Goblin");
            long start = System.currentTimeMillis();
            throw new RuntimeException();
        }

        @Override
        public MultiAlignment getArgmax() {
            return MultiAlignment.inducedMultiAlignment(this.msp.getArgmax());
        }

        public AffineGapAlignmentSampler.AlignmentSufficientStatistics getSuffStat() {
            return this.cssp.getAveragedSuffStat();
        }
    }

    public static class GoblinAlignerOptions
    implements Cloneable {
        @Option
        public int nSamplingRounds = Integer.MAX_VALUE;
        @Option
        public Random samplerRandom = new Random(1L);
        @Option
        public Random permutationRandom = new Random(1L);

        public Object clone() throws CloneNotSupportedException {
            return super.clone();
        }
    }

    public static abstract class IterativeTreeBasedAligner {
        protected MultiAlignment gold = null;
        protected final Map<Taxon, Double> branchLengths;
        protected long maxTime = Long.MAX_VALUE;

        public IterativeTreeBasedAligner(Map<Taxon, Double> branchLengths) {
            this.branchLengths = branchLengths;
        }

        public abstract MultiAlignment getArgmax();

        public abstract void align();

        public double getArgmaxColumnScore() {
            return this.gold.columnScore(this.getArgmax());
        }

        public double getArgmaxSumOfPairs() {
            return this.gold.sumOfPairsScore(this.getArgmax());
        }

        public void setGold(MultiAlignment gold) {
            if (this.gold != null) {
                throw new RuntimeException();
            }
            this.gold = gold;
        }

        public void setMaxTime(long maxTime) {
            this.maxTime = maxTime;
        }
    }

    public static class MaxSampleProcessor
    implements TreeSamplers.SampleProcessor {
        private double max = Double.NEGATIVE_INFINITY;
        private Arbre<DerivationTree.DerivationNode> argmax = null;
        private MultiAlignment gold = null;
        private long interval = 10000L;
        private int nSamples = 0;
        private long firstStatDump = System.currentTimeMillis();
        private long lastStatDump = System.currentTimeMillis();

        public MultiAlignment getGold() {
            return this.gold;
        }

        public void setGold(MultiAlignment gold) {
            this.gold = gold;
            this.firstStatDump = System.currentTimeMillis();
            this.lastStatDump = System.currentTimeMillis();
        }

        public Arbre<DerivationTree.DerivationNode> getArgmax() {
            return this.argmax;
        }

        public double getMax() {
            return this.max;
        }

        @Override
        public void process(Arbre<DerivationTree.DerivationNode> currentSample, double score, CognateId id) {
            ++this.nSamples;
            if (score > this.max) {
                this.argmax = currentSample;
                this.max = score;
            }
            if (this.gold != null && System.currentTimeMillis() - this.lastStatDump > this.interval || this.max == score) {
                MultiAlignment guess = MultiAlignment.inducedMultiAlignment(currentSample);
                System.out.println(YAM.statLine(guess, this.gold, this.firstStatDump, score) + "\tnSamples=" + this.nSamples + (this.max == score ? "\t*" : "\t"));
                this.lastStatDump = System.currentTimeMillis();
            }
        }
    }

    public static class CollectSuffStatProcessor
    implements TreeSamplers.SampleProcessor {
        private AffineGapAlignmentSampler.AlignmentSufficientStatistics ass = new AffineGapAlignmentSampler.AlignmentSufficientStatistics();
        private double N = 0.0;

        @Override
        public void process(Arbre<DerivationTree.DerivationNode> currentSample, double score, CognateId id) {
            this.N += 1.0;
            this.ass.add(currentSample);
        }

        public AffineGapAlignmentSampler.AlignmentSufficientStatistics getAveragedSuffStat() {
            AffineGapAlignmentSampler.AlignmentSufficientStatistics result = new AffineGapAlignmentSampler.AlignmentSufficientStatistics();
            result.add(this.ass);
            result.scale(1.0 / this.N);
            return result;
        }
    }

    public static enum LearnType {
        NO,
        HOB,
        HL;

    }
}

