/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.translator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.adapter.druid.DruidQuery;
import org.apache.calcite.adapter.jdbc.JdbcConvention;
import org.apache.calcite.adapter.jdbc.JdbcRel;
import org.apache.calcite.adapter.jdbc.JdbcRules;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.Spool;
import org.apache.calcite.rel.core.Window;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelShuttleImpl;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAntiJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSemiJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortLimit;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableFunctionScan;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveValues;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.jdbc.HiveJdbcConverter;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveRelColumnsAlignment;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.PlanModifierUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.SqlFunctionConverter;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PlanModifierForASTConv {
    private static final Logger LOG = LoggerFactory.getLogger(PlanModifierForASTConv.class);

    public static RelNode convertOpTree(RelNode rel, List<FieldSchema> resultSchema, boolean alignColumns) throws CalciteSemanticException {
        if (rel instanceof HiveValues) {
            List<String> fieldNames = resultSchema.stream().map(FieldSchema::getName).collect(Collectors.toList());
            return ((HiveValues)rel).copy(fieldNames);
        }
        RelNode newTopNode = rel;
        if (LOG.isDebugEnabled()) {
            LOG.debug("Original plan for PlanModifier\n " + RelOptUtil.toString((RelNode)newTopNode));
        }
        if (!(newTopNode instanceof Project || newTopNode instanceof Sort || newTopNode instanceof Exchange)) {
            newTopNode = PlanModifierForASTConv.introduceDerivedTable(newTopNode);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Plan after top-level introduceDerivedTable\n " + RelOptUtil.toString((RelNode)newTopNode));
            }
        }
        PlanModifierForASTConv.convertOpTree(newTopNode, null);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Plan after nested convertOpTree\n " + RelOptUtil.toString((RelNode)newTopNode));
        }
        newTopNode = newTopNode.accept((RelShuttle)new SelfJoinHandler());
        if (LOG.isDebugEnabled()) {
            LOG.debug("Plan after self-join disambiguation\n " + RelOptUtil.toString((RelNode)newTopNode));
        }
        if (alignColumns) {
            HiveRelColumnsAlignment propagator = new HiveRelColumnsAlignment(HiveRelFactories.HIVE_BUILDER.create(newTopNode.getCluster(), null));
            newTopNode = propagator.align(newTopNode);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Plan after propagating order\n " + RelOptUtil.toString((RelNode)newTopNode));
            }
        }
        Pair<RelNode, RelNode> topSelparentPair = HiveCalciteUtil.getTopLevelSelect(newTopNode);
        PlanModifierUtil.fixTopOBSchema(newTopNode, topSelparentPair, resultSchema, true);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Plan after fixTopOBSchema\n " + RelOptUtil.toString((RelNode)newTopNode));
        }
        topSelparentPair = HiveCalciteUtil.getTopLevelSelect(newTopNode);
        newTopNode = PlanModifierForASTConv.renameTopLevelSelectInResultSchema(newTopNode, topSelparentPair, resultSchema);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Final plan after modifier\n " + RelOptUtil.toString((RelNode)newTopNode));
        }
        return newTopNode;
    }

    private static void convertOpTree(RelNode rel, RelNode parent) {
        List childNodes;
        if (rel instanceof HepRelVertex) {
            throw new RuntimeException("Found HepRelVertex");
        }
        if (rel instanceof Join) {
            if (!PlanModifierForASTConv.validJoinParent(rel, parent)) {
                PlanModifierForASTConv.introduceDerivedTable(rel, parent);
            }
        } else {
            if (rel instanceof MultiJoin) {
                throw new RuntimeException("Found MultiJoin");
            }
            if (rel instanceof RelSubset) {
                throw new RuntimeException("Found RelSubset");
            }
            if (rel instanceof SetOp) {
                if (!PlanModifierForASTConv.validSetopParent(rel, parent)) {
                    PlanModifierForASTConv.introduceDerivedTable(rel, parent);
                }
                SetOp setop = (SetOp)rel;
                for (RelNode inputRel : setop.getInputs()) {
                    if (PlanModifierForASTConv.validSetopChild(inputRel)) continue;
                    PlanModifierForASTConv.introduceDerivedTable(inputRel, (RelNode)setop);
                }
            } else if (rel instanceof HiveTableFunctionScan) {
                if (!PlanModifierForASTConv.validTableFunctionScanChild((HiveTableFunctionScan)rel)) {
                    PlanModifierForASTConv.introduceDerivedTable(rel.getInput(0), rel);
                }
            } else if (rel instanceof SingleRel) {
                if (rel instanceof HiveJdbcConverter) {
                    PlanModifierForASTConv.introduceDerivedTable(rel, parent);
                } else if (rel instanceof Filter) {
                    if (!PlanModifierForASTConv.validFilterParent(rel, parent)) {
                        PlanModifierForASTConv.introduceDerivedTable(rel, parent);
                    }
                } else if (rel instanceof HiveSortLimit) {
                    if (!PlanModifierForASTConv.validSortParent(rel, parent)) {
                        PlanModifierForASTConv.introduceDerivedTable(rel, parent);
                    }
                    if (!PlanModifierForASTConv.validSortChild((HiveSortLimit)rel)) {
                        PlanModifierForASTConv.introduceDerivedTable(((HiveSortLimit)rel).getInput(), rel);
                    }
                } else if (rel instanceof HiveSortExchange) {
                    if (!PlanModifierForASTConv.validExchangeChild((HiveSortExchange)rel)) {
                        PlanModifierForASTConv.introduceDerivedTable(((HiveSortExchange)rel).getInput(), rel);
                    }
                } else if (rel instanceof HiveAggregate) {
                    RelNode newParent = parent;
                    if (!PlanModifierForASTConv.validGBParent(rel, parent)) {
                        newParent = PlanModifierForASTConv.introduceDerivedTable(rel, parent);
                    }
                    if (PlanModifierForASTConv.isEmptyGrpAggr(rel)) {
                        PlanModifierForASTConv.replaceEmptyGroupAggr(rel, newParent);
                        rel = (RelNode)newParent.getInputs().get(0);
                    }
                } else if (rel instanceof Spool) {
                    Spool spool = (Spool)rel;
                    RelBuilder b = HiveRelFactories.HIVE_BUILDER.create(spool.getCluster(), null);
                    b.push(spool.getInput());
                    b.project((Iterable)b.fields(), (Iterable)spool.getTable().getRowType().getFieldNames(), true);
                    spool.replaceInput(0, b.build());
                }
            }
        }
        if ((childNodes = rel.getInputs()) != null) {
            for (RelNode r : childNodes) {
                PlanModifierForASTConv.convertOpTree(r, rel);
            }
        }
    }

    public static RelNode renameTopLevelSelectInResultSchema(RelNode rootRel, Pair<RelNode, RelNode> topSelparentPair, List<FieldSchema> resultSchema) throws CalciteSemanticException {
        RelNode parentOforiginalProjRel = (RelNode)topSelparentPair.getKey();
        HiveProject originalProjRel = (HiveProject)topSelparentPair.getValue();
        List rootChildExps = originalProjRel.getProjects();
        if (resultSchema.size() != rootChildExps.size()) {
            LOG.error(PlanModifierUtil.generateInvalidSchemaMessage(originalProjRel, resultSchema, 0));
            throw new CalciteSemanticException("Result Schema didn't match Optimized Op Tree Schema");
        }
        ArrayList<String> newSelAliases = new ArrayList<String>();
        for (int i = 0; i < rootChildExps.size(); ++i) {
            String colAlias = resultSchema.get(i).getName();
            colAlias = PlanModifierForASTConv.getNewColAlias(newSelAliases, colAlias);
            newSelAliases.add(colAlias);
        }
        HiveProject replacementProjectRel = HiveProject.create(originalProjRel.getInput(), originalProjRel.getProjects(), newSelAliases);
        if (rootRel == originalProjRel) {
            return replacementProjectRel;
        }
        parentOforiginalProjRel.replaceInput(0, (RelNode)replacementProjectRel);
        return rootRel;
    }

    private static String getNewColAlias(List<String> newSelAliases, String colAlias) {
        int index = 1;
        Object newColAlias = colAlias;
        while (newSelAliases.contains(newColAlias)) {
            newColAlias = colAlias + "_" + index++;
        }
        return newColAlias;
    }

    private static RelNode introduceDerivedTable(RelNode rel) {
        List<RexNode> projectList = HiveCalciteUtil.getProjsFromBelowAsInputRef(rel);
        HiveProject select = HiveProject.create(rel.getCluster(), rel, projectList, rel.getRowType(), Collections.emptyList());
        if (rel instanceof JdbcRel) {
            select = JdbcRules.JdbcProjectRule.create((JdbcConvention)((JdbcConvention)rel.getConvention())).convert((RelNode)select);
        }
        return select;
    }

    private static RelNode introduceDerivedTable(RelNode rel, RelNode parent) {
        int i = 0;
        int pos = -1;
        List childList = parent.getInputs();
        for (RelNode child : childList) {
            if (child == rel) {
                pos = i;
                break;
            }
            ++i;
        }
        if (pos == -1) {
            throw new RuntimeException("Couldn't find child node in parent's inputs");
        }
        RelNode select = PlanModifierForASTConv.introduceDerivedTable(rel);
        parent.replaceInput(pos, select);
        return select;
    }

    private static boolean validJoinParent(RelNode joinNode, RelNode parent) {
        boolean validParent = true;
        if (parent instanceof Join) {
            if (((Join)parent).getRight() == joinNode && (((Join)parent).getLeft() instanceof Join || parent instanceof HiveSemiJoin || parent instanceof HiveAntiJoin)) {
                validParent = false;
            }
        } else if (parent instanceof SetOp) {
            validParent = false;
        }
        return validParent;
    }

    private static boolean validFilterParent(RelNode filterNode, RelNode parent) {
        boolean validParent = true;
        if (parent instanceof Filter || parent instanceof Join || parent instanceof SetOp || parent instanceof Aggregate && filterNode.getInputs().get(0) instanceof Aggregate) {
            validParent = false;
        }
        return validParent;
    }

    private static boolean validGBParent(RelNode gbNode, RelNode parent) {
        boolean validParent = true;
        if (parent instanceof Join || parent instanceof SetOp || parent instanceof Aggregate || parent instanceof Filter && ((Aggregate)gbNode).getGroupSet().isEmpty()) {
            validParent = false;
        }
        if (parent instanceof Project) {
            for (RexNode child : ((Project)parent).getProjects()) {
                if (!(child instanceof RexOver) && !(child instanceof Window.RexWinAggCall)) continue;
                return false;
            }
        }
        return validParent;
    }

    private static boolean validSortParent(RelNode sortNode, RelNode parent) {
        boolean validParent = true;
        if (!(parent == null || parent instanceof Project || HiveCalciteUtil.pureLimitRelNode(parent) && HiveCalciteUtil.pureOrderRelNode(sortNode))) {
            validParent = false;
        }
        return validParent;
    }

    private static boolean validSortChild(HiveSortLimit sortNode) {
        boolean validChild = true;
        RelNode child = sortNode.getInput();
        if (!(child instanceof Project || HiveCalciteUtil.pureLimitRelNode(sortNode) && HiveCalciteUtil.pureOrderRelNode(child))) {
            validChild = false;
        }
        return validChild;
    }

    private static boolean validExchangeChild(HiveSortExchange sortNode) {
        return sortNode.getInput() instanceof Project;
    }

    private static boolean validTableFunctionScanChild(HiveTableFunctionScan htfsNode) {
        return htfsNode.getInputs().size() == 1 && (htfsNode.getInput(0) instanceof Project || htfsNode.getInput(0) instanceof HiveTableScan);
    }

    private static boolean validSetopParent(RelNode setop, RelNode parent) {
        boolean validChild = true;
        if (parent != null && !(parent instanceof Project)) {
            validChild = false;
        }
        return validChild;
    }

    private static boolean validSetopChild(RelNode setopChild) {
        boolean validChild = true;
        if (!(setopChild instanceof Project)) {
            validChild = false;
        }
        return validChild;
    }

    private static boolean isEmptyGrpAggr(RelNode gbNode) {
        Aggregate aggrnode = (Aggregate)gbNode;
        return aggrnode.getGroupSet().isEmpty() && aggrnode.getAggCallList().isEmpty();
    }

    private static void replaceEmptyGroupAggr(RelNode rel, RelNode parent) {
        List exps = parent instanceof Project ? ((Project)parent).getProjects() : Collections.emptyList();
        for (RexNode rexNode : exps) {
            if (((Boolean)rexNode.accept((RexVisitor)new HiveCalciteUtil.ConstantFinder())).booleanValue()) continue;
            throw new RuntimeException("We expect " + parent.toString() + " to contain only constants. However, " + rexNode.toString() + " is " + String.valueOf(rexNode.getKind()));
        }
        HiveAggregate oldAggRel = (HiveAggregate)rel;
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        RelDataType longType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, typeFactory);
        RelDataType intType = TypeConverter.convert(TypeInfoFactory.intTypeInfo, typeFactory);
        SqlAggFunction countFn = SqlFunctionConverter.getCalciteAggFn("count", false, (ImmutableList<RelDataType>)ImmutableList.of((Object)intType), longType);
        ImmutableList argList = ImmutableList.of((Object)0);
        AggregateCall dummyCall = new AggregateCall(countFn, false, (List)argList, longType, null);
        RelTraitSet relTraitSet = oldAggRel.getTraitSet();
        RelNode relNode = oldAggRel.getInput();
        Objects.requireNonNull(oldAggRel);
        Aggregate newAggRel = oldAggRel.copy(relTraitSet, relNode, false, oldAggRel.getGroupSet(), (List)oldAggRel.getGroupSets(), (List)ImmutableList.of((Object)dummyCall));
        RelNode select = PlanModifierForASTConv.introduceDerivedTable((RelNode)newAggRel);
        parent.replaceInput(0, select);
    }

    private static class SelfJoinHandler
    extends HiveRelShuttleImpl {
        private final Set<String> aliases = new HashSet<String>();

        private SelfJoinHandler() {
        }

        @Override
        public RelNode visit(HiveJoin join) {
            SelfJoinHandler lf = new SelfJoinHandler();
            RelNode newL = join.getLeft().accept((RelShuttle)lf);
            SelfJoinHandler rf = new SelfJoinHandler();
            RelNode newR = join.getRight().accept((RelShuttle)rf);
            if (Sets.intersection(lf.aliases, rf.aliases).isEmpty()) {
                this.aliases.addAll(lf.aliases);
                this.aliases.addAll(rf.aliases);
            } else {
                this.aliases.addAll(rf.aliases);
                newL = PlanModifierForASTConv.introduceDerivedTable(newL);
            }
            if (newL == join.getLeft() && newR == join.getRight()) {
                return join;
            }
            return join.copy(join.getTraitSet(), Arrays.asList(newL, newR));
        }

        @Override
        public RelNode visit(HiveProject project) {
            RelNode rel = super.visit(project);
            this.aliases.clear();
            return rel;
        }

        @Override
        public RelNode visit(HiveTableScan scan) {
            this.aliases.add(scan.getTableAlias().toLowerCase());
            return scan;
        }

        @Override
        public RelNode visit(HiveJdbcConverter conv) {
            this.aliases.add(conv.getTableScan().getHiveTableScan().getTableAlias().toLowerCase());
            return conv;
        }

        @Override
        public RelNode visit(RelNode rel) {
            if (rel instanceof DruidQuery) {
                DruidQuery dq = (DruidQuery)rel;
                this.aliases.add(((HiveTableScan)dq.getTableScan()).getTableAlias().toLowerCase());
                return dq;
            }
            return super.visit(rel);
        }
    }
}

