/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.inference.messagepassing;

import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.exponentialfamily.EF_BayesianNetwork;
import eu.amidst.core.exponentialfamily.EF_ConditionalDistribution;
import eu.amidst.core.exponentialfamily.EF_UnivariateDistribution;
import eu.amidst.core.inference.InferenceAlgorithm;
import eu.amidst.core.inference.messagepassing.Message;
import eu.amidst.core.inference.messagepassing.Node;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.utils.Vector;
import eu.amidst.core.variables.Assignment;
import eu.amidst.core.variables.HashMapAssignment;
import eu.amidst.core.variables.Variable;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class MessagePassingAlgorithm<E extends Vector>
implements InferenceAlgorithm,
Serializable {
    static Logger logger = LoggerFactory.getLogger(MessagePassingAlgorithm.class);
    private static final long serialVersionUID = 4107783324901370839L;
    protected BayesianNetwork model;
    protected EF_BayesianNetwork ef_model;
    protected Assignment assignment = new HashMapAssignment(0);
    protected transient List<Node> nodes;
    protected transient Map<Variable, Node> variablesToNode;
    protected double probOfEvidence = Double.NaN;
    protected Random random = new Random(0L);
    protected int seed = 0;
    protected int maxIter = 1000;
    protected double threshold = 1.0E-6;
    protected boolean output = false;
    protected int nIter = 0;
    protected double local_elbo = -1.7976931348623157E308;
    protected int local_iter = 0;

    public void setOutput(boolean output) {
        this.output = output;
    }

    public boolean isOutput() {
        return this.output;
    }

    public void setThreshold(double threshold) {
        this.threshold = threshold;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public void setMaxIter(int maxIter) {
        this.maxIter = maxIter;
    }

    public int getMaxIter() {
        return this.maxIter;
    }

    public void resetQs() {
        this.nodes.stream().forEach(node -> node.resetQDist(this.random));
    }

    @Override
    public void setSeed(int seed) {
        this.seed = seed;
        this.random = new Random(seed);
    }

    @Override
    public void runInference() {
        this.nIter = 0;
        boolean convergence = false;
        this.local_elbo = Double.NEGATIVE_INFINITY;
        this.local_iter = 0;
        while (!convergence && this.local_iter++ < this.maxIter) {
            boolean done = true;
            for (Node node : this.nodes) {
                if (!node.isActive() || node.isObserved()) continue;
                Message<E> selfMessage = this.newSelfMessage(node);
                Optional<Message> message = node.getChildren().stream().filter(children -> children.isActive()).map(children -> this.newMessageToParent((Node)children, node)).reduce(Message::combineNonStateless);
                if (message.isPresent()) {
                    selfMessage.combine(message.get());
                }
                this.updateCombinedMessage(node, selfMessage);
                done &= node.isDone();
            }
            convergence = this.testConvergence();
            if (!done) continue;
            convergence = true;
        }
        this.probOfEvidence = this.local_elbo;
        if (this.output) {
            System.out.println("N Iter: " + this.local_iter + ", elbo:" + this.local_elbo);
            logger.info("N Iter: {}, elbo: {}", (Object)this.local_iter, (Object)this.local_elbo);
        }
        this.nIter = this.local_iter;
    }

    public int getNumberOfIterations() {
        return this.nIter;
    }

    @Override
    public void setModel(BayesianNetwork model_) {
        this.model = model_;
        this.setEFModel(new EF_BayesianNetwork(this.model));
    }

    public void setEFModel(EF_BayesianNetwork model) {
        this.ef_model = model;
        this.variablesToNode = new ConcurrentHashMap<Variable, Node>();
        this.nodes = this.ef_model.getDistributionList().stream().map(dist -> {
            Node node = new Node((EF_ConditionalDistribution)dist);
            this.variablesToNode.put(dist.getVariable(), node);
            return node;
        }).collect(Collectors.toList());
        for (Node node : this.nodes) {
            node.setParents(node.getPDist().getConditioningVariables().stream().map(this::getNodeOfVar).collect(Collectors.toList()));
            node.getPDist().getConditioningVariables().stream().forEach(var -> this.getNodeOfVar((Variable)var).getChildren().add(node));
        }
    }

    public EF_BayesianNetwork getEFModel() {
        return this.ef_model;
    }

    public Node getNodeOfVar(Variable variable) {
        return this.variablesToNode.get(variable);
    }

    public List<Node> getNodes() {
        return this.nodes;
    }

    public void setNodes(List<Node> nodes) {
        this.nodes = nodes;
        this.variablesToNode = new ConcurrentHashMap<Variable, Node>();
        nodes.stream().forEach(node -> this.variablesToNode.put(node.getMainVariable(), (Node)node));
    }

    public void updateChildrenAndParents() {
        for (Node node : this.nodes) {
            node.setParents(node.getPDist().getConditioningVariables().stream().map(this::getNodeOfVar).collect(Collectors.toList()));
            node.getPDist().getConditioningVariables().stream().forEach(var -> this.getNodeOfVar((Variable)var).getChildren().add(node));
        }
    }

    @Override
    public BayesianNetwork getOriginalModel() {
        return this.model;
    }

    @Override
    public void setEvidence(Assignment assignment_) {
        this.assignment = assignment_;
        this.nodes.stream().forEach(node -> node.setAssignment(this.assignment));
    }

    @Override
    public <E extends UnivariateDistribution> E getPosterior(Variable var) {
        return this.getNodeOfVar(var).getQDist().toUnivariateDistribution();
    }

    @Override
    public double getLogProbabilityOfEvidence() {
        return this.probOfEvidence;
    }

    public <E extends EF_UnivariateDistribution> E getEFPosterior(Variable var) {
        return (E)this.getNodeOfVar(var).getQDist();
    }

    public abstract Message<E> newSelfMessage(Node var1);

    public abstract Message<E> newMessageToParent(Node var1, Node var2);

    public abstract void updateCombinedMessage(Node var1, Message<E> var2);

    public abstract boolean testConvergence();

    public abstract double computeLogProbabilityOfEvidence();
}

