/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.query.aggregation.variance.sql;

import com.google.common.collect.ImmutableList;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignatures;

public abstract class BaseVarianceSqlAggregator
implements SqlAggregator {
    private static final String VARIANCE_NAME = "VARIANCE";
    private static final String STDDEV_NAME = "STDDEV";
    private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE = BaseVarianceSqlAggregator.buildSqlVarianceAggFunction("VARIANCE");
    private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE = BaseVarianceSqlAggregator.buildSqlVarianceAggFunction(SqlKind.VAR_POP.name());
    private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE = BaseVarianceSqlAggregator.buildSqlVarianceAggFunction(SqlKind.VAR_SAMP.name());
    private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE = BaseVarianceSqlAggregator.buildSqlVarianceAggFunction("STDDEV");
    private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE = BaseVarianceSqlAggregator.buildSqlVarianceAggFunction(SqlKind.STDDEV_POP.name());
    private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE = BaseVarianceSqlAggregator.buildSqlVarianceAggFunction(SqlKind.STDDEV_SAMP.name());

    @Nullable
    public Aggregation toDruidAggregation(PlannerContext plannerContext, VirtualColumnRegistry virtualColumnRegistry, String name, AggregateCall aggregateCall, InputAccessor inputAccessor, List<Aggregation> existingAggregations, boolean finalizeAggregations) {
        String inputTypeName;
        DimensionSpec dimensionSpec;
        RexNode inputOperand = inputAccessor.getField(((Integer)aggregateCall.getArgList().get(0)).intValue());
        DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator((PlannerContext)plannerContext, (RowSignature)inputAccessor.getInputRowSignature(), (RexNode)inputOperand);
        if (input == null) {
            return null;
        }
        RelDataType dataType = inputOperand.getType();
        ColumnType inputType = Calcites.getColumnTypeForRelDataType((RelDataType)dataType);
        SqlAggFunction func = this.calciteFunction();
        boolean needsPostAggregator = false;
        String aggName = name;
        if (func.getName().equals(STDDEV_NAME) || func.getName().equals(SqlKind.STDDEV_POP.name()) || func.getName().equals(SqlKind.STDDEV_SAMP.name())) {
            needsPostAggregator = true;
            aggName = StringUtils.format((String)"%s:agg", (Object[])new Object[]{name});
        }
        StandardDeviationPostAggregator postAggregator = null;
        if (input.isSimpleExtraction()) {
            dimensionSpec = input.getSimpleExtraction().toDimensionSpec(null, inputType);
        } else {
            String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(input, dataType);
            dimensionSpec = new DefaultDimensionSpec(virtualColumnName, null, inputType);
        }
        if (inputType == null) {
            throw new IAE("VarianceSqlAggregator[%s] has invalid inputType", new Object[]{func});
        }
        if (inputType.isNumeric()) {
            inputTypeName = StringUtils.toLowerCase((String)((ValueType)inputType.getType()).name());
        } else if (inputType.equals((Object)VarianceAggregatorFactory.TYPE)) {
            inputTypeName = "variance";
        } else {
            throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", new Object[]{func, inputType.asTypeString()});
        }
        String estimator = func.getName().equals(SqlKind.VAR_POP.name()) || func.getName().equals(SqlKind.STDDEV_POP.name()) ? "population" : "sample";
        VarianceAggregatorFactory aggregatorFactory = new VarianceAggregatorFactory(aggName, dimensionSpec.getDimension(), estimator, inputTypeName);
        if (needsPostAggregator) {
            postAggregator = new StandardDeviationPostAggregator(name, aggregatorFactory.getName(), estimator);
        }
        return Aggregation.create((List)ImmutableList.of((Object)((Object)aggregatorFactory)), postAggregator);
    }

    private static SqlAggFunction buildSqlVarianceAggFunction(String name) {
        return (SqlAggFunction)OperatorConversions.aggregatorBuilder((String)name).returnTypeInference((SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.DOUBLE)).operandTypeChecker((SqlOperandTypeChecker)OperandTypes.or((SqlSingleOperandTypeChecker[])new SqlSingleOperandTypeChecker[]{OperandTypes.NUMERIC, RowSignatures.complexTypeChecker((ColumnType)VarianceAggregatorFactory.TYPE)})).functionCategory(SqlFunctionCategory.NUMERIC).build();
    }

    public static class StdDevSqlAggregator
    extends BaseVarianceSqlAggregator {
        public SqlAggFunction calciteFunction() {
            return STDDEV_SQL_AGG_FUNC_INSTANCE;
        }
    }

    public static class StdDevSampSqlAggregator
    extends BaseVarianceSqlAggregator {
        public SqlAggFunction calciteFunction() {
            return STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE;
        }
    }

    public static class StdDevPopSqlAggregator
    extends BaseVarianceSqlAggregator {
        public SqlAggFunction calciteFunction() {
            return STDDEV_POP_SQL_AGG_FUNC_INSTANCE;
        }
    }

    public static class VarianceSqlAggregator
    extends BaseVarianceSqlAggregator {
        public SqlAggFunction calciteFunction() {
            return VARIANCE_SQL_AGG_FUNC_INSTANCE;
        }
    }

    public static class VarSampSqlAggregator
    extends BaseVarianceSqlAggregator {
        public SqlAggFunction calciteFunction() {
            return VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE;
        }
    }

    public static class VarPopSqlAggregator
    extends BaseVarianceSqlAggregator {
        public SqlAggFunction calciteFunction() {
            return VARIANCE_POP_SQL_AGG_FUNC_INSTANCE;
        }
    }
}

