/*
 * Decompiled with CFR 0.152.
 */
package slice.processor;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import nuts.util.Counter;
import nuts.util.PriorityQueue;
import slice.SampleProcessor;
import slice.likelihood.NormalLocation;
import slice.ndp.DistributionToken;
import slice.ndp.NDPLocation;
import slice.stickrep.DPMSample;
import slice.stickrep.Sticks;

public class ExpectedTopLocation
implements SampleProcessor<NDPLocation<NormalLocation, List<Double>>, DistributionToken<List<Double>>> {
    public static int numberOfClusters = 3;
    public static int numberOfSubclusters = 3;
    public static int burnin = 100;
    private double N = 0.0;
    private double bigSum = 0.0;
    private DPMSample<NDPLocation<NormalLocation, List<Double>>, DistributionToken<List<Double>>> sample;

    public double average() {
        if (this.N - (double)burnin <= 0.0) {
            return -1.0;
        }
        return this.bigSum / (this.N - (double)burnin);
    }

    @Override
    public void processAuxiliarySliceSample(int dataIndex, double newValue) {
    }

    @Override
    public void processIndicatorSample(int dataIndex, int newIndicatorValue) {
    }

    @Override
    public void processStickSample(int clusterIndex, double newWLength) {
        this.N += 1.0;
        if (this.N > (double)burnin) {
            this.bigSum += ExpectedTopLocation.f(this.sample);
        }
    }

    @Override
    public void setSample(DPMSample<NDPLocation<NormalLocation, List<Double>>, DistributionToken<List<Double>>> sampleRef) {
        this.sample = sampleRef;
    }

    public static double f(DPMSample<NDPLocation<NormalLocation, List<Double>>, DistributionToken<List<Double>>> sample) {
        double sum = 0.0;
        for (int clusterIndex : ExpectedTopLocation.topKClusters(sample.getSticks(), numberOfClusters)) {
            NDPLocation<NormalLocation, List<Double>> currentLoc = sample.getLocationParams().get(clusterIndex);
            for (int subClusterIndex : ExpectedTopLocation.topKClusters(currentLoc.getW(), numberOfSubclusters)) {
                NormalLocation currentSubLoc = currentLoc.getSubLocations().get(subClusterIndex);
                for (double currentCoord : currentSubLoc.getMean()) {
                    sum += currentCoord * currentCoord;
                }
            }
        }
        return sum;
    }

    public static Counter<Integer> asCounter(Sticks sticks) {
        Counter<Integer> result = new Counter<Integer>();
        for (int i = 0; i < sticks.nSticks(); ++i) {
            result.setCount(i, sticks.retreiveW(i));
        }
        return result;
    }

    public static Set<Integer> topKClusters(Sticks sticks, int k) {
        HashSet<Integer> result = new HashSet<Integer>();
        Counter<Integer> sticksCounter = ExpectedTopLocation.asCounter(sticks);
        assert (sticksCounter.size() >= k);
        PriorityQueue<Integer> iter = sticksCounter.asPriorityQueue();
        for (int i = 0; i < k; ++i) {
            result.add((Integer)iter.next());
        }
        return result;
    }

    @Override
    public void processLocationSample(int clusterIndex, NDPLocation<NormalLocation, List<Double>> newLocation) {
    }

    @Override
    public void setAuxiliarySlice(List<Double> auxRef) {
    }
}

