/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class AODEsr
extends Classifier
implements OptionHandler,
WeightedInstancesHandler,
UpdateableClassifier,
TechnicalInformationHandler {
    static final long serialVersionUID = 5602143019183068848L;
    private double[][][] m_CondiCounts;
    private double[][] m_CondiCountsNoClass;
    private double[] m_ClassCounts;
    private double[][] m_SumForCounts;
    private int m_NumClasses;
    private int m_NumAttributes;
    private int m_NumInstances;
    private int m_ClassIndex;
    private Instances m_Instances;
    private int m_TotalAttValues;
    private int[] m_StartAttIndex;
    private int[] m_NumAttValues;
    private double[] m_Frequencies;
    private double m_SumInstances;
    private int m_Limit = 1;
    private boolean m_Debug = false;
    protected double m_MWeight = 1.0;
    private boolean m_Laplace = false;
    private int m_Critical = 50;

    public String globalInfo() {
        return "AODEsr augments AODE with Subsumption Resolution.AODEsr detects specializations between two attribute values at classification time and deletes the generalization attribute value.\nFor more information, see:\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Fei Zheng and Geoffrey I. Webb");
        result.setValue(TechnicalInformation.Field.YEAR, "2006");
        result.setValue(TechnicalInformation.Field.TITLE, "Efficient Lazy Elimination for Averaged-One Dependence Estimators");
        result.setValue(TechnicalInformation.Field.PAGES, "1113-1120");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the Twenty-third International Conference on Machine  Learning (ICML 2006)");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "ACM Press");
        result.setValue(TechnicalInformation.Field.ISBN, "1-59593-383-2");
        return result;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        this.m_Instances = new Instances(instances);
        this.m_Instances.deleteWithMissingClass();
        this.m_SumInstances = 0.0;
        this.m_ClassIndex = instances.classIndex();
        this.m_NumInstances = this.m_Instances.numInstances();
        this.m_NumAttributes = instances.numAttributes();
        this.m_NumClasses = instances.numClasses();
        this.m_StartAttIndex = new int[this.m_NumAttributes];
        this.m_NumAttValues = new int[this.m_NumAttributes];
        this.m_TotalAttValues = 0;
        int i = 0;
        while (i < this.m_NumAttributes) {
            if (i != this.m_ClassIndex) {
                this.m_StartAttIndex[i] = this.m_TotalAttValues;
                this.m_NumAttValues[i] = this.m_Instances.attribute(i).numValues();
                this.m_TotalAttValues += this.m_NumAttValues[i] + 1;
            } else {
                this.m_NumAttValues[i] = this.m_NumClasses;
            }
            ++i;
        }
        this.m_CondiCounts = new double[this.m_NumClasses][this.m_TotalAttValues][this.m_TotalAttValues];
        this.m_ClassCounts = new double[this.m_NumClasses];
        this.m_SumForCounts = new double[this.m_NumClasses][this.m_NumAttributes];
        this.m_Frequencies = new double[this.m_TotalAttValues];
        this.m_CondiCountsNoClass = new double[this.m_TotalAttValues][this.m_TotalAttValues];
        int k = 0;
        while (k < this.m_NumInstances) {
            this.addToCounts(this.m_Instances.instance(k));
            ++k;
        }
        this.m_Instances = new Instances(this.m_Instances, 0);
    }

    @Override
    public void updateClassifier(Instance instance) {
        this.addToCounts(instance);
    }

    private void addToCounts(Instance instance) {
        if (instance.classIsMissing()) {
            return;
        }
        int classVal = (int)instance.classValue();
        double weight = instance.weight();
        int n = classVal;
        this.m_ClassCounts[n] = this.m_ClassCounts[n] + weight;
        this.m_SumInstances += weight;
        int[] attIndex = new int[this.m_NumAttributes];
        int i = 0;
        while (i < this.m_NumAttributes) {
            attIndex[i] = i == this.m_ClassIndex ? -1 : (instance.isMissing(i) ? this.m_StartAttIndex[i] + this.m_NumAttValues[i] : this.m_StartAttIndex[i] + (int)instance.value(i));
            ++i;
        }
        int Att1 = 0;
        while (Att1 < this.m_NumAttributes) {
            if (attIndex[Att1] != -1) {
                int n2 = attIndex[Att1];
                this.m_Frequencies[n2] = this.m_Frequencies[n2] + weight;
                if (!instance.isMissing(Att1)) {
                    double[] dArray = this.m_SumForCounts[classVal];
                    int n3 = Att1;
                    dArray[n3] = dArray[n3] + weight;
                }
                double[] countsPointer = this.m_CondiCounts[classVal][attIndex[Att1]];
                double[] countsNoClassPointer = this.m_CondiCountsNoClass[attIndex[Att1]];
                int Att2 = 0;
                while (Att2 < this.m_NumAttributes) {
                    if (attIndex[Att2] != -1) {
                        int n4 = attIndex[Att2];
                        countsPointer[n4] = countsPointer[n4] + weight;
                        int n5 = attIndex[Att2];
                        countsNoClassPointer[n5] = countsNoClassPointer[n5] + weight;
                    }
                    ++Att2;
                }
            }
            ++Att1;
        }
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] probs = new double[this.m_NumClasses];
        int[] SpecialGeneralArray = new int[this.m_NumAttributes];
        int[] attIndex = new int[this.m_NumAttributes];
        int att = 0;
        while (att < this.m_NumAttributes) {
            attIndex[att] = instance.isMissing(att) || att == this.m_ClassIndex ? -1 : this.m_StartAttIndex[att] + (int)instance.value(att);
            ++att;
        }
        int i = 0;
        while (i < this.m_NumAttributes) {
            SpecialGeneralArray[i] = -1;
            ++i;
        }
        i = 0;
        while (i < this.m_NumAttributes) {
            if (attIndex[i] != -1) {
                double[] countsForAtti = this.m_CondiCountsNoClass[attIndex[i]];
                int j = 0;
                while (j < this.m_NumAttributes) {
                    double[] countsForAttj;
                    if (attIndex[j] != -1 && i != j && SpecialGeneralArray[j] != i && (countsForAttj = this.m_CondiCountsNoClass[attIndex[j]])[attIndex[j]] > (double)this.m_Critical && countsForAttj[attIndex[j]] == countsForAtti[attIndex[j]] && (countsForAttj[attIndex[j]] != countsForAtti[attIndex[i]] || i >= j)) {
                        SpecialGeneralArray[i] = j;
                        break;
                    }
                    ++j;
                }
            }
            ++i;
        }
        int classVal = 0;
        while (classVal < this.m_NumClasses) {
            probs[classVal] = 0.0;
            double x = 0.0;
            int parentCount = 0;
            double[][] countsForClass = this.m_CondiCounts[classVal];
            int parent = 0;
            while (parent < this.m_NumAttributes) {
                int pIndex;
                if (attIndex[parent] != -1 && !(this.m_Frequencies[pIndex = attIndex[parent]] < (double)this.m_Limit) && SpecialGeneralArray[parent] == -1) {
                    double[] countsForClassParent = countsForClass[pIndex];
                    attIndex[parent] = -1;
                    ++parentCount;
                    double classparentfreq = countsForClassParent[pIndex];
                    double missing4ParentAtt = this.m_Frequencies[this.m_StartAttIndex[parent] + this.m_NumAttValues[parent]];
                    x = this.m_Laplace ? this.LaplaceEstimate(classparentfreq, this.m_SumInstances - missing4ParentAtt, this.m_NumClasses * this.m_NumAttValues[parent]) : this.MEstimate(classparentfreq, this.m_SumInstances - missing4ParentAtt, this.m_NumClasses * this.m_NumAttValues[parent]);
                    int att2 = 0;
                    while (att2 < this.m_NumAttributes) {
                        if (attIndex[att2] != -1 && SpecialGeneralArray[att2] == -1) {
                            double missingForParentandChildAtt = countsForClassParent[this.m_StartAttIndex[att2] + this.m_NumAttValues[att2]];
                            x = this.m_Laplace ? (x *= this.LaplaceEstimate(countsForClassParent[attIndex[att2]], classparentfreq - missingForParentandChildAtt, this.m_NumAttValues[att2])) : (x *= this.MEstimate(countsForClassParent[attIndex[att2]], classparentfreq - missingForParentandChildAtt, this.m_NumAttValues[att2]));
                        }
                        ++att2;
                    }
                    int n = classVal;
                    probs[n] = probs[n] + x;
                    attIndex[parent] = pIndex;
                }
                ++parent;
            }
            if (parentCount < 1) {
                probs[classVal] = this.NBconditionalProb(instance, classVal);
            } else {
                int n = classVal;
                probs[n] = probs[n] / (double)parentCount;
            }
            ++classVal;
        }
        Utils.normalize(probs);
        return probs;
    }

    public double NBconditionalProb(Instance instance, int classVal) throws Exception {
        double prob = this.m_Laplace ? this.LaplaceEstimate(this.m_ClassCounts[classVal], this.m_SumInstances, this.m_NumClasses) : this.MEstimate(this.m_ClassCounts[classVal], this.m_SumInstances, this.m_NumClasses);
        double[][] pointer = this.m_CondiCounts[classVal];
        int att = 0;
        while (att < this.m_NumAttributes) {
            if (att != this.m_ClassIndex && !instance.isMissing(att)) {
                int attIndex = this.m_StartAttIndex[att] + (int)instance.value(att);
                prob = this.m_Laplace ? (prob *= this.LaplaceEstimate(pointer[attIndex][attIndex], this.m_SumForCounts[classVal][att], this.m_NumAttValues[att])) : (prob *= this.MEstimate(pointer[attIndex][attIndex], this.m_SumForCounts[classVal][att], this.m_NumAttValues[att]));
            }
            ++att;
        }
        return prob;
    }

    public double MEstimate(double frequency, double total, double numValues) {
        return (frequency + this.m_MWeight / numValues) / (total + this.m_MWeight);
    }

    public double LaplaceEstimate(double frequency, double total, double numValues) {
        return (frequency + 1.0) / (total + numValues);
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(5);
        newVector.addElement(new Option("\tOutput debugging information\n", "D", 0, "-D"));
        newVector.addElement(new Option("\tImpose a critcal value for specialization-generalization relationship\n\t(default is 50)", "C", 1, "-C"));
        newVector.addElement(new Option("\tImpose a frequency limit for superParents\n\t(default is 1)", "F", 2, "-F"));
        newVector.addElement(new Option("\tUsing Laplace estimation\n\t(default is m-esimation (m=1))", "L", 3, "-L"));
        newVector.addElement(new Option("\tWeight value for m-estimation\n\t(default is 1.0)", "M", 4, "-M"));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.m_Debug = Utils.getFlag('D', options);
        String Critical = Utils.getOption('C', options);
        this.m_Critical = Critical.length() != 0 ? Integer.parseInt(Critical) : 50;
        String Freq = Utils.getOption('F', options);
        this.m_Limit = Freq.length() != 0 ? Integer.parseInt(Freq) : 1;
        this.m_Laplace = Utils.getFlag('L', options);
        String MWeight = Utils.getOption('M', options);
        if (MWeight.length() != 0) {
            if (this.m_Laplace) {
                throw new Exception("weight for m-estimate is pointless if using laplace estimation!");
            }
            this.m_MWeight = Double.parseDouble(MWeight);
        } else {
            this.m_MWeight = 1.0;
        }
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.m_Debug) {
            result.add("-D");
        }
        result.add("-F");
        result.add("" + this.m_Limit);
        if (this.m_Laplace) {
            result.add("-L");
        } else {
            result.add("-M");
            result.add("" + this.m_MWeight);
        }
        result.add("-C");
        result.add("" + this.m_Critical);
        return result.toArray(new String[result.size()]);
    }

    public String mestWeightTipText() {
        return "Set the weight for m-estimate.";
    }

    public void setMestWeight(double w) {
        if (this.getUseLaplace()) {
            System.out.println("Weight is only used in conjunction with m-estimate - ignored!");
        } else if (w > 0.0) {
            this.m_MWeight = w;
        } else {
            System.out.println("M-Estimate Weight must be greater than 0!");
        }
    }

    public double getMestWeight() {
        return this.m_MWeight;
    }

    public String useLaplaceTipText() {
        return "Use Laplace correction instead of m-estimation.";
    }

    public boolean getUseLaplace() {
        return this.m_Laplace;
    }

    public void setUseLaplace(boolean value) {
        this.m_Laplace = value;
    }

    public String frequencyLimitTipText() {
        return "Attributes with a frequency in the train set below this value aren't used as parents.";
    }

    public void setFrequencyLimit(int f) {
        this.m_Limit = f;
    }

    public int getFrequencyLimit() {
        return this.m_Limit;
    }

    public String criticalValueTipText() {
        return "Specify critical value for specialization-generalization relationship (default 50).";
    }

    public void setCriticalValue(int c) {
        this.m_Critical = c;
    }

    public int getCriticalValue() {
        return this.m_Critical;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        text.append("The AODEsr Classifier");
        if (this.m_Instances == null) {
            text.append(": No model built yet.");
        } else {
            try {
                int i = 0;
                while (i < this.m_NumClasses) {
                    text.append("\nClass " + this.m_Instances.classAttribute().value(i) + ": Prior probability = " + Utils.doubleToString((this.m_ClassCounts[i] + 1.0) / (this.m_SumInstances + (double)this.m_NumClasses), 4, 2) + "\n\n");
                    ++i;
                }
                text.append("Dataset: " + this.m_Instances.relationName() + "\n" + "Instances: " + this.m_NumInstances + "\n" + "Attributes: " + this.m_NumAttributes + "\n" + "Frequency limit for superParents: " + this.m_Limit + "\n" + "Critical value for the specializtion-generalization " + "relationship: " + this.m_Critical + "\n");
                if (this.m_Laplace) {
                    text.append("Using LapLace estimation.");
                } else {
                    text.append("Using m-estimation, m = " + this.m_MWeight);
                }
            }
            catch (Exception ex) {
                text.append(ex.getMessage());
            }
        }
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5516 $");
    }

    public static void main(String[] argv) {
        AODEsr.runClassifier(new AODEsr(), argv);
    }
}

