/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import org.apache.lucene.internal.hppc.IntHashSet;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.HnswBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.HnswLock;
import org.apache.lucene.util.hnsw.HnswUtil;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

public class HnswGraphBuilder
implements HnswBuilder {
    public static final int DEFAULT_MAX_CONN = 16;
    public static final int DEFAULT_BEAM_WIDTH = 100;
    private static final long DEFAULT_RAND_SEED = 42L;
    public static final String HNSW_COMPONENT = "HNSW";
    public static long randSeed = 42L;
    private final int M;
    private final double ml;
    private final SplittableRandom random;
    protected final RandomVectorScorerSupplier scorerSupplier;
    private final HnswGraphSearcher graphSearcher;
    private final GraphBuilderKnnCollector entryCandidates;
    private final GraphBuilderKnnCollector beamCandidates;
    private final GraphBuilderKnnCollector beamCandidates0;
    protected final OnHeapHnswGraph hnsw;
    protected final HnswLock hnswLock;
    protected InfoStream infoStream = InfoStream.getDefault();
    protected boolean frozen;

    public static HnswGraphBuilder create(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) throws IOException {
        return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, -1);
    }

    public static HnswGraphBuilder create(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) throws IOException {
        return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize);
    }

    protected HnswGraphBuilder(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) throws IOException {
        this(scorerSupplier, M, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
    }

    protected HnswGraphBuilder(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, OnHeapHnswGraph hnsw) throws IOException {
        this(scorerSupplier, M, beamWidth, seed, hnsw, null, new HnswGraphSearcher(new NeighborQueue(beamWidth, true), new FixedBitSet(hnsw.size())));
    }

    protected HnswGraphBuilder(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, OnHeapHnswGraph hnsw, HnswLock hnswLock, HnswGraphSearcher graphSearcher) throws IOException {
        if (M <= 0) {
            throw new IllegalArgumentException("M (max connections) must be positive");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.M = M;
        this.scorerSupplier = Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null");
        this.ml = M == 1 ? 1.0 : 1.0 / Math.log(1.0 * (double)M);
        this.random = new SplittableRandom(seed);
        this.hnsw = hnsw;
        this.hnswLock = hnswLock;
        this.graphSearcher = graphSearcher;
        this.entryCandidates = new GraphBuilderKnnCollector(1);
        this.beamCandidates = new GraphBuilderKnnCollector(beamWidth);
        this.beamCandidates0 = new GraphBuilderKnnCollector(Math.min(beamWidth / 2, M * 3));
    }

    @Override
    public OnHeapHnswGraph build(int maxOrd) throws IOException {
        if (this.frozen) {
            throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
        }
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
        }
        this.addVectors(maxOrd);
        return this.getCompletedGraph();
    }

    @Override
    public void setInfoStream(InfoStream infoStream) {
        this.infoStream = infoStream;
    }

    @Override
    public OnHeapHnswGraph getCompletedGraph() throws IOException {
        if (!this.frozen) {
            this.finish();
        }
        return this.getGraph();
    }

    @Override
    public OnHeapHnswGraph getGraph() {
        return this.hnsw;
    }

    protected void addVectors(int minOrd, int maxOrd) throws IOException {
        long start;
        if (this.frozen) {
            throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
        }
        long t = start = System.nanoTime();
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
        }
        UpdateableRandomVectorScorer scorer = this.scorerSupplier.scorer();
        for (int node = minOrd; node < maxOrd; ++node) {
            scorer.setScoringOrdinal(node);
            this.addGraphNode(node, scorer);
            if (node % 10000 != 0 || !this.infoStream.isEnabled(HNSW_COMPONENT)) continue;
            t = this.printGraphBuildStatus(node, start, t);
        }
    }

    private void addVectors(int maxOrd) throws IOException {
        this.addVectors(0, maxOrd);
    }

    public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException {
        this.addGraphNodeInternal(node, scorer, null);
    }

    private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer, IntHashSet eps0) throws IOException {
        int curMaxLevel;
        int nodeLevel;
        if (this.frozen) {
            throw new IllegalStateException("Graph builder is already frozen");
        }
        for (int level = nodeLevel = HnswGraphBuilder.getRandomGraphLevel(this.ml, this.random); level >= 0; --level) {
            this.hnsw.addNode(level, node);
        }
        if (this.hnsw.trySetNewEntryNode(node, nodeLevel)) {
            return;
        }
        int lowestUnsetLevel = 0;
        do {
            int i;
            curMaxLevel = this.hnsw.numLevels() - 1;
            int[] eps = new int[]{this.hnsw.entryNode()};
            GraphBuilderKnnCollector candidates = this.entryCandidates;
            for (int level = curMaxLevel; level > nodeLevel; --level) {
                candidates.clear();
                this.graphSearcher.searchLevel(candidates, scorer, level, eps, this.hnsw, null);
                eps[0] = candidates.popNode();
            }
            candidates = this.beamCandidates;
            NeighborArray[] scratchPerLevel = new NeighborArray[Math.min(nodeLevel, curMaxLevel) - lowestUnsetLevel + 1];
            for (i = scratchPerLevel.length - 1; i >= 0; --i) {
                int level = i + lowestUnsetLevel;
                candidates.clear();
                if (level == 0 && eps0 != null && eps0.size() > 0) {
                    eps = eps0.toArray();
                    candidates = this.beamCandidates0;
                }
                this.graphSearcher.searchLevel(candidates, scorer, level, eps, this.hnsw, null);
                eps = candidates.popUntilNearestKNodes();
                scratchPerLevel[i] = new NeighborArray(Math.max(candidates.k(), this.M + 1), false);
                HnswGraphBuilder.popToScratch(candidates, scratchPerLevel[i]);
            }
            for (i = 0; i < scratchPerLevel.length; ++i) {
                this.addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], scorer);
            }
            assert ((lowestUnsetLevel += scratchPerLevel.length) == Math.min(nodeLevel, curMaxLevel) + 1);
            if (lowestUnsetLevel > nodeLevel) {
                return;
            }
            assert (lowestUnsetLevel == curMaxLevel + 1 && nodeLevel > curMaxLevel);
            if (!this.hnsw.tryPromoteNewEntryNode(node, nodeLevel, curMaxLevel)) continue;
            return;
        } while (this.hnsw.numLevels() != curMaxLevel + 1);
        throw new IllegalStateException("We're not able to promote node " + node + " at level " + nodeLevel + " as entry node. But the max graph level " + curMaxLevel + " has not changed while we are inserting the node.");
    }

    @Override
    public void addGraphNode(int node) throws IOException {
        UpdateableRandomVectorScorer scorer = this.scorerSupplier.scorer();
        scorer.setScoringOrdinal(node);
        this.addGraphNodeInternal(node, scorer, null);
    }

    public void addGraphNodeWithEps(int node, IntHashSet eps0) throws IOException {
        UpdateableRandomVectorScorer scorer = this.scorerSupplier.scorer();
        scorer.setScoringOrdinal(node);
        this.addGraphNodeInternal(node, scorer, eps0);
    }

    private long printGraphBuildStatus(int node, long start, long t) {
        long now2 = System.nanoTime();
        this.infoStream.message(HNSW_COMPONENT, String.format(Locale.ROOT, "built %d in %d/%d ms", node, TimeUnit.NANOSECONDS.toMillis(now2 - t), TimeUnit.NANOSECONDS.toMillis(now2 - start)));
        return now2;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void addDiverseNeighbors(int level, int node, NeighborArray candidates, UpdateableRandomVectorScorer scorer) throws IOException {
        NeighborArray neighbors = this.hnsw.getNeighbors(level, node);
        assert (neighbors.size() == 0);
        int maxConnOnLevel = level == 0 ? this.M * 2 : this.M;
        boolean[] mask = this.selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel, scorer);
        for (int i = 0; i < candidates.size(); ++i) {
            if (!mask[i]) continue;
            int nbr = candidates.nodes()[i];
            if (this.hnswLock != null) {
                Lock lock = this.hnswLock.write(level, nbr);
                try {
                    NeighborArray nbrsOfNbr = this.getGraph().getNeighbors(level, nbr);
                    nbrsOfNbr.addAndEnsureDiversity(node, candidates.getScores(i), nbr, scorer);
                    continue;
                }
                finally {
                    lock.unlock();
                }
            }
            NeighborArray nbrsOfNbr = this.hnsw.getNeighbors(level, nbr);
            nbrsOfNbr.addAndEnsureDiversity(node, candidates.getScores(i), nbr, scorer);
        }
    }

    private boolean[] selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel, UpdateableRandomVectorScorer scorer) throws IOException {
        boolean[] mask = new boolean[candidates.size()];
        for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; --i) {
            int cNode = candidates.nodes()[i];
            float cScore = candidates.getScores(i);
            assert (cNode <= this.hnsw.maxNodeId());
            scorer.setScoringOrdinal(cNode);
            if (!this.diversityCheck(cScore, neighbors, scorer)) continue;
            mask[i] = true;
            neighbors.addInOrder(cNode, cScore);
        }
        return mask;
    }

    private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) {
        scratch.clear();
        int candidateCount = candidates.size();
        for (int i = 0; i < candidateCount; ++i) {
            float maxSimilarity = candidates.minimumScore();
            scratch.addInOrder(candidates.popNode(), maxSimilarity);
        }
    }

    private boolean diversityCheck(float score, NeighborArray neighbors, RandomVectorScorer scorer) throws IOException {
        for (int i = 0; i < neighbors.size(); ++i) {
            float neighborSimilarity = scorer.score(neighbors.nodes()[i]);
            if (!(neighborSimilarity >= score)) continue;
            return false;
        }
        return true;
    }

    private static int getRandomGraphLevel(double ml, SplittableRandom random) {
        double randDouble;
        while ((randDouble = random.nextDouble()) == 0.0) {
        }
        return (int)(-Math.log(randDouble) * ml);
    }

    void finish() throws IOException {
        this.frozen = true;
    }

    private void connectComponents() throws IOException {
        long start = System.nanoTime();
        for (int level = 0; level < this.hnsw.numLevels(); ++level) {
            if (this.connectComponents(level) || !this.infoStream.isEnabled(HNSW_COMPONENT)) continue;
            this.infoStream.message(HNSW_COMPONENT, "connectComponents failed on level " + level);
        }
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "connectComponents " + (System.nanoTime() - start) / 1000000L + " ms");
        }
    }

    private boolean connectComponents(int level) throws IOException {
        FixedBitSet notFullyConnected = new FixedBitSet(this.hnsw.size());
        int maxConn = this.M;
        if (level == 0) {
            maxConn *= 2;
        }
        List<HnswUtil.Component> components = HnswUtil.components(this.hnsw, level, notFullyConnected, maxConn);
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "connect " + components.size() + " components on level=" + level);
        }
        boolean result = true;
        if (components.size() > 1) {
            HnswUtil.Component c0 = components.stream().max(Comparator.comparingInt(HnswUtil.Component::size)).get();
            if (c0.start() == Integer.MAX_VALUE) {
                return false;
            }
            GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(2);
            int[] eps = new int[1];
            UpdateableRandomVectorScorer scorer = this.scorerSupplier.scorer();
            for (HnswUtil.Component c : components) {
                if (c == c0 || c.start() == Integer.MAX_VALUE) continue;
                if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
                    this.infoStream.message(HNSW_COMPONENT, "connect component " + String.valueOf(c) + " to " + String.valueOf(c0));
                }
                beam.clear();
                eps[0] = c0.start();
                scorer.setScoringOrdinal(c.start());
                this.graphSearcher.searchLevel(beam, scorer, level, eps, this.hnsw, notFullyConnected);
                boolean linked = false;
                while (beam.size() > 0) {
                    int c0node = beam.popNode();
                    if (c0node == c.start() || !notFullyConnected.get(c0node)) continue;
                    float score = beam.minimumScore();
                    assert (notFullyConnected.get(c0node));
                    this.link(level, c0node, c.start(), score, notFullyConnected);
                    linked = true;
                    if (!this.infoStream.isEnabled(HNSW_COMPONENT)) continue;
                    this.infoStream.message(HNSW_COMPONENT, "connected ok " + c0node + " -> " + c.start());
                }
                if (linked) continue;
                if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
                    this.infoStream.message(HNSW_COMPONENT, "not connected; no free nodes found");
                }
                result = false;
            }
        }
        return result;
    }

    private void link(int level, int n0, int n1, float score, FixedBitSet notFullyConnected) {
        NeighborArray nbr0 = this.hnsw.getNeighbors(level, n0);
        NeighborArray nbr1 = this.hnsw.getNeighbors(level, n1);
        int maxConn = nbr0.maxSize() - 1;
        assert (notFullyConnected.get(n0));
        assert (nbr0.size() < maxConn) : "node " + n0 + " is full, has " + nbr0.size() + " friends";
        nbr0.addOutOfOrder(n1, score);
        if (nbr0.size() == maxConn) {
            notFullyConnected.clear(n0);
        }
        if (nbr1.size() < maxConn) {
            nbr1.addOutOfOrder(n0, score);
            if (nbr1.size() == maxConn) {
                notFullyConnected.clear(n1);
            }
        }
    }

    public static final class GraphBuilderKnnCollector
    implements KnnCollector {
        private final NeighborQueue queue;
        private final int k;
        private long visitedCount;

        public GraphBuilderKnnCollector(int k) {
            this.queue = new NeighborQueue(k, false);
            this.k = k;
        }

        public int size() {
            return this.queue.size();
        }

        public int popNode() {
            return this.queue.pop();
        }

        public int[] popUntilNearestKNodes() {
            while (this.size() > this.k()) {
                this.queue.pop();
            }
            return this.queue.nodes();
        }

        float minimumScore() {
            return this.queue.topScore();
        }

        public void clear() {
            this.queue.clear();
            this.visitedCount = 0L;
        }

        @Override
        public boolean earlyTerminated() {
            return false;
        }

        @Override
        public void incVisitedCount(int count) {
            this.visitedCount += (long)count;
        }

        @Override
        public long visitedCount() {
            return this.visitedCount;
        }

        @Override
        public long visitLimit() {
            return Long.MAX_VALUE;
        }

        @Override
        public int k() {
            return this.k;
        }

        @Override
        public boolean collect(int docId, float similarity) {
            return this.queue.insertWithOverflow(docId, similarity);
        }

        @Override
        public float minCompetitiveSimilarity() {
            return this.queue.size() >= this.k() ? this.queue.topScore() : Float.NEGATIVE_INFINITY;
        }

        @Override
        public TopDocs topDocs() {
            throw new IllegalArgumentException();
        }

        @Override
        public KnnSearchStrategy getSearchStrategy() {
            return null;
        }
    }
}

