package de.unijena.bioinf.GibbsSampling.model;

import de.unijena.bioinf.ChemistryBase.algorithm.scoring.Scored;
import de.unijena.bioinf.GibbsSampling.model.Candidate;
import de.unijena.bioinf.jjobs.BasicMasterJJob;
import de.unijena.bioinf.jjobs.JJob;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.Random;

/* loaded from: input_file:de/unijena/bioinf/GibbsSampling/model/GibbsMFCorrectionNetwork.class */
public class GibbsMFCorrectionNetwork<C extends Candidate<?>> extends BasicMasterJJob<Scored<C>[][]> {
    public static final boolean DEBUG = false;
    public static final int DEFAULT_CORRELATION_STEPSIZE = 10;
    private static final boolean OUTPUT_SAMPLE_PROBABILITY = false;
    protected Graph<C> graph;
    public static final boolean iniAssignMostLikely = true;
    private int burnInRounds;
    private int currentRound;
    double[] priorProb;
    private int[] activeEdgeCounter;
    int[] activeIdx;
    boolean[] active;
    int[] overallAssignmentFreq;
    double[] posteriorProbs;
    double[] posteriorProbSums;
    private Random random;
    private static final boolean USE_MAX_PRIOR_PROBABILITY = false;
    private static final boolean USE_SQRT_PRIOR_PROBABILITY = false;
    private TIntHashSet fixedCompounds;
    private int maxSteps;
    private int burnIn;
    static final /* synthetic */ boolean $assertionsDisabled;

    public GibbsMFCorrectionNetwork(Graph graph) {
        this(graph, null);
    }

    public GibbsMFCorrectionNetwork(Graph graph, TIntHashSet tIntHashSet) {
        super(JJob.JobType.CPU);
        this.maxSteps = -1;
        this.burnIn = -1;
        this.graph = graph;
        this.fixedCompounds = tIntHashSet == null ? new TIntHashSet() : tIntHashSet;
        this.random = new Random();
        setActive();
    }

    private static boolean isFixed(TIntHashSet tIntHashSet, int i) {
        return tIntHashSet != null && tIntHashSet.contains(i);
    }

    private void setActive() {
        this.priorProb = new double[this.graph.getSize()];
        this.activeEdgeCounter = new int[this.graph.getSize()];
        this.activeIdx = new int[this.graph.numberOfCompounds()];
        this.active = new boolean[this.graph.getSize()];
        int i = 0;
        for (int i2 = 0; i2 < this.graph.numberOfCompounds(); i2++) {
            Scored<C>[] possibleFormulas = this.graph.getPossibleFormulas(i2);
            int i3 = Integer.MIN_VALUE;
            double d = Double.NEGATIVE_INFINITY;
            for (int i4 = 0; i4 < possibleFormulas.length; i4++) {
                double score = possibleFormulas[i4].getScore();
                if (score > d) {
                    d = score;
                    i3 = i4;
                }
            }
            this.activeIdx[i2] = i3;
            this.active[i3 + i] = true;
            i += possibleFormulas.length;
        }
        for (int i5 = 0; i5 < this.priorProb.length; i5++) {
            if (!isFixed(this.fixedCompounds, this.graph.getPeakIdx(i5))) {
                int[] connections = this.graph.getConnections(i5);
                for (int i6 = 0; i6 < connections.length; i6++) {
                    if (this.active[connections[i6]]) {
                        addActiveEdge(connections[i6], i5);
                        int[] iArr = this.activeEdgeCounter;
                        int i7 = i5;
                        iArr[i7] = iArr[i7] + 1;
                    }
                }
            }
        }
        this.posteriorProbs = new double[this.graph.getSize()];
        this.posteriorProbSums = new double[this.graph.numberOfCompounds()];
        for (int i8 = 0; i8 < this.graph.numberOfCompounds(); i8++) {
            updatePeak(i8);
        }
        this.overallAssignmentFreq = new int[this.graph.getSize()];
    }

    private double getPosteriorScore(double d, double d2) {
        return d + d2;
    }

    public void setIterationSteps(int i, int i2) {
        this.maxSteps = i;
        this.burnIn = i2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public Scored<C>[][] m7compute() throws Exception {
        if (this.maxSteps < 0 || this.burnIn < 0) {
            throw new IllegalArgumentException("number of iterations steps not set.");
        }
        updateProgress(0L, this.maxSteps + this.burnIn, 0L);
        setActive();
        this.burnInRounds = this.burnIn;
        int numberOfCompounds = this.graph.numberOfCompounds();
        int i = (this.burnIn + this.maxSteps) / 10;
        for (int i2 = 0; i2 < this.burnIn + this.maxSteps; i2++) {
            this.currentRound = i2;
            for (int i3 : getRandomOrdering(numberOfCompounds)) {
                if (iterationStep(i3)) {
                }
            }
            checkForInterruption();
            updateProgress(0L, this.maxSteps + this.burnIn, i2 + 1);
        }
        return getChosenFormulas();
    }

    public String[] getIds() {
        return this.graph.getIds();
    }

    public Scored<C>[][] getAllPossibleMolecularFormulas() {
        return this.graph.getPossibleFormulas();
    }

    public Scored<C>[][] getAllEdges() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.graph.getSize(); i++) {
            int[] connections = this.graph.getConnections(i);
            Scored<C> possibleFormulas1D = this.graph.getPossibleFormulas1D(i);
            for (int i2 : connections) {
                if (i2 <= i) {
                    arrayList.add(new Scored[]{possibleFormulas1D, this.graph.getPossibleFormulas1D(i2)});
                }
            }
        }
        return (Scored[][]) arrayList.toArray(new Scored[0]);
    }

    public int[][] getAllEdgesIndices() {
        return this.graph.getAllEdgesIndices();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Scored<C>[][] getFormulasSortedByScoring(double[] dArr) {
        Scored<C>[][] scoredArr = (Scored<C>[][]) new Scored[this.graph.numberOfCompounds()];
        for (int i = 0; i < this.graph.numberOfCompounds(); i++) {
            int[] peakBoundaries = this.graph.getPeakBoundaries(i);
            int i2 = peakBoundaries[0];
            int i3 = peakBoundaries[1];
            Scored<C>[] scoredArr2 = new Scored[(i3 - i2) + 1];
            double d = 0.0d;
            for (int i4 = i2; i4 <= i3; i4++) {
                d += dArr[i4];
            }
            for (int i5 = i2; i5 <= i3; i5++) {
                scoredArr2[i5 - i2] = new Scored<>((Candidate) this.graph.getPossibleFormulas1D(i5).getCandidate(), dArr[i5] / d);
            }
            Arrays.sort(scoredArr2, Comparator.reverseOrder());
            scoredArr[i] = scoredArr2;
        }
        return scoredArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Scored<C>[][] getFormulasSortedByScoring(int[] iArr) {
        Scored<C>[][] scoredArr = (Scored<C>[][]) new Scored[this.graph.numberOfCompounds()];
        for (int i = 0; i < this.graph.numberOfCompounds(); i++) {
            if (isFixed(this.fixedCompounds, i)) {
                Scored<C>[] possibleFormulas = this.graph.getPossibleFormulas(i);
                Scored<C>[] scoredArr2 = new Scored[possibleFormulas.length];
                for (int i2 = 0; i2 < scoredArr2.length; i2++) {
                    scoredArr2[i2] = new Scored<>((Candidate) possibleFormulas[i2].getCandidate(), Math.exp(possibleFormulas[i2].getScore()));
                }
                scoredArr[i] = scoredArr2;
            } else {
                int[] peakBoundaries = this.graph.getPeakBoundaries(i);
                int i3 = peakBoundaries[0];
                int i4 = peakBoundaries[1];
                Scored<C>[] scoredArr3 = new Scored[(i4 - i3) + 1];
                int i5 = 0;
                for (int i6 = i3; i6 <= i4; i6++) {
                    i5 += iArr[i6];
                }
                for (int i7 = i3; i7 <= i4; i7++) {
                    double d = (1.0d * iArr[i7]) / i5;
                    if (Double.isNaN(d)) {
                        throw new IllegalStateException("ZODIAC Gibbs sampling produced NaN score for: " + this.graph.getIds()[i]);
                    }
                    scoredArr3[i7 - i3] = new Scored<>((Candidate) this.graph.getPossibleFormulas1D(i7).getCandidate(), d);
                }
                Arrays.sort(scoredArr3, Comparator.reverseOrder());
                scoredArr[i] = scoredArr3;
            }
        }
        return scoredArr;
    }

    public Graph getGraph() {
        return this.graph;
    }

    public Scored<C>[][] getChosenFormulasBySampling() {
        return getFormulasSortedByScoring(this.overallAssignmentFreq);
    }

    public Scored<C>[][] getChosenFormulas() {
        return getFormulasSortedByScoring(this.overallAssignmentFreq);
    }

    private boolean iterationStep(int i) {
        int[] peakBoundaries = this.graph.getPeakBoundaries(i);
        int i2 = peakBoundaries[0];
        int randomIdx = getRandomIdx(i2, peakBoundaries[1], this.posteriorProbSums[i], this.posteriorProbs);
        if (this.currentRound > this.burnInRounds && (this.currentRound - this.burnInRounds) % 10.0d == 0.0d) {
            int[] iArr = this.overallAssignmentFreq;
            iArr[randomIdx] = iArr[randomIdx] + 1;
        }
        int i3 = this.activeIdx[i];
        int i4 = i3 + i2;
        int i5 = randomIdx - i2;
        if (i3 == i5) {
            return false;
        }
        BitSet bitSet = new BitSet();
        for (int i6 : this.graph.getConnections(i4)) {
            int peakIdx = this.graph.getPeakIdx(i6);
            if (!isFixed(this.fixedCompounds, peakIdx)) {
                removeActiveEdge(i4, i6);
                bitSet.set(peakIdx);
            }
        }
        for (int i7 : this.graph.getConnections(randomIdx)) {
            int peakIdx2 = this.graph.getPeakIdx(i7);
            if (!isFixed(this.fixedCompounds, peakIdx2)) {
                addActiveEdge(randomIdx, i7);
                bitSet.set(peakIdx2);
            }
        }
        int nextSetBit = bitSet.nextSetBit(0);
        while (true) {
            int i8 = nextSetBit;
            if (i8 < 0) {
                break;
            }
            if (!isFixed(this.fixedCompounds, i8)) {
                updatePeak(i8);
            }
            if (i8 == Integer.MAX_VALUE) {
                break;
            }
            nextSetBit = bitSet.nextSetBit(i8 + 1);
        }
        this.activeIdx[i] = i5;
        this.active[i4] = false;
        this.active[randomIdx] = true;
        return true;
    }

    private void removeActiveEdge(int i, int i2) {
        double[] dArr = this.priorProb;
        dArr[i2] = dArr[i2] - this.graph.getLogWeight(i, i2);
    }

    private void addActiveEdge(int i, int i2) {
        double[] dArr = this.priorProb;
        dArr[i2] = dArr[i2] + this.graph.getLogWeight(i, i2);
    }

    private int getRandomIdx(int i, int i2, double d, double[] dArr) {
        double nextDouble = this.random.nextDouble() * d;
        int i3 = i - 1;
        double d2 = 0.0d;
        do {
            i3++;
            d2 += dArr[i3];
        } while (d2 < nextDouble);
        if (i3 > i2) {
            throw new RuntimeException("sampling by probability produced error");
        }
        return i3;
    }

    private void updatePeak(int i) {
        int[] peakBoundaries = this.graph.getPeakBoundaries(i);
        int i2 = peakBoundaries[0];
        int i3 = peakBoundaries[1];
        double d = Double.NEGATIVE_INFINITY;
        for (int i4 = i2; i4 <= i3; i4++) {
            this.posteriorProbs[i4] = getPosteriorScore(this.priorProb[i4], this.graph.getCandidateScore(i4));
            if (this.posteriorProbs[i4] > d) {
                d = this.posteriorProbs[i4];
            }
        }
        double d2 = 0.0d;
        for (int i5 = i2; i5 <= i3; i5++) {
            this.posteriorProbs[i5] = Math.exp(this.posteriorProbs[i5] - d);
            d2 += this.posteriorProbs[i5];
        }
        if (!$assertionsDisabled && d2 <= 0.0d) {
            throw new AssertionError();
        }
        this.posteriorProbSums[i] = d2;
    }

    public static int[] getRandomOrdering(int i) {
        return getRandomOrdering(0, i);
    }

    public static int[] getRandomOrdering(int i, int i2) {
        TIntArrayList tIntArrayList = new TIntArrayList(i2 - i);
        TIntArrayList tIntArrayList2 = new TIntArrayList(i2 - i);
        Random random = new Random();
        for (int i3 = i; i3 < i2; i3++) {
            tIntArrayList.add(i3);
        }
        while (tIntArrayList.size() > 0) {
            tIntArrayList2.add(tIntArrayList.removeAt(random.nextInt(tIntArrayList.size())));
        }
        return tIntArrayList2.toArray();
    }

    public static <C extends Candidate<?>> Scored<C>[] computeFromSnapshot(Graph<C> graph, int i) {
        int peakLeftBoundary = graph.getPeakLeftBoundary(i);
        int peakRightBoundary = graph.getPeakRightBoundary(i);
        Scored<C>[] scoredArr = new Scored[(peakRightBoundary - peakLeftBoundary) + 1];
        double[] dArr = new double[(peakRightBoundary - peakLeftBoundary) + 1];
        for (int i2 = peakLeftBoundary; i2 <= peakRightBoundary; i2++) {
            int[] connections = graph.getConnections(i2);
            double candidateScore = graph.getCandidateScore(i2);
            for (int i3 : connections) {
                candidateScore += graph.getLogWeight(i3, i2) * Math.exp(graph.getCandidateScore(i3));
            }
            dArr[i2 - peakLeftBoundary] = candidateScore;
        }
        double d = Double.NEGATIVE_INFINITY;
        for (double d2 : dArr) {
            if (d < d2) {
                d = d2;
            }
        }
        double d3 = 0.0d;
        for (int i4 = 0; i4 < dArr.length; i4++) {
            double exp = Math.exp(dArr[i4] - d);
            dArr[i4] = exp;
            d3 += exp;
        }
        if (!$assertionsDisabled && d3 <= 0.0d) {
            throw new AssertionError();
        }
        for (int i5 = peakLeftBoundary; i5 <= peakRightBoundary; i5++) {
            scoredArr[i5 - peakLeftBoundary] = new Scored<>(graph.getPossibleFormulas1D(i5).getCandidate(), dArr[i5 - peakLeftBoundary] / d3);
        }
        Arrays.sort(scoredArr, Comparator.reverseOrder());
        return scoredArr;
    }

    static {
        $assertionsDisabled = !GibbsMFCorrectionNetwork.class.desiredAssertionStatus();
    }
}
