package de.unijena.bioinf.fragmenter;

import com.fasterxml.jackson.databind.JsonNode;
import de.unijena.bioinf.ChemistryBase.data.JacksonDocument;
import gnu.trove.map.hash.TObjectDoubleHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.special.Gamma;
import org.openscience.cdk.interfaces.IBond;

/* loaded from: input_file:de/unijena/bioinf/fragmenter/ScoringParameterEstimator.class */
public class ScoringParameterEstimator {
    private final File subtreeDir;
    private final String[] subtreeFileNames;
    private final File outputDir;
    private final double peakExplanationPercentile;
    private final double bondScoreSignificanceValue;
    private TObjectDoubleHashMap<String> directedBondName2Score;
    private double wildcardScore;
    private double hydrogenRearrangementProb;
    private double pseudoFragmentScore;
    private double gammaShapeParameter;
    private double gammaScaleParameter;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/unijena/bioinf/fragmenter/ScoringParameterEstimator$ExtractedData.class */
    public class ExtractedData {
        protected final Collection<Integer> hydrogenRearrangements;
        protected final Collection<Float> penalties;
        protected final TObjectDoubleHashMap<String> directedBondName2BreakProb;
        protected final TObjectDoubleHashMap<String> directedBondName2cutDirProb;

        public ExtractedData(Collection<Integer> collection, Collection<Float> collection2, TObjectDoubleHashMap<String> tObjectDoubleHashMap, TObjectDoubleHashMap<String> tObjectDoubleHashMap2) {
            this.hydrogenRearrangements = collection;
            this.penalties = collection2;
            this.directedBondName2BreakProb = tObjectDoubleHashMap;
            this.directedBondName2cutDirProb = tObjectDoubleHashMap2;
        }
    }

    public ScoringParameterEstimator(File file, File file2, double d, double d2) {
        if (!file.isDirectory() || !file2.isDirectory()) {
            throw new RuntimeException("The abstract path name denoted by the given File object does not exists or is not a directory.");
        }
        if (!isProbability(d) || !isProbability(d2)) {
            throw new RuntimeException("The given parameters aren't probabilities.");
        }
        this.subtreeDir = file;
        this.outputDir = file2;
        this.subtreeFileNames = file.list();
        this.peakExplanationPercentile = d;
        this.bondScoreSignificanceValue = d2;
    }

    private boolean isProbability(double d) {
        return d >= 0.0d && d <= 1.0d;
    }

    private double estimateHydrogenRearrangementProbability(Collection<Integer> collection) {
        int size = collection.size();
        long j = 0;
        while (collection.iterator().hasNext()) {
            j += r0.next().intValue();
        }
        return j / (j + size);
    }

    private double calculatePseudoFragmentScore(Collection<Float> collection) {
        double d;
        Collection collection2 = (Collection) collection.stream().filter(f -> {
            return f.floatValue() != 0.0f;
        }).map(f2 -> {
            return Float.valueOf(-f2.floatValue());
        }).collect(Collectors.toList());
        int size = collection2.size();
        double d2 = 0.0d;
        double d3 = 0.0d;
        Iterator it = collection2.iterator();
        while (it.hasNext()) {
            float floatValue = ((Float) it.next()).floatValue();
            d2 += floatValue;
            d3 += Math.log(floatValue);
        }
        double d4 = d2 / size;
        double d5 = d3 / size;
        double log = 0.5d / (Math.log(d4) - d5);
        do {
            double log2 = (1.0d / log) + ((((d5 - Math.log(d4)) + Math.log(log)) - Gamma.digamma(log)) / (Math.pow(log, 2.0d) * ((1.0d / log) - Gamma.trigamma(log))));
            d = log;
            log = 1.0d / log2;
        } while (Math.abs(d - log) >= 1.0E-4d);
        this.gammaShapeParameter = log;
        this.gammaScaleParameter = d4 / this.gammaShapeParameter;
        return new GammaDistribution(this.gammaShapeParameter, this.gammaScaleParameter).inverseCumulativeProbability(this.peakExplanationPercentile) - ((this.hydrogenRearrangementProb / (1.0d - this.hydrogenRearrangementProb)) * Math.log(this.hydrogenRearrangementProb));
    }

    private double calculateWildcardScore(TObjectDoubleHashMap<String> tObjectDoubleHashMap) {
        double[] values = tObjectDoubleHashMap.values();
        double d = values.length > 0 ? values[0] : 0.0d;
        for (int i = 1; i < values.length; i++) {
            if (values[i] < d) {
                d = values[i];
            }
        }
        return Math.log(d) - Math.log(2.0d);
    }

    private void estimateProbabilitiesAndUpdate(CombinatorialSubtree combinatorialSubtree, ArrayList<IBond> arrayList, TObjectDoubleHashMap<String> tObjectDoubleHashMap, TObjectDoubleHashMap<String> tObjectDoubleHashMap2) {
        String bondNameSpecific = DirectedBondTypeScoring.bondNameSpecific(arrayList.get(0), true);
        String bondNameSpecific2 = DirectedBondTypeScoring.bondNameSpecific(arrayList.get(0), false);
        int size = arrayList.size();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        Iterator<IBond> it = arrayList.iterator();
        while (it.hasNext()) {
            IBond next = it.next();
            int[] numberOfCuts = combinatorialSubtree.getNumberOfCuts(next);
            if (bondNameSpecific.equals(DirectedBondTypeScoring.bondNameSpecific(next, true))) {
                i2 += numberOfCuts[0];
                i3 += numberOfCuts[1];
            } else {
                i2 += numberOfCuts[1];
                i3 += numberOfCuts[0];
            }
            i += numberOfCuts[0] + numberOfCuts[1] > 0 ? 1 : 0;
        }
        tObjectDoubleHashMap.put(bondNameSpecific, i / size);
        tObjectDoubleHashMap.put(bondNameSpecific2, i / size);
        if (bondNameSpecific.equals(bondNameSpecific2)) {
            tObjectDoubleHashMap2.put(bondNameSpecific, 0.5d);
            tObjectDoubleHashMap2.put(bondNameSpecific2, 0.5d);
        } else if (i > 0) {
            tObjectDoubleHashMap2.put(bondNameSpecific, i2 / (i2 + i3));
            tObjectDoubleHashMap2.put(bondNameSpecific2, i3 / (i2 + i3));
        } else {
            tObjectDoubleHashMap2.put(bondNameSpecific, 0.5d);
            tObjectDoubleHashMap2.put(bondNameSpecific2, 0.5d);
        }
    }

    private void postprocessBondScores(TObjectDoubleHashMap<String> tObjectDoubleHashMap) {
        HashMap hashMap = new HashMap();
        for (String str : tObjectDoubleHashMap.keySet()) {
            ((ArrayList) hashMap.computeIfAbsent(getGenericBondName(str), str2 -> {
                return new ArrayList();
            })).add(str);
        }
        Set<String> keySet = hashMap.keySet();
        this.directedBondName2Score = new TObjectDoubleHashMap<>();
        for (String str3 : keySet) {
            this.directedBondName2Score.put(str3, computeAverageLogScore(tObjectDoubleHashMap, (ArrayList) hashMap.get(str3)));
        }
        int i = 0;
        while (this.directedBondName2Score.size() - i != 0) {
            i = this.directedBondName2Score.size();
            for (String str4 : keySet) {
                double d = this.directedBondName2Score.get(str4);
                Iterator it = new ArrayList((Collection) hashMap.get(str4)).iterator();
                while (it.hasNext()) {
                    String str5 = (String) it.next();
                    double d2 = tObjectDoubleHashMap.get(str5);
                    if (Math.exp(d2) - Math.exp(d) >= this.bondScoreSignificanceValue) {
                        ((ArrayList) hashMap.get(str4)).remove(str5);
                        this.directedBondName2Score.put(str5, d2);
                    }
                }
            }
            for (String str6 : keySet) {
                this.directedBondName2Score.put(str6, computeAverageLogScore(tObjectDoubleHashMap, (ArrayList) hashMap.get(str6)));
            }
        }
    }

    private double computeAverageLogScore(TObjectDoubleHashMap<String> tObjectDoubleHashMap, ArrayList<String> arrayList) {
        double d = 0.0d;
        Iterator<String> it = arrayList.iterator();
        while (it.hasNext()) {
            d += tObjectDoubleHashMap.get(it.next());
        }
        return d / arrayList.size();
    }

    private String getGenericBondName(String str) {
        String[] split = str.split("[:\\-=#?]");
        return split[0].split("\\.")[0] + str.charAt(split[0].length()) + split[1].split("\\.")[0];
    }

    private void saveData(Collection<Integer> collection, Collection<Float> collection2, TObjectDoubleHashMap<String> tObjectDoubleHashMap, TObjectDoubleHashMap<String> tObjectDoubleHashMap2, TObjectDoubleHashMap<String> tObjectDoubleHashMap3, TObjectIntHashMap<String> tObjectIntHashMap) throws IOException {
        saveObservedValues(new File(this.outputDir, "H-rearrangement_observations.txt"), collection, false);
        saveObservedValues(new File(this.outputDir, "observed_fragment_penalties.txt"), collection2, true);
        File file = new File(this.outputDir, "unprocessed_bond_scores.csv");
        Set<String> keySet = tObjectDoubleHashMap3.keySet();
        BufferedWriter newBufferedWriter = Files.newBufferedWriter(file.toPath(), new OpenOption[0]);
        try {
            newBufferedWriter.write("directedBondTypeName,numObservations,breakProb,cutDirectionProb,logProb");
            for (String str : keySet) {
                newBufferedWriter.newLine();
                int i = tObjectIntHashMap.get(str);
                double d = tObjectDoubleHashMap.get(str);
                double d2 = tObjectDoubleHashMap2.get(str);
                tObjectDoubleHashMap3.get(str);
                newBufferedWriter.write(str + "," + i + "," + d + "," + newBufferedWriter + "," + d2);
            }
            if (newBufferedWriter != null) {
                newBufferedWriter.close();
            }
            File file2 = new File(this.outputDir, "scoring_model.txt");
            String[] strArr = new String[this.directedBondName2Score.size()];
            double[] dArr = new double[this.directedBondName2Score.size()];
            int i2 = 0;
            for (String str2 : this.directedBondName2Score.keySet()) {
                double d3 = this.directedBondName2Score.get(str2);
                strArr[i2] = str2;
                dArr[i2] = d3;
                i2++;
            }
            DirectedBondTypeScoring.writeScoringToFile(file2, strArr, dArr, this.wildcardScore, this.hydrogenRearrangementProb, this.pseudoFragmentScore);
        } catch (Throwable th) {
            if (newBufferedWriter != null) {
                try {
                    newBufferedWriter.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private <T extends Number> void saveObservedValues(File file, Collection<T> collection, boolean z) throws IOException {
        BufferedWriter newBufferedWriter = Files.newBufferedWriter(file.toPath(), new OpenOption[0]);
        try {
            if (z) {
                newBufferedWriter.write("# Gamma distribution with parameter:");
                newBufferedWriter.newLine();
                newBufferedWriter.write("# shape " + this.gammaShapeParameter);
                newBufferedWriter.newLine();
                newBufferedWriter.write("# scale " + this.gammaScaleParameter);
            } else {
                newBufferedWriter.write("# Geometric distribution with parameter:");
                newBufferedWriter.newLine();
                newBufferedWriter.write("# probability " + this.hydrogenRearrangementProb);
            }
            for (T t : collection) {
                newBufferedWriter.newLine();
                newBufferedWriter.write(t.toString());
            }
            if (newBufferedWriter != null) {
                newBufferedWriter.close();
            }
        } catch (Throwable th) {
            if (newBufferedWriter != null) {
                try {
                    newBufferedWriter.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void estimateParameters() throws InterruptedException, ExecutionException, IOException {
        System.out.println("Initialize the ExecutorService and collect all tasks for each instance.");
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        ArrayList arrayList = new ArrayList(this.subtreeFileNames.length);
        for (String str : this.subtreeFileNames) {
            arrayList.add(() -> {
                File file = new File(this.subtreeDir, str);
                JacksonDocument jacksonDocument = new JacksonDocument();
                JsonNode fromReader = jacksonDocument.fromReader(new FileReader(file));
                CombinatorialSubtree readTreeFromJson = CombinatorialSubtreeCalculatorJsonReader.readTreeFromJson(fromReader, jacksonDocument);
                ArrayList<Integer> hydrogenRearrangements = CombinatorialSubtreeCalculatorJsonReader.getHydrogenRearrangements(fromReader, jacksonDocument);
                HashMap hashMap = new HashMap();
                Iterator<CombinatorialNode> it = readTreeFromJson.getTerminalNodes().iterator();
                while (it.hasNext()) {
                    CombinatorialNode combinatorialNode = it.next().incomingEdges.get(0).source;
                    hashMap.putIfAbsent(combinatorialNode, Float.valueOf(combinatorialNode.totalScore));
                }
                HashMap<String, ArrayList<IBond>> bondNames2BondList = readTreeFromJson.getRoot().getFragment().parent.bondNames2BondList(true);
                TObjectDoubleHashMap<String> tObjectDoubleHashMap = new TObjectDoubleHashMap<>(2 * bondNames2BondList.size());
                TObjectDoubleHashMap<String> tObjectDoubleHashMap2 = new TObjectDoubleHashMap<>(2 * bondNames2BondList.size());
                Iterator<ArrayList<IBond>> it2 = bondNames2BondList.values().iterator();
                while (it2.hasNext()) {
                    estimateProbabilitiesAndUpdate(readTreeFromJson, it2.next(), tObjectDoubleHashMap, tObjectDoubleHashMap2);
                }
                return new ExtractedData(hydrogenRearrangements, hashMap.values(), tObjectDoubleHashMap, tObjectDoubleHashMap2);
            });
        }
        System.out.println("All tasks will be submitted to the executor service.\nThe main thread will be stopped until all tasks have been completed.");
        List invokeAll = newFixedThreadPool.invokeAll(arrayList);
        newFixedThreadPool.shutdown();
        System.out.println("Collect all extracted instance data.");
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        TObjectDoubleHashMap<String> tObjectDoubleHashMap = new TObjectDoubleHashMap<>();
        TObjectDoubleHashMap<String> tObjectDoubleHashMap2 = new TObjectDoubleHashMap<>();
        TObjectIntHashMap<String> tObjectIntHashMap = new TObjectIntHashMap<>();
        Iterator it = invokeAll.iterator();
        while (it.hasNext()) {
            ExtractedData extractedData = (ExtractedData) ((Future) it.next()).get();
            arrayList2.addAll(extractedData.hydrogenRearrangements);
            arrayList3.addAll(extractedData.penalties);
            for (String str2 : extractedData.directedBondName2BreakProb.keySet()) {
                tObjectIntHashMap.adjustOrPutValue(str2, 1, 1);
                double d = extractedData.directedBondName2BreakProb.get(str2);
                double d2 = extractedData.directedBondName2cutDirProb.get(str2);
                tObjectDoubleHashMap.adjustOrPutValue(str2, d, d);
                tObjectDoubleHashMap2.adjustOrPutValue(str2, d2, d2);
            }
        }
        TObjectDoubleHashMap<String> tObjectDoubleHashMap3 = new TObjectDoubleHashMap<>(tObjectDoubleHashMap.size(), 0.75f);
        for (String str3 : tObjectDoubleHashMap.keySet()) {
            double d3 = tObjectDoubleHashMap.get(str3);
            double d4 = tObjectDoubleHashMap2.get(str3);
            int i = tObjectIntHashMap.get(str3);
            double d5 = (d3 + 1.0E-6d) / (i + 1.0E-6d);
            double d6 = (d4 + 1.0E-6d) / (i + 1.0E-6d);
            tObjectDoubleHashMap.put(str3, d5);
            tObjectDoubleHashMap2.put(str3, d6);
            tObjectDoubleHashMap3.put(str3, Math.log(d5) + Math.log(d6));
        }
        System.out.println("All extracted data was collected. Estimate all parameters.");
        this.hydrogenRearrangementProb = estimateHydrogenRearrangementProbability(arrayList2);
        this.pseudoFragmentScore = calculatePseudoFragmentScore(arrayList3);
        postprocessBondScores(tObjectDoubleHashMap3);
        this.wildcardScore = calculateWildcardScore(tObjectDoubleHashMap);
        System.out.println("All parameters were estimated. Save unprocessed and processed data.");
        saveData(arrayList2, arrayList3, tObjectDoubleHashMap, tObjectDoubleHashMap2, tObjectDoubleHashMap3, tObjectIntHashMap);
    }

    public static void main(String[] strArr) {
        try {
            File file = new File(strArr[0]);
            File file2 = new File(strArr[1]);
            double parseDouble = Double.parseDouble(strArr[2]);
            double parseDouble2 = Double.parseDouble(strArr[3]);
            new ScoringParameterEstimator(file, file2, parseDouble, parseDouble2).estimateParameters();
            BufferedWriter newBufferedWriter = Files.newBufferedWriter(new File(file2, "input_parameters.txt").toPath(), new OpenOption[0]);
            try {
                newBufferedWriter.write("peakExplanationPercentile: " + parseDouble);
                newBufferedWriter.newLine();
                newBufferedWriter.write("bondScoresProbabilityDifference: " + parseDouble2);
                if (newBufferedWriter != null) {
                    newBufferedWriter.close();
                }
            } finally {
            }
        } catch (IOException | InterruptedException | ExecutionException e) {
            e.printStackTrace();
        }
    }
}
