/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.tree.ClassifierNode;
import org.jpmml.rexp.DecorationUtil;
import org.jpmml.rexp.Formula;
import org.jpmml.rexp.FormulaContext;
import org.jpmml.rexp.FormulaUtil;
import org.jpmml.rexp.RBooleanVector;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RExpUtil;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.RVector;
import org.jpmml.rexp.TreeModelConverter;

public class PartyConverter
extends TreeModelConverter<RGenericVector> {
    public PartyConverter(RGenericVector party) {
        super(party);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RGenericVector party = (RGenericVector)this.getObject();
        final RGenericVector data = party.getGenericElement("data");
        final RGenericVector fitted = party.getGenericElement("fitted");
        RExp terms = (RExp)party.getElement("terms");
        RIntegerVector factors = terms.getIntegerAttribute("factors");
        RIntegerVector response = terms.getIntegerAttribute("response");
        RStringVector variableRows = factors.dimnames(0);
        RStringVector termColumns = factors.dimnames(1);
        int responseIndex = (Integer)response.asScalar();
        final String responseVariable = responseIndex != 0 ? variableRows.getDequotedValue(responseIndex - 1) : null;
        FormulaContext context = new FormulaContext(){

            @Override
            public List<String> getCategories(String variable) {
                RVector<?> data2 = this.getData(variable);
                if (data2 != null && RExpUtil.isFactor(data2)) {
                    RIntegerVector factor = (RIntegerVector)data2;
                    return factor.getLevelValues();
                }
                return null;
            }

            @Override
            public RVector<?> getData(String variable) {
                if (data.hasElement(variable)) {
                    return data.getVectorElement(variable);
                }
                if (variable.equals(responseVariable)) {
                    return fitted.getVectorElement("(response)");
                }
                return null;
            }
        };
        Formula formula = FormulaUtil.createFormula(terms, context, encoder);
        RIntegerVector levels = null;
        if (responseIndex != 0) {
            RVector<?> responseData = context.getData(responseVariable);
            if (responseData != null && RExpUtil.isFactor(responseData)) {
                levels = (RIntegerVector)responseData;
            }
        } else {
            throw new IllegalArgumentException();
        }
        FormulaUtil.setLabel(formula, terms, levels, encoder);
        FormulaUtil.addFeatures(formula, termColumns, false, encoder);
    }

    @Override
    public Model encodeModel(Schema schema) {
        TreeModel treeModel;
        RGenericVector party = (RGenericVector)this.getObject();
        RGenericVector partyNode = party.getGenericElement("node");
        RGenericVector predicted = DecorationUtil.getGenericElement(party, "predicted");
        RVector<?> response = predicted.getVectorElement("(response)");
        RDoubleVector prob = predicted.getDoubleElement("(prob)", false);
        Node root = this.encodeNode((Predicate)new True(), partyNode, response, prob, schema);
        if (RExpUtil.isFactor(response)) {
            CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
            treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), root).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)categoricalLabel));
        } else {
            treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel()), root);
        }
        return treeModel;
    }

    private Node encodeNode(Predicate predicate, RGenericVector partyNode, RVector<?> response, RDoubleVector prob, Schema schema) {
        RIntegerVector id = partyNode.getIntegerElement("id");
        RGenericVector split = partyNode.getGenericElement("split");
        RGenericVector kids = partyNode.getGenericElement("kids");
        RGenericVector surrogates = partyNode.getGenericElement("surrogates");
        RGenericVector info = partyNode.getGenericElement("info");
        if (surrogates != null) {
            throw new IllegalArgumentException();
        }
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        boolean factorResponse = RExpUtil.isFactor(response);
        Object result = factorResponse ? new ClassifierNode() : (kids == null ? new LeafNode() : new BranchNode());
        result.setId((Object)((Integer)id.asScalar())).setPredicate(predicate);
        if (factorResponse) {
            RIntegerVector factor = (RIntegerVector)response;
            int index = (Integer)id.asScalar() - 1;
            result.setScore((Object)factor.getFactorValue(index));
            CategoricalLabel categoricalLabel = (CategoricalLabel)label;
            List probabilities = FortranMatrixUtil.getRow(prob.getValues(), (int)response.size(), (int)categoricalLabel.size(), (int)index);
            List scoreDistributions = result.getScoreDistributions();
            for (int i = 0; i < categoricalLabel.size(); ++i) {
                Object value = categoricalLabel.getValue(i);
                Double probability = (Double)probabilities.get(i);
                ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability.doubleValue());
                scoreDistributions.add(scoreDistribution);
            }
        } else {
            result.setScore(response.getValue((Integer)id.asScalar() - 1));
        }
        if (kids == null) {
            return result;
        }
        RIntegerVector varid = split.getIntegerElement("varid");
        RDoubleVector breaks = split.getDoubleElement("breaks");
        RIntegerVector index = split.getIntegerElement("index");
        RBooleanVector right = split.getBooleanElement("right");
        Feature feature = (Feature)features.get((Integer)varid.asScalar() - 1);
        if (breaks != null && index == null) {
            Predicate rightPredicate;
            Predicate leftPredicate;
            ContinuousFeature continuousFeature = (ContinuousFeature)feature;
            if (kids.size() != 2) {
                throw new IllegalArgumentException();
            }
            if (breaks.size() != 1) {
                throw new IllegalArgumentException();
            }
            Double value = (Double)breaks.asScalar();
            if (((Boolean)right.asScalar()).booleanValue()) {
                leftPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            } else {
                leftPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
                rightPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
            }
            Node leftChild = this.encodeNode(leftPredicate, (RGenericVector)kids.getValue(0), response, prob, schema);
            Node rightChild = this.encodeNode(rightPredicate, (RGenericVector)kids.getValue(1), response, prob, schema);
            result.addNodes(new Node[]{leftChild, rightChild});
        } else if (breaks == null && index != null) {
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            if (kids.size() < 2) {
                throw new IllegalArgumentException();
            }
            List values = categoricalFeature.getValues();
            for (int i = 0; i < kids.size(); ++i) {
                if (!((Boolean)right.asScalar()).booleanValue()) {
                    throw new IllegalArgumentException();
                }
                Predicate childPredicate = this.createSimpleSetPredicate((Feature)categoricalFeature, PartyConverter.selectValues(values, index, i + 1));
                Node child = this.encodeNode(childPredicate, (RGenericVector)kids.getValue(i), response, prob, schema);
                result.addNodes(new Node[]{child});
            }
        } else {
            throw new IllegalArgumentException();
        }
        return result;
    }

    private static List<Object> selectValues(List<?> values, RIntegerVector index, int flag) {
        ArrayList<Object> result = new ArrayList<Object>();
        if (values.size() != index.size()) {
            throw new IllegalArgumentException();
        }
        for (int i = 0; i < values.size(); ++i) {
            Object value = values.get(i);
            if (index.getValue(i) != flag) continue;
            result.add(value);
        }
        return result;
    }
}

