/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.ensemble;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableDataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.classification.Label;
import org.tribuo.classification.ensemble.VotingCombiner;
import org.tribuo.dataset.DatasetView;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.ensemble.WeightedEnsembleModel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

public class AdaBoostTrainer
implements Trainer<Label> {
    private static final Logger logger = Logger.getLogger(AdaBoostTrainer.class.getName());
    @Config(mandatory=true, description="The trainer to use to build each weak learner.")
    protected Trainer<Label> innerTrainer;
    @Config(mandatory=true, description="The number of ensemble members to train.")
    protected int numMembers;
    @Config(mandatory=true, description="The seed for the RNG.")
    protected long seed;
    protected SplittableRandom rng;
    protected int trainInvocationCounter;

    private AdaBoostTrainer() {
    }

    public AdaBoostTrainer(Trainer<Label> trainer, int numMembers) {
        this(trainer, numMembers, 12345L);
    }

    public AdaBoostTrainer(Trainer<Label> trainer, int numMembers, long seed) {
        this.innerTrainer = trainer;
        this.numMembers = numMembers;
        this.seed = seed;
        this.postConfig();
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public String toString() {
        StringBuilder buffer = new StringBuilder();
        buffer.append("AdaBoostTrainer(");
        buffer.append("innerTrainer=");
        buffer.append(this.innerTrainer.toString());
        buffer.append(",numMembers=");
        buffer.append(this.numMembers);
        buffer.append(",seed=");
        buffer.append(this.seed);
        buffer.append(")");
        return buffer.toString();
    }

    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
        return this.train(examples, runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        int i;
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        AdaBoostTrainer adaBoostTrainer = this;
        synchronized (adaBoostTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        boolean weighted = this.innerTrainer instanceof WeightedExamples;
        ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
        ImmutableOutputInfo labelIDs = examples.getOutputIDInfo();
        int numClasses = labelIDs.size();
        logger.log(Level.INFO, "NumClasses = " + numClasses);
        ArrayList<Model> models = new ArrayList<Model>();
        float[] modelWeights = new float[this.numMembers];
        float[] exampleWeights = Util.generateUniformFloatVector((int)examples.size(), (float)(1.0f / (float)examples.size()));
        if (weighted) {
            logger.info("Using weighted Adaboost.");
            examples = ImmutableDataset.copyDataset(examples);
            for (i = 0; i < examples.size(); ++i) {
                Example e = examples.getExample(i);
                e.setWeight(exampleWeights[i]);
            }
        } else {
            logger.info("Using sampling Adaboost.");
        }
        for (i = 0; i < this.numMembers; ++i) {
            int j;
            Model newModel;
            logger.info("Building model " + i);
            if (weighted) {
                newModel = this.innerTrainer.train((Dataset)examples);
            } else {
                DatasetView bag = DatasetView.createWeightedBootstrapView((Dataset)examples, (int)examples.size(), (long)localRNG.nextLong(), (float[])exampleWeights, (ImmutableFeatureMap)featureIDs, (ImmutableOutputInfo)labelIDs);
                newModel = this.innerTrainer.train((Dataset)bag);
            }
            List predictions = newModel.predict((Dataset)examples);
            float accuracy = AdaBoostTrainer.accuracy(predictions, (Dataset<Label>)examples, exampleWeights);
            float error = 1.0f - accuracy;
            float alpha = (float)(Math.log(accuracy / error) + Math.log(numClasses - 1));
            models.add(newModel);
            modelWeights[i] = alpha;
            if ((double)accuracy + 1.0E-10 > 1.0) {
                float[] newModelWeights = Arrays.copyOf(modelWeights, models.size());
                newModelWeights[models.size() - 1] = 1.0f;
                logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i + ", returning current model.");
                logger.log(Level.FINE, "Model weights:");
                Util.logVector((Logger)logger, (Level)Level.FINE, (float[])newModelWeights);
                EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
                return new WeightedEnsembleModel("boosted-ensemble", provenance, featureIDs, labelIDs, models, (EnsembleCombiner)new VotingCombiner(), newModelWeights);
            }
            for (j = 0; j < predictions.size(); ++j) {
                if (((Label)((Prediction)predictions.get(j)).getOutput()).equals(examples.getExample(j).getOutput())) continue;
                int n = j;
                exampleWeights[n] = (float)((double)exampleWeights[n] * Math.exp(alpha));
            }
            Util.inplaceNormalizeToDistribution((float[])exampleWeights);
            if (!weighted) continue;
            for (j = 0; j < examples.size(); ++j) {
                examples.getExample(j).setWeight(exampleWeights[j]);
            }
        }
        logger.log(Level.FINE, "Model weights:");
        Util.logVector((Logger)logger, (Level)Level.FINE, (float[])modelWeights);
        EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
        return new WeightedEnsembleModel("boosted-ensemble", provenance, featureIDs, labelIDs, models, (EnsembleCombiner)new VotingCombiner(), modelWeights);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < invocationCount) {
            SplittableRandom splittableRandom = this.rng.split();
            ++this.trainInvocationCounter;
        }
    }

    private static float accuracy(List<Prediction<Label>> predictions, Dataset<Label> examples, float[] weights) {
        float correctSum = 0.0f;
        float total = 0.0f;
        for (int i = 0; i < predictions.size(); ++i) {
            if (((Label)predictions.get(i).getOutput()).equals(examples.getExample(i).getOutput())) {
                correctSum += weights[i];
            }
            total += weights[i];
        }
        logger.log(Level.FINEST, "Correct count = " + correctSum + " size = " + examples.size());
        return correctSum / total;
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }
}

