/*
 * Decompiled with CFR 0.152.
 */
package ev.multi;

import goblin.CognateId;
import goblin.Taxon;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ma.GreedyDecoder;
import ma.MSAPoset;
import nuts.util.CollUtils;
import nuts.util.Counter;
import nuts.util.CounterMap;

public class Segmenter {
    private final List<SegmentBoundary> segmentBounds;
    private final Map<Taxon, String> sequences;
    private final Map<CognateId, Map<Taxon, String>> segmentedSequences = CollUtils.map();
    private final Map<Taxon, Map<Integer, Integer>> inverseMaps = CollUtils.map();
    private final CognateId baseCognateId;

    public Map<Taxon, String> sequences() {
        return Collections.unmodifiableMap(this.sequences);
    }

    public Segmenter(CognateId id, Map<Taxon, String> sequences, List<SegmentBoundary> segmentBounds) {
        this.baseCognateId = id;
        this.sequences = sequences;
        this.segmentBounds = CollUtils.list(segmentBounds);
        for (Taxon t : sequences.keySet()) {
            if ((Integer)CollUtils.first(segmentBounds).points.get(t) == 0 && ((Integer)CollUtils.last(segmentBounds).points.get(t)).intValue() == sequences.get(t).length()) continue;
            throw new RuntimeException();
        }
        this.init();
    }

    private void init() {
        for (Taxon t : this.sequences.keySet()) {
            this.inverseMaps.put(t, new HashMap());
        }
        for (int i = 0; i < this.segmentBounds.size() - 1; ++i) {
            SegmentBoundary left = this.segmentBounds.get(i);
            SegmentBoundary right = this.segmentBounds.get(i + 1);
            SegmentedCognateId currentId = new SegmentedCognateId(this.baseCognateId, i);
            HashMap currentSegmentedSequence = CollUtils.map();
            for (Taxon t : this.sequences.keySet()) {
                currentSegmentedSequence.put(t, this.sequences.get(t).substring((Integer)left.points.get(t), (Integer)right.points.get(t)));
            }
            this.segmentedSequences.put(currentId, currentSegmentedSequence);
            for (Taxon t : this.sequences.keySet()) {
                Map<Integer, Integer> current = this.inverseMaps.get(t);
                int rightPt = (Integer)right.points.get(t);
                for (int p = ((Integer)left.points.get(t)).intValue(); p < rightPt; ++p) {
                    current.put(p, i);
                }
            }
        }
    }

    public Counter<GreedyDecoder.Edge> desegment(CounterMap<SegmentedCognateId, GreedyDecoder.Edge> counters) {
        Counter<GreedyDecoder.Edge> result = new Counter<GreedyDecoder.Edge>();
        for (SegmentedCognateId id : counters.keySet()) {
            Counter<GreedyDecoder.Edge> current = counters.getCounter(id);
            for (GreedyDecoder.Edge e : current.keySet()) {
                GreedyDecoder.Edge transformed = this.desegment(e, id);
                if (result.getCount(transformed) > 0.0) {
                    throw new RuntimeException();
                }
                result.setCount(transformed, current.getCount(e));
            }
        }
        return result;
    }

    public MSAPoset desegment(Map<CognateId, MSAPoset> segments) {
        MSAPoset msa = new MSAPoset(this.sequences);
        for (CognateId id : segments.keySet()) {
            SegmentedCognateId sci = (SegmentedCognateId)id;
            for (MSAPoset.Column c : segments.get(id).columns()) {
                for (GreedyDecoder.Edge e : c.spanningEdges()) {
                    if (msa.tryAdding(this.desegment(e, sci))) continue;
                    throw new RuntimeException();
                }
            }
        }
        return msa;
    }

    public GreedyDecoder.Edge desegment(GreedyDecoder.Edge e, SegmentedCognateId id) {
        SegmentBoundary leftBound = this.segmentBounds.get(id.index);
        return new GreedyDecoder.Edge(e.index1() + (Integer)leftBound.points.get(e.lang1()), e.index2() + (Integer)leftBound.points.get(e.lang2()), e.lang1(), e.lang2());
    }

    public Map<CognateId, MSAPoset> segmentMSA(MSAPoset fullMSA) {
        HashMap<CognateId, MSAPoset> result = CollUtils.map();
        for (CognateId id : this.segmentedSequences.keySet()) {
            result.put(id, new MSAPoset(this.segmentedSequences.get(id)));
        }
        for (MSAPoset.Column c : fullMSA.columns()) {
            int segmentIndex = -1;
            for (Taxon aTaxon : c.getPoints().keySet()) {
                if (segmentIndex == -1) {
                    segmentIndex = this.inverseMaps.get(aTaxon).get(c.getPoints().get(aTaxon));
                    continue;
                }
                if (segmentIndex == this.inverseMaps.get(aTaxon).get(c.getPoints().get(aTaxon))) continue;
                throw new RuntimeException("MSA violates segmentation:" + c);
            }
            SegmentBoundary leftBoundary = this.segmentBounds.get(segmentIndex);
            MSAPoset currentMSA = (MSAPoset)result.get(new SegmentedCognateId(this.baseCognateId, segmentIndex));
            if (currentMSA.tryAdding(this.transform(c.getPoints(), leftBoundary))) continue;
            throw new RuntimeException();
        }
        return result;
    }

    private Map<Taxon, Integer> transform(Map<Taxon, Integer> points, SegmentBoundary leftBoundary) {
        HashMap<Taxon, Integer> result = CollUtils.map();
        for (Taxon t : points.keySet()) {
            result.put(t, points.get(t) - (Integer)leftBoundary.points.get(t));
        }
        return result;
    }

    public static List<SegmentBoundary> chunk(MSAPoset msa, int targetSize) {
        List<MSAPoset.Column> linearized = msa.linearizedColumns();
        ArrayList<SegmentBoundary> result = CollUtils.list();
        result.add(Segmenter.first(msa.sequences()));
        int nConsumed = 0;
        for (MSAPoset.Column c : linearized) {
            if (c.getPoints().size() != msa.sequences().size() || ++nConsumed != targetSize) continue;
            nConsumed = 0;
            result.add(new SegmentBoundary(c.getPoints()));
        }
        result.add(Segmenter.last(msa.sequences()));
        return result;
    }

    public static SegmentBoundary first(Map<Taxon, String> seq) {
        HashMap<Taxon, Integer> result = CollUtils.map();
        for (Taxon t : seq.keySet()) {
            result.put(t, 0);
        }
        return new SegmentBoundary(result);
    }

    public static SegmentBoundary last(Map<Taxon, String> seq) {
        HashMap<Taxon, Integer> result = CollUtils.map();
        for (Taxon t : seq.keySet()) {
            result.put(t, seq.get(t).length());
        }
        return new SegmentBoundary(result);
    }

    public static void main(String[] args) {
        MSAPoset msa = MSAPoset.parseFASTA(new File("/Users/bouchard/w/evolvere/data/bench1.0/bali2dna/ref/1aab_ref1"));
        System.out.println(msa);
        List<SegmentBoundary> bounds = Segmenter.chunk(msa, 50);
        Segmenter seg = new Segmenter(new CognateId("id"), msa.sequences(), bounds);
        Map<CognateId, MSAPoset> mapp = seg.segmentMSA(msa);
        for (CognateId id : mapp.keySet()) {
            System.out.println("" + id + "\n" + mapp.get(id));
        }
    }

    public static class SegmentedCognateId
    extends CognateId {
        public final int index;
        public final CognateId originalIdStr;
        private static final long serialVersionUID = 1L;

        public SegmentedCognateId(CognateId base, int index) {
            super(base.toString() + "-seg" + index);
            this.index = index;
            this.originalIdStr = base;
        }
    }

    public static class SegmentBoundary {
        private final Map<Taxon, Integer> points;

        public SegmentBoundary(Map<Taxon, Integer> points) {
            this.points = points;
        }
    }
}

