/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.ensemble;

import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.regression.Regressor;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXRef;

public class AveragingCombiner
implements EnsembleCombiner<Regressor> {
    private static final long serialVersionUID = 1L;

    public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions) {
        int numPredictions = predictions.size();
        int dimensions = outputInfo.size();
        int numUsed = 0;
        double[] mean = new double[dimensions];
        double[] variance = new double[dimensions];
        for (Prediction<Regressor> p : predictions) {
            if (numUsed < p.getNumActiveFeatures()) {
                numUsed = p.getNumActiveFeatures();
            }
            Regressor curValue = (Regressor)p.getOutput();
            for (int i = 0; i < dimensions; ++i) {
                double value = curValue.getValues()[i];
                double oldMean = mean[i];
                int n = i;
                mean[n] = mean[n] + (value - oldMean);
                int n2 = i;
                variance[n2] = variance[n2] + (value - oldMean) * (value - mean[i]);
            }
        }
        String[] names = ((Regressor)predictions.get(0).getOutput()).getNames();
        if (numPredictions > 1) {
            int i = 0;
            while (i < dimensions) {
                int n = i++;
                variance[n] = variance[n] / (double)(numPredictions - 1);
            }
        } else {
            Arrays.fill(variance, 0.0);
        }
        Example example = predictions.get(0).getExample();
        return new Prediction((Output)new Regressor(names, mean, variance), numUsed, example);
    }

    public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions, float[] weights) {
        int i;
        if (predictions.size() != weights.length) {
            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()=" + predictions.size() + ", weights.length=" + weights.length);
        }
        int dimensions = outputInfo.size();
        int numUsed = 0;
        double[] mean = new double[dimensions];
        double[] variance = new double[dimensions];
        double weightSum = 0.0;
        for (i = 0; i < weights.length; ++i) {
            Prediction<Regressor> p = predictions.get(i);
            if (numUsed < p.getNumActiveFeatures()) {
                numUsed = p.getNumActiveFeatures();
            }
            Regressor curValue = (Regressor)p.getOutput();
            float weight = weights[i];
            weightSum += (double)weight;
            for (int j = 0; j < dimensions; ++j) {
                double value = curValue.getValues()[j];
                double oldMean = mean[j];
                int n = j;
                mean[n] = mean[n] + (double)weight / weightSum * (value - oldMean);
                int n2 = j;
                variance[n2] = variance[n2] + (double)weight * (value - oldMean) * (value - mean[j]);
            }
        }
        String[] names = ((Regressor)predictions.get(0).getOutput()).getNames();
        if (weights.length > 1) {
            i = 0;
            while (i < dimensions) {
                int n = i++;
                variance[n] = variance[n] / (weightSum - 1.0);
            }
        } else {
            Arrays.fill(variance, 0.0);
        }
        Example example = predictions.get(0).getExample();
        return new Prediction((Output)new Regressor(names, mean, variance), numUsed, example);
    }

    public String toString() {
        return "MultipleOutputAveragingCombiner()";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "EnsembleCombiner");
    }

    public ONNXNode exportCombiner(ONNXNode input) {
        HashMap<String, Object> attributes = new HashMap<String, Object>();
        attributes.put("axes", new int[]{2});
        attributes.put("keepdims", 0);
        return input.apply(ONNXOperators.REDUCE_MEAN, attributes);
    }

    public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
        ONNXInitializer unsqueezeAxes = input.onnxContext().array("unsqueeze_ensemble_output", new long[]{0L, 1L});
        ONNXInitializer sumAxes = input.onnxContext().array("sum_across_ensemble_axes", new long[]{2L});
        ONNXNode unsqueezed = weight.apply(ONNXOperators.UNSQUEEZE, (ONNXRef)unsqueezeAxes);
        ONNXNode mulByWeights = input.apply(ONNXOperators.MUL, (ONNXRef)unsqueezed);
        ONNXNode weightSum = weight.apply(ONNXOperators.REDUCE_SUM);
        return mulByWeights.apply(ONNXOperators.REDUCE_SUM, (ONNXRef)sumAxes, Collections.singletonMap("keepdims", 0)).apply(ONNXOperators.DIV, (ONNXRef)weightSum);
    }
}

