/*
 * Decompiled with CFR 0.152.
 */
package pedi.tests;

import fenchel.algo.FactorGraphSumProduct;
import fenchel.factor.FactorUtils;
import fenchel.factor.multisites.MSFactorGraph;
import nuts.math.MtxUtils;
import nuts.util.MathUtils;
import org.junit.Assert;
import org.junit.Test;
import pedi.Genotype;
import pedi.Genotypes;
import pedi.Individual;
import pedi.Pedigree;
import pedi.RegularPedigree;
import pedi.factor.Pedigree2FactorGraph;
import pedi.factor.PedigreeNode;
import pedi.factor.ProjectionEncodings;
import pedi.io.PedigreeFileReader;

public class Pedigree2FactorGraphTest {
    private RegularPedigree p;
    private Genotypes g;
    private MSFactorGraph<PedigreeNode> fg;
    public static final ProjectionEncodings old_inheritanceNodeProjections = new ProjectionEncodings(new double[][]{{0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0}});
    public static final ProjectionEncodings haplotypeEncodings = new ProjectionEncodings(new double[][]{{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}});

    @Test
    public void testNorm() {
        int i;
        double sum;
        for (boolean v : new boolean[]{true, false}) {
            sum = 0.0;
            for (i = 1; i < 4; ++i) {
                sum += Math.exp(this.logZ("src-pedi/pedi/tests/simple" + i + ".ped", v));
            }
            Assert.assertEquals((double)1.0, (double)sum, (double)MathUtils.threshold);
        }
        for (String test : new String[]{"manychildren", "morecomplex"}) {
            sum = 0.0;
            for (i = 0; i < 9; ++i) {
                sum += Math.exp(this.logZ("src-pedi/pedi/tests/" + test + "." + i + ".ped", false));
            }
            Assert.assertEquals((double)1.0, (double)sum, (double)MathUtils.threshold);
        }
    }

    @Test
    public void testOldVsNew() {
        for (int i = 1; i < 4; ++i) {
            double v1 = Math.exp(this.logZ("src-pedi/pedi/tests/simple" + i + ".ped", true));
            double v2 = Math.exp(this.logZ("src-pedi/pedi/tests/simple" + i + ".ped", false));
            Assert.assertEquals((double)v1, (double)v2, (double)MathUtils.threshold);
        }
    }

    private double logZ(String file, boolean useOld) {
        this.setup(file, useOld);
        FactorGraphSumProduct<PedigreeNode> sumProd = new FactorGraphSumProduct<PedigreeNode>();
        sumProd.init(this.fg);
        return sumProd.logZ();
    }

    private void setup(String file, boolean useOld) {
        PedigreeFileReader reader = new PedigreeFileReader();
        reader.read(file);
        this.p = RegularPedigree.getRegularPedigree(reader.getPedigree());
        this.g = reader.getGenotypes();
        System.out.println(this.p.founders());
        this.fg = FactorUtils.newFactorGraph();
        if (useOld) {
            Pedigree2FactorGraphTest._old_addGenotypeFactors(this.p, this.g, this.fg);
        } else {
            Pedigree2FactorGraph.addGenotypeFactors(this.p, this.g, this.fg);
        }
        if (useOld) {
            Pedigree2FactorGraphTest._old_addInheritanceFactors(this.p, this.fg);
        } else {
            Pedigree2FactorGraph.addInheritanceFactors(this.p, this.fg, this.g.getFactorEncodings());
        }
        Pedigree2FactorGraph.addIndependentRecombFactors(this.p, this.fg, this.g.genomeSize());
        if (useOld) {
            Pedigree2FactorGraphTest._old_addSimpleFounderFactors(this.p, Pedigree2FactorGraphTest.getUniformAlleleDistribution(1), this.fg);
        } else {
            Pedigree2FactorGraph.addSimpleFounderFactors(this.p, Pedigree2FactorGraph.getUniformHardyWeinbergDistribution(1, this.g.getFactorEncodings(), 0.5), this.fg);
        }
    }

    public static void _old_addSimpleFounderFactors(RegularPedigree p, double[][] factor, MSFactorGraph<PedigreeNode> f) {
        for (Individual i : p.founders()) {
            for (int parent = 0; parent < 2; ++parent) {
                f.addUnary(PedigreeNode.createAlleleNode(i, parent), factor);
            }
        }
    }

    public static double[][] getUniformAlleleDistribution(int nSites) {
        double[] values = new double[]{0.5, 0.5};
        return FactorUtils.simpleUnary(values, nSites);
    }

    public static void _old_addInheritanceFactors(RegularPedigree p, MSFactorGraph<PedigreeNode> f) {
        double[][][] grandParent2inheritanceNode = new double[][][]{old_inheritanceNodeProjections.getFactor(0), old_inheritanceNodeProjections.getFactor(1)};
        double[][] recomb2inheritanceNode = old_inheritanceNodeProjections.getFactor(2);
        double[][] children2inheritanceNode = old_inheritanceNodeProjections.getFactor(3);
        double[][][] transp_grandParent2inheritanceNode = new double[][][]{MtxUtils.transpose(grandParent2inheritanceNode[0]), MtxUtils.transpose(grandParent2inheritanceNode[1])};
        double[][] transp_recomb2inheritanceNode = MtxUtils.transpose(recomb2inheritanceNode);
        double[][] transp_children2inheritanceNode = MtxUtils.transpose(children2inheritanceNode);
        for (Individual i : p.individuals()) {
            if (p.founders().contains(i)) continue;
            for (int parent = 0; parent < 2; ++parent) {
                PedigreeNode inheritanceNode = PedigreeNode.createInheritanceNode(i, parent);
                for (int grandparent = 0; grandparent < 2; ++grandparent) {
                    PedigreeNode grandParentAllele = PedigreeNode.createAlleleNode(p.parent(i, parent), grandparent);
                    f.addBinary(grandParentAllele, inheritanceNode, grandParent2inheritanceNode[grandparent], transp_grandParent2inheritanceNode[grandparent]);
                }
                PedigreeNode alleleNode = PedigreeNode.createAlleleNode(i, parent);
                f.addBinary(alleleNode, inheritanceNode, children2inheritanceNode, transp_children2inheritanceNode);
                PedigreeNode recombNode = PedigreeNode.createRecombNode(i, parent);
                f.addBinary(recombNode, inheritanceNode, recomb2inheritanceNode, transp_recomb2inheritanceNode);
            }
        }
    }

    public static void _old_addGenotypeFactors(Pedigree p, Genotypes genotypes, MSFactorGraph<PedigreeNode> f) {
        ProjectionEncodings haploEncodings = haplotypeEncodings;
        double[][][] factors = new double[][][]{haploEncodings.getFactor(0), haploEncodings.getFactor(1)};
        double[][][] transp_factors = new double[][][]{MtxUtils.transpose(factors[0]), MtxUtils.transpose(factors[1])};
        for (Individual i : genotypes.genotypedIndividuals()) {
            Genotype genotype = genotypes.getGenotype(i);
            PedigreeNode haplotypeNode = PedigreeNode.createHaplotypeNode(i);
            f.addUnary(haplotypeNode, genotype.getUnaryFactor());
            for (int parent = 0; parent < 2; ++parent) {
                PedigreeNode alleleNode = PedigreeNode.createAlleleNode(i, parent);
                f.addBinary(alleleNode, haplotypeNode, factors[parent], transp_factors[parent]);
            }
        }
    }
}

