/*
 * Decompiled with CFR 0.152.
 */
package ev.par;

import ev.hmm.HetPairHMMSpecification;
import ev.par.Input;
import ev.par.Model;
import ev.par.Output;
import ev.par.StrTaxonSuffStat;
import goblin.Taxon;
import java.util.Arrays;
import java.util.List;
import java.util.SortedSet;
import ma.MultiAlignment;
import nuts.lang.ArrayUtils;
import nuts.maxent.MaxentClassifier;

public final class CachedParams {
    private final double[][][][][] cachedLogPrs;
    private final Model model;
    public static final int NEUTRAL = -1;
    public static final int PROHIBITED = -2;

    public HetPairHMMSpecification getUnsupPairHMM(String top, String bot, Taxon topTaxon, Taxon botTaxon) {
        return this.getPairHMM(top, bot, topTaxon, botTaxon, null);
    }

    public HetPairHMMSpecification getSupPairHMM(MultiAlignment truth, String top, String bot, Taxon topL, Taxon botL) {
        return this.getPairHMM(top, bot, topL, botL, truth);
    }

    public HetPairHMMSpecification getReweightedHMM(final double[][][] logWeights, final String top, final String bot, Taxon topTaxon, Taxon botTaxon) {
        final StrTaxonSuffStat.StrTaxonSuffStatExtractor extractor = this.model.stSuffStat.getExtractor(top, bot, topTaxon, botTaxon);
        return new HetPairHMMSpecification(){

            @Override
            public final int startState() {
                return ((CachedParams)CachedParams.this).model.startState;
            }

            @Override
            public final int endState() {
                return ((CachedParams)CachedParams.this).model.endState;
            }

            @Override
            public final int nStates() {
                return ((CachedParams)CachedParams.this).model.nStates;
            }

            @Override
            public final double logWeight(int prevState, int currentState, int xpos, int ypos, int deltaX, int deltaY) {
                int xid = CachedParams.this.model.charIdAt(top, xpos, deltaX);
                int yid = CachedParams.this.model.charIdAt(bot, ypos, deltaY);
                int stss = extractor.extract(xpos, ypos);
                boolean isAligned = deltaX == 1 && deltaY == 1;
                double logWeight = isAligned ? logWeights[xpos][ypos][1] - logWeights[xpos][ypos][0] : 0.0;
                return logWeight + CachedParams.this.getLogPr(prevState, currentState, stss, xid, yid);
            }
        };
    }

    private HetPairHMMSpecification getPairHMM(final String top, final String bot, Taxon topTaxon, Taxon botTaxon, final MultiAlignment msa) {
        int[] inverseConnections;
        final StrTaxonSuffStat.StrTaxonSuffStatExtractor extractor = this.model.stSuffStat.getExtractor(top, bot, topTaxon, botTaxon);
        final int topSeqLength = msa == null ? -1 : msa.getSequences().get(topTaxon).length();
        final int botSeqLength = msa == null ? -1 : msa.getSequences().get(botTaxon).length();
        final int[] directconnections = msa == null ? null : new int[topSeqLength];
        int[] nArray = inverseConnections = msa == null ? null : new int[botSeqLength];
        if (msa != null) {
            int b;
            Arrays.fill(directconnections, -1);
            Arrays.fill(inverseConnections, -1);
            for (int t = 0; t < topSeqLength; ++t) {
                if (!msa.isCoreBlock(topTaxon, t)) continue;
                directconnections[t] = -2;
            }
            for (b = 0; b < botSeqLength; ++b) {
                if (!msa.isCoreBlock(botTaxon, b)) continue;
                inverseConnections[b] = -2;
            }
            for (b = 0; b < botSeqLength; ++b) {
                if (!msa.isCoreBlock(botTaxon, b)) continue;
                for (int t = 0; t < topSeqLength; ++t) {
                    if (!msa.isAligned(topTaxon, t, botTaxon, b)) continue;
                    directconnections[t] = b;
                    inverseConnections[b] = t;
                }
            }
        }
        return new HetPairHMMSpecification(){

            @Override
            public final int startState() {
                return ((CachedParams)CachedParams.this).model.startState;
            }

            @Override
            public final int endState() {
                return ((CachedParams)CachedParams.this).model.endState;
            }

            @Override
            public final int nStates() {
                return ((CachedParams)CachedParams.this).model.nStates;
            }

            @Override
            public final double logWeight(int prevState, int currentState, int xpos, int ypos, int deltaX, int deltaY) {
                int xid = CachedParams.this.model.charIdAt(top, xpos, deltaX);
                int yid = CachedParams.this.model.charIdAt(bot, ypos, deltaY);
                if (msa != null && xpos < topSeqLength && ypos < botSeqLength) {
                    boolean guessIsAligned;
                    boolean bl = guessIsAligned = deltaX == 1 && deltaY == 1;
                    if (guessIsAligned) {
                        if (directconnections[xpos] != -1 && directconnections[xpos] != ypos || inverseConnections[ypos] != -1 && inverseConnections[ypos] != xpos) {
                            return Double.NEGATIVE_INFINITY;
                        }
                    } else if (deltaX == 0) {
                        if (inverseConnections[ypos] >= 0) {
                            return Double.NEGATIVE_INFINITY;
                        }
                    } else if (deltaY == 0) {
                        if (directconnections[xpos] >= 0) {
                            return Double.NEGATIVE_INFINITY;
                        }
                    } else {
                        throw new RuntimeException();
                    }
                }
                int stss = extractor.extract(xpos, ypos);
                return CachedParams.this.getLogPr(prevState, currentState, stss, xid, yid);
            }
        };
    }

    public CachedParams(Model model) {
        this.model = model;
        this.cachedLogPrs = new double[model.nStates][model.nStates][model.stSuffStat.valuesIndexer.size()][model.enc.N() + 1][model.enc.N() + 1];
        ArrayUtils.deepFill(this.cachedLogPrs, Double.NEGATIVE_INFINITY);
    }

    private void set(Input in, Output out, double logpr) {
        if (logpr > 1.0E-4) {
            throw new RuntimeException("Invalid log pr:" + logpr);
        }
        this.cachedLogPrs[in.state1][out.state2][in.strTaxSuffStat][out.topSymbol][out.botSymbol] = logpr;
    }

    public double getLogPr(Input in, Output out) {
        return this.cachedLogPrs[in.state1][out.state2][in.strTaxSuffStat][out.topSymbol][out.botSymbol];
    }

    public final double getLogPr(int state1, int state2, int strTaxSuffStat, int topSymbol, int botSymbol) {
        return this.cachedLogPrs[state1][state2][strTaxSuffStat][topSymbol][botSymbol];
    }

    public CachedParams(Model model, MaxentClassifier<Input, Output, Object> maxentClassifier) {
        this(model);
        for (Input in : model.allInputs()) {
            SortedSet<Output> outs = maxentClassifier.getLabels(in);
            double[] logprs = maxentClassifier.logProb(in);
            int i = 0;
            for (Output out : outs) {
                this.set(in, out, logprs[i]);
                ++i;
            }
        }
    }

    public String toString() {
        StringBuilder result = new StringBuilder();
        List<Output> allOuts = this.model.allOutputs();
        for (Input in : this.model.allInputs()) {
            for (Output out : allOuts) {
                double val = this.getLogPr(in, out);
                if (Double.isInfinite(val)) continue;
                result.append("" + in + "\t" + out + "\t" + Math.exp(val) + "\n");
            }
        }
        return result.toString();
    }
}

