package de.unijena.bioinf.IsotopePatternAnalysis.prediction;

import de.unijena.bioinf.ChemistryBase.chem.Element;
import de.unijena.bioinf.ChemistryBase.chem.PeriodicTable;
import de.unijena.bioinf.ChemistryBase.ms.utils.SimpleSpectrum;
import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;

/* loaded from: input_file:de/unijena/bioinf/IsotopePatternAnalysis/prediction/TrainedElementDetectionNetwork.class */
class TrainedElementDetectionNetwork {
    private static final int INPUT_SIZE = 69;
    private static final int[] NEURONS = {48, 32, 5};
    private static final ActivationFunction[] ACTIVATION_FUNCTIONS = {ActivationFunction.TANH, ActivationFunction.TANH, ActivationFunction.LINEAR};
    private final int npeaks;
    private final Layer[] layers;

    /* loaded from: input_file:de/unijena/bioinf/IsotopePatternAnalysis/prediction/TrainedElementDetectionNetwork$ActivationFunction.class */
    protected enum ActivationFunction {
        LINEAR,
        RELU,
        TANH
    }

    /* loaded from: input_file:de/unijena/bioinf/IsotopePatternAnalysis/prediction/TrainedElementDetectionNetwork$ExponentialLayer.class */
    private static class ExponentialLayer implements Layer {
        @Override // de.unijena.bioinf.IsotopePatternAnalysis.prediction.TrainedElementDetectionNetwork.Layer
        public double[] activate(double[] dArr) {
            double[] dArr2 = new double[dArr.length];
            for (int i = 0; i < dArr.length; i++) {
                dArr2[i] = Math.exp(dArr[i]);
            }
            return dArr2;
        }
    }

    /* loaded from: input_file:de/unijena/bioinf/IsotopePatternAnalysis/prediction/TrainedElementDetectionNetwork$FullyConnectedLayer.class */
    protected static class FullyConnectedLayer implements Layer {
        protected final double[][] W;
        protected final double[] b;
        protected final ActivationFunction function;

        public FullyConnectedLayer(double[][] dArr, double[] dArr2, ActivationFunction activationFunction) {
            this.W = dArr;
            this.b = dArr2;
            this.function = activationFunction;
        }

        @Override // de.unijena.bioinf.IsotopePatternAnalysis.prediction.TrainedElementDetectionNetwork.Layer
        public double[] activate(double[] dArr) {
            double[] dArr2 = new double[this.W.length];
            for (int i = 0; i < this.W.length; i++) {
                double[] dArr3 = this.W[i];
                double d = this.b[i];
                for (int i2 = 0; i2 < dArr3.length; i2++) {
                    d += dArr[i2] * dArr3[i2];
                }
                switch (this.function) {
                    case RELU:
                        d = Math.max(0.0d, d);
                        break;
                    case TANH:
                        d = Math.tanh(d);
                        break;
                }
                dArr2[i] = d;
            }
            return dArr2;
        }
    }

    /* loaded from: input_file:de/unijena/bioinf/IsotopePatternAnalysis/prediction/TrainedElementDetectionNetwork$Layer.class */
    protected interface Layer {
        double[] activate(double[] dArr);
    }

    /* loaded from: input_file:de/unijena/bioinf/IsotopePatternAnalysis/prediction/TrainedElementDetectionNetwork$PlattSigmoidLayer.class */
    protected static class PlattSigmoidLayer implements Layer {
        protected final double[] As;
        protected final double[] Bs;

        public PlattSigmoidLayer(double[] dArr, double[] dArr2) {
            this.As = dArr;
            this.Bs = dArr2;
        }

        @Override // de.unijena.bioinf.IsotopePatternAnalysis.prediction.TrainedElementDetectionNetwork.Layer
        public double[] activate(double[] dArr) {
            double[] dArr2 = new double[dArr.length];
            for (int i = 0; i < this.As.length; i++) {
                dArr2[i] = sigmoid_predict(dArr[i], this.As[i], this.Bs[i]);
            }
            return dArr2;
        }

        private static double sigmoid_predict(double d, double d2, double d3) {
            double d4 = (d * d2) + d3;
            return d4 >= 0.0d ? Math.exp(-d4) / (1.0d + Math.exp(-d4)) : 1.0d / (1.0d + Math.exp(d4));
        }
    }

    /* loaded from: input_file:de/unijena/bioinf/IsotopePatternAnalysis/prediction/TrainedElementDetectionNetwork$PreprocessingLayer.class */
    protected static class PreprocessingLayer implements Layer {
        protected final double[] centering;
        protected final double[] scaling;

        public PreprocessingLayer(double[] dArr, double[] dArr2) {
            this.centering = dArr;
            this.scaling = dArr2;
        }

        @Override // de.unijena.bioinf.IsotopePatternAnalysis.prediction.TrainedElementDetectionNetwork.Layer
        public double[] activate(double[] dArr) {
            double[] dArr2 = new double[dArr.length];
            for (int i = 0; i < this.centering.length; i++) {
                dArr2[i] = (dArr[i] - this.centering[i]) / this.scaling[i];
            }
            return dArr2;
        }
    }

    public int numberOfPeaks() {
        return this.npeaks;
    }

    public static TrainedElementDetectionNetwork readRegressionNetwork(InputStream inputStream) throws IOException {
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(inputStream));
        Throwable th = null;
        try {
            PeriodicTable periodicTable = PeriodicTable.getInstance();
            int readInt = dataInputStream.readInt();
            int readInt2 = dataInputStream.readInt();
            Element[] elementArr = new Element[dataInputStream.readInt()];
            for (int i = 0; i < elementArr.length; i++) {
                elementArr[i] = periodicTable.get(dataInputStream.readInt());
            }
            int readInt3 = dataInputStream.readInt();
            int[] iArr = new int[readInt3];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = dataInputStream.readInt();
            }
            double[] dArr = new double[dataInputStream.readInt()];
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr[i3] = dataInputStream.readDouble();
            }
            Layer[] layerArr = new Layer[readInt3 + 2];
            int i4 = 0;
            double[] dArr2 = new double[readInt2];
            double[] dArr3 = new double[readInt2];
            for (int i5 = 0; i5 < readInt2; i5++) {
                int i6 = i4;
                i4++;
                dArr2[i5] = dArr[i6];
            }
            for (int i7 = 0; i7 < readInt2; i7++) {
                int i8 = i4;
                i4++;
                dArr3[i7] = dArr[i8];
            }
            layerArr[0] = new PreprocessingLayer(dArr2, dArr3);
            int i9 = readInt2;
            int i10 = 0;
            while (i10 < readInt3) {
                double[][] dArr4 = new double[iArr[i10]][i9];
                double[] dArr5 = new double[iArr[i10]];
                for (double[] dArr6 : dArr4) {
                    for (int i11 = 0; i11 < dArr6.length; i11++) {
                        int i12 = i4;
                        i4++;
                        dArr6[i11] = dArr[i12];
                    }
                }
                for (int i13 = 0; i13 < dArr5.length; i13++) {
                    int i14 = i4;
                    i4++;
                    dArr5[i13] = dArr[i14];
                }
                i9 = iArr[i10];
                layerArr[i10 + 1] = new FullyConnectedLayer(dArr4, dArr5, i10 < readInt3 - 1 ? ActivationFunction.TANH : ActivationFunction.LINEAR);
                i10++;
            }
            layerArr[readInt3 + 1] = new ExponentialLayer();
            TrainedElementDetectionNetwork trainedElementDetectionNetwork = new TrainedElementDetectionNetwork(readInt, layerArr);
            if (dataInputStream != null) {
                if (0 != 0) {
                    try {
                        dataInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    dataInputStream.close();
                }
            }
            return trainedElementDetectionNetwork;
        } catch (Throwable th3) {
            if (dataInputStream != null) {
                if (0 != 0) {
                    try {
                        dataInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    dataInputStream.close();
                }
            }
            throw th3;
        }
    }

    public static TrainedElementDetectionNetwork readNetwork(InputStream inputStream) throws IOException {
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(inputStream));
        Throwable th = null;
        try {
            PeriodicTable periodicTable = PeriodicTable.getInstance();
            int readInt = dataInputStream.readInt();
            int readInt2 = dataInputStream.readInt();
            int readInt3 = dataInputStream.readInt();
            Element[] elementArr = new Element[readInt3];
            double[] dArr = new double[readInt3];
            double[] dArr2 = new double[readInt3];
            for (int i = 0; i < elementArr.length; i++) {
                elementArr[i] = periodicTable.get(dataInputStream.readInt());
            }
            for (int i2 = 0; i2 < elementArr.length; i2++) {
                dArr[i2] = dataInputStream.readDouble();
            }
            for (int i3 = 0; i3 < elementArr.length; i3++) {
                dArr2[i3] = dataInputStream.readDouble();
            }
            int readInt4 = dataInputStream.readInt();
            int[] iArr = new int[readInt4];
            for (int i4 = 0; i4 < iArr.length; i4++) {
                iArr[i4] = dataInputStream.readInt();
            }
            double[] dArr3 = new double[dataInputStream.readInt()];
            for (int i5 = 0; i5 < dArr3.length; i5++) {
                dArr3[i5] = dataInputStream.readDouble();
            }
            Layer[] layerArr = new Layer[readInt4 + 2];
            int i6 = 0;
            double[] dArr4 = new double[readInt2];
            double[] dArr5 = new double[readInt2];
            for (int i7 = 0; i7 < readInt2; i7++) {
                int i8 = i6;
                i6++;
                dArr4[i7] = dArr3[i8];
            }
            for (int i9 = 0; i9 < readInt2; i9++) {
                int i10 = i6;
                i6++;
                dArr5[i9] = dArr3[i10];
            }
            layerArr[0] = new PreprocessingLayer(dArr4, dArr5);
            int i11 = readInt2;
            int i12 = 0;
            while (i12 < readInt4) {
                double[][] dArr6 = new double[iArr[i12]][i11];
                double[] dArr7 = new double[iArr[i12]];
                for (double[] dArr8 : dArr6) {
                    for (int i13 = 0; i13 < dArr8.length; i13++) {
                        int i14 = i6;
                        i6++;
                        dArr8[i13] = dArr3[i14];
                    }
                }
                for (int i15 = 0; i15 < dArr7.length; i15++) {
                    int i16 = i6;
                    i6++;
                    dArr7[i15] = dArr3[i16];
                }
                i11 = iArr[i12];
                layerArr[i12 + 1] = new FullyConnectedLayer(dArr6, dArr7, i12 < readInt4 - 1 ? ActivationFunction.TANH : ActivationFunction.LINEAR);
                i12++;
            }
            layerArr[readInt4 + 1] = new PlattSigmoidLayer(dArr, dArr2);
            TrainedElementDetectionNetwork trainedElementDetectionNetwork = new TrainedElementDetectionNetwork(readInt, layerArr);
            if (dataInputStream != null) {
                if (0 != 0) {
                    try {
                        dataInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    dataInputStream.close();
                }
            }
            return trainedElementDetectionNetwork;
        } catch (Throwable th3) {
            if (dataInputStream != null) {
                if (0 != 0) {
                    try {
                        dataInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    dataInputStream.close();
                }
            }
            throw th3;
        }
    }

    protected TrainedElementDetectionNetwork(int i, Layer[] layerArr) {
        this.layers = layerArr;
        this.npeaks = i;
    }

    public double[] predict(SimpleSpectrum simpleSpectrum) {
        double[] featureVector = new FeatureVector(simpleSpectrum, this.npeaks).getFeatureVector(this.npeaks);
        for (int i = 0; i < this.layers.length; i++) {
            featureVector = this.layers[i].activate(featureVector);
        }
        return featureVector;
    }
}
