/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.neural_network.NeuralEntity;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;
import org.jpmml.rexp.Formula;
import org.jpmml.rexp.FormulaUtil;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.RBooleanVector;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.XLevelsFormulaContext;

public class NNetConverter
extends ModelConverter<RGenericVector> {
    public NNetConverter(RGenericVector nnet) {
        super(nnet);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RGenericVector nnet = (RGenericVector)this.getObject();
        RStringVector lev = nnet.getStringElement("lev", false);
        RExp terms = (RExp)nnet.getElement("terms");
        RGenericVector xlevels = nnet.getGenericElement("xlevels");
        RStringVector coefnames = nnet.getStringElement("coefnames");
        XLevelsFormulaContext context = new XLevelsFormulaContext(xlevels);
        Formula formula = FormulaUtil.createFormula(terms, context, encoder);
        FormulaUtil.setLabel(formula, terms, lev, encoder);
        FormulaUtil.addFeatures(formula, coefnames, true, encoder);
    }

    @Override
    public Model encodeModel(Schema schema) {
        NeuralLayer neuralLayer;
        int nOutput;
        MiningFunction miningFunction;
        RGenericVector nnet = (RGenericVector)this.getObject();
        RDoubleVector n = nnet.getDoubleElement("n");
        RBooleanVector linout = nnet.getBooleanElement("linout", false);
        RBooleanVector softmax = nnet.getBooleanElement("softmax", false);
        RBooleanVector censored = nnet.getBooleanElement("censored", false);
        RDoubleVector wts = nnet.getDoubleElement("wts");
        RStringVector lev = nnet.getStringElement("lev", false);
        if (n.size() != 3) {
            throw new IllegalArgumentException();
        }
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        if (lev == null) {
            if (linout != null && !((Boolean)linout.asScalar()).booleanValue()) {
                throw new IllegalArgumentException();
            }
            miningFunction = MiningFunction.REGRESSION;
        } else {
            miningFunction = MiningFunction.CLASSIFICATION;
        }
        int nInput = ValueUtil.asInt((Number)n.getValue(0));
        SchemaUtil.checkSize((int)nInput, (List)features);
        NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs((List)features, (DataType)DataType.DOUBLE);
        int offset = 0;
        ArrayList<NeuralLayer> neuralLayers = new ArrayList<NeuralLayer>();
        List entities = neuralInputs.getNeuralInputs();
        int nHidden = ValueUtil.asInt((Number)n.getValue(1));
        if (nHidden > 0) {
            NeuralLayer neuralLayer2 = NNetConverter.encodeNeuralLayer("hidden", nHidden, entities, wts, offset).setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
            offset += nHidden * (entities.size() + 1);
            neuralLayers.add(neuralLayer2);
            entities = neuralLayer2.getNeurons();
        }
        if ((nOutput = ValueUtil.asInt((Number)n.getValue(2))) == 1) {
            neuralLayer = NNetConverter.encodeNeuralLayer("output", nOutput, entities, wts, offset);
            offset += nOutput * (entities.size() + 1);
            neuralLayers.add(neuralLayer);
            entities = neuralLayer.getNeurons();
            switch (miningFunction) {
                case REGRESSION: {
                    break;
                }
                case CLASSIFICATION: {
                    List transformationNeuralLayers = NeuralNetworkUtil.createBinaryLogisticTransformation((NeuralEntity)((NeuralEntity)Iterables.getOnlyElement((Iterable)entities)));
                    neuralLayers.addAll(transformationNeuralLayers);
                    neuralLayer = (NeuralLayer)Iterables.getLast((Iterable)transformationNeuralLayers);
                    entities = neuralLayer.getNeurons();
                }
            }
        } else if (nOutput > 1) {
            neuralLayer = NNetConverter.encodeNeuralLayer("output", nOutput, entities, wts, offset);
            if (softmax != null && ((Boolean)softmax.asScalar()).booleanValue()) {
                if (censored != null && ((Boolean)censored.asScalar()).booleanValue()) {
                    throw new IllegalArgumentException();
                }
                neuralLayer.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
            }
            offset += nOutput * (entities.size() + 1);
            neuralLayers.add(neuralLayer);
            entities = neuralLayer.getNeurons();
        } else {
            throw new IllegalArgumentException();
        }
        NeuralNetwork neuralNetwork = new NeuralNetwork(miningFunction, NeuralNetwork.ActivationFunction.IDENTITY, ModelUtil.createMiningSchema((Label)label), neuralInputs, neuralLayers);
        switch (miningFunction) {
            case REGRESSION: {
                neuralNetwork.setNeuralOutputs(NeuralNetworkUtil.createRegressionNeuralOutputs((List)entities, (ContinuousLabel)((ContinuousLabel)label)));
                break;
            }
            case CLASSIFICATION: {
                neuralNetwork.setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs((List)entities, (CategoricalLabel)((CategoricalLabel)label))).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)((CategoricalLabel)label)));
            }
        }
        return neuralNetwork;
    }

    private static NeuralLayer encodeNeuralLayer(String prefix, int n, List<? extends NeuralEntity> entities, RDoubleVector wts, int offset) {
        NeuralLayer neuralLayer = new NeuralLayer();
        for (int i = 0; i < n; ++i) {
            List<Double> weights = wts.getValues().subList(offset + 1, offset + (entities.size() + 1));
            Double bias = wts.getValue(offset);
            Neuron neuron = NeuralNetworkUtil.createNeuron(entities, weights, (Number)bias).setId(prefix + "/" + String.valueOf(i + 1));
            neuralLayer.addNeurons(new Neuron[]{neuron});
            offset += entities.size() + 1;
        }
        return neuralLayer;
    }
}

