package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/classify/NaiveBayesClassifierFactory.class */
public class NaiveBayesClassifierFactory<L, F> implements ClassifierFactory<L, F, NaiveBayesClassifier<L, F>> {
    private static final Redwood.RedwoodChannels logger = Redwood.channels(NaiveBayesClassifierFactory.class);
    private static final long serialVersionUID = -8164165428834534041L;
    public static final int JL = 0;
    public static final int CL = 1;
    public static final int UCL = 2;
    private int kind;
    private double alphaClass;
    private double alphaFeature;
    private double sigma;
    private int prior;
    private Index<L> labelIndex;
    private Index<F> featureIndex;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/classify/NaiveBayesClassifierFactory$NBWeights.class */
    public static class NBWeights {
        double[] priors;
        double[][][] weights;

        NBWeights(double[] dArr, double[][][] dArr2) {
            this.priors = dArr;
            this.weights = dArr2;
        }

        NBWeights(double[][] dArr, int[] iArr) {
            int length = dArr[0].length;
            this.priors = new double[length];
            synchronized (System.class) {
                System.arraycopy(dArr[0], 0, this.priors, 0, length);
            }
            int[] iArr2 = new int[iArr.length];
            for (int i = 1; i < iArr.length; i++) {
                iArr2[i] = iArr2[i - 1] + iArr[i - 1];
            }
            this.weights = new double[this.priors.length][iArr2.length];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                for (int i3 = 0; i3 < length; i3++) {
                    this.weights[i3][i2] = new double[iArr[i2]];
                }
                for (int i4 = 0; i4 < iArr[i2]; i4++) {
                    int i5 = iArr2[i2] + i4 + 1;
                    for (int i6 = 0; i6 < length; i6++) {
                        this.weights[i6][i2][i4] = dArr[i5][i6];
                    }
                }
            }
        }
    }

    public NaiveBayesClassifierFactory() {
        this.kind = 0;
        this.prior = LogPrior.LogPriorType.NULL.ordinal();
    }

    public NaiveBayesClassifierFactory(double d, double d2, double d3, int i, int i2) {
        this.kind = 0;
        this.prior = LogPrior.LogPriorType.NULL.ordinal();
        this.alphaClass = d;
        this.alphaFeature = d2;
        this.sigma = d3;
        this.prior = i;
        this.kind = i2;
    }

    private NaiveBayesClassifier<L, F> trainClassifier(int[][] iArr, int[] iArr2, int i, int i2, Index<L> index, Index<F> index2) {
        Set newHashSet = Generics.newHashSet();
        NBWeights trainWeights = trainWeights(iArr, iArr2, i, i2);
        ClassicCounter classicCounter = new ClassicCounter();
        double[] dArr = trainWeights.priors;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            classicCounter.incrementCount(index.get(i3), dArr[i3]);
            newHashSet.add(index.get(i3));
        }
        ClassicCounter classicCounter2 = new ClassicCounter();
        double[][][] dArr2 = trainWeights.weights;
        for (int i4 = 0; i4 < i2; i4++) {
            L l = index.get(i4);
            for (int i5 = 0; i5 < i; i5++) {
                Pair pair = new Pair(l, index2.get(i5));
                for (int i6 = 0; i6 < dArr2[i4][i5].length; i6++) {
                    classicCounter2.incrementCount(new Pair(pair, Integer.valueOf(i6)), dArr2[i4][i5][i6]);
                }
            }
        }
        return new NaiveBayesClassifier<>(classicCounter2, classicCounter, newHashSet);
    }

    public NaiveBayesClassifier<L, F> trainClassifier(GeneralDataset<L, F> generalDataset, Set<F> set) {
        int size = set.size();
        int[][] iArr = new int[generalDataset.size()][size];
        int[] iArr2 = new int[generalDataset.size()];
        this.labelIndex = new HashIndex();
        this.featureIndex = new HashIndex();
        Iterator<F> it2 = set.iterator();
        while (it2.hasNext()) {
            this.featureIndex.add(it2.next());
        }
        for (int i = 0; i < generalDataset.size(); i++) {
            RVFDatum<L, F> rVFDatum = generalDataset.getRVFDatum(i);
            Counter<F> asFeaturesCounter = rVFDatum.asFeaturesCounter();
            for (F f : asFeaturesCounter.keySet()) {
                iArr[i][this.featureIndex.indexOf(f)] = (int) asFeaturesCounter.getCount(f);
            }
            this.labelIndex.add(rVFDatum.label());
            iArr2[i] = this.labelIndex.indexOf(rVFDatum.label());
        }
        return trainClassifier(iArr, iArr2, size, this.labelIndex.size(), this.labelIndex, this.featureIndex);
    }

    private NBWeights trainWeights(int[][] iArr, int[] iArr2, int i, int i2) {
        if (this.kind == 0) {
            return trainWeightsJL(iArr, iArr2, i, i2);
        }
        if (this.kind == 2) {
            return trainWeightsUCL(iArr, iArr2, i, i2);
        }
        if (this.kind == 1) {
            return trainWeightsCL(iArr, iArr2, i, i2);
        }
        return null;
    }

    private NBWeights trainWeightsJL(int[][] iArr, int[] iArr2, int i, int i2) {
        int[] numberValues = numberValues(iArr, i);
        double[] dArr = new double[i2];
        double[][][] dArr2 = new double[i2][i];
        for (int i3 = 0; i3 < i2; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                dArr2[i3][i4] = new double[numberValues[i4]];
            }
        }
        for (int i5 = 0; i5 < iArr.length; i5++) {
            int i6 = iArr2[i5];
            dArr[i6] = dArr[i6] + 1.0d;
            for (int i7 = 0; i7 < i; i7++) {
                double[] dArr3 = dArr2[iArr2[i5]][i7];
                int i8 = iArr[i5][i7];
                dArr3[i8] = dArr3[i8] + 1.0d;
            }
        }
        for (int i9 = 0; i9 < i2; i9++) {
            for (int i10 = 0; i10 < i; i10++) {
                for (int i11 = 0; i11 < numberValues[i10]; i11++) {
                    dArr2[i9][i10][i11] = Math.log((dArr2[i9][i10][i11] + this.alphaFeature) / (dArr[i9] + (this.alphaFeature * numberValues[i10])));
                }
            }
            dArr[i9] = Math.log((dArr[i9] + this.alphaClass) / (iArr.length + (this.alphaClass * i2)));
        }
        return new NBWeights(dArr, dArr2);
    }

    private NBWeights trainWeightsUCL(int[][] iArr, int[] iArr2, int i, int i2) {
        int[] numberValues = numberValues(iArr, i);
        int[] iArr3 = new int[i];
        for (int i3 = 1; i3 < i; i3++) {
            iArr3[i3] = iArr3[i3 - 1] + numberValues[i3 - 1];
        }
        int[][] iArr4 = new int[iArr.length][i + 1];
        for (int i4 = 0; i4 < iArr.length; i4++) {
            iArr4[i4][0] = 0;
            for (int i5 = 0; i5 < i; i5++) {
                iArr4[i4][i5 + 1] = iArr3[i5] + iArr[i4][i5] + 1;
            }
        }
        int i6 = iArr3[i - 1] + numberValues[i - 1] + 1;
        logger.info("total feats " + i6);
        LogConditionalObjectiveFunction logConditionalObjectiveFunction = new LogConditionalObjectiveFunction(i6, i2, iArr4, iArr2, this.prior, this.sigma, 0.0d);
        double[][] dArr = logConditionalObjectiveFunction.to2D(new QNMinimizer().minimize((QNMinimizer) logConditionalObjectiveFunction, 1.0E-4d, logConditionalObjectiveFunction.initial()));
        System.out.println("weights have dimension " + dArr.length);
        return new NBWeights(dArr, numberValues);
    }

    private NBWeights trainWeightsCL(int[][] iArr, int[] iArr2, int i, int i2) {
        LogConditionalEqConstraintFunction logConditionalEqConstraintFunction = new LogConditionalEqConstraintFunction(i, i2, iArr, iArr2, this.prior, this.sigma, 0.0d);
        double[] minimize = new QNMinimizer().minimize((QNMinimizer) logConditionalEqConstraintFunction, 1.0E-4d, logConditionalEqConstraintFunction.initial());
        return new NBWeights(logConditionalEqConstraintFunction.priors(minimize), logConditionalEqConstraintFunction.to3D(minimize));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int[] numberValues(int[][] iArr, int i) {
        int[] iArr2 = new int[i];
        for (int[] iArr3 : iArr) {
            for (int i2 = 0; i2 < iArr3.length; i2++) {
                if (iArr2[i2] < iArr3[i2] + 1) {
                    iArr2[i2] = iArr3[i2] + 1;
                }
            }
        }
        return iArr2;
    }

    @Override // edu.stanford.nlp.classify.ClassifierFactory
    public NaiveBayesClassifier<L, F> trainClassifier(GeneralDataset<L, F> generalDataset) {
        if (generalDataset instanceof RVFDataset) {
            throw new RuntimeException("Not sure if RVFDataset runs correctly in this method. Please update this code if it does.");
        }
        return trainClassifier(generalDataset.getDataArray(), generalDataset.labels, generalDataset.numFeatures(), generalDataset.numClasses(), generalDataset.labelIndex, generalDataset.featureIndex);
    }
}
