package de.unijena.bioinf.confidence_score_train;

import de.unijena.bioinf.ChemistryBase.algorithm.scoring.Scored;
import de.unijena.bioinf.confidence_score.CombinedFeatureCreator;
import de.unijena.bioinf.confidence_score.features.PvalueScoreUtils;
import de.unijena.bioinf.confidence_score.svm.LibLinearImpl;
import de.unijena.bioinf.confidence_score.svm.SVMInterface;
import de.unijena.bioinf.confidence_score.svm.SVMPredict;
import de.unijena.bioinf.confidence_score.svm.SVMScales;
import de.unijena.bioinf.confidence_score.svm.SVMUtils;
import de.unijena.bioinf.confidence_score.svm.TrainedSVM;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.libsvm.SVM;

/* loaded from: input_file:de/unijena/bioinf/confidence_score_train/TrainConfidenceScore.class */
public class TrainConfidenceScore {
    CombinedFeatureCreator featureCreator;
    double[][] featureMatrix;
    double[] labels;
    ArrayList<double[][]> cvFeatureMatrix;
    ArrayList<double[]> cvLabel;
    LibLinearImpl imp;

    public TrainedSVM trainLinearSVMWithCV(double[][] dArr, double[] dArr2, String[] strArr, double[][] dArr3, double[] dArr4, String[] strArr2, String str, String str2) throws IOException {
        FileWriter fileWriter = new FileWriter("/vol/clusterdata/fingerid_martin/exp2_nfp/features.csv");
        this.featureMatrix = dArr;
        this.labels = dArr2;
        new SVMUtils();
        System.out.println("starting scale calc");
        SVMScales calculateScales = SVMUtils.calculateScales(dArr);
        System.out.println("starting standard");
        SVMUtils.standardize_features(dArr, calculateScales);
        List list = (List) IntStream.rangeClosed(0, dArr.length - 1).boxed().collect(Collectors.toList());
        int i = 0;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < 10; i2++) {
            int size = i + (list.size() / 10);
            System.out.println("in fold " + i2 + " - start: " + i + " - end: " + size);
            ArrayList arrayList3 = new ArrayList();
            ArrayList<String> arrayList4 = new ArrayList<>();
            ArrayList arrayList5 = new ArrayList();
            ArrayList arrayList6 = new ArrayList();
            ArrayList arrayList7 = new ArrayList();
            ArrayList arrayList8 = new ArrayList();
            for (int i3 = 0; i3 < list.size(); i3++) {
                if (i3 <= i || i3 >= size) {
                    arrayList7.add((double[]) dArr[((Integer) list.get(i3)).intValue()].clone());
                    arrayList5.add(strArr[((Integer) list.get(i3)).intValue()]);
                    arrayList8.add(Double.valueOf(dArr2[((Integer) list.get(i3)).intValue()]));
                } else {
                    arrayList3.add((double[]) dArr[((Integer) list.get(i3)).intValue()].clone());
                    arrayList6.add(Double.valueOf(dArr2[((Integer) list.get(i3)).intValue()]));
                    arrayList4.add(strArr[((Integer) list.get(i3)).intValue()]);
                }
            }
            for (int i4 = 0; i4 < dArr3.length; i4++) {
                System.out.println(Arrays.toString(dArr3[i4]));
                arrayList7.add((double[]) dArr3[i4].clone());
                arrayList8.add(Double.valueOf(dArr4[i4]));
                arrayList5.add("synth");
            }
            double[][] dArr5 = new double[arrayList3.size()][dArr[0].length];
            double[] dArr6 = new double[arrayList6.size()];
            double[][] dArr7 = new double[arrayList7.size()][dArr[0].length];
            double[] dArr8 = new double[arrayList8.size()];
            for (int i5 = 0; i5 < arrayList3.size(); i5++) {
                dArr5[i5] = (double[]) arrayList3.get(i5);
                dArr6[i5] = ((Double) arrayList6.get(i5)).doubleValue();
            }
            for (int i6 = 0; i6 < arrayList7.size(); i6++) {
                dArr7[i6] = (double[]) arrayList7.get(i6);
                dArr8[i6] = ((Double) arrayList8.get(i6)).doubleValue();
            }
            String[] strArr3 = new String[arrayList5.size()];
            for (int i7 = 0; i7 < arrayList5.size(); i7++) {
                strArr3[i7] = (String) arrayList5.get(i7);
            }
            System.out.println("train size: " + dArr7.length);
            TrainedSVM trainLinearSVM = trainLinearSVM(dArr7, dArr8, dArr5, dArr6, strArr2, calculateScales, String.valueOf(i2), strArr3);
            writeFold(arrayList4, trainLinearSVM, String.valueOf(i2), str, str2);
            for (int i8 = 0; i8 < trainLinearSVM.weights.length; i8++) {
                System.out.print(trainLinearSVM.weights[i8] + " , ");
            }
            System.out.println();
            SVMPredict sVMPredict = new SVMPredict();
            boolean[] zArr = new boolean[dArr6.length];
            for (int i9 = 0; i9 < zArr.length; i9++) {
                if (dArr6[i9] == 1.0d) {
                    zArr[i9] = true;
                } else {
                    zArr[i9] = false;
                }
            }
            double[] predict_confidence = sVMPredict.predict_confidence(dArr5, trainLinearSVM);
            new PvalueScoreUtils();
            for (int i10 = 0; i10 < predict_confidence.length; i10++) {
                arrayList.add(Double.valueOf(predict_confidence[i10]));
                arrayList2.add(Boolean.valueOf(zArr[i10]));
            }
            i = size;
        }
        double[] dArr9 = new double[arrayList.size()];
        boolean[] zArr2 = new boolean[arrayList2.size()];
        for (int i11 = 0; i11 < arrayList.size(); i11++) {
            dArr9[i11] = ((Double) arrayList.get(i11)).doubleValue();
            zArr2[i11] = ((Boolean) arrayList2.get(i11)).booleanValue();
        }
        writeScores(dArr9, zArr2);
        System.out.println(new Stats(dArr9, zArr2).getAUC());
        TrainedSVM trainedSVM = new TrainedSVM((SVMScales) null, (double[]) null, (String[]) null);
        fileWriter.close();
        return trainedSVM;
    }

    public TrainedSVM trainLinearSVM(double[][] dArr, double[] dArr2, double[][] dArr3, double[] dArr4, String[] strArr, SVMScales sVMScales, String str, String[] strArr2) {
        this.imp = new LibLinearImpl();
        new SVMUtils();
        ArrayList arrayList = new ArrayList();
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = 1;
            ArrayList arrayList2 = new ArrayList();
            for (int i5 = 0; i5 < dArr[i3].length; i5++) {
                arrayList2.add(this.imp.createSVM_Node(i4, dArr[i3][i5]));
                i4++;
            }
            arrayList.add(arrayList2);
        }
        for (double d : dArr2) {
            if (d == 1.0d) {
                i++;
            } else {
                i2++;
            }
        }
        System.out.println("before gridsearch");
        double[] dArr5 = {10.0d, 100.0d, 1000.0d};
        double[] dArr6 = {1.0E-6d, 1.0E-5d, 1.0E-4d, 0.001d, 0.01d, 0.1d, 1.0d};
        double d2 = -1.0d;
        double[] dArr7 = new double[2];
        LibLinearImpl.svm_model svm_modelVar = null;
        for (int i6 = 0; i6 < dArr5.length; i6++) {
            for (int i7 = 0; i7 < dArr6.length; i7++) {
                PrintStream printStream = System.out;
                double d3 = dArr5[i6];
                double d4 = dArr6[i7];
                printStream.println("Computing C: " + d3 + "  Computing epsilon: " + printStream);
                LibLinearImpl.svm_problemImpl createSVM_Problem = this.imp.createSVM_Problem();
                createSVM_Problem.svm_problem.bias = 0.0d;
                createSVM_Problem.setX(arrayList);
                createSVM_Problem.setY(dArr2);
                createSVM_Problem.setL(arrayList.size());
                SVMInterface.svm_parameter svm_parameterVar = new SVMInterface.svm_parameter();
                svm_parameterVar.C = dArr5[i6];
                svm_parameterVar.kernel_type = 0;
                svm_parameterVar.eps = dArr6[i7];
                svm_parameterVar.weight = new double[]{1.0d, 1.0d};
                svm_parameterVar.weight_label = new int[]{-1, 1};
                LibLinearImpl.svm_model svm_train = this.imp.svm_train(createSVM_Problem, svm_parameterVar);
                System.out.println("trained");
                SVMPredict sVMPredict = new SVMPredict();
                TrainedSVM trainedSVM = new TrainedSVM(sVMScales, svm_train.getModel().getFeatureWeights(), strArr);
                double[] dArr8 = new double[2];
                int i8 = 0;
                for (String str2 : strArr2) {
                    if (!str2.equals("synth")) {
                        i8++;
                    }
                }
                double[][] dArr9 = new double[i8][dArr[0].length];
                double[] dArr10 = new double[i8];
                for (int i9 = 0; i9 < dArr.length; i9++) {
                    if (!strArr2[i9].equals("synth")) {
                        dArr9[i9] = dArr[i9];
                        dArr10[i9] = dArr2[i9];
                    }
                }
                System.out.println("reduced: " + dArr9.length + " - orig: " + dArr);
                trainSigmoid(dArr9, dArr10, trainedSVM, dArr8, false);
                PrintStream printStream2 = System.out;
                double d5 = dArr8[0];
                double d6 = dArr8[1];
                printStream2.println("trained sigmoid - " + d5 + " - " + printStream2);
                double[] predict_confidence = sVMPredict.predict_confidence(dArr3, trainedSVM);
                boolean[] zArr = new boolean[dArr4.length];
                for (int i10 = 0; i10 < dArr4.length; i10++) {
                    if (dArr4[i10] == 1.0d) {
                        zArr[i10] = true;
                    } else {
                        zArr[i10] = false;
                    }
                }
                double auc = new Stats(predict_confidence, zArr).getAUC();
                if (auc > d2) {
                    d2 = auc;
                    svm_modelVar = svm_train;
                    dArr7 = dArr8;
                }
            }
        }
        TrainedSVM trainedSVM2 = new TrainedSVM(sVMScales, svm_modelVar.getModel().getFeatureWeights(), strArr);
        trainedSVM2.probAB = dArr7;
        new SVMPredict().predict_confidence(dArr3, trainedSVM2);
        boolean[] zArr2 = new boolean[dArr4.length];
        for (int i11 = 0; i11 < dArr4.length; i11++) {
            if (dArr4[i11] == 1.0d) {
                zArr2[i11] = true;
            } else {
                zArr2[i11] = false;
            }
        }
        System.out.println("fold complete" + str);
        PrintStream printStream3 = System.out;
        double d7 = svm_modelVar.getParam().C;
        double d8 = svm_modelVar.getParam().eps;
        printStream3.println(d2 + " - " + printStream3 + " - " + d7);
        return trainedSVM2;
    }

    public void writeBogusScores(double[] dArr, boolean[] zArr) {
        try {
            FileWriter fileWriter = new FileWriter(new File("/vol/clusterdata/fingerid_martin/exp2/bogus_dist.txt"));
            for (double d : dArr) {
                fileWriter.write(d + "\n");
            }
            fileWriter.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public TrainedSVM trainLinearSVMNoEval(double[][] dArr, double[] dArr2, String[] strArr, double[][] dArr3, double[] dArr4, double[] dArr5, double d, double d2) {
        this.imp = new LibLinearImpl();
        new SVMUtils();
        SVMScales calculateScales = SVMUtils.calculateScales(dArr);
        SVMUtils.standardize_features(dArr, calculateScales);
        ArrayList arrayList = new ArrayList();
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = 1;
            ArrayList arrayList2 = new ArrayList();
            for (int i5 = 0; i5 < dArr[i3].length; i5++) {
                arrayList2.add(this.imp.createSVM_Node(i4, dArr[i3][i5]));
                i4++;
            }
            arrayList.add(arrayList2);
        }
        for (int i6 = 0; i6 < dArr3.length; i6++) {
            int i7 = 1;
            ArrayList arrayList3 = new ArrayList();
            for (int i8 = 0; i8 < dArr3[i6].length; i8++) {
                arrayList3.add(this.imp.createSVM_Node(i7, dArr3[i6][i8]));
                i7++;
            }
            arrayList.add(arrayList3);
        }
        for (double d3 : dArr2) {
            if (d3 == 1.0d) {
                i++;
            } else {
                i2++;
            }
        }
        double[] dArr6 = new double[dArr2.length + dArr4.length];
        for (int i9 = 0; i9 < dArr2.length; i9++) {
            dArr6[i9] = dArr2[i9];
        }
        for (int length = dArr2.length; length < dArr6.length; length++) {
            dArr6[length] = dArr4[length - dArr2.length];
        }
        LibLinearImpl.svm_problemImpl createSVM_Problem = this.imp.createSVM_Problem();
        createSVM_Problem.svm_problem.bias = 0.0d;
        createSVM_Problem.setX(arrayList);
        createSVM_Problem.setY(dArr6);
        createSVM_Problem.setL(arrayList.size());
        SVMInterface.svm_parameter svm_parameterVar = new SVMInterface.svm_parameter();
        svm_parameterVar.C = d;
        svm_parameterVar.kernel_type = 0;
        svm_parameterVar.eps = d2;
        svm_parameterVar.weight = new double[]{1.0d, 1.0d};
        svm_parameterVar.weight_label = new int[]{-1, 1};
        TrainedSVM trainedSVM = new TrainedSVM(calculateScales, this.imp.svm_train(createSVM_Problem, svm_parameterVar).getModel().getFeatureWeights(), strArr);
        double[] dArr7 = new double[2];
        trainSigmoid(dArr, dArr2, trainedSVM, dArr7, false);
        trainedSVM.probAB = dArr7;
        for (int i10 = 0; i10 < trainedSVM.weights.length; i10++) {
            System.out.print(trainedSVM.weights[i10] + " , ");
        }
        System.out.println();
        return trainedSVM;
    }

    public void trainSigmoid(double[][] dArr, double[] dArr2, TrainedSVM trainedSVM, double[] dArr3, Boolean bool) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < dArr.length; i++) {
            hashMap.put(dArr[i], Double.valueOf(dArr2[i]));
        }
        double[][] upperFeatures = getUpperFeatures(dArr, 1.0d, trainedSVM);
        double[] dArr4 = new double[upperFeatures.length];
        for (int i2 = 0; i2 < upperFeatures.length; i2++) {
            dArr4[i2] = ((Double) hashMap.get(upperFeatures[i2])).doubleValue();
        }
        SVMPredict sVMPredict = new SVMPredict();
        if (bool.booleanValue()) {
            SVM.sigmoid_train(upperFeatures.length, sVMPredict.predict_confidence(upperFeatures, trainedSVM), dArr4, dArr3);
        }
        if (bool.booleanValue()) {
            return;
        }
        SVM.sigmoid_train(dArr.length, sVMPredict.predict_confidence(dArr, trainedSVM), dArr2, dArr3);
    }

    public void writeFold(ArrayList<String> arrayList, TrainedSVM trainedSVM, String str, String str2, String str3) {
        try {
            System.out.println("writing fold" + str);
            File file = new File("/vol/clusterdata/fingerid_martin/fingerid_confidence_120/cv_folds/fold" + str);
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(file + "/testids_" + str2 + "_" + str3 + "_" + str)));
            for (int i = 0; i < arrayList.size(); i++) {
                bufferedWriter.write(arrayList.get(i) + "\n");
            }
            bufferedWriter.close();
            trainedSVM.exportAsJSON(new File(file + "/svm_" + str2 + "_" + str3 + "_" + str));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public double[][] getUpperFeatures(double[][] dArr, double d, TrainedSVM trainedSVM) {
        SVMPredict sVMPredict = new SVMPredict();
        double[][] dArr2 = new double[(int) Math.round(dArr.length * d)][dArr[0].length];
        double[] predict_confidence = sVMPredict.predict_confidence(dArr, trainedSVM);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < predict_confidence.length; i++) {
            arrayList.add(new Scored(dArr[i], predict_confidence[i]));
        }
        Collections.sort(arrayList);
        Collections.reverse(arrayList);
        for (int i2 = 0; i2 < ((int) Math.round(dArr.length * d)); i2++) {
            dArr2[i2] = (double[]) ((Scored) arrayList.get(i2)).getCandidate();
        }
        return dArr2;
    }

    public void writeScores(double[] dArr, boolean[] zArr) {
        try {
            FileWriter fileWriter = new FileWriter("/vol/clusterdata/fingerid_martin/fingerid_confidence_120/scores_true.txt");
            FileWriter fileWriter2 = new FileWriter("/vol/clusterdata/fingerid_martin/fingerid_confidence_120/scores_bogus.txt");
            for (int i = 0; i < dArr.length; i++) {
                if (zArr[i]) {
                    fileWriter.write(dArr[i] + "\n");
                } else {
                    fileWriter2.write(dArr[i] + "\n");
                }
            }
            fileWriter2.close();
            fileWriter.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public int posExampleAmount(double[] dArr) {
        int i = 0;
        for (double d : dArr) {
            if (d == 1.0d) {
                i++;
            }
        }
        return i;
    }

    public int negExampleAmount(double[] dArr) {
        int i = 0;
        for (double d : dArr) {
            if (d == -1.0d) {
                i++;
            }
        }
        return i;
    }
}
