/*
 * Decompiled with CFR 0.152.
 */
package sampler;

import cern.jet.random.tdouble.Gamma;
import cern.jet.random.tdouble.Normal;
import cern.jet.random.tdouble.engine.DoubleMersenneTwister;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import org.apache.commons.math3.distribution.BetaDistribution;
import utils.MathUtil;
import utils.ProgressPopup;

public class InsilicoSampler {
    public int N;
    public int S;
    public int C;
    public int N_sub;
    public int N_level;
    public double[][] probbase;
    public double[][] probbase_order;
    public double[] level_values;
    public HashMap<Integer, HashMap<Integer, ArrayList<Integer>>> probbase_level;
    public int[][] count_m;
    public int[][] count_m_all;
    public int[] count_c;

    public void initiate(int N, int S, int C, int N_sub, int N_level, int[] subpop, double[][] probbase, double[][] probbase_order, double[] level_values) {
        this.N = N;
        this.S = S;
        this.C = C;
        this.N_sub = N_sub;
        this.N_level = N_level;
        this.probbase = probbase;
        this.probbase_order = probbase_order;
        this.count_m = new int[S][C];
        this.count_m_all = new int[S][C];
        this.count_c = new int[C];
        this.level_values = level_values;
        this.probbase_level = new HashMap();
        this.levelize();
    }

    private void levelize() {
        for (int s = 0; s < this.S; ++s) {
            for (int c = 0; c < this.C; ++c) {
                int level = (int)this.probbase_order[s][c];
                if (this.probbase_level.get(c) == null) {
                    this.probbase_level.put(c, new HashMap());
                }
                if (this.probbase_level.get(c).get(level) == null) {
                    this.probbase_level.get(c).put(level, new ArrayList());
                }
                this.probbase_level.get(c).get(level).add(s);
            }
        }
    }

    public void CountCurrent(double[][] indic, int[] ynew) {
        this.count_m = new int[this.S][this.C];
        this.count_m_all = new int[this.S][this.C];
        this.count_c = new int[this.C];
        for (int n = 0; n < this.N; ++n) {
            int c_current;
            int n2 = c_current = ynew[n];
            this.count_c[n2] = this.count_c[n2] + 1;
            for (int s = 0; s < this.S; ++s) {
                if (indic[n][s] == 1.0) {
                    int[] nArray = this.count_m[s];
                    int n3 = c_current;
                    nArray[n3] = nArray[n3] + 1;
                    int[] nArray2 = this.count_m_all[s];
                    int n4 = c_current;
                    nArray2[n4] = nArray2[n4] + 1;
                    continue;
                }
                if (indic[n][s] != 0.0) continue;
                int[] nArray = this.count_m_all[s];
                int n5 = c_current;
                nArray[n5] = nArray[n5] + 1;
            }
        }
    }

    public double[][] pnb(boolean contains_missing, double[][] indic, double[][] csmf_sub, int[] subpop) {
        int n;
        double[][] nb = new double[this.N][this.C];
        for (n = 0; n < this.N; ++n) {
            for (int c = 0; c < this.C; ++c) {
                nb[n][c] = csmf_sub[subpop[n]][c];
            }
        }
        for (n = 0; n < this.N; ++n) {
            ArrayList<Integer> nomissing = new ArrayList<Integer>();
            for (int s = 0; s < indic[n].length; ++s) {
                if (!(indic[n][s] >= 0.0)) continue;
                nomissing.add(s);
            }
            for (int c = 0; c < this.C; ++c) {
                Iterator i$ = nomissing.iterator();
                while (i$.hasNext()) {
                    int s = (Integer)i$.next();
                    double[] dArray = nb[n];
                    int n2 = c;
                    dArray[n2] = dArray[n2] * (indic[n][s] > 0.0 ? this.probbase[s][c] : 1.0 - this.probbase[s][c]);
                }
            }
            nb[n] = MathUtil.norm(nb[n]);
        }
        return nb;
    }

    public int[] sampleY(double[][] pnb, Random rand) {
        int[] y = new int[this.N];
        block0: for (int n = 0; n < this.N; ++n) {
            double u = rand.nextDouble();
            double cum = 0.0;
            for (int c = 0; c < this.C; ++c) {
                if (!(u < (cum += pnb[n][c]))) continue;
                y[n] = c;
                continue block0;
            }
        }
        return y;
    }

    public double[] thetaBlockUpdate(double jumprange, double[] mu, double sigma2, double[] theta, int[] Y, boolean jump_prop, Normal rngN, Random rand) {
        double[] jump = new double[this.C];
        for (int c = 0; c < this.C; ++c) {
            jump[c] = jumprange;
        }
        if (jump_prop) {
            // empty if block
        }
        double[] theta_new = new double[this.C];
        theta_new[0] = 1.0;
        double expsum = Math.exp(1.0);
        double expsum_new = Math.exp(1.0);
        for (int c = 1; c < this.C; ++c) {
            theta_new[c] = rngN.nextDouble(theta[c], jump[c]);
            expsum += Math.exp(theta[c]);
            expsum_new += Math.exp(theta_new[c]);
        }
        double logTrans = 0.0;
        for (int c = 0; c < this.C; ++c) {
            double diffquad = (theta_new[c] - mu[c]) * (theta_new[c] - mu[c]) - (theta[c] - mu[c]) * (theta[c] - mu[c]);
            logTrans += (double)Y[c] * (theta_new[c] - theta[c] - Math.log(expsum_new / expsum)) - 1.0 / (2.0 * sigma2) * diffquad;
        }
        double u = Math.log(rand.nextDouble());
        if (!(logTrans >= u)) {
            return theta;
        }
        return theta_new;
    }

    public void TruncBeta(Random rand, double[] prior_a, double prior_b, double trunc_min, double trunc_max) {
        double a = 0.0;
        double b = 0.0;
        double[][] new_probbase = new double[this.S][this.C];
        for (int s = 0; s < this.S; ++s) {
            for (int c = 0; c < this.C; ++c) {
                new_probbase[s][c] = this.probbase[s][c];
            }
        }
        for (int c = 0; c < this.C; ++c) {
            HashMap<Integer, ArrayList<Integer>> levels_under_c = this.probbase_level.get(c);
            double[] prob_under_c = MathUtil.grab2(this.probbase, c);
            double[] new_prob_under_c = new double[this.S];
            ArrayList<Integer> exist_levels_under_c = new ArrayList<Integer>();
            for (int l = 1; l <= this.N_level; ++l) {
                if (levels_under_c.get(l) == null) continue;
                exist_levels_under_c.add(l);
            }
            for (int index = 0; index < exist_levels_under_c.size(); ++index) {
                int l_current = (Integer)exist_levels_under_c.get(index);
                for (int s : levels_under_c.get(l_current)) {
                    int l_next;
                    int count = this.count_m[s][c];
                    int count_all = this.count_m_all[s][c];
                    double lower = 0.0;
                    double upper = 1.0;
                    if (index == 0) {
                        l_next = (Integer)exist_levels_under_c.get(index + 1);
                        lower = MathUtil.array_max(prob_under_c, levels_under_c.get(l_next));
                        lower = Math.max(lower, trunc_min);
                        upper = trunc_max;
                    } else if (index == exist_levels_under_c.size() - 1) {
                        lower = trunc_min;
                        int l_prev = (Integer)exist_levels_under_c.get(index - 1);
                        upper = MathUtil.array_min(new_prob_under_c, levels_under_c.get(l_prev));
                        upper = Math.min(upper, trunc_max);
                    } else {
                        l_next = (Integer)exist_levels_under_c.get(index + 1);
                        lower = MathUtil.array_max(prob_under_c, levels_under_c.get(l_next));
                        lower = Math.max(lower, trunc_min);
                        int l_prev = (Integer)exist_levels_under_c.get(index - 1);
                        upper = MathUtil.array_min(new_prob_under_c, levels_under_c.get(l_prev));
                        upper = Math.min(upper, trunc_max);
                    }
                    if (lower >= upper) {
                        new_prob_under_c[s] = upper;
                        continue;
                    }
                    a = prior_a[l_current - 1] + (double)count;
                    b = prior_b + (double)count_all - a;
                    BetaDistribution beta = new BetaDistribution(a, b, 1.0E-10);
                    new_prob_under_c[s] = MathUtil.truncbeta(beta, rand, lower, upper);
                }
            }
            for (int s = 0; s < this.S; ++s) {
                this.probbase[s][c] = new_prob_under_c[s];
            }
        }
    }

    public void TruncBeta_pool(Random rand, double[] prior_a, double prior_b, double trunc_min, double trunc_max) {
        double a = 0.0;
        double b = 0.0;
        double[][] new_probbase = new double[this.S][this.C];
        for (int s = 0; s < this.S; ++s) {
            for (int c = 0; c < this.C; ++c) {
                new_probbase[s][c] = this.probbase[s][c];
            }
        }
        double[] new_level_values = new double[this.N_level];
        for (int l = 1; l <= this.N_level; ++l) {
            int count = 0;
            int count_all = 0;
            for (int c = 0; c < this.C; ++c) {
                if (this.probbase_level.get(c).get(l) == null) continue;
                for (int s : this.probbase_level.get(c).get(l)) {
                    count += this.count_m[s][c];
                    count_all += this.count_m_all[s][c];
                }
            }
            double lower = 0.0;
            double upper = 1.0;
            if (l == 1) {
                lower = Math.max(this.level_values[l], trunc_min);
                upper = trunc_max;
            } else if (l == this.N_level) {
                lower = trunc_min;
                upper = Math.min(new_level_values[l - 2], trunc_max);
            } else {
                lower = Math.max(this.level_values[l], trunc_min);
                upper = Math.min(new_level_values[l - 2], trunc_max);
            }
            if (lower >= upper) {
                new_level_values[l - 1] = upper;
                continue;
            }
            a = prior_a[l - 1] + (double)count;
            b = prior_b + (double)count_all - a;
            BetaDistribution beta = new BetaDistribution(a, b, 1.0E-10);
            new_level_values[l - 1] = MathUtil.truncbeta(beta, rand, lower, upper);
        }
        this.level_values = new_level_values;
        for (int s = 0; s < this.S; ++s) {
            for (int c = 0; c < this.C; ++c) {
                this.probbase[s][c] = this.level_values[(int)this.probbase_order[s][c] - 1];
            }
        }
    }

    public static double[] Fit(int N, int S, int C, int N_sub, int N_level, double[][] probbase, double[][] probbase_order, double[] level_values, double[] prior_a, double prior_b, double jumprange, double trunc_min, double trunc_max, double[][] indic, int[] subpop, int contains_missing, int pool, int seed, int N_gibbs, int burn, int thin, double[] mu, double sigma2, boolean this_is_Unix, boolean useProbbase, boolean isAdded, double[][] mu_continue, double[] sigma2_continue, double[][] theta_continue) {
        int sub;
        int c;
        double expsum;
        int sub2;
        InsilicoSampler insilico = new InsilicoSampler();
        insilico.initiate(N, S, C, N_sub, N_level, subpop, probbase, probbase_order, level_values);
        System.out.printf("Insilico Sampler initiated, %d iterations to sample\n", N_gibbs);
        DoubleMersenneTwister rngEngine = new DoubleMersenneTwister(seed);
        Normal rngN = new Normal(0.0, 1.0, rngEngine);
        Gamma rngG = new Gamma(1.0, 1.0, rngEngine);
        Random rand = new Random(seed);
        int N_thin = (int)((double)(N_gibbs - burn) / ((double)thin + 0.0));
        int n_report = Math.max(N_gibbs / 20, 100);
        if (N_gibbs < 200) {
            n_report = 50;
        }
        double[][][] probbase_gibbs = new double[N_thin][S][C];
        double[][] levels_gibbs = new double[N_thin][N_level];
        double[][][] p_gibbs = new double[N_thin][N_sub][C];
        double[][] pnb_mean = new double[N][C];
        int[] naccept = new int[N_sub];
        double[][] mu_now = new double[N_sub][C];
        double[] sigma2_now = new double[N_sub];
        double[][] theta_now = new double[N_sub][C];
        double[][] p_now = new double[N_sub][C];
        if (!isAdded) {
            for (sub2 = 0; sub2 < N_sub; ++sub2) {
                mu_now[sub2] = mu;
                sigma2_now[sub2] = sigma2;
                theta_now[sub2][0] = 1.0;
                expsum = Math.exp(1.0);
                for (c = 1; c < C; ++c) {
                    theta_now[sub2][c] = Math.log(rand.nextDouble() * 100.0);
                    expsum += Math.exp(theta_now[sub2][c]);
                }
                for (c = 0; c < C; ++c) {
                    p_now[sub2][c] = Math.exp(theta_now[sub2][c]) / expsum;
                }
            }
        } else {
            for (sub2 = 0; sub2 < N_sub; ++sub2) {
                mu_now[sub2] = mu_continue[sub2];
                sigma2_now[sub2] = sigma2_continue[sub2];
                theta_now[sub2] = theta_continue[sub2];
                expsum = Math.exp(1.0);
                for (c = 1; c < C; ++c) {
                    expsum += Math.exp(theta_now[sub2][c]);
                }
                for (c = 0; c < C; ++c) {
                    p_now[sub2][c] = Math.exp(theta_now[sub2][c]) / expsum;
                }
            }
        }
        double[][] pnb = new double[N][C];
        pnb = insilico.pnb(contains_missing == 1, indic, p_now, subpop);
        long start = System.currentTimeMillis();
        ProgressPopup popup = new ProgressPopup(this_is_Unix, N_gibbs);
        for (int k = 0; k < N_gibbs; ++k) {
            int d1;
            if (!this_is_Unix) {
                popup.update(k);
            }
            int[] y_new = insilico.sampleY(pnb, rand);
            int[][] Y = new int[N_sub][C];
            for (int n = 0; n < N; ++n) {
                int[] nArray = Y[subpop[n]];
                int n2 = y_new[n];
                nArray[n2] = nArray[n2] + 1;
            }
            for (int sub3 = 0; sub3 < N_sub; ++sub3) {
                int c2;
                int c3;
                double mu_mean = 0.0;
                for (c3 = 0; c3 < C; ++c3) {
                    mu_mean += theta_now[sub3][c3];
                }
                mu_mean /= (double)C + 0.0;
                mu_mean = rngN.nextDouble(mu_mean, Math.sqrt(sigma2_now[sub3] / ((double)C + 0.0)));
                for (c3 = 0; c3 < C; ++c3) {
                    mu_now[sub3][c3] = mu_mean;
                }
                double shape = ((double)C - 1.0) / 2.0;
                double rate2 = 0.0;
                for (int c4 = 0; c4 < C; ++c4) {
                    rate2 += Math.pow(theta_now[sub3][c4] - mu_now[sub3][c4], 2.0);
                }
                sigma2_now[sub3] = 1.0 / rngG.nextDouble(shape, rate2 / 2.0);
                double[] theta_prev = theta_now[sub3];
                theta_now[sub3] = insilico.thetaBlockUpdate(jumprange, mu_now[sub3], sigma2_now[sub3], theta_prev, Y[sub3], false, rngN, rand);
                if (theta_now[sub3][1] != theta_prev[1]) {
                    int n = sub3;
                    naccept[n] = naccept[n] + 1;
                }
                double expsum2 = 0.0;
                for (c2 = 0; c2 < C; ++c2) {
                    expsum2 += Math.exp(theta_now[sub3][c2]);
                }
                for (c2 = 0; c2 < C; ++c2) {
                    p_now[sub3][c2] = Math.exp(theta_now[sub3][c2]) / expsum2;
                }
            }
            if (!useProbbase) {
                insilico.CountCurrent(indic, y_new);
                if (pool == 1) {
                    insilico.TruncBeta_pool(rand, prior_a, prior_b, trunc_min, trunc_max);
                } else {
                    insilico.TruncBeta(rand, prior_a, prior_b, trunc_min, trunc_max);
                }
            }
            pnb = insilico.pnb(contains_missing == 1, indic, p_now, subpop);
            if (k % 10 == 0) {
                System.out.printf(".", new Object[0]);
            }
            if (k % n_report == 0 & k != 0) {
                long now = System.currentTimeMillis();
                String message = String.format("\nIteration: %d \n", k);
                for (int sub4 = 0; sub4 < N_sub; ++sub4) {
                    double ratio = (double)naccept[sub4] / ((double)k + 0.0);
                    message = message + String.format("Sub-population %d acceptance ratio: %.2f \n", sub4, ratio);
                }
                System.out.printf(message, new Object[0]);
                System.out.printf("%.2fmin elapsed, %.2fmin remaining \n", (double)(now - start) / 1000.0 / 60.0, (double)(now - start) / 1000.0 / 60.0 / ((double)k + 0.0) * (double)(N_gibbs - k));
                if (!this_is_Unix) {
                    popup.message(k, message);
                }
            }
            if (k < burn || (k - burn + 1) % thin != 0) continue;
            int save = (k - burn + 1) / thin - 1;
            for (d1 = 0; d1 < N; ++d1) {
                for (int d2 = 0; d2 < C; ++d2) {
                    double[] dArray = pnb_mean[d1];
                    int n = d2;
                    dArray[n] = dArray[n] + pnb[d1][d2];
                }
            }
            for (d1 = 0; d1 < N_sub; ++d1) {
                for (int d2 = 0; d2 < C; ++d2) {
                    p_gibbs[save][d1][d2] = p_now[d1][d2];
                }
            }
            if (pool == 1) {
                for (d1 = 0; d1 < N_level; ++d1) {
                    levels_gibbs[save][d1] = insilico.level_values[d1];
                }
                continue;
            }
            for (d1 = 0; d1 < S; ++d1) {
                for (int d2 = 0; d2 < C; ++d2) {
                    probbase_gibbs[save][d1][d2] = insilico.probbase[d1][d2];
                }
            }
        }
        if (!this_is_Unix) {
            popup.close();
        }
        System.out.println("\nOverall acceptance ratio");
        for (int sub5 = 0; sub5 < N_sub; ++sub5) {
            double ratio = (double)naccept[sub5] / ((double)N_gibbs + 0.0);
            System.out.printf("Sub-population %d : %.4f \n", sub5, ratio);
        }
        int N_out = 1 + N_sub * C * N_thin + N * C + N_sub * (C * 2 + 1);
        N_out = pool == 1 ? (N_out += N_level * N_thin) : (N_out += S * C * N_thin);
        double[] parameters = new double[N_out];
        int counter = 0;
        parameters[0] = N_thin;
        counter = 1;
        for (sub = 0; sub < N_sub; ++sub) {
            for (int k = 0; k < N_thin; ++k) {
                for (int c5 = 0; c5 < C; ++c5) {
                    parameters[counter] = p_gibbs[k][sub][c5];
                    ++counter;
                }
            }
        }
        for (int n = 0; n < N; ++n) {
            for (int c6 = 0; c6 < C; ++c6) {
                parameters[counter] = pnb_mean[n][c6] / ((double)N_thin + 0.0);
                ++counter;
            }
        }
        if (pool != 1) {
            for (int c7 = 0; c7 < C; ++c7) {
                for (int s = 0; s < S; ++s) {
                    for (int k = 0; k < N_thin; ++k) {
                        parameters[counter] = probbase_gibbs[k][s][c7];
                        ++counter;
                    }
                }
            }
        } else {
            for (int k = 0; k < N_thin; ++k) {
                for (int l = 1; l <= N_level; ++l) {
                    parameters[counter] = levels_gibbs[k][l - 1];
                    ++counter;
                }
            }
        }
        for (sub = 0; sub < N_sub; ++sub) {
            for (int c8 = 0; c8 < C; ++c8) {
                parameters[counter] = mu_now[sub][c8];
                ++counter;
            }
        }
        for (sub = 0; sub < N_sub; ++sub) {
            parameters[counter] = sigma2_now[sub];
            ++counter;
        }
        for (sub = 0; sub < N_sub; ++sub) {
            for (int c9 = 0; c9 < C; ++c9) {
                parameters[counter] = theta_now[sub][c9];
                ++counter;
            }
        }
        System.out.println("Organizing output, might take a moment...");
        return parameters;
    }
}

