/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.druid.query.scan;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.collect.Iterables;
import com.google.inject.Inject;
import org.apache.druid.frame.allocation.MemoryAllocatorFactory;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.java.util.common.guava.BaseSequence;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.FrameSignaturePair;
import org.apache.druid.query.GenericQueryMetricsFactory;
import org.apache.druid.query.OrderBy;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryToolChest;
import org.apache.druid.query.aggregation.MetricManipulationFn;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.utils.CloseableUtils;

import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class ScanQueryQueryToolChest extends QueryToolChest<ScanResultValue, ScanQuery>
{
  private static final byte SCAN_QUERY = 0x13;
  private static final byte CACHE_STRATEGY_VERSION = 0x1;
  private static final TypeReference<ScanResultValue> TYPE_REFERENCE = new TypeReference<>() {};

  private final GenericQueryMetricsFactory queryMetricsFactory;

  @Inject
  public ScanQueryQueryToolChest(
      final GenericQueryMetricsFactory queryMetricsFactory
  )
  {
    this.queryMetricsFactory = queryMetricsFactory;
  }

  @Override
  public QueryRunner<ScanResultValue> mergeResults(final QueryRunner<ScanResultValue> runner)
  {
    return (queryPlus, responseContext) -> {
      final ScanQuery originalQuery = ((ScanQuery) (queryPlus.getQuery()));
      ScanQuery.verifyOrderByForNativeExecution(originalQuery);

      // Remove "offset" and add it to the "limit" (we won't push the offset down, just apply it here, at the
      // merge at the top of the stack).
      final long newLimit;
      if (!originalQuery.isLimited()) {
        // Unlimited stays unlimited.
        newLimit = Long.MAX_VALUE;
      } else if (originalQuery.getScanRowsLimit() > Long.MAX_VALUE - originalQuery.getScanRowsOffset()) {
        throw new ISE(
            "Cannot apply limit[%d] with offset[%d] due to overflow",
            originalQuery.getScanRowsLimit(),
            originalQuery.getScanRowsOffset()
        );
      } else {
        newLimit = originalQuery.getScanRowsLimit() + originalQuery.getScanRowsOffset();
      }

      final ScanQuery queryToRun = originalQuery.withOffset(0)
                                                .withLimit(newLimit);

      final Sequence<ScanResultValue> results;

      if (!queryToRun.isLimited()) {
        results = runner.run(queryPlus.withQuery(queryToRun), responseContext);
      } else {
        results = new BaseSequence<>(
            new BaseSequence.IteratorMaker<ScanResultValue, ScanQueryLimitRowIterator>()
            {
              @Override
              public ScanQueryLimitRowIterator make()
              {
                return new ScanQueryLimitRowIterator(runner, queryPlus.withQuery(queryToRun), responseContext);
              }

              @Override
              public void cleanup(ScanQueryLimitRowIterator iterFromMake)
              {
                CloseableUtils.closeAndWrapExceptions(iterFromMake);
              }
            });
      }

      if (originalQuery.getScanRowsOffset() > 0) {
        return new ScanQueryOffsetSequence(results, originalQuery.getScanRowsOffset());
      } else {
        return results;
      }
    };
  }

  @Override
  public QueryMetrics<Query<?>> makeMetrics(ScanQuery query)
  {
    return queryMetricsFactory.makeMetrics(query);
  }

  @Override
  public Function<ScanResultValue, ScanResultValue> makePreComputeManipulatorFn(
      ScanQuery query,
      MetricManipulationFn fn
  )
  {
    return Functions.identity();
  }

  @Override
  public TypeReference<ScanResultValue> getResultTypeReference()
  {
    return TYPE_REFERENCE;
  }

  @Override
  public QueryRunner<ScanResultValue> preMergeQueryDecoration(final QueryRunner<ScanResultValue> runner)
  {
    return (queryPlus, responseContext) -> {
      return runner.run(queryPlus, responseContext);
    };
  }

  @Override
  public RowSignature resultArraySignature(final ScanQuery query)
  {
    return query.getRowSignature();
  }

  /**
   * This batches the fetched {@link ScanResultValue}s which have similar signatures and are consecutives. In best case
   * it would return a single frame, and in the worst case, it would return as many frames as the number of {@link ScanResultValue}
   * passed.
   */
  @Override
  public Optional<Sequence<FrameSignaturePair>> resultsAsFrames(
      final ScanQuery query,
      final Sequence<ScanResultValue> resultSequence,
      MemoryAllocatorFactory memoryAllocatorFactory,
      boolean useNestedForUnknownTypes
  )
  {
    final RowSignature defaultRowSignature = resultArraySignature(query);
    return Optional.of(
        Sequences.simple(
            new ScanResultValueFramesIterable(
                resultSequence,
                memoryAllocatorFactory,
                useNestedForUnknownTypes,
                defaultRowSignature,
                rowSignature -> getResultFormatMapper(query.getResultFormat(), rowSignature.getColumnNames())
            )
        )
    );
  }

  @Override
  public Sequence<Object[]> resultsAsArrays(final ScanQuery query, final Sequence<ScanResultValue> resultSequence)
  {
    final Function<?, Object[]> mapper = getResultFormatMapper(query.getResultFormat(), resultArraySignature(query).getColumnNames());

    return resultSequence.flatMap(
        result -> {
          // Generics? Where we're going, we don't need generics.
          final List rows = (List) result.getEvents();
          final Iterable arrays = Iterables.transform(rows, (Function) mapper);
          return Sequences.simple(arrays);
        }
    );
  }

  @Override
  public CacheStrategy<ScanResultValue, ScanResultValue, ScanQuery> getCacheStrategy(
      final ScanQuery query,
      @Nullable final ObjectMapper objectMapper
  )
  {
    return new CacheStrategy<>()
    {
      @Override
      public boolean isCacheable(ScanQuery query, boolean willMergeRunners, boolean segmentLevel)
      {
        // Currently, there is no bijective mapping from ScanResultValue to Result<BySegmentResultValueClass<ScanResultValue>>.
        // This means queries will fail if:
        //   - A query is issued with bySegment:true
        //   - Segment-level cache is enabled on the broker (in which case it sends bySegment queries to data nodes).
        return !query.context().isBySegment() && (!segmentLevel || willMergeRunners);
      }

      @Override
      public byte[] computeCacheKey(ScanQuery query)
      {
        CacheKeyBuilder builder = new CacheKeyBuilder(SCAN_QUERY)
            .appendByte(CACHE_STRATEGY_VERSION)
            .appendCacheable(query.getVirtualColumns())
            .appendString(query.getResultFormat().toString())
            .appendLong(query.getScanRowsOffset())
            .appendLong(query.getScanRowsLimit())
            .appendCacheable(query.getFilter())
            .appendStrings(query.getColumns() != null ? query.getColumns() : List.of())
            .appendCacheable(query.getTimeOrder());

        List<OrderBy> orderBys = query.getOrderBys();
        if (orderBys != null) {
          builder.appendCacheables(orderBys);
        }

        List<ColumnType> columnTypes = query.getColumnTypes();
        if (columnTypes != null) {
          builder.appendCacheables(columnTypes);
        }

        return builder.build();
      }

      @Override
      public byte[] computeResultLevelCacheKey(ScanQuery query)
      {
        // Use the same key as segment-level cache no result-level transformations like aggregations
        return computeCacheKey(query);
      }

      @Override
      public TypeReference<ScanResultValue> getCacheObjectClazz()
      {
        return TYPE_REFERENCE;
      }

      @Override
      public Function<ScanResultValue, ScanResultValue> prepareForCache(boolean isResultLevelCache)
      {
        return input -> input;
      }

      @Override
      public Function<ScanResultValue, ScanResultValue> pullFromCache(boolean isResultLevelCache)
      {
        return input -> input;
      }
    };
  }

  private static Function<?, Object[]> getResultFormatMapper(ScanQuery.ResultFormat resultFormat, List<String> fields)
  {
    Function<?, Object[]> mapper;

    switch (resultFormat) {
      case RESULT_FORMAT_LIST:
        mapper = (Map<String, Object> row) -> {
          final Object[] rowArray = new Object[fields.size()];

          for (int i = 0; i < fields.size(); i++) {
            rowArray[i] = row.get(fields.get(i));
          }

          return rowArray;
        };
        break;
      case RESULT_FORMAT_COMPACTED_LIST:
        mapper = (List<Object> row) -> {
          if (row.size() == fields.size()) {
            return row.toArray();
          } else if (fields.isEmpty()) {
            return new Object[0];
          } else {
            // Uh oh... mismatch in expected and actual field count. I don't think this should happen, so let's
            // throw an exception. If this really does happen, and there's a good reason for it, then we should remap
            // the result row here.
            throw new ISE("Mismatch in expected[%d] vs actual[%s] field count", fields.size(), row.size());
          }
        };
        break;
      default:
        throw new UOE("Unsupported resultFormat for array-based results: %s", resultFormat);
    }
    return mapper;
  }
}
