package de.unijena.bioinf.IsotopePatternAnalysis.prediction;

import de.unijena.bioinf.ChemistryBase.chem.ChemicalAlphabet;
import de.unijena.bioinf.ChemistryBase.chem.Element;
import de.unijena.bioinf.ChemistryBase.chem.FormulaConstraints;
import de.unijena.bioinf.ChemistryBase.chem.PeriodicTable;
import de.unijena.bioinf.ChemistryBase.ms.utils.SimpleSpectrum;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;

/* loaded from: input_file:de/unijena/bioinf/IsotopePatternAnalysis/prediction/DNNElementPredictor.class */
public class DNNElementPredictor implements ElementPredictor {
    protected ChemicalAlphabet alphabet;
    private static final Element[] DETECTABLE_ELEMENTS;
    private static final int[] UPPERBOUNDS;
    private static final int[] FREE_UPPERBOUNDS;
    private static final Element[] FREE_ELEMENTS;
    private static final Element SELENE;
    protected TrainedElementDetectionNetwork[] networks = readNetworks();
    protected double[] thresholds = new double[DETECTABLE_ELEMENTS.length];

    public DNNElementPredictor() {
        Arrays.fill(this.thresholds, 0.05d);
        setThreshold("Si", 0.5d);
        Element[] elementArr = new Element[FREE_ELEMENTS.length + DETECTABLE_ELEMENTS.length];
        System.arraycopy(FREE_ELEMENTS, 0, elementArr, 0, FREE_ELEMENTS.length);
        System.arraycopy(DETECTABLE_ELEMENTS, 0, elementArr, FREE_ELEMENTS.length, DETECTABLE_ELEMENTS.length);
        this.alphabet = new ChemicalAlphabet(elementArr);
    }

    public void disableSilicon() {
        setThreshold("Si", Double.POSITIVE_INFINITY);
    }

    public void setThreshold(double d) {
        Arrays.fill(this.thresholds, d);
    }

    public void setThreshold(String str, double d) {
        setThreshold(PeriodicTable.getInstance().getByName(str), d);
    }

    public void setThreshold(Element element, double d) {
        for (int i = 0; i < DETECTABLE_ELEMENTS.length; i++) {
            if (DETECTABLE_ELEMENTS[i].equals(element)) {
                this.thresholds[i] = d;
                return;
            }
        }
        throw new IllegalArgumentException(element.getSymbol() + " is not predictable");
    }

    private static TrainedElementDetectionNetwork[] readNetworks() {
        try {
            return new TrainedElementDetectionNetwork[]{TrainedElementDetectionNetwork.readNetwork(DNNElementPredictor.class.getResourceAsStream("/dnn_element_detection_5.param")), TrainedElementDetectionNetwork.readNetwork(DNNElementPredictor.class.getResourceAsStream("/dnn_element_detection_4.param")), TrainedElementDetectionNetwork.readNetwork(DNNElementPredictor.class.getResourceAsStream("/dnn_element_detection_3.param"))};
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // de.unijena.bioinf.IsotopePatternAnalysis.prediction.ElementPredictor
    public FormulaConstraints predictConstraints(SimpleSpectrum simpleSpectrum) {
        HashSet hashSet = new HashSet(10);
        hashSet.addAll(Arrays.asList(FREE_ELEMENTS));
        if (simpleSpectrum.size() > 5) {
            double d = 0.0d;
            for (int size = simpleSpectrum.size() - 1; size >= 5; size--) {
                d += simpleSpectrum.getIntensityAt(size);
            }
            double d2 = 0.0d;
            for (int i = 0; i < 5; i++) {
                d2 += simpleSpectrum.getIntensityAt(i);
            }
            if (d / d2 > 0.25d) {
                hashSet.add(SELENE);
            }
        }
        TrainedElementDetectionNetwork[] trainedElementDetectionNetworkArr = this.networks;
        int length = trainedElementDetectionNetworkArr.length;
        int i2 = 0;
        while (true) {
            if (i2 >= length) {
                break;
            }
            TrainedElementDetectionNetwork trainedElementDetectionNetwork = trainedElementDetectionNetworkArr[i2];
            if (trainedElementDetectionNetwork.numberOfPeaks() <= simpleSpectrum.size()) {
                double[] predict = trainedElementDetectionNetwork.predict(simpleSpectrum);
                for (int i3 = 0; i3 < predict.length; i3++) {
                    if (predict[i3] >= this.thresholds[i3]) {
                        hashSet.add(DETECTABLE_ELEMENTS[i3]);
                    }
                }
            } else {
                i2++;
            }
        }
        FormulaConstraints formulaConstraints = new FormulaConstraints(new ChemicalAlphabet((Element[]) hashSet.toArray(new Element[hashSet.size()])));
        for (int i4 = 0; i4 < FREE_UPPERBOUNDS.length; i4++) {
            formulaConstraints.setUpperbound(FREE_ELEMENTS[i4], FREE_UPPERBOUNDS[i4]);
        }
        for (int i5 = 0; i5 < UPPERBOUNDS.length; i5++) {
            if (hashSet.contains(DETECTABLE_ELEMENTS[i5])) {
                formulaConstraints.setUpperbound(DETECTABLE_ELEMENTS[i5], UPPERBOUNDS[i5]);
            }
        }
        return formulaConstraints;
    }

    @Override // de.unijena.bioinf.IsotopePatternAnalysis.prediction.ElementPredictor
    public ChemicalAlphabet getChemicalAlphabet() {
        return this.alphabet;
    }

    @Override // de.unijena.bioinf.IsotopePatternAnalysis.prediction.ElementPredictor
    public boolean isPredictable(Element element) {
        for (Element element2 : DETECTABLE_ELEMENTS) {
            if (element2.equals(element)) {
                return true;
            }
        }
        return false;
    }

    static {
        PeriodicTable periodicTable = PeriodicTable.getInstance();
        DETECTABLE_ELEMENTS = new Element[]{periodicTable.getByName("B"), periodicTable.getByName("Br"), periodicTable.getByName("Cl"), periodicTable.getByName("S"), periodicTable.getByName("Si"), periodicTable.getByName("Se")};
        FREE_ELEMENTS = new Element[]{periodicTable.getByName("C"), periodicTable.getByName("H"), periodicTable.getByName("N"), periodicTable.getByName("O"), periodicTable.getByName("P"), periodicTable.getByName("F"), periodicTable.getByName("I")};
        UPPERBOUNDS = new int[]{2, 5, 5, 10, 2, 2};
        FREE_UPPERBOUNDS = new int[]{Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, 10, 20, 6};
        SELENE = periodicTable.getByName("Se");
    }
}
