/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.GammaDistributionModel;
import dr.inference.distribution.GammaStatisticsProvider;
import dr.inference.distribution.LogNormalDistributionModel;
import dr.inference.distribution.NormalDistributionModel;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.inference.operators.repeatedMeasures.GammaGibbsProvider;
import dr.math.MathUtils;
import dr.math.distributions.Distribution;
import dr.math.distributions.GammaDistribution;
import dr.math.matrixAlgebra.Vector;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import dr.xml.XORRule;

public class NormalGammaPrecisionGibbsOperator
extends SimpleMCMCOperator
implements GibbsOperator,
Reportable {
    public static final String OPERATOR_NAME = "normalGammaPrecisionGibbsOperator";
    public static final String LIKELIHOOD = "likelihood";
    public static final String PRIOR = "prior";
    private static final String WORKING = "workingDistribution";
    private static final String INDICES = "indices";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), AttributeRule.newStringRule("indices", true), new XORRule(new ElementRule("likelihood", new XMLSyntaxRule[]{new ElementRule(DistributionLikelihood.class)}), new ElementRule(GammaGibbsProvider.class)), new ElementRule("prior", new XMLSyntaxRule[]{new XORRule(new ElementRule(DistributionLikelihood.class), new ElementRule(GammaStatisticsProvider.class))}), new ElementRule("workingDistribution", new XMLSyntaxRule[]{new XORRule(new ElementRule(DistributionLikelihood.class), new ElementRule(GammaStatisticsProvider.class))}, true)};

        @Override
        public String getParserName() {
            return NormalGammaPrecisionGibbsOperator.OPERATOR_NAME;
        }

        private void checkGammaDistribution(DistributionLikelihood distributionLikelihood) throws XMLParseException {
            if (!(distributionLikelihood.getDistribution() instanceof GammaDistribution) && !(distributionLikelihood.getDistribution() instanceof GammaDistributionModel)) {
                throw new XMLParseException("Gibbs operator assumes normal-gamma model");
            }
        }

        private GammaStatisticsProvider getGammaStatisticsProvider(Object object) throws XMLParseException {
            GammaStatisticsProvider gammaStatisticsProvider;
            if (object instanceof DistributionLikelihood) {
                DistributionLikelihood distributionLikelihood = (DistributionLikelihood)object;
                this.checkGammaDistribution(distributionLikelihood);
                gammaStatisticsProvider = new GammaParametrization(distributionLikelihood.getDistribution());
            } else if (object instanceof GammaStatisticsProvider) {
                gammaStatisticsProvider = (GammaStatisticsProvider)object;
            } else {
                throw new XMLParseException("Prior must be gamma");
            }
            return gammaStatisticsProvider;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            GammaGibbsProvider gammaGibbsProvider;
            Object object;
            double d = xMLObject.getDoubleAttribute("weight");
            Object object2 = xMLObject.getElementFirstChild(NormalGammaPrecisionGibbsOperator.PRIOR);
            GammaStatisticsProvider gammaStatisticsProvider = this.getGammaStatisticsProvider(object2);
            Object object3 = xMLObject.hasChildNamed(NormalGammaPrecisionGibbsOperator.WORKING) ? xMLObject.getElementFirstChild(NormalGammaPrecisionGibbsOperator.WORKING) : null;
            GammaStatisticsProvider gammaStatisticsProvider2 = null;
            if (object3 != null) {
                gammaStatisticsProvider2 = this.getGammaStatisticsProvider(object3);
            }
            if (xMLObject.hasChildNamed(NormalGammaPrecisionGibbsOperator.LIKELIHOOD)) {
                object = (DistributionLikelihood)xMLObject.getElementFirstChild(NormalGammaPrecisionGibbsOperator.LIKELIHOOD);
                if (!(((DistributionLikelihood)object).getDistribution() instanceof NormalDistributionModel) && !(((DistributionLikelihood)object).getDistribution() instanceof LogNormalDistributionModel)) {
                    throw new XMLParseException("Gibbs operator assumes normal-gamma model");
                }
                gammaGibbsProvider = new GammaGibbsProvider.Default((DistributionLikelihood)object);
            } else {
                gammaGibbsProvider = (GammaGibbsProvider)xMLObject.getChild(GammaGibbsProvider.class);
            }
            if (xMLObject.hasAttribute(NormalGammaPrecisionGibbsOperator.INDICES)) {
                object = xMLObject.getIntegerArrayAttribute(NormalGammaPrecisionGibbsOperator.INDICES);
                int n = 0;
                while (n < ((Object)object).length) {
                    Object object4 = object;
                    int n2 = n++;
                    object4[n2] = object4[n2] - true;
                }
            } else {
                int n = gammaGibbsProvider.getPrecisionParameter().getDimension();
                object = new int[n];
                for (int i = 0; i < n; ++i) {
                    object[i] = i;
                }
            }
            return new NormalGammaPrecisionGibbsOperator(gammaGibbsProvider, gammaStatisticsProvider, gammaStatisticsProvider2, (int[])object, d);
        }

        @Override
        public String getParserDescription() {
            return "This element returns a operator on the precision parameter of a normal model with gamma prior.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    private final GammaGibbsProvider gammaGibbsProvider;
    private final Parameter precisionParameter;
    private final GammaStatisticsProvider prior;
    private final GammaStatisticsProvider working;
    private double pathParameter = 1.0;
    private final int[] indices;

    public NormalGammaPrecisionGibbsOperator(GammaGibbsProvider gammaGibbsProvider, GammaStatisticsProvider gammaStatisticsProvider, int[] nArray, double d) {
        this(gammaGibbsProvider, gammaStatisticsProvider, null, nArray, d);
    }

    public NormalGammaPrecisionGibbsOperator(GammaGibbsProvider gammaGibbsProvider, GammaStatisticsProvider gammaStatisticsProvider, GammaStatisticsProvider gammaStatisticsProvider2, int[] nArray, double d) {
        this.gammaGibbsProvider = gammaGibbsProvider;
        this.precisionParameter = gammaGibbsProvider.getPrecisionParameter();
        this.prior = gammaStatisticsProvider;
        this.working = gammaStatisticsProvider2;
        this.indices = nArray;
        this.setWeight(d);
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override
    public String getOperatorName() {
        return OPERATOR_NAME;
    }

    @Override
    public String getReport() {
        int n = this.precisionParameter.getDimension();
        double[] dArray = new double[n];
        double[] dArray2 = new double[n];
        this.gammaGibbsProvider.drawValues();
        for (int i = 0; i < n; ++i) {
            GammaGibbsProvider.SufficientStatistics sufficientStatistics = this.gammaGibbsProvider.getSufficientStatistics(i);
            dArray[i] = sufficientStatistics.observationCount;
            dArray2[i] = sufficientStatistics.sumOfSquaredErrors;
        }
        StringBuilder stringBuilder = new StringBuilder("normalGammaPrecisionGibbsOperator report:\n");
        stringBuilder.append("Observation counts:\t");
        stringBuilder.append(new Vector(dArray));
        stringBuilder.append("\n");
        stringBuilder.append("Sum of squared errors:\t");
        stringBuilder.append(new Vector(dArray2));
        return stringBuilder.toString();
    }

    private double weigh(double d, double d2) {
        return (1.0 - this.pathParameter) * d + this.pathParameter * d2;
    }

    @Override
    public double doOperation() {
        this.gammaGibbsProvider.drawValues();
        for (int n : this.indices) {
            GammaGibbsProvider.SufficientStatistics sufficientStatistics = this.gammaGibbsProvider.getSufficientStatistics(n);
            double d = this.pathParameter * (double)sufficientStatistics.observationCount / 2.0;
            double d2 = this.pathParameter * sufficientStatistics.sumOfSquaredErrors / 2.0;
            if (this.working == null) {
                d += this.prior.getShape(n);
                d2 += this.prior.getRate(n);
            } else {
                d += this.weigh(this.prior.getShape(n), this.prior.getShape(n));
                d2 += this.weigh(this.prior.getRate(n), this.prior.getShape(n));
            }
            double d3 = MathUtils.nextGamma(d, d2);
            this.precisionParameter.setParameterValue(n, d3);
        }
        return 0.0;
    }

    @Override
    public void setPathParameter(double d) {
        if (d < 0.0 || d > 1.0) {
            throw new IllegalArgumentException("Invalid pathParameter value");
        }
        this.pathParameter = d;
    }

    public int getStepCount() {
        return 1;
    }

    static class GammaParametrization
    implements GammaStatisticsProvider {
        private final double rate;
        private final double shape;

        GammaParametrization(double d, double d2) {
            if (d == 0.0) {
                this.rate = 0.0;
                this.shape = -0.5;
            } else {
                this.rate = d / d2;
                this.shape = d * this.rate;
            }
        }

        GammaParametrization(Distribution distribution) {
            this(distribution.mean(), distribution.variance());
        }

        double getRate() {
            return this.rate;
        }

        double getShape() {
            return this.shape;
        }

        @Override
        public double getShape(int n) {
            return this.getShape();
        }

        @Override
        public double getRate(int n) {
            return this.getRate();
        }
    }
}

