package de.unijena.bioinf.canopus;

import de.unijena.bioinf.ChemistryBase.fp.MaskedFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.NPCFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.utils.FileUtils;
import de.unijena.bioinf.canopus.dnn.ActivationFunction;
import de.unijena.bioinf.canopus.dnn.FullyConnectedLayer;
import de.unijena.bioinf.canopus.dnn.PlattLayer;
import java.io.BufferedWriter;
import java.io.Closeable;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPOutputStream;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

/* loaded from: input_file:de/unijena/bioinf/canopus/TensorflowModel.class */
public class TensorflowModel implements AutoCloseable, Closeable {
    protected static final String OUTPUT_LAYER = "final_output";
    protected String[] TRAINABLE_VARIABLES;
    protected Tensor regStren;
    protected static int nformulaLayers = 2;
    protected static int nplattLayers = 1;
    protected static int ninnerLayers = 2;
    protected int npcLayers;
    protected final Graph graph;
    protected final Session session;
    protected final boolean HAS_TRAINING_FLAG;
    protected String loss;
    protected String optimizer;
    protected Tensor in_training_tensor;
    protected Tensor not_in_training_tensor;
    protected int numberOfLayers = 0;
    protected float regularizerStrength = 0.0f;

    /* loaded from: input_file:de/unijena/bioinf/canopus/TensorflowModel$Resetter.class */
    public interface Resetter {
        void resetWeights();
    }

    public float getRegularizerStrength() {
        return this.regularizerStrength;
    }

    public void setRegularizerStrength(float f) {
        if (this.regularizerStrength != f) {
            this.regularizerStrength = f;
            this.regStren.close();
            this.regStren = Tensor.create(new float[]{f});
        }
    }

    protected void readTrainableLayers() {
        ArrayList arrayList = new ArrayList();
        arrayList.add("fully_connected/weights");
        arrayList.add("fully_connected/biases");
        this.numberOfLayers = 1;
        while (this.numberOfLayers < Integer.MAX_VALUE) {
            String str = "fully_connected_" + this.numberOfLayers + "/weights";
            if (this.graph.operation(str) == null) {
                break;
            }
            arrayList.add(str);
            arrayList.add("fully_connected_" + this.numberOfLayers + "/biases");
            this.numberOfLayers++;
        }
        if (this.graph.operation("npc/weights") != null && this.graph.operation("npc/biases") != null) {
            arrayList.add("npc/weights");
            arrayList.add("npc/biases");
            System.out.println("USE NPC in DNN");
            this.npcLayers = 1;
        }
        this.TRAINABLE_VARIABLES = (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    public Resetter feedWeightMatrices(Canopus canopus) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (FullyConnectedLayer fullyConnectedLayer : canopus.formulaLayers) {
            arrayList.add(fullyConnectedLayer.getWeightMatrixCopy());
            arrayList.add(fullyConnectedLayer.getBiasVectorCopy());
            arrayList2.add(new long[]{fullyConnectedLayer.getInputSize(), fullyConnectedLayer.getOutputSize()});
            arrayList2.add(new long[]{fullyConnectedLayer.getOutputSize()});
        }
        for (FullyConnectedLayer fullyConnectedLayer2 : canopus.fingerprintLayers) {
            arrayList.add(fullyConnectedLayer2.getWeightMatrixCopy());
            arrayList.add(fullyConnectedLayer2.getBiasVectorCopy());
            arrayList2.add(new long[]{fullyConnectedLayer2.getInputSize(), fullyConnectedLayer2.getOutputSize()});
            arrayList2.add(new long[]{fullyConnectedLayer2.getOutputSize()});
        }
        for (FullyConnectedLayer fullyConnectedLayer3 : canopus.innerLayers) {
            arrayList.add(fullyConnectedLayer3.getWeightMatrixCopy());
            arrayList.add(fullyConnectedLayer3.getBiasVectorCopy());
            arrayList2.add(new long[]{fullyConnectedLayer3.getInputSize(), fullyConnectedLayer3.getOutputSize()});
            arrayList2.add(new long[]{fullyConnectedLayer3.getOutputSize()});
        }
        arrayList.add(canopus.outputLayer.getWeightMatrixCopy());
        arrayList.add(canopus.outputLayer.getBiasVectorCopy());
        arrayList2.add(new long[]{canopus.outputLayer.getInputSize(), canopus.outputLayer.getOutputSize()});
        arrayList2.add(new long[]{canopus.outputLayer.getOutputSize()});
        final ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < this.TRAINABLE_VARIABLES.length; i++) {
            Tensor create = Tensor.create((long[]) arrayList2.get(i), FloatBuffer.wrap((float[]) arrayList.get(i)));
            arrayList3.add(this.graph.opBuilder("Assign", "Assign/" + this.TRAINABLE_VARIABLES[i]).addInput(this.graph.operation(this.TRAINABLE_VARIABLES[i]).output(0)).addInput(this.graph.opBuilder("Const", "MyConst/" + i).setAttr("dtype", create.dataType()).setAttr("value", create).build().output(0)).build());
            create.close();
        }
        return new Resetter() { // from class: de.unijena.bioinf.canopus.TensorflowModel.1
            @Override // de.unijena.bioinf.canopus.TensorflowModel.Resetter
            public void resetWeights() {
                Iterator it = arrayList3.iterator();
                while (it.hasNext()) {
                    TensorflowModel.this.session.runner().fetch(((Operation) it.next()).output(0)).run();
                }
            }
        };
    }

    public TensorflowModel(File file) {
        SavedModelBundle load = SavedModelBundle.load(file.getAbsolutePath(), new String[0]);
        this.graph = load.graph();
        this.session = load.session();
        this.in_training_tensor = Tensor.create(true);
        this.regStren = Tensor.create(new float[]{0.0f});
        this.not_in_training_tensor = Tensor.create(false);
        this.HAS_TRAINING_FLAG = this.graph.operation("in_training") != null;
        readTrainableLayers();
        System.out.println("IN TRAINING? " + this.HAS_TRAINING_FLAG);
        System.out.println("Use Activiation: " + getActivationFunction().getClass().getSimpleName());
        System.out.println("Number of Layers: " + this.numberOfLayers);
        if (this.graph.operation("hinge_loss/value") != null) {
            this.loss = "hinge_loss/value";
        } else if (this.graph.operation("sigmoid_cross_entropy_loss/value") != null) {
            this.loss = "sigmoid_cross_entropy_loss/value";
        } else {
            if (this.graph.operation("loss") == null) {
                throw new RuntimeException("Unknown loss function!");
            }
            this.loss = "loss";
        }
        if (this.graph.operation("Adam") != null) {
            this.optimizer = "Adam";
        } else {
            if (this.graph.operation("Momentum") == null) {
                throw new RuntimeException("Unknown optimizer");
            }
            this.optimizer = "Momentum";
        }
    }

    public Tensor predictTensor(TrainingBatch trainingBatch) {
        return (Tensor) feedTraining(this.session.runner(), false).feed("input_platts", trainingBatch.platts).feed("input_formulas", trainingBatch.formulas).fetch(OUTPUT_LAYER, 0).run().get(0);
    }

    protected Session.Runner feedTraining(Session.Runner runner, boolean z) {
        if (this.HAS_TRAINING_FLAG) {
            return runner.feed("in_training", 0, z ? this.in_training_tensor : this.not_in_training_tensor);
        }
        return runner;
    }

    public float[][] predict(TrainingBatch trainingBatch) {
        Tensor predictTensor = predictTensor(trainingBatch);
        float[][] fArr = new float[(int) predictTensor.shape()[0]][(int) predictTensor.shape()[1]];
        predictTensor.copyTo(fArr);
        predictTensor.close();
        return fArr;
    }

    public float[][] predictNPC(TrainingBatch trainingBatch) {
        Tensor tensor = (Tensor) feedTraining(this.session.runner(), false).feed("input_platts", trainingBatch.platts).feed("input_formulas", trainingBatch.formulas).fetch("npc_output", 0).run().get(0);
        float[][] fArr = new float[(int) tensor.shape()[0]][(int) tensor.shape()[1]];
        tensor.copyTo(fArr);
        tensor.close();
        return fArr;
    }

    private PredictionPerformance[] evaluatePerformance(TrainingBatch trainingBatch) {
        List run = feedTraining(this.session.runner(), false).feed("input_platts", trainingBatch.platts).feed("input_formulas", trainingBatch.formulas).fetch(OUTPUT_LAYER, 0).run();
        Tensor tensor = (Tensor) run.get(0);
        float[][] fArr = new float[(int) trainingBatch.labels.shape()[0]][(int) trainingBatch.labels.shape()[1]];
        trainingBatch.labels.copyTo(fArr);
        float[][] fArr2 = new float[(int) tensor.shape()[0]][(int) tensor.shape()[1]];
        tensor.copyTo(fArr2);
        Iterator it = run.iterator();
        while (it.hasNext()) {
            ((Tensor) it.next()).close();
        }
        PredictionPerformance.Modify[] modifyArr = new PredictionPerformance.Modify[fArr2[0].length];
        for (int i = 0; i < modifyArr.length; i++) {
            modifyArr[i] = new PredictionPerformance(0.0d, 0.0d, 0.0d, 0.0d, 0.0d).modify();
        }
        for (int i2 = 0; i2 < fArr2.length; i2++) {
            float[] fArr3 = fArr[i2];
            for (int i3 = 0; i3 < fArr3.length; i3++) {
                modifyArr[i3].update(fArr3[i3] >= 0.0f, fArr2[i2][i3] >= 0.0f);
            }
        }
        PredictionPerformance[] predictionPerformanceArr = new PredictionPerformance[modifyArr.length];
        for (int i4 = 0; i4 < predictionPerformanceArr.length; i4++) {
            predictionPerformanceArr[i4] = modifyArr[i4].done();
        }
        return predictionPerformanceArr;
    }

    private PredictionPerformance[] evaluateNPCPerformance(TrainingBatch trainingBatch) {
        List run = feedTraining(this.session.runner(), false).feed("input_platts", trainingBatch.platts).feed("input_formulas", trainingBatch.formulas).fetch("npc_output", 0).run();
        Tensor tensor = (Tensor) run.get(0);
        float[][] fArr = new float[(int) trainingBatch.npcLabels.shape()[0]][(int) trainingBatch.npcLabels.shape()[1]];
        trainingBatch.npcLabels.copyTo(fArr);
        float[][] fArr2 = new float[(int) tensor.shape()[0]][(int) tensor.shape()[1]];
        tensor.copyTo(fArr2);
        Iterator it = run.iterator();
        while (it.hasNext()) {
            ((Tensor) it.next()).close();
        }
        PredictionPerformance.Modify[] modifyArr = new PredictionPerformance.Modify[fArr2[0].length];
        for (int i = 0; i < modifyArr.length; i++) {
            modifyArr[i] = new PredictionPerformance(0.0d, 0.0d, 0.0d, 0.0d, 0.0d).modify();
        }
        for (int i2 = 0; i2 < fArr2.length; i2++) {
            float[] fArr3 = fArr[i2];
            for (int i3 = 0; i3 < fArr3.length; i3++) {
                modifyArr[i3].update(fArr3[i3] >= 0.0f, fArr2[i2][i3] >= 0.0f);
            }
        }
        PredictionPerformance[] predictionPerformanceArr = new PredictionPerformance[modifyArr.length];
        for (int i4 = 0; i4 < predictionPerformanceArr.length; i4++) {
            predictionPerformanceArr[i4] = modifyArr[i4].done();
        }
        return predictionPerformanceArr;
    }

    public Report evaluate(TrainingBatch trainingBatch) {
        return new Report(evaluatePerformance(trainingBatch));
    }

    public Report evaluateNPC(TrainingBatch trainingBatch) {
        return new Report(evaluateNPCPerformance(trainingBatch));
    }

    public Report[] evaluateWithFingerprints(TrainingBatch trainingBatch, List<DummyMolecularProperty> list, int[] iArr) {
        list.size();
        PredictionPerformance[] evaluatePerformance = evaluatePerformance(trainingBatch);
        PredictionPerformance[] predictionPerformanceArr = new PredictionPerformance[evaluatePerformance.length - list.size()];
        PredictionPerformance[] predictionPerformanceArr2 = new PredictionPerformance[list.size()];
        System.arraycopy(evaluatePerformance, 0, predictionPerformanceArr, 0, predictionPerformanceArr.length);
        System.arraycopy(evaluatePerformance, predictionPerformanceArr.length, predictionPerformanceArr2, 0, predictionPerformanceArr2.length);
        PredictionPerformance[] predictionPerformanceArr3 = new PredictionPerformance[iArr.length];
        int i = 0;
        for (int i2 : iArr) {
            int i3 = i;
            i++;
            predictionPerformanceArr3[i3] = predictionPerformanceArr2[i2];
        }
        return new Report[]{new Report(predictionPerformanceArr), new Report(predictionPerformanceArr2), new Report(predictionPerformanceArr3)};
    }

    public ActivationFunction getActivationFunction() {
        if (this.graph.operation("fully_connected/Relu") != null) {
            return new ActivationFunction.ReLu();
        }
        if (this.graph.operation("fully_connected/Tanh") != null) {
            return new ActivationFunction.Tanh();
        }
        if (this.graph.operation("fully_connected/Selu") != null) {
            return new ActivationFunction.SELU();
        }
        throw new IllegalArgumentException("Unknown activation function");
    }

    public double[] train(TrainingBatch trainingBatch) {
        return train(trainingBatch.platts, trainingBatch.formulas, trainingBatch.labels);
    }

    public double[] train(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        List run = feedTraining(this.session.runner(), true).feed("input_platts", tensor).feed("input_formulas", tensor2).feed("input_labels", tensor3).feed("regstren", this.regStren).fetch(this.loss, 0).fetch("regularization", 0).fetch(this.optimizer).run();
        double floatValue = ((Tensor) run.get(0)).floatValue();
        double floatValue2 = ((Tensor) run.get(1)).floatValue();
        Iterator it = run.iterator();
        while (it.hasNext()) {
            ((Tensor) it.next()).close();
        }
        return new double[]{floatValue, floatValue2};
    }

    public double[] trainWithGradient(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        List run = feedTraining(this.session.runner(), true).feed("input_platts", tensor).feed("input_formulas", tensor2).feed("input_labels", tensor3).feed("regstren", this.regStren).fetch(this.loss, 0).fetch("regularization", 0).fetch("gradient_sum", 0).fetch(this.optimizer).run();
        double floatValue = ((Tensor) run.get(0)).floatValue();
        double floatValue2 = ((Tensor) run.get(1)).floatValue();
        double floatValue3 = ((Tensor) run.get(2)).floatValue();
        Iterator it = run.iterator();
        while (it.hasNext()) {
            ((Tensor) it.next()).close();
        }
        return new double[]{floatValue, floatValue2, floatValue3};
    }

    public double[] train_npc(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        List run = feedTraining(this.session.runner(), true).feed("input_platts", tensor).feed("input_formulas", tensor2).feed("input_labels", tensor3).feed("npc_labels", tensor4).feed("regstren", this.regStren).fetch("npc_loss", 0).fetch(this.loss, 0).fetch("npc_regularization", 0).fetch("npc_op").run();
        double floatValue = ((Tensor) run.get(0)).floatValue();
        double floatValue2 = ((Tensor) run.get(1)).floatValue();
        double floatValue3 = ((Tensor) run.get(2)).floatValue();
        Iterator it = run.iterator();
        while (it.hasNext()) {
            ((Tensor) it.next()).close();
        }
        return new double[]{floatValue, floatValue2, floatValue3};
    }

    public void saveWithoutPlattEstimate(TrainingData trainingData, int i, boolean z, boolean z2, boolean z3, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) throws IOException {
        saveWithoutPlattEstimate(z3 ? "final" : "notrained", trainingData, i, z, z2, z3, dArr, dArr2, dArr3, dArr4);
    }

    public void saveWithoutPlattEstimate(String str, TrainingData trainingData, int i, boolean z, boolean z2, boolean z3, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) throws IOException {
        double[] dArr5;
        double[] dArr6;
        File file = new File("saved_model_" + str + "_" + i);
        if (!file.exists()) {
            file.mkdirs();
        }
        IntBuffer allocate = IntBuffer.allocate(1);
        allocate.put(i);
        allocate.rewind();
        Tensor create = Tensor.create(new long[0], allocate);
        Tensor tensor = (Tensor) this.session.runner().fetch("save_model_with_id", 0).feed("model_id", create).run().get(0);
        List run = this.session.runner().fetch("save/control_dependency", 0).feed("save/Const", tensor).run();
        create.close();
        tensor.close();
        Iterator it = run.iterator();
        while (it.hasNext()) {
            ((Tensor) it.next()).close();
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        if (z3) {
            System.out.println("TRAIN MISSING COMPOUNDS (canopus.data only)");
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(trainingData.crossvalidation);
            if (trainingData.independent != null) {
                arrayList.addAll(trainingData.independent);
            }
            int i2 = 35000;
            TrainingBatch generateNPCBatch = trainingData.generateNPCBatch(trainingData.npcInstances);
            try {
                TrainingBatch generateBatch = trainingData.generateBatch(arrayList);
                for (int i3 = 0; i3 < 200; i3++) {
                    try {
                        train(generateBatch);
                        int i4 = i2;
                        i2++;
                        TrainingBatch generateBatch2 = trainingData.generateBatch(i4, null, newFixedThreadPool);
                        try {
                            train(generateBatch2);
                            if (generateBatch2 != null) {
                                generateBatch2.close();
                            }
                            if (trainingData.isNPC()) {
                                train_npc(generateNPCBatch.platts, generateNPCBatch.formulas, generateNPCBatch.labels, generateNPCBatch.npcLabels);
                            }
                        } finally {
                        }
                    } finally {
                    }
                }
                if (generateBatch != null) {
                    generateBatch.close();
                }
                if (generateNPCBatch != null) {
                    generateNPCBatch.close();
                }
            } catch (Throwable th) {
                if (generateNPCBatch != null) {
                    try {
                        generateNPCBatch.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        if (z || z2) {
            int i5 = 0;
            for (String str2 : this.TRAINABLE_VARIABLES) {
                Tensor tensor2 = (Tensor) this.session.runner().fetch(str2, 0).run().get(0);
                if (tensor2.shape().length == 2) {
                    float[][] fArr = new float[(int) tensor2.shape()[0]][(int) tensor2.shape()[1]];
                    tensor2.copyTo(fArr);
                    if (z) {
                        FileUtils.writeFloatMatrix(new File(file, String.valueOf(i5) + ".matrix"), fArr);
                    }
                    if (z2) {
                        arrayList2.add(fArr);
                    }
                } else {
                    float[] fArr2 = new float[(int) tensor2.shape()[0]];
                    tensor2.copyTo(fArr2);
                    if (z) {
                        BufferedWriter writer = FileUtils.getWriter(new File(file, String.valueOf(i5) + ".matrix"));
                        try {
                            for (float f : fArr2) {
                                writer.write(String.valueOf(f));
                                writer.newLine();
                            }
                            if (writer != null) {
                                writer.close();
                            }
                        } catch (Throwable th3) {
                            if (writer != null) {
                                try {
                                    writer.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            }
                            throw th3;
                        }
                    }
                    if (z2) {
                        arrayList3.add(fArr2);
                    }
                }
                tensor2.close();
                i5++;
            }
        }
        if (z2) {
            ActivationFunction.Identity activationFunction = getActivationFunction();
            ArrayList arrayList4 = new ArrayList();
            int size = (arrayList2.size() - 1) - this.npcLayers;
            int i6 = 0;
            while (i6 < arrayList2.size()) {
                arrayList4.add(new FullyConnectedLayer((float[][]) arrayList2.get(i6), (float[]) arrayList3.get(i6), i6 == size ? new ActivationFunction.Identity() : activationFunction));
                i6++;
            }
            FullyConnectedLayer[] fullyConnectedLayerArr = new FullyConnectedLayer[nformulaLayers];
            for (int i7 = 0; i7 < fullyConnectedLayerArr.length; i7++) {
                fullyConnectedLayerArr[i7] = (FullyConnectedLayer) arrayList4.remove(0);
            }
            FullyConnectedLayer[] fullyConnectedLayerArr2 = new FullyConnectedLayer[nplattLayers];
            for (int i8 = 0; i8 < fullyConnectedLayerArr2.length; i8++) {
                fullyConnectedLayerArr2[i8] = (FullyConnectedLayer) arrayList4.remove(0);
            }
            FullyConnectedLayer[] fullyConnectedLayerArr3 = new FullyConnectedLayer[ninnerLayers];
            for (int i9 = 0; i9 < fullyConnectedLayerArr3.length; i9++) {
                fullyConnectedLayerArr3[i9] = (FullyConnectedLayer) arrayList4.remove(0);
            }
            String[] strArr = new String[trainingData.nlabels];
            ArrayList arrayList5 = new ArrayList(trainingData.compoundClasses.valueCollection());
            Collections.sort(arrayList5, new Comparator<CompoundClass>() { // from class: de.unijena.bioinf.canopus.TensorflowModel.2
                @Override // java.util.Comparator
                public int compare(CompoundClass compoundClass, CompoundClass compoundClass2) {
                    return Integer.compare(compoundClass.index, compoundClass2.index);
                }
            });
            int i10 = 0;
            for (int i11 = 0; i11 < arrayList5.size(); i11++) {
                if (trainingData.classyFireMask.hasProperty(((CompoundClass) arrayList5.get(i11)).index)) {
                    int i12 = i10;
                    i10++;
                    strArr[i12] = ((CompoundClass) arrayList5.get(i11)).ontology.getName();
                }
            }
            if (TrainingData.VECNORM_SCALING) {
                dArr5 = (double[]) trainingData.plattNorm.clone();
                dArr6 = (double[]) trainingData.plattScale.clone();
            } else {
                dArr5 = (double[]) trainingData.plattNorm.clone();
                dArr6 = new double[trainingData.nplatts];
                Arrays.fill(dArr6, 1.0d);
            }
            Canopus canopus = new Canopus(fullyConnectedLayerArr, fullyConnectedLayerArr2, fullyConnectedLayerArr3, (FullyConnectedLayer) arrayList4.remove(0), new PlattLayer(dArr, dArr2), trainingData.formulaNorm, trainingData.formulaScale, dArr5, dArr6, trainingData.classyFireMask, (MaskedFingerprintVersion) null, trainingData.isNPC() ? MaskedFingerprintVersion.allowAll(NPCFingerprintVersion.get()) : null, trainingData.isNPC() ? trainingData.isNPC() ? new FullyConnectedLayer((float[][]) arrayList2.get(arrayList2.size() - 1), (float[]) arrayList3.get(arrayList3.size() - 1), new ActivationFunction.Identity()) : null : null, trainingData.isNPC() ? new PlattLayer(dArr3, dArr4) : null);
            GZIPOutputStream gZIPOutputStream = new GZIPOutputStream(new FileOutputStream(new File("canopus_" + (z3 ? "final_" : "") + i + ".data.gz")));
            try {
                canopus.dump(gZIPOutputStream);
                gZIPOutputStream.close();
            } catch (Throwable th5) {
                try {
                    gZIPOutputStream.close();
                } catch (Throwable th6) {
                    th5.addSuppressed(th6);
                }
                throw th5;
            }
        }
        newFixedThreadPool.shutdown();
    }

    public double[][] plattEstimate(TrainingData trainingData) {
        return plattEstimate(trainingData, true);
    }

    /* JADX WARN: Type inference failed for: r0v35, types: [double[], double[][]] */
    public double[][] plattEstimateForNPC(TrainingData trainingData, boolean z) {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        TrainingBatch fillUpWithTrainDataNPC = trainingData.fillUpWithTrainDataNPC(z);
        try {
            final float[][] predictNPC = predictNPC(fillUpWithTrainDataNPC);
            final float[][] fArr = new float[predictNPC.length][predictNPC[0].length];
            fillUpWithTrainDataNPC.npcLabels.copyTo(fArr);
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < predictNPC[0].length; i++) {
                final int i2 = i;
                arrayList.add(newFixedThreadPool.submit(new Callable<double[]>() { // from class: de.unijena.bioinf.canopus.TensorflowModel.3
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public double[] call() throws Exception {
                        double[] dArr = new double[predictNPC.length];
                        for (int i3 = 0; i3 < predictNPC.length; i3++) {
                            dArr[i3] = predictNPC[i3][i2];
                        }
                        double[] dArr2 = new double[predictNPC.length];
                        for (int i4 = 0; i4 < predictNPC.length; i4++) {
                            dArr2[i4] = fArr[i4][i2];
                        }
                        return PlattLayer.sigmoid_train(dArr, dArr2);
                    }
                }));
            }
            double[] dArr = new double[predictNPC[0].length];
            double[] dArr2 = new double[predictNPC[0].length];
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                try {
                    double[] dArr3 = (double[]) ((Future) arrayList.get(i3)).get();
                    dArr[i3] = dArr3[0];
                    dArr2[i3] = dArr3[1];
                } catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException(e);
                }
            }
            if (fillUpWithTrainDataNPC != null) {
                fillUpWithTrainDataNPC.close();
            }
            newFixedThreadPool.shutdown();
            return new double[]{dArr, dArr2};
        } catch (Throwable th) {
            if (fillUpWithTrainDataNPC != null) {
                try {
                    fillUpWithTrainDataNPC.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v35, types: [double[], double[][]] */
    public double[][] plattEstimate(TrainingData trainingData, boolean z) {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        TrainingBatch fillUpWithTrainData = trainingData.fillUpWithTrainData(z);
        try {
            final float[][] predict = predict(fillUpWithTrainData);
            final float[][] fArr = new float[predict.length][predict[0].length];
            fillUpWithTrainData.labels.copyTo(fArr);
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < predict[0].length; i++) {
                final int i2 = i;
                arrayList.add(newFixedThreadPool.submit(new Callable<double[]>() { // from class: de.unijena.bioinf.canopus.TensorflowModel.4
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public double[] call() throws Exception {
                        double[] dArr = new double[predict.length];
                        for (int i3 = 0; i3 < predict.length; i3++) {
                            dArr[i3] = predict[i3][i2];
                        }
                        double[] dArr2 = new double[predict.length];
                        for (int i4 = 0; i4 < predict.length; i4++) {
                            dArr2[i4] = fArr[i4][i2];
                        }
                        return PlattLayer.sigmoid_train(dArr, dArr2);
                    }
                }));
            }
            double[] dArr = new double[predict[0].length];
            double[] dArr2 = new double[predict[0].length];
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                try {
                    double[] dArr3 = (double[]) ((Future) arrayList.get(i3)).get();
                    dArr[i3] = dArr3[0];
                    dArr2[i3] = dArr3[1];
                } catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException(e);
                }
            }
            if (fillUpWithTrainData != null) {
                fillUpWithTrainData.close();
            }
            newFixedThreadPool.shutdown();
            return new double[]{dArr, dArr2};
        } catch (Throwable th) {
            if (fillUpWithTrainData != null) {
                try {
                    fillUpWithTrainData.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void saveWithPlattOnCrossval(TrainingData trainingData, int i, boolean z, boolean z2) throws IOException {
        double[] dArr = null;
        double[] dArr2 = null;
        double[] dArr3 = null;
        double[] dArr4 = null;
        if (z2) {
            double[][] plattEstimate = plattEstimate(trainingData, false);
            dArr = plattEstimate[0];
            dArr2 = plattEstimate[1];
            if (trainingData.isNPC()) {
                double[][] plattEstimateForNPC = plattEstimateForNPC(trainingData, false);
                dArr3 = plattEstimateForNPC[0];
                dArr4 = plattEstimateForNPC[1];
            }
        }
        saveWithoutPlattEstimate(trainingData, i, z, z2, false, dArr, dArr2, dArr3, dArr4);
    }

    public void save(TrainingData trainingData, int i, boolean z, boolean z2, boolean z3) throws IOException {
        double[] dArr = null;
        double[] dArr2 = null;
        double[] dArr3 = null;
        double[] dArr4 = null;
        if (z2) {
            double[][] plattEstimate = plattEstimate(trainingData);
            dArr = plattEstimate[0];
            dArr2 = plattEstimate[1];
            if (trainingData.isNPC()) {
                double[][] plattEstimateForNPC = plattEstimateForNPC(trainingData, true);
                dArr3 = plattEstimateForNPC[0];
                dArr4 = plattEstimateForNPC[1];
            }
        }
        saveWithoutPlattEstimate(trainingData, i, z, z2, z3, dArr, dArr2, dArr3, dArr4);
    }

    @Override // java.lang.AutoCloseable, java.io.Closeable
    public void close() {
        this.session.close();
    }

    public float regularizerTerm() {
        List run = this.session.runner().fetch("regularization", 0).run();
        float floatValue = ((Tensor) run.get(0)).floatValue();
        Iterator it = run.iterator();
        while (it.hasNext()) {
            ((Tensor) it.next()).close();
        }
        return floatValue;
    }
}
