package de.unijena.bioinf.lcms.adducts.assignment;

import de.unijena.bioinf.ChemistryBase.chem.MolecularFormula;
import de.unijena.bioinf.ChemistryBase.chem.PrecursorIonType;
import de.unijena.bioinf.lcms.adducts.AdductEdge;
import de.unijena.bioinf.lcms.adducts.AdductNode;
import de.unijena.bioinf.lcms.adducts.AdductRelationship;
import de.unijena.bioinf.lcms.adducts.IonType;
import de.unijena.bioinf.lcms.adducts.KnownMassDelta;
import de.unijena.bioinf.lcms.adducts.LossRelationship;
import de.unijena.bioinf.lcms.adducts.MultimereRelationship;
import de.unijena.bioinf.lcms.adducts.assignment.AdductBeamSearch;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:de/unijena/bioinf/lcms/adducts/assignment/OptimalAssignmentViaBeamSearch.class */
public class OptimalAssignmentViaBeamSearch implements SubnetworkResolver {

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:de/unijena/bioinf/lcms/adducts/assignment/OptimalAssignmentViaBeamSearch$CompatibilityEdge.class */
    public static class CompatibilityEdge {
        private final double score;
        private final CompatibilityNode from;
        private final CompatibilityNode to;
        private final AdductEdge underlyingEdge;
        private final int fromType;
        private final int toType;

        public CompatibilityEdge(CompatibilityNode compatibilityNode, CompatibilityNode compatibilityNode2, int i, int i2, AdductEdge adductEdge) {
            this.score = adductEdge.getScore();
            this.from = compatibilityNode;
            this.to = compatibilityNode2;
            this.fromType = i;
            this.toType = i2;
            this.underlyingEdge = adductEdge;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:de/unijena/bioinf/lcms/adducts/assignment/OptimalAssignmentViaBeamSearch$CompatibilityNode.class */
    public static class CompatibilityNode {
        AdductNode[] subnodes;
        PrecursorIonType[] ionTypes;
        List<CompatibilityEdge>[] edgesPerIonType;
        int index;

        public CompatibilityNode(int i, AdductNode[] adductNodeArr) {
            this.index = i;
            this.subnodes = adductNodeArr;
        }

        public String toString() {
            return Arrays.toString(this.subnodes) + " with modes: " + Arrays.toString(this.ionTypes);
        }
    }

    @Override // de.unijena.bioinf.lcms.adducts.assignment.SubnetworkResolver
    public AdductAssignment[] resolve(AdductNode[] adductNodeArr, int i) {
        CompatibilityNode[] transformGraphIntoCompatibilityGraph = transformGraphIntoCompatibilityGraph(adductNodeArr);
        if (transformGraphIntoCompatibilityGraph == null) {
            return null;
        }
        return beamSearch(adductNodeArr, transformGraphIntoCompatibilityGraph, i);
    }

    private AdductAssignment[] resolveCompatibilityNetwork(AdductNode[] adductNodeArr, CompatibilityNode[] compatibilityNodeArr, int[] iArr, int i) {
        Int2ObjectOpenHashMap<IonType> int2ObjectOpenHashMap = new Int2ObjectOpenHashMap<>();
        for (int i2 = 0; i2 < compatibilityNodeArr.length; i2++) {
            int i3 = iArr[i2];
            IonType ionType = i3 > 0 ? new IonType(compatibilityNodeArr[i2].ionTypes[i3 - 1], 1.0f, MolecularFormula.emptyFormula()) : new IonType(PrecursorIonType.unknown(i), 1.0f, MolecularFormula.emptyFormula());
            for (AdductNode adductNode : compatibilityNodeArr[i2].subnodes) {
                int2ObjectOpenHashMap.put(adductNode.getIndex(), ionType);
            }
        }
        spreadMultimere(adductNodeArr, int2ObjectOpenHashMap);
        spreadInsource(adductNodeArr, int2ObjectOpenHashMap);
        float f = 1.0f;
        MolecularFormula emptyFormula = MolecularFormula.emptyFormula();
        for (AdductNode adductNode2 : adductNodeArr) {
            IonType ionType2 = (IonType) int2ObjectOpenHashMap.get(adductNode2.getIndex());
            f = Math.min(f, ionType2.getMultimere());
            if (!ionType2.getInsource().isAllPositiveOrZero()) {
                emptyFormula = ionType2.getInsource().negate().union(emptyFormula);
            }
        }
        if (f != 1.0f || !emptyFormula.isEmpty()) {
            float f2 = 1.0f / f;
            for (AdductNode adductNode3 : adductNodeArr) {
                int2ObjectOpenHashMap.put(adductNode3.getIndex(), ((IonType) int2ObjectOpenHashMap.get(adductNode3.getIndex())).multiplyMultimere(f2).addInsource(emptyFormula));
            }
        }
        return (AdductAssignment[]) Arrays.stream(adductNodeArr).map(adductNode4 -> {
            return new AdductAssignment(new IonType[]{(IonType) int2ObjectOpenHashMap.get(adductNode4.getIndex())}, new double[]{1.0d});
        }).toArray(i4 -> {
            return new AdductAssignment[i4];
        });
    }

    private void spreadMultimere(AdductNode[] adductNodeArr, Int2ObjectOpenHashMap<IonType> int2ObjectOpenHashMap) {
        Int2ObjectOpenHashMap int2ObjectOpenHashMap2 = new Int2ObjectOpenHashMap();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < adductNodeArr.length; i++) {
            if (!int2ObjectOpenHashMap2.containsKey(adductNodeArr[i].getIndex())) {
                arrayList.add(adductNodeArr[i]);
                int2ObjectOpenHashMap2.put(adductNodeArr[i].getIndex(), (IonType) int2ObjectOpenHashMap.get(adductNodeArr[i].getIndex()));
                while (!arrayList.isEmpty()) {
                    AdductNode adductNode = (AdductNode) arrayList.remove(arrayList.size() - 1);
                    IonType ionType = (IonType) int2ObjectOpenHashMap2.get(adductNode.getIndex());
                    for (AdductEdge adductEdge : adductNode.getEdges()) {
                        for (KnownMassDelta knownMassDelta : adductEdge.getExplanations()) {
                            if (knownMassDelta instanceof MultimereRelationship) {
                                AdductNode other = adductEdge.getOther(adductNode);
                                if (((IonType) int2ObjectOpenHashMap2.get(other.getIndex())) == null) {
                                    float multiplicator = ((MultimereRelationship) knownMassDelta).getMultiplicator();
                                    if (adductNode == adductEdge.getRight()) {
                                        multiplicator = 1.0f / multiplicator;
                                    }
                                    int2ObjectOpenHashMap2.put(other.getIndex(), ((IonType) int2ObjectOpenHashMap.get(other.getIndex())).withMultimere(ionType.getMultimere() * multiplicator));
                                    arrayList.add(other);
                                }
                            }
                        }
                    }
                }
            }
        }
        int2ObjectOpenHashMap.putAll(int2ObjectOpenHashMap2);
    }

    private void spreadInsource(AdductNode[] adductNodeArr, Int2ObjectOpenHashMap<IonType> int2ObjectOpenHashMap) {
        Int2ObjectOpenHashMap int2ObjectOpenHashMap2 = new Int2ObjectOpenHashMap();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < adductNodeArr.length; i++) {
            if (!int2ObjectOpenHashMap2.containsKey(adductNodeArr[i].getIndex())) {
                arrayList.add(adductNodeArr[i]);
                int2ObjectOpenHashMap2.put(adductNodeArr[i].getIndex(), (IonType) int2ObjectOpenHashMap.get(adductNodeArr[i].getIndex()));
                while (!arrayList.isEmpty()) {
                    AdductNode adductNode = (AdductNode) arrayList.remove(arrayList.size() - 1);
                    IonType ionType = (IonType) int2ObjectOpenHashMap2.get(adductNode.getIndex());
                    for (AdductEdge adductEdge : adductNode.getEdges()) {
                        for (KnownMassDelta knownMassDelta : adductEdge.getExplanations()) {
                            if (knownMassDelta instanceof LossRelationship) {
                                AdductNode other = adductEdge.getOther(adductNode);
                                if (((IonType) int2ObjectOpenHashMap2.get(other.getIndex())) == null) {
                                    MolecularFormula formula = ((LossRelationship) knownMassDelta).getFormula();
                                    if (adductNode == adductEdge.getLeft()) {
                                        formula = formula.negate();
                                    }
                                    int2ObjectOpenHashMap2.put(other.getIndex(), ((IonType) int2ObjectOpenHashMap.get(other.getIndex())).withInsource(ionType.getInsource().add(formula)));
                                    arrayList.add(other);
                                }
                            }
                        }
                    }
                }
            }
        }
        int2ObjectOpenHashMap.putAll(int2ObjectOpenHashMap2);
    }

    private double evaluate(CompatibilityNode[] compatibilityNodeArr, int[] iArr) {
        double d = 0.0d;
        for (int i = 0; i < compatibilityNodeArr.length; i++) {
            int i2 = iArr[i];
            if (i2 != 0) {
                int i3 = i2 - 1;
                CompatibilityNode compatibilityNode = compatibilityNodeArr[i];
                for (int i4 = 0; i4 < compatibilityNode.edgesPerIonType[i3].size(); i4++) {
                    CompatibilityEdge compatibilityEdge = compatibilityNode.edgesPerIonType[i3].get(i4);
                    int i5 = iArr[compatibilityEdge.to.index] - 1;
                    if (compatibilityEdge.toType == i5) {
                        d += compatibilityEdge.score;
                    } else if (i5 >= 0) {
                        return Double.NEGATIVE_INFINITY;
                    }
                }
            }
        }
        return d;
    }

    private void compareBeamSearch(AdductNode[] adductNodeArr, CompatibilityNode[] compatibilityNodeArr, int i) {
        ArrayList arrayList = new ArrayList();
        for (CompatibilityNode compatibilityNode : compatibilityNodeArr) {
            for (List<CompatibilityEdge> list : compatibilityNode.edgesPerIonType) {
                for (CompatibilityEdge compatibilityEdge : list) {
                    if (compatibilityEdge.from.index < compatibilityEdge.to.index) {
                        arrayList.add(compatibilityEdge);
                    }
                }
            }
        }
        arrayList.sort(Comparator.comparingDouble(compatibilityEdge2 -> {
            return -compatibilityEdge2.score;
        }));
        AdductBeamSearch adductBeamSearch = new AdductBeamSearch(compatibilityNodeArr.length, 10);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            CompatibilityEdge compatibilityEdge3 = (CompatibilityEdge) it.next();
            adductBeamSearch.add(compatibilityEdge3.from.index, compatibilityEdge3.fromType + 1, compatibilityEdge3.to.index, compatibilityEdge3.toType + 1, compatibilityEdge3.score);
        }
        AdductBeamSearch.MatchNode[] topSolutions = adductBeamSearch.getTopSolutions();
        for (int i2 = 0; i2 < Math.min(3, topSolutions.length); i2++) {
            System.out.println(Arrays.toString(topSolutions[i2].assignment()) + "\t" + topSolutions[i2].score());
        }
        double score = topSolutions[0].score() - 3.0d;
        if (topSolutions.length > 1 && topSolutions[1].score() >= score) {
            System.out.println("######");
            for (int i3 = 0; i3 < topSolutions.length; i3++) {
                if (topSolutions[i3].score() >= score) {
                    prettyprint(compatibilityNodeArr, topSolutions[i3].assignment());
                }
            }
        }
        System.out.println("-----------------------");
    }

    private void prettyprint(CompatibilityNode[] compatibilityNodeArr, int[] iArr) {
        for (CompatibilityNode compatibilityNode : compatibilityNodeArr) {
            if (iArr[compatibilityNode.index] != 0) {
                for (CompatibilityEdge compatibilityEdge : compatibilityNode.edgesPerIonType[iArr[compatibilityNode.index] - 1]) {
                    CompatibilityNode compatibilityNode2 = compatibilityEdge.from;
                    CompatibilityNode compatibilityNode3 = compatibilityEdge.to;
                    if (iArr[compatibilityNode3.index] != 0 && iArr[compatibilityNode2.index] != 0 && iArr[compatibilityNode2.index] - 1 == compatibilityEdge.fromType && iArr[compatibilityNode3.index] - 1 == compatibilityEdge.toType) {
                        System.out.print(String.valueOf(compatibilityNode2.ionTypes[iArr[compatibilityNode2.index] - 1]) + " -> " + String.valueOf(compatibilityNode3.ionTypes[iArr[compatibilityNode3.index] - 1]) + " (" + compatibilityEdge.score + "),\t");
                    }
                }
            }
        }
        System.out.println();
    }

    private AdductAssignment[] beamSearch(AdductNode[] adductNodeArr, CompatibilityNode[] compatibilityNodeArr, int i) {
        ArrayList arrayList = new ArrayList();
        for (CompatibilityNode compatibilityNode : compatibilityNodeArr) {
            for (List<CompatibilityEdge> list : compatibilityNode.edgesPerIonType) {
                for (CompatibilityEdge compatibilityEdge : list) {
                    if (compatibilityEdge.from.index < compatibilityEdge.to.index) {
                        arrayList.add(compatibilityEdge);
                    }
                }
            }
        }
        arrayList.sort(Comparator.comparingDouble(compatibilityEdge2 -> {
            return -compatibilityEdge2.score;
        }));
        AdductBeamSearch adductBeamSearch = new AdductBeamSearch(compatibilityNodeArr.length, 10);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            CompatibilityEdge compatibilityEdge3 = (CompatibilityEdge) it.next();
            adductBeamSearch.add(compatibilityEdge3.from.index, compatibilityEdge3.fromType + 1, compatibilityEdge3.to.index, compatibilityEdge3.toType + 1, compatibilityEdge3.score);
        }
        return mergeTopResults(adductBeamSearch.getTopSolutions(), adductNodeArr, compatibilityNodeArr, i);
    }

    private AdductAssignment[] mergeTopResults(AdductBeamSearch.MatchNode[] matchNodeArr, AdductNode[] adductNodeArr, CompatibilityNode[] compatibilityNodeArr, int i) {
        if (matchNodeArr.length == 0) {
            return null;
        }
        double score = matchNodeArr[0].score();
        double d = score - 3.0d;
        AdductBeamSearch.MatchNode[] matchNodeArr2 = (AdductBeamSearch.MatchNode[]) Arrays.stream(matchNodeArr).takeWhile(matchNode -> {
            return matchNode.score() >= d;
        }).toArray(i2 -> {
            return new AdductBeamSearch.MatchNode[i2];
        });
        if (matchNodeArr2.length == 1) {
            return resolveCompatibilityNetwork(adductNodeArr, compatibilityNodeArr, matchNodeArr2[0].assignment(), i);
        }
        AdductAssignment[][] adductAssignmentArr = (AdductAssignment[][]) Arrays.stream(matchNodeArr2).map(matchNode2 -> {
            return resolveCompatibilityNetwork(adductNodeArr, compatibilityNodeArr, matchNode2.assignment(), i);
        }).toArray(i3 -> {
            return new AdductAssignment[i3];
        });
        double[] array = Arrays.stream(matchNodeArr2).mapToDouble((v0) -> {
            return v0.score();
        }).map(d2 -> {
            return Math.exp(d2 - score);
        }).toArray();
        AdductAssignment[] adductAssignmentArr2 = new AdductAssignment[adductNodeArr.length];
        for (int i4 = 0; i4 < adductNodeArr.length; i4++) {
            int i5 = i4;
            adductAssignmentArr2[i4] = AdductAssignment.merge(i, (AdductAssignment[]) Arrays.stream(adductAssignmentArr).map(adductAssignmentArr3 -> {
                return adductAssignmentArr3[i5];
            }).toArray(i6 -> {
                return new AdductAssignment[i6];
            }), array);
        }
        return adductAssignmentArr2;
    }

    private CompatibilityNode[] transformGraphIntoCompatibilityGraph(AdductNode[] adductNodeArr) {
        Int2IntOpenHashMap int2IntOpenHashMap = new Int2IntOpenHashMap();
        int2IntOpenHashMap.defaultReturnValue(-1);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        int i = -1;
        for (AdductNode adductNode : adductNodeArr) {
            if (int2IntOpenHashMap.get(adductNode.getIndex()) < 0) {
                i++;
                int2IntOpenHashMap.put(adductNode.getIndex(), i);
                arrayList2.add(new ArrayList());
                arrayList.add(new HashSet());
                ((ArrayList) arrayList2.get(i)).add(adductNode);
                arrayList3.add(adductNode);
            }
            while (!arrayList3.isEmpty()) {
                AdductNode adductNode2 = (AdductNode) arrayList3.remove(arrayList3.size() - 1);
                for (AdductEdge adductEdge : adductNode2.getEdges()) {
                    if (adductEdge.isAdductEdge()) {
                        ((HashSet) arrayList.get(i)).add(adductEdge);
                    } else {
                        AdductNode other = adductEdge.getOther(adductNode2);
                        if (int2IntOpenHashMap.get(other.getIndex()) < 0) {
                            arrayList3.add(other);
                            int2IntOpenHashMap.put(other.getIndex(), i);
                            ((ArrayList) arrayList2.get(i)).add(other);
                        }
                    }
                }
            }
        }
        if (i <= 0) {
            return null;
        }
        CompatibilityNode[] compatibilityNodeArr = new CompatibilityNode[i + 1];
        for (int i2 = 0; i2 <= i; i2++) {
            compatibilityNodeArr[i2] = new CompatibilityNode(i2, (AdductNode[]) ((ArrayList) arrayList2.get(i2)).toArray(i3 -> {
                return new AdductNode[i3];
            }));
        }
        Object2IntOpenHashMap[] object2IntOpenHashMapArr = new Object2IntOpenHashMap[i + 1];
        int i4 = 0;
        while (i4 <= i) {
            HashSet hashSet = (HashSet) arrayList.get(i4);
            Object2IntOpenHashMap object2IntOpenHashMap = new Object2IntOpenHashMap();
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                AdductEdge adductEdge2 = (AdductEdge) it.next();
                int i5 = int2IntOpenHashMap.get(adductEdge2.getLeft().getIndex());
                for (KnownMassDelta knownMassDelta : adductEdge2.getExplanations()) {
                    if (knownMassDelta instanceof AdductRelationship) {
                        PrecursorIonType left = i5 == i4 ? ((AdductRelationship) knownMassDelta).getLeft() : ((AdductRelationship) knownMassDelta).getRight();
                        if (!object2IntOpenHashMap.containsKey(left)) {
                            object2IntOpenHashMap.put(left, object2IntOpenHashMap.size());
                        }
                    }
                }
            }
            CompatibilityNode compatibilityNode = compatibilityNodeArr[i4];
            compatibilityNode.ionTypes = new PrecursorIonType[object2IntOpenHashMap.size()];
            object2IntOpenHashMap.forEach((precursorIonType, num) -> {
                compatibilityNode.ionTypes[num.intValue()] = precursorIonType;
            });
            compatibilityNode.edgesPerIonType = new ArrayList[compatibilityNode.ionTypes.length];
            for (int i6 = 0; i6 < compatibilityNode.edgesPerIonType.length; i6++) {
                compatibilityNode.edgesPerIonType[i6] = new ArrayList();
            }
            object2IntOpenHashMapArr[i4] = object2IntOpenHashMap;
            i4++;
        }
        for (int i7 = 0; i7 <= i; i7++) {
            Iterator it2 = ((HashSet) arrayList.get(i7)).iterator();
            while (it2.hasNext()) {
                AdductEdge adductEdge3 = (AdductEdge) it2.next();
                if (int2IntOpenHashMap.get(adductEdge3.getLeft().getIndex()) == i7) {
                    for (KnownMassDelta knownMassDelta2 : adductEdge3.getExplanations()) {
                        if (knownMassDelta2 instanceof AdductRelationship) {
                            int i8 = int2IntOpenHashMap.get(adductEdge3.getRight().getIndex());
                            int i9 = object2IntOpenHashMapArr[i7].getInt(((AdductRelationship) knownMassDelta2).getLeft());
                            compatibilityNodeArr[i7].edgesPerIonType[i9].add(new CompatibilityEdge(compatibilityNodeArr[i7], compatibilityNodeArr[i8], i9, object2IntOpenHashMapArr[i8].getInt(((AdductRelationship) knownMassDelta2).getRight()), adductEdge3));
                        }
                    }
                } else {
                    for (KnownMassDelta knownMassDelta3 : adductEdge3.getExplanations()) {
                        if (knownMassDelta3 instanceof AdductRelationship) {
                            int i10 = int2IntOpenHashMap.get(adductEdge3.getLeft().getIndex());
                            int i11 = object2IntOpenHashMapArr[i7].getInt(((AdductRelationship) knownMassDelta3).getRight());
                            compatibilityNodeArr[i7].edgesPerIonType[i11].add(new CompatibilityEdge(compatibilityNodeArr[i7], compatibilityNodeArr[i10], i11, object2IntOpenHashMapArr[i10].getInt(((AdductRelationship) knownMassDelta3).getLeft()), adductEdge3));
                        }
                    }
                }
            }
        }
        return compatibilityNodeArr;
    }
}
