package de.unijena.bioinf.lcms.adducts;

import de.unijena.bioinf.ChemistryBase.math.Statistics;
import de.unijena.bioinf.ChemistryBase.ms.Deviation;
import de.unijena.bioinf.ChemistryBase.ms.Normalization;
import de.unijena.bioinf.ChemistryBase.ms.Spectrum;
import de.unijena.bioinf.ChemistryBase.ms.utils.SimpleSpectrum;
import de.unijena.bioinf.ChemistryBase.ms.utils.Spectrums;
import de.unijena.bioinf.ms.persistence.model.core.feature.AbstractAlignedFeatures;
import de.unijena.bioinf.ms.persistence.model.core.feature.AlignedFeatures;
import de.unijena.bioinf.ms.persistence.model.core.spectrum.MergedMSnSpectrum;
import de.unijena.bioinf.ms.persistence.model.core.trace.AbstractTrace;
import de.unijena.bioinf.ms.persistence.model.core.trace.MergedTrace;
import de.unijena.bioinf.ms.persistence.model.core.trace.SourceTrace;
import de.unijena.bioinf.ms.persistence.model.core.trace.TraceRef;
import de.unijena.bionf.spectral_alignment.ModifiedCosine;
import it.unimi.dsi.fastutil.Pair;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:de/unijena/bioinf/lcms/adducts/Scorer.class */
public class Scorer {
    private static final double DEFAULT_SCORE = 1.0d;
    public static final float SCORE_BONUS_FOR_SIMPLE_EDGES = 1.0f;

    public double computeScore(ProjectSpaceTraceProvider projectSpaceTraceProvider, AdductEdge adductEdge) {
        AlignedFeatures feature = adductEdge.left.getFeature();
        AlignedFeatures alignedFeatures = adductEdge.right.features;
        return (feature.getTraceReference().isPresent() && alignedFeatures.getTraceReference().isPresent()) ? computeScoreFromCorrelation(projectSpaceTraceProvider, adductEdge, feature, alignedFeatures) : computeScoreWithoutCorrelation(feature, alignedFeatures);
    }

    private double computeScoreWithoutCorrelation(AlignedFeatures alignedFeatures, AlignedFeatures alignedFeatures2) {
        if (alignedFeatures.getRetentionTime().compareTo(alignedFeatures2.getRetentionTime()) == 0) {
            return DEFAULT_SCORE;
        }
        return Double.NEGATIVE_INFINITY;
    }

    public double computeScoreFromCorrelation(ProjectSpaceTraceProvider projectSpaceTraceProvider, AdductEdge adductEdge, AlignedFeatures alignedFeatures, AlignedFeatures alignedFeatures2) {
        Long2DoubleMap intensities = projectSpaceTraceProvider.getIntensities(alignedFeatures);
        Long2DoubleMap intensities2 = projectSpaceTraceProvider.getIntensities(alignedFeatures2);
        adductEdge.ratioScore = (float) correlateAcrossSamples(intensities, intensities2);
        Optional<MergedTrace> mergeTrace = projectSpaceTraceProvider.getMergeTrace(alignedFeatures);
        Optional<MergedTrace> mergeTrace2 = projectSpaceTraceProvider.getMergeTrace(alignedFeatures2);
        if (mergeTrace.isPresent() && mergeTrace2.isPresent()) {
            adductEdge.correlationScore = (float) correlateTraces((TraceRef) alignedFeatures.getTraceReference().get(), mergeTrace.get(), (TraceRef) alignedFeatures2.getTraceReference().get(), mergeTrace2.get());
        }
        adductEdge.representativeCorrelationScore = correlateRepresentatives(projectSpaceTraceProvider, alignedFeatures, alignedFeatures2, intensities, intensities2);
        return 0.0d;
    }

    public static float correlateRepresentatives(TraceProvider traceProvider, AbstractAlignedFeatures abstractAlignedFeatures, AbstractAlignedFeatures abstractAlignedFeatures2, Long2DoubleMap long2DoubleMap, Long2DoubleMap long2DoubleMap2) {
        LongOpenHashSet longOpenHashSet = new LongOpenHashSet(long2DoubleMap.keySet());
        longOpenHashSet.retainAll(long2DoubleMap2.keySet());
        double d = 0.0d;
        long j = -1;
        LongIterator it = longOpenHashSet.iterator();
        while (it.hasNext()) {
            long longValue = ((Long) it.next()).longValue();
            double sqrt = Math.sqrt(long2DoubleMap.get(longValue) * long2DoubleMap2.get(longValue));
            if (sqrt > d) {
                j = longValue;
                d = sqrt;
            }
        }
        if (d <= 0.0d) {
            return Float.NaN;
        }
        Optional<Pair<TraceRef, SourceTrace>> sourceTrace = traceProvider.getSourceTrace(abstractAlignedFeatures, j);
        Optional<Pair<TraceRef, SourceTrace>> sourceTrace2 = traceProvider.getSourceTrace(abstractAlignedFeatures2, j);
        if (sourceTrace.isPresent() && sourceTrace2.isPresent()) {
            return (float) correlateTraces((TraceRef) sourceTrace.get().left(), (AbstractTrace) sourceTrace.get().right(), (TraceRef) sourceTrace2.get().left(), (AbstractTrace) sourceTrace2.get().right());
        }
        return Float.NaN;
    }

    public static double correlateAcrossSamples(Long2DoubleMap long2DoubleMap, Long2DoubleMap long2DoubleMap2) {
        DoubleArrayList doubleArrayList = new DoubleArrayList();
        DoubleArrayList doubleArrayList2 = new DoubleArrayList();
        long2DoubleMap.keySet().forEach(j -> {
            doubleArrayList.add(long2DoubleMap.getOrDefault(j, 0.0d));
            doubleArrayList2.add(long2DoubleMap2.getOrDefault(j, 0.0d));
        });
        long2DoubleMap2.keySet().forEach(j2 -> {
            if (long2DoubleMap.containsKey(j2)) {
                return;
            }
            doubleArrayList.add(0.0d);
            doubleArrayList2.add(long2DoubleMap2.getOrDefault(j2, 0.0d));
        });
        if (doubleArrayList.size() <= 0 || doubleArrayList2.size() <= 0) {
            return Double.NEGATIVE_INFINITY;
        }
        if (doubleArrayList.size() <= 2 || doubleArrayList2.size() <= 2) {
            return 0.0d;
        }
        double sum = doubleArrayList.doubleStream().sum();
        double sum2 = doubleArrayList2.doubleStream().sum();
        for (int i = 0; i < doubleArrayList.size(); i++) {
            doubleArrayList.set(i, doubleArrayList.getDouble(i) / sum);
            doubleArrayList2.set(i, doubleArrayList2.getDouble(i) / sum2);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < doubleArrayList.size(); i2++) {
            if (doubleArrayList.getDouble(i2) + doubleArrayList2.getDouble(i2) >= 0.01d) {
                double max = Math.max(doubleArrayList.getDouble(i2), doubleArrayList2.getDouble(i2));
                if (doubleArrayList.getDouble(i2) <= 0.0d || doubleArrayList2.getDouble(i2) <= 0.0d) {
                    d2 += max;
                }
                double d3 = doubleArrayList.getDouble(i2) - doubleArrayList2.getDouble(i2);
                double exp = Math.exp((-(d3 * d3)) / (2.0d * (0.0025000000000000005d + (((max * max) * 0.1d) * 0.1d)))) / (((6.283185307179586d * max) * 0.1d) * 0.05d);
                double d4 = (max * 2.0d * 0.1d) + 0.1d;
                d += Math.log(exp) - Math.log(Math.exp((-(d4 * d4)) / (2.0d * (0.0025000000000000005d + (((max * max) * 0.1d) * 0.1d)))) / (((6.283185307179586d * max) * 0.1d) * 0.1d));
            }
        }
        if (d2 >= 0.5d) {
            return Double.NEGATIVE_INFINITY;
        }
        return d;
    }

    private static double correlateLargerWithSmaller(TraceRef traceRef, AbstractTrace abstractTrace, TraceRef traceRef2, AbstractTrace abstractTrace2) {
        int start = traceRef.getStart();
        traceRef2.getStart();
        int end = traceRef.getEnd();
        traceRef2.getEnd();
        double d = abstractTrace.getIntensities().getFloat(traceRef.getApex()) * 0.5d;
        int apex = traceRef.getApex();
        while (apex > start && abstractTrace.getIntensities().getFloat(apex) > d) {
            apex--;
        }
        int apex2 = traceRef.getApex();
        while (apex2 < end && abstractTrace.getIntensities().getFloat(apex2) > d) {
            apex2++;
        }
        if (apex2 - apex >= 3) {
            start = apex;
            end = apex2;
        }
        int max = Math.max((end - start) + 1, (traceRef2.getEnd() - traceRef2.getStart()) + 1);
        DoubleArrayList doubleArrayList = new DoubleArrayList(max);
        DoubleArrayList doubleArrayList2 = new DoubleArrayList(max);
        int scanIndexOffsetOfTrace = start + traceRef.getScanIndexOffsetOfTrace();
        int scanIndexOffsetOfTrace2 = end + traceRef.getScanIndexOffsetOfTrace();
        int apex3 = traceRef2.getApex() + traceRef2.getScanIndexOffsetOfTrace();
        if (apex3 < scanIndexOffsetOfTrace || apex3 > scanIndexOffsetOfTrace2) {
            return Double.NaN;
        }
        for (int i = scanIndexOffsetOfTrace; i <= scanIndexOffsetOfTrace2; i++) {
            int scanIndexOffsetOfTrace3 = i - traceRef.getScanIndexOffsetOfTrace();
            int scanIndexOffsetOfTrace4 = i - traceRef2.getScanIndexOffsetOfTrace();
            if (scanIndexOffsetOfTrace3 < 0 || scanIndexOffsetOfTrace3 >= abstractTrace.getIntensities().size()) {
                doubleArrayList.add(0.0d);
            } else {
                doubleArrayList.add(abstractTrace.getIntensities().getFloat(scanIndexOffsetOfTrace3));
            }
            if (scanIndexOffsetOfTrace4 < 0 || scanIndexOffsetOfTrace4 >= abstractTrace2.getIntensities().size()) {
                doubleArrayList2.add(0.0d);
            } else {
                doubleArrayList2.add(abstractTrace2.getIntensities().getFloat(scanIndexOffsetOfTrace4));
            }
        }
        return Statistics.pearson(doubleArrayList.toDoubleArray(), doubleArrayList2.toDoubleArray());
    }

    public static double correlateTraces(TraceRef traceRef, AbstractTrace abstractTrace, TraceRef traceRef2, AbstractTrace abstractTrace2) {
        return abstractTrace.getIntensities().getFloat(traceRef.getApex()) > abstractTrace2.getIntensities().getFloat(traceRef2.getApex()) ? correlateLargerWithSmaller(traceRef, abstractTrace, traceRef2, abstractTrace2) : correlateLargerWithSmaller(traceRef2, abstractTrace2, traceRef, abstractTrace);
    }

    public SimpleSpectrum prepareForCosine(AdductNode adductNode, List<MergedMSnSpectrum> list) {
        Spectrum mergePeaksWithinSpectrum = Spectrums.mergePeaksWithinSpectrum(Spectrums.mergeSpectra((SimpleSpectrum[]) list.stream().map((v0) -> {
            return v0.getPeaks();
        }).toArray(i -> {
            return new SimpleSpectrum[i];
        })), new Deviation(10.0d), true, false);
        int firstPeakGreaterOrEqualThan = Spectrums.getFirstPeakGreaterOrEqualThan(mergePeaksWithinSpectrum, adductNode.getMass() - 8.0d);
        if (firstPeakGreaterOrEqualThan < mergePeaksWithinSpectrum.size()) {
            mergePeaksWithinSpectrum = Spectrums.subspectrum(mergePeaksWithinSpectrum, 0, firstPeakGreaterOrEqualThan);
        }
        return Spectrums.getNormalizedSpectrum(mergePeaksWithinSpectrum, Normalization.L2());
    }

    public boolean hasMinimumMs2Quality(SimpleSpectrum simpleSpectrum) {
        double maximalIntensity = Spectrums.getMaximalIntensity(simpleSpectrum);
        double d = maximalIntensity;
        int i = 0;
        for (int i2 = 0; i2 < simpleSpectrum.size(); i2++) {
            if (simpleSpectrum.getIntensityAt(i2) / maximalIntensity >= 0.05d) {
                i++;
            }
            d = Math.min(d, simpleSpectrum.getIntensityAt(i2));
        }
        return i >= 3 && maximalIntensity / d >= 10.0d;
    }

    public void computeMs2Score(AdductEdge adductEdge, SimpleSpectrum simpleSpectrum, SimpleSpectrum simpleSpectrum2) {
        adductEdge.ms2score = (float) new ModifiedCosine(new Deviation(10.0d)).score(simpleSpectrum, simpleSpectrum2, adductEdge.getLeft().getMass(), adductEdge.getRight().getMass(), DEFAULT_SCORE).similarity;
    }
}
