package de.unijena.bioinf.canopus;

import de.unijena.bioinf.ChemistryBase.chem.FormulaConstraints;
import de.unijena.bioinf.ChemistryBase.chem.MolecularFormula;
import de.unijena.bioinf.ChemistryBase.chem.PeriodicTable;
import de.unijena.bioinf.ChemistryBase.fp.ArrayFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.CdkFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.ClassyFireFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.ClassyfireProperty;
import de.unijena.bioinf.ChemistryBase.fp.CustomFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.FPIter;
import de.unijena.bioinf.ChemistryBase.fp.MaskedFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.ProbabilityFingerprint;
import de.unijena.bioinf.canopus.BufferedTrainData;
import de.unijena.bioinf.fingerid.KernelToNumpyConverter;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.array.TShortArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.procedure.TIntIntProcedure;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.Function;
import java.util.regex.Pattern;
import org.tensorflow.Tensor;

/* loaded from: input_file:de/unijena/bioinf/canopus/TrainingData.class */
public class TrainingData {
    public static final boolean INCLUDE_FINGERPRINT = false;
    public static final boolean SAMPLE_FROM_TEMPLATE_FINGERPRINTS = true;
    protected ClassyFireFingerprintVersion classyFireFingerprintVersion;
    protected MaskedFingerprintVersion classyFireMask;
    protected MaskedFingerprintVersion fingerprintVersion;
    protected MaskedFingerprintVersion canopusFingerprint;
    protected MaskedFingerprintVersion withoutCanopus;
    protected MaskedFingerprintVersion canopusOnly;
    protected final TIntObjectHashMap<CompoundClass> compoundClasses;
    protected final HashMap<String, CompoundClass> name2class;
    protected final List<LabeledCompound> compounds;
    protected final HashSet<String> blacklist;
    protected CustomFingerprintVersion dummyFingerprintVersion;
    protected List<EvaluationInstance> crossvalidation;
    protected List<EvaluationInstance> independent;
    public double[] formulaNorm;
    public double[] formulaScale;
    public double[] plattNorm;
    public double[] plattScale;
    protected Sampler fingerprintSampler;
    protected Pattern independentPattern;
    protected int nplatts;
    protected int nformulas;
    protected int nlabels;
    protected TIntIntHashMap canopusFingerprintMapping;
    public static final CdkFingerprintVersion VERSION = CdkFingerprintVersion.getComplete();
    public static boolean SCALE_BY_MAX = false;
    public static boolean VECNORM_SCALING = false;
    public static boolean PLATT_CENTERING = true;
    public static boolean SCALE_BY_STD = false;
    protected static int GROW = 1;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: de.unijena.bioinf.canopus.TrainingData$4, reason: invalid class name */
    /* loaded from: input_file:de/unijena/bioinf/canopus/TrainingData$4.class */
    public static /* synthetic */ class AnonymousClass4 {
        static final /* synthetic */ int[] $SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy = new int[SamplingStrategy.values().length];

        static {
            try {
                $SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy[SamplingStrategy.PERFECT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy[SamplingStrategy.INDEPENDENT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy[SamplingStrategy.INDEPENDENT_DISTURBED.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy[SamplingStrategy.TEMPLATE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy[SamplingStrategy.DISTURBED_TEMPLATE.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy[SamplingStrategy.CONDITIONAL.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:de/unijena/bioinf/canopus/TrainingData$SamplingStrategy.class */
    public enum SamplingStrategy {
        PERFECT,
        INDEPENDENT,
        INDEPENDENT_DISTURBED,
        TEMPLATE,
        DISTURBED_TEMPLATE,
        CONDITIONAL
    }

    /* loaded from: input_file:de/unijena/bioinf/canopus/TrainingData$SamplingStrategyFunction.class */
    public interface SamplingStrategyFunction {
        SamplingStrategy sample(EvaluationInstance evaluationInstance, int i);
    }

    public TrainingData(File file) throws IOException {
        this(file, null);
    }

    public TrainingData(File file, Pattern pattern) throws IOException {
        this.independentPattern = pattern;
        this.compoundClasses = new TIntObjectHashMap<>(4000);
        this.compounds = new ArrayList(1200000);
        this.blacklist = new HashSet<>(12000);
        this.name2class = new HashMap<>(4000);
        setupEnv(file);
    }

    public void normalizeVector(LabeledCompound labeledCompound) {
        float[] fArr = new float[labeledCompound.formulaFeatures.length];
        for (int i = 0; i < labeledCompound.formulaFeatures.length; i++) {
            fArr[i] = (float) ((labeledCompound.formulaFeatures[i] - this.formulaNorm[i]) / this.formulaScale[i]);
        }
        labeledCompound.formulaFeaturesF = fArr;
    }

    public TrainingBatch fillUpWithTrainData() {
        final Random random = new Random();
        final ArrayList arrayList = new ArrayList();
        arrayList.addAll(this.crossvalidation);
        arrayList.addAll(this.independent);
        final TIntIntHashMap tIntIntHashMap = new TIntIntHashMap();
        Iterator<EvaluationInstance> it = arrayList.iterator();
        while (it.hasNext()) {
            for (short s : it.next().compound.label.toIndizesArray()) {
                tIntIntHashMap.adjustOrPutValue(s, 1, 1);
            }
        }
        tIntIntHashMap.forEachEntry(new TIntIntProcedure() { // from class: de.unijena.bioinf.canopus.TrainingData.1
            public boolean execute(int i, int i2) {
                if (i2 >= 30) {
                    return true;
                }
                for (LabeledCompound labeledCompound : ((CompoundClass) TrainingData.this.compoundClasses.get(i)).drawExamples(20 - i2, random)) {
                    arrayList.add(new EvaluationInstance("", TrainingData.this.fingerprintSampler.sample(labeledCompound.fingerprint, false), labeledCompound));
                    for (short s2 : labeledCompound.label.toIndizesArray()) {
                        tIntIntHashMap.adjustOrPutValue(s2, 1, 1);
                    }
                }
                return true;
            }
        });
        return generateBatch(arrayList);
    }

    protected void addNormalizedPlatts(FloatBuffer floatBuffer, double[] dArr) {
        if (!VECNORM_SCALING) {
            if (PLATT_CENTERING || SCALE_BY_STD) {
                for (int i = 0; i < dArr.length; i++) {
                    floatBuffer.put((float) ((dArr[i] - this.plattNorm[i]) / this.plattScale[i]));
                }
                return;
            }
            for (double d : dArr) {
                floatBuffer.put((float) d);
            }
            return;
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (PLATT_CENTERING) {
                int i3 = i2;
                dArr[i3] = dArr[i3] - this.plattNorm[i2];
            }
            d2 += dArr[i2] * dArr[i2];
        }
        double sqrt = Math.sqrt(d2);
        for (int i4 = 0; i4 < this.nplatts; i4++) {
            floatBuffer.put((float) (dArr[i4] / sqrt));
        }
    }

    public TrainingBatch generateBatch(List<EvaluationInstance> list) {
        FloatBuffer allocate = FloatBuffer.allocate(list.size() * this.nplatts);
        FloatBuffer allocate2 = FloatBuffer.allocate(list.size() * this.nformulas);
        FloatBuffer allocate3 = FloatBuffer.allocate(list.size() * this.nlabels);
        for (EvaluationInstance evaluationInstance : list) {
            addNormalizedPlatts(allocate, evaluationInstance.fingerprint.toProbabilityArray());
            allocate2.put(evaluationInstance.compound.formulaFeaturesF);
            allocate3.put(getLabelVector(evaluationInstance.compound));
        }
        allocate.rewind();
        allocate2.rewind();
        allocate3.rewind();
        return new TrainingBatch(Tensor.create(new long[]{list.size(), this.nplatts}, allocate), Tensor.create(new long[]{list.size(), this.nformulas}, allocate2), Tensor.create(new long[]{list.size(), this.nlabels}, allocate3));
    }

    public TrainingBatch resample(List<EvaluationInstance> list, SamplingStrategyFunction samplingStrategyFunction) {
        FloatBuffer allocate = FloatBuffer.allocate(list.size() * this.nplatts);
        FloatBuffer allocate2 = FloatBuffer.allocate(list.size() * this.nformulas);
        FloatBuffer allocate3 = FloatBuffer.allocate(list.size() * this.nlabels);
        int i = 0;
        for (EvaluationInstance evaluationInstance : list) {
            int i2 = i;
            int i3 = i + 1;
            switch (AnonymousClass4.$SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy[samplingStrategyFunction.sample(evaluationInstance, i2).ordinal()]) {
                case SAMPLE_FROM_TEMPLATE_FINGERPRINTS /* 1 */:
                    addNormalizedPlatts(allocate, evaluationInstance.compound.fingerprint.toProbabilityArray());
                    break;
                case 2:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sampleIndependently(evaluationInstance.compound.fingerprint, false).toProbabilityArray());
                    break;
                case 3:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sampleIndependently(evaluationInstance.compound.fingerprint, true).toProbabilityArray());
                    break;
                case 4:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sample(evaluationInstance.compound.fingerprint, false).toProbabilityArray());
                    break;
                case 5:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sample(evaluationInstance.compound.fingerprint, true).toProbabilityArray());
                    break;
                case 6:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sampleFromCovariance(evaluationInstance.compound.fingerprint).toProbabilityArray());
                    break;
            }
            allocate2.put(evaluationInstance.compound.formulaFeaturesF);
            allocate3.put(getLabelVector(evaluationInstance.compound));
            i = i3 + 1;
        }
        allocate.rewind();
        allocate2.rewind();
        allocate3.rewind();
        return new TrainingBatch(Tensor.create(new long[]{list.size(), this.nplatts}, allocate), Tensor.create(new long[]{list.size(), this.nformulas}, allocate2), Tensor.create(new long[]{list.size(), this.nlabels}, allocate3));
    }

    public TrainingBatch resample(List<EvaluationInstance> list) {
        return resample(list, 1000);
    }

    public TrainingBatch resampleMultithreaded(List<EvaluationInstance> list, SamplingStrategyFunction samplingStrategyFunction) {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(40);
        try {
            ArrayList arrayList = new ArrayList();
            int i = 0;
            for (final EvaluationInstance evaluationInstance : list) {
                int i2 = i;
                i++;
                final SamplingStrategy sample = samplingStrategyFunction.sample(evaluationInstance, i2);
                arrayList.add(newFixedThreadPool.submit(new Callable<double[]>() { // from class: de.unijena.bioinf.canopus.TrainingData.2
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public double[] call() {
                        return TrainingData.this.sampleFingerprintVector(evaluationInstance.compound, sample);
                    }
                }));
            }
            FloatBuffer allocate = FloatBuffer.allocate(list.size() * this.nplatts);
            FloatBuffer allocate2 = FloatBuffer.allocate(list.size() * this.nformulas);
            FloatBuffer allocate3 = FloatBuffer.allocate(list.size() * this.nlabels);
            for (int i3 = 0; i3 < list.size(); i3++) {
                try {
                    addNormalizedPlatts(allocate, (double[]) ((Future) arrayList.get(i3)).get());
                    allocate2.put(list.get(i3).compound.formulaFeaturesF);
                    allocate3.put(getLabelVector(list.get(i3).compound));
                } catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException(e);
                }
            }
            allocate.rewind();
            allocate2.rewind();
            allocate3.rewind();
            TrainingBatch trainingBatch = new TrainingBatch(Tensor.create(new long[]{list.size(), this.nplatts}, allocate), Tensor.create(new long[]{list.size(), this.nformulas}, allocate2), Tensor.create(new long[]{list.size(), this.nlabels}, allocate3));
            newFixedThreadPool.shutdown();
            return trainingBatch;
        } catch (Throwable th) {
            newFixedThreadPool.shutdown();
            throw th;
        }
    }

    public TrainingBatch resample(List<EvaluationInstance> list, int i) {
        return resample(list, getSamplingStrategies(i));
    }

    public TrainingBatch resample(List<EvaluationInstance> list, List<SamplingStrategy> list2) {
        FloatBuffer allocate = FloatBuffer.allocate(list.size() * this.nplatts);
        FloatBuffer allocate2 = FloatBuffer.allocate(list.size() * this.nformulas);
        FloatBuffer allocate3 = FloatBuffer.allocate(list.size() * this.nlabels);
        int i = 0;
        for (EvaluationInstance evaluationInstance : list) {
            switch (AnonymousClass4.$SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy[list2.get(i % list2.size()).ordinal()]) {
                case SAMPLE_FROM_TEMPLATE_FINGERPRINTS /* 1 */:
                    addNormalizedPlatts(allocate, evaluationInstance.compound.fingerprint.toProbabilityArray());
                    break;
                case 2:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sampleIndependently(evaluationInstance.compound.fingerprint, false).toProbabilityArray());
                    break;
                case 3:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sampleIndependently(evaluationInstance.compound.fingerprint, true).toProbabilityArray());
                    break;
                case 4:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sample(evaluationInstance.compound.fingerprint, false).toProbabilityArray());
                    break;
                case 5:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sample(evaluationInstance.compound.fingerprint, true).toProbabilityArray());
                    break;
                case 6:
                    addNormalizedPlatts(allocate, this.fingerprintSampler.sampleFromCovariance(evaluationInstance.compound.fingerprint).toProbabilityArray());
                    break;
            }
            allocate2.put(evaluationInstance.compound.formulaFeaturesF);
            allocate3.put(getLabelVector(evaluationInstance.compound));
            i++;
        }
        allocate.rewind();
        allocate2.rewind();
        allocate3.rewind();
        return new TrainingBatch(Tensor.create(new long[]{list.size(), this.nplatts}, allocate), Tensor.create(new long[]{list.size(), this.nformulas}, allocate2), Tensor.create(new long[]{list.size(), this.nlabels}, allocate3));
    }

    public TrainingBatch generateBatch(int i, BufferedTrainData bufferedTrainData, final ExecutorService executorService) {
        int i2 = GROW;
        final ArrayList arrayList = new ArrayList();
        final SamplingStrategy[] samplingStrategyArr = (SamplingStrategy[]) getSamplingStrategies(i).toArray(new SamplingStrategy[0]);
        final ArrayList arrayList2 = new ArrayList();
        balancedSample(i, 5, 500, new Function<LabeledCompound, LabeledCompound>() { // from class: de.unijena.bioinf.canopus.TrainingData.3
            protected int counter = 0;

            @Override // java.util.function.Function
            public LabeledCompound apply(final LabeledCompound labeledCompound) {
                final int i3 = this.counter;
                this.counter = i3 + 1;
                boolean z = i3 % 311 == 0;
                arrayList.add(labeledCompound);
                arrayList2.add(executorService.submit(new Callable<double[]>() { // from class: de.unijena.bioinf.canopus.TrainingData.3.1
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public double[] call() throws Exception {
                        return TrainingData.this.sampleFingerprintVector(labeledCompound, samplingStrategyArr[i3 % samplingStrategyArr.length]);
                    }
                }));
                return labeledCompound;
            }
        });
        FloatBuffer allocate = FloatBuffer.allocate(arrayList.size() * this.nplatts);
        FloatBuffer allocate2 = FloatBuffer.allocate(arrayList.size() * this.nformulas);
        FloatBuffer allocate3 = FloatBuffer.allocate(arrayList.size() * this.nlabels);
        TIntArrayList tIntArrayList = new TIntArrayList(arrayList.size());
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            tIntArrayList.add(i3);
        }
        tIntArrayList.shuffle(new Random());
        int i4 = 0;
        for (int i5 : tIntArrayList.toArray()) {
            LabeledCompound labeledCompound = (LabeledCompound) arrayList.get(i5);
            try {
                double[] dArr = (double[]) ((Future) arrayList2.get(i5)).get();
                addNormalizedPlatts(allocate, dArr);
                allocate2.put(labeledCompound.formulaFeaturesF);
                float[] labelVector = getLabelVector(labeledCompound);
                allocate3.put(labelVector);
                if (bufferedTrainData != null) {
                    int i6 = i4;
                    i4++;
                    BufferedTrainData.Buffer buffer = bufferedTrainData.getBuffer(i6);
                    synchronized (buffer) {
                        if (buffer.filled < buffer.size) {
                            addNormalizedPlatts(buffer.p, dArr);
                            buffer.f.put(labeledCompound.formulaFeaturesF);
                            buffer.l.put(labelVector);
                            buffer.fill();
                        }
                    }
                }
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        }
        allocate.rewind();
        allocate2.rewind();
        allocate3.rewind();
        return new TrainingBatch(Tensor.create(new long[]{arrayList.size(), this.nplatts}, allocate), Tensor.create(new long[]{arrayList.size(), this.nformulas}, allocate2), Tensor.create(new long[]{arrayList.size(), this.nlabels}, allocate3));
    }

    private List<SamplingStrategy> getSamplingStrategies(int i) {
        SamplingStrategy[] samplingStrategyArr;
        if (i < 200) {
            samplingStrategyArr = new SamplingStrategy[100];
            for (int i2 = 0; i2 < 33; i2++) {
                samplingStrategyArr[i2] = SamplingStrategy.INDEPENDENT;
            }
            for (int i3 = 33; i3 < 66; i3++) {
                samplingStrategyArr[i3] = SamplingStrategy.CONDITIONAL;
            }
            for (int i4 = 66; i4 < 90; i4++) {
                samplingStrategyArr[i4] = SamplingStrategy.TEMPLATE;
            }
            for (int i5 = 90; i5 < 100; i5++) {
                samplingStrategyArr[i5] = SamplingStrategy.PERFECT;
            }
        } else {
            samplingStrategyArr = new SamplingStrategy[200];
            for (int i6 = 0; i6 < 20; i6++) {
                samplingStrategyArr[i6] = SamplingStrategy.INDEPENDENT;
            }
            for (int i7 = 20; i7 < 50; i7++) {
                samplingStrategyArr[i7] = SamplingStrategy.INDEPENDENT_DISTURBED;
            }
            for (int i8 = 50; i8 < 66; i8++) {
                samplingStrategyArr[i8] = SamplingStrategy.CONDITIONAL;
            }
            for (int i9 = 66; i9 < 160; i9++) {
                samplingStrategyArr[i9] = SamplingStrategy.TEMPLATE;
            }
            for (int i10 = 160; i10 < 199; i10++) {
                samplingStrategyArr[i10] = SamplingStrategy.DISTURBED_TEMPLATE;
            }
            for (int i11 = 199; i11 < 200; i11++) {
                samplingStrategyArr[i11] = SamplingStrategy.PERFECT;
            }
        }
        ArrayList arrayList = new ArrayList(Arrays.asList(samplingStrategyArr));
        Collections.shuffle(arrayList);
        return arrayList;
    }

    private void balancedSample(int i, int i2, int i3, Function<LabeledCompound, LabeledCompound> function) {
        Random random = new Random();
        TIntHashSet tIntHashSet = new TIntHashSet();
        HashSet hashSet = new HashSet();
        for (CompoundClass compoundClass : this.compoundClasses.valueCollection()) {
            if (compoundClass.compounds.size() < 10 * i2) {
                ArrayList<LabeledCompound> arrayList = new ArrayList(compoundClass.compounds);
                Collections.shuffle(arrayList, random);
                int i4 = 0;
                for (LabeledCompound labeledCompound : arrayList) {
                    if (!hashSet.contains(labeledCompound.inchiKey) && !this.blacklist.contains(labeledCompound.inchiKey)) {
                        hashSet.add(labeledCompound.inchiKey);
                        function.apply(labeledCompound);
                        i4++;
                        if (i4 >= i2) {
                            break;
                        }
                    }
                }
            } else {
                tIntHashSet.clear();
                int i5 = 0;
                int i6 = i2 * 10;
                int i7 = 0;
                while (i7 < i2) {
                    i5++;
                    if (i5 < i6) {
                        int nextInt = random.nextInt(compoundClass.compounds.size());
                        if (!tIntHashSet.contains(nextInt)) {
                            tIntHashSet.add(nextInt);
                            i7++;
                            LabeledCompound labeledCompound2 = compoundClass.compounds.get(nextInt);
                            if (!hashSet.contains(labeledCompound2.inchiKey) && !this.blacklist.contains(labeledCompound2.inchiKey)) {
                                hashSet.add(labeledCompound2.inchiKey);
                                function.apply(labeledCompound2);
                            }
                        }
                    }
                }
            }
        }
        int i8 = i * i3;
        for (int i9 = 0; i9 < i3; i9++) {
            int size = (i9 + i8) % this.compounds.size();
            if (!hashSet.contains(this.compounds.get(size).inchiKey) && !this.blacklist.contains(this.compounds.get(size).inchiKey)) {
                function.apply(this.compounds.get(size));
                hashSet.add(this.compounds.get(size).inchiKey);
            }
        }
    }

    private double[] sampleFingerprintVectorPerfectly(LabeledCompound labeledCompound) {
        return labeledCompound.fingerprint.asProbabilistic().toProbabilityArray();
    }

    public ProbabilityFingerprint sampleFingerprint(LabeledCompound labeledCompound, SamplingStrategy samplingStrategy) {
        switch (AnonymousClass4.$SwitchMap$de$unijena$bioinf$canopus$TrainingData$SamplingStrategy[samplingStrategy.ordinal()]) {
            case SAMPLE_FROM_TEMPLATE_FINGERPRINTS /* 1 */:
                return labeledCompound.fingerprint.asProbabilistic();
            case 2:
                return this.fingerprintSampler.sampleIndependently(labeledCompound.fingerprint, false);
            case 3:
                return this.fingerprintSampler.sampleIndependently(labeledCompound.fingerprint, true);
            case 4:
                return this.fingerprintSampler.sample(labeledCompound.fingerprint, false);
            case 5:
                return this.fingerprintSampler.sample(labeledCompound.fingerprint, true);
            case 6:
                return this.fingerprintSampler.sampleFromCovariance(labeledCompound.fingerprint);
            default:
                throw new RuntimeException("Unknown strategy: " + String.valueOf(samplingStrategy));
        }
    }

    public double[] sampleFingerprintVector(LabeledCompound labeledCompound, SamplingStrategy samplingStrategy) {
        return sampleFingerprint(labeledCompound, samplingStrategy).toProbabilityArray();
    }

    public static float[] getLabelVector(LabeledCompound labeledCompound) {
        int size = labeledCompound.label.getFingerprintVersion().size();
        MaskedFingerprintVersion fingerprintVersion = labeledCompound.label.getFingerprintVersion();
        float[] fArr = new float[size];
        Arrays.fill(fArr, -1.0f);
        Iterator it = labeledCompound.label.presentFingerprints().iterator();
        while (it.hasNext()) {
            fArr[fingerprintVersion.getRelativeIndexOf(((FPIter) it.next()).getIndex())] = 1.0f;
        }
        return fArr;
    }

    public List<LabeledCompound> balancedSample(int i, int i2, int i3) {
        Random random = new Random();
        ArrayList arrayList = new ArrayList(i2 * this.compoundClasses.size());
        TIntHashSet tIntHashSet = new TIntHashSet();
        HashSet hashSet = new HashSet();
        for (CompoundClass compoundClass : this.compoundClasses.valueCollection()) {
            if (compoundClass.compounds.size() < 10 * i2) {
                ArrayList<LabeledCompound> arrayList2 = new ArrayList(compoundClass.compounds);
                Collections.shuffle(arrayList2, random);
                int i4 = 0;
                for (LabeledCompound labeledCompound : arrayList2) {
                    if (!hashSet.contains(labeledCompound.inchiKey) && !this.blacklist.contains(labeledCompound.inchiKey)) {
                        hashSet.add(labeledCompound.inchiKey);
                        arrayList.add(labeledCompound);
                        i4++;
                        if (i4 >= i2) {
                            break;
                        }
                    }
                }
            } else {
                tIntHashSet.clear();
                int i5 = 0;
                int i6 = i2 * 10;
                while (tIntHashSet.size() < i2) {
                    i5++;
                    if (i5 < i6) {
                        int nextInt = random.nextInt(compoundClass.compounds.size());
                        if (!tIntHashSet.contains(nextInt) && !hashSet.contains(compoundClass.compounds.get(nextInt).inchiKey) && !this.blacklist.contains(compoundClass.compounds.get(nextInt).inchiKey)) {
                            tIntHashSet.add(nextInt);
                            hashSet.add(compoundClass.compounds.get(nextInt).inchiKey);
                            arrayList.add(compoundClass.compounds.get(nextInt));
                        }
                    }
                }
            }
        }
        int i7 = i * i3;
        for (int i8 = 0; i8 < i3; i8++) {
            int size = (i8 + i7) % this.compounds.size();
            if (!hashSet.contains(this.compounds.get(size).inchiKey) && !this.blacklist.contains(this.compounds.get(size).inchiKey)) {
                arrayList.add(this.compounds.get(size));
                hashSet.add(this.compounds.get(size).inchiKey);
            }
        }
        Collections.shuffle(arrayList, random);
        return arrayList;
    }

    public void setupEnv(File file) throws IOException {
        ArrayFingerprint arrayFingerprint;
        MaskedFingerprintVersion.Builder buildMaskFor = MaskedFingerprintVersion.buildMaskFor(VERSION);
        buildMaskFor.disableAll();
        Iterator<String> it = Files.readAllLines(new File(file, "fingerprint_indizes.txt").toPath(), Charset.forName("UTF-8")).iterator();
        while (it.hasNext()) {
            buildMaskFor.enable(Integer.parseInt(it.next()));
        }
        this.fingerprintVersion = buildMaskFor.toMask();
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        int i = 0;
        Iterator<String> it2 = Files.readAllLines(new File(file, "klasses_with_indizes.csv").toPath(), Charset.forName("UTF-8")).iterator();
        while (it2.hasNext()) {
            String str = it2.next().split("\t")[0];
            String normalizeName = normalizeName(str);
            hashSet.add(str);
            hashSet.add(normalizeName);
            if (!hashSet2.add(normalizeName)) {
                System.out.println("Double class with name: " + normalizeName + " and " + str);
            }
            i++;
        }
        System.out.println("Number of classes: " + hashSet2.size() + " ( lines:  " + i + ")");
        this.classyFireFingerprintVersion = ClassyFireFingerprintVersion.loadClassyfire(new File(file, "chemont.csv.gz"));
        double[][] readFromFile = new KernelToNumpyConverter().readFromFile(new File("formula_normalized.txt"));
        if (SCALE_BY_MAX) {
            this.formulaNorm = readFromFile[0];
            this.formulaScale = readFromFile[1];
        }
        MaskedFingerprintVersion.Builder buildMaskFor2 = MaskedFingerprintVersion.buildMaskFor(this.classyFireFingerprintVersion);
        buildMaskFor2.disableAll();
        int size = this.classyFireFingerprintVersion.size();
        for (int i2 = 0; i2 < size; i2++) {
            ClassyfireProperty molecularProperty = this.classyFireFingerprintVersion.getMolecularProperty(i2);
            if (hashSet.contains(molecularProperty.getName())) {
                this.compoundClasses.put(i2, new CompoundClass((short) i2, molecularProperty));
                buildMaskFor2.enable(i2);
                hashSet2.remove(this.classyFireFingerprintVersion.getMolecularProperty(i2).getName());
            }
        }
        System.out.println("MISSING:\n" + hashSet2);
        this.classyFireMask = buildMaskFor2.toMask();
        System.out.println("Number of Labels: " + this.classyFireMask.size());
        for (CompoundClass compoundClass : this.compoundClasses.valueCollection()) {
            this.name2class.put(compoundClass.ontology.getName(), compoundClass);
            this.name2class.put(unnormalize(compoundClass.ontology.getName()), compoundClass);
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        BufferedReader reader = KernelToNumpyConverter.getReader(new File(file, "fingerprints.csv"));
        while (true) {
            try {
                String readLine = reader.readLine();
                if (readLine == null) {
                    break;
                }
                String[] split = readLine.split("\t");
                String str2 = split[0];
                String str3 = split[1];
                MolecularFormula molecularFormula = (MolecularFormula) hashMap3.get(str3);
                if (molecularFormula == null) {
                    molecularFormula = MolecularFormula.parseOrThrow(str3);
                    hashMap3.put(str3, molecularFormula);
                }
                short[] sArr = new short[split.length - 2];
                for (int i3 = 0; i3 < sArr.length; i3++) {
                    sArr[i3] = Short.parseShort(split[i3 + 2]);
                }
                hashMap2.put(str2, molecularFormula);
                hashMap.put(str2, new ArrayFingerprint(this.fingerprintVersion, sArr));
            } finally {
            }
        }
        if (reader != null) {
            reader.close();
        }
        HashMap hashMap4 = new HashMap();
        FormulaConstraints formulaConstraints = new FormulaConstraints("CHNOPSClBrIFBSe");
        this.fingerprintSampler = new Sampler(this.fingerprintVersion);
        reader = KernelToNumpyConverter.getReader(new File(file, "compounds.csv"));
        while (true) {
            try {
                String readLine2 = reader.readLine();
                if (readLine2 == null) {
                    break;
                }
                String[] split2 = readLine2.split("\t");
                String substring = split2[0].substring(0, 14);
                MolecularFormula molecularFormula2 = (MolecularFormula) hashMap2.get(substring);
                if (molecularFormula2 != null && !formulaConstraints.isViolated(molecularFormula2, PeriodicTable.getInstance().neutralIonization()) && molecularFormula2.getMass() <= 1500.0d && (arrayFingerprint = (ArrayFingerprint) hashMap.get(substring)) != null) {
                    short[] sArr2 = new short[split2.length - 1];
                    int i4 = 0;
                    for (int i5 = 1; i5 < split2.length; i5++) {
                        CompoundClass compoundClass2 = this.name2class.get(split2[i5]);
                        if (compoundClass2 != null) {
                            int i6 = i4;
                            i4++;
                            sArr2[i6] = compoundClass2.index;
                        }
                    }
                    short[] copyOf = Arrays.copyOf(sArr2, i4);
                    Arrays.sort(copyOf);
                    LabeledCompound labeledCompound = new LabeledCompound(substring, molecularFormula2, arrayFingerprint, this.classyFireMask.mask(new ArrayFingerprint(this.classyFireMask.getMaskedFingerprintVersion(), copyOf)), getFormulaFeatures(molecularFormula2), (ArrayFingerprint) hashMap4.get(substring.substring(0, 14)));
                    this.compounds.add(labeledCompound);
                    for (int i7 = 0; i7 < i4; i7++) {
                        ((CompoundClass) this.compoundClasses.get(copyOf[i7])).compounds.add(labeledCompound);
                    }
                }
            } finally {
            }
        }
        if (reader != null) {
            reader.close();
        }
        Collections.shuffle(this.compounds);
        if (SCALE_BY_MAX) {
            for (LabeledCompound labeledCompound2 : this.compounds) {
                labeledCompound2.formulaFeaturesF = new float[labeledCompound2.formulaFeatures.length];
                for (int i8 = 0; i8 < labeledCompound2.formulaFeatures.length; i8++) {
                    double[] dArr = labeledCompound2.formulaFeatures;
                    int i9 = i8;
                    dArr[i9] = dArr[i9] - this.formulaNorm[i8];
                    double[] dArr2 = labeledCompound2.formulaFeatures;
                    int i10 = i8;
                    dArr2[i10] = dArr2[i10] / this.formulaScale[i8];
                    labeledCompound2.formulaFeaturesF[i8] = (float) labeledCompound2.formulaFeatures[i8];
                }
            }
        } else {
            scaleFormulaFeatures();
        }
        if (VECNORM_SCALING) {
            for (LabeledCompound labeledCompound3 : this.compounds) {
                double d = 0.0d;
                for (int i11 = 0; i11 < labeledCompound3.formulaFeatures.length; i11++) {
                    d += labeledCompound3.formulaFeatures[i11] * labeledCompound3.formulaFeatures[i11];
                }
                double sqrt = Math.sqrt(d);
                for (int i12 = 0; i12 < labeledCompound3.formulaFeatures.length; i12++) {
                    double[] dArr3 = labeledCompound3.formulaFeatures;
                    int i13 = i12;
                    dArr3[i13] = dArr3[i13] / sqrt;
                    labeledCompound3.formulaFeaturesF[i12] = (float) labeledCompound3.formulaFeatures[i12];
                }
            }
        }
        HashMap<String, LabeledCompound> hashMap5 = new HashMap<>();
        for (LabeledCompound labeledCompound4 : this.compounds) {
            hashMap5.put(labeledCompound4.inchiKey, labeledCompound4);
        }
        if (this.independentPattern == null) {
            this.independent = null;
            this.crossvalidation = this.fingerprintSampler.readCrossvalidation(new File(file, "prediction_prediction.csv"), hashMap5);
        } else {
            this.fingerprintSampler.setExclude(this.independentPattern);
            this.crossvalidation = this.fingerprintSampler.readCrossvalidation(new File(file, "prediction_prediction.csv"), hashMap5);
            Sampler sampler = new Sampler(this.fingerprintVersion);
            sampler.setInclude(this.independentPattern);
            this.independent = sampler.readCrossvalidation(new File(file, "prediction_prediction.csv"), hashMap5);
        }
        Iterator<EvaluationInstance> it3 = this.crossvalidation.iterator();
        while (it3.hasNext()) {
            this.blacklist.add(it3.next().compound.inchiKey);
        }
        if (this.independent != null) {
            Iterator<EvaluationInstance> it4 = this.independent.iterator();
            while (it4.hasNext()) {
                this.blacklist.add(it4.next().compound.inchiKey);
            }
        }
        this.nformulas = this.formulaNorm.length;
        this.nplatts = this.fingerprintVersion.size();
        this.nlabels = this.classyFireMask.size();
        if (PLATT_CENTERING || SCALE_BY_STD) {
            this.plattNorm = new double[this.nplatts];
            this.plattScale = new double[this.nplatts];
            this.fingerprintSampler.standardize(this.plattNorm, this.plattScale);
            if (!PLATT_CENTERING) {
                Arrays.fill(this.plattNorm, 0.0d);
            }
            if (!SCALE_BY_STD) {
                Arrays.fill(this.plattScale, 1.0d);
            }
        } else {
            this.plattScale = null;
            this.plattNorm = null;
        }
        for (int i14 : this.classyFireMask.allowedIndizes()) {
            ClassyfireProperty molecularProperty2 = this.classyFireMask.getMolecularProperty(i14);
            if (this.compoundClasses.containsKey(i14)) {
                int size2 = ((CompoundClass) this.compoundClasses.get(i14)).compounds.size();
                if (size2 < 300) {
                    System.err.println("We have less than " + size2 + " training examples for " + molecularProperty2.getName());
                }
            } else {
                System.err.println("Inconsistency with " + molecularProperty2.getName() + " which is part of the fingerprint but not part of the hash map");
            }
        }
        if (new File("treeWithCovariance.tree").exists()) {
            System.out.println("Build tree with covariance");
            this.fingerprintSampler.buildCovarianceTree(new File("treeWithCovariance.tree"));
        }
    }

    private ArrayFingerprint addFingerprintsAsLabels(ArrayFingerprint arrayFingerprint) {
        return null;
    }

    private void scaleFormulaFeatures() {
        int length = this.compounds.get(0).formulaFeatures.length;
        this.formulaNorm = new double[length];
        this.formulaScale = new double[length];
        Iterator<LabeledCompound> it = this.compounds.iterator();
        while (it.hasNext()) {
            double[] dArr = it.next().formulaFeatures;
            for (int i = 0; i < dArr.length; i++) {
                double[] dArr2 = this.formulaNorm;
                int i2 = i;
                dArr2[i2] = dArr2[i2] + dArr[i];
            }
        }
        for (int i3 = 0; i3 < this.formulaNorm.length; i3++) {
            double[] dArr3 = this.formulaNorm;
            int i4 = i3;
            dArr3[i4] = dArr3[i4] / this.compounds.size();
        }
        Iterator<LabeledCompound> it2 = this.compounds.iterator();
        while (it2.hasNext()) {
            double[] dArr4 = it2.next().formulaFeatures;
            for (int i5 = 0; i5 < dArr4.length; i5++) {
                int i6 = i5;
                dArr4[i6] = dArr4[i6] - this.formulaNorm[i5];
                double[] dArr5 = this.formulaScale;
                int i7 = i5;
                dArr5[i7] = dArr5[i7] + (dArr4[i5] * dArr4[i5]);
            }
        }
        for (int i8 = 0; i8 < this.formulaNorm.length; i8++) {
            this.formulaScale[i8] = Math.sqrt(this.formulaScale[i8] / this.compounds.size());
        }
        for (LabeledCompound labeledCompound : this.compounds) {
            double[] dArr6 = labeledCompound.formulaFeatures;
            for (int i9 = 0; i9 < dArr6.length; i9++) {
                int i10 = i9;
                dArr6[i10] = dArr6[i10] / this.formulaScale[i9];
            }
            float[] fArr = new float[labeledCompound.formulaFeatures.length];
            for (int i11 = 0; i11 < dArr6.length; i11++) {
                fArr[i11] = (float) dArr6[i11];
            }
            labeledCompound.formulaFeaturesF = fArr;
        }
    }

    private String normalizeName(String str) {
        return str.replaceAll("&#39;", "'").replaceAll("&gt;", ">");
    }

    private String unnormalize(String str) {
        return str.replaceAll("'", "&#39;").replaceAll(">", "&gt;");
    }

    public ArrayFingerprint integrateCanopusFingerprint(ArrayFingerprint arrayFingerprint, ArrayFingerprint arrayFingerprint2) {
        short[] indizesArray = arrayFingerprint.toIndizesArray();
        TShortArrayList tShortArrayList = new TShortArrayList(indizesArray);
        Iterator it = arrayFingerprint2.presentFingerprints().iterator();
        while (it.hasNext()) {
            int i = this.canopusFingerprintMapping.get(((FPIter) it.next()).getIndex());
            if (i >= 0) {
                tShortArrayList.add((short) i);
            }
        }
        tShortArrayList.sort(indizesArray.length, tShortArrayList.size());
        return new ArrayFingerprint(this.canopusFingerprint, tShortArrayList.toArray());
    }

    public double[] getFormulaFeatures(MolecularFormula molecularFormula) {
        return Canopus.getFormulaFeatures(molecularFormula);
    }
}
