package de.unijena.bioinf.canopus;

import de.unijena.bioinf.ChemistryBase.chem.MolecularFormula;
import de.unijena.bioinf.ChemistryBase.fp.ArrayFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.CdkFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.ClassyfireProperty;
import de.unijena.bioinf.ChemistryBase.fp.CustomFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.FPIter2;
import de.unijena.bioinf.ChemistryBase.fp.MaskedFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.fp.ProbabilityFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.Tanimoto;
import de.unijena.bioinf.ChemistryBase.jobs.SiriusJobs;
import de.unijena.bioinf.ChemistryBase.ms.Ms2Experiment;
import de.unijena.bioinf.ChemistryBase.utils.FileUtils;
import de.unijena.bioinf.babelms.ms.JenaMsWriter;
import de.unijena.bioinf.canopus.TensorflowModel;
import de.unijena.bioinf.canopus.TrainingData;
import de.unijena.bioinf.fingerid.InputFeatures;
import de.unijena.bioinf.fingerid.KernelToNumpyConverter;
import de.unijena.bioinf.fingerid.Prediction;
import de.unijena.bioinf.fingerid.SpectralPreprocessor;
import de.unijena.bioinf.sirius.IdentificationResult;
import de.unijena.bioinf.sirius.Sirius;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.array.TShortArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.logging.LogManager;
import java.util.regex.Pattern;
import org.tensorflow.Tensor;

/* loaded from: input_file:de/unijena/bioinf/canopus/Learn.class */
public class Learn {
    public static String findArgWithValue(String[] strArr, String str) {
        for (int i = 0; i < strArr.length; i++) {
            if (strArr[i].startsWith(str)) {
                return strArr[i].contains("=") ? strArr[i].split("=")[1].trim() : strArr[i + 1];
            }
        }
        return null;
    }

    /* JADX WARN: Finally extract failed */
    public static void main(String[] strArr) {
        String[] removeOpts = removeOpts(strArr);
        if (removeOpts[0].startsWith("play-around")) {
            try {
                if (removeOpts.length != 5) {
                    System.err.println("Usage:\nevaluate modeldir model.tgz outputdir independentPattern");
                } else {
                    playAround(new File(removeOpts[1]), new File(removeOpts[2]), new File(removeOpts[3]), removeOpts[4]);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            System.exit(0);
            return;
        }
        if (removeOpts[0].startsWith("evaluate")) {
            try {
                if (removeOpts.length != 5) {
                    System.err.println("Usage:\nevaluate modeldir model.tgz outputdir independentPattern");
                } else {
                    getDecisionValueOutputAndPerformance(new File(removeOpts[1]), new File(removeOpts[2]), new File(removeOpts[3]), removeOpts[4]);
                }
            } catch (IOException e2) {
                e2.printStackTrace();
            }
            System.exit(0);
            return;
        }
        if (removeOpts[0].startsWith("continue")) {
            try {
                continueModel(new File(removeOpts[1]), new File(removeOpts[2]), findArgWithValue(removeOpts, "--independent"));
            } catch (IOException e3) {
                e3.printStackTrace();
            }
            System.exit(0);
            return;
        }
        if (removeOpts[0].startsWith("prepare")) {
            System.out.println("Prepare learning");
            try {
                Prepare.prepare(new File(removeOpts[1]));
                return;
            } catch (IOException e4) {
                e4.printStackTrace();
                return;
            }
        }
        if (removeOpts[0].startsWith("fix")) {
            fix();
            return;
        }
        if (removeOpts[0].startsWith("decoy")) {
            try {
                Prediction loadFromFile = Prediction.loadFromFile(new File("fingerid.data"));
                DecoySpectrumGenerator decoySpectrumGenerator = new DecoySpectrumGenerator(loadFromFile);
                for (int i = 0; i < 50; i++) {
                    Ms2Experiment drawExperiment = decoySpectrumGenerator.drawExperiment();
                    Sirius.SiriusIdentificationJob makeIdentificationJob = decoySpectrumGenerator.sirius.makeIdentificationJob(drawExperiment);
                    SiriusJobs.getGlobalJobManager().submitJob(makeIdentificationJob);
                    IdentificationResult identificationResult = (IdentificationResult) ((List) makeIdentificationJob.takeResult()).get(0);
                    if (identificationResult != null) {
                        InputFeatures preprocessFromSirius = SpectralPreprocessor.preprocessFromSirius(decoySpectrumGenerator.sirius, identificationResult, drawExperiment);
                        ProbabilityFingerprint predictProbabilityFingerprint = loadFromFile.predictProbabilityFingerprint(preprocessFromSirius);
                        System.out.println("---------------------------");
                        System.out.println("Tree: " + preprocessFromSirius.tree.numberOfVertices());
                        System.out.println("Spectrum: " + preprocessFromSirius.spectrum.size());
                        int i2 = 0;
                        Iterator it = predictProbabilityFingerprint.presentFingerprints().iterator();
                        while (it.hasNext()) {
                            i2++;
                        }
                        System.out.println("Fingerprint: " + i2);
                        System.out.println("..............");
                        System.out.println(Arrays.toString(predictProbabilityFingerprint.toProbabilityArray()));
                        BufferedWriter writer = FileUtils.getWriter(new File(String.format(Locale.US, "decoys/%03d.ms", Integer.valueOf(i))));
                        try {
                            new JenaMsWriter().write(writer, drawExperiment);
                            if (writer != null) {
                                writer.close();
                            }
                        } catch (Throwable th) {
                            if (writer != null) {
                                try {
                                    writer.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    } else {
                        System.out.println("Empty spectrum");
                    }
                }
            } catch (IOException e5) {
                e5.printStackTrace();
            }
            try {
                SiriusJobs.getGlobalJobManager().shutdown();
                return;
            } catch (InterruptedException e6) {
                e6.printStackTrace();
                return;
            }
        }
        new File(removeOpts[0]);
        int parseInt = Integer.parseInt(removeOpts[1]);
        TrainingData.GROW = 1;
        String findArgWithValue = findArgWithValue(removeOpts, "--independent");
        try {
            TensorflowModel tensorflowModel = new TensorflowModel(new File(removeOpts[0]));
            try {
                TrainingData trainingData = new TrainingData(new File("."), findArgWithValue == null ? null : Pattern.compile(findArgWithValue));
                BatchGenerator batchGenerator = new BatchGenerator(trainingData, 20);
                System.out.println("Loss function: " + tensorflowModel.loss.substring(0, tensorflowModel.loss.indexOf(47)));
                System.out.println("PLATT CENTERING: " + String.valueOf(TrainingData.PLATT_CENTERING));
                System.out.println("PLATT SCALING: " + String.valueOf(TrainingData.SCALE_BY_STD));
                System.out.println("VECTOR NORM: " + String.valueOf(TrainingData.VECNORM_SCALING));
                ArrayList arrayList = new ArrayList();
                for (int i3 = 0; i3 < 2; i3++) {
                    arrayList.add(new Thread(batchGenerator));
                }
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    ((Thread) it2.next()).start();
                }
                TrainingBatch poll = batchGenerator.poll(0);
                ArrayList arrayList2 = new ArrayList();
                if (trainingData.independent != null) {
                    HashSet hashSet = new HashSet();
                    Iterator<EvaluationInstance> it3 = trainingData.crossvalidation.iterator();
                    while (it3.hasNext()) {
                        hashSet.add(it3.next().compound.inchiKey);
                    }
                    for (EvaluationInstance evaluationInstance : trainingData.independent) {
                        if (!hashSet.contains(evaluationInstance.compound.inchiKey)) {
                            arrayList2.add(evaluationInstance);
                        }
                    }
                }
                TrainingBatch generateBatch = trainingData.generateBatch(trainingData.crossvalidation);
                TrainingBatch generateBatch2 = trainingData.independent == null ? null : trainingData.generateBatch(trainingData.independent);
                TrainingBatch generateBatch3 = trainingData.independent == null ? null : trainingData.generateBatch(arrayList2);
                boolean z = generateBatch2 != null;
                new ArrayList();
                System.out.println("Resample");
                System.out.flush();
                TrainingBatch resampleMultithreaded = trainingData.resampleMultithreaded(trainingData.crossvalidation, (evaluationInstance2, i4) -> {
                    return TrainingData.SamplingStrategy.CONDITIONAL;
                });
                System.out.println("Start.");
                System.out.flush();
                double d = Double.NEGATIVE_INFINITY;
                int i5 = 0;
                int i6 = 0;
                while (true) {
                    if (i6 > 30000) {
                        break;
                    }
                    try {
                        TrainingBatch poll2 = batchGenerator.poll(i6);
                        if (i6 <= 0) {
                            try {
                                System.out.println("Batch size: ~" + poll2.platts.shape()[0]);
                            } catch (Throwable th3) {
                                if (poll2 != null) {
                                    try {
                                        poll2.close();
                                    } catch (Throwable th4) {
                                        th3.addSuppressed(th4);
                                    }
                                }
                                throw th3;
                            }
                        }
                        i5++;
                        long currentTimeMillis = System.currentTimeMillis();
                        double[] train = tensorflowModel.train(poll2);
                        long currentTimeMillis2 = System.currentTimeMillis();
                        PrintStream printStream = System.out;
                        double d2 = (currentTimeMillis2 - currentTimeMillis) / 1000.0d;
                        printStream.println(i6 + ".)\tloss = " + train[0] + "\tl2 norm = " + printStream + "\t (" + train[1] + " s)");
                        if (i6 % 200 == 0) {
                            reportStuff(Arrays.asList(poll, generateBatch, resampleMultithreaded, generateBatch2, generateBatch3), Arrays.asList("simulated", "crossval", "resampled", "indep", "indepNovel"), tensorflowModel, i6);
                        }
                        if (i6 > 0 && i6 % 2000 == 0) {
                            Report evaluate = tensorflowModel.evaluate(generateBatch);
                            if (evaluate.mcc > d) {
                                System.out.println("############ SAVE MODEL ##############");
                                tensorflowModel.save(trainingData, parseInt, true, true, i6 >= 30000);
                                d = evaluate.mcc;
                            }
                            poll.close();
                            if (i6 < 30000 * TrainingData.GROW) {
                                poll = batchGenerator.poll(0);
                                resampleMultithreaded.close();
                                resampleMultithreaded = trainingData.resampleMultithreaded(trainingData.crossvalidation, (evaluationInstance3, i7) -> {
                                    return TrainingData.SamplingStrategy.CONDITIONAL;
                                });
                            } else if (poll2 != null) {
                                poll2.close();
                            }
                        }
                        if (poll2 != null) {
                            poll2.close();
                        }
                        i6++;
                    } catch (Throwable th5) {
                        if (i5 > 1000) {
                            System.out.println("############ SAVE MODEL temp ##############");
                            tensorflowModel.save(trainingData, -parseInt, true, true, false);
                        }
                        throw th5;
                    }
                }
                if (i5 > 1000) {
                    System.out.println("############ SAVE MODEL temp ##############");
                    tensorflowModel.save(trainingData, -parseInt, true, true, false);
                }
                batchGenerator.stop();
                generateBatch.close();
                Iterator it4 = arrayList.iterator();
                while (it4.hasNext()) {
                    ((Thread) it4.next()).interrupt();
                }
                System.out.println("SHUTDOWN");
                resampleMultithreaded.close();
                tensorflowModel.close();
            } finally {
            }
        } catch (IOException e7) {
            e7.printStackTrace();
        }
    }

    private static String[] removeOpts(String[] strArr) {
        ArrayList arrayList = new ArrayList(strArr.length);
        for (String str : strArr) {
            if (str.startsWith("--no-norm")) {
                TrainingData.SCALE_BY_STD = false;
                TrainingData.SCALE_BY_MAX = false;
                TrainingData.PLATT_CENTERING = false;
                TrainingData.VECNORM_SCALING = false;
            } else {
                arrayList.add(str);
            }
        }
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    private static void writeExample(TensorflowModel tensorflowModel, TrainingData trainingData) throws IOException {
        EvaluationInstance evaluationInstance = trainingData.crossvalidation.get(0);
        TrainingBatch generateBatch = trainingData.generateBatch(Arrays.asList(evaluationInstance));
        try {
            float[][] predict = tensorflowModel.predict(generateBatch);
            if (!new File("example").exists()) {
                new File("example").mkdir();
            }
            float[][] fArr = (float[][]) generateBatch.formulas.copyTo(new float[1][(int) generateBatch.formulas.shape()[1]]);
            float[][] fArr2 = (float[][]) generateBatch.platts.copyTo(new float[1][(int) generateBatch.platts.shape()[1]]);
            float[][] fArr3 = (float[][]) generateBatch.labels.copyTo(new float[1][(int) generateBatch.labels.shape()[1]]);
            new KernelToNumpyConverter().writeToFile(new File("example/formula.matrix"), fArr);
            new KernelToNumpyConverter().writeToFile(new File("example/platts.matrix"), fArr2);
            new KernelToNumpyConverter().writeToFile(new File("example/labels.matrix"), fArr3);
            new KernelToNumpyConverter().writeToFile(new File("example/prediction.matrix"), predict);
            BufferedWriter writer = KernelToNumpyConverter.getWriter(new File("example/example.txt"));
            try {
                writer.write(evaluationInstance.compound.inchiKey);
                writer.write(9);
                writer.write(evaluationInstance.compound.formula.toString());
                writer.write(9);
                writer.write(evaluationInstance.fingerprint.toTabSeparatedString());
                writer.newLine();
                if (writer != null) {
                    writer.close();
                }
                PredictionPerformance.Modify modify = new PredictionPerformance().modify();
                for (int i = 0; i < fArr3[0].length; i++) {
                    modify.update(fArr3[0][i] > 0.0f, predict[0][i] > 0.0f);
                }
                System.out.println("Example: " + modify.done().toString());
                if (generateBatch != null) {
                    generateBatch.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (generateBatch != null) {
                try {
                    generateBatch.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static Report generateReportFromTrainData(List<EvaluationInstance> list, CustomFingerprintVersion customFingerprintVersion, MaskedFingerprintVersion maskedFingerprintVersion, TIntHashSet tIntHashSet, TIntHashSet tIntHashSet2) {
        TIntIntHashMap tIntIntHashMap = new TIntIntHashMap();
        int size = customFingerprintVersion.size();
        for (int i = 0; i < size; i++) {
            DummyMolecularProperty dummyMolecularProperty = (DummyMolecularProperty) customFingerprintVersion.getMolecularProperty(i);
            tIntIntHashMap.put(dummyMolecularProperty.absoluteIndex, dummyMolecularProperty.relativeIndex);
        }
        TIntHashSet tIntHashSet3 = new TIntHashSet();
        for (int i2 : maskedFingerprintVersion.allowedIndizes()) {
            tIntHashSet3.add(i2);
        }
        tIntHashSet3.retainAll(tIntIntHashMap.keySet());
        tIntIntHashMap.retainEntries((i3, i4) -> {
            return tIntHashSet3.contains(i3);
        });
        for (int i5 : tIntIntHashMap.values()) {
            tIntHashSet2.add(i5);
        }
        PredictionPerformance.Modify[] modifyArr = new PredictionPerformance.Modify[tIntHashSet3.size()];
        int[] array = tIntHashSet3.toArray();
        for (int i6 = 0; i6 < array.length; i6++) {
            modifyArr[i6] = new PredictionPerformance().modify();
        }
        Arrays.sort(array);
        for (EvaluationInstance evaluationInstance : list) {
            ArrayFingerprint arrayFingerprint = evaluationInstance.compound.fingerprint;
            ProbabilityFingerprint probabilityFingerprint = evaluationInstance.fingerprint;
            for (int i7 = 0; i7 < array.length; i7++) {
                int i8 = array[i7];
                modifyArr[i7].update(arrayFingerprint.isSet(i8), probabilityFingerprint.isSet(i8));
            }
        }
        PredictionPerformance[] predictionPerformanceArr = new PredictionPerformance[tIntHashSet3.size()];
        for (int i9 = 0; i9 < modifyArr.length; i9++) {
            predictionPerformanceArr[i9] = modifyArr[i9].done();
        }
        return new Report(predictionPerformanceArr);
    }

    /* JADX WARN: Type inference failed for: r0v136, types: [double[], double[][]] */
    public static void playAround(File file, File file2, File file3, String str) throws IOException {
        TrainingData trainingData = new TrainingData(new File("."), str != null ? Pattern.compile(str) : null);
        trainingData.generateBatch(trainingData.crossvalidation);
        trainingData.generateBatch(trainingData.independent);
        Canopus loadFromFile = Canopus.loadFromFile(file2);
        HashSet hashSet = new HashSet(Arrays.asList("1-benzopyrans", "Alkyl aryl ethers", "Benzene and substituted derivatives", "Benzenoids", "Benzopyrans", "Chemical entities", "Ethers", "Flavans", "Flavonoids", "Hydrazines and derivatives", "Hydrocarbon derivatives", "Organic compounds", "Organic nitrogen compounds", "Organic oxygen compounds", "Organoheterocyclic compounds", "Organonitrogen compounds", "Organooxygen compounds", "Organopnictogen compounds", "Organosulfur compounds", "Oxacyclic compounds", "Phenylpropanoids and polyketides", "Thiosemicarbazides", "Thiosemicarbazones"));
        TShortArrayList tShortArrayList = new TShortArrayList();
        for (int i : trainingData.classyFireMask.allowedIndizes()) {
            String name = trainingData.classyFireFingerprintVersion.getMolecularProperty(i).getName();
            if (hashSet.contains(name)) {
                System.out.println("has " + name);
                tShortArrayList.add((short) i);
            }
        }
        MolecularFormula parseOrThrow = MolecularFormula.parseOrThrow("C23H21N3OS");
        LabeledCompound labeledCompound = new LabeledCompound("SUZFIXDOJRIJQR", parseOrThrow, trainingData.fingerprintVersion.mask(new ArrayFingerprint(CdkFingerprintVersion.getDefault(), new short[]{9, 11, 36, 38, 44, 45, 47, 48, 56, 72, 197, 215, 328, 329, 333, 341, 349, 354, 356, 361, 404, 408, 413, 418, 423, 430, 434, 438, 439, 440, 442, 449, 452, 455, 458, 459, 465, 466, 467, 472, 474, 480, 485, 486, 492, 493, 494, 498, 503, 505, 506, 512, 513, 517, 518, 519, 522, 523, 524, 525, 526, 528, 529, 530, 537, 538, 539, 540, 542, 543, 546, 561, 706, 707, 712, 713, 714, 720, 721, 727, 783, 785, 787, 811, 812, 813, 814, 821, 827, 828, 860, 861, 866, 867, 868, 869, 872, 873, 874, 879, 880, 883, 884, 893, 894, 898, 899, 903, 909, 910, 912, 918, 920, 921, 922, 926, 933, 944, 946, 949, 958, 962, 963, 969, 970, 974, 975, 992, 995, 998, 1004, 1018, 1026, 1036, 1039, 1044, 1045, 1048, 1052, 1056, 1066, 1068, 1069, 1070, 1076, 1077, 1080, 1083, 1084, 1088, 1092, 1093, 1098, 1101, 1102, 1104, 1106, 1110, 1112, 1117, 1120, 1122, 1123, 1127, 1131, 1132, 1134, 1135, 1136, 1141, 1142, 1146, 1147, 1148, 1154, 1156, 1160, 1161, 1162, 1163, 1165, 1166, 1168, 1169, 1171, 1172, 1179, 1182, 1183, 1188, 1192, 1194, 1196, 1205, 1206, 1207, 1208, 1211, 1216, 1217, 1224, 1225, 1226, 1236, 1237, 1238, 1284, 1347, 1409, 1441, 1442, 1569, 1705, 1706, 2085, 2086, 2344, 2345, 2369, 2533, 2972, 2974, 3733, 3955, 3956, 4103, 4383, 4421, 4632, 4823, 4838, 5014, 5048, 5090, 5112, 5148, 5152, 5158, 5181, 5196, 5290, 5334, 5357, 5358, 5364, 5408, 5488, 5521, 5617, 5629, 5646, 5695, 5701, 5709, 5710, 5724, 5739, 6221, 6251, 6286, 6550, 6603, 6606, 6617, 6712, 6771, 6840, 6862, 6878, 6954, 7118, 7145, 7211, 7305, 7380, 7594, 7626, 7660, 7718, 7834, 7922, 8095, 8186, 8277, 8297, 8327, 8348, 8376, 8403, 8406, 8469, 8491, 8512, 8590, 8639, 8767, 8793})).asArray(), trainingData.classyFireMask.mask(new ArrayFingerprint(trainingData.classyFireFingerprintVersion, tShortArrayList.toArray())), Canopus.getFormulaFeatures(parseOrThrow), null);
        trainingData.normalizeVector(labeledCompound);
        TensorflowModel tensorflowModel = new TensorflowModel(file);
        try {
            tensorflowModel.feedWeightMatrices(loadFromFile).resetWeights();
            List<TrainingData.SamplingStrategy> asList = Arrays.asList(TrainingData.SamplingStrategy.CONDITIONAL, TrainingData.SamplingStrategy.INDEPENDENT, TrainingData.SamplingStrategy.TEMPLATE, TrainingData.SamplingStrategy.DISTURBED_TEMPLATE, TrainingData.SamplingStrategy.PERFECT, null);
            for (TrainingData.SamplingStrategy samplingStrategy : asList) {
                if (samplingStrategy != null) {
                    ?? r0 = new double[100];
                    r0[0] = labeledCompound.fingerprint.asProbabilistic().toProbabilityArray();
                    for (int i2 = 1; i2 < 100; i2++) {
                        r0[i2] = trainingData.sampleFingerprintVector(labeledCompound, samplingStrategy);
                    }
                    FileUtils.writeDoubleMatrix(new File(samplingStrategy.name() + ".matrix"), (double[][]) r0);
                }
            }
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
            for (TrainingData.SamplingStrategy samplingStrategy2 : asList) {
                System.out.println("-------------- " + (samplingStrategy2 == null ? "real" : samplingStrategy2.name()) + "-------------- ");
                TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
                TDoubleArrayList tDoubleArrayList2 = new TDoubleArrayList();
                TDoubleArrayList tDoubleArrayList3 = new TDoubleArrayList();
                ArrayList arrayList = new ArrayList(trainingData.crossvalidation);
                Collections.shuffle(arrayList);
                List<EvaluationInstance> subList = arrayList.subList(0, Math.min(10000, arrayList.size()));
                ArrayList arrayList2 = new ArrayList();
                for (EvaluationInstance evaluationInstance : subList) {
                    if (samplingStrategy2 == null) {
                        arrayList2.add(newFixedThreadPool.submit(() -> {
                            return evaluationInstance.fingerprint;
                        }));
                    } else {
                        arrayList2.add(newFixedThreadPool.submit(() -> {
                            return trainingData.sampleFingerprint(evaluationInstance.compound, samplingStrategy2);
                        }));
                    }
                }
                double d = 0.0d;
                for (int i3 = 0; i3 < subList.size(); i3++) {
                    EvaluationInstance evaluationInstance2 = (EvaluationInstance) subList.get(i3);
                    try {
                        ProbabilityFingerprint probabilityFingerprint = (ProbabilityFingerprint) ((Future) arrayList2.get(i3)).get();
                        tDoubleArrayList2.add(Tanimoto.fastTanimoto(probabilityFingerprint, evaluationInstance2.compound.fingerprint));
                        tDoubleArrayList.add(Tanimoto.fastTanimoto(evaluationInstance2.fingerprint, evaluationInstance2.compound.fingerprint));
                        double d2 = 0.0d;
                        double d3 = 0.0d;
                        double d4 = 0.0d;
                        for (FPIter2 fPIter2 : probabilityFingerprint.foreachPair(evaluationInstance2.fingerprint)) {
                            double leftProbability = fPIter2.getLeftProbability();
                            double rightProbability = fPIter2.getRightProbability();
                            double d5 = d2 + (leftProbability * rightProbability);
                            double d6 = d3 + (leftProbability * leftProbability);
                            double d7 = d4 + (rightProbability * rightProbability);
                            double leftProbability2 = 1.0d - fPIter2.getLeftProbability();
                            double rightProbability2 = 1.0d - fPIter2.getRightProbability();
                            d2 = d5 + (leftProbability2 * rightProbability2);
                            d3 = d6 + (leftProbability2 * leftProbability2);
                            d4 = d7 + (rightProbability2 * rightProbability2);
                            d += fPIter2.getLeftProbability();
                        }
                        tDoubleArrayList3.add(d2 / Math.sqrt(d3 * d4));
                    } catch (InterruptedException | ExecutionException e) {
                        e.printStackTrace();
                        throw new RuntimeException(e);
                    }
                }
                System.out.printf("------\nAverage: real %f, simulated %f\tSimilarity predicted vs. simulated: %f\tavg. #1 = %f\n", Double.valueOf(tDoubleArrayList.sum() / tDoubleArrayList.size()), Double.valueOf(tDoubleArrayList2.sum() / tDoubleArrayList2.size()), Double.valueOf(tDoubleArrayList3.sum() / tDoubleArrayList3.size()), Double.valueOf(d / subList.size()));
            }
            newFixedThreadPool.shutdown();
            tensorflowModel.close();
        } catch (Throwable th) {
            try {
                tensorflowModel.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public static void getDecisionValueOutputAndPerformance(File file, File file2, File file3, String str) throws IOException {
        TrainingData trainingData = new TrainingData(new File("."), str != null ? Pattern.compile(str) : null);
        TrainingBatch generateBatch = trainingData.generateBatch(trainingData.crossvalidation);
        TrainingBatch generateBatch2 = trainingData.generateBatch(trainingData.independent);
        Canopus loadFromFile = Canopus.loadFromFile(file2);
        TensorflowModel tensorflowModel = new TensorflowModel(file);
        try {
            tensorflowModel.feedWeightMatrices(loadFromFile).resetWeights();
            float[][] predict = tensorflowModel.predict(generateBatch);
            float[][] predict2 = tensorflowModel.predict(generateBatch2);
            writePredictOutput(file3, "crossvalidation", trainingData, trainingData.crossvalidation, predict);
            writePredictOutput(file3, "independent", trainingData, trainingData.independent, predict2);
            Random random = new Random();
            ArrayList arrayList = new ArrayList();
            ArrayList<CompoundClass> arrayList2 = new ArrayList(trainingData.compoundClasses.valueCollection());
            arrayList2.sort(Comparator.comparingInt(compoundClass -> {
                return compoundClass.index;
            }));
            for (CompoundClass compoundClass2 : arrayList2) {
                if (compoundClass2.compounds.isEmpty()) {
                    System.err.println("No example for " + compoundClass2.ontology.getName());
                } else {
                    LabeledCompound labeledCompound = compoundClass2.compounds.get(random.nextInt(compoundClass2.compounds.size()));
                    arrayList.add(new EvaluationInstance(compoundClass2.ontology.getName(), new ProbabilityFingerprint(trainingData.fingerprintVersion, trainingData.sampleFingerprintVector(labeledCompound, TrainingData.SamplingStrategy.DISTURBED_TEMPLATE)), labeledCompound));
                }
            }
            TrainingBatch generateBatch3 = trainingData.generateBatch(arrayList);
            try {
                writePredictOutput(file3, "simulated", trainingData, arrayList, tensorflowModel.predict(generateBatch3));
                if (generateBatch3 != null) {
                    generateBatch3.close();
                }
                tensorflowModel.close();
                generateBatch2.close();
                generateBatch.close();
            } finally {
            }
        } catch (Throwable th) {
            try {
                tensorflowModel.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static void writePredictOutput(File file, String str, TrainingData trainingData, List<EvaluationInstance> list, float[][] fArr) {
        BufferedWriter writer;
        file.mkdirs();
        ClassyfireProperty[] classyfirePropertyArr = new ClassyfireProperty[trainingData.compoundClasses.size()];
        PredictionPerformance.Modify[] modifyArr = new PredictionPerformance.Modify[trainingData.compoundClasses.size()];
        int i = 0;
        for (int i2 : trainingData.classyFireMask.allowedIndizes()) {
            PredictionPerformance.Modify modify = new PredictionPerformance(0.0d, 0.0d, 0.0d, 0.0d, 0.0d, false).modify();
            classyfirePropertyArr[i] = (ClassyfireProperty) trainingData.classyFireMask.getMolecularProperty(i2);
            modifyArr[i] = modify;
            i++;
        }
        try {
            writer = FileUtils.getWriter(new File(file, str + "_prediction.csv"));
            for (int i3 = 0; i3 < list.size(); i3++) {
                try {
                    EvaluationInstance evaluationInstance = list.get(i3);
                    writer.write(evaluationInstance.name);
                    writer.write(9);
                    writer.write(evaluationInstance.compound.inchiKey);
                    writer.write(9);
                    writer.write(evaluationInstance.compound.label.toOneZeroString());
                    boolean[] booleanArray = evaluationInstance.compound.label.toBooleanArray();
                    float[] fArr2 = fArr[i3];
                    for (int i4 = 0; i4 < fArr2.length; i4++) {
                        writer.write(9);
                        writer.write(String.valueOf(fArr2[i4]));
                        modifyArr[i4].update(booleanArray[i4], fArr2[i4] >= 0.0f);
                    }
                    writer.newLine();
                } finally {
                    if (writer != null) {
                        try {
                            writer.close();
                        } catch (Throwable th) {
                            th.addSuppressed(th);
                        }
                    }
                }
            }
            if (writer != null) {
                writer.close();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        try {
            writer = FileUtils.getWriter(new File(file, str + "_stats.csv"));
            try {
                writer.write("name\tid\tparent\tparentId\t" + PredictionPerformance.csvHeader());
                for (int i5 = 0; i5 < classyfirePropertyArr.length; i5++) {
                    ClassyfireProperty classyfireProperty = classyfirePropertyArr[i5];
                    writer.write(String.valueOf(i5));
                    writer.write(9);
                    writer.write(classyfireProperty.getName());
                    writer.write(9);
                    writer.write(classyfireProperty.getChemontIdentifier());
                    writer.write(9);
                    writer.write(classyfireProperty.getParent().getName());
                    writer.write(9);
                    writer.write(classyfireProperty.getParent().getChemontIdentifier());
                    writer.write(9);
                    writer.write(modifyArr[i5].done().toCsvRow());
                }
                if (writer != null) {
                    writer.close();
                }
            } finally {
            }
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    public static void continueModel(File file, File file2, String str) throws IOException {
        TrainingBatch poll;
        Canopus loadFromFile = Canopus.loadFromFile(file2);
        TrainingData trainingData = new TrainingData(new File("."), Pattern.compile(str));
        BatchGenerator batchGenerator = new BatchGenerator(trainingData, 20);
        batchGenerator.iterationNum.set(30000);
        Thread thread = new Thread(batchGenerator);
        Thread thread2 = new Thread(batchGenerator);
        thread.start();
        thread2.start();
        TrainingBatch poll2 = batchGenerator.poll(0);
        HashSet hashSet = new HashSet();
        Iterator<EvaluationInstance> it = trainingData.crossvalidation.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().compound.inchiKey);
        }
        TrainingBatch generateBatch = trainingData.generateBatch(trainingData.crossvalidation);
        TrainingBatch generateBatch2 = trainingData.generateBatch(trainingData.independent);
        new ArrayList();
        TensorflowModel tensorflowModel = new TensorflowModel(file);
        try {
            TensorflowModel.Resetter feedWeightMatrices = tensorflowModel.feedWeightMatrices(loadFromFile);
            feedWeightMatrices.resetWeights();
            double[][] plattEstimate = tensorflowModel.plattEstimate(trainingData);
            int i = 0;
            reportStuff(Arrays.asList(poll2, generateBatch), Arrays.asList("simulated", "crossval", "indep"), tensorflowModel, 0);
            for (int i2 = 0; i2 < 100; i2++) {
                if (i2 % 10 == 0) {
                    long currentTimeMillis = System.currentTimeMillis();
                    double[] train = tensorflowModel.train(generateBatch);
                    tensorflowModel.train(generateBatch2);
                    long currentTimeMillis2 = System.currentTimeMillis();
                    PrintStream printStream = System.out;
                    double d = (currentTimeMillis2 - currentTimeMillis) / 1000.0d;
                    printStream.println((0 + i2) + ".)\tloss = " + train[0] + "\tl2 norm = " + printStream + "\t (" + train[1] + " s)");
                } else {
                    poll = batchGenerator.poll(30000 + i2);
                    try {
                        long currentTimeMillis3 = System.currentTimeMillis();
                        double[] train2 = tensorflowModel.train(poll);
                        long currentTimeMillis4 = System.currentTimeMillis();
                        PrintStream printStream2 = System.out;
                        double d2 = (currentTimeMillis4 - currentTimeMillis3) / 1000.0d;
                        printStream2.println((0 + i2) + ".)\tloss = " + train2[0] + "\tl2 norm = " + printStream2 + "\t (" + train2[1] + " s)");
                        tensorflowModel.train(poll);
                        if (poll != null) {
                            poll.close();
                        }
                    } finally {
                    }
                }
                if (i2 % 25 == 0) {
                    System.out.println("---> reset all weights.");
                    feedWeightMatrices.resetWeights();
                    reportStuff(Arrays.asList(poll2, generateBatch), Arrays.asList("simulated", "crossval", "indep"), tensorflowModel, 0);
                }
            }
            reportStuff(Arrays.asList(poll2, generateBatch), Arrays.asList("simulated", "crossval", "indep"), tensorflowModel, 0);
            feedWeightMatrices.resetWeights();
            System.out.println("---> reset all weights.");
            reportStuff(Arrays.asList(poll2, generateBatch), Arrays.asList("simulated", "crossval", "indep"), tensorflowModel, 0);
            for (int i3 = 0; i3 < 1000; i3++) {
                i = 30000 + i3;
                if (i3 % 10 == 0) {
                    long currentTimeMillis5 = System.currentTimeMillis();
                    double[] train3 = tensorflowModel.train(generateBatch);
                    long currentTimeMillis6 = System.currentTimeMillis();
                    PrintStream printStream3 = System.out;
                    double d3 = (currentTimeMillis6 - currentTimeMillis5) / 1000.0d;
                    printStream3.println(i + ".)\tloss = " + train3[0] + "\tl2 norm = " + printStream3 + "\t (" + train3[1] + " s)");
                } else if (i3 % 10 == 1) {
                    long currentTimeMillis7 = System.currentTimeMillis();
                    double[] train4 = tensorflowModel.train(generateBatch2);
                    long currentTimeMillis8 = System.currentTimeMillis();
                    PrintStream printStream4 = System.out;
                    double d4 = (currentTimeMillis8 - currentTimeMillis7) / 1000.0d;
                    printStream4.println(i + ".)\tloss = " + train4[0] + "\tl2 norm = " + printStream4 + "\t (" + train4[1] + " s)");
                } else {
                    poll = batchGenerator.poll(i);
                    try {
                        long currentTimeMillis9 = System.currentTimeMillis();
                        double[] train5 = tensorflowModel.train(poll);
                        long currentTimeMillis10 = System.currentTimeMillis();
                        PrintStream printStream5 = System.out;
                        double d5 = (currentTimeMillis10 - currentTimeMillis9) / 1000.0d;
                        printStream5.println(i + ".)\tloss = " + train5[0] + "\tl2 norm = " + printStream5 + "\t (" + train5[1] + " s)");
                        if (poll != null) {
                            poll.close();
                        }
                    } finally {
                    }
                }
            }
            reportStuff(Arrays.asList(poll2, generateBatch), Arrays.asList("simulated", "crossval", "indep"), tensorflowModel, i);
            tensorflowModel.saveWithoutPlattEstimate(trainingData, 100, true, true, false, plattEstimate[0], plattEstimate[1]);
            tensorflowModel.close();
        } catch (Throwable th) {
            try {
                tensorflowModel.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static void reportStuff(TrainingBatch trainingBatch, TrainingBatch trainingBatch2, TrainingBatch trainingBatch3, TrainingBatch trainingBatch4, int[] iArr, List<DummyMolecularProperty> list, Report report, Report report2, TensorflowModel tensorflowModel, int i) {
        System.out.println("Evaluation " + i + ".) " + tensorflowModel.evaluate(trainingBatch));
        System.out.println("Crossvalidation " + i + ".) " + tensorflowModel.evaluate(trainingBatch2));
        if (trainingBatch3 != null) {
            System.out.println("Indep. " + i + ".) " + tensorflowModel.evaluate(trainingBatch3));
        }
        if (trainingBatch4 != null) {
            System.out.println("Indep. Novel " + i + ".) " + tensorflowModel.evaluate(trainingBatch4));
        }
    }

    private static void reportStuff(List<TrainingBatch> list, List<String> list2, TensorflowModel tensorflowModel, int i) {
        for (int i2 = 0; i2 < list.size(); i2++) {
            System.out.println(i + ".) " + list2.get(i2) + ":\t" + tensorflowModel.evaluate(list.get(i2)));
        }
    }

    private static void fix() {
        try {
            Canopus loadFromFile = Canopus.loadFromFile(new File("canopus_1.data.gz"));
            loadFromFile.cdkFingerprintVersion = TrainingData.VERSION;
            new TIntArrayList();
            String[] readLines = FileUtils.readLines(new File("trainable_indizes.csv"));
            MaskedFingerprintVersion.Builder buildMaskFor = MaskedFingerprintVersion.buildMaskFor(loadFromFile.cdkFingerprintVersion);
            buildMaskFor.disableAll();
            for (int i = 1; i < readLines.length; i++) {
                buildMaskFor.enable(Integer.parseInt(readLines[i].split("\t")[0]));
            }
            loadFromFile.cdkMask = buildMaskFor.toMask();
            loadFromFile.writeToFile(new File("canopus_fp.data.gz"));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static String tensor2string(Tensor tensor) {
        float[][] fArr = new float[(int) tensor.shape()[0]][(int) tensor.shape()[1]];
        tensor.copyTo(fArr);
        StringBuilder sb = new StringBuilder();
        sb.append("{\n");
        for (float[] fArr2 : fArr) {
            sb.append('\t').append(Arrays.toString(fArr2)).append('\n');
        }
        sb.append('}');
        return sb.toString();
    }

    static {
        System.setProperty("org.apache.commons.logging.Log", "org.apache.commons.logging.impl.NoOpLog");
        System.setProperty("de.unijena.bioinf.ms.propertyLocations", "sirius.build.properties, csi_fingerid.build.properties");
        try {
            InputStream resourceAsStream = Learn.class.getResourceAsStream("/logging.properties");
            try {
                LogManager.getLogManager().readConfiguration(resourceAsStream);
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
            } finally {
            }
        } catch (IOException e) {
            System.err.println("Could not read logging configuration.");
            e.printStackTrace();
        }
    }
}
