/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.BoundedSpace;
import dr.inference.model.GeneralBoundsProvider;
import dr.inference.model.GraphicalParameterBound;
import dr.inference.model.Parameter;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;
import dr.xml.Reportable;
import java.util.ArrayList;

public class ReflectiveHamiltonianMonteCarloOperator
extends HamiltonianMonteCarloOperator
implements Reportable {
    private final GeneralBoundsProvider parameterBound;
    private boolean isAtBoundary = false;
    private static final boolean DEBUG = false;

    public ReflectiveHamiltonianMonteCarloOperator(AdaptationMode adaptationMode, double d, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, Transform transform, Parameter parameter2, HamiltonianMonteCarloOperator.Options options, MassPreconditioner massPreconditioner, GeneralBoundsProvider generalBoundsProvider) {
        super(adaptationMode, d, gradientWrtParameterProvider, parameter, transform, parameter2, options, massPreconditioner);
        this.parameterBound = generalBoundsProvider;
        this.leapFrogEngine = this.constructLeapFrogEngine(transform);
    }

    @Override
    protected HamiltonianMonteCarloOperator.LeapFrogEngine constructLeapFrogEngine(Transform transform) {
        if (this.parameterBound == null) {
            return null;
        }
        if (transform != null) {
            throw new RuntimeException("not yet implemented");
        }
        if (this.parameterBound instanceof GraphicalParameterBound) {
            return new WithGraphBounds(this.parameter, this.getDefaultInstabilityHandler(), this.preconditioning, this.mask, (GraphicalParameterBound)this.parameterBound);
        }
        return new WithMultivariateCurvedBounds(this.parameter, this.getDefaultInstabilityHandler(), this.preconditioning, this.mask, (BoundedSpace)this.parameterBound);
    }

    @Override
    public String getReport() {
        return "operator type: reflectiveHamiltonianMonteCarloOperator\n\n";
    }

    @Override
    public String getOperatorName() {
        return "ReflectiveHMC(" + this.parameter.getParameterName() + ")";
    }

    class WithGraphBounds
    extends WithBounds {
        private final GraphicalParameterBound graphicalParameterBound;

        protected WithGraphBounds(Parameter parameter, HamiltonianMonteCarloOperator.InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, double[] dArray, GraphicalParameterBound graphicalParameterBound) {
            super(parameter, instabilityHandler, massPreconditioner, dArray);
            this.graphicalParameterBound = graphicalParameterBound;
        }

        @Override
        protected ReflectionEvent nextEvent(double[] dArray, WrappedVector wrappedVector, double d, boolean bl) {
            ReflectionEvent reflectionEvent = this.firstReflectionAtFixedBounds(dArray, wrappedVector, d);
            ReflectionEvent reflectionEvent2 = this.firstCollision(dArray, wrappedVector, d);
            return reflectionEvent.getEventTime() < reflectionEvent2.getEventTime() ? reflectionEvent : reflectionEvent2;
        }

        private boolean isReflected(double d, double d2, double d3) {
            return (d3 - d) * (d2 - d3) > 0.0;
        }

        private boolean isCollision(double d, double d2, double d3, double d4) {
            return (d - d3) * (d2 - d4) < 0.0 || d2 == d4;
        }

        private ReflectionEvent firstCollision(double[] dArray, ReadableVector readableVector, double d) {
            int n = dArray.length;
            double[] dArray2 = this.getIntendedPosition(dArray, readableVector, d);
            double d2 = d;
            double d3 = -1.0;
            ReflectionType reflectionType = ReflectionType.None;
            int n2 = -1;
            int n3 = -1;
            for (int i = 0; i < n; ++i) {
                double d4 = ReflectiveHamiltonianMonteCarloOperator.this.preconditioning.getVelocity(i, readableVector);
                int[] nArray = this.graphicalParameterBound.getConnectedParameterIndices(i);
                if (nArray == null) continue;
                for (int n4 : this.graphicalParameterBound.getConnectedParameterIndices(i)) {
                    double d5;
                    if (n4 <= i) continue;
                    double d6 = ReflectiveHamiltonianMonteCarloOperator.this.preconditioning.getVelocity(n4, readableVector);
                    if (!this.isCollision(dArray[i], dArray2[i], dArray[n4], dArray2[n4]) || !((d5 = (dArray[n4] - dArray[i]) / (d4 - d6)) < d2)) continue;
                    d2 = d5;
                    d3 = d5 * d4 + dArray[i];
                    n2 = i;
                    n3 = n4;
                    reflectionType = ReflectionType.Collision;
                }
            }
            return new ReflectionEvent(reflectionType, d2, d3, new int[]{n2, n3});
        }

        private double[] getIntendedPosition(double[] dArray, ReadableVector readableVector, double d) {
            int n = dArray.length;
            double[] dArray2 = new double[n];
            for (int i = 0; i < n; ++i) {
                double d2 = ReflectiveHamiltonianMonteCarloOperator.this.preconditioning.getVelocity(i, readableVector);
                dArray2[i] = dArray[i] + d * d2;
            }
            return dArray2;
        }

        private ReflectionEvent firstReflectionAtFixedBounds(double[] dArray, ReadableVector readableVector, double d) {
            int n = dArray.length;
            double[] dArray2 = this.getIntendedPosition(dArray, readableVector, d);
            double d2 = d;
            double d3 = -1.0;
            ReflectionType reflectionType = ReflectionType.None;
            int n2 = -1;
            for (int i = 0; i < n; ++i) {
                double d4;
                double d5 = ReflectiveHamiltonianMonteCarloOperator.this.preconditioning.getVelocity(i, readableVector);
                double d6 = this.graphicalParameterBound.getFixedUpperBound(i);
                double d7 = this.graphicalParameterBound.getFixedLowerBound(i);
                if (this.isReflected(dArray[i], dArray2[i], d6)) {
                    d4 = (d6 - dArray[i]) / d5;
                    if (d4 < 0.0) {
                        throw new RuntimeException("Check isReflected() function plz.");
                    }
                    if (!(d4 < d2)) continue;
                    d2 = d4;
                    reflectionType = ReflectionType.Reflection;
                    n2 = i;
                    d3 = d6;
                    continue;
                }
                if (!this.isReflected(dArray[i], dArray2[i], d7)) continue;
                d4 = (d7 - dArray[i]) / d5;
                if (d4 < 0.0) {
                    throw new RuntimeException("Check isReflected() function plz.");
                }
                if (!(d4 < d2)) continue;
                d2 = d4;
                reflectionType = ReflectionType.Reflection;
                n2 = i;
                d3 = d7;
            }
            return new ReflectionEvent(reflectionType, d2, d3, new int[]{n2});
        }
    }

    class WithMultivariateCurvedBounds
    extends WithBounds {
        private final BoundedSpace space;
        public final int[] defaultIndices;

        WithMultivariateCurvedBounds(Parameter parameter, HamiltonianMonteCarloOperator.InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, double[] dArray, BoundedSpace boundedSpace) {
            int n;
            super(parameter, instabilityHandler, massPreconditioner, dArray);
            this.space = boundedSpace;
            ArrayList<Integer> arrayList = new ArrayList<Integer>();
            for (n = 0; n < parameter.getDimension(); ++n) {
                if (dArray != null && dArray[n] != 1.0) continue;
                arrayList.add(n);
            }
            this.defaultIndices = new int[arrayList.size()];
            for (n = 0; n < arrayList.size(); ++n) {
                this.defaultIndices[n] = (Integer)arrayList.get(n);
            }
        }

        @Override
        protected ReflectionEvent nextEvent(double[] dArray, WrappedVector wrappedVector, double d, boolean bl) throws HamiltonianMonteCarloOperator.NumericInstabilityException {
            double[] dArray2 = ReflectiveHamiltonianMonteCarloOperator.this.preconditioning.getVelocity(wrappedVector);
            double d2 = this.space.forwardDistanceToBoundary(dArray, dArray2, bl);
            if (d2 > d) {
                return new ReflectionEvent(ReflectionType.None, d, Double.NaN, new int[0]);
            }
            double[] dArray3 = new double[dArray.length];
            for (int i = 0; i < dArray.length; ++i) {
                dArray3[i] = dArray[i] + d2 * dArray2[i];
            }
            double[] dArray4 = this.space.getNormalVectorAtBoundary(dArray3);
            return new ReflectionEvent(ReflectionType.MultivariateReflection, d2, d - d2, dArray3, dArray4, this.defaultIndices);
        }
    }

    static enum ReflectionType {
        MultivariateReflection{

            @Override
            void doReflection(double[] dArray, MassPreconditioner massPreconditioner, WrappedVector wrappedVector, double[] dArray2, int[] nArray, double[] dArray3, double d, double d2) {
                this.updatePosition(dArray, massPreconditioner, wrappedVector, d);
                double d3 = 0.0;
                double d4 = 0.0;
                for (int n : nArray) {
                    d3 += wrappedVector.get(n) * dArray3[n];
                    d4 += dArray3[n] * dArray3[n];
                }
                double d5 = 2.0 * d3 / d4;
                for (int n : nArray) {
                    wrappedVector.set(n, wrappedVector.get(n) - d5 * dArray3[n]);
                    dArray[n] = dArray2[n];
                }
            }
        }
        ,
        Reflection{

            @Override
            void doReflection(double[] dArray, MassPreconditioner massPreconditioner, WrappedVector wrappedVector, double[] dArray2, int[] nArray, double[] dArray3, double d, double d2) {
                this.updatePosition(dArray, massPreconditioner, wrappedVector, d);
                wrappedVector.set(nArray[0], -wrappedVector.get(nArray[0]));
                dArray[nArray[0]] = dArray2[0];
            }
        }
        ,
        Collision{

            @Override
            void doReflection(double[] dArray, MassPreconditioner massPreconditioner, WrappedVector wrappedVector, double[] dArray2, int[] nArray, double[] dArray3, double d, double d2) {
                this.updatePosition(dArray, massPreconditioner, wrappedVector, d);
                ReadableVector readableVector = massPreconditioner.doCollision(nArray, wrappedVector);
                for (int n : nArray) {
                    wrappedVector.set(n, readableVector.get(n));
                    dArray[n] = dArray2[0];
                }
            }
        }
        ,
        None{

            @Override
            void doReflection(double[] dArray, MassPreconditioner massPreconditioner, WrappedVector wrappedVector, double[] dArray2, int[] nArray, double[] dArray3, double d, double d2) {
                this.updatePosition(dArray, massPreconditioner, wrappedVector, d);
            }

            @Override
            public boolean isAtBoundary() {
                return false;
            }
        };


        void updatePosition(double[] dArray, MassPreconditioner massPreconditioner, WrappedVector wrappedVector, double d) {
            int n = dArray.length;
            for (int i = 0; i < n; ++i) {
                int n2 = i;
                dArray[n2] = dArray[n2] + massPreconditioner.getVelocity(i, wrappedVector) * d;
            }
        }

        abstract void doReflection(double[] var1, MassPreconditioner var2, WrappedVector var3, double[] var4, int[] var5, double[] var6, double var7, double var9);

        public boolean isAtBoundary() {
            return true;
        }
    }

    class ReflectionEvent {
        private final ReflectionType type;
        private final double eventTime;
        private final double[] eventLocation;
        private final int[] indices;
        private final double[] normalVector;
        private final double remainingTime;

        ReflectionEvent(ReflectionType reflectionType, double d, double d2, double[] dArray, double[] dArray2, int[] nArray) {
            this.type = reflectionType;
            this.eventTime = d;
            this.indices = nArray;
            this.eventLocation = dArray;
            this.normalVector = dArray2;
            this.remainingTime = d2;
        }

        ReflectionEvent(ReflectionType reflectionType, double d, double d2, int[] nArray) {
            this(reflectionType, d, Double.NaN, new double[]{d2}, null, nArray);
        }

        public double getEventTime() {
            return this.eventTime;
        }

        public ReflectionType getType() {
            return this.type;
        }

        public boolean doReflection(double[] dArray, WrappedVector wrappedVector) {
            this.type.doReflection(dArray, ReflectiveHamiltonianMonteCarloOperator.this.preconditioning, wrappedVector, this.eventLocation, this.indices, this.normalVector, this.eventTime, this.remainingTime);
            return this.type.isAtBoundary();
        }
    }

    abstract class WithBounds
    extends HamiltonianMonteCarloOperator.LeapFrogEngine.Default {
        WithBounds(Parameter parameter, HamiltonianMonteCarloOperator.InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, double[] dArray) {
            super(parameter, instabilityHandler, massPreconditioner, dArray);
        }

        protected abstract ReflectionEvent nextEvent(double[] var1, WrappedVector var2, double var3, boolean var5) throws HamiltonianMonteCarloOperator.NumericInstabilityException;

        @Override
        public void updatePosition(double[] dArray, WrappedVector wrappedVector, double d) throws HamiltonianMonteCarloOperator.NumericInstabilityException {
            ReflectionEvent reflectionEvent;
            for (double d2 = 0.0; d2 < d; d2 += reflectionEvent.getEventTime()) {
                reflectionEvent = this.nextEvent(dArray, wrappedVector, d - d2, ReflectiveHamiltonianMonteCarloOperator.this.isAtBoundary);
                ReflectiveHamiltonianMonteCarloOperator.this.isAtBoundary = reflectionEvent.doReflection(dArray, wrappedVector);
            }
            this.setParameter(dArray);
        }
    }
}

