/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.model;

import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.CrossValidationProvider;
import dr.inference.model.MatrixParameterInterface;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class FactorValidationProvider
implements CrossValidationProvider {
    private final CompoundParameter data;
    private final MatrixParameterInterface loadings;
    private final int nFac;
    private final int nTrait;
    private final int nTaxa;
    private final TreeTrait treeTrait;
    private final Tree tree;
    private final String id;
    private static final String TRAIT = "traitName";
    private static final String FACTOR_VALIDATION = "factorValidation";
    public static AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser(){

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            IntegratedFactorAnalysisLikelihood integratedFactorAnalysisLikelihood = (IntegratedFactorAnalysisLikelihood)xMLObject.getChild(IntegratedFactorAnalysisLikelihood.class);
            TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood)xMLObject.getChild(TreeDataLikelihood.class);
            String string = xMLObject.getStringAttribute(FactorValidationProvider.TRAIT);
            String string2 = null;
            if (xMLObject.hasId()) {
                string2 = xMLObject.getId();
            }
            return new FactorValidationProvider(integratedFactorAnalysisLikelihood, treeDataLikelihood, string, string2);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return new XMLSyntaxRule[]{new ElementRule(IntegratedFactorAnalysisLikelihood.class), new ElementRule(TreeDataLikelihood.class), AttributeRule.newStringRule(FactorValidationProvider.TRAIT)};
        }

        @Override
        public String getParserDescription() {
            return "Cross-validation between the latent factors and the observed data";
        }

        @Override
        public Class getReturnType() {
            return FactorValidationProvider.class;
        }

        @Override
        public String getParserName() {
            return FactorValidationProvider.FACTOR_VALIDATION;
        }
    };

    FactorValidationProvider(IntegratedFactorAnalysisLikelihood integratedFactorAnalysisLikelihood, TreeDataLikelihood treeDataLikelihood, String string, String string2) {
        this.data = integratedFactorAnalysisLikelihood.getParameter();
        this.loadings = integratedFactorAnalysisLikelihood.getLoadings();
        this.nFac = integratedFactorAnalysisLikelihood.getNumberOfFactors();
        this.nTrait = integratedFactorAnalysisLikelihood.getNumberOfTraits();
        this.nTaxa = integratedFactorAnalysisLikelihood.getNumberOfTaxa();
        this.treeTrait = treeDataLikelihood.getTreeTrait("tip." + string);
        this.tree = treeDataLikelihood.getTree();
        this.id = string2 == null ? FACTOR_VALIDATION : string2;
    }

    @Override
    public double[] getTrueValues() {
        double[] dArray = new double[this.nTrait * this.nTaxa];
        for (int i = 0; i < this.nTaxa; ++i) {
            int n = this.nTrait * i;
            for (int j = 0; j < this.nTrait; ++j) {
                dArray[j + n] = this.data.getParameter(i).getParameterValue(j);
            }
        }
        return dArray;
    }

    @Override
    public double[] getInferredValues() {
        double[] dArray = new double[this.nTrait * this.nTaxa];
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.nFac, this.nTrait);
        System.arraycopy(this.loadings.getParameterValues(), 0, denseMatrix64F.data, 0, this.loadings.getDimension());
        double[] dArray2 = (double[])this.treeTrait.getTrait(this.tree, null);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.nFac, 1);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.nTrait, 1);
        for (int i = 0; i < this.nTaxa; ++i) {
            int n = this.nFac * i;
            int n2 = this.nTrait * i;
            System.arraycopy(dArray2, n, denseMatrix64F2.data, 0, this.nFac);
            CommonOps.multTransA(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
            System.arraycopy(denseMatrix64F3.data, 0, dArray, n2, this.nTrait);
        }
        return dArray;
    }

    @Override
    public int[] getRelevantDimensions() {
        int n = this.nTaxa * this.nTrait;
        int[] nArray = new int[n];
        for (int i = 0; i < n; ++i) {
            nArray[i] = i;
        }
        return nArray;
    }

    @Override
    public String getName(int n) {
        String string = this.id;
        if (string == null) {
            string = this.treeTrait.getTraitName();
        }
        int n2 = n / this.nTrait;
        int n3 = n - n2 * this.nTrait;
        return string + "." + this.tree.getTaxonId(n2) + "." + n3;
    }

    @Override
    public String getNameSum(int n) {
        return this.id + ".sum";
    }
}

