package de.unijena.bioinf.fingerid.cli.tools;

import de.unijena.bioinf.ChemistryBase.chem.CompoundWithAbstractFP;
import de.unijena.bioinf.ChemistryBase.chem.InChI;
import de.unijena.bioinf.ChemistryBase.fp.MaskedFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.fp.ProbabilityFingerprint;
import de.unijena.bioinf.ConfidenceScore.EvalConfidenceScore;
import de.unijena.bioinf.ConfidenceScore.TrainConfidenceScore;
import de.unijena.bioinf.chemdb.BioFilter;
import de.unijena.bioinf.chemdb.ChemicalDatabase;
import de.unijena.bioinf.chemdb.DatabaseException;
import de.unijena.bioinf.fingerid.Mask;
import de.unijena.bioinf.fingerid.TrainedCSIFingerId;
import de.unijena.bioinf.fingerid.cli.CliTool;
import de.unijena.bioinf.fingerid.cli.Configuration;
import de.unijena.bioinf.fingerid.cli.Reporter;
import de.unijena.bioinf.fingerid.cli.ToolSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedReader;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:de/unijena/bioinf/fingerid/cli/tools/ConfidenceScore.class */
public class ConfidenceScore implements CliTool {
    private static final String helpMessage = "usage: \ncommand 'train': train a model \n       train <queryFile> <path to outputModel> <solve 'primal' or 'dual'> [<maskfile>]\ncommand 'predict': predict using trained model\n       predict <queryFile> <modelFile> <outputFile> [<maskfile>]\ncommand 'crossvalidation': crossvalidation on a dataset \n       crossvalidation <queryFile> <outputFile> <solve 'primal' or 'dual'> [<maskfile>]\nor simply: 'train [primal|dual]' or 'predict' or 'crossvalidation [primal|dual]'\nadding --pubchem as last command causes the tool to search in whole pubchem";

    @Override // de.unijena.bioinf.fingerid.cli.CliTool
    public void run(ToolSet toolSet, Configuration configuration, Reporter reporter) {
        Path path;
        Path path2;
        boolean z;
        Path path3;
        Path path4;
        Path path5;
        Path path6;
        Path path7;
        boolean z2;
        String[] args = configuration.getArgs();
        if (args.length == 0 || args[0].matches("-*help")) {
            reporter.report(this, helpMessage);
            return;
        }
        try {
            ChemicalDatabase fingerprintDb = configuration.getFingerprintDb();
            if (args[args.length - 1].toLowerCase().equals("--pubchem")) {
                fingerprintDb.setBioFilter(BioFilter.ALL);
                args = (String[]) Arrays.copyOf(args, args.length - 1);
            } else {
                fingerprintDb.setBioFilter(BioFilter.ONLY_BIO);
            }
            if (fingerprintDb == null) {
                reporter.error(this, new RuntimeException("no Chemical database found"));
                return;
            }
            try {
                TrainedCSIFingerId load = TrainedCSIFingerId.load(configuration.fingeridFile());
                Mask mask = configuration.getMask();
                PredictionPerformance[] predictionPerformances = load.getPredictionPerformances();
                MaskedFingerprintVersion maskedFingerprintVersion = load.getMaskedFingerprintVersion();
                try {
                    if (args[0].toLowerCase().equals("train")) {
                        Mask mask2 = mask;
                        if (args.length == 2 && (args[1].toLowerCase().contains("dual") || args[1].toLowerCase().contains("prim"))) {
                            path6 = configuration.getCrossvalidationPredictionFile().toPath();
                            path7 = configuration.confidenceScoreModel().toPath();
                            if (args[1].toLowerCase().contains("dual")) {
                                z2 = true;
                            } else {
                                if (!args[1].toLowerCase().contains("prim")) {
                                    reporter.report(this, helpMessage);
                                    return;
                                }
                                z2 = false;
                            }
                        } else {
                            if (args.length != 4 && args.length != 5) {
                                reporter.report(this, "Wrong number of arguments.\nusage: \ncommand 'train': train a model \n       train <queryFile> <path to outputModel> <solve 'primal' or 'dual'> [<maskfile>]\ncommand 'predict': predict using trained model\n       predict <queryFile> <modelFile> <outputFile> [<maskfile>]\ncommand 'crossvalidation': crossvalidation on a dataset \n       crossvalidation <queryFile> <outputFile> <solve 'primal' or 'dual'> [<maskfile>]\nor simply: 'train [primal|dual]' or 'predict' or 'crossvalidation [primal|dual]'\nadding --pubchem as last command causes the tool to search in whole pubchem");
                                return;
                            }
                            path6 = Paths.get(args[1], new String[0]);
                            path7 = Paths.get(args[2], new String[0]);
                            if (args[3].toLowerCase().contains("dual")) {
                                z2 = true;
                            } else {
                                if (!args[3].toLowerCase().contains("prim")) {
                                    reporter.report(this, helpMessage);
                                    return;
                                }
                                z2 = false;
                            }
                            if (args.length == 5) {
                                if (args[4].equals("--nomask")) {
                                    int size = load.getMaskedFingerprintVersion().getMaskedFingerprintVersion().size();
                                    mask2 = new Mask(size);
                                    TIntHashSet tIntHashSet = new TIntHashSet(load.getFingerprintIndizes());
                                    for (int i = 0; i < size; i++) {
                                        if (!tIntHashSet.contains(i)) {
                                            mask2.disableFingerprint(i);
                                        }
                                    }
                                } else {
                                    mask2 = Mask.fromString(Files.newBufferedReader(Paths.get(args[4], new String[0]), Charset.defaultCharset()).readLine().split("\\s+"));
                                }
                            }
                        }
                        if (!path6.toFile().exists()) {
                            reporter.error(this, new RuntimeException("no query file found: " + path6.toString()));
                            return;
                        }
                        List<CompoundWithAbstractFP<ProbabilityFingerprint>> parseQueries = parseQueries(path6, maskedFingerprintVersion, mask2);
                        if (fingerprintDb.getBioFilter() == BioFilter.ALL) {
                            EvalConfidenceScore.train(parseQueries, predictionPerformances, maskedFingerprintVersion, path7, fingerprintDb, TrainConfidenceScore.AllLong(!z2));
                        } else {
                            EvalConfidenceScore.train(parseQueries, predictionPerformances, maskedFingerprintVersion, path7, !z2, fingerprintDb);
                        }
                        return;
                    }
                    if (args[0].toLowerCase().equals("predict")) {
                        Mask mask3 = null;
                        if (args.length == 1) {
                            path3 = configuration.getIndependentOutputFile().toPath();
                            path4 = configuration.confidenceScoreModel().toPath();
                            path5 = configuration.confidenceScoreDir().toPath().resolve("independent_confidenceScore.csv");
                        } else {
                            if (args.length != 4 && args.length != 5) {
                                reporter.report(this, "Wrong number of arguments.\nusage: \ncommand 'train': train a model \n       train <queryFile> <path to outputModel> <solve 'primal' or 'dual'> [<maskfile>]\ncommand 'predict': predict using trained model\n       predict <queryFile> <modelFile> <outputFile> [<maskfile>]\ncommand 'crossvalidation': crossvalidation on a dataset \n       crossvalidation <queryFile> <outputFile> <solve 'primal' or 'dual'> [<maskfile>]\nor simply: 'train [primal|dual]' or 'predict' or 'crossvalidation [primal|dual]'\nadding --pubchem as last command causes the tool to search in whole pubchem");
                                return;
                            }
                            path3 = Paths.get(args[1], new String[0]);
                            path4 = Paths.get(args[2], new String[0]);
                            path5 = Paths.get(args[3], new String[0]);
                            if (args.length == 5) {
                                mask3 = Mask.fromString(Files.newBufferedReader(Paths.get(args[4], new String[0]), Charset.defaultCharset()).readLine().split("\\s+"));
                            }
                        }
                        EvalConfidenceScore.predict(parseQueries(path3, maskedFingerprintVersion, mask3), maskedFingerprintVersion, path4, path5, fingerprintDb);
                        return;
                    }
                    if (!args[0].toLowerCase().equals("crossvalidation")) {
                        reporter.report(this, helpMessage);
                        return;
                    }
                    Mask mask4 = null;
                    if (args.length == 2 && (args[1].toLowerCase().contains("dual") || args[1].toLowerCase().contains("prim"))) {
                        path = configuration.getCrossvalidationPredictionFile().toPath();
                        mask4 = mask;
                        path2 = configuration.confidenceScoreDir().toPath().resolve("crossvalidation_confidenceScore.csv");
                        if (args[1].toLowerCase().contains("dual")) {
                            z = true;
                        } else {
                            if (!args[1].toLowerCase().contains("prim")) {
                                reporter.report(this, helpMessage);
                                return;
                            }
                            z = false;
                        }
                    } else {
                        if (args.length != 4 && args.length != 5) {
                            reporter.report(this, "Wrong number of arguments.\nusage: \ncommand 'train': train a model \n       train <queryFile> <path to outputModel> <solve 'primal' or 'dual'> [<maskfile>]\ncommand 'predict': predict using trained model\n       predict <queryFile> <modelFile> <outputFile> [<maskfile>]\ncommand 'crossvalidation': crossvalidation on a dataset \n       crossvalidation <queryFile> <outputFile> <solve 'primal' or 'dual'> [<maskfile>]\nor simply: 'train [primal|dual]' or 'predict' or 'crossvalidation [primal|dual]'\nadding --pubchem as last command causes the tool to search in whole pubchem");
                            return;
                        }
                        path = Paths.get(args[1], new String[0]);
                        path2 = Paths.get(args[2], new String[0]);
                        if (args[3].toLowerCase().contains("dual")) {
                            z = true;
                        } else {
                            if (!args[3].toLowerCase().contains("prim")) {
                                reporter.report(this, helpMessage);
                                return;
                            }
                            z = false;
                        }
                        if (args.length == 5) {
                            mask4 = Mask.fromString(Files.newBufferedReader(Paths.get(args[4], new String[0]), Charset.defaultCharset()).readLine().split("\\s+"));
                        }
                    }
                    if (!path.toFile().exists()) {
                        reporter.error(this, new RuntimeException("no query file found: " + path.toString()));
                        return;
                    }
                    List<CompoundWithAbstractFP<ProbabilityFingerprint>> parseQueries2 = parseQueries(path, maskedFingerprintVersion, mask4);
                    EvalConfidenceScore.crossvalidation(parseQueries2, predictionPerformances, maskedFingerprintVersion, path2, !z, fingerprintDb);
                    if (fingerprintDb.getBioFilter() == BioFilter.ALL) {
                        EvalConfidenceScore.crossvalidation(parseQueries2, predictionPerformances, maskedFingerprintVersion, path2, fingerprintDb, TrainConfidenceScore.AllLong(!z));
                    } else {
                        EvalConfidenceScore.crossvalidation(parseQueries2, predictionPerformances, maskedFingerprintVersion, path2, !z, fingerprintDb);
                    }
                } catch (Exception e) {
                    reporter.error(this, e);
                }
            } catch (IOException e2) {
                reporter.error(this, e2);
            }
        } catch (DatabaseException e3) {
            reporter.error(this, e3);
        }
    }

    @Override // de.unijena.bioinf.fingerid.cli.CliTool
    public String getName() {
        return "confidence";
    }

    @Override // de.unijena.bioinf.fingerid.cli.CliTool
    public String getDescription() {
        return "learns a linear SVM for predicting the confidence of a csi-fingerid prediction.";
    }

    private List<CompoundWithAbstractFP<ProbabilityFingerprint>> parseQueries(Path path, MaskedFingerprintVersion maskedFingerprintVersion, Mask mask) throws IOException {
        String readLine;
        ArrayList arrayList = new ArrayList();
        BufferedReader newBufferedReader = Files.newBufferedReader(path, Charset.forName("UTF-8"));
        String readLine2 = newBufferedReader.readLine();
        if (readLine2 == null) {
            throw new RuntimeException("Empty query file");
        }
        String[] split = readLine2.split("\t");
        int[] iArr = {-1, -1, -1};
        if (!searchColumns(split, iArr)) {
            readLine2 = newBufferedReader.readLine();
            if (readLine2 != null) {
                split = readLine2.split("\t");
            }
            if (readLine2 == null || !searchColumns(split, iArr)) {
                throw new RuntimeException("Expect a tab separated file with an InChI column, an InChI-key in the column before that and a fingerprint (01) column");
            }
        }
        int i = iArr[0];
        int i2 = iArr[1];
        int i3 = iArr[2] + 1;
        do {
            String[] split2 = readLine2.split("\t");
            String str = split2[i];
            String str2 = split2[i2];
            double[] dArr = new double[split2.length - i3];
            for (int i4 = 0; i4 < dArr.length; i4++) {
                dArr[i4] = Double.parseDouble(split2[i3 + i4]);
            }
            arrayList.add(mask == null ? new CompoundWithAbstractFP(new InChI(str2, str), new ProbabilityFingerprint(maskedFingerprintVersion, dArr)) : new CompoundWithAbstractFP(new InChI(str2, str), maskedFingerprintVersion.mask(mask.unapply(dArr))));
            readLine = newBufferedReader.readLine();
            readLine2 = readLine;
        } while (readLine != null);
        return arrayList;
    }

    private boolean searchColumns(String[] strArr, int[] iArr) {
        int i;
        int i2 = -1;
        int i3 = -1;
        int i4 = -1;
        for (int i5 = 0; i5 < strArr.length; i5++) {
            String str = strArr[i5];
            if (i2 >= 0 || !str.startsWith("InChI=")) {
                if (i4 >= 0) {
                    if (i4 >= 0 && i2 >= 0) {
                        break;
                    }
                } else {
                    while (true) {
                        if (i >= str.length()) {
                            i4 = i5;
                            break;
                        }
                        char charAt = str.charAt(i);
                        i = (charAt == '0' || charAt == '1') ? i + 1 : 0;
                    }
                }
            } else {
                i2 = i5;
                if (i5 == 0) {
                    return false;
                }
                String str2 = strArr[i5 - 1];
                int length = str2.length();
                if (length == 14) {
                    for (int i6 = 0; i6 < str2.length(); i6++) {
                        char charAt2 = str2.charAt(i6);
                        if (!Character.isAlphabetic(charAt2) || !Character.isUpperCase(charAt2)) {
                            return false;
                        }
                    }
                } else if (length == 25) {
                    for (int i7 = 0; i7 < str2.length(); i7++) {
                        char charAt3 = str2.charAt(i7);
                        if (i7 == 14) {
                            if (charAt3 != '-') {
                                return false;
                            }
                        } else if (!Character.isAlphabetic(charAt3) || !Character.isUpperCase(charAt3)) {
                            return false;
                        }
                    }
                } else {
                    if (length != 27) {
                        return false;
                    }
                    for (int i8 = 0; i8 < str2.length(); i8++) {
                        char charAt4 = str2.charAt(i8);
                        if (i8 == 14 || i8 == 25) {
                            if (charAt4 != '-') {
                                return false;
                            }
                        } else if (!Character.isAlphabetic(charAt4) || !Character.isUpperCase(charAt4)) {
                            return false;
                        }
                    }
                }
                i3 = i5 - 1;
            }
        }
        iArr[0] = i2;
        iArr[1] = i3;
        iArr[2] = i4;
        return i2 >= 0 && i4 >= 0;
    }
}
