package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.PrintStream;
import java.util.Collection;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/classify/NaiveBayesClassifier.class */
public class NaiveBayesClassifier<L, F> implements Classifier<L, F>, RVFClassifier<L, F> {
    private static final long serialVersionUID = 1544820342684024068L;
    private Counter<Pair<Pair<L, F>, Number>> weights;
    private Counter<L> priors;
    private Set<F> features;
    private boolean addZeroValued;
    private Counter<L> priorZero;
    private Set<L> labels;
    private final Integer zero;
    private static final Redwood.RedwoodChannels logger = Redwood.channels(NaiveBayesClassifier.class);

    @Override // edu.stanford.nlp.classify.Classifier
    public Collection<L> labels() {
        return this.labels;
    }

    @Override // edu.stanford.nlp.classify.RVFClassifier
    public L classOf(RVFDatum<L, F> rVFDatum) {
        return (L) Counters.argmax(scoresOf((RVFDatum) rVFDatum));
    }

    @Override // edu.stanford.nlp.classify.RVFClassifier
    public ClassicCounter<L> scoresOf(RVFDatum<L, F> rVFDatum) {
        ClassicCounter<L> classicCounter = new ClassicCounter<>();
        Counters.addInPlace(classicCounter, this.priors);
        if (this.addZeroValued) {
            Counters.addInPlace(classicCounter, this.priorZero);
        }
        for (L l : this.labels) {
            double d = 0.0d;
            Counter<F> asFeaturesCounter = rVFDatum.asFeaturesCounter();
            for (F f : asFeaturesCounter.keySet()) {
                d += weight(l, f, Integer.valueOf((int) asFeaturesCounter.getCount(f)));
                if (this.addZeroValued) {
                    d -= weight(l, f, this.zero);
                }
            }
            classicCounter.incrementCount(l, d);
        }
        return classicCounter;
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public L classOf(Datum<L, F> datum) {
        return classOf((RVFDatum) new RVFDatum<>(datum));
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public ClassicCounter<L> scoresOf(Datum<L, F> datum) {
        return scoresOf((RVFDatum) new RVFDatum<>(datum));
    }

    public NaiveBayesClassifier(Counter<Pair<Pair<L, F>, Number>> counter, Counter<L> counter2, Set<L> set, Set<F> set2, boolean z) {
        this.zero = 0;
        this.weights = counter;
        this.features = set2;
        this.priors = counter2;
        this.labels = set;
        this.addZeroValued = z;
        if (this.addZeroValued) {
            initZeros();
        }
    }

    public float accuracy(Iterator<RVFDatum<L, F>> it2) {
        int i = 0;
        int i2 = 0;
        while (it2.hasNext()) {
            RVFDatum<L, F> next = it2.next();
            if (classOf((RVFDatum) next).equals(next.label())) {
                i++;
            }
            i2++;
        }
        logger.info("correct " + i + " out of " + i2);
        return i / i2;
    }

    public void print(PrintStream printStream) {
        printStream.println("priors ");
        printStream.println(this.priors.toString());
        printStream.println("weights ");
        printStream.println(this.weights.toString());
    }

    public void print() {
        print(System.out);
    }

    private double weight(L l, F f, Number number) {
        return this.weights.getCount(new Pair(new Pair(l, f), number));
    }

    public NaiveBayesClassifier(Counter<Pair<Pair<L, F>, Number>> counter, Counter<L> counter2, Set<L> set) {
        this(counter, counter2, set, null, false);
    }

    private void initZeros() {
        this.priorZero = new ClassicCounter();
        for (L l : this.labels) {
            double d = 0.0d;
            Iterator<F> it2 = this.features.iterator();
            while (it2.hasNext()) {
                d += weight(l, it2.next(), this.zero);
            }
            this.priorZero.setCount(l, d);
        }
    }
}
