/*
 * Decompiled with CFR 0.152.
 */
package nuts.maxent;

import fig.basic.Indexer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import nuts.maxent.BaseMeasures;
import nuts.maxent.FeatureExtractor;
import nuts.maxent.FeatureVectorsInterface;
import nuts.maxent.LabeledInstance;
import nuts.maxent.SparseVector;
import nuts.util.Counter;

public class FeatureVectors<I, L, F>
implements FeatureVectorsInterface<I, L> {
    private final Map<LabeledInstance<I, L>, SparseVector> featureVectors = new HashMap<LabeledInstance<I, L>, SparseVector>();
    private final Indexer<F> indexer;
    private final FeatureExtractor<LabeledInstance<I, L>, F> extractor;
    private final int dim;
    private final double[] regularizationFactors;

    public static <I, L, F> FeatureVectorsInterface<I, L> createFeatureVectors(BaseMeasures<I, L> baseMeasures, Set<LabeledInstance<I, L>> training, FeatureExtractor<LabeledInstance<I, L>, F> extractor) {
        Indexer<F> indexer = FeatureVectors.computeIndexes(baseMeasures, training, extractor);
        double[] regularizationFactor = FeatureVectors.computeRegFactors(indexer, extractor);
        FeatureVectors<I, L, F> result = new FeatureVectors<I, L, F>(extractor, indexer, regularizationFactor);
        for (LabeledInstance<I, L> labeledInstance : training) {
            super.addFeatureVector(labeledInstance);
        }
        return result;
    }

    public static <I, L, F> FeatureVectorsInterface<I, L> createFeatureVectorsFromSet(Set<F> features, FeatureExtractor<LabeledInstance<I, L>, F> extractor) {
        Indexer<F> indexer = FeatureVectors.computeIndexesFromSet(features);
        FeatureVectors<I, L, F> result = new FeatureVectors<I, L, F>(extractor, indexer, null);
        return result;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    public int size() {
        return this.featureVectors.size();
    }

    @Override
    public double getRegularizationFactor(int index) {
        return this.regularizationFactors[index];
    }

    @Override
    public SparseVector getFeatureVector(LabeledInstance<I, L> labeledInstance, boolean cache) {
        SparseVector vector = this.featureVectors.get(labeledInstance);
        if (vector == null) {
            vector = FeatureVectors.extractFeatureVector(labeledInstance, this.indexer, this.extractor, false);
            if (cache) {
                this.featureVectors.put(labeledInstance, vector);
            }
        }
        return vector;
    }

    public String toString() {
        return this.featureVectors.toString();
    }

    @Override
    public double[] createInitialWeight(Counter init) {
        double[] result = new double[this.dim()];
        for (Object feature : init.keySet()) {
            if (!this.indexer.contains(feature)) continue;
            int index = this.indexer.indexOf(feature);
            result[index] = init.getCount(feature);
        }
        return result;
    }

    @Override
    public double[] createInitialWeight() {
        return this.createInitialWeight(new Counter());
    }

    @Override
    public Counter namedCounter(double[] weights) {
        assert (weights.length == this.dim());
        Counter<F> named = new Counter<F>();
        for (int i = 0; i < this.dim(); ++i) {
            F name = this.indexer.getObject(i);
            double count = weights[i];
            named.setCount(name, count);
        }
        return named;
    }

    private void addFeatureVector(LabeledInstance<I, L> labeledInstance) {
        SparseVector vector = FeatureVectors.extractFeatureVector(labeledInstance, this.indexer, this.extractor, true);
        this.featureVectors.put(labeledInstance, vector);
    }

    private FeatureVectors(FeatureExtractor<LabeledInstance<I, L>, F> extractor, Indexer<F> indexer, double[] regularizationFactors) {
        this.dim = indexer.size();
        this.indexer = indexer;
        this.extractor = extractor;
        this.regularizationFactors = regularizationFactors;
    }

    private static <I, L, F> Indexer<F> computeIndexes(BaseMeasures<I, L> baseMeasures, Set<LabeledInstance<I, L>> training, FeatureExtractor<LabeledInstance<I, L>, F> extractor) {
        Indexer<F> indexer = new Indexer<F>();
        for (LabeledInstance<I, L> labeledInstance : training) {
            for (Object label : baseMeasures.support(labeledInstance.getInput())) {
                LabeledInstance currentLabeledInstance = new LabeledInstance(label, labeledInstance.getInput());
                Counter<F> features = extractor.extractFeatures(currentLabeledInstance);
                for (F feature : features) {
                    indexer.getIndex(feature);
                }
            }
        }
        return indexer;
    }

    private static <I, L, F> Indexer<F> computeIndexesFromSet(Set<F> features) {
        Indexer<F> indexer = new Indexer<F>();
        for (F feature : features) {
            indexer.getIndex(feature);
        }
        return indexer;
    }

    private static <I, L, F> double[] computeRegFactors(Indexer<F> indexer, FeatureExtractor<LabeledInstance<I, L>, F> extractor) {
        double[] result = new double[indexer.size()];
        for (int index = 0; index < indexer.size(); ++index) {
            F feature = indexer.getObject(index);
            result[index] = extractor.regularizationFactor(feature);
        }
        return result;
    }

    private static <I, L, F> SparseVector extractFeatureVector(LabeledInstance<I, L> instance, Indexer<F> indexer, FeatureExtractor<LabeledInstance<I, L>, F> extractor, boolean ensureFeaturesPresent) {
        Counter<F> features = extractor.extractFeatures(instance);
        ArrayList<Integer> indices = new ArrayList<Integer>();
        ArrayList<Double> values = new ArrayList<Double>();
        for (F feature : features) {
            if (indexer.indexOf(feature) == -1) {
                if (!ensureFeaturesPresent) continue;
                throw new RuntimeException("Unknown feature: " + feature.toString());
            }
            indices.add(indexer.indexOf(feature));
            values.add(features.getCount(feature));
        }
        return new SparseVector(indices, values, indexer.size());
    }
}

