package de.unijena.bioinf.fingerid.cli.tools.temp;

import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.utils.FileUtils;
import de.unijena.bioinf.fingerid.LogisticRegression;
import de.unijena.bioinf.fingerid.cli.CliTool;
import de.unijena.bioinf.fingerid.cli.Configuration;
import de.unijena.bioinf.fingerid.cli.Reporter;
import de.unijena.bioinf.fingerid.cli.ToolSet;
import java.io.File;
import java.io.IOException;
import java.util.Random;

/* loaded from: input_file:de/unijena/bioinf/fingerid/cli/tools/temp/TestLogReg.class */
public class TestLogReg implements CliTool {
    @Override // de.unijena.bioinf.fingerid.cli.CliTool
    public void run(ToolSet toolSet, Configuration configuration, Reporter reporter) throws IOException {
        double[][] dArr = new double[40][4];
        double[][] dArr2 = new double[dArr.length][1];
        Random random = new Random(1337L);
        for (int i = 0; i < dArr.length; i++) {
            double[] dArr3 = dArr[i];
            double[] dArr4 = dArr2[i];
            dArr3[0] = random.nextDouble();
            dArr3[1] = random.nextDouble();
            dArr3[2] = random.nextDouble();
            dArr3[3] = 1.0d;
            dArr4[0] = (dArr3[0] + (3.0d * dArr3[1])) - (2.0d * dArr3[2]) > 0.0d ? 1.0d : -1.0d;
        }
        LogisticRegression logisticRegression = new LogisticRegression(dArr, dArr2);
        double d = 1.0d;
        double d2 = 0.0d;
        for (double d3 : new double[]{1.0d}) {
            logisticRegression.train(d3);
            double meanF1 = meanF1(performance(dArr2, logisticRegression.predict(dArr)));
            System.out.printf("%2.3f: %f\n", Double.valueOf(d3), Double.valueOf(meanF1));
            if (meanF1 >= d2) {
                d2 = meanF1;
                d = d3;
            }
        }
        System.out.println("----------------");
        logisticRegression.train(d);
        for (PredictionPerformance predictionPerformance : performance(dArr2, logisticRegression.predict(dArr))) {
            System.out.println(predictionPerformance);
        }
        FileUtils.writeDoubleMatrix(new File("W"), logisticRegression.getW());
        FileUtils.writeDoubleMatrix(new File("Y"), dArr2);
        FileUtils.writeDoubleMatrix(new File("X"), dArr);
    }

    private void normalizeFeatureMatrix(double[][] dArr) {
        for (int i = 0; i < dArr[0].length; i++) {
            double d = 0.0d;
            for (double[] dArr2 : dArr) {
                d += dArr2[i];
            }
            double length = d / dArr.length;
            for (double[] dArr3 : dArr) {
                int i2 = i;
                dArr3[i2] = dArr3[i2] - length;
            }
        }
    }

    private double meanF1(PredictionPerformance[] predictionPerformanceArr) {
        double d = 0.0d;
        for (PredictionPerformance predictionPerformance : predictionPerformanceArr) {
            d += predictionPerformance.getF();
        }
        return d / predictionPerformanceArr.length;
    }

    private double[][] kernelize(double[][] dArr) {
        double[][] dArr2 = new double[dArr.length][dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i][i] = dotProduct(dArr[i], dArr[i]);
            for (int i2 = 0; i2 < i; i2++) {
                double dotProduct = dotProduct(dArr[i], dArr[i2]);
                dArr2[i2][i] = dotProduct;
                dArr2[i][i2] = dotProduct;
            }
        }
        return dArr2;
    }

    private double dotProduct(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr2[i];
        }
        return d;
    }

    private PredictionPerformance[] performance(double[][] dArr, double[][] dArr2) {
        PredictionPerformance.Modify[] modifyArr = new PredictionPerformance.Modify[dArr[0].length];
        for (int i = 0; i < modifyArr.length; i++) {
            modifyArr[i] = new PredictionPerformance().modify();
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            for (int i3 = 0; i3 < dArr[0].length; i3++) {
                modifyArr[i3].update(dArr[i2][i3] > 0.0d, dArr2[i2][i3] > 0.0d);
            }
        }
        PredictionPerformance[] predictionPerformanceArr = new PredictionPerformance[modifyArr.length];
        for (int i4 = 0; i4 < modifyArr.length; i4++) {
            predictionPerformanceArr[i4] = modifyArr[i4].done();
        }
        return predictionPerformanceArr;
    }

    @Override // de.unijena.bioinf.fingerid.cli.CliTool
    public String getName() {
        return "logreg";
    }

    @Override // de.unijena.bioinf.fingerid.cli.CliTool
    public String getDescription() {
        return "";
    }
}
