package weka.classifiers.meta;

import de.metanome.algorithm_integration.ColumnIdentifier;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.Sourcable;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: input_file:weka/classifiers/meta/AdaBoostM1.class */
public class AdaBoostM1 extends RandomizableIteratedSingleClassifierEnhancer implements WeightedInstancesHandler, Sourcable, TechnicalInformationHandler {
    static final long serialVersionUID = -7378107808933117974L;
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    protected double[] m_Betas;
    protected int m_NumIterationsPerformed;
    protected int m_WeightThreshold = 100;
    protected boolean m_UseResampling;
    protected int m_NumClasses;
    protected Classifier m_ZeroR;

    public AdaBoostM1() {
        this.m_Classifier = new DecisionStump();
    }

    public String globalInfo() {
        return "Class for boosting a nominal class classifier using the Adaboost M1 method. Only nominal class problems can be tackled. Often dramatically improves performance, but sometimes overfits.\n\nFor more information, see\n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Yoav Freund and Robert E. Schapire");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Experiments with a new boosting algorithm");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Thirteenth International Conference on Machine Learning");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1996");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "148-156");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "Morgan Kaufmann");
        technicalInformation.setValue(TechnicalInformation.Field.ADDRESS, "San Francisco");
        return technicalInformation;
    }

    @Override // weka.classifiers.SingleClassifierEnhancer
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    protected Instances selectWeightQuantile(Instances instances, double d) {
        int numInstances = instances.numInstances();
        Instances instances2 = new Instances(instances, numInstances);
        double[] dArr = new double[numInstances];
        double d2 = 0.0d;
        for (int i = 0; i < numInstances; i++) {
            dArr[i] = instances.instance(i).weight();
            d2 += dArr[i];
        }
        double d3 = d2 * d;
        int[] sort = Utils.sort(dArr);
        double d4 = 0.0d;
        for (int i2 = numInstances - 1; i2 >= 0; i2--) {
            instances2.add((Instance) instances.instance(sort[i2]).copy());
            d4 += dArr[sort[i2]];
            if (d4 > d3 && i2 > 0 && dArr[sort[i2]] != dArr[sort[i2 - 1]]) {
                break;
            }
        }
        if (this.m_Debug) {
            System.err.println("Selected " + instances2.numInstances() + " out of " + numInstances);
        }
        return instances2;
    }

    @Override // weka.classifiers.RandomizableIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tPercentage of weight mass to base training on.\n\t(default 100, reduce to around 90 speed up)", "P", 1, "-P <num>"));
        vector.addElement(new Option("\tUse resampling for boosting.", "Q", 0, "-Q"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('P', strArr);
        if (option.length() != 0) {
            setWeightThreshold(Integer.parseInt(option));
        } else {
            setWeightThreshold(100);
        }
        setUseResampling(Utils.getFlag('Q', strArr));
        super.setOptions(strArr);
    }

    @Override // weka.classifiers.RandomizableIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        if (getUseResampling()) {
            vector.add("-Q");
        }
        vector.add("-P");
        vector.add("" + getWeightThreshold());
        for (String str : super.getOptions()) {
            vector.add(str);
        }
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public String weightThresholdTipText() {
        return "Weight threshold for weight pruning.";
    }

    public void setWeightThreshold(int i) {
        this.m_WeightThreshold = i;
    }

    public int getWeightThreshold() {
        return this.m_WeightThreshold;
    }

    public String useResamplingTipText() {
        return "Whether resampling is used instead of reweighting.";
    }

    public void setUseResampling(boolean z) {
        this.m_UseResampling = z;
    }

    public boolean getUseResampling() {
        return this.m_UseResampling;
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        if (super.getCapabilities().handles(Capabilities.Capability.NOMINAL_CLASS)) {
            capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        }
        if (super.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        }
        return capabilities;
    }

    @Override // weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        super.buildClassifier(instances);
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (instances2.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(instances2);
            return;
        }
        this.m_ZeroR = null;
        this.m_NumClasses = instances2.numClasses();
        if (this.m_UseResampling || !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            buildClassifierUsingResampling(instances2);
        } else {
            buildClassifierWithWeights(instances2);
        }
    }

    protected void buildClassifierUsingResampling(Instances instances) throws Exception {
        double errorRate;
        int numInstances = instances.numInstances();
        Random random = new Random(this.m_Seed);
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterationsPerformed = 0;
        Instances instances2 = new Instances(instances, 0, numInstances);
        double sumOfWeights = instances2.sumOfWeights();
        for (int i = 0; i < instances2.numInstances(); i++) {
            instances2.instance(i).setWeight(instances2.instance(i).weight() / sumOfWeights);
        }
        this.m_NumIterationsPerformed = 0;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            Instances selectWeightQuantile = this.m_WeightThreshold < 100 ? selectWeightQuantile(instances2, this.m_WeightThreshold / 100.0d) : new Instances(instances2);
            int i2 = 0;
            double[] dArr = new double[selectWeightQuantile.numInstances()];
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr[i3] = selectWeightQuantile.instance(i3).weight();
            }
            do {
                this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(selectWeightQuantile.resampleWithWeights(random, dArr));
                Evaluation evaluation = new Evaluation(instances);
                evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], instances2, new Object[0]);
                errorRate = evaluation.errorRate();
                i2++;
                if (!Utils.eq(errorRate, KStarConstants.FLOOR)) {
                    break;
                }
            } while (i2 < MAX_NUM_RESAMPLING_ITERATIONS);
            if (Utils.grOrEq(errorRate, 0.5d) || Utils.eq(errorRate, KStarConstants.FLOOR)) {
                if (this.m_NumIterationsPerformed == 0) {
                    this.m_NumIterationsPerformed = 1;
                    return;
                }
                return;
            } else {
                this.m_Betas[this.m_NumIterationsPerformed] = Math.log((1.0d - errorRate) / errorRate);
                double d = (1.0d - errorRate) / errorRate;
                if (this.m_Debug) {
                    System.err.println("\terror rate = " + errorRate + "  beta = " + this.m_Betas[this.m_NumIterationsPerformed]);
                }
                setWeights(instances2, d);
                this.m_NumIterationsPerformed++;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setWeights(Instances instances, double d) throws Exception {
        double sumOfWeights = instances.sumOfWeights();
        Enumeration enumerateInstances = instances.enumerateInstances();
        while (enumerateInstances.hasMoreElements()) {
            Instance instance = (Instance) enumerateInstances.nextElement();
            if (!Utils.eq(this.m_Classifiers[this.m_NumIterationsPerformed].classifyInstance(instance), instance.classValue())) {
                instance.setWeight(instance.weight() * d);
            }
        }
        double sumOfWeights2 = instances.sumOfWeights();
        Enumeration enumerateInstances2 = instances.enumerateInstances();
        while (enumerateInstances2.hasMoreElements()) {
            Instance instance2 = (Instance) enumerateInstances2.nextElement();
            instance2.setWeight((instance2.weight() * sumOfWeights) / sumOfWeights2);
        }
    }

    protected void buildClassifierWithWeights(Instances instances) throws Exception {
        int numInstances = instances.numInstances();
        Random random = new Random(this.m_Seed);
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterationsPerformed = 0;
        Instances instances2 = new Instances(instances, 0, numInstances);
        this.m_NumIterationsPerformed = 0;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            Instances selectWeightQuantile = this.m_WeightThreshold < 100 ? selectWeightQuantile(instances2, this.m_WeightThreshold / 100.0d) : new Instances(instances2, 0, numInstances);
            if (this.m_Classifiers[this.m_NumIterationsPerformed] instanceof Randomizable) {
                ((Randomizable) this.m_Classifiers[this.m_NumIterationsPerformed]).setSeed(random.nextInt());
            }
            this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(selectWeightQuantile);
            Evaluation evaluation = new Evaluation(instances);
            evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], instances2, new Object[0]);
            double errorRate = evaluation.errorRate();
            if (Utils.grOrEq(errorRate, 0.5d) || Utils.eq(errorRate, KStarConstants.FLOOR)) {
                if (this.m_NumIterationsPerformed == 0) {
                    this.m_NumIterationsPerformed = 1;
                    return;
                }
                return;
            } else {
                this.m_Betas[this.m_NumIterationsPerformed] = Math.log((1.0d - errorRate) / errorRate);
                double d = (1.0d - errorRate) / errorRate;
                if (this.m_Debug) {
                    System.err.println("\terror rate = " + errorRate + "  beta = " + this.m_Betas[this.m_NumIterationsPerformed]);
                }
                setWeights(instances2, d);
                this.m_NumIterationsPerformed++;
            }
        }
    }

    @Override // weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(instance);
        }
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built");
        }
        double[] dArr = new double[instance.numClasses()];
        if (this.m_NumIterationsPerformed == 1) {
            return this.m_Classifiers[0].distributionForInstance(instance);
        }
        for (int i = 0; i < this.m_NumIterationsPerformed; i++) {
            int classifyInstance = (int) this.m_Classifiers[i].classifyInstance(instance);
            dArr[classifyInstance] = dArr[classifyInstance] + this.m_Betas[i];
        }
        return Utils.logs2probs(dArr);
    }

    @Override // weka.classifiers.Sourcable
    public String toSource(String str) throws Exception {
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built yet");
        }
        if (!(this.m_Classifiers[0] instanceof Sourcable)) {
            throw new Exception("Base learner " + this.m_Classifier.getClass().getName() + " is not Sourcable");
        }
        StringBuffer stringBuffer = new StringBuffer("class ");
        stringBuffer.append(str).append(" {\n\n");
        stringBuffer.append("  public static double classify(Object[] i) {\n");
        if (this.m_NumIterationsPerformed == 1) {
            stringBuffer.append("    return " + str + "_0.classify(i);\n");
        } else {
            stringBuffer.append("    double [] sums = new double [" + this.m_NumClasses + "];\n");
            for (int i = 0; i < this.m_NumIterationsPerformed; i++) {
                stringBuffer.append("    sums[(int) " + str + '_' + i + ".classify(i)] += " + this.m_Betas[i] + ";\n");
            }
            stringBuffer.append("    double maxV = sums[0];\n    int maxI = 0;\n    for (int j = 1; j < " + this.m_NumClasses + "; j++) {\n      if (sums[j] > maxV) { maxV = sums[j]; maxI = j; }\n    }\n    return (double) maxI;\n");
        }
        stringBuffer.append("  }\n}\n");
        for (int i2 = 0; i2 < this.m_Classifiers.length; i2++) {
            stringBuffer.append(((Sourcable) this.m_Classifiers[i2]).toSource(str + '_' + i2));
        }
        return stringBuffer.toString();
    }

    public String toString() {
        if (this.m_ZeroR != null) {
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append(getClass().getName().replaceAll(".*\\.", "") + "\n");
            stringBuffer.append(getClass().getName().replaceAll(".*\\.", "").replaceAll(ColumnIdentifier.TABLE_COLUMN_CONCATENATOR, "=") + "\n\n");
            stringBuffer.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            stringBuffer.append(this.m_ZeroR.toString());
            return stringBuffer.toString();
        }
        StringBuffer stringBuffer2 = new StringBuffer();
        if (this.m_NumIterationsPerformed == 0) {
            stringBuffer2.append("AdaBoostM1: No model built yet.\n");
        } else if (this.m_NumIterationsPerformed == 1) {
            stringBuffer2.append("AdaBoostM1: No boosting possible, one classifier used!\n");
            stringBuffer2.append(this.m_Classifiers[0].toString() + "\n");
        } else {
            stringBuffer2.append("AdaBoostM1: Base classifiers and their weights: \n\n");
            for (int i = 0; i < this.m_NumIterationsPerformed; i++) {
                stringBuffer2.append(this.m_Classifiers[i].toString() + "\n\n");
                stringBuffer2.append("Weight: " + Utils.roundDouble(this.m_Betas[i], 2) + "\n\n");
            }
            stringBuffer2.append("Number of performed Iterations: " + this.m_NumIterationsPerformed + "\n");
        }
        return stringBuffer2.toString();
    }

    @Override // weka.classifiers.Classifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.40 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new AdaBoostM1(), strArr);
    }
}
