/*
 * Decompiled with CFR 0.152.
 */
package fenchel.measurefacto;

import fenchel.measurefacto.ComputeFullProduct;
import fenchel.measurefacto.ComputeXis;
import fenchel.measurefacto.ComputeZeta;
import fenchel.measurefacto.ExponentialFamily;
import fenchel.measurefacto.ExtendedExponentialFamily;
import fenchel.measurefacto.PointwiseOperationApplicator;
import fig.basic.IOUtils;
import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;

public class MFBeliefProp<M> {
    private List<M> queryMoments = null;
    private List<M> xis = null;
    private List<M> zetas = null;
    private final M parameters;
    private List<M> moments = null;
    private M workingCopy = null;
    private final PointwiseOperationApplicator<M> applicator;
    private final List<ExponentialFamily<M>> approximations;
    private final int nApproximations;
    private boolean initializing = true;
    private boolean momentsComputed = false;
    private final boolean useNorm;
    private final List<Double> alphas;
    private final double correction;
    private final boolean useSum;
    private File outputDir = null;
    private int iter = 0;

    public void setOutputDir(File f) {
        if (!f.isDirectory()) {
            throw new RuntimeException();
        }
        this.outputDir = f;
    }

    public MFBeliefProp(M parameters, M emptyVector, List<ExponentialFamily<M>> approximations, PointwiseOperationApplicator<M> applicator, List<Double> alphas, double correction, boolean useSum) {
        this.useSum = useSum;
        this.correction = correction;
        this.alphas = alphas;
        this.useNorm = false;
        this.approximations = approximations;
        this.nApproximations = approximations.size();
        this.parameters = parameters;
        this.applicator = applicator;
        this.workingCopy = emptyVector;
    }

    public List<M> queryMoments() {
        if (!this.momentsComputed) {
            this.updateMoments(this.initializing);
            this.momentsComputed = true;
        }
        return this.queryMoments;
    }

    public M queryProductMoments() {
        List<M> queryMoments = this.queryMoments();
        ComputeFullProduct.computeFullProduct(this.applicator, queryMoments, this.workingCopy, this.alphas, this.correction, this.useSum);
        M result = this.workingCopy;
        queryMoments = null;
        this.log("mu", result);
        return result;
    }

    private List<M> computeIndividualFactorMoments() {
        if (!this.momentsComputed) {
            this.updateMoments(this.initializing);
            this.momentsComputed = true;
        }
        return this.moments;
    }

    public void iterate() {
        this.computeIndividualFactorMoments();
        this.momentsComputed = false;
        this.updateZetas(this.initializing);
        this.initializing = false;
        this.updateXis();
        ++this.iter;
    }

    private void updateMoments(boolean isInit) {
        this.queryMoments = new ArrayList<M>(this.nApproximations);
        ArrayList<M> newMoments = new ArrayList<M>(this.nApproximations);
        for (int i = 0; i < this.nApproximations; ++i) {
            M currentParam;
            ExponentialFamily<M> currentApprox = this.approximations.get(i);
            M m = currentParam = isInit ? this.parameters : this.xis.get(i);
            if (this.useNorm) {
                this.applicator.normalize(currentParam);
            }
            this.log("xi", i, currentParam);
            M currentMoment = currentApprox.moments(currentParam);
            if (currentApprox instanceof ExtendedExponentialFamily) {
                this.queryMoments.add(((ExtendedExponentialFamily)currentApprox).getExtendedStatistics());
            }
            this.log("nabla", i, currentMoment);
            newMoments.add(currentMoment);
        }
        this.moments = newMoments;
    }

    private void log(String msg, Object contents) {
        this.log(msg, null, contents);
    }

    private void log(String msg, Integer i, Object contents) {
        if (this.outputDir == null) {
            return;
        }
        File current = new File(this.outputDir, "iter=" + this.iter + ",msg=" + msg + (i == null ? "" : ",fac=" + i));
        PrintWriter out = IOUtils.openOutEasy(current);
        out.append(contents.toString());
        out.close();
    }

    private void updateZetas(boolean isInit) {
        ArrayList<M> newZetas = new ArrayList<M>(this.nApproximations);
        for (int i = 0; i < this.nApproximations; ++i) {
            M currentXi = isInit ? this.parameters : this.xis.get(i);
            M currentMoment = this.moments.get(i);
            ComputeZeta.computeZeta(this.applicator, currentMoment, currentXi);
            this.log("zeta", i, currentMoment);
            newZetas.add(currentMoment);
        }
        this.moments = null;
        this.zetas = newZetas;
    }

    private void updateXis() {
        ComputeXis.computeXis(this.applicator, this.zetas, this.parameters, this.alphas);
        this.xis = this.zetas;
        this.zetas = null;
    }
}

