/*
 * Decompiled with CFR 0.152.
 */
package scratch;

import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import fig.exec.Execution;
import fig.prob.Gaussian;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import nuts.io.IO;
import nuts.tui.Table;

public class SemiSupAsym
implements Runnable {
    private int nFeatures;
    private final int nHiddenClasses = 2;
    private final int vagueSigmaSqr = 10;
    private double hiddenPos2obsPos = Double.NaN;
    private double hiddenNeg2obsPos = Double.NaN;
    private double piPos = Double.NaN;
    @Option(required=true)
    public String trainingDataFile;
    @Option(required=true)
    public String testDataFile;
    @Option(gloss="Integer seed for random numbers")
    public Random rand = new Random(1L);
    @Option
    public int emIterations = 10;
    @Option(gloss="False positive in the data (labeled positive but are actually negative) ~.01-.1")
    public double falsePositive = 0.05;
    @Option(gloss="Approximate fraction of unlabeled examplars that should actually be positive ~.1-.3")
    public double fractionUnlabeledThatShouldBePos = 0.2;

    private static void normalize(List<Instance> instances, NormalizationMap nm) {
        for (Instance instance : instances) {
            for (int feat = 0; feat < nm.nFeatures; ++feat) {
                instance.observedFeatures[feat] = (instance.observedFeatures[feat] - nm.min[feat]) * nm.scale[feat];
            }
        }
    }

    private Params initParams(Random rand, boolean isVague) {
        Params result = new Params();
        for (int hc = 0; hc < 2; ++hc) {
            for (int feat = 0; feat < this.nFeatures; ++feat) {
                ((Params)result).mu[hc][feat] = rand.nextDouble();
                ((Params)result).sigmaSqr[hc][feat] = isVague ? 10.0 : 0.5 + rand.nextDouble() / 5.0;
            }
        }
        return result;
    }

    private int argmax(double[] x) {
        double max = Double.NEGATIVE_INFINITY;
        int argmax = -1;
        for (int i = 0; i < x.length; ++i) {
            if (!(max < x[i])) continue;
            argmax = i;
            max = x[i];
        }
        return argmax;
    }

    private Params m(SuffStats suffStats) {
        Params result = new Params();
        for (int hc = 0; hc < 2; ++hc) {
            assert (suffStats.hiddenClassSoftCounts[hc] > 0.0);
            for (int feat = 0; feat < this.nFeatures; ++feat) {
                ((Params)result).mu[hc][feat] = suffStats.featureSoftCounts[hc][feat] / suffStats.hiddenClassSoftCounts[hc];
                ((Params)result).sigmaSqr[hc][feat] = suffStats.featureSoftSqrCounts[hc][feat] / suffStats.hiddenClassSoftCounts[hc] - result.mu[hc][feat] * result.mu[hc][feat];
            }
        }
        return result;
    }

    private void e(Params params, Instance instance, SuffStats suffStats) {
        double[] posteriorHiddenClassPrs = this.posteriorHiddenClassPrs(params, instance, true);
        for (int hc = 0; hc < 2; ++hc) {
            double[] dArray = suffStats.hiddenClassSoftCounts;
            int n = hc;
            dArray[n] = dArray[n] + posteriorHiddenClassPrs[hc];
            for (int feat = 0; feat < this.nFeatures; ++feat) {
                double[] dArray2 = suffStats.featureSoftCounts[hc];
                int n2 = feat;
                dArray2[n2] = dArray2[n2] + posteriorHiddenClassPrs[hc] * instance.observedFeatures[feat];
                double[] dArray3 = suffStats.featureSoftSqrCounts[hc];
                int n3 = feat;
                dArray3[n3] = dArray3[n3] + posteriorHiddenClassPrs[hc] * instance.observedFeatures[feat] * instance.observedFeatures[feat];
            }
        }
    }

    public Params em(List<? extends Instance> instances) {
        TransitionParametrizationConverter converter = new TransitionParametrizationConverter();
        converter.compute(instances);
        LogInfo.logs(converter);
        this.piPos = converter.getPiPos();
        this.hiddenNeg2obsPos = converter.getHiddenNeg2obsPos();
        this.hiddenPos2obsPos = converter.getHiddenPos2obsPos();
        Params params = this.initParams(this.rand, true);
        LogInfo.track((Object)"Computing MLE with EM", true);
        for (int i = 0; i < this.emIterations; ++i) {
            SuffStats suffStats = new SuffStats();
            for (Instance instance : instances) {
                this.e(params, instance, suffStats);
            }
            params = this.m(suffStats);
            LogInfo.logs("Params after iteration " + i + "\n" + params.toString());
        }
        LogInfo.end_track();
        return params;
    }

    private double[] posteriorHiddenClassPrs(Params params, Instance instance, boolean useObsClassId) {
        double[] result = new double[2];
        for (int hc = 0; hc < 2; ++hc) {
            double product = 1.0;
            product *= params.pi(hc);
            if (useObsClassId) {
                product *= params.hidden2observedClassPr(hc, instance.observedClassId);
            }
            result[hc] = product *= params.likelihood(hc, instance.observedFeatures);
        }
        NumUtils.normalize(result);
        return result;
    }

    @Override
    public void run() {
        Dataset train = new Dataset(this.trainingDataFile);
        Dataset test = new Dataset(this.testDataFile, train.normMap);
        Params mle = this.em(train.data);
        this.evaluateHeldout(test.data, mle);
    }

    private void evaluateHeldout(List<Instance> instances, Params mle) {
        double matches = 0.0;
        double truePositives = 0.0;
        double predictedPositives = 0.0;
        LogInfo.track((Object)"Evaluating precision and recall", true);
        for (Instance instance : instances) {
            double[] posteriorHiddenClassPrs = this.posteriorHiddenClassPrs(mle, instance, false);
            int prediction = this.argmax(posteriorHiddenClassPrs);
            if (prediction == 1) {
                predictedPositives += 1.0;
            }
            if (instance.observedClassId == 1) {
                truePositives += 1.0;
            }
            if (prediction == instance.observedClassId && prediction == 1) {
                matches += 1.0;
            }
            LogInfo.logs("Truth: " + instance.observedClassId + ", Pred: " + prediction + ", Pred dist: " + Arrays.toString(posteriorHiddenClassPrs) + ", Features: " + Arrays.toString(instance.observedFeatures));
        }
        LogInfo.end_track();
        LogInfo.logs("Precision: " + matches / predictedPositives);
        LogInfo.logs("Recall: " + matches / truePositives);
    }

    public static void main(String[] args) {
        Execution.run(args, new SemiSupAsym());
    }

    private class SuffStats {
        private double[] hiddenClassSoftCounts = new double[2];
        private double[][] featureSoftCounts = new double[2][SemiSupAsym.access$000(SemiSupAsym.this)];
        private double[][] featureSoftSqrCounts = new double[2][SemiSupAsym.access$000(SemiSupAsym.this)];

        private SuffStats() {
        }
    }

    private class Params {
        private double[][] mu;
        private double[][] sigmaSqr;

        private Params() {
            this.mu = new double[2][SemiSupAsym.this.nFeatures];
            this.sigmaSqr = new double[2][SemiSupAsym.this.nFeatures];
        }

        public String toString() {
            return new Table(new Table.Populator(){

                @Override
                public void populate() {
                    for (int hc = 0; hc < 2; ++hc) {
                        this.set(hc, 0, "Class " + hc);
                        for (int feat = 0; feat < SemiSupAsym.this.nFeatures; ++feat) {
                            this.set(hc, 1 + feat, "N(" + Params.this.mu[hc][feat] + "," + Params.this.sigmaSqr[hc][feat] + ")");
                        }
                    }
                }
            }).toString();
        }

        private double pi(int hiddenId) {
            if (hiddenId == 0) {
                return 1.0 - SemiSupAsym.this.piPos;
            }
            if (hiddenId == 1) {
                return SemiSupAsym.this.piPos;
            }
            throw new RuntimeException();
        }

        private double hidden2observedClassPr(int hiddenId, int obsId) {
            if (hiddenId == 0) {
                if (obsId == 0) {
                    return 1.0 - SemiSupAsym.this.hiddenNeg2obsPos;
                }
                if (obsId == 1) {
                    return SemiSupAsym.this.hiddenNeg2obsPos;
                }
                throw new RuntimeException();
            }
            if (hiddenId == 1) {
                if (obsId == 0) {
                    return 1.0 - SemiSupAsym.this.hiddenPos2obsPos;
                }
                if (obsId == 1) {
                    return SemiSupAsym.this.hiddenPos2obsPos;
                }
                throw new RuntimeException();
            }
            throw new RuntimeException();
        }

        public double likelihood(int hc, double[] observedFeatures) {
            double product = 1.0;
            for (int feat = 0; feat < SemiSupAsym.this.nFeatures; ++feat) {
                product *= Math.exp(Gaussian.logProb(this.mu[hc][feat], this.sigmaSqr[hc][feat], observedFeatures[feat]));
            }
            return product;
        }
    }

    private class TransitionParametrizationConverter {
        private double fractionLabeled;
        private double h1o1;
        private double h1o0;
        private double h0o1;
        private double h0o0;

        private TransitionParametrizationConverter() {
        }

        private void compute(List<? extends Instance> instances) {
            double nLab = 0.0;
            for (Instance instance : instances) {
                if (instance.observedClassId != 1) continue;
                nLab += 1.0;
            }
            this.fractionLabeled = nLab / (double)instances.size();
            LogInfo.logs("Fraction labeled: " + this.fractionLabeled);
            this.computeJoint();
        }

        private void computeJoint() {
            double q1 = SemiSupAsym.this.falsePositive;
            double q2 = SemiSupAsym.this.fractionUnlabeledThatShouldBePos;
            this.h1o1 = (1.0 - q1) * this.fractionLabeled;
            this.h1o0 = q2 * (1.0 - this.fractionLabeled);
            this.h0o1 = q1 * this.fractionLabeled;
            this.h0o0 = (1.0 - q2) * (1.0 - this.fractionLabeled);
        }

        public double getPiPos() {
            return this.h1o1 + this.h1o0;
        }

        public double getHiddenPos2obsPos() {
            return this.h1o1 / this.getPiPos();
        }

        public double getHiddenNeg2obsPos() {
            return this.h0o1 / (1.0 - this.getPiPos());
        }

        public String toString() {
            return "PiPos: " + this.getPiPos() + "\nHiddenPos2obsPos: " + this.getHiddenPos2obsPos() + "\nHiddenNeg2obsPos: " + this.getHiddenNeg2obsPos();
        }
    }

    private static class NormalizationMap {
        private final double[] min;
        private final double[] scale;
        private final int nFeatures;

        public NormalizationMap(List<Instance> instances) {
            this.nFeatures = instances.get((int)0).observedFeatures.length;
            double[] max = new double[this.nFeatures];
            this.min = new double[this.nFeatures];
            for (int feat = 0; feat < this.nFeatures; ++feat) {
                max[feat] = Double.NEGATIVE_INFINITY;
                this.min[feat] = Double.POSITIVE_INFINITY;
            }
            for (Instance instance : instances) {
                for (int feat = 0; feat < this.nFeatures; ++feat) {
                    if (instance.observedFeatures[feat] > max[feat]) {
                        max[feat] = instance.observedFeatures[feat];
                    }
                    if (!(instance.observedFeatures[feat] < this.min[feat])) continue;
                    this.min[feat] = instance.observedFeatures[feat];
                }
            }
            this.scale = new double[this.nFeatures];
            for (int feat = 0; feat < this.nFeatures; ++feat) {
                this.scale[feat] = 1.0 / (max[feat] - this.min[feat]);
            }
        }
    }

    private class Dataset {
        public final List<Instance> data = new ArrayList<Instance>();
        private int observedLabelIndex = -1;
        private List<Integer> featureIndices = new ArrayList<Integer>();
        private final NormalizationMap normMap;
        public static final String SPLIT_REGEX = "\\s*[,]\\s*";

        public Dataset(String dataFile, NormalizationMap nm) {
            this.readData(dataFile);
            this.normMap = nm;
            SemiSupAsym.normalize(this.data, nm);
        }

        public Dataset(String dataFile) {
            NormalizationMap nm;
            this.readData(dataFile);
            this.normMap = nm = new NormalizationMap(this.data);
            SemiSupAsym.normalize(this.data, nm);
        }

        private void readData(String dataFile) {
            int lineNumber = 0;
            for (String line : IO.i(dataFile)) {
                if (lineNumber == 0) {
                    this.readHeaders(line);
                } else if (!line.matches("^\\s*$")) {
                    String[] fields = line.split(SPLIT_REGEX);
                    Instance instance = new Instance();
                    for (int f = 0; f < SemiSupAsym.this.nFeatures; ++f) {
                        instance.observedFeatures[f] = Double.parseDouble(fields[this.featureIndices.get(f)]);
                    }
                    instance.observedClassId = Integer.parseInt(fields[this.observedLabelIndex]);
                    this.data.add(instance);
                }
                ++lineNumber;
            }
        }

        private void readHeaders(String line) {
            String[] fields = line.split(SPLIT_REGEX);
            for (int i = 0; i < fields.length; ++i) {
                String field = fields[i];
                if (field.matches("unique_id")) continue;
                if (field.matches("multiconf")) {
                    this.observedLabelIndex = i;
                    continue;
                }
                if (field.matches("ubchi1")) continue;
                this.featureIndices.add(i);
            }
            SemiSupAsym.this.nFeatures = this.featureIndices.size();
        }
    }

    public class Instance {
        protected int observedClassId;
        protected double[] observedFeatures;

        public Instance() {
            this.observedFeatures = new double[SemiSupAsym.this.nFeatures];
        }

        public String toString() {
            return "[class=" + this.observedClassId + ",features=" + Arrays.toString(this.observedFeatures) + "]";
        }
    }
}

