package de.unijena.bioinf.confidence_score;

import de.unijena.bioinf.ChemistryBase.algorithm.scoring.Scored;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.fp.ProbabilityFingerprint;
import de.unijena.bioinf.ChemistryBase.ms.CollisionEnergy;
import de.unijena.bioinf.ChemistryBase.ms.Ms2Experiment;
import de.unijena.bioinf.ChemistryBase.ms.Ms2Spectrum;
import de.unijena.bioinf.ChemistryBase.ms.Peak;
import de.unijena.bioinf.chemdb.FingerprintCandidate;
import de.unijena.bioinf.confidence_score.parameters.SuperParameters;
import de.unijena.bioinf.confidence_score.svm.SVMPredict;
import de.unijena.bioinf.confidence_score.svm.SVMUtils;
import de.unijena.bioinf.confidence_score.svm.TrainedSVM;
import de.unijena.bioinf.fingerid.blast.BayesnetScoring;
import de.unijena.bioinf.fingerid.blast.CSIFingerIdScoring;
import de.unijena.bioinf.fingerid.blast.FingerblastScoringMethod;
import de.unijena.bioinf.fingerid.blast.ScoringMethodFactory;
import de.unijena.bioinf.sirius.IdentificationResult;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/unijena/bioinf/confidence_score/CSICovarianceConfidenceScorer.class */
public class CSICovarianceConfidenceScorer implements ConfidenceScorer {
    public static final String NO_DISTANCE_ID = "Nodist";
    public static final String DISTANCE_ID = "dist";
    public static final String DB_ALL_ID = "All";
    public static final String DB_BIO_ID = "Bio";
    public static final String CE_LOW = "feLow";
    public static final String CE_MED = "feMed";
    public static final String CE_HIGH = "feHi";
    public static final String CE_RAMP = "feRAMP";
    private final Map<String, TrainedSVM> trainedSVMs;
    private final BayesnetScoring covarianceScoringMethod;
    private final ScoringMethodFactory.CSIFingerIdScoringMethod csiFingerIdScoringMethod;
    private Class<? extends FingerblastScoringMethod> scoringOfInput;

    public CSICovarianceConfidenceScorer(@NotNull Map<String, TrainedSVM> map, @NotNull BayesnetScoring bayesnetScoring, @NotNull ScoringMethodFactory.CSIFingerIdScoringMethod cSIFingerIdScoringMethod, Class<? extends FingerblastScoringMethod> cls) {
        this.trainedSVMs = map;
        this.covarianceScoringMethod = bayesnetScoring;
        this.csiFingerIdScoringMethod = cSIFingerIdScoringMethod;
        this.scoringOfInput = cls;
    }

    public Class<? extends FingerblastScoringMethod> getScoringOfInput() {
        return this.scoringOfInput;
    }

    public void setScoringOfInput(Class<? extends FingerblastScoringMethod<?>> cls) {
        this.scoringOfInput = cls;
    }

    public double computeConfidence(Ms2Experiment ms2Experiment, IdentificationResult<?> identificationResult, List<Scored<FingerprintCandidate>> list, long j, ProbabilityFingerprint probabilityFingerprint) {
        return computeConfidence(ms2Experiment, identificationResult, list, probabilityFingerprint, fingerprintCandidate -> {
            return (fingerprintCandidate.getBitset() & j) != 0;
        });
    }

    @Override // de.unijena.bioinf.confidence_score.ConfidenceScorer
    public double computeConfidence(@NotNull Ms2Experiment ms2Experiment, @NotNull IdentificationResult<?> identificationResult, @NotNull List<Scored<FingerprintCandidate>> list, @NotNull ProbabilityFingerprint probabilityFingerprint, @Nullable Predicate<FingerprintCandidate> predicate) {
        return computeConfidence(ms2Experiment, identificationResult, list, this.scoringOfInput, probabilityFingerprint, predicate);
    }

    @Override // de.unijena.bioinf.confidence_score.ConfidenceScorer
    public double computeConfidence(@NotNull Ms2Experiment ms2Experiment, @NotNull IdentificationResult<?> identificationResult, @NotNull List<Scored<FingerprintCandidate>> list, @NotNull List<Scored<FingerprintCandidate>> list2, @NotNull ProbabilityFingerprint probabilityFingerprint) {
        return computeConfidence(ms2Experiment, identificationResult, list, list2, this.scoringOfInput, probabilityFingerprint);
    }

    public double computeConfidence(@NotNull Ms2Experiment ms2Experiment, @NotNull IdentificationResult<?> identificationResult, @NotNull List<Scored<FingerprintCandidate>> list, @NotNull List<Scored<FingerprintCandidate>> list2, @NotNull Class<? extends FingerblastScoringMethod> cls, @NotNull ProbabilityFingerprint probabilityFingerprint) {
        List<Scored<FingerprintCandidate>> calculateCSIScores;
        List<Scored<FingerprintCandidate>> calculateCSIScores2;
        List<Scored<FingerprintCandidate>> calculateCovarianceScores;
        List<Scored<FingerprintCandidate>> calculateCovarianceScores2;
        if (cls == ScoringMethodFactory.CSIFingerIdScoringMethod.class) {
            calculateCSIScores = list;
            calculateCSIScores2 = list2;
        } else {
            CSIFingerIdScoring scoring = this.csiFingerIdScoringMethod.getScoring();
            calculateCSIScores = calculateCSIScores(list, scoring, probabilityFingerprint);
            calculateCSIScores2 = calculateCSIScores(list2, scoring, probabilityFingerprint);
        }
        BayesnetScoring.Scorer scoring2 = this.covarianceScoringMethod.getScoring();
        scoring2.prepare(() -> {
            return probabilityFingerprint;
        });
        if (cls == BayesnetScoring.class) {
            calculateCovarianceScores = list;
            calculateCovarianceScores2 = list2;
        } else {
            calculateCovarianceScores = calculateCovarianceScores(list, scoring2, probabilityFingerprint);
            calculateCovarianceScores2 = calculateCovarianceScores(list2, scoring2, probabilityFingerprint);
        }
        return computeConfidence(ms2Experiment, identificationResult, (Scored[]) calculateCovarianceScores.toArray(i -> {
            return new Scored[i];
        }), (Scored[]) calculateCSIScores.toArray(i2 -> {
            return new Scored[i2];
        }), (Scored[]) calculateCovarianceScores2.toArray(i3 -> {
            return new Scored[i3];
        }), (Scored[]) calculateCSIScores2.toArray(i4 -> {
            return new Scored[i4];
        }), probabilityFingerprint, scoring2, this.csiFingerIdScoringMethod.getPerformances());
    }

    public double computeConfidence(Ms2Experiment ms2Experiment, IdentificationResult<?> identificationResult, List<Scored<FingerprintCandidate>> list, Class<? extends FingerblastScoringMethod> cls, ProbabilityFingerprint probabilityFingerprint, @Nullable Predicate<FingerprintCandidate> predicate) {
        List<Scored<FingerprintCandidate>> calculateCSIScores = cls == ScoringMethodFactory.CSIFingerIdScoringMethod.class ? list : calculateCSIScores(list, this.csiFingerIdScoringMethod.getScoring(), probabilityFingerprint);
        BayesnetScoring.Scorer scoring = this.covarianceScoringMethod.getScoring();
        scoring.prepare(() -> {
            return probabilityFingerprint;
        });
        List<Scored<FingerprintCandidate>> calculateCovarianceScores = cls == BayesnetScoring.class ? list : calculateCovarianceScores(list, scoring, probabilityFingerprint);
        return computeConfidence(ms2Experiment, identificationResult, (Scored[]) calculateCovarianceScores.toArray(i -> {
            return new Scored[i];
        }), (Scored[]) calculateCSIScores.toArray(i2 -> {
            return new Scored[i2];
        }), predicate != null ? (Scored[]) calculateCSIScores.stream().filter(scored -> {
            return predicate.test((FingerprintCandidate) scored.getCandidate());
        }).toArray(i3 -> {
            return new Scored[i3];
        }) : null, predicate != null ? (Scored[]) calculateCovarianceScores.stream().filter(scored2 -> {
            return predicate.test((FingerprintCandidate) scored2.getCandidate());
        }).toArray(i4 -> {
            return new Scored[i4];
        }) : null, probabilityFingerprint, scoring, this.csiFingerIdScoringMethod.getPerformances());
    }

    public double computeConfidence(Ms2Experiment ms2Experiment, IdentificationResult<?> identificationResult, Scored<FingerprintCandidate>[] scoredArr, Scored<FingerprintCandidate>[] scoredArr2, @Nullable Scored<FingerprintCandidate>[] scoredArr3, @Nullable Scored<FingerprintCandidate>[] scoredArr4, ProbabilityFingerprint probabilityFingerprint, BayesnetScoring.Scorer scorer, PredictionPerformance[] predictionPerformanceArr) {
        CombinedFeatureCreator combinedFeatureCreatorALL;
        String str;
        String str2;
        if (scoredArr.length != scoredArr2.length) {
            throw new IllegalArgumentException("Covariance scored candidate list has different length from fingerid scored candidates list!");
        }
        if (scoredArr.length <= 4) {
            LoggerFactory.getLogger(getClass()).debug("Cannot calculate confidence with less than 5 hits in \"PubChem\" database! Returning NaN. Instance: " + ms2Experiment.getName() + "-" + ms2Experiment.getMolecularFormula() + "-" + ms2Experiment.getPrecursorIonType());
            return Double.NaN;
        }
        if (scoredArr3 != null && scoredArr3.length == 0) {
            LoggerFactory.getLogger(getClass()).debug("Cannot calculate confidence with NO hit in \"Search\" database! Returning NaN. Instance: " + ms2Experiment.getName() + "-" + ms2Experiment.getMolecularFormula() + "-" + ms2Experiment.getPrecursorIonType());
            return Double.NaN;
        }
        String makeCeString = makeCeString(ms2Experiment.getMs2Spectra());
        if (scoredArr3 == null || scoredArr4 == null) {
            combinedFeatureCreatorALL = new CombinedFeatureCreatorALL(scoredArr2, scoredArr, predictionPerformanceArr, scorer);
            str = null;
            str2 = DB_ALL_ID;
        } else if (moreThanOneUniqueFPs(scoredArr3)) {
            combinedFeatureCreatorALL = new CombinedFeatureCreatorBIODISTANCE(scoredArr2, scoredArr, scoredArr4, scoredArr3, predictionPerformanceArr, scorer);
            str = DISTANCE_ID;
            str2 = DB_BIO_ID;
        } else {
            combinedFeatureCreatorALL = new CombinedFeatureCreatorBIONODISTANCE(scoredArr2, scoredArr, scoredArr4, scoredArr3, predictionPerformanceArr, scorer);
            str = NO_DISTANCE_ID;
            str2 = DB_BIO_ID;
        }
        return calculateConfidence(combinedFeatureCreatorALL.computeFeatures((CombinedFeatureCreator) new SuperParameters.Default(probabilityFingerprint, identificationResult)), str2, str, makeCeString);
    }

    private boolean moreThanOneUniqueFPs(Scored<FingerprintCandidate>[] scoredArr) {
        if (scoredArr.length < 2) {
            return false;
        }
        short[] indizesArray = ((FingerprintCandidate) scoredArr[0].getCandidate()).getFingerprint().toIndizesArray();
        for (int i = 1; i < scoredArr.length; i++) {
            if (!Arrays.equals(indizesArray, ((FingerprintCandidate) scoredArr[i].getCandidate()).getFingerprint().toIndizesArray())) {
                return true;
            }
        }
        return false;
    }

    private double calculateConfidence(@NotNull double[] dArr, @NotNull String str, @Nullable String str2, @NotNull String str3) {
        String str4 = str2 != null ? str3 + "_" + str + str2 + ".svm" : str3 + "_" + str + ".svm";
        TrainedSVM trainedSVM = this.trainedSVMs.get(str4);
        if (trainedSVM == null) {
            throw new IllegalArgumentException("Could not found confidence svm with ID: \"" + str4 + "\"");
        }
        double[][] dArr2 = {dArr};
        SVMUtils.standardize_features(dArr2, trainedSVM.scales);
        return new SVMPredict().predict_confidence(dArr2, trainedSVM)[0];
    }

    private static List<Scored<FingerprintCandidate>> calculateCSIScores(List<Scored<FingerprintCandidate>> list, CSIFingerIdScoring cSIFingerIdScoring, ProbabilityFingerprint probabilityFingerprint) {
        cSIFingerIdScoring.prepare(() -> {
            return probabilityFingerprint;
        });
        return (List) list.stream().map((v0) -> {
            return v0.getCandidate();
        }).map(fingerprintCandidate -> {
            return new Scored(fingerprintCandidate, cSIFingerIdScoring.score(probabilityFingerprint, fingerprintCandidate.getFingerprint()));
        }).sorted(Comparator.reverseOrder()).collect(Collectors.toList());
    }

    private static List<Scored<FingerprintCandidate>> calculateCovarianceScores(List<Scored<FingerprintCandidate>> list, BayesnetScoring.Scorer scorer, ProbabilityFingerprint probabilityFingerprint) {
        return (List) list.stream().map((v0) -> {
            return v0.getCandidate();
        }).map(fingerprintCandidate -> {
            return new Scored(fingerprintCandidate, scorer.score(probabilityFingerprint, fingerprintCandidate.getFingerprint()));
        }).sorted(Comparator.reverseOrder()).collect(Collectors.toList());
    }

    public static String makeCeString(@NotNull List<Ms2Spectrum<Peak>> list) {
        double d = Double.MAX_VALUE;
        double d2 = Double.MIN_VALUE;
        Iterator<Ms2Spectrum<Peak>> it = list.iterator();
        while (it.hasNext()) {
            CollisionEnergy collisionEnergy = it.next().getCollisionEnergy();
            if (collisionEnergy == null) {
                return CE_RAMP;
            }
            d2 = Math.max(d2, collisionEnergy.getMaxEnergy());
            d = Math.min(d, collisionEnergy.getMinEnergy());
            if (d != d2) {
                return CE_RAMP;
            }
        }
        return d <= 15.0d ? CE_LOW : d < 30.0d ? CE_MED : CE_HIGH;
    }
}
