package de.unijena.bioinf.fingerid.blast;

import de.unijena.bioinf.ChemistryBase.fp.FPIter;
import de.unijena.bioinf.ChemistryBase.fp.FPIter2;
import de.unijena.bioinf.ChemistryBase.fp.Fingerprint;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.fp.ProbabilityFingerprint;
import de.unijena.bioinf.fingerid.blast.parameters.Parameters;

/* loaded from: input_file:de/unijena/bioinf/fingerid/blast/CSIFingerIdScoring.class */
public class CSIFingerIdScoring implements FingerblastScoring<Parameters.FP> {
    private PredictionPerformance[] performances;
    private double[] tp;
    private double[] fp;
    private double[] tn;
    private double[] fn;
    private double[] logOneMinusRecall;
    private double[] logOneminusSpecificity;
    private double alpha;
    private double threshold = 0.25d;
    private double minSamples = 25.0d;

    public CSIFingerIdScoring(PredictionPerformance[] predictionPerformanceArr) {
        this.performances = (PredictionPerformance[]) predictionPerformanceArr.clone();
        this.alpha = 1.0d / predictionPerformanceArr[0].withPseudoCount(0.25d).numberOfSamplesWithPseudocounts();
        this.tp = new double[predictionPerformanceArr.length];
        this.fp = new double[predictionPerformanceArr.length];
        this.tn = new double[predictionPerformanceArr.length];
        this.fn = new double[predictionPerformanceArr.length];
        this.logOneMinusRecall = new double[predictionPerformanceArr.length];
        this.logOneminusSpecificity = new double[predictionPerformanceArr.length];
        for (int i = 0; i < predictionPerformanceArr.length; i++) {
            this.performances[i] = predictionPerformanceArr[i].withPseudoCount(0.25d);
            this.logOneMinusRecall[i] = Math.log(1.0d - this.performances[i].getRecall());
            this.logOneminusSpecificity[i] = Math.log(1.0d - this.performances[i].getSpecitivity());
        }
    }

    public PredictionPerformance[] getPerfomances() {
        return this.performances;
    }

    @Override // de.unijena.bioinf.fingerid.blast.FingerblastScoring
    public double getThreshold() {
        return this.threshold;
    }

    @Override // de.unijena.bioinf.fingerid.blast.FingerblastScoring
    public void setThreshold(double d) {
        this.threshold = d;
    }

    @Override // de.unijena.bioinf.fingerid.blast.FingerblastScoring
    public double getMinSamples() {
        return this.minSamples;
    }

    @Override // de.unijena.bioinf.fingerid.blast.FingerblastScoring
    public void setMinSamples(double d) {
        this.minSamples = d;
    }

    @Override // de.unijena.bioinf.fingerid.blast.FingerblastScoring
    public void prepare(Parameters.FP fp) {
        int i = 0;
        FPIter it = fp.getFP().iterator();
        while (it.hasNext()) {
            double laplaceSmoothing = laplaceSmoothing(((FPIter) it.next()).getProbability());
            double log = Math.log(laplaceSmoothing);
            double log2 = Math.log(1.0d - laplaceSmoothing);
            this.tp[i] = (0.75d * log) + (0.25d * this.logOneMinusRecall[i]);
            this.fp[i] = 0.75d * log2;
            this.tn[i] = (0.75d * log2) + (0.25d * this.logOneminusSpecificity[i]);
            this.fn[i] = 0.75d * log;
            i++;
        }
    }

    private double laplaceSmoothing(double d) {
        return (d + this.alpha) / (1.0d + (2.0d * this.alpha));
    }

    @Override // de.unijena.bioinf.fingerid.blast.FingerblastScoring
    public double score(ProbabilityFingerprint probabilityFingerprint, Fingerprint fingerprint) {
        double d;
        double d2;
        double d3 = 0.0d;
        int i = -1;
        if (!probabilityFingerprint.isCompatible(fingerprint)) {
            throw new RuntimeException("Fingerprints are not compatible");
        }
        int i2 = 0;
        for (FPIter2 fPIter2 : probabilityFingerprint.foreachPair(fingerprint)) {
            i++;
            if (this.performances[i].getF() >= this.threshold && this.performances[i].getSmallerClassSize() >= this.minSamples) {
                if (fPIter2.isRightSet()) {
                    if (fPIter2.isLeftSet()) {
                        d = d3;
                        d2 = this.tp[i];
                    } else {
                        d = d3;
                        d2 = this.fn[i];
                    }
                } else if (fPIter2.isLeftSet()) {
                    d = d3;
                    d2 = this.fp[i];
                } else {
                    d = d3;
                    d2 = this.tn[i];
                }
                d3 = d + d2;
                i2++;
            }
        }
        return d3;
    }
}
