/*
 * Decompiled with CFR 0.152.
 */
package org.extratrees;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.Set;
import org.extratrees.AbstractTrees;
import org.extratrees.BinaryTree;
import org.extratrees.Matrix;
import org.extratrees.ShuffledIterator;
import org.extratrees.TaskCutResult;

public class ExtraTrees
extends AbstractTrees<BinaryTree> {
    double[] output;
    double[] outputSq;

    public ExtraTrees(Matrix matrix, double[] dArray) {
        this(matrix, dArray, null);
    }

    public ExtraTrees(Matrix matrix, double[] dArray, int[] nArray) {
        int n;
        if (matrix.nrows != dArray.length) {
            throw new IllegalArgumentException("Input and output do not have same length.");
        }
        if (nArray != null && matrix.nrows != nArray.length) {
            throw new IllegalArgumentException("Input and tasks do not have the same number of data points.");
        }
        this.input = matrix;
        this.output = dArray;
        this.outputSq = new double[dArray.length];
        for (n = 0; n < dArray.length; ++n) {
            this.outputSq[n] = this.output[n] * this.output[n];
        }
        this.setTasks(nArray);
        this.cols = new ArrayList(matrix.ncols);
        for (n = 0; n < matrix.ncols; ++n) {
            this.cols.add(n);
        }
    }

    public ExtraTrees selectTrees(boolean[] blArray) {
        ExtraTrees extraTrees = new ExtraTrees(this.input, this.output);
        extraTrees.trees = new ArrayList();
        for (int i = 0; i < blArray.length; ++i) {
            if (!blArray[i]) continue;
            extraTrees.trees.add(this.trees.get(i));
        }
        return extraTrees;
    }

    public ArrayList<BinaryTree> buildTrees(int n, int n2, int n3, int[] nArray) {
        ArrayList<BinaryTree> arrayList = new ArrayList<BinaryTree>(n3);
        ShuffledIterator<Integer> shuffledIterator = new ShuffledIterator<Integer>(this.cols);
        for (int i = 0; i < n3; ++i) {
            arrayList.add((BinaryTree)this.buildTree(n, n2, nArray, shuffledIterator, ExtraTrees.getSequenceSet(this.nTasks)));
        }
        return arrayList;
    }

    public static double getValue(ArrayList<BinaryTree> arrayList, double[] dArray) {
        double d = 0.0;
        for (BinaryTree binaryTree : arrayList) {
            d += binaryTree.getValue(dArray);
        }
        return d / (double)arrayList.size();
    }

    public static double getValue(ArrayList<BinaryTree> arrayList, double[] dArray, int n) {
        double d = 0.0;
        for (BinaryTree binaryTree : arrayList) {
            d += binaryTree.getValue(dArray, n);
        }
        return d / (double)arrayList.size();
    }

    public Matrix getAllValues(Matrix matrix) {
        Matrix matrix2 = new Matrix(matrix.nrows, this.trees.size());
        double[] dArray = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            matrix.copyRow(i, dArray);
            for (int j = 0; j < this.trees.size(); ++j) {
                matrix2.set(i, j, ((BinaryTree)this.trees.get(j)).getValue(dArray));
            }
        }
        return matrix2;
    }

    public double[] getValues(Matrix matrix) {
        return ExtraTrees.getValues(this.trees, matrix);
    }

    public static double[] getValues(ArrayList<BinaryTree> arrayList, Matrix matrix) {
        double[] dArray = new double[matrix.nrows];
        double[] dArray2 = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            matrix.copyRow(i, dArray2);
            dArray[i] = ExtraTrees.getValue(arrayList, dArray2);
        }
        return dArray;
    }

    public double[] getValuesMT(Matrix matrix, int[] nArray) {
        double[] dArray = new double[matrix.nrows];
        double[] dArray2 = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            for (int j = 0; j < matrix.ncols; ++j) {
                dArray2[j] = matrix.get(i, j);
            }
            dArray[i] = this.getValueMT(dArray2, nArray[i]);
        }
        return dArray;
    }

    public double getValueMT(double[] dArray, int n) {
        double d = 0.0;
        for (BinaryTree binaryTree : this.trees) {
            d += binaryTree.getValueMT(dArray, n);
        }
        return d /= (double)this.trees.size();
    }

    public Matrix getAllValuesMT(Matrix matrix, int[] nArray) {
        if (matrix.nrows != nArray.length) {
            throw new IllegalArgumentException("Inputs and tasks do not have the same length.");
        }
        Matrix matrix2 = new Matrix(matrix.nrows, this.trees.size());
        double[] dArray = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            matrix.copyRow(i, dArray);
            for (int j = 0; j < this.trees.size(); ++j) {
                matrix2.set(i, j, ((BinaryTree)this.trees.get(j)).getValueMT(dArray, nArray[i]));
            }
        }
        return matrix2;
    }

    @Override
    protected BinaryTree makeFilledTree(BinaryTree binaryTree, BinaryTree binaryTree2, int n, double d, int n2) {
        BinaryTree binaryTree3 = new BinaryTree();
        binaryTree3.column = n;
        binaryTree3.threshold = d;
        binaryTree3.nSuccessors = n2;
        binaryTree3.left = binaryTree;
        binaryTree3.right = binaryTree2;
        binaryTree3.value = binaryTree3.left.value * (double)binaryTree3.left.nSuccessors + binaryTree3.right.value * (double)binaryTree3.right.nSuccessors;
        binaryTree3.value /= (double)binaryTree3.nSuccessors;
        return binaryTree3;
    }

    @Override
    protected void calculateCutScore(int[] nArray, int n, double d, AbstractTrees.CutResult cutResult) {
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        double d5 = 0.0;
        for (int i = 0; i < nArray.length; ++i) {
            if (this.input.get(nArray[i], n) < d) {
                ++cutResult.countLeft;
                d2 += this.output[nArray[i]];
                d4 += this.outputSq[nArray[i]];
                continue;
            }
            ++cutResult.countRight;
            d3 += this.output[nArray[i]];
            d5 += this.outputSq[nArray[i]];
        }
        this.cutResultFromSums(cutResult, d2, d3, d4, d5, cutResult.countLeft, cutResult.countRight);
    }

    private void cutResultFromSums(AbstractTrees.CutResult cutResult, double d, double d2, double d3, double d4, double d5, double d6) {
        double d7 = d3 / d5 - d / d5 * (d / d5);
        double d8 = d4 / d6 - d2 / d6 * (d2 / d6);
        cutResult.score = (double)cutResult.countLeft * d7 + (double)cutResult.countRight * d8;
        cutResult.leftConst = d7 < 9.999999999999998E-15;
        cutResult.rightConst = d8 < 9.999999999999998E-15;
    }

    @Override
    protected TaskCutResult getTaskCut(int[] nArray, Set<Integer> set, double d) {
        if (set.size() <= 1) {
            return null;
        }
        double d2 = this.getOutputMean(nArray);
        int[] nArray2 = new int[this.nTasks];
        double[] dArray = new double[this.nTasks];
        double[] dArray2 = new double[this.nTasks];
        double[] dArray3 = new double[this.nTasks];
        double[] dArray4 = this.getTaskScores(nArray, d2, set, nArray2, dArray, dArray2, dArray3);
        if (!this.hasAtLeast2Tasks(nArray)) {
            return null;
        }
        double[] dArray5 = this.getRange(dArray4);
        TaskCutResult taskCutResult = null;
        for (int i = 0; i < this.numRandomTaskCuts; ++i) {
            double d3 = this.getRandom(dArray5[0], dArray5[1]);
            TaskCutResult taskCutResult2 = new TaskCutResult();
            this.calculateTaskCutScore(dArray4, nArray2, dArray, dArray2, dArray3, d2, d3, taskCutResult2, set);
            if (!(taskCutResult2.score < d)) continue;
            taskCutResult = taskCutResult2;
            d = taskCutResult2.score;
        }
        return taskCutResult;
    }

    private void calculateTaskCutScore(double[] dArray, int[] nArray, double[] dArray2, double[] dArray3, double[] dArray4, double d, double d2, TaskCutResult taskCutResult, Set<Integer> set) {
        double d3 = 0.0;
        double d4 = 0.0;
        double d5 = 0.0;
        double d6 = 0.0;
        double d7 = 0.0;
        double d8 = 0.0;
        taskCutResult.leftTasks = new HashSet<Integer>();
        taskCutResult.rightTasks = new HashSet<Integer>();
        taskCutResult.countLeft = 0;
        taskCutResult.countRight = 0;
        for (int n : set) {
            if (dArray[n] < d2) {
                taskCutResult.leftTasks.add(n);
                taskCutResult.countLeft += nArray[n];
                d7 += dArray2[n];
                d3 += dArray3[n];
                d5 += dArray4[n];
                continue;
            }
            taskCutResult.rightTasks.add(n);
            taskCutResult.countRight += nArray[n];
            d8 += dArray2[n];
            d4 += dArray3[n];
            d6 += dArray4[n];
        }
        this.cutResultFromSums(taskCutResult, d3, d4, d5, d6, d7, d8);
    }

    private boolean hasAtLeast2Tasks(int[] nArray) {
        int n = this.tasks[nArray[0]];
        for (int i = 1; i < nArray.length; ++i) {
            if (n == this.tasks[nArray[i]]) continue;
            return true;
        }
        return false;
    }

    private double[] getTaskScores(int[] nArray, double d, Set<Integer> set, int[] nArray2, double[] dArray, double[] dArray2, double[] dArray3) {
        double d2 = 1.0;
        double[] dArray4 = new double[this.nTasks];
        for (int i = 0; i < nArray.length; ++i) {
            int n = nArray[i];
            int n2 = this.tasks[n];
            nArray2[n2] = nArray2[n2] + 1;
            int n3 = this.tasks[n];
            dArray2[n3] = dArray2[n3] + this.output[n];
            int n4 = this.tasks[n];
            dArray3[n4] = dArray3[n4] + this.outputSq[n];
        }
        for (int n : set) {
            dArray[n] = nArray2[n];
            dArray4[n] = (dArray2[n] + d * d2) / ((double)nArray2[n] + d2);
        }
        return dArray4;
    }

    private double getOutputMean(int[] nArray) {
        double d = 0.0;
        for (int i = 0; i < nArray.length; ++i) {
            d += this.output[nArray[i]];
        }
        return d /= (double)nArray.length;
    }

    @Override
    public BinaryTree makeLeaf(int[] nArray, Set<Integer> set) {
        BinaryTree binaryTree = new BinaryTree();
        binaryTree.value = 0.0;
        binaryTree.nSuccessors = nArray.length;
        binaryTree.tasks = set;
        for (int i = 0; i < nArray.length; ++i) {
            binaryTree.value += this.output[nArray[i]];
        }
        binaryTree.value /= (double)nArray.length;
        return binaryTree;
    }

    public static ExtraTrees getSampleData(int n, int n2) {
        double[] dArray = new double[n];
        double[] dArray2 = new double[n * n2];
        for (int i = 0; i < dArray2.length; ++i) {
            dArray2[i] = Math.random();
        }
        Matrix matrix = new Matrix(dArray2, n, n2);
        for (int i = 0; i < dArray.length; ++i) {
            matrix.set(i, 2, 0.5);
            dArray[i] = matrix.get(i, 1) + 0.2 * matrix.get(i, 3);
        }
        ExtraTrees extraTrees = new ExtraTrees(matrix, dArray);
        return extraTrees;
    }

    public static double getMeanSqError(ArrayList<BinaryTree> arrayList, Matrix matrix, double[] dArray) {
        double d = 0.0;
        double[] dArray2 = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            for (int j = 0; j < matrix.ncols; ++j) {
                dArray2[j] = matrix.get(i, j);
            }
            d += Math.pow(ExtraTrees.getValue(arrayList, dArray2) - dArray[i], 2.0);
        }
        return d / (double)dArray.length;
    }

    public static double getMeanSqError(ArrayList<BinaryTree> arrayList, Matrix matrix, double[] dArray, int n, int[] nArray) {
        double d = 0.0;
        double[] dArray2 = new double[matrix.ncols];
        for (int i = 0; i < nArray.length; ++i) {
            int n2 = nArray[i];
            for (int j = 0; j < matrix.ncols; ++j) {
                dArray2[j] = matrix.get(n2, j);
            }
            d += Math.pow(ExtraTrees.getValue(arrayList, dArray2, n) - dArray[n2], 2.0);
        }
        return d / (double)nArray.length;
    }

    public static double getMeanAbsError(ArrayList<BinaryTree> arrayList, Matrix matrix, double[] dArray) {
        double d = 0.0;
        double[] dArray2 = ExtraTrees.getValues(arrayList, matrix);
        for (int i = 0; i < dArray.length; ++i) {
            d += Math.abs(dArray2[i] - dArray[i]);
        }
        return d / (double)dArray.length;
    }

    public ArrayList<BinaryTree> buildTreeCV(int n, int n2) {
        int n3;
        int[] nArray = new int[]{2, 3, 5, 9, 14};
        int n4 = (int)(0.6666666666666666 * (double)this.output.length);
        Integer[] integerArray = new Integer[this.output.length];
        for (int i = 0; i < integerArray.length; ++i) {
            integerArray[i] = i;
        }
        Collections.shuffle(Arrays.asList(integerArray));
        int[] nArray2 = new int[n4];
        int[] nArray3 = new int[this.output.length - n4];
        for (n3 = 0; n3 < nArray2.length; ++n3) {
            nArray2[n3] = integerArray[n3];
        }
        for (n3 = 0; n3 < nArray3.length; ++n3) {
            nArray3[n3] = integerArray[n3 + nArray2.length];
        }
        ArrayList<BinaryTree> arrayList = this.buildTrees(2, n, n2, nArray2);
        double[] dArray = new double[nArray.length];
        double d = Double.POSITIVE_INFINITY;
        int n5 = nArray[0];
        for (int i = 0; i < nArray.length; ++i) {
            dArray[i] = ExtraTrees.getMeanSqError(arrayList, this.input, this.output, nArray[i], nArray3);
            if (!(dArray[i] < d)) continue;
            n5 = nArray[i];
            d = dArray[i];
        }
        ArrayList<BinaryTree> arrayList2 = this.buildTrees(n5, n, n2);
        return arrayList2;
    }

    public static void main(String[] stringArray) {
        int n = 10000;
        int n2 = 7;
        int n3 = 15;
        ExtraTrees extraTrees = ExtraTrees.getSampleData(n, n2);
        Date date = new Date();
        Date date2 = new Date();
        System.out.println("Took: " + (double)(date2.getTime() - date.getTime()) / 1000.0 + "s");
        Date date3 = new Date();
        extraTrees.learnTrees(2, 6, n3);
        ArrayList arrayList = extraTrees.trees;
        Date date4 = new Date();
        ExtraTrees extraTrees2 = ExtraTrees.getSampleData(1000, n2);
        double[] dArray = ExtraTrees.getValues(arrayList, extraTrees2.input);
        for (int i = 0; i < extraTrees2.output.length; ++i) {
            System.out.print(String.format("%d\t%1.3f %1.3f", i, extraTrees2.output[i], dArray[i]));
            System.out.println();
        }
        System.out.println("Took: " + (double)(date4.getTime() - date3.getTime()) / 1000.0 + "s");
        int[] nArray = new int[extraTrees2.output.length];
        for (int i = 0; i < nArray.length; ++i) {
            nArray[i] = i;
        }
        double d = ExtraTrees.getMeanSqError(arrayList, extraTrees2.input, extraTrees2.output);
        double d2 = ExtraTrees.getMeanSqError(arrayList, extraTrees2.input, extraTrees2.output, 5, nArray);
        System.out.println("Error: " + d);
        System.out.println("Error: " + d2);
    }
}

