package phylo.tree.treetools;

import cern.jet.math.Arithmetic;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Vector;
import phylo.tree.model.Tree;
import phylo.tree.model.TreeNode;

/* loaded from: input_file:phylo/tree/treetools/BCNWithLikelihoods.class */
public class BCNWithLikelihoods {
    private Tree tree;
    private Tree compareTree;
    private Vector<TreeNode> internalNodesTree;
    private Vector<TreeNode> internalNodesCompareTree;
    private BigDecimal[][] matrix;
    private double[][] matrixDouble;
    public boolean chooseDouble = true;
    private Vector<NodeSet> solutionTree = new Vector<>();
    private Vector<NodeSet> solutionCompareTree = new Vector<>();
    private BigDecimal overallResult = BigDecimal.TEN;
    private double overallResultDouble = -1.0d;

    public BCNWithLikelihoods(Tree tree, Tree tree2) {
        this.tree = tree;
        this.compareTree = tree2;
    }

    public static Vector<TreeNode> getInternalNodes(Tree tree) {
        TreeNode root = tree.getRoot();
        Vector<TreeNode> vector = new Vector<>();
        for (TreeNode treeNode : root.depthFirstIterator()) {
            if (!treeNode.isLeaf()) {
                vector.add(treeNode);
            }
        }
        return vector;
    }

    public static BigDecimal computeValue(int i, int i2, int i3, int i4) {
        BigDecimal divide = binomial(i, i2).multiply(binomial(i, i3)).divide(BigDecimal.valueOf(Math.pow(2.0d, i) * Math.pow(2.0d, i)), 20, RoundingMode.UP);
        BigDecimal bigDecimal = BigDecimal.ZERO;
        if (i2 > i3) {
            BigDecimal binomial = binomial(i, i3);
            for (int i5 = i4; i5 <= i3; i5++) {
                bigDecimal = bigDecimal.add(binomial(i2, i5).multiply(binomial(i - i2, i3 - i5)).divide(binomial, 10, RoundingMode.UP));
            }
        } else {
            BigDecimal binomial2 = binomial(i, i2);
            for (int i6 = i4; i6 <= i2; i6++) {
                bigDecimal = bigDecimal.add(binomial(i3, i6).multiply(binomial(i - i3, i2 - i6)).divide(binomial2, 10, RoundingMode.UP));
            }
        }
        return divide.multiply(bigDecimal);
    }

    public static double computeValueWithDouble(int i, int i2, int i3, int i4) {
        double binomial = (Arithmetic.binomial(i, i2) * Arithmetic.binomial(i, i3)) / (Math.pow(2.0d, i) * Math.pow(2.0d, i));
        double d = 0.0d;
        if (i2 > i3) {
            double binomial2 = Arithmetic.binomial(i, i3);
            for (int i5 = i4; i5 <= i3; i5++) {
                d += (Arithmetic.binomial(i2, i5) * Arithmetic.binomial(i - i2, i3 - i5)) / binomial2;
            }
        } else {
            double binomial3 = Arithmetic.binomial(i, i2);
            for (int i6 = i4; i6 <= i2; i6++) {
                d += (Arithmetic.binomial(i3, i6) * Arithmetic.binomial(i - i3, i2 - i6)) / binomial3;
            }
        }
        return binomial * d;
    }

    public static int countCommonLeaves(TreeNode[] treeNodeArr, TreeNode[] treeNodeArr2) {
        int i = 0;
        for (TreeNode treeNode : treeNodeArr) {
            int i2 = 0;
            while (true) {
                if (i2 >= treeNodeArr2.length) {
                    break;
                }
                if (treeNode.equalsNode(treeNodeArr2[i2])) {
                    i++;
                    break;
                }
                i2++;
            }
        }
        return i;
    }

    public void run() {
        this.internalNodesTree = getInternalNodes(this.tree);
        this.internalNodesCompareTree = getInternalNodes(this.compareTree);
        int numTaxa = this.tree.getNumTaxa();
        if (this.chooseDouble) {
            this.matrixDouble = new double[this.internalNodesTree.size()][this.internalNodesCompareTree.size()];
        } else {
            this.matrix = new BigDecimal[this.internalNodesTree.size()][this.internalNodesCompareTree.size()];
        }
        for (int i = 0; i < this.internalNodesTree.size(); i++) {
            for (int i2 = 0; i2 < this.internalNodesCompareTree.size(); i2++) {
                TreeNode treeNode = this.internalNodesTree.get(i);
                TreeNode treeNode2 = this.internalNodesCompareTree.get(i2);
                int leafCount = treeNode.leafCount();
                int leafCount2 = treeNode2.leafCount();
                int countCommonLeaves = countCommonLeaves(treeNode.getLeaves(), treeNode2.getLeaves());
                if (this.chooseDouble) {
                    this.overallResultDouble = computeValueWithDouble(numTaxa, leafCount, leafCount2, countCommonLeaves);
                    this.matrixDouble[i][i2] = this.overallResultDouble;
                } else {
                    this.overallResult = computeValue(numTaxa, leafCount, leafCount2, countCommonLeaves);
                    this.matrix[i][i2] = this.overallResult;
                }
            }
        }
        if (this.chooseDouble) {
            for (int i3 = 0; i3 < this.internalNodesTree.size(); i3++) {
                System.out.print(listLeaves(this.internalNodesTree.get(i3).getLeaves()) + "\t\t");
                for (int i4 = 0; i4 < this.internalNodesCompareTree.size(); i4++) {
                    System.out.print(this.matrixDouble[i3][i4] + "       \t");
                }
                System.out.println();
            }
        } else {
            for (int i5 = 0; i5 < this.internalNodesTree.size(); i5++) {
                System.out.print(listLeaves(this.internalNodesTree.get(i5).getLeaves()) + "\t\t");
                for (int i6 = 0; i6 < this.internalNodesCompareTree.size(); i6++) {
                    System.out.print(this.matrix[i5][i6] + " ");
                }
                System.out.println();
            }
        }
        new NodeSet();
        new NodeSet();
        for (int i7 = 0; i7 < this.internalNodesTree.size(); i7++) {
            NodeSet nodeSet = this.chooseDouble ? new NodeSet(this.internalNodesTree.get(i7), this.internalNodesCompareTree.get(0), this.matrixDouble[i7][0]) : new NodeSet(this.internalNodesTree.get(i7), this.internalNodesCompareTree.get(0), this.matrix[i7][0]);
            for (int i8 = 0; i8 < this.internalNodesCompareTree.size(); i8++) {
                if (this.chooseDouble) {
                    if (this.matrixDouble[i7][i8] < nodeSet.getScore()) {
                        nodeSet.setNodeTree(this.internalNodesTree.get(i7));
                        nodeSet.setNodeCompareTree(this.internalNodesCompareTree.get(i8));
                        nodeSet.setScore(this.matrixDouble[i7][i8]);
                    } else if (this.matrixDouble[i7][i8] == nodeSet.getScore() && computeBcn(this.internalNodesTree.get(i7).getLeaves(), this.internalNodesCompareTree.get(i8).getLeaves()) > computeBcn(nodeSet.getNodeTree().getLeaves(), nodeSet.getNodeCompareTree().getLeaves())) {
                        nodeSet.setNodeTree(this.internalNodesTree.get(i7));
                        nodeSet.setNodeCompareTree(this.internalNodesCompareTree.get(i8));
                        nodeSet.setScore(this.matrixDouble[i7][i8]);
                    }
                } else if (this.matrix[i7][i8].compareTo(nodeSet.getScore2()) == -1) {
                    nodeSet.setNodeTree(this.internalNodesTree.get(i7));
                    nodeSet.setNodeCompareTree(this.internalNodesCompareTree.get(i8));
                    nodeSet.setScore2(this.matrix[i7][i8]);
                } else if (this.matrix[i7][i8].compareTo(nodeSet.getScore2()) == 0 && computeBcn(this.internalNodesTree.get(i7).getLeaves(), this.internalNodesCompareTree.get(i8).getLeaves()) > computeBcn(nodeSet.getNodeTree().getLeaves(), nodeSet.getNodeCompareTree().getLeaves())) {
                    nodeSet.setNodeTree(this.internalNodesTree.get(i7));
                    nodeSet.setNodeCompareTree(this.internalNodesCompareTree.get(i8));
                    nodeSet.setScore2(this.matrix[i7][i8]);
                }
            }
            this.solutionTree.add(nodeSet);
        }
        System.out.println(" ------------------------ change \"direction\" ---------------------------");
        for (int i9 = 0; i9 < this.internalNodesCompareTree.size(); i9++) {
            NodeSet nodeSet2 = this.chooseDouble ? new NodeSet(this.internalNodesTree.get(0), this.internalNodesCompareTree.get(i9), this.matrixDouble[0][i9]) : new NodeSet(this.internalNodesTree.get(0), this.internalNodesCompareTree.get(i9), this.matrix[0][i9]);
            for (int i10 = 0; i10 < this.internalNodesTree.size(); i10++) {
                if (this.chooseDouble) {
                    if (this.matrixDouble[i10][i9] < nodeSet2.getScore()) {
                        nodeSet2.setNodeTree(this.internalNodesTree.get(i10));
                        nodeSet2.setNodeCompareTree(this.internalNodesCompareTree.get(i9));
                        nodeSet2.setScore(this.matrixDouble[i10][i9]);
                    } else if (this.matrixDouble[i10][i9] == nodeSet2.getScore() && computeBcn(this.internalNodesTree.get(i10).getLeaves(), this.internalNodesCompareTree.get(i9).getLeaves()) > computeBcn(nodeSet2.getNodeTree().getLeaves(), nodeSet2.getNodeCompareTree().getLeaves())) {
                        nodeSet2.setNodeTree(this.internalNodesTree.get(i10));
                        nodeSet2.setNodeCompareTree(this.internalNodesCompareTree.get(i9));
                        nodeSet2.setScore(this.matrixDouble[i10][i9]);
                    }
                } else if (this.matrix[i10][i9].compareTo(nodeSet2.getScore2()) == -1) {
                    nodeSet2.setNodeTree(this.internalNodesTree.get(i10));
                    nodeSet2.setNodeCompareTree(this.internalNodesCompareTree.get(i9));
                    nodeSet2.setScore2(this.matrix[i10][i9]);
                } else if (this.matrix[i10][i9].compareTo(nodeSet2.getScore2()) == 0 && computeBcn(this.internalNodesTree.get(i10).getLeaves(), this.internalNodesCompareTree.get(i9).getLeaves()) > computeBcn(nodeSet2.getNodeTree().getLeaves(), nodeSet2.getNodeCompareTree().getLeaves())) {
                    nodeSet2.setNodeTree(this.internalNodesTree.get(i10));
                    nodeSet2.setNodeCompareTree(this.internalNodesCompareTree.get(i9));
                    nodeSet2.setScore2(this.matrix[i10][i9]);
                }
            }
            this.solutionCompareTree.add(nodeSet2);
        }
    }

    public static double computeBcn(TreeNode[] treeNodeArr, TreeNode[] treeNodeArr2) {
        double countCommonLeaves = countCommonLeaves(treeNodeArr, treeNodeArr2) / ((treeNodeArr.length + treeNodeArr2.length) - r0);
        for (TreeNode treeNode : treeNodeArr) {
            System.out.print(treeNode + " ");
        }
        System.out.print(" mit ");
        for (TreeNode treeNode2 : treeNodeArr2) {
            System.out.print(treeNode2 + " ");
        }
        System.out.println("haben BCN-Score = " + countCommonLeaves);
        return countCommonLeaves;
    }

    public static BigDecimal factorial(BigDecimal bigDecimal) {
        BigDecimal bigDecimal2 = BigDecimal.ZERO;
        if (bigDecimal.compareTo(BigDecimal.valueOf(0L)) != 0 && bigDecimal.compareTo(BigDecimal.valueOf(1L)) != 0) {
            return bigDecimal.multiply(factorial(bigDecimal.subtract(BigDecimal.valueOf(1L))));
        }
        return BigDecimal.ONE;
    }

    public static double binomial(double d, long j) {
        return Arithmetic.binomial(d, j);
    }

    public static BigDecimal binomial(int i, int i2) {
        BigDecimal valueOf = BigDecimal.valueOf(i);
        BigDecimal valueOf2 = BigDecimal.valueOf(i2);
        BigDecimal bigDecimal = BigDecimal.ZERO;
        if (i == 0 && i2 > 0) {
            return BigDecimal.valueOf(0L);
        }
        if ((i <= 0 || i2 != 0) && i != i2) {
            return i < i2 ? BigDecimal.valueOf(0L) : factorial(valueOf).divide(factorial(valueOf2).multiply(factorial(valueOf.subtract(valueOf2))));
        }
        return BigDecimal.valueOf(1L);
    }

    public Vector<NodeSet> getSolutionTree() {
        return this.solutionTree;
    }

    public void setSolutionTree(Vector<NodeSet> vector) {
        this.solutionTree = vector;
    }

    public Vector<NodeSet> getSolutionCompareTree() {
        return this.solutionCompareTree;
    }

    public String listLeaves(TreeNode[] treeNodeArr) {
        String str = "";
        for (TreeNode treeNode : treeNodeArr) {
            str = str + treeNode;
        }
        return str;
    }
}
