/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.regression;

import java.io.Serializable;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.regression.AdamWUpdater;
import org.apache.spark.ml.regression.BaseFactorizationMachinesGradient;
import org.apache.spark.ml.regression.LogisticFactorizationMachinesGradient;
import org.apache.spark.ml.regression.MSEFactorizationMachinesGradient;
import org.apache.spark.mllib.optimization.SquaredL2Updater;
import org.apache.spark.mllib.optimization.Updater;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.ArrayOps$;
import scala.collection.immutable.Seq;
import scala.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.ModuleSerializationProxy;
import scala.runtime.RichInt$;
import scala.runtime.java8.JFunction1;
import scala.runtime.java8.JFunction2;

public final class FactorizationMachines$
implements Serializable {
    public static final FactorizationMachines$ MODULE$ = new FactorizationMachines$();
    private static final String GD = "gd";
    private static final String AdamW = "adamW";
    private static final String[] supportedSolvers = (String[])((Object[])new String[]{MODULE$.GD(), MODULE$.AdamW()});
    private static final String LogisticLoss = "logisticLoss";
    private static final String SquaredError = "squaredError";
    private static final String[] supportedRegressorLosses = (String[])((Object[])new String[]{MODULE$.SquaredError()});
    private static final String[] supportedClassifierLosses = (String[])((Object[])new String[]{MODULE$.LogisticLoss()});
    private static final String[] supportedLosses = (String[])ArrayOps$.MODULE$.$plus$plus$extension(Predef$.MODULE$.refArrayOps((Object[])MODULE$.supportedRegressorLosses()), (Object)MODULE$.supportedClassifierLosses(), ClassTag$.MODULE$.apply(String.class));

    public String GD() {
        return GD;
    }

    public String AdamW() {
        return AdamW;
    }

    public String[] supportedSolvers() {
        return supportedSolvers;
    }

    public String LogisticLoss() {
        return LogisticLoss;
    }

    public String SquaredError() {
        return SquaredError;
    }

    public String[] supportedRegressorLosses() {
        return supportedRegressorLosses;
    }

    public String[] supportedClassifierLosses() {
        return supportedClassifierLosses;
    }

    public String[] supportedLosses() {
        return supportedLosses;
    }

    public Updater parseSolver(String solver, int coefficientsSize) {
        Updater updater;
        String string = solver;
        String string2 = this.GD();
        String string3 = string;
        if (!(string2 != null ? !string2.equals(string3) : string3 != null)) {
            updater = new SquaredL2Updater();
        } else {
            String string4 = this.AdamW();
            String string5 = string;
            if (!(string4 != null ? !string4.equals(string5) : string5 != null)) {
                updater = new AdamWUpdater(coefficientsSize);
            } else {
                throw new MatchError((Object)string);
            }
        }
        return updater;
    }

    public BaseFactorizationMachinesGradient parseLoss(String lossFunc, int factorSize, boolean fitIntercept, boolean fitLinear, int numFeatures) {
        BaseFactorizationMachinesGradient baseFactorizationMachinesGradient;
        String string = lossFunc;
        String string2 = this.LogisticLoss();
        String string3 = string;
        if (!(string2 != null ? !string2.equals(string3) : string3 != null)) {
            baseFactorizationMachinesGradient = new LogisticFactorizationMachinesGradient(factorSize, fitIntercept, fitLinear, numFeatures);
        } else {
            String string4 = this.SquaredError();
            String string5 = string;
            if (!(string4 != null ? !string4.equals(string5) : string5 != null)) {
                baseFactorizationMachinesGradient = new MSEFactorizationMachinesGradient(factorSize, fitIntercept, fitLinear, numFeatures);
            } else {
                throw new IllegalArgumentException(new StringBuilder(35).append("loss function type ").append(lossFunc).append(" is invalidation").toString());
            }
        }
        return baseFactorizationMachinesGradient;
    }

    public Tuple3<Object, Vector, Matrix> splitCoefficients(Vector coefficients, int numFeatures, int factorSize, boolean fitIntercept, boolean fitLinear) {
        int coefficientsSize = numFeatures * factorSize + (fitLinear ? numFeatures : 0) + (fitIntercept ? 1 : 0);
        Predef$.MODULE$.require(coefficientsSize == coefficients.size(), (Function0 & Serializable)() -> new StringBuilder(50).append("coefficients.size did not match the excepted size ").append(coefficientsSize).toString());
        double intercept = fitIntercept ? coefficients.apply(coefficients.size() - 1) : 0.0;
        DenseVector linear = fitLinear ? new DenseVector((double[])ArrayOps$.MODULE$.slice$extension(Predef$.MODULE$.doubleArrayOps(coefficients.toArray()), numFeatures * factorSize, numFeatures * factorSize + numFeatures)) : Vectors$.MODULE$.sparse(numFeatures, (Seq)package$.MODULE$.Seq().empty());
        DenseMatrix factors = new DenseMatrix(numFeatures, factorSize, (double[])ArrayOps$.MODULE$.slice$extension(Predef$.MODULE$.doubleArrayOps(coefficients.toArray()), 0, numFeatures * factorSize), true);
        return new Tuple3((Object)BoxesRunTime.boxToDouble((double)intercept), (Object)linear, (Object)factors);
    }

    public Vector combineCoefficients(double intercept, Vector linear, Matrix factors, boolean fitIntercept, boolean fitLinear) {
        double[] dArray;
        Object object = Predef$.MODULE$.doubleArrayOps((double[])ArrayOps$.MODULE$.$plus$plus$extension(Predef$.MODULE$.doubleArrayOps(factors.toDense().values()), (Object)(fitLinear ? linear.toArray() : Array$.MODULE$.emptyDoubleArray()), (ClassTag)ClassTag$.MODULE$.Double()));
        if (fitIntercept) {
            double[] dArray2 = new double[1];
            dArray = dArray2;
            dArray2[0] = intercept;
        } else {
            dArray = Array$.MODULE$.emptyDoubleArray();
        }
        double[] coefficients = (double[])ArrayOps$.MODULE$.$plus$plus$extension(object, (Object)dArray, (ClassTag)ClassTag$.MODULE$.Double());
        return new DenseVector(coefficients);
    }

    public double getRawPrediction(Vector features, double intercept, Vector linear, Matrix factors) {
        DoubleRef rawPrediction = DoubleRef.create((double)(intercept + features.dot(linear)));
        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), factors.numCols()).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable)f -> {
            DoubleRef sumSquare = DoubleRef.create((double)0.0);
            DoubleRef sum = DoubleRef.create((double)0.0);
            features.foreachNonZero((Function2)(JFunction2.mcVID.sp & Serializable)(x0$1, x1$1) -> {
                Tuple2.mcID.sp sp2 = new Tuple2.mcID.sp(x0$1, x1$1);
                if (sp2 != null) {
                    int index = sp2._1$mcI$sp();
                    double value = sp2._2$mcD$sp();
                    double vx = factors.apply(index, f) * value;
                    sumSquare$1.elem += vx * vx;
                    sum$1.elem += vx;
                } else {
                    throw new MatchError((Object)sp2);
                }
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            });
            rawPrediction$1.elem += 0.5 * (sum.elem * sum.elem - sumSquare.elem);
        });
        return rawPrediction.elem;
    }

    private Object writeReplace() {
        return new ModuleSerializationProxy(FactorizationMachines$.class);
    }

    private FactorizationMachines$() {
    }
}

