package de.unijena.bioinf.fingerid.pvalues;

import de.unijena.bioinf.graphUtils.tree.Tree;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:de/unijena/bioinf/fingerid/pvalues/TreeDP.class */
public class TreeDP {
    final HashMap<Integer, DPTableUnit> tableMap = new HashMap<>();
    final ArrayList<DPTableUnit> tables = new ArrayList<>();
    final FingerprintTree tree;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:de/unijena/bioinf/fingerid/pvalues/TreeDP$DPTableUnit.class */
    public static final class DPTableUnit {
        private final Tree<FPVariable> node;
        private final FPVariable variable;
        private Probability[] forZero;
        private Probability[] forOne;
        private int capaOne;
        private int capaZero;

        public DPTableUnit(Tree<FPVariable> tree) {
            this.node = tree;
            this.variable = (FPVariable) tree.getLabel();
        }

        void reserve(int i) {
            reserve(i, i);
        }

        void reserve(int i, int i2) {
            this.capaOne = i2;
            this.capaZero = i;
            this.forZero = new Probability[this.capaZero];
            Arrays.fill(this.forZero, Probability.ZERO);
            this.forOne = new Probability[this.capaOne];
            Arrays.fill(this.forOne, Probability.ZERO);
        }

        void clear() {
            this.forOne = null;
            this.forZero = null;
        }

        void copyTo(DPTableUnit dPTableUnit) {
            System.arraycopy(this.forOne, 0, dPTableUnit.forOne, 0, this.forOne.length);
            System.arraycopy(this.forZero, 0, dPTableUnit.forZero, 0, this.forZero.length);
        }
    }

    public TreeDP(FingerprintTree fingerprintTree) {
        this.tree = fingerprintTree;
        Iterator<Tree<FPVariable>> it = fingerprintTree.nodes.iterator();
        while (it.hasNext()) {
            Tree<FPVariable> next = it.next();
            DPTableUnit dPTableUnit = new DPTableUnit(next);
            this.tableMap.put(Integer.valueOf(((FPVariable) next.getLabel()).to), dPTableUnit);
            this.tables.add(dPTableUnit);
        }
    }

    private static boolean checkProbability(Probability[] probabilityArr) {
        Probability probability = Probability.ZERO;
        for (Probability probability2 : probabilityArr) {
            probability = probability.add(probability2);
        }
        return probability.getExp() <= 0;
    }

    public long computePlattScores(boolean[] zArr, int i, double[] dArr, double[] dArr2, double d) {
        int[] iArr = new int[dArr.length];
        int[] iArr2 = new int[dArr2.length];
        for (int i2 = 0; i2 < iArr2.length; i2++) {
            iArr2[i2] = (int) Math.round(dArr2[i2] * i);
            iArr[i2] = (int) Math.round(dArr[i2] * i);
        }
        int ceil = (int) Math.ceil(d * i);
        Iterator<DPTableUnit> it = this.tables.iterator();
        while (it.hasNext()) {
            DPTableUnit next = it.next();
            int degree = next.node.degree();
            if (degree == 0) {
                leafPlatt(next, zArr, iArr, iArr2);
            } else if (degree == 1) {
                innerVertexPlatt(next, this.tableMap.get(Integer.valueOf(((FPVariable) ((Tree) next.node.children().get(0)).getLabel()).to)), zArr, iArr, iArr2, ceil);
            } else {
                DPTableUnit[] dPTableUnitArr = new DPTableUnit[degree];
                int i3 = 0;
                Iterator it2 = next.node.children().iterator();
                while (it2.hasNext()) {
                    int i4 = i3;
                    i3++;
                    dPTableUnitArr[i4] = this.tableMap.get(Integer.valueOf(((FPVariable) ((Tree) it2.next()).getLabel()).to));
                }
                multipleChildrenVertexPlatt(next, dPTableUnitArr, zArr, iArr, iArr2, ceil);
            }
        }
        return rootPlatt(this.tableMap.get(Integer.valueOf(((FPVariable) this.tree.root.getLabel()).to)), ceil).getExp();
    }

    public long computeUnitScores(boolean[] zArr, int i) {
        Iterator<DPTableUnit> it = this.tables.iterator();
        while (it.hasNext()) {
            DPTableUnit next = it.next();
            int degree = next.node.degree();
            if (degree == 0) {
                leafUnit(next, zArr);
            } else if (degree == 1) {
                innerVertexUnit(next, this.tableMap.get(Integer.valueOf(((FPVariable) ((Tree) next.node.children().get(0)).getLabel()).to)), zArr, i);
            } else {
                DPTableUnit[] dPTableUnitArr = new DPTableUnit[degree];
                int i2 = 0;
                Iterator it2 = next.node.children().iterator();
                while (it2.hasNext()) {
                    int i3 = i2;
                    i2++;
                    dPTableUnitArr[i3] = this.tableMap.get(Integer.valueOf(((FPVariable) ((Tree) it2.next()).getLabel()).to));
                }
                multipleChildrenVertexUnit(next, dPTableUnitArr, zArr, i);
            }
        }
        return rootUnit(this.tableMap.get(Integer.valueOf(((FPVariable) this.tree.root.getLabel()).to)), i).getExp();
    }

    private Probability rootUnit(DPTableUnit dPTableUnit, int i) {
        Probability probability = Probability.ZERO;
        for (int i2 = 0; i2 <= i; i2++) {
            probability = probability.add(dPTableUnit.forOne[i2]).add(dPTableUnit.forZero[i2]);
        }
        return probability;
    }

    private void multipleChildrenVertexUnit(DPTableUnit dPTableUnit, DPTableUnit[] dPTableUnitArr, boolean[] zArr, int i) {
        DPTableUnit[] dPTableUnitArr2 = new DPTableUnit[dPTableUnitArr.length];
        int i2 = 0;
        for (int i3 = 0; i3 < dPTableUnitArr.length; i3++) {
            dPTableUnitArr2[i3] = new DPTableUnit(dPTableUnit.node);
            innerVertexUnit(dPTableUnitArr2[i3], dPTableUnitArr[i3], zArr, i);
            dPTableUnitArr[i3].clear();
            i2 += dPTableUnitArr2[i3].capaOne;
            if (!$assertionsDisabled && dPTableUnitArr2[i3].capaOne != dPTableUnitArr2[i3].capaZero) {
                throw new AssertionError();
            }
        }
        int min = Math.min(i2, i + 1);
        DPTableUnit dPTableUnit2 = new DPTableUnit(dPTableUnit.node);
        dPTableUnit2.reserve(min);
        dPTableUnit.reserve(min);
        dPTableUnitArr2[0].copyTo(dPTableUnit2);
        for (int i4 = 1; i4 < dPTableUnitArr.length; i4++) {
            int min2 = Math.min(i + 1, min);
            for (int i5 = 0; i5 < min2; i5++) {
                int min3 = Math.min(i5, dPTableUnitArr2[i4].capaOne - 1);
                for (int i6 = 0; i6 <= min3; i6++) {
                    dPTableUnit.forOne[i5] = dPTableUnit.forOne[i5].add(dPTableUnit2.forOne[i5 - i6].multiply(dPTableUnitArr2[i4].forOne[i6]));
                    dPTableUnit.forZero[i5] = dPTableUnit.forZero[i5].add(dPTableUnit2.forZero[i5 - i6].multiply(dPTableUnitArr2[i4].forZero[i6]));
                }
            }
            dPTableUnit.copyTo(dPTableUnit2);
            Arrays.fill(dPTableUnit.forZero, Probability.ZERO);
            Arrays.fill(dPTableUnit.forOne, Probability.ZERO);
        }
        dPTableUnit2.copyTo(dPTableUnit);
        for (DPTableUnit dPTableUnit3 : dPTableUnitArr2) {
            dPTableUnit3.clear();
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forOne)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forZero)) {
            throw new AssertionError();
        }
    }

    private void innerVertexUnit(DPTableUnit dPTableUnit, DPTableUnit dPTableUnit2, boolean[] zArr, int i) {
        FPVariable fPVariable = dPTableUnit2.variable;
        boolean z = zArr[fPVariable.from];
        boolean z2 = zArr[fPVariable.to];
        int min = Math.min(i, dPTableUnit2.capaOne);
        if (!$assertionsDisabled && dPTableUnit2.capaOne != dPTableUnit2.capaZero) {
            throw new AssertionError();
        }
        dPTableUnit.reserve(min + 1);
        for (int i2 = 0; i2 < min; i2++) {
            dPTableUnit.forOne[i2 + (z ? 0 : 1)] = fPVariable.PII.multiply(dPTableUnit2.forOne[i2]).add(fPVariable.PoI.multiply(dPTableUnit2.forZero[i2]));
            dPTableUnit.forZero[i2 + (z ? 1 : 0)] = fPVariable.PIo.multiply(dPTableUnit2.forOne[i2]).add(fPVariable.Poo.multiply(dPTableUnit2.forZero[i2]));
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forOne)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forZero)) {
            throw new AssertionError();
        }
        dPTableUnit2.clear();
    }

    private void leafUnit(DPTableUnit dPTableUnit, boolean[] zArr) {
        FPVariable fPVariable = dPTableUnit.variable;
        boolean z = zArr[dPTableUnit.variable.to];
        dPTableUnit.reserve(2);
        if (z) {
            dPTableUnit.forOne[0] = fPVariable.I;
            dPTableUnit.forZero[1] = fPVariable.o;
        } else {
            dPTableUnit.forOne[1] = fPVariable.I;
            dPTableUnit.forZero[0] = fPVariable.o;
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forOne)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forZero)) {
            throw new AssertionError();
        }
    }

    private void leafPlatt(DPTableUnit dPTableUnit, boolean[] zArr, int[] iArr, int[] iArr2) {
        FPVariable fPVariable = dPTableUnit.variable;
        int i = dPTableUnit.variable.to;
        dPTableUnit.reserve(iArr[i] + 1, iArr2[i] + 1);
        dPTableUnit.forOne[iArr2[i]] = fPVariable.I;
        dPTableUnit.forZero[iArr[i]] = fPVariable.o;
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forOne)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forZero)) {
            throw new AssertionError();
        }
    }

    private void innerVertexPlatt(DPTableUnit dPTableUnit, DPTableUnit dPTableUnit2, boolean[] zArr, int[] iArr, int[] iArr2, int i) {
        FPVariable fPVariable = dPTableUnit2.variable;
        int i2 = dPTableUnit2.variable.from;
        dPTableUnit.reserve(Math.min(i, Math.max(dPTableUnit2.capaOne - 1, dPTableUnit2.capaZero - 1) + iArr[i2]) + 1, Math.min(i, Math.max(dPTableUnit2.capaOne - 1, dPTableUnit2.capaZero - 1) + iArr2[i2]) + 1);
        int min = Math.min(i + 1, dPTableUnit2.capaZero);
        for (int i3 = 0; i3 < min; i3++) {
            if (i3 + iArr2[i2] <= i) {
                dPTableUnit.forOne[i3 + iArr2[i2]] = dPTableUnit.forOne[i3 + iArr2[i2]].add(fPVariable.PoI.multiply(dPTableUnit2.forZero[i3]));
            }
            if (i3 + iArr[i2] <= i) {
                dPTableUnit.forZero[i3 + iArr[i2]] = dPTableUnit.forZero[i3 + iArr[i2]].add(fPVariable.Poo.multiply(dPTableUnit2.forZero[i3]));
            }
        }
        int min2 = Math.min(i + 1, dPTableUnit2.capaOne);
        for (int i4 = 0; i4 < min2; i4++) {
            if (i4 + iArr2[i2] <= i) {
                dPTableUnit.forOne[i4 + iArr2[i2]] = dPTableUnit.forOne[i4 + iArr2[i2]].add(fPVariable.PII.multiply(dPTableUnit2.forOne[i4]));
            }
            if (i4 + iArr[i2] <= i) {
                dPTableUnit.forZero[i4 + iArr[i2]] = dPTableUnit.forZero[i4 + iArr[i2]].add(fPVariable.PIo.multiply(dPTableUnit2.forOne[i4]));
            }
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forOne)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forZero)) {
            throw new AssertionError();
        }
        dPTableUnit2.clear();
    }

    private void multipleChildrenVertexPlatt(DPTableUnit dPTableUnit, DPTableUnit[] dPTableUnitArr, boolean[] zArr, int[] iArr, int[] iArr2, int i) {
        DPTableUnit[] dPTableUnitArr2 = new DPTableUnit[dPTableUnitArr.length];
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < dPTableUnitArr.length; i4++) {
            dPTableUnitArr2[i4] = new DPTableUnit(dPTableUnit.node);
            innerVertexPlatt(dPTableUnitArr2[i4], dPTableUnitArr[i4], zArr, iArr, iArr2, i);
            dPTableUnitArr[i4].clear();
            i2 += dPTableUnitArr2[i4].capaOne;
            i3 += dPTableUnitArr2[i4].capaZero;
        }
        DPTableUnit dPTableUnit2 = new DPTableUnit(dPTableUnit.node);
        dPTableUnit2.reserve(Math.min(i + 1, i3), Math.min(i + 1, i2));
        dPTableUnit.reserve(Math.min(i + 1, i3), Math.min(i + 1, i2));
        dPTableUnitArr2[0].copyTo(dPTableUnit2);
        for (int i5 = 1; i5 < dPTableUnitArr.length; i5++) {
            for (int i6 = 0; i6 < dPTableUnit2.capaOne; i6++) {
                if (!dPTableUnit2.forOne[i6].isZeroProbability()) {
                    for (int i7 = 0; i7 < dPTableUnitArr2[i5].capaOne && i6 + i7 <= i; i7++) {
                        dPTableUnit.forOne[i6 + i7] = dPTableUnit.forOne[i6 + i7].add(dPTableUnit2.forOne[i6].multiply(dPTableUnitArr2[i5].forOne[i7]));
                    }
                }
            }
            for (int i8 = 0; i8 < dPTableUnit2.capaZero; i8++) {
                if (!dPTableUnit2.forZero[i8].isZeroProbability()) {
                    for (int i9 = 0; i9 < dPTableUnitArr2[i5].capaZero && i8 + i9 <= i; i9++) {
                        dPTableUnit.forZero[i8 + i9] = dPTableUnit.forZero[i8 + i9].add(dPTableUnit2.forZero[i8].multiply(dPTableUnitArr2[i5].forZero[i9]));
                    }
                }
            }
            dPTableUnit.copyTo(dPTableUnit2);
            Arrays.fill(dPTableUnit.forZero, Probability.ZERO);
            Arrays.fill(dPTableUnit.forOne, Probability.ZERO);
        }
        dPTableUnit2.copyTo(dPTableUnit);
        for (DPTableUnit dPTableUnit3 : dPTableUnitArr2) {
            dPTableUnit3.clear();
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forOne)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !checkProbability(dPTableUnit.forZero)) {
            throw new AssertionError();
        }
    }

    private Probability rootPlatt(DPTableUnit dPTableUnit, int i) {
        Probability probability = Probability.ZERO;
        for (int i2 = 0; i2 <= i; i2++) {
            probability = probability.add(dPTableUnit.forOne[i2]).add(dPTableUnit.forZero[i2]);
        }
        return probability;
    }

    static {
        $assertionsDisabled = !TreeDP.class.desiredAssertionStatus();
    }
}
