/*
 * Decompiled with CFR 0.152.
 */
package goblin;

import fig.basic.IOUtils;
import fig.basic.LogInfo;
import goblin.BayesRiskMinimizer;
import goblin.CognateId;
import goblin.CognateSet;
import goblin.DerivationTree;
import goblin.ObservationsTracker;
import goblin.Taxon;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import ma.MultiAlignment;
import nuts.tui.Table;
import nuts.util.Arbre;
import nuts.util.CollUtils;
import nuts.util.Counter;
import pepper.CognateDetector;

public class Heldout
implements Serializable {
    private static final long serialVersionUID = 3L;
    private final Map<CognateId, String> wordHeldoutEntries = new HashMap<CognateId, String>();
    private final Taxon lang;
    private final Map<CognateId, MultiAlignment> multiAlignmentReferences = new HashMap<CognateId, MultiAlignment>();
    private boolean frozen = false;

    public Set<CognateId> getMAIds() {
        return this.multiAlignmentReferences.keySet();
    }

    public String toString() {
        StringBuilder result = new StringBuilder();
        result.append("Multi alignments references:\n");
        for (CognateId id : this.multiAlignmentReferences.keySet()) {
            result.append("" + id + "\n");
            result.append(this.multiAlignmentReferences.get(id) + "\n");
        }
        return result.toString();
    }

    public void freeze() {
        this.frozen = true;
    }

    public Heldout(Taxon lang) {
        this.lang = lang;
    }

    public Heldout() {
        this.lang = null;
    }

    public void addWordHeldoutEntry(CognateId id, String trueRecons) {
        if (this.lang == null || this.frozen) {
            throw new RuntimeException();
        }
        this.wordHeldoutEntries.put(id, trueRecons);
    }

    public void addMultiAlignmentReferenceEntry(CognateId id, MultiAlignment ma) {
        if (ma == null || !ma.isReference() || this.frozen) {
            new RuntimeException();
        }
        this.multiAlignmentReferences.put(id, ma);
    }

    public WordHeldoutEvaluation evaluateWordHeldout(CognateSet cognates) {
        HashMap scores = CollUtils.map();
        for (CognateId id : this.wordHeldoutEntries.keySet()) {
            String proposedReconstruction = cognates.getWord(id, this.lang);
            scores.put(id, Heldout.editDistance(proposedReconstruction, this.wordHeldoutEntries.get(id)));
        }
        return new WordHeldoutEvaluation(scores);
    }

    public int size() {
        return this.wordHeldoutEntries.size();
    }

    private static double editDistance(String s1, String s2) {
        CognateDetector editDComputer = new CognateDetector(1.0, 1.0, 1.0);
        return editDComputer.cost(s1, s2);
    }

    public WordHeldoutEvaluation randomBaseline(CognateSet cognates, Random rand) {
        HashMap<CognateId, Double> scores = new HashMap<CognateId, Double>();
        for (CognateId id : this.wordHeldoutEntries.keySet()) {
            ArrayList obs = CollUtils.list();
            obs.addAll(cognates.getObs(id).observedLanguages());
            int rIndex = rand.nextInt(obs.size());
            Taxon rObs = (Taxon)obs.get(rIndex);
            Arbre<DerivationTree.DerivationNode> arbre = cognates.getTree(id);
            String randomBaseline = DerivationTree.findNodeByLangName(arbre, rObs).getContents().getWord();
            scores.put(id, Heldout.editDistance(randomBaseline, this.wordHeldoutEntries.get(id)));
        }
        return new WordHeldoutEvaluation(scores);
    }

    public WordHeldoutEvaluation oracleBaseline(CognateSet cognates) {
        HashMap scores = new HashMap();
        for (CognateId id : this.wordHeldoutEntries.keySet()) {
            Arbre<DerivationTree.DerivationNode> arbre = cognates.getTree(id);
            ArrayList<Double> subScores = new ArrayList<Double>();
            for (Taxon lang : cognates.getObs(id).observedLanguages()) {
                subScores.add(Heldout.editDistance(DerivationTree.findNodeByLangName(arbre, lang).getContents().getWord(), this.wordHeldoutEntries.get(id)));
            }
            scores.put(id, Collections.min(subScores));
        }
        return new WordHeldoutEvaluation(scores);
    }

    public String pointEstimateStats(CognateId id, MultiAlignment reconstruction) {
        MultiAlignment ref = this.multiAlignmentReferences.get(id);
        return "CS=" + ref.columnScore(reconstruction) + ",SP=" + ref.sumOfPairsScore(reconstruction);
    }

    public String pointEstimateStats(CognateId id, Arbre<DerivationTree.DerivationNode> reconstruction) {
        return this.pointEstimateStats(id, MultiAlignment.inducedMultiAlignment(reconstruction, ObservationsTracker.modernObservationsTracker(reconstruction)));
    }

    public static void saveHeldoutEvaluation(WordHeldoutEvaluation heldout, String filePath) throws IOException {
        ObjectOutputStream oos = IOUtils.openBinOut(filePath);
        oos.writeObject(heldout);
        oos.close();
    }

    public static WordHeldoutEvaluation restoreHeldoutEvaluation(String filePath) throws IOException, ClassNotFoundException {
        ObjectInputStream ois = IOUtils.openBinIn(filePath);
        return (WordHeldoutEvaluation)ois.readObject();
    }

    public static void saveHeldout(Heldout heldout, String filePath) throws IOException {
        ObjectOutputStream oos = IOUtils.openBinOut(filePath);
        oos.writeObject(heldout);
        oos.close();
    }

    public static Heldout restoreHeldout(String filePath) throws IOException, ClassNotFoundException {
        ObjectInputStream ois = IOUtils.openBinIn(filePath);
        return (Heldout)ois.readObject();
    }

    public static void saveBayesEvaluator(BayesEvaluator be, String filePath) throws IOException {
        ObjectOutputStream oos = IOUtils.openBinOut(filePath);
        oos.writeObject(be);
        oos.close();
    }

    public static <T> boolean expNormalize(Counter<T> probs) {
        double max = Double.NEGATIVE_INFINITY;
        for (T key : probs.keySet()) {
            max = Math.max(max, probs.getCount(key));
        }
        for (T key : probs.keySet()) {
            probs.setCount(key, Math.exp(probs.getCount(key) - max));
        }
        return Heldout.normalize(probs);
    }

    public static <T> boolean normalize(Counter<T> data) {
        double sum = 0.0;
        for (T key : data.keySet()) {
            sum += data.getCount(key);
        }
        if (sum == 0.0) {
            return false;
        }
        for (T key : data.keySet()) {
            data.setCount(key, data.getCount(key) / sum);
        }
        return true;
    }

    private static enum MaxType {
        ORACLE_CS,
        ORACLE_SP,
        PR;

    }

    public class WordHeldoutEvaluation
    implements Serializable {
        private static final long serialVersionUID = 4L;
        private final Map<CognateId, Double> scores;

        private WordHeldoutEvaluation(Map<CognateId, Double> scores) {
            this.scores = scores;
        }

        public double averageOverWords() {
            double result = 0.0;
            for (CognateId id : this.scores.keySet()) {
                result += this.scores.get(id).doubleValue();
            }
            return result / (double)this.scores.size();
        }
    }

    public class BayesEvaluator
    implements Serializable {
        private static final long serialVersionUID = 2L;
        private final boolean useTotalLogProb;
        private final Map<CognateId, Counter<String>> samples = new HashMap<CognateId, Counter<String>>();
        private final Map<CognateId, String> initialSamplerState = new HashMap<CognateId, String>();
        private final Map<CognateId, Counter<MultiAlignment>> multiAlignmentSampleCounters = new HashMap<CognateId, Counter<MultiAlignment>>();
        private final Map<CognateId, MultiAlignment> initialMultiAlignmentState = new HashMap<CognateId, MultiAlignment>();
        private final Map<CognateId, ObservationsTracker> observations = new HashMap<CognateId, ObservationsTracker>();
        private File dumpFolder = null;

        public BayesEvaluator(CognateSet cognates, boolean useTotalLogProbabilities) {
            this.useTotalLogProb = useTotalLogProbabilities;
            for (CognateId id : cognates.getCognateIds()) {
                if (Heldout.this.wordHeldoutEntries.keySet().contains(id)) {
                    this.samples.put(id, new Counter());
                    this.initialSamplerState.put(id, cognates.getWord(id, Heldout.this.lang));
                }
                if (!Heldout.this.multiAlignmentReferences.keySet().contains(id)) continue;
                Arbre<DerivationTree.DerivationNode> tree = cognates.getTree(id);
                ObservationsTracker obs = cognates.getObs(id);
                this.multiAlignmentSampleCounters.put(id, new Counter());
                MultiAlignment initialMultiAlign = MultiAlignment.inducedMultiAlignment(tree, obs);
                this.initialMultiAlignmentState.put(id, initialMultiAlign);
                this.observations.put(id, obs);
            }
        }

        public BayesEvaluator(CognateSet cognates) {
            this(cognates, false);
        }

        public void processSample(CognateId id, Arbre<DerivationTree.DerivationNode> newTree) {
            if (this.useTotalLogProb) {
                throw new RuntimeException();
            }
            this.processSample(id, newTree, null);
        }

        public void processSample(CognateId id, Arbre<DerivationTree.DerivationNode> newTree, Double totalLogPr) {
            if (Heldout.this.wordHeldoutEntries.containsKey(id)) {
                String sample = DerivationTree.findNodeByLangName(newTree, Heldout.this.lang).getContents().getWord();
                if (sample == null) {
                    throw new RuntimeException();
                }
                this.samples.get(id).incrementCount(sample, 1.0);
            }
            if (Heldout.this.multiAlignmentReferences.containsKey(id)) {
                ObservationsTracker obs = this.observations.get(id);
                MultiAlignment sample = MultiAlignment.inducedMultiAlignment(newTree, obs);
                if (this.useTotalLogProb) {
                    if (this.multiAlignmentSampleCounters.get(id).containsKey(sample)) {
                        totalLogPr = Math.max(totalLogPr, this.multiAlignmentSampleCounters.get(id).getCount(sample));
                    }
                    this.multiAlignmentSampleCounters.get(id).setCount(sample, totalLogPr);
                } else {
                    this.multiAlignmentSampleCounters.get(id).incrementCount(sample, 1.0);
                }
            }
        }

        public WordHeldoutEvaluation evaluate() {
            HashMap scores = CollUtils.map();
            for (CognateId id : Heldout.this.wordHeldoutEntries.keySet()) {
                String proposedReconstruction = this.heldoutWordReconstruction(id);
                scores.put(id, Heldout.editDistance(proposedReconstruction, (String)Heldout.this.wordHeldoutEntries.get(id)));
            }
            return new WordHeldoutEvaluation(scores);
        }

        public String heldoutWordReconstruction(CognateId id) {
            if (!Heldout.this.wordHeldoutEntries.containsKey(id)) {
                throw new RuntimeException();
            }
            Counter<String> sampleMultiplicities = this.samples.get(id);
            if (sampleMultiplicities.size() == 0) {
                return this.initialSamplerState.get(id);
            }
            BayesRiskMinimizer<String> riskMinimizer = new BayesRiskMinimizer<String>(new BayesRiskMinimizer.LevenshteinLoss());
            return riskMinimizer.findMin(sampleMultiplicities);
        }

        private MultiAlignment multiAlignmentBayesReconstruction(CognateId id, boolean useSP) {
            return this.multiAlignmentReconstruction(id, new MultiAlignment.MALossFunction(useSP));
        }

        private MultiAlignment multiAlignmentMaxReconstruction(CognateId id, MaxType maxType) {
            double max = Double.NEGATIVE_INFINITY;
            MultiAlignment argMax = this.initialMultiAlignmentState.get(id);
            if (!this.useTotalLogProb) {
                throw new RuntimeException();
            }
            for (MultiAlignment current : this.multiAlignmentSampleCounters.get(id).keySet()) {
                double currentValue = Double.NEGATIVE_INFINITY;
                if (maxType == MaxType.ORACLE_SP) {
                    currentValue = ((MultiAlignment)Heldout.this.multiAlignmentReferences.get(id)).sumOfPairsScore(current);
                } else if (maxType == MaxType.ORACLE_CS) {
                    currentValue = ((MultiAlignment)Heldout.this.multiAlignmentReferences.get(id)).columnScore(current);
                } else if (maxType == MaxType.PR) {
                    currentValue = this.multiAlignmentSampleCounters.get(id).getCount(current);
                } else {
                    throw new RuntimeException();
                }
                if (!(currentValue > max)) continue;
                max = currentValue;
                argMax = current;
            }
            return argMax;
        }

        private MultiAlignment multiAlignmentReconstruction(CognateId id, BayesRiskMinimizer.LossFct<MultiAlignment> loss) {
            if (!Heldout.this.multiAlignmentReferences.containsKey(id) || !this.multiAlignmentSampleCounters.containsKey(id)) {
                throw new RuntimeException("Id " + id + " not found");
            }
            Counter<MultiAlignment> sampleMultiplicities = this.multiAlignmentSampleCounters.get(id);
            if (sampleMultiplicities.size() == 0) {
                return this.initialMultiAlignmentState.get(id);
            }
            if (this.useTotalLogProb && !Heldout.expNormalize(sampleMultiplicities = new Counter<MultiAlignment>(sampleMultiplicities))) {
                LogInfo.warning("Exp norm failed");
            }
            BayesRiskMinimizer<MultiAlignment> riskMinimizer = new BayesRiskMinimizer<MultiAlignment>(loss);
            MultiAlignment result = riskMinimizer.findMin(sampleMultiplicities);
            if (this.dumpFolder != null) {
                this.dumpMultiAlign(id, result);
            }
            return result;
        }

        private void dumpMultiAlign(CognateId id, MultiAlignment result) {
            String msfFormatted = result.createAlignmentMatrix().toMSFString();
            File msfFile = new File(this.dumpFolder, id.toString());
            try (PrintWriter out = null;){
                out = IOUtils.openOut(msfFile);
                out.append(msfFormatted);
            }
        }

        public double multiAlignmentLoss(CognateId id, BayesRiskMinimizer.LossFct<MultiAlignment> loss) {
            MultiAlignment reconstruction = this.multiAlignmentReconstruction(id, loss);
            MultiAlignment reference = (MultiAlignment)Heldout.this.multiAlignmentReferences.get(id);
            double result = loss.loss(reference, reconstruction);
            return result;
        }

        public double evaluateMultiAlignments() {
            int nMetrics = 6;
            final double[] CSEvaluatedSums = new double[6];
            final double[] SPEvaluatedSums = new double[6];
            final double nInstances = Heldout.this.multiAlignmentReferences.size();
            for (CognateId id : Heldout.this.multiAlignmentReferences.keySet()) {
                MultiAlignment[] reconstructions = new MultiAlignment[]{this.multiAlignmentMaxReconstruction(id, MaxType.ORACLE_CS), this.multiAlignmentMaxReconstruction(id, MaxType.ORACLE_SP), this.multiAlignmentMaxReconstruction(id, MaxType.PR), this.multiAlignmentBayesReconstruction(id, false), this.multiAlignmentBayesReconstruction(id, true), this.initialMultiAlignmentState.get(id)};
                for (int i = 0; i < reconstructions.length; ++i) {
                    int n = i;
                    CSEvaluatedSums[n] = CSEvaluatedSums[n] + ((MultiAlignment)Heldout.this.multiAlignmentReferences.get(id)).columnScore(reconstructions[i]);
                    int n2 = i;
                    SPEvaluatedSums[n2] = SPEvaluatedSums[n2] + ((MultiAlignment)Heldout.this.multiAlignmentReferences.get(id)).sumOfPairsScore(reconstructions[i]);
                }
            }
            Table table = new Table(new Table.Populator(){

                @Override
                public void populate() {
                    this.set(0, 0, "Metric name");
                    this.set(1, 0, "SP");
                    this.set(2, 0, "CS");
                    this.set(0, 1, "Oracle-CS");
                    this.set(0, 2, "Oracle-SP");
                    this.set(0, 3, "Max Pr");
                    this.set(0, 4, "Bayes-CS");
                    this.set(0, 5, "Bayes-SP");
                    this.set(0, 6, "Init");
                    for (int i = 0; i < CSEvaluatedSums.length; ++i) {
                        this.set(1, i + 1, SPEvaluatedSums[i] / nInstances);
                        this.set(2, i + 1, CSEvaluatedSums[i] / nInstances);
                    }
                }
            });
            LogInfo.logs(table);
            return SPEvaluatedSums[4] / nInstances;
        }

        public String toString() {
            return this.heldoutWordReconstructionToString();
        }

        public String heldoutWordReconstructionToString() {
            StringBuilder builder = new StringBuilder();
            if (Heldout.this.lang != null) {
                builder.append("Samples used for computing the min Bayes risk estimate for " + Heldout.this.lang + " heldout\n\n");
            }
            for (CognateId id : this.samples.keySet()) {
                builder.append("Current word id: " + id.toString() + "\n");
                builder.append("\tTruth is: " + (String)Heldout.this.wordHeldoutEntries.get(id) + "\n");
                builder.append("\tMin Bayes reconstruction is: " + this.heldoutWordReconstruction(id) + "\n");
                builder.append("\tSamples used: \n");
                Counter<String> sampleMultiplicities = this.samples.get(id);
                for (String sample : sampleMultiplicities) {
                    builder.append("\t\t" + sample + "(" + sampleMultiplicities.getCount(sample) + " times)\n");
                }
            }
            return builder.toString();
        }

        public void setDumpFolder(String string) {
            this.dumpFolder = new File(string);
            this.dumpFolder.mkdir();
        }
    }
}

