/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions;

import dr.math.GammaFunction;
import dr.math.distributions.KernelDensityEstimatorDistribution;
import dr.stats.DiscreteStatistics;

public class BetaKDEDistribution
extends KernelDensityEstimatorDistribution {
    private double range;

    public BetaKDEDistribution(Double[] sample, Double lowerBound, Double upperBound, Double bandWidth) {
        super(sample, lowerBound, upperBound, bandWidth);
    }

    @Override
    protected void processBounds(Double lowerBound, Double upperBound) {
        if (lowerBound == null || upperBound == null || upperBound - lowerBound <= 0.0) {
            throw new RuntimeException("BetaKDEDistribution must be bounded");
        }
        if (lowerBound > DiscreteStatistics.min(this.sample) || upperBound < DiscreteStatistics.max(this.sample)) {
            throw new RuntimeException("Sample range outside bounds: " + DiscreteStatistics.min(this.sample) + " -> " + DiscreteStatistics.max(this.sample));
        }
        this.lowerBound = lowerBound;
        this.upperBound = upperBound;
        double[] oldSample = this.sample;
        this.sample = new double[this.sample.length];
        this.range = upperBound - lowerBound;
        for (int i = 0; i < this.N; ++i) {
            this.sample[i] = (oldSample[i] - this.lowerBound) / this.range;
        }
    }

    @Override
    protected void setBandWidth(Double bandWidth) {
        if (bandWidth == null) {
            double sigma = DiscreteStatistics.stdev(this.sample);
            this.bandWidth = sigma * Math.pow(this.N, -0.4);
        } else {
            this.bandWidth = bandWidth;
        }
    }

    @Override
    protected double evaluateKernel(double x) {
        double xPrime = (x - this.lowerBound) / this.range;
        double alphaMinus1 = xPrime / this.bandWidth - 1.0;
        double betaMinus1 = (1.0 - xPrime) / this.bandWidth - 1.0;
        if (xPrime < 2.0 * this.bandWidth) {
            alphaMinus1 = this.getRho(xPrime, this.bandWidth) - 1.0;
        } else if (xPrime > 1.0 - 2.0 * this.bandWidth) {
            betaMinus1 = this.getRho(1.0 - xPrime, this.bandWidth) - 1.0;
        }
        double logK = GammaFunction.lnGamma(alphaMinus1 + betaMinus1 + 2.0) - GammaFunction.lnGamma(alphaMinus1 + 1.0) - GammaFunction.lnGamma(betaMinus1 + 1.0);
        double pdf = 0.0;
        for (int i = 0; i < this.N; ++i) {
            pdf += Math.pow(this.sample[i], alphaMinus1) * Math.pow(1.0 - this.sample[i], betaMinus1);
        }
        return pdf * Math.exp(logK) / (double)this.N / this.range;
    }

    private double getRho(double x, double bandWidth) {
        return 2.0 * bandWidth * bandWidth + 2.5 - Math.sqrt(4.0 * bandWidth * bandWidth * bandWidth * bandWidth + 6.0 * bandWidth * bandWidth + 2.25 - x * x - x / bandWidth);
    }

    public double sampleMean() {
        return DiscreteStatistics.mean(this.sample);
    }
}

