/*
 * Decompiled with CFR 0.152.
 */
package ev.par;

import ev.hmm.HetPairHMM;
import ev.par.CachedParams;
import ev.par.FeatureExtractor;
import ev.par.Input;
import ev.par.Model;
import ev.par.Output;
import ev.par.StrTaxonSuffStat;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.UnorderedPair;
import fig.exec.Execution;
import goblin.HLParamsUpdater;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Map;
import ma.GreedyDecoder;
import ma.MSAPoset;
import ma.MultiAlignment;
import ma.SequenceType;
import nuts.maxent.BaseMeasures;
import nuts.maxent.LabeledInstance;
import nuts.maxent.MaxentClassifier;
import nuts.util.CollUtils;
import nuts.util.Counter;
import pepper.Encodings;

public final class ExponentialFamily {
    private CachedParams cachedParams;
    private Counter<Object> naturalParams;
    private Counter<Object> regularizationCenters;
    public Counter<LabeledInstance<Input, Output>> suffStats = new Counter();
    public final MaxentClassifier.MaxentOptions<Object> learningOptions;
    public final Model model;
    public final BaseMeasures<Input, Output> bm;
    public final FeatureExtractor featureExtractor;

    public String toString() {
        return this.cachedParams.toString();
    }

    public ExponentialFamily(Counter<Object> naturalParams, Counter<Object> regularizationCenters, MaxentClassifier.MaxentOptions<Object> learningOptions, Model model, BaseMeasures<Input, Output> bm, FeatureExtractor featureExtractor) {
        this.naturalParams = naturalParams;
        this.regularizationCenters = regularizationCenters;
        this.learningOptions = learningOptions;
        this.model = model;
        this.bm = bm;
        this.featureExtractor = featureExtractor;
        this.initParameters();
    }

    public static ExponentialFamily createExpfam(MaxentClassifier.MaxentOptions<Object> learningOptions, ExponentialFamilyOptions options, FeatureExtractor.FeatureOptions fo, Map<UnorderedPair<Taxon, Taxon>, Double> distances) {
        Counter<Object> initParams = options.getInitCounter();
        Counter<Object> centerParams = options.getCenterCounter();
        Encodings enc = options.encodingType.getEncodings();
        FeatureExtractor fe = new FeatureExtractor(distances, fo);
        Model model = Model.stdBranchSpecificModel(enc, fe.getStrTaxonSuffStat());
        Model.ThreeStatesBaseMeasure tsmb = new Model.ThreeStatesBaseMeasure(model);
        return new ExponentialFamily(initParams, centerParams, learningOptions, model, tsmb, fe);
    }

    public void updateParameters() {
        MaxentClassifier.MaxentOptions<Object> currentLearningOptions = MaxentClassifier.MaxentOptions.cloneWithWeights(this.learningOptions, this.naturalParams);
        MaxentClassifier<Input, Output, Object> maxentClassifier = MaxentClassifier.learnMaxentClassifier(this.bm, this.suffStats, this.featureExtractor, currentLearningOptions, this.regularizationCenters);
        this.naturalParams = maxentClassifier.weights();
        this.cachedParams = new CachedParams(this.model, maxentClassifier);
        this.suffStats = new Counter();
    }

    public void saveWeightsInExec(String name) {
        try {
            this.saveWeights(new File(Execution.getFile(name)));
        }
        catch (Exception e) {
            LogInfo.warning("Saving weights failed: " + e.toString());
        }
    }

    public void saveWeights(File f) {
        try {
            HLParamsUpdater.saveCounter(this.naturalParams, f.getAbsolutePath());
        }
        catch (Exception e) {
            LogInfo.warning("Saving weights failed: " + e.toString());
        }
    }

    private void initParameters() {
        MaxentClassifier<Input, Output, Object> maxentClassifier = MaxentClassifier.createMaxentClassifierFromWeights(this.bm, this.naturalParams, this.featureExtractor);
        this.cachedParams = new CachedParams(this.model, maxentClassifier);
        this.suffStats = new Counter();
    }

    public void addSufficientStatistics(HetPairHMM pairHMM, Taxon topTaxon, Taxon botTaxon) {
        this.addSufficientStatistics(this.suffStats, pairHMM, topTaxon, botTaxon);
    }

    public void addSufficientStatistics(Counter<LabeledInstance<Input, Output>> suffStats, HetPairHMM pairHMM, Taxon topTaxon, Taxon botTaxon) {
        StrTaxonSuffStat.StrTaxonSuffStatExtractor extractor = this.model.stSuffStat.getExtractor(pairHMM.str1, pairHMM.str2, topTaxon, botTaxon);
        double logSumPr = pairHMM.logSumProduct();
        for (int xpos = 0; xpos <= pairHMM.str1.length(); ++xpos) {
            for (int dx = 0; dx < 2 && xpos + dx <= pairHMM.str1.length(); ++dx) {
                int xid = this.model.charIdAt(pairHMM.str1, xpos, dx);
                for (int ypos = 0; ypos <= pairHMM.str2.length(); ++ypos) {
                    for (int dy = 0; dy < 2 && ypos + dy <= pairHMM.str2.length(); ++dy) {
                        if (dx <= 0 && dy <= 0) continue;
                        int yid = this.model.charIdAt(pairHMM.str2, ypos, dy);
                        int gss = extractor.extract(xpos, ypos);
                        for (int s1 = 0; s1 < this.model.nStates; ++s1) {
                            for (int s2 = 0; s2 < this.model.nStates; ++s2) {
                                Input in = new Input(s1, gss, this.model);
                                Output out = new Output(s2, xid, yid, this.model);
                                double value = Math.exp(pairHMM.logSumProduct(s1, s2, xpos, ypos, dx, dy) - logSumPr);
                                if (!(value > 0.0)) continue;
                                suffStats.incrementCount(new LabeledInstance<Input, Output>(out, in), value);
                            }
                        }
                    }
                }
            }
        }
    }

    public static void main(String[] args) {
    }

    public HetPairHMM getReweightedHMM(double[][][] logWeights, String top, String bot, Taxon topL, Taxon botL) {
        top = top + this.model.enc.boundChar();
        bot = bot + this.model.enc.boundChar();
        return new HetPairHMM(top, bot, this.cachedParams.getReweightedHMM(logWeights, top, bot, topL, botL));
    }

    public HetPairHMM getHMM(String top, String bot, Taxon topL, Taxon botL) {
        top = top + this.model.enc.boundChar();
        bot = bot + this.model.enc.boundChar();
        return new HetPairHMM(top, bot, this.cachedParams.getUnsupPairHMM(top, bot, topL, botL));
    }

    public HetPairHMM getSupervisedHMM(MultiAlignment truth, Taxon topL, Taxon botL) {
        String top = truth.getSequences().get(topL) + this.model.enc.boundChar();
        String bot = truth.getSequences().get(botL) + this.model.enc.boundChar();
        return new HetPairHMM(top, bot, this.cachedParams.getSupPairHMM(truth, top, bot, topL, botL));
    }

    public Counter<GreedyDecoder.Edge> allPairsPosterior(Map<Taxon, String> sequences) {
        Counter<GreedyDecoder.Edge> edgePosteriors = new Counter<GreedyDecoder.Edge>();
        ArrayList<Taxon> langs = CollUtils.list(sequences.keySet());
        for (int i = 0; i < langs.size(); ++i) {
            Taxon l1 = (Taxon)langs.get(i);
            for (int j = i + 1; j < langs.size(); ++j) {
                Taxon l2 = (Taxon)langs.get(j);
                String s1 = sequences.get(l1);
                String s2 = sequences.get(l2);
                HetPairHMM hmm = this.getHMM(s1, s2, l1, l2);
                for (int p1 = 0; p1 < s1.length(); ++p1) {
                    for (int p2 = 0; p2 < s2.length(); ++p2) {
                        GreedyDecoder.Edge current = new GreedyDecoder.Edge(p1, p2, l1, l2);
                        edgePosteriors.setCount(current, Math.exp(hmm.logPosteriorAlignment(p1, p2)));
                    }
                }
            }
        }
        return edgePosteriors;
    }

    public MSAPoset maxRecallAlignFromAllPairs(Map<Taxon, String> sequences) {
        return MSAPoset.maxRecallMSA(sequences, this.allPairsPosterior(sequences));
    }

    public static class ExponentialFamilyOptions {
        @Option
        public SequenceType encodingType = SequenceType.PROTEIN;
        @Option
        public String initParams = "ZERO";
        @Option
        public String reguCenterParams = "SAME";
        public static final String SAME = "SAME";
        public static final String ZERO = "ZERO";
        public static final String INTERNAL = "INTERNAL";
        public Counter<Object> internal = new Counter();

        public Counter<Object> getInitCounter() {
            if (SAME.equals(this.initParams)) {
                throw new RuntimeException();
            }
            if (INTERNAL.equals(this.initParams)) {
                return new Counter<Object>(this.internal);
            }
            if (ZERO.equals(this.initParams)) {
                return new Counter<Object>();
            }
            return HLParamsUpdater.restoreCounter(this.initParams);
        }

        public Counter<Object> getCenterCounter() {
            if (SAME.equals(this.reguCenterParams)) {
                return this.getInitCounter();
            }
            if (INTERNAL.equals(this.reguCenterParams)) {
                return new Counter<Object>(this.internal);
            }
            if (ZERO.equals(this.reguCenterParams)) {
                return new Counter<Object>();
            }
            return HLParamsUpdater.restoreCounter(this.reguCenterParams);
        }
    }
}

