package de.unijena.bioinf.canopus;

import de.unijena.bioinf.ChemistryBase.fp.ArrayFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.ClassyfireProperty;
import de.unijena.bioinf.ChemistryBase.fp.CustomFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.MaskedFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.NPCFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.fp.ProbabilityFingerprint;
import de.unijena.bioinf.ChemistryBase.utils.FileUtils;
import de.unijena.bioinf.canopus.TensorflowModel;
import de.unijena.bioinf.canopus.TrainingData;
import de.unijena.bioinf.chemdb.ChemicalDatabase;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.LogManager;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.tensorflow.Tensor;

/* loaded from: input_file:de/unijena/bioinf/canopus/Learn.class */
public class Learn {
    private static int FLINDEX;
    private static int FLGINDEX;
    private static double REGSTREN;

    public static String findArgWithValue(String[] strArr, String str) {
        for (int i = 0; i < strArr.length; i++) {
            if (strArr[i].startsWith(str)) {
                return strArr[i].contains("=") ? strArr[i].split("=")[1].trim() : strArr[i + 1];
            }
        }
        return null;
    }

    public static void main(String[] strArr) {
        Thread thread;
        BatchGenerator batchGenerator;
        System.out.println("Use fingerprints from " + ChemicalDatabase.FINGERPRINT_TABLE + " with ID " + ChemicalDatabase.FINGERPRINT_ID);
        System.out.println("Uptodate version 5");
        System.out.println("CLIPPING? " + TrainingData.CLIPPING);
        String[] removeOpts = removeOpts(strArr);
        if (removeOpts[0].startsWith("evaluate")) {
            try {
                if (removeOpts.length != 5) {
                    System.err.println("Usage:\nevaluate modeldir model.tgz outputdir independentPattern");
                } else {
                    getDecisionValueOutputAndPerformance(new File(removeOpts[1]), new File(removeOpts[2]), new File(removeOpts[3]), removeOpts[4]);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            System.exit(0);
            return;
        }
        if (removeOpts[0].startsWith("evaluate-indep")) {
            try {
                if (removeOpts.length != 5) {
                    System.err.println("Usage:\nevaluate modeldir model.tgz outputdir independentPattern");
                } else {
                    getDecisionValueOutputAndPerformanceOnIndep(new File(removeOpts[1]), new File(removeOpts[2]), new File(removeOpts[3]), removeOpts[4]);
                }
            } catch (IOException e2) {
                e2.printStackTrace();
            }
            System.exit(0);
            return;
        }
        if (removeOpts[0].startsWith("continue")) {
            try {
                continueModel(new File(removeOpts[1]), new File(removeOpts[2]), findArgWithValue(removeOpts, "--independent"), false);
            } catch (IOException e3) {
                e3.printStackTrace();
            }
            System.exit(0);
            return;
        }
        if (removeOpts[0].startsWith("finalize")) {
            try {
                continueModel(new File(removeOpts[1]), new File(removeOpts[2]), findArgWithValue(removeOpts, "--independent"), true);
            } catch (IOException e4) {
                e4.printStackTrace();
            }
            System.exit(0);
            return;
        }
        if (removeOpts[0].startsWith("sample")) {
            sample(new File(removeOpts[1]));
            return;
        }
        if (removeOpts[0].startsWith("prepare")) {
            System.out.println("Prepare learning");
            try {
                Prepare.prepare(new File(removeOpts[1]));
                return;
            } catch (IOException e5) {
                e5.printStackTrace();
                return;
            }
        }
        if (removeOpts[0].startsWith("fix")) {
            fix();
            return;
        }
        new File(removeOpts[0]);
        int parseInt = Integer.parseInt(removeOpts[1]);
        TrainingData.GROW = 1;
        String findArgWithValue = findArgWithValue(removeOpts, "--independent");
        try {
            TensorflowModel tensorflowModel = new TensorflowModel(new File(removeOpts[0]));
            try {
                TrainingData trainingData = new TrainingData(new File("."), findArgWithValue == null ? null : Pattern.compile(findArgWithValue));
                BatchGenerator batchGenerator2 = new BatchGenerator(trainingData, 20);
                System.out.println("Loss function: " + tensorflowModel.loss);
                System.out.println("PLATT CENTERING: " + String.valueOf(TrainingData.PLATT_CENTERING));
                System.out.println("PLATT SCALING: " + String.valueOf(TrainingData.SCALE_BY_STD));
                System.out.println("VECTOR NORM: " + String.valueOf(TrainingData.VECNORM_SCALING));
                ArrayList arrayList = new ArrayList();
                for (int i = 0; i < 2; i++) {
                    arrayList.add(new Thread(batchGenerator2));
                }
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    ((Thread) it.next()).start();
                }
                if (trainingData.isNPC()) {
                    batchGenerator = new BatchGenerator(trainingData, 4);
                    batchGenerator.npc = true;
                    thread = new Thread(batchGenerator);
                    thread.start();
                } else {
                    thread = null;
                    batchGenerator = null;
                }
                TrainingBatch poll = batchGenerator2.poll(0);
                ArrayList arrayList2 = new ArrayList();
                if (trainingData.independent != null) {
                    HashSet hashSet = new HashSet();
                    Iterator<EvaluationInstance> it2 = trainingData.crossvalidation.iterator();
                    while (it2.hasNext()) {
                        hashSet.add(it2.next().compound.inchiKey);
                    }
                    for (EvaluationInstance evaluationInstance : trainingData.independent) {
                        if (!hashSet.contains(evaluationInstance.compound.inchiKey)) {
                            arrayList2.add(evaluationInstance);
                        }
                    }
                }
                TrainingBatch generateBatch = trainingData.generateBatch(trainingData.crossvalidation);
                TrainingBatch generateBatch2 = trainingData.independent == null ? null : trainingData.generateBatch(trainingData.independent);
                TrainingBatch generateNPCBatch = trainingData.generateNPCBatch(trainingData.npcInstances);
                TrainingBatch generateBatch3 = trainingData.independent == null ? null : trainingData.generateBatch(arrayList2);
                boolean z = generateBatch2 != null;
                new ArrayList();
                System.out.println("Resample");
                System.out.flush();
                TrainingBatch resampleMultithreaded = trainingData.resampleMultithreaded(trainingData.crossvalidation, (evaluationInstance2, i2) -> {
                    return TrainingData.SamplingStrategy.CONDITIONAL;
                });
                System.out.println("Start.");
                System.out.flush();
                tensorflowModel.setRegularizerStrength(0.0f);
                int i3 = 1;
                int i4 = 0;
                int i5 = 0;
                while (true) {
                    if (i5 > 40000) {
                        break;
                    }
                    if (i5 == 500) {
                        System.out.println("Set regularization to " + REGSTREN);
                        tensorflowModel.setRegularizerStrength((float) REGSTREN);
                    }
                    TrainingBatch poll2 = batchGenerator2.poll(i5);
                    if (i5 <= 0) {
                        try {
                            System.out.println("Batch size: ~" + poll2.platts.shape()[0]);
                        } catch (Throwable th) {
                            throw th;
                        }
                    }
                    i4++;
                    long currentTimeMillis = System.currentTimeMillis();
                    if (i5 % 10 == 0) {
                        double[] trainWithGradient = tensorflowModel.trainWithGradient(poll2.platts, poll2.formulas, poll2.labels);
                        long currentTimeMillis2 = System.currentTimeMillis();
                        PrintStream printStream = System.out;
                        double d = trainWithGradient[0];
                        double d2 = trainWithGradient[1];
                        double d3 = trainWithGradient[2];
                        double d4 = (currentTimeMillis2 - currentTimeMillis) / 1000.0d;
                        printStream.println(i5 + ".)\tloss = " + d + "\tl2 norm = " + printStream + "\tgradient = " + d2 + "\t (" + printStream + " s)");
                    } else {
                        double[] train = tensorflowModel.train(poll2.platts, poll2.formulas, poll2.labels);
                        long currentTimeMillis3 = System.currentTimeMillis();
                        PrintStream printStream2 = System.out;
                        double d5 = (currentTimeMillis3 - currentTimeMillis) / 1000.0d;
                        printStream2.println(i5 + ".)\tloss = " + train[0] + "\tl2 norm = " + printStream2 + "\t (" + train[1] + " s)");
                    }
                    if (batchGenerator != null && i5 % i3 == 0) {
                        long currentTimeMillis4 = System.currentTimeMillis();
                        poll2 = batchGenerator.poll(i5);
                        try {
                            double[] train_npc = tensorflowModel.train_npc(poll2.platts, poll2.formulas, poll2.labels, poll2.npcLabels);
                            long currentTimeMillis5 = System.currentTimeMillis();
                            PrintStream printStream3 = System.out;
                            double d6 = train_npc[0];
                            double d7 = train_npc[1];
                            double d8 = train_npc[2];
                            double d9 = (currentTimeMillis5 - currentTimeMillis4) / 1000.0d;
                            printStream3.println(i5 + ".)\tnpcloss = " + d6 + "\tloss = " + printStream3 + "\tl2 norm = " + d7 + "\t (" + printStream3 + " s)");
                            if (i5 > 2000) {
                                i3 = 4;
                            } else if (i5 > 500) {
                                i3 = 2;
                            }
                            if (poll2 != null) {
                                poll2.close();
                            }
                        } finally {
                            if (poll2 != null) {
                                try {
                                    poll2.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                        }
                    }
                    if (i5 % 400 == 0) {
                        reportStuff(Arrays.asList(poll, generateBatch, resampleMultithreaded, generateBatch2, generateBatch3), Arrays.asList("simulated", "crossval", "resampled", "indep", "indepNovel"), tensorflowModel, i5);
                        if (trainingData.isNPC()) {
                            evalNPC(generateNPCBatch, tensorflowModel);
                        }
                    }
                    if (i5 < 18000 || tensorflowModel.evaluate(generateBatch).score() <= Double.NEGATIVE_INFINITY) {
                        if (poll2 != null) {
                            poll2.close();
                        }
                        i5++;
                    } else {
                        System.out.println("############ SAVE MODEL ##############");
                        File file = new File("canopus_final_model_" + parseInt);
                        float[][] predict = tensorflowModel.predict(generateBatch);
                        float[][] predict2 = tensorflowModel.predict(generateBatch2);
                        writePredictOutput(file, "crossvalidation", trainingData, trainingData.crossvalidation, predict);
                        writePredictOutput(file, "independent", trainingData, trainingData.independent, predict2);
                        if (trainingData.isNPC()) {
                            writeNPCPredictOutput(file, "crossvalidation", trainingData, trainingData.npcInstances, tensorflowModel.predictNPC(generateNPCBatch));
                        }
                        tensorflowModel.saveWithPlattOnCrossval(trainingData, -parseInt, true, true);
                        tensorflowModel.save(trainingData, parseInt, true, true, true);
                        poll.close();
                        if (poll2 != null) {
                            poll2.close();
                        }
                    }
                }
                batchGenerator2.stop();
                if (batchGenerator != null) {
                    batchGenerator.stop();
                }
                generateBatch.close();
                Iterator it3 = arrayList.iterator();
                while (it3.hasNext()) {
                    ((Thread) it3.next()).interrupt();
                }
                if (thread != null) {
                    thread.interrupt();
                }
                System.out.println("SHUTDOWN");
                resampleMultithreaded.close();
                tensorflowModel.close();
            } finally {
            }
        } catch (IOException e6) {
            e6.printStackTrace();
        }
    }

    private static void sample(File file) {
        TrainingData.PLATT_CENTERING = false;
        TrainingData.CLIPPING = false;
        TrainingData.SCALE_BY_MAX = false;
        TrainingData.SCALE_BY_STD = false;
        TrainingData.VECNORM_SCALING = false;
        try {
            TrainingData trainingData = new TrainingData(new File("."), null);
            new BatchGenerator(trainingData, 20);
            new ArrayList();
            ArrayList arrayList = new ArrayList(trainingData.crossvalidation);
            Collections.shuffle(arrayList);
            ArrayList arrayList2 = new ArrayList(arrayList.subList(0, 5000));
            for (TrainingData.SamplingStrategy samplingStrategy : TrainingData.SamplingStrategy.values()) {
                File file2 = new File("sample_" + samplingStrategy.name() + ".csv");
                if (!file2.exists()) {
                    List list = (List) arrayList2.stream().map(evaluationInstance -> {
                        return trainingData.sampleBy(evaluationInstance, samplingStrategy);
                    }).collect(Collectors.toList());
                    BufferedWriter writer = FileUtils.getWriter(file2);
                    for (int i = 0; i < arrayList2.size(); i++) {
                        try {
                            EvaluationInstance evaluationInstance2 = (EvaluationInstance) arrayList2.get(i);
                            writer.write(evaluationInstance2.name);
                            writer.write(9);
                            writer.write(evaluationInstance2.compound.inchiKey);
                            writer.write(9);
                            writer.write(evaluationInstance2.compound.fingerprint.toOneZeroString());
                            writer.write(9);
                            writer.write(evaluationInstance2.fingerprint.toTabSeparatedString());
                            for (double d : (double[]) list.get(i)) {
                                writer.write(9);
                                writer.write(String.valueOf(d));
                            }
                            writer.newLine();
                        } catch (Throwable th) {
                            if (writer != null) {
                                try {
                                    writer.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    }
                    if (writer != null) {
                        writer.close();
                    }
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static void evalFl(TrainingData trainingData, Report[] reportArr) {
        if (FLINDEX < 0) {
            int i = 0;
            for (int i2 : trainingData.classyFireMask.allowedIndizes()) {
                String name = trainingData.classyFireFingerprintVersion.getMolecularProperty(i2).getName();
                if (name.equals("Flavonoids")) {
                    FLINDEX = i;
                }
                if (name.equals("Flavonoid glycosides")) {
                    FLGINDEX = i;
                }
                i++;
            }
        }
        for (Report report : reportArr) {
            System.out.println("Flavonoids: " + report.performancePerClass[FLINDEX]);
            System.out.println("Flavonoid glycosides: " + report.performancePerClass[FLGINDEX]);
        }
    }

    private static String[] removeOpts(String[] strArr) {
        ArrayList arrayList = new ArrayList(strArr.length);
        for (String str : strArr) {
            if (str.startsWith("--no-norm")) {
                TrainingData.SCALE_BY_STD = false;
                TrainingData.SCALE_BY_MAX = false;
                TrainingData.PLATT_CENTERING = false;
                TrainingData.VECNORM_SCALING = false;
            } else if (str.startsWith("--l2=")) {
                REGSTREN = Double.parseDouble(str.split("=")[1]);
                System.out.println("Multiply the l2 norm with " + REGSTREN);
            } else {
                arrayList.add(str);
            }
        }
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    private static void writeExample(TensorflowModel tensorflowModel, TrainingData trainingData) throws IOException {
        EvaluationInstance evaluationInstance = trainingData.crossvalidation.get(0);
        TrainingBatch generateBatch = trainingData.generateBatch(Arrays.asList(evaluationInstance));
        try {
            float[][] predict = tensorflowModel.predict(generateBatch);
            if (!new File("example").exists()) {
                new File("example").mkdir();
            }
            float[][] fArr = (float[][]) generateBatch.formulas.copyTo(new float[1][(int) generateBatch.formulas.shape()[1]]);
            float[][] fArr2 = (float[][]) generateBatch.platts.copyTo(new float[1][(int) generateBatch.platts.shape()[1]]);
            float[][] fArr3 = (float[][]) generateBatch.labels.copyTo(new float[1][(int) generateBatch.labels.shape()[1]]);
            FileUtils.writeFloatMatrix(new File("example/formula.matrix"), fArr);
            FileUtils.writeFloatMatrix(new File("example/platts.matrix"), fArr2);
            FileUtils.writeFloatMatrix(new File("example/labels.matrix"), fArr3);
            FileUtils.writeFloatMatrix(new File("example/prediction.matrix"), predict);
            BufferedWriter writer = FileUtils.getWriter(new File("example/example.txt"));
            try {
                writer.write(evaluationInstance.compound.inchiKey);
                writer.write(9);
                writer.write(evaluationInstance.compound.formula.toString());
                writer.write(9);
                writer.write(evaluationInstance.fingerprint.toTabSeparatedString());
                writer.newLine();
                if (writer != null) {
                    writer.close();
                }
                PredictionPerformance.Modify modify = new PredictionPerformance().modify();
                for (int i = 0; i < fArr3[0].length; i++) {
                    modify.update(fArr3[0][i] > 0.0f, predict[0][i] > 0.0f);
                }
                System.out.println("Example: " + modify.done().toString());
                if (generateBatch != null) {
                    generateBatch.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (generateBatch != null) {
                try {
                    generateBatch.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static Report generateReportFromTrainData(List<EvaluationInstance> list, CustomFingerprintVersion customFingerprintVersion, MaskedFingerprintVersion maskedFingerprintVersion, TIntHashSet tIntHashSet, TIntHashSet tIntHashSet2) {
        TIntIntHashMap tIntIntHashMap = new TIntIntHashMap();
        int size = customFingerprintVersion.size();
        for (int i = 0; i < size; i++) {
            DummyMolecularProperty dummyMolecularProperty = (DummyMolecularProperty) customFingerprintVersion.getMolecularProperty(i);
            tIntIntHashMap.put(dummyMolecularProperty.absoluteIndex, dummyMolecularProperty.relativeIndex);
        }
        TIntHashSet tIntHashSet3 = new TIntHashSet();
        for (int i2 : maskedFingerprintVersion.allowedIndizes()) {
            tIntHashSet3.add(i2);
        }
        tIntHashSet3.retainAll(tIntIntHashMap.keySet());
        tIntIntHashMap.retainEntries((i3, i4) -> {
            return tIntHashSet3.contains(i3);
        });
        for (int i5 : tIntIntHashMap.values()) {
            tIntHashSet2.add(i5);
        }
        PredictionPerformance.Modify[] modifyArr = new PredictionPerformance.Modify[tIntHashSet3.size()];
        int[] array = tIntHashSet3.toArray();
        for (int i6 = 0; i6 < array.length; i6++) {
            modifyArr[i6] = new PredictionPerformance().modify();
        }
        Arrays.sort(array);
        for (EvaluationInstance evaluationInstance : list) {
            ArrayFingerprint arrayFingerprint = evaluationInstance.compound.fingerprint;
            ProbabilityFingerprint probabilityFingerprint = evaluationInstance.fingerprint;
            for (int i7 = 0; i7 < array.length; i7++) {
                int i8 = array[i7];
                modifyArr[i7].update(arrayFingerprint.isSet(i8), probabilityFingerprint.isSet(i8));
            }
        }
        PredictionPerformance[] predictionPerformanceArr = new PredictionPerformance[tIntHashSet3.size()];
        for (int i9 = 0; i9 < modifyArr.length; i9++) {
            predictionPerformanceArr[i9] = modifyArr[i9].done();
        }
        return new Report(predictionPerformanceArr);
    }

    public static void getDecisionValueOutputAndPerformance(File file, File file2, File file3, String str) throws IOException {
        TrainingData trainingData = new TrainingData(new File("."), str != null ? Pattern.compile(str) : null);
        TrainingBatch generateBatch = trainingData.generateBatch(trainingData.crossvalidation);
        try {
            TrainingBatch generateBatch2 = trainingData.generateBatch(trainingData.independent);
            try {
                Canopus loadFromFile = Canopus.loadFromFile(file2);
                TensorflowModel tensorflowModel = new TensorflowModel(file);
                try {
                    tensorflowModel.feedWeightMatrices(loadFromFile).resetWeights();
                    float[][] predict = tensorflowModel.predict(generateBatch);
                    float[][] predict2 = tensorflowModel.predict(generateBatch2);
                    writePredictOutput(file3, "crossvalidation", trainingData, trainingData.crossvalidation, predict);
                    writePredictOutput(file3, "independent", trainingData, trainingData.independent, predict2);
                    Random random = new Random();
                    ArrayList arrayList = new ArrayList();
                    ArrayList<CompoundClass> arrayList2 = new ArrayList(trainingData.compoundClasses.valueCollection());
                    arrayList2.sort(Comparator.comparingInt(compoundClass -> {
                        return compoundClass.index;
                    }));
                    ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
                    ArrayList arrayList3 = new ArrayList();
                    for (CompoundClass compoundClass2 : arrayList2) {
                        if (compoundClass2.compounds.isEmpty()) {
                            System.err.println("No example for " + compoundClass2.ontology.getName());
                        } else {
                            ArrayList arrayList4 = new ArrayList(compoundClass2.compounds);
                            Collections.shuffle(arrayList4, random);
                            for (int i = 0; i < Math.min(arrayList4.size(), 20); i++) {
                                LabeledCompound labeledCompound = (LabeledCompound) arrayList4.get(i);
                                arrayList3.add(newFixedThreadPool.submit(() -> {
                                    return new EvaluationInstance(compoundClass2.ontology.getName(), new ProbabilityFingerprint(trainingData.fingerprintVersion, trainingData.sampleFingerprintVector(labeledCompound, TrainingData.SamplingStrategy.DISTURBED_TEMPLATE)), labeledCompound);
                                }));
                            }
                        }
                    }
                    arrayList3.forEach(future -> {
                        try {
                            arrayList.add((EvaluationInstance) future.get());
                        } catch (InterruptedException | ExecutionException e) {
                            e.printStackTrace();
                            throw new RuntimeException(e);
                        }
                    });
                    newFixedThreadPool.shutdown();
                    TrainingBatch generateBatch3 = trainingData.generateBatch(arrayList);
                    try {
                        writePredictOutput(file3, "simulated", trainingData, arrayList, tensorflowModel.predict(generateBatch3));
                        if (generateBatch3 != null) {
                            generateBatch3.close();
                        }
                        tensorflowModel.close();
                        if (generateBatch2 != null) {
                            generateBatch2.close();
                        }
                        if (generateBatch != null) {
                            generateBatch.close();
                        }
                    } catch (Throwable th) {
                        if (generateBatch3 != null) {
                            try {
                                generateBatch3.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (Throwable th3) {
                    try {
                        tensorflowModel.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                    throw th3;
                }
            } catch (Throwable th5) {
                if (generateBatch2 != null) {
                    try {
                        generateBatch2.close();
                    } catch (Throwable th6) {
                        th5.addSuppressed(th6);
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (generateBatch != null) {
                try {
                    generateBatch.close();
                } catch (Throwable th8) {
                    th7.addSuppressed(th8);
                }
            }
            throw th7;
        }
    }

    public static void getDecisionValueOutputAndPerformanceOnIndep(File file, File file2, File file3, String str) throws IOException {
        TrainingData trainingData = new TrainingData(new File("."), str != null ? Pattern.compile(str) : null);
        TrainingBatch generateBatch = trainingData.generateBatch(trainingData.crossvalidation);
        try {
            TrainingBatch generateBatch2 = trainingData.generateBatch(trainingData.independent);
            try {
                Canopus loadFromFile = Canopus.loadFromFile(file2);
                TensorflowModel tensorflowModel = new TensorflowModel(file);
                try {
                    tensorflowModel.feedWeightMatrices(loadFromFile).resetWeights();
                    double[][] plattEstimate = tensorflowModel.plattEstimate(trainingData, false);
                    loadFromFile.setPlattCalibration(plattEstimate[0], plattEstimate[1]);
                    float[][] predict = tensorflowModel.predict(generateBatch);
                    float[][] predict2 = tensorflowModel.predict(generateBatch2);
                    writePredictOutput(file3, "crossvalidation", trainingData, trainingData.crossvalidation, predict);
                    writePredictOutput(file3, "independent", trainingData, trainingData.independent, predict2);
                    tensorflowModel.close();
                    if (generateBatch2 != null) {
                        generateBatch2.close();
                    }
                    if (generateBatch != null) {
                        generateBatch.close();
                    }
                } catch (Throwable th) {
                    try {
                        tensorflowModel.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                    throw th;
                }
            } catch (Throwable th3) {
                if (generateBatch2 != null) {
                    try {
                        generateBatch2.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                }
                throw th3;
            }
        } catch (Throwable th5) {
            if (generateBatch != null) {
                try {
                    generateBatch.close();
                } catch (Throwable th6) {
                    th5.addSuppressed(th6);
                }
            }
            throw th5;
        }
    }

    private static void writePredictOutput(File file, String str, TrainingData trainingData, List<EvaluationInstance> list, float[][] fArr) {
        BufferedWriter writer;
        file.mkdirs();
        ClassyfireProperty[] classyfirePropertyArr = new ClassyfireProperty[trainingData.compoundClasses.size()];
        PredictionPerformance.Modify[] modifyArr = new PredictionPerformance.Modify[trainingData.compoundClasses.size()];
        int i = 0;
        for (int i2 : trainingData.classyFireMask.allowedIndizes()) {
            PredictionPerformance.Modify modify = new PredictionPerformance(0.0d, 0.0d, 0.0d, 0.0d, 0.0d).modify();
            classyfirePropertyArr[i] = (ClassyfireProperty) trainingData.classyFireMask.getMolecularProperty(i2);
            modifyArr[i] = modify;
            i++;
        }
        try {
            writer = FileUtils.getWriter(new File(file, str + "_prediction.csv"));
            for (int i3 = 0; i3 < list.size(); i3++) {
                try {
                    EvaluationInstance evaluationInstance = list.get(i3);
                    writer.write(evaluationInstance.name);
                    writer.write(9);
                    writer.write(evaluationInstance.compound.inchiKey);
                    writer.write(9);
                    writer.write(evaluationInstance.compound.label.toOneZeroString());
                    boolean[] booleanArray = evaluationInstance.compound.label.toBooleanArray();
                    float[] fArr2 = fArr[i3];
                    for (int i4 = 0; i4 < fArr2.length; i4++) {
                        writer.write(9);
                        writer.write(String.valueOf(fArr2[i4]));
                        modifyArr[i4].update(booleanArray[i4], fArr2[i4] >= 0.0f);
                    }
                    writer.newLine();
                } finally {
                    if (writer != null) {
                        try {
                            writer.close();
                        } catch (Throwable th) {
                            th.addSuppressed(th);
                        }
                    }
                }
            }
            if (writer != null) {
                writer.close();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        try {
            writer = FileUtils.getWriter(new File(file, str + "_stats.csv"));
            try {
                writer.write("index\tname\tid\tparent\tparentId\t" + PredictionPerformance.csvHeader());
                for (int i5 = 0; i5 < classyfirePropertyArr.length; i5++) {
                    ClassyfireProperty classyfireProperty = classyfirePropertyArr[i5];
                    writer.write(String.valueOf(i5));
                    writer.write(9);
                    writer.write(classyfireProperty.getName());
                    writer.write(9);
                    writer.write(classyfireProperty.getChemontIdentifier());
                    writer.write(9);
                    writer.write(classyfireProperty.getParent() == null ? "" : classyfireProperty.getParent().getName());
                    writer.write(9);
                    writer.write(classyfireProperty.getParent() == null ? "" : classyfireProperty.getParent().getChemontIdentifier());
                    writer.write(9);
                    writer.write(modifyArr[i5].done().toCsvRow());
                }
                if (writer != null) {
                    writer.close();
                }
            } finally {
            }
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    private static void writeNPCPredictOutput(File file, String str, TrainingData trainingData, List<EvaluationInstance> list, float[][] fArr) {
        BufferedWriter writer;
        file.mkdirs();
        NPCFingerprintVersion.NPCProperty[] nPCPropertyArr = new NPCFingerprintVersion.NPCProperty[trainingData.NPCVersion.size()];
        PredictionPerformance.Modify[] modifyArr = new PredictionPerformance.Modify[trainingData.NPCVersion.size()];
        int i = 0;
        for (int i2 = 0; i2 < trainingData.NPCVersion.size(); i2++) {
            PredictionPerformance.Modify modify = new PredictionPerformance(0.0d, 0.0d, 0.0d, 0.0d, 0.0d).modify();
            nPCPropertyArr[i] = (NPCFingerprintVersion.NPCProperty) trainingData.NPCVersion.getMolecularProperty(i2);
            modifyArr[i] = modify;
            i++;
        }
        try {
            writer = FileUtils.getWriter(new File(file, str + "npc_prediction.csv"));
            for (int i3 = 0; i3 < list.size(); i3++) {
                try {
                    EvaluationInstance evaluationInstance = list.get(i3);
                    writer.write(evaluationInstance.name);
                    writer.write(9);
                    writer.write(evaluationInstance.compound.inchiKey);
                    writer.write(9);
                    writer.write(evaluationInstance.compound.npcLabel.toOneZeroString());
                    boolean[] booleanArray = evaluationInstance.compound.npcLabel.toBooleanArray();
                    float[] fArr2 = fArr[i3];
                    for (int i4 = 0; i4 < fArr2.length; i4++) {
                        writer.write(9);
                        writer.write(String.valueOf(fArr2[i4]));
                        modifyArr[i4].update(booleanArray[i4], fArr2[i4] >= 0.0f);
                    }
                    writer.newLine();
                } finally {
                    if (writer != null) {
                        try {
                            writer.close();
                        } catch (Throwable th) {
                            th.addSuppressed(th);
                        }
                    }
                }
            }
            if (writer != null) {
                writer.close();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        try {
            writer = FileUtils.getWriter(new File(file, str + "npc_stats.csv"));
            try {
                writer.write("index\tname\tid\ttype\t" + PredictionPerformance.csvHeader());
                for (int i5 = 0; i5 < nPCPropertyArr.length; i5++) {
                    NPCFingerprintVersion.NPCProperty nPCProperty = nPCPropertyArr[i5];
                    writer.write(String.valueOf(i5));
                    writer.write(9);
                    writer.write(nPCProperty.name);
                    writer.write(9);
                    writer.write(String.valueOf(nPCProperty.npcIndex));
                    writer.write(9);
                    writer.write(nPCProperty.level.name);
                    writer.write(9);
                    writer.write(modifyArr[i5].done().toCsvRow());
                }
                if (writer != null) {
                    writer.close();
                }
            } finally {
            }
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    public static void continueModel(File file, File file2, String str, boolean z) throws IOException {
        TrainingBatch poll;
        Canopus loadFromFile = Canopus.loadFromFile(file2);
        TrainingData trainingData = new TrainingData(new File("."), Pattern.compile(str));
        BatchGenerator batchGenerator = new BatchGenerator(trainingData, 20);
        batchGenerator.iterationNum.set(30000);
        Thread thread = new Thread(batchGenerator);
        Thread thread2 = new Thread(batchGenerator);
        thread.start();
        thread2.start();
        batchGenerator.poll(0);
        HashSet hashSet = new HashSet();
        Iterator<EvaluationInstance> it = trainingData.crossvalidation.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().compound.inchiKey);
        }
        ArrayList arrayList = new ArrayList(trainingData.crossvalidation);
        arrayList.addAll(trainingData.independent);
        TrainingBatch generateBatch = trainingData.generateBatch(arrayList);
        new ArrayList();
        TensorflowModel tensorflowModel = new TensorflowModel(file);
        try {
            TensorflowModel.Resetter feedWeightMatrices = tensorflowModel.feedWeightMatrices(loadFromFile);
            feedWeightMatrices.resetWeights();
            double[][] plattEstimate = tensorflowModel.plattEstimate(trainingData);
            double[][] plattEstimateForNPC = tensorflowModel.plattEstimateForNPC(trainingData, true);
            int i = 0;
            reportStuff(Arrays.asList(generateBatch), Arrays.asList("all"), tensorflowModel, 0);
            for (int i2 = 0; i2 < 400; i2++) {
                poll = batchGenerator.poll(30000 + i2);
                try {
                    long currentTimeMillis = System.currentTimeMillis();
                    double[] train = tensorflowModel.train(poll);
                    long currentTimeMillis2 = System.currentTimeMillis();
                    PrintStream printStream = System.out;
                    double d = (currentTimeMillis2 - currentTimeMillis) / 1000.0d;
                    printStream.println((0 + i2) + ".)\tloss = " + train[0] + "\tl2 norm = " + printStream + "\t (" + train[1] + " s)");
                    if (poll != null) {
                        poll.close();
                    }
                    if (i2 % 50 == 0) {
                        reportStuff(Arrays.asList(generateBatch), Arrays.asList("all"), tensorflowModel, 0);
                        System.out.println("---> reset all weights.");
                        feedWeightMatrices.resetWeights();
                        reportStuff(Arrays.asList(generateBatch), Arrays.asList("all"), tensorflowModel, 0);
                    }
                } finally {
                }
            }
            feedWeightMatrices.resetWeights();
            System.out.println("---> reset all weights.");
            for (int i3 = 0; i3 < 1000; i3++) {
                i = 30000 + i3;
                if (i3 % 10 == 0) {
                    long currentTimeMillis3 = System.currentTimeMillis();
                    double[] train2 = tensorflowModel.train(generateBatch);
                    long currentTimeMillis4 = System.currentTimeMillis();
                    PrintStream printStream2 = System.out;
                    double d2 = (currentTimeMillis4 - currentTimeMillis3) / 1000.0d;
                    printStream2.println(i + ".)\tloss = " + train2[0] + "\tl2 norm = " + printStream2 + "\t (" + train2[1] + " s)");
                } else {
                    poll = batchGenerator.poll(i);
                    try {
                        long currentTimeMillis5 = System.currentTimeMillis();
                        double[] train3 = tensorflowModel.train(poll);
                        long currentTimeMillis6 = System.currentTimeMillis();
                        PrintStream printStream3 = System.out;
                        double d3 = (currentTimeMillis6 - currentTimeMillis5) / 1000.0d;
                        printStream3.println(i + ".)\tloss = " + train3[0] + "\tl2 norm = " + printStream3 + "\t (" + train3[1] + " s)");
                        if (poll != null) {
                            poll.close();
                        }
                    } finally {
                    }
                }
            }
            reportStuff(Arrays.asList(generateBatch), Arrays.asList("all"), tensorflowModel, i);
            tensorflowModel.saveWithoutPlattEstimate(trainingData, 100, true, true, false, plattEstimate[0], plattEstimate[1], plattEstimateForNPC[0], plattEstimateForNPC[1]);
            tensorflowModel.close();
            generateBatch.close();
        } catch (Throwable th) {
            try {
                tensorflowModel.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static void reportStuff(TrainingBatch trainingBatch, TrainingBatch trainingBatch2, TrainingBatch trainingBatch3, TrainingBatch trainingBatch4, int[] iArr, List<DummyMolecularProperty> list, Report report, Report report2, TensorflowModel tensorflowModel, int i) {
        System.out.println("Evaluation " + i + ".) " + tensorflowModel.evaluate(trainingBatch));
        System.out.println("Crossvalidation " + i + ".) " + tensorflowModel.evaluate(trainingBatch2));
        if (trainingBatch3 != null) {
            System.out.println("Indep. " + i + ".) " + tensorflowModel.evaluate(trainingBatch3));
        }
        if (trainingBatch4 != null) {
            System.out.println("Indep. Novel " + i + ".) " + tensorflowModel.evaluate(trainingBatch4));
        }
    }

    private static void evalNPC(TrainingBatch trainingBatch, TensorflowModel tensorflowModel) {
        Report evaluateNPC = tensorflowModel.evaluateNPC(trainingBatch);
        System.out.print("NPC Evaluation:\t");
        System.out.println(evaluateNPC);
    }

    private static Report[] reportStuff(List<TrainingBatch> list, List<String> list2, TensorflowModel tensorflowModel, int i) {
        Report[] reportArr = new Report[list.size()];
        for (int i2 = 0; i2 < list.size(); i2++) {
            TrainingBatch trainingBatch = list.get(i2);
            String str = list2.get(i2);
            Report evaluate = tensorflowModel.evaluate(trainingBatch);
            reportArr[i2] = evaluate;
            System.out.println(i + ".) " + str + ":\t" + evaluate);
        }
        return reportArr;
    }

    private static void fix() {
        try {
            Canopus loadFromFile = Canopus.loadFromFile(new File("canopus_1.data.gz"));
            loadFromFile.cdkFingerprintVersion = TrainingData.VERSION;
            new TIntArrayList();
            String[] readLines = FileUtils.readLines(new File("trainable_indizes.csv"));
            MaskedFingerprintVersion.Builder buildMaskFor = MaskedFingerprintVersion.buildMaskFor(loadFromFile.cdkFingerprintVersion);
            buildMaskFor.disableAll();
            for (int i = 1; i < readLines.length; i++) {
                buildMaskFor.enable(Integer.parseInt(readLines[i].split("\t")[0]));
            }
            loadFromFile.cdkMask = buildMaskFor.toMask();
            loadFromFile.writeToFile(new File("canopus_fp.data.gz"));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static String tensor2string(Tensor tensor) {
        float[][] fArr = new float[(int) tensor.shape()[0]][(int) tensor.shape()[1]];
        tensor.copyTo(fArr);
        StringBuilder sb = new StringBuilder();
        sb.append("{\n");
        for (float[] fArr2 : fArr) {
            sb.append('\t').append(Arrays.toString(fArr2)).append('\n');
        }
        sb.append('}');
        return sb.toString();
    }

    static {
        System.setProperty("org.apache.commons.logging.Log", "org.apache.commons.logging.impl.NoOpLog");
        System.setProperty("de.unijena.bioinf.ms.propertyLocations", "sirius.build.properties, csi_fingerid.build.properties");
        try {
            InputStream resourceAsStream = Learn.class.getResourceAsStream("/logging.properties");
            try {
                LogManager.getLogManager().readConfiguration(resourceAsStream);
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
            } finally {
            }
        } catch (IOException e) {
            System.err.println("Could not read logging configuration.");
            e.printStackTrace();
        }
        FLINDEX = -1;
        FLGINDEX = -1;
        REGSTREN = 1.0d;
    }
}
