/*
 * Decompiled with CFR 0.152.
 */
package goblin;

import fig.basic.IOUtils;
import fig.basic.LogInfo;
import goblin.HLBaseMeasures;
import goblin.HLFeatureExtractor;
import goblin.HLParams;
import goblin.Taxon;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Set;
import nuts.io.Extensions;
import nuts.maxent.LabeledInstance;
import nuts.maxent.MaxentClassifier;
import nuts.util.Counter;
import pepper.Encodings;
import pepper.editmodel.Utils;

public class HLParamsUpdater {
    private final HLBaseMeasures baseMeasure;
    private final HLFeatureExtractor featureExtractor;
    public final MaxentClassifier.MaxentOptions learningOptions;
    public Counter<String> previousWeights = new Counter();
    private final Set<Taxon> languages;
    private final int nThreads;
    private final Counter regularizerCenters;

    public HLParamsUpdater(Encodings enc, Set<Taxon> languages, HLFeatureExtractor featureExtractor, MaxentClassifier.MaxentOptions options, Counter<String> regularizerCenters, int nThreads) {
        this.nThreads = nThreads;
        this.baseMeasure = new HLBaseMeasures(enc);
        this.languages = languages;
        this.featureExtractor = featureExtractor;
        this.learningOptions = options;
        this.regularizerCenters = regularizerCenters;
    }

    public Counter currentWeights() {
        return this.previousWeights;
    }

    public HLParams update(Counter<LabeledInstance<HLParams.HLContext, HLParams.HLOutcome>> suffStats) {
        LogInfo.track("Fitting maxent model from suff stats");
        MaxentClassifier.MaxentOptions<String> currentLearningOptions = MaxentClassifier.MaxentOptions.cloneWithWeights(this.learningOptions, this.previousWeights);
        MaxentClassifier<HLParams.HLContext, HLParams.HLOutcome, String> maxentClassifier = MaxentClassifier.learnMaxentClassifier(this.baseMeasure, suffStats, this.featureExtractor, currentLearningOptions, this.regularizerCenters);
        LogInfo.end_track();
        this.previousWeights = maxentClassifier.weights();
        return HLParams.createHLParamsFromMaxent(this.baseMeasure.enc, maxentClassifier, this.languages, this.nThreads);
    }

    public void saveWeightsInExec(String prefix) {
        this.saveWeightsInExec(prefix, null);
    }

    public void saveWeightsInExec(String prefix, Integer iteration) {
        String fileName = prefix + (iteration == null ? "" : Extensions.extension2String(iteration)) + ".weights";
        try {
            HLParamsUpdater.saveCounter(this.currentWeights(), Utils.safeGetExecFilePath(fileName));
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static void saveCounterHard(Counter weights, String file) {
        try {
            HLParamsUpdater.saveCounter(weights, file);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static void saveCounter(Counter weights, String file) throws IOException {
        PrintWriter out = IOUtils.openOut(file);
        for (Object key : weights) {
            out.append(key.toString() + "\t" + weights.getCount(key) + "\n");
        }
        out.close();
    }

    public static Counter restoreCounter(String filePath) {
        Counter<String> result = new Counter<String>();
        BufferedReader br = IOUtils.openInHard(filePath);
        String line = null;
        try {
            while ((line = br.readLine()) != null) {
                if (line.equals("")) continue;
                String[] fields = line.split("\\t+");
                String f = fields[0];
                double w = Double.parseDouble(fields[1]);
                if (result.containsKey(f)) {
                    throw new RuntimeException("Duplicate entries for " + f + " in " + filePath);
                }
                result.setCount(f, w);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return result;
    }

    public static void main(String[] args) {
        System.out.println(HLParamsUpdater.restoreCounter("data/wDna"));
    }
}

