/*
 * Decompiled with CFR 0.152.
 */
package slice;

import fig.basic.NumUtils;
import fig.prob.SampleUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import nuts.util.MathUtils;
import slice.SampleProcessor;
import slice.processor.TrivialProcessor;
import slice.stickrep.DPMSample;
import slice.stickrep.ImmutableSticks;
import slice.stickrep.Location;
import slice.stickrep.LocationDistribution;
import slice.stickrep.Sticks;
import slice.util.TruncUtils;
import slice.util.TruncatedBeta;

public class Sampler<L extends Location<D>, D> {
    private DPMSample<L, D> sample = new DPMSample();
    public LocationDistribution<L, D> prior;
    private List<Double> auxiliary = new ArrayList<Double>();
    private List<NodeSampler> nodeSamplers;
    private SampleProcessor<L, D> processor = new TrivialProcessor();

    public static <L extends Location<D>, D> Sampler<L, D> createSliceSampler(List<D> data, LocationDistribution<L, D> prior, double alpha_0) {
        Sampler<L, D> sampler = new Sampler<L, D>(prior, alpha_0, data);
        ArrayList<NodeSampler> nodeSamplers = new ArrayList<NodeSampler>();
        sampler.nodeSamplers = nodeSamplers;
        Sampler<L, D> sampler2 = sampler;
        sampler2.getClass();
        nodeSamplers.add(sampler2.new AuxSampler());
        Sampler<L, D> sampler3 = sampler;
        sampler3.getClass();
        nodeSamplers.add(sampler3.new IndicatorSliceSampler());
        Sampler<L, D> sampler4 = sampler;
        sampler4.getClass();
        nodeSamplers.add(sampler4.new LocationSampler());
        Sampler<L, D> sampler5 = sampler;
        sampler5.getClass();
        nodeSamplers.add(sampler5.new StickLengthSliceSampler());
        return sampler;
    }

    public static <L extends Location<D>, D> Sampler<L, D> createTruncSampler(List<D> data, LocationDistribution<L, D> prior, double alpha_0) {
        Sampler<L, D> sampler = new Sampler<L, D>(prior, alpha_0, data);
        ArrayList<NodeSampler> nodeSamplers = new ArrayList<NodeSampler>();
        sampler.nodeSamplers = nodeSamplers;
        Sampler<L, D> sampler2 = sampler;
        sampler2.getClass();
        nodeSamplers.add(sampler2.new IndicatorTruncSampler());
        Sampler<L, D> sampler3 = sampler;
        sampler3.getClass();
        nodeSamplers.add(sampler3.new LocationSampler());
        Sampler<L, D> sampler4 = sampler;
        sampler4.getClass();
        nodeSamplers.add(sampler4.new StickLengthTruncSampler());
        return sampler;
    }

    public void setProcessor(SampleProcessor<L, D> processor) {
        assert (processor != null);
        this.processor = processor;
        processor.setAuxiliarySlice(Collections.unmodifiableList(this.auxiliary));
        DPMSample<L, D> sampleCopy = new DPMSample<L, D>();
        sampleCopy.setIndicators(Collections.unmodifiableList(this.sample.getIndicators()));
        sampleCopy.setLocationParams(Collections.unmodifiableList(this.sample.getLocationParams()));
        sampleCopy.setSticks(new ImmutableSticks(this.sample.getSticks()));
        sampleCopy.setAlpha0(this.sample.getAlpha0());
        sampleCopy.setData(Collections.unmodifiableList(this.sample.getData()));
        processor.setSample(sampleCopy);
    }

    public void basicInit(int n, Random rand) {
        int numberOfSticksReqd = this.sample.getData().size() / n + 1;
        double stickLength = this.laplace(numberOfSticksReqd, 0.1 * (double)n);
        this.nPerClusterIndicatorInit(n);
        this.uniformSticksInit(numberOfSticksReqd, stickLength);
        this.uniformAuxInit(rand, numberOfSticksReqd, stickLength);
        this.sampledLocationInit(rand);
    }

    private void nPerClusterIndicatorInit(int n) {
        int cIndex = -1;
        for (int i = 0; i < this.nObs(); ++i) {
            if (i % n == 0) {
                ++cIndex;
            }
            this.sample.getIndicators().add(cIndex);
        }
    }

    private void sampledLocationInit(Random rand) {
        int i;
        for (i = 0; i < this.nSticks(); ++i) {
            this.sample.getLocationParams().add(null);
        }
        for (i = 0; i < this.nSticks(); ++i) {
            LocationSampler locSampler = new LocationSampler();
            locSampler.sample(i, rand);
        }
    }

    private double laplace(int initNumberOfSticks, double discount) {
        return discount / (discount + (double)initNumberOfSticks);
    }

    private void uniformSticksInit(int initNumberOfSticks, double stickLengths) {
        double currentLength = 1.0;
        for (int i = 0; i < initNumberOfSticks; ++i) {
            double v = stickLengths / currentLength;
            currentLength -= stickLengths;
            this.sample.getSticks().add(v);
        }
    }

    public void uniformAuxInit(Random rand, int initNumberOfSticks, double stickLengths) {
        AuxSampler auxSampler = new AuxSampler();
        for (int i = 0; i < this.nObs(); ++i) {
            this.auxiliary.add(null);
            auxSampler.sample(i, rand);
        }
    }

    private Sampler(LocationDistribution<L, D> prior, double alpha0, List<D> data) {
        this.prior = prior;
        this.sample.setAlpha0(alpha0);
        this.sample.setData(data);
    }

    public void sample(Random rand) {
        NodeSampler nodeSampler = this.nodeSamplers.get(rand.nextInt(this.nodeSamplers.size()));
        int index = rand.nextInt(nodeSampler.maxIndex());
        nodeSampler.sample(index, rand);
    }

    private int cluster(int dataIndex) {
        return this.sample.getIndicators().get(dataIndex);
    }

    public int nObs() {
        return this.sample.getData().size();
    }

    public int nSticks() {
        return this.sample.getSticks().nSticks();
    }

    private class IndicatorTruncSampler
    implements NodeSampler {
        private IndicatorTruncSampler() {
        }

        @Override
        public int maxIndex() {
            return Sampler.this.nObs();
        }

        @Override
        public void sample(int dataIndex, Random rand) {
            double[] probs = new double[Sampler.this.sample.getSticks().nSticks()];
            for (int clusterIndex = 0; clusterIndex < probs.length; ++clusterIndex) {
                probs[clusterIndex] = ((Location)Sampler.this.sample.getLocationParams().get(clusterIndex)).unnormLoglikelihood(Sampler.this.sample.getData().get(dataIndex));
            }
            NumUtils.expNormalize(probs);
            int multSample = SampleUtils.sampleMultinomial(rand, probs);
            Sampler.this.sample.getIndicators().set(dataIndex, multSample);
            Sampler.this.processor.processIndicatorSample(dataIndex, multSample);
        }
    }

    private class StickLengthTruncSampler
    implements NodeSampler {
        private StickLengthTruncSampler() {
        }

        @Override
        public int maxIndex() {
            return Sampler.this.sample.getSticks().nSticks();
        }

        @Override
        public void sample(int clusterIndex, Random rand) {
            double count = 0.0;
            double summedCount = 0.0;
            for (int dataIndex = 0; dataIndex < Sampler.this.nObs(); ++dataIndex) {
                if (Sampler.this.cluster(dataIndex) == clusterIndex) {
                    count += 1.0;
                }
                if (Sampler.this.cluster(dataIndex) <= clusterIndex) continue;
                summedCount += 1.0;
            }
            Sampler.this.sample.getSticks().updateV(clusterIndex, TruncUtils.sampleBeta(rand, 1.0 + count, Sampler.this.sample.getAlpha0() + summedCount));
            Sampler.this.processor.processStickSample(clusterIndex, Sampler.this.sample.getSticks().retreiveW(clusterIndex));
        }
    }

    private class IndicatorSliceSampler
    implements NodeSampler {
        private IndicatorSliceSampler() {
        }

        @Override
        public int maxIndex() {
            return Sampler.this.nObs();
        }

        @Override
        public void sample(int dataIndex, Random rand) {
            double u = (Double)Sampler.this.auxiliary.get(dataIndex);
            Sampler.this.sample.ensureEnoughSticks(u, Sampler.this.prior, rand);
            List<Integer> clusterIndices = Sampler.this.sample.getSticks().largeSticks(u);
            double[] probs = new double[clusterIndices.size()];
            for (int i = 0; i < probs.length; ++i) {
                int clusterIndex = clusterIndices.get(i);
                probs[i] = ((Location)Sampler.this.sample.getLocationParams().get(clusterIndex)).unnormLoglikelihood(Sampler.this.sample.getData().get(dataIndex));
            }
            NumUtils.expNormalize(probs);
            int multSample = SampleUtils.sampleMultinomial(rand, probs);
            int newIndicValue = clusterIndices.get(multSample);
            Sampler.this.sample.getIndicators().set(dataIndex, newIndicValue);
            Sampler.this.processor.processIndicatorSample(dataIndex, newIndicValue);
        }
    }

    private class AuxSampler
    implements NodeSampler {
        private AuxSampler() {
        }

        @Override
        public int maxIndex() {
            return Sampler.this.nObs();
        }

        @Override
        public void sample(int dataIndex, Random rand) {
            int cluster = Sampler.this.cluster(dataIndex);
            double weight = Sampler.this.sample.getSticks().retreiveW(cluster);
            double sample = rand.nextDouble() * weight;
            Sampler.this.auxiliary.set(dataIndex, sample);
            Sampler.this.processor.processAuxiliarySliceSample(dataIndex, sample);
        }
    }

    private class LocationSampler
    implements NodeSampler {
        private LocationSampler() {
        }

        @Override
        public int maxIndex() {
            return Sampler.this.nSticks();
        }

        @Override
        public void sample(int clusterIndex, Random rand) {
            ArrayList correspData = new ArrayList();
            for (int i = 0; i < Sampler.this.nObs(); ++i) {
                if (Sampler.this.sample.getIndicators().get(i) != clusterIndex) continue;
                correspData.add(Sampler.this.sample.getData().get(i));
            }
            Object newLoc = Sampler.this.prior.samplePosterior(rand, correspData);
            Sampler.this.sample.getLocationParams().set(clusterIndex, newLoc);
            Sampler.this.processor.processLocationSample(clusterIndex, newLoc);
        }
    }

    private class StickLengthSliceSampler
    implements NodeSampler {
        private StickLengthSliceSampler() {
        }

        @Override
        public int maxIndex() {
            return Sampler.this.nSticks();
        }

        @Override
        public void sample(int clusterIndex, Random rand) {
            double min = this.minTruncation(clusterIndex);
            double max = this.maxTruncation(clusterIndex);
            TruncatedBeta betaSampler = new TruncatedBeta(Sampler.this.sample.getAlpha0(), min, max);
            Sampler.this.sample.getSticks().updateV(clusterIndex, betaSampler.sample(rand));
            Sampler.this.processor.processStickSample(clusterIndex, Sampler.this.sample.getSticks().retreiveW(clusterIndex));
        }

        private double minTruncation(int clusterIndex) {
            double MF = Sampler.this.sample.getSticks().retreiveV(clusterIndex) / Sampler.this.sample.getSticks().retreiveW(clusterIndex);
            double cMax = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < Sampler.this.nObs(); ++i) {
                if (Sampler.this.cluster(i) != clusterIndex || !((Double)Sampler.this.auxiliary.get(i) > cMax)) continue;
                cMax = (Double)Sampler.this.auxiliary.get(i);
            }
            if (cMax == Double.NEGATIVE_INFINITY) {
                return 0.0;
            }
            double result = cMax * MF;
            assert (MathUtils.close(result, this.inefMinTruncation(clusterIndex)));
            return result;
        }

        private double inefMinTruncation(int clusterIndex) {
            double cMax = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < Sampler.this.nObs(); ++i) {
                if (Sampler.this.cluster(i) != clusterIndex) continue;
                double current = (Double)Sampler.this.auxiliary.get(i);
                for (int l = 0; l < clusterIndex; ++l) {
                    current /= 1.0 - Sampler.this.sample.getSticks().retreiveV(l);
                }
                if (!(current > cMax)) continue;
                cMax = current;
            }
            if (cMax == Double.NEGATIVE_INFINITY) {
                return 0.0;
            }
            return cMax;
        }

        private double maxTruncation(int clusterIndex) {
            double cMax = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < Sampler.this.nObs(); ++i) {
                if (Sampler.this.cluster(i) <= clusterIndex) continue;
                double P = (1.0 - Sampler.this.sample.getSticks().retreiveV(clusterIndex)) / Sampler.this.sample.getSticks().retreiveW(Sampler.this.cluster(i));
                double cur = (Double)Sampler.this.auxiliary.get(i) * P;
                if (!(cur > cMax)) continue;
                cMax = cur;
            }
            if (cMax == Double.NEGATIVE_INFINITY) {
                return 1.0;
            }
            double result = 1.0 - cMax;
            assert (MathUtils.close(result, this.inefMaxTruncation(clusterIndex)));
            return result;
        }

        private double inefMaxTruncation(final int clusterIndex) {
            final Sticks sticks = Sampler.this.sample.getSticks();
            return 1.0 - MathUtils.max(Sampler.this.nObs(), new MathUtils.I2D(){

                @Override
                public boolean inDom(int i) {
                    return Sampler.this.cluster(i) > clusterIndex;
                }

                @Override
                public double f(int i) {
                    return (Double)Sampler.this.auxiliary.get(i) / sticks.retreiveV(Sampler.this.cluster(i)) / MathUtils.prod(Sampler.this.cluster(i), new MathUtils.I2D(){

                        @Override
                        public final boolean inDom(int l) {
                            return l != clusterIndex;
                        }

                        @Override
                        public final double f(int l) {
                            return 1.0 - sticks.retreiveV(l);
                        }
                    });
                }
            });
        }
    }

    private static interface NodeSampler {
        public void sample(int var1, Random var2);

        public int maxIndex();
    }
}

