/*
 * Decompiled with CFR 0.152.
 */
package pty.io;

import goblin.Taxon;
import nuts.lang.ArrayUtils;
import nuts.math.GMFct;
import org.apache.commons.math.stat.descriptive.SummaryStatistics;
import pty.Observations;
import pty.RootedTree;
import pty.learn.DiscreteBP;
import pty.smc.PartialCoalescentState;
import pty.smc.models.CTMC;

public class LeaveOneOut {
    public static double loo(PartialCoalescentState pcs) {
        return LeaveOneOut.loo(pcs.getFullCoalescentState(), pcs.getCTMC(), pcs.getObservations());
    }

    public static double loo(RootedTree tree, CTMC ctmc, Observations observations) {
        SummaryStatistics stats = new SummaryStatistics();
        for (Taxon lang : observations.observations().keySet()) {
            for (int site = 0; site < observations.nSites(); ++site) {
                if (!LeaveOneOut.isKnown(observations, lang, site)) continue;
                GMFct<Taxon> post = DiscreteBP.posteriorMarginalTransitions(tree, ctmc, observations, site, lang);
                int prediction = -1;
                double max = Double.NEGATIVE_INFINITY;
                for (int ch = 0; ch < observations.nCharacter(site); ++ch) {
                    double d;
                    double current = post.get(lang, ch);
                    if (!(d > max)) continue;
                    prediction = ch;
                    max = current;
                }
                int truth = LeaveOneOut.getObs(observations, lang, site);
                stats.addValue((double)(prediction == truth ? 1 : 0));
            }
        }
        return stats.getMean();
    }

    private static int getObs(Observations observations, Taxon lang, int site) {
        int result = -1;
        double[] array = observations.observations().get(lang)[site];
        for (int i = 0; i < array.length; ++i) {
            double cur = array[i];
            if (cur != 1.0) continue;
            if (result == -1) {
                result = i;
                continue;
            }
            throw new RuntimeException();
        }
        return result;
    }

    private static boolean isKnown(Observations observations, Taxon lang, int site) {
        double sum = ArrayUtils.sum(observations.observations().get(lang)[site]);
        return sum == 1.0;
    }
}

