package de.unijena.bioinf.fingerid.cli.tools.wc_cross_validation;

import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/unijena/bioinf/fingerid/cli/tools/wc_cross_validation/IterativeStratification.class */
public class IterativeStratification implements Runnable, Callable {
    private List<TIntIntMap> fingerPrints;
    private int dimension;
    private int numOfBatches;
    private TIntSet[] batches;

    @Override // java.lang.Runnable
    public void run() {
        try {
            this.batches = call();
        } catch (Exception e) {
            LoggerFactory.getLogger(getClass()).error("Iterative stratification fails with an error", e);
        }
    }

    @Override // java.util.concurrent.Callable
    public TIntSet[] call() throws Exception {
        this.fingerPrints = (List) this.fingerPrints.stream().filter(tIntIntMap -> {
            boolean isEmpty = tIntIntMap.isEmpty();
            if (isEmpty) {
                LoggerFactory.getLogger(getClass()).warn("Trivial zero only example removed!");
            }
            return !isEmpty;
        }).collect(Collectors.toList());
        if (this.fingerPrints.size() < this.numOfBatches) {
            throw new IllegalArgumentException("Number of exmples lower than number of batches!" + this.fingerPrints.size() + "<" + this.numOfBatches);
        }
        TIntObjectHashMap tIntObjectHashMap = new TIntObjectHashMap(this.numOfBatches);
        TIntSet[] tIntSetArr = (TIntSet[]) IntStream.range(0, this.numOfBatches).mapToObj(i -> {
            TIntHashSet tIntHashSet = new TIntHashSet((int) Math.ceil(this.fingerPrints.size() / this.numOfBatches));
            tIntObjectHashMap.put(i, new int[this.dimension]);
            return tIntHashSet;
        }).toArray(i2 -> {
            return new TIntHashSet[i2];
        });
        int[] iArr = new int[this.dimension];
        TIntSet[] tIntSetArr2 = (TIntSet[]) IntStream.range(0, this.dimension).parallel().mapToObj(i3 -> {
            TIntHashSet tIntHashSet = new TIntHashSet(IntStream.range(0, this.fingerPrints.size()).filter(i3 -> {
                int i3 = this.fingerPrints.get(i3).get(i3);
                iArr[i3] = iArr[i3] + i3;
                return i3 > 0;
            }).toArray());
            Arrays.stream(tIntObjectHashMap.values()).forEach(obj -> {
                ((int[]) obj)[i3] = 0;
            });
            return tIntHashSet;
        }).toArray(i4 -> {
            return new TIntSet[i4];
        });
        Random random = new Random();
        int i5 = 0;
        while (i5 < this.fingerPrints.size()) {
            int i6 = Integer.MAX_VALUE;
            int i7 = -1;
            for (int i8 = 0; i8 < iArr.length; i8++) {
                int i9 = iArr[i8];
                if (i9 > 0 && i9 < i6) {
                    i6 = i9;
                    i7 = i8;
                }
            }
            TIntHashSet tIntHashSet = new TIntHashSet(tIntSetArr2[i7]);
            TIntIterator it = tIntHashSet.iterator();
            while (it.hasNext()) {
                int next = it.next();
                TIntArrayList tIntArrayList = new TIntArrayList();
                int i10 = Integer.MAX_VALUE;
                TIntArrayList tIntArrayList2 = new TIntArrayList();
                int i11 = Integer.MAX_VALUE;
                for (int i12 = 0; i12 < tIntSetArr.length; i12++) {
                    TIntSet tIntSet = tIntSetArr[i12];
                    if (tIntSet.size() <= i10) {
                        if (tIntSet.size() < i10) {
                            i10 = tIntSet.size();
                            tIntArrayList.clear();
                        }
                        tIntArrayList.add(i12);
                    }
                    int i13 = ((int[]) tIntObjectHashMap.get(i12))[i7];
                    if (i13 <= i11) {
                        if (i13 < i11) {
                            i11 = ((int[]) tIntObjectHashMap.get(i12))[i7];
                            tIntArrayList2.clear();
                        }
                        tIntArrayList2.add(i12);
                    }
                }
                int i14 = tIntArrayList2.size() > 1 ? tIntArrayList.get(random.nextInt(tIntArrayList.size())) : tIntArrayList2.get(0);
                tIntSetArr[i14].add(next);
                int[] iArr2 = (int[]) tIntObjectHashMap.get(i14);
                TIntIntMap tIntIntMap2 = this.fingerPrints.get(next);
                for (int i15 = 0; i15 < iArr2.length; i15++) {
                    int i16 = i15;
                    iArr2[i16] = iArr2[i16] + tIntIntMap2.get(i15);
                }
                i5++;
            }
            for (int i17 = 0; i17 < tIntSetArr2.length; i17++) {
                tIntSetArr2[i17].removeAll(tIntHashSet);
                int i18 = i17;
                tIntHashSet.forEach(i19 -> {
                    iArr[i18] = iArr[i18] - this.fingerPrints.get(i19).get(i18);
                    return true;
                });
            }
        }
        return tIntSetArr;
    }

    public void setFingerPrints(List<TIntIntMap> list, int i) {
        this.fingerPrints = list;
        this.dimension = i;
    }

    public void setNumOfBatches(int i) {
        this.numOfBatches = i;
    }

    public TIntSet[] getResult() {
        return this.batches;
    }
}
