/*
 * Decompiled with CFR 0.152.
 */
package cognates;

import cognates.CognateTree;
import cognates.DirichletProcessKernel;
import cognates.DirichletProcessState;
import fig.basic.LogInfo;
import fig.basic.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import nuts.util.CollUtils;

public class DirichletProcessSMC {
    private List<DirichletProcessKernel> kernels;
    private int N;

    public DirichletProcessSMC(int N) {
        this.N = N;
        this.kernels = CollUtils.list();
    }

    public void initialize(Random rand, String gloss, CognateTree tree) {
        DirichletProcessKernel kernel = new DirichletProcessKernel(gloss, tree);
        this.kernels.add(kernel);
        DirichletProcessState initial = kernel.getInitial();
        int R = kernel.nIterationsLeft(initial);
        List<Object> states = CollUtils.list();
        states.add(new Pair<DirichletProcessState, Double>(initial, 1.0));
        for (int r = 0; r < R; ++r) {
            double norm = 0.0;
            ArrayList<Pair<DirichletProcessState, Double>> tempStates = CollUtils.list();
            LogInfo.logs("------------------iter=" + r + "------------------");
            for (Pair state : states) {
                List<Pair<DirichletProcessState, Double>> nextStates = kernel.enumerateNext(rand, (DirichletProcessState)state.getFirst());
                for (Pair<DirichletProcessState, Double> nextState : nextStates) {
                    tempStates.add(nextState);
                    norm += nextState.getSecond().doubleValue();
                }
            }
            states = tempStates.size() > this.N ? this.resample(rand, tempStates, norm, this.N) : tempStates;
        }
    }

    public void runSMC() {
        for (DirichletProcessKernel kernel : this.kernels) {
            int R = kernel.nIterationsLeft(kernel.getInitial());
            for (int r = 0; r < R; ++r) {
            }
        }
    }

    private List<Pair<DirichletProcessState, Double>> resample(Random rand, List<Pair<DirichletProcessState, Double>> states, double norm, int N) {
        ArrayList<Double> weights = CollUtils.list();
        for (Pair<DirichletProcessState, Double> state : states) {
            weights.add(state.getSecond() / norm);
        }
        ArrayList<Pair<DirichletProcessState, Double>> resampledStates = CollUtils.list();
        for (int i = 0; i < N; ++i) {
            int index = DirichletProcessSMC.sampleNormalized(rand, weights);
            resampledStates.add(states.get(index));
        }
        return resampledStates;
    }

    private List<Pair<DirichletProcessState, Double>> resample(Random rand, List<Pair<DirichletProcessState, Double>> states, int N) {
        double norm = 0.0;
        for (Pair<DirichletProcessState, Double> state : states) {
            norm += state.getSecond().doubleValue();
        }
        return this.resample(rand, states, norm, N);
    }

    private static int sampleNormalized(Random rand, List<Double> normalizedProbs) {
        double v = rand.nextDouble();
        double sum = 0.0;
        for (int i = 0; i < normalizedProbs.size(); ++i) {
            if (!(v < (sum += normalizedProbs.get(i).doubleValue()))) continue;
            return i;
        }
        throw new RuntimeException("Bad probs: " + normalizedProbs);
    }
}

