/*
 * Decompiled with CFR 0.152.
 */
package sage;

import fig.basic.LogInfo;
import goblin.CognateId;
import goblin.DerivationTree;
import goblin.HLBaseMeasures;
import goblin.HLParams;
import goblin.HLParamsUpdater;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.concurrent.ConcurrentHashMap;
import nuts.io.Extensions;
import nuts.maxent.BaseMeasures;
import nuts.maxent.LabeledInstance;
import nuts.maxent.MaxentClassifier;
import nuts.util.Arbre;
import nuts.util.Counter;
import pepper.Encodings;
import pepper.editmodel.Utils;
import sage.FatContext;
import sage.FatFeatureExtractor;

public interface LikelihoodModel {
    public double partialLogLikelihood(Arbre<DerivationTree.LineagedNode> var1, CognateId var2);

    public double fullLogLikelihood(Arbre<DerivationTree.DerivationNode> var1, CognateId var2);

    public static class FatLikelihoodModel
    implements LikelihoodModel {
        private final FatBaseMeasure baseMeasure;
        private Map<CacheKey, Double> cachedLogLikelihood = new ConcurrentHashMap<CacheKey, Double>();
        private MaxentClassifier<FatContext, HLParams.HLOutcome, Object> maxentClassifier;
        private final FatFeatureExtractor fatExtractor;
        private final Encodings enc;
        public final MaxentClassifier.MaxentOptions<Object> learningOptions;
        private final Counter reguCenter;

        public FatLikelihoodModel(Encodings enc, FatFeatureExtractor fatExtractor, MaxentClassifier.MaxentOptions<Object> learningOptions, Counter<Object> initialWeights, Counter<Object> reguCenter) {
            this.reguCenter = reguCenter;
            this.enc = enc;
            this.fatExtractor = fatExtractor;
            this.learningOptions = learningOptions;
            this.baseMeasure = new FatBaseMeasure(new HLBaseMeasures(enc));
            this.maxentClassifier = MaxentClassifier.createMaxentClassifierFromWeights(this.baseMeasure, initialWeights, fatExtractor);
        }

        public FatLikelihoodModel(Encodings enc, FatFeatureExtractor fatExtractor, MaxentClassifier.MaxentOptions<Object> learningOptions) {
            this(enc, fatExtractor, learningOptions, new Counter<Object>(), new Counter<Object>());
        }

        @Override
        public double fullLogLikelihood(Arbre<DerivationTree.DerivationNode> state, CognateId id) {
            return this.partialLogLikelihood(DerivationTree.fullLineage(state), id);
        }

        @Override
        public double partialLogLikelihood(Arbre<DerivationTree.LineagedNode> state, CognateId id) {
            double result = 0.0;
            Counter<LabeledInstance<FatContext, HLParams.HLOutcome>> ops = new Counter<LabeledInstance<FatContext, HLParams.HLOutcome>>();
            FatContext.addSuffStatsFromLineagedTree(ops, state, this.granularities(), this.enc, id);
            for (LabeledInstance<FatContext, HLParams.HLOutcome> op : ops.keySet()) {
                result += ops.getCount(op) * this.logLikelihood(op);
            }
            return result;
        }

        public SortedSet<HLParams.HLOutcome> outcomes(FatContext context) {
            return this.maxentClassifier.getLabels(context);
        }

        private Set<FatContext.Granularity> granularities() {
            FatLikelihoodModel fatLikelihoodModel = this;
            return fatLikelihoodModel.fatExtractor.granularities();
        }

        public double logLikelihood(LabeledInstance<FatContext, HLParams.HLOutcome> operation) {
            CacheKey key = new CacheKey(operation);
            if (this.cachedLogLikelihood.containsKey(key)) {
                double result = this.cachedLogLikelihood.get(key);
                return result;
            }
            double[] logPrs = this.maxentClassifier.logProb(operation.getInput());
            SortedSet<HLParams.HLOutcome> outcomes = this.maxentClassifier.getLabels(operation.getInput());
            int i = 0;
            for (HLParams.HLOutcome outcome : outcomes) {
                CacheKey cKey = new CacheKey(operation.getInput(), outcome);
                this.cachedLogLikelihood.put(cKey, logPrs[i]);
                ++i;
            }
            return this.cachedLogLikelihood.get(key);
        }

        public void update(Counter<LabeledInstance<FatContext, HLParams.HLOutcome>> suffStats) {
            LogInfo.track("Learning maxent model from expected suff stats");
            MaxentClassifier.MaxentOptions<Object> curOpt = MaxentClassifier.MaxentOptions.cloneWithWeights(this.learningOptions, this.currentWeights());
            this.maxentClassifier = MaxentClassifier.learnMaxentClassifier(this.baseMeasure, suffStats, this.fatExtractor, curOpt, this.reguCenter);
            LogInfo.end_track();
            this.cachedLogLikelihood = new ConcurrentHashMap<CacheKey, Double>();
        }

        public Counter currentWeights() {
            return this.maxentClassifier.weights();
        }

        public void saveWeightsInExec(String prefix, Integer iteration) {
            String fileName = prefix + (iteration == null ? "" : Extensions.extension2String(iteration)) + ".weights";
            try {
                HLParamsUpdater.saveCounter(this.currentWeights(), Utils.safeGetExecFilePath(fileName));
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        public Encodings getEncodings() {
            return ((FatBaseMeasure)this.baseMeasure).hlBaseMeasure.enc;
        }

        public static class FatBaseMeasure
        implements BaseMeasures<FatContext, HLParams.HLOutcome> {
            private static final long serialVersionUID = 1L;
            private final HLBaseMeasures hlBaseMeasure;

            public FatBaseMeasure(HLBaseMeasures hlBaseMeasure) {
                this.hlBaseMeasure = hlBaseMeasure;
            }

            @Override
            public SortedSet<HLParams.HLOutcome> support(FatContext input) {
                return this.hlBaseMeasure.support(input.getBaseContext());
            }
        }

        private static class CacheKey {
            private final LabeledInstance<FatContext, HLParams.HLOutcome> instance;
            private final int hash;

            private CacheKey(FatContext ctxt, HLParams.HLOutcome out) {
                this(LabeledInstance.create(out, ctxt));
            }

            private CacheKey(LabeledInstance<FatContext, HLParams.HLOutcome> instance) {
                this.instance = instance;
                this.hash = instance.hashCode();
            }

            public int hashCode() {
                return this.hash;
            }

            public boolean equals(Object obj) {
                return ((CacheKey)obj).instance.equals(this.instance);
            }
        }
    }
}

