package de.unijena.bioinf.fingerid.cli.tools;

import com.google.common.base.Joiner;
import de.unijena.bioinf.ChemistryBase.chem.InChI;
import de.unijena.bioinf.ChemistryBase.fp.AbstractFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.ArrayFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.BooleanFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.CdkFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.Fingerprint;
import de.unijena.bioinf.ChemistryBase.fp.MaskedFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.fp.ProbabilityFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.Tanimoto;
import de.unijena.bioinf.ChemistryBase.ms.Ms2Experiment;
import de.unijena.bioinf.ChemistryBase.ms.ft.FTree;
import de.unijena.bioinf.chemdb.ChemicalDatabase;
import de.unijena.bioinf.fingerid.Fingerprinter;
import de.unijena.bioinf.fingerid.KernelCentering;
import de.unijena.bioinf.fingerid.KernelMatrix;
import de.unijena.bioinf.fingerid.KernelToNumpyConverter;
import de.unijena.bioinf.fingerid.MatrixUtils;
import de.unijena.bioinf.fingerid.Prediction;
import de.unijena.bioinf.fingerid.Predictor;
import de.unijena.bioinf.fingerid.SpectralPreprocessor;
import de.unijena.bioinf.fingerid.Train;
import de.unijena.bioinf.fingerid.TrainResult;
import de.unijena.bioinf.fingerid.TrainedCSIFingerId;
import de.unijena.bioinf.fingerid.cli.CliTool;
import de.unijena.bioinf.fingerid.cli.Compound;
import de.unijena.bioinf.fingerid.cli.Configuration;
import de.unijena.bioinf.fingerid.cli.Reporter;
import de.unijena.bioinf.fingerid.cli.ToolSet;
import de.unijena.bioinf.sirius.IdentificationResult;
import de.unijena.bioinf.sirius.Sirius;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.json.JSONException;
import org.openscience.cdk.exception.CDKException;

/* loaded from: input_file:de/unijena/bioinf/fingerid/cli/tools/EvaluateSingleKernel.class */
public class EvaluateSingleKernel implements CliTool {
    protected int MODE = 0;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:de/unijena/bioinf/fingerid/cli/tools/EvaluateSingleKernel$KernelM.class */
    public static class KernelM {
        protected String name;
        protected double[][] matrix;
        protected KernelCentering centering;
        protected double[] norm;
        protected double weight;

        public KernelM(String str, int i, double[][] dArr, double d) {
            this.name = str;
            this.matrix = dArr;
            this.norm = new double[dArr.length];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                this.norm[i2] = dArr[i2][i2];
            }
            if (i == 0) {
                MatrixUtils.normalize(dArr);
                this.centering = new KernelCentering(dArr, false);
                this.centering.applyToTrainMatrix(dArr);
            } else if (i == 1) {
                this.centering = new KernelCentering(dArr, true);
                this.centering.applyToTrainMatrix(dArr);
                new KernelCentering(dArr, false).applyToTrainMatrix(dArr);
            } else {
                double d2 = 0.0d;
                for (double[] dArr2 : dArr) {
                    for (double d3 : dArr2) {
                        d2 = Math.max(d3, d2);
                    }
                }
                MatrixUtils.applyScale(dArr, 1.0d / d2);
                this.centering = new KernelCentering(dArr, false);
                this.centering.applyToTrainMatrix(dArr);
            }
            this.weight = d;
        }
    }

    @Override // de.unijena.bioinf.fingerid.cli.CliTool
    public void run(ToolSet toolSet, final Configuration configuration, Reporter reporter) {
        String str;
        double d;
        boolean hasArg = configuration.hasArg("--maccs");
        this.MODE = 0;
        if (configuration.hasArg("--centerNormalize")) {
            this.MODE = 1;
        } else if (configuration.hasArg("--no-norm")) {
            this.MODE = 2;
        }
        boolean hasArg2 = configuration.hasArg("--alignf");
        ArrayList arrayList = new ArrayList();
        final HashMap hashMap = new HashMap();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() / 2);
        ArrayList arrayList2 = new ArrayList();
        String str2 = null;
        File file = null;
        ArrayList<String> arrayList3 = new ArrayList(Arrays.asList(configuration.getArgs()));
        int i = 0;
        while (true) {
            if (i >= arrayList3.size()) {
                break;
            }
            if (((String) arrayList3.get(i)).equals("--name")) {
                str2 = (String) arrayList3.get(i + 1);
                arrayList3.remove("--name");
                arrayList3.remove(str2);
                break;
            } else {
                if (((String) arrayList3.get(i)).equals("--weights")) {
                    file = new File((String) arrayList3.get(i + 1));
                    arrayList3.remove("--weights");
                    arrayList3.remove(file);
                }
                i++;
            }
        }
        if (hasArg2) {
            file = new File("mkl/ALIGNF.weights");
        }
        if (file != null) {
            try {
                Iterator<String> it = Files.readAllLines(file.toPath(), Charset.forName("UTF-8")).iterator();
                while (it.hasNext()) {
                    String[] split = it.next().split("\\s+", 2);
                    arrayList3.add(split[0] + ":" + split[1]);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        System.out.println(arrayList3);
        for (String str3 : arrayList3) {
            if (!str3.startsWith("--")) {
                if (str3.contains(":")) {
                    String[] split2 = str3.split(":");
                    str = split2[0];
                    d = Double.parseDouble(split2[1]);
                } else {
                    str = str3;
                    d = 1.0d;
                }
                final double d2 = d;
                arrayList.add(str);
                System.out.println("Process " + str);
                final String str4 = str;
                arrayList2.add(newFixedThreadPool.submit(new Callable<KernelM>() { // from class: de.unijena.bioinf.fingerid.cli.tools.EvaluateSingleKernel.1
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public KernelM call() throws Exception {
                        KernelM kernelM = new KernelM(str4, EvaluateSingleKernel.this.MODE, new KernelToNumpyConverter().readFromFile(configuration.getKernelFile(str4)), d2);
                        synchronized (hashMap) {
                            hashMap.put(str4, kernelM);
                        }
                        return kernelM;
                    }
                }));
            }
        }
        Future submit = newFixedThreadPool.submit(new Callable<List<Compound>>() { // from class: de.unijena.bioinf.fingerid.cli.tools.EvaluateSingleKernel.2
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public List<Compound> call() throws Exception {
                return configuration.getCompounds(true);
            }
        });
        newFixedThreadPool.shutdown();
        MaskedFingerprintVersion maskedFingerprintVersion = configuration.getMaskedFingerprintVersion();
        if (hasArg) {
            maskedFingerprintVersion = maskedFingerprintVersion.getIntersection(configuration.getFingerprintVersion().getMaskFor(new CdkFingerprintVersion.USED_FINGERPRINTS[]{CdkFingerprintVersion.USED_FINGERPRINTS.MACCS}));
        }
        try {
            List list = (List) submit.get();
            InChI[] inChIArr = new InChI[list.size()];
            HashSet hashSet = new HashSet();
            ArrayFingerprint[] arrayFingerprintArr = new ArrayFingerprint[list.size()];
            for (int i2 = 0; i2 < list.size(); i2++) {
                arrayFingerprintArr[i2] = maskedFingerprintVersion.mask(configuration.getFingerprint((Compound) list.get(i2))).asArray();
                inChIArr[i2] = ((Compound) list.get(i2)).getInchi();
                hashSet.add(((Compound) list.get(i2)).getInchi().key2D());
            }
            double[][] scale = MatrixUtils.scale(((KernelM) ((Future) arrayList2.get(0)).get()).matrix, ((KernelM) ((Future) arrayList2.get(0)).get()).weight);
            for (int i3 = 1; i3 < arrayList2.size(); i3++) {
                MatrixUtils.applyWeightedSum(scale, ((KernelM) ((Future) arrayList2.get(i3)).get()).matrix, ((KernelM) ((Future) arrayList2.get(i3)).get()).weight);
            }
            Train train = new Train(inChIArr, arrayFingerprintArr, scale);
            train.setCSelections(configuration.getCSelection());
            train.sequentialCrossvalidation(configuration.getArg("--folds") == null ? 5 : Integer.parseInt(configuration.getArg("--folds")));
            TrainResult startTraining = train.startTraining();
            PrintStream printStream = new PrintStream("kernelEval_" + (str2 == null ? Joiner.on('_').join(arrayList) : str2) + ".txt");
            printStream.println("Used Kernels: " + Joiner.on(", ").join(arrayList));
            printStream.print("Average performance over all fingerprints with at least 25 samples: ");
            double d3 = 0.0d;
            double d4 = 0.0d;
            int i4 = 0;
            int i5 = 0;
            for (Predictor predictor : startTraining.predictors) {
                if (predictor.getTp() + predictor.getFn() >= 25.0d) {
                    d3 += predictor.getPerformance().getF();
                    i4++;
                    if (predictor.getTp() + predictor.getFn() >= 100.0d) {
                        i5++;
                        d4 += predictor.getPerformance().getF();
                    }
                }
            }
            printStream.println(d3 / i4);
            printStream.print("Average performance over all fingerprints with at least 100 samples: " + (d4 / i5));
            printStream.println("\nDetails:");
            for (Predictor predictor2 : startTraining.predictors) {
                printStream.println(predictor2.getPerformance());
            }
            if (configuration.getArg("--independent") != null) {
                TrainedCSIFingerId trainedCSIFingerId = new TrainedCSIFingerId(maskedFingerprintVersion, list.size(), arrayList.size());
                for (int i6 = 0; i6 < list.size(); i6++) {
                    Compound compound = (Compound) list.get(i6);
                    trainedCSIFingerId.getInchis()[i6] = compound.getInchi().in2D;
                    trainedCSIFingerId.getNames()[i6] = compound.getName();
                    trainedCSIFingerId.getPrecursorMz()[i6] = configuration.getPrecursorMass(compound);
                    trainedCSIFingerId.getTrainingFingerprints()[i6] = arrayFingerprintArr[i6];
                    trainedCSIFingerId.getTrainingTrees()[i6] = configuration.getCompoundTree(compound);
                    trainedCSIFingerId.getTrainingSpectra()[i6] = configuration.getSpectrum(compound);
                }
                for (int i7 = 0; i7 < arrayList.size(); i7++) {
                    KernelM kernelM = (KernelM) hashMap.get(arrayList.get(i7));
                    trainedCSIFingerId.getKernels()[i7] = new KernelMatrix((String) arrayList.get(i7), kernelM.centering, kernelM.norm, kernelM.weight);
                }
                Arrays.fill(trainedCSIFingerId.getKernelNormalizationVector(), 1.0d);
                for (int i8 = 0; i8 < maskedFingerprintVersion.size(); i8++) {
                    trainedCSIFingerId.getPredictors()[i8] = startTraining.predictors[i8];
                }
                HashSet hashSet2 = new HashSet();
                Iterator it2 = list.iterator();
                while (it2.hasNext()) {
                    hashSet2.add(((Compound) it2.next()).getInchi().key2D());
                }
                Prediction prediction = new Prediction(trainedCSIFingerId);
                Sirius sirius = configuration.getSirius();
                SpectralPreprocessor spectralPreprocessor = new SpectralPreprocessor(sirius.getMs2Analyzer());
                PredictionPerformance.Modify[] modifyArr = new PredictionPerformance.Modify[startTraining.predictors.length];
                for (int i9 = 0; i9 < modifyArr.length; i9++) {
                    modifyArr[i9] = new PredictionPerformance().modify();
                }
                int i10 = 0;
                double d5 = 0.0d;
                Fingerprinter fingerprinter = configuration.getFingerprinter();
                ChemicalDatabase fingerprintDb = configuration.getFingerprintDb();
                for (File file2 : (configuration.getArg("--independent") != null ? new File(configuration.getArg("--independent")) : new File(configuration.getRootDirectory(), "independent")).listFiles()) {
                    if (file2.getName().endsWith(".ms")) {
                        try {
                            Ms2Experiment ms2Experiment = (Ms2Experiment) sirius.parseExperiment(file2).next();
                            InChI inChI = (InChI) ms2Experiment.getAnnotation(InChI.class);
                            if (inChI != null && !hashSet2.contains(inChI.key2D())) {
                                IdentificationResult compute = sirius.compute(ms2Experiment, ms2Experiment.getMolecularFormula());
                                sirius.beautifyTree(compute, ms2Experiment);
                                AbstractFingerprint lookupFingerprintByInChI = fingerprintDb.lookupFingerprintByInChI(inChI);
                                if (lookupFingerprintByInChI == null) {
                                    try {
                                        lookupFingerprintByInChI = new BooleanFingerprint(maskedFingerprintVersion.getMaskedFingerprintVersion(), fingerprinter.fingerprintsToBooleans(fingerprinter.computeFingerprints(fingerprinter.convertInchi2Mol(inChI.in2D))));
                                    } catch (CDKException e2) {
                                        e2.printStackTrace();
                                    }
                                }
                                Fingerprint mask = maskedFingerprintVersion.mask(lookupFingerprintByInChI);
                                FTree beautifulTree = compute.getBeautifulTree();
                                spectralPreprocessor.preprocessTrees(beautifulTree);
                                ProbabilityFingerprint predictProbabilityFingerprint = prediction.predictProbabilityFingerprint(spectralPreprocessor.preprocess(ms2Experiment, beautifulTree), beautifulTree, spectralPreprocessor.getPrecursorMass(beautifulTree));
                                double[] probabilityArray = predictProbabilityFingerprint.toProbabilityArray();
                                boolean[] booleanArray = mask.toBooleanArray();
                                for (int i11 = 0; i11 < modifyArr.length; i11++) {
                                    modifyArr[i11].update(booleanArray[i11], probabilityArray[i11] >= 0.5d);
                                }
                                double expectationValue = Tanimoto.probabilisticTanimoto(mask, predictProbabilityFingerprint).expectationValue();
                                d5 += expectationValue;
                                i10++;
                                System.out.println("Predict " + file2.getName() + " with tanimoto = " + expectationValue);
                            }
                        } catch (Exception e3) {
                            e3.printStackTrace();
                        }
                    }
                }
                fingerprintDb.close();
                prediction.shutdown();
                double d6 = 0.0d;
                for (PredictionPerformance.Modify modify : modifyArr) {
                    d6 += modify.done().getF();
                }
                printStream.println("Average F1 for independent data: " + (d6 / modifyArr.length) + " with average tanimoto: " + (d5 / i10));
                for (PredictionPerformance.Modify modify2 : modifyArr) {
                    printStream.println(modify2.done().toString());
                }
            }
            printStream.close();
        } catch (IOException e4) {
            e4.printStackTrace();
        } catch (InterruptedException e5) {
            e5.printStackTrace();
        } catch (ExecutionException e6) {
            e6.printStackTrace();
        } catch (JSONException e7) {
            e7.printStackTrace();
        }
    }

    @Override // de.unijena.bioinf.fingerid.cli.CliTool
    public String getName() {
        return "evaluate-kernel";
    }

    @Override // de.unijena.bioinf.fingerid.cli.CliTool
    public String getDescription() {
        return "";
    }
}
