/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.epidemiology.casetocase.periodpriors;

import dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Parameter;
import dr.math.distributions.NormalDistribution;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.Arrays;

public class KnownVarianceNormalPeriodPriorDistribution
extends AbstractPeriodPriorDistribution {
    public static final String NORMAL = "knownVarianceNormalPeriodPriorDistribution";
    public static final String LOG = "log";
    public static final String ID = "id";
    public static final String MU_0 = "mu0";
    public static final String SIGMA = "sigma";
    public static final String SIGMA_0 = "sigma0";
    private NormalDistribution hyperprior;
    private Parameter posteriorMean;
    private Parameter posteriorVariance;
    private double sigma;
    private ArrayList<Double> dataValues;
    private double[] currentParameters;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newBooleanRule("log", true), AttributeRule.newStringRule("id", false), AttributeRule.newDoubleRule("mu0", false), AttributeRule.newDoubleRule("sigma", false), AttributeRule.newDoubleRule("sigma0", false)};

        @Override
        public String getParserName() {
            return KnownVarianceNormalPeriodPriorDistribution.NORMAL;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            String string = (String)xMLObject.getAttribute(KnownVarianceNormalPeriodPriorDistribution.ID);
            double d = xMLObject.getDoubleAttribute(KnownVarianceNormalPeriodPriorDistribution.MU_0);
            double d2 = xMLObject.getDoubleAttribute(KnownVarianceNormalPeriodPriorDistribution.SIGMA);
            double d3 = xMLObject.getDoubleAttribute(KnownVarianceNormalPeriodPriorDistribution.SIGMA_0);
            boolean bl = xMLObject.hasAttribute(KnownVarianceNormalPeriodPriorDistribution.LOG) ? xMLObject.getBooleanAttribute(KnownVarianceNormalPeriodPriorDistribution.LOG) : false;
            return new KnownVarianceNormalPeriodPriorDistribution(string, bl, d2, d, d3);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public String getParserDescription() {
            return "Calculates the probability of a set of doubles being drawn from the prior posterior distributionof a normal distribution of unknown mean and known standard deviation sigma";
        }

        @Override
        public Class getReturnType() {
            return KnownVarianceNormalPeriodPriorDistribution.class;
        }
    };

    public KnownVarianceNormalPeriodPriorDistribution(String string, boolean bl, double d, NormalDistribution normalDistribution) {
        super(string, bl);
        this.hyperprior = normalDistribution;
        this.posteriorVariance = new Parameter.Default(1);
        this.posteriorMean = new Parameter.Default(1);
        this.addVariable(this.posteriorVariance);
        this.addVariable(this.posteriorMean);
        this.sigma = d;
    }

    public KnownVarianceNormalPeriodPriorDistribution(String string, boolean bl, double d, double d2, double d3) {
        this(string, bl, d, new NormalDistribution(d2, d3));
    }

    @Override
    public void reset() {
        this.dataValues = new ArrayList();
        this.currentParameters[0] = this.hyperprior.getMean();
        this.currentParameters[1] = this.hyperprior.getSD();
        this.logL = 0.0;
    }

    @Override
    public double calculateLogPosteriorProbability(double d, double d2) {
        double d3 = this.calculateLogPosteriorPredictiveProbability(d);
        if (d2 != Double.NEGATIVE_INFINITY) {
            d3 -= this.calculateLogPosteriorPredictiveCDF(d2, true);
        }
        this.logL += d3;
        this.update(d);
        return d3;
    }

    @Override
    public double calculateLogPosteriorCDF(double d, boolean bl) {
        return this.calculateLogPosteriorPredictiveCDF(d, bl);
    }

    public double calculateLogPosteriorPredictiveProbability(double d) {
        double d2 = this.currentParameters[0];
        double d3 = this.currentParameters[1];
        return NormalDistribution.logPdf(d, d2, Math.sqrt(Math.pow(d3, 2.0) + Math.pow(this.sigma, 2.0)));
    }

    public double calculateLogPosteriorPredictiveCDF(double d, boolean bl) {
        double d2 = this.currentParameters[0];
        double d3 = this.currentParameters[1];
        double d4 = (d - d2) / Math.sqrt(Math.pow(d3, 2.0) + Math.pow(this.sigma, 2.0));
        return bl ? NormalDistribution.standardCDF(-d4, true) : NormalDistribution.standardCDF(d4, true);
    }

    private void update(double d) {
        this.dataValues.add(d);
        double d2 = this.hyperprior.getMean();
        double d3 = this.hyperprior.getSD();
        double d4 = this.dataValues.size();
        double d5 = 0.0;
        for (double d6 : this.dataValues) {
            d5 += d6;
        }
        double d7 = Math.sqrt(1.0 / (d4 / Math.pow(this.sigma, 2.0) + 1.0 / Math.pow(d3, 2.0)));
        double d8 = Math.pow(d7, 2.0) * (d2 / Math.pow(d3, 2.0) + d4 * (d5 /= d4) / Math.pow(this.sigma, 2.0));
        this.currentParameters = new double[]{d8, d7};
    }

    @Override
    public double calculateLogLikelihood(double[] dArray) {
        int n = dArray.length;
        double d = this.hyperprior.getMean();
        double d2 = this.hyperprior.getSD();
        double d3 = Math.pow(this.sigma, 2.0);
        double d4 = Math.pow(d2, 2.0);
        double d5 = 0.0;
        double d6 = 0.0;
        double[] dArray2 = dArray;
        int n2 = dArray2.length;
        for (int i = 0; i < n2; ++i) {
            Double d7 = dArray2[i];
            d5 += d7.doubleValue();
            d6 += Math.pow(d7, 2.0);
        }
        double d8 = d5 / (double)n;
        this.posteriorMean.setParameterValue(0, (d / d4 + d5 / d3) / (1.0 / d4 + (double)n / d3));
        this.posteriorVariance.setParameterValue(0, 1.0 / (1.0 / d4 + (double)n / d3));
        this.logL = Math.log(this.sigma) - (double)n * Math.log(Math.sqrt(Math.PI * 2) * this.sigma) - Math.log(Math.sqrt((double)n * d4 + d3)) + -d6 / (2.0 * d3) - Math.pow(d, 2.0) / (2.0 * d4) + (Math.pow(d2 * (double)n * d8 / this.sigma, 2.0) + Math.pow(this.sigma * d / d2, 2.0) + (double)(2 * n) * d8 * d) / (2.0 * ((double)n * d4 + d3));
        return this.logL;
    }

    @Override
    public LogColumn[] getColumns() {
        ArrayList<LogColumn> arrayList = new ArrayList<LogColumn>(Arrays.asList(super.getColumns()));
        arrayList.add(new LogColumn.Abstract(this.getModelName() + "_posteriorMean"){

            @Override
            protected String getFormattedValue() {
                return String.valueOf(KnownVarianceNormalPeriodPriorDistribution.this.posteriorMean.getParameterValue(0));
            }
        });
        arrayList.add(new LogColumn.Abstract(this.getModelName() + "_posteriorVariance"){

            @Override
            protected String getFormattedValue() {
                return String.valueOf(KnownVarianceNormalPeriodPriorDistribution.this.posteriorVariance.getParameterValue(0));
            }
        });
        return arrayList.toArray(new LogColumn[arrayList.size()]);
    }
}

