package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/stanford/nlp/parser/lexparser/ChineseSimWordAvgDepGrammar.class */
public class ChineseSimWordAvgDepGrammar extends MLEDependencyGrammar {
    private static Redwood.RedwoodChannels log = Redwood.channels(ChineseSimWordAvgDepGrammar.class);
    private static final long serialVersionUID = -1845503582705055342L;
    private static final double simSmooth = 10.0d;
    private static final String argHeadFile = "simWords/ArgHead.5";
    private static final String headArgFile = "simWords/HeadArg.5";
    private Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> simArgMap;
    private Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> simHeadMap;
    private static final boolean debug = true;
    private static final boolean verbose = false;
    private ClassicCounter<String> statsCounter;

    public ChineseSimWordAvgDepGrammar(TreebankLangParserParams treebankLangParserParams, boolean z, boolean z2, boolean z3, boolean z4, Options options, Index<String> index, Index<String> index2) {
        super(treebankLangParserParams, z, z2, z3, z4, options, index, index2);
        this.statsCounter = new ClassicCounter<>();
        this.simHeadMap = getMap(headArgFile);
        this.simArgMap = getMap(argHeadFile);
    }

    public Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> getMap(String str) {
        Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> newHashMap = Generics.newHashMap();
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str), "UTF-8"));
            Pattern compile = Pattern.compile("sim\\((.+)/(.+):(.+)/(.+)\\)=(.+)");
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return newHashMap;
                }
                Matcher matcher = compile.matcher(readLine);
                if (matcher.matches()) {
                    Pair<Integer, String> pair = new Pair<>(Integer.valueOf(this.wordIndex.addToIndex(matcher.group(1))), matcher.group(2));
                    double parseDouble = Double.parseDouble(matcher.group(5));
                    List<Triple<Integer, String, Double>> list = newHashMap.get(pair);
                    if (list == null) {
                        list = new ArrayList();
                        newHashMap.put(pair, list);
                    }
                    list.add(new Triple<>(Integer.valueOf(this.wordIndex.addToIndex(matcher.group(3))), matcher.group(4), Double.valueOf(parseDouble)));
                } else {
                    log.info("Ill-formed line in similar word map file: " + readLine);
                }
            }
        } catch (IOException e) {
            throw new RuntimeException("Problem reading similar words file!");
        }
    }

    @Override // edu.stanford.nlp.parser.lexparser.MLEDependencyGrammar, edu.stanford.nlp.parser.lexparser.DependencyGrammar
    public double scoreTB(IntDependency intDependency) {
        return this.op.testOptions.depWeight * Math.log(probTBwithSimWords(intDependency));
    }

    public void setLex(Lexicon lexicon) {
        this.lex = lexicon;
    }

    public void dumpSimWordAvgStats() {
        log.info("SimWordAvg stats:");
        log.info(this.statsCounter);
    }

    private double probTBwithSimWords(IntDependency intDependency) {
        boolean z = intDependency.leftHeaded && this.directional;
        IntTaggedWord intTaggedWord = new IntTaggedWord(-1, intDependency.head.tag);
        IntTaggedWord intTaggedWord2 = new IntTaggedWord(-1, intDependency.arg.tag);
        short s = intDependency.distance;
        IntTaggedWord intTaggedWord3 = intDependency.arg;
        double stopProb = getStopProb(intDependency);
        boolean rootTW = rootTW(intDependency.head);
        if (intDependency.arg.word == -2) {
            if (rootTW) {
                return 0.0d;
            }
            return stopProb;
        }
        double d = 1.0d - stopProb;
        if (rootTW) {
            d = 1.0d;
        }
        short valenceBin = valenceBin(s);
        double count = this.argCounter.getCount(new IntDependency(intDependency.head, intDependency.arg, z, valenceBin));
        double count2 = this.argCounter.getCount(new IntDependency(intDependency.head, intTaggedWord2, z, valenceBin));
        double count3 = this.argCounter.getCount(new IntDependency(intDependency.head, this.wildTW, z, valenceBin));
        double count4 = this.argCounter.getCount(new IntDependency(intTaggedWord, intDependency.arg, z, valenceBin));
        double count5 = this.argCounter.getCount(new IntDependency(intTaggedWord, intTaggedWord2, z, valenceBin));
        double count6 = this.argCounter.getCount(new IntDependency(intTaggedWord, this.wildTW, z, valenceBin));
        double count7 = this.argCounter.getCount(new IntDependency(this.wildTW, intDependency.arg, false, -1));
        double count8 = this.argCounter.getCount(new IntDependency(this.wildTW, intTaggedWord2, false, -1));
        double d2 = count6 > 0.0d ? count4 / count6 : 0.0d;
        double d3 = count6 > 0.0d ? count5 / count6 : 0.0d;
        double d4 = count7 > 0.0d ? count7 / count8 : 1.0d;
        double d5 = (count2 + (this.smooth_aT_hTWd * d3)) / (count3 + this.smooth_aT_hTWd);
        List<Triple<Integer, String, Double>> list = this.simArgMap.get(new Pair(Integer.valueOf(intDependency.arg.word), stringBasicCategory(intDependency.arg.tag)));
        List<Triple<Integer, String, Double>> list2 = this.simHeadMap.get(new Pair(Integer.valueOf(intDependency.head.word), stringBasicCategory(intDependency.head.tag)));
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        if (list != null) {
            Iterator<Triple<Integer, String, Double>> it2 = list.iterator();
            while (it2.hasNext()) {
                arrayList.add(it2.next().first);
            }
        }
        if (list2 != null) {
            Iterator<Triple<Integer, String, Double>> it3 = list2.iterator();
            while (it3.hasNext()) {
                arrayList2.add(it3.next().first);
            }
        }
        double d6 = 0.0d;
        double d7 = 0.0d;
        Iterator it4 = arrayList2.iterator();
        while (it4.hasNext()) {
            IntTaggedWord intTaggedWord4 = new IntTaggedWord(((Integer) it4.next()).intValue(), intDependency.head.tag);
            d6 += this.argCounter.getCount(new IntDependency(intTaggedWord4, intDependency.arg, intDependency.leftHeaded, intDependency.distance));
            d7 += this.argCounter.getCount(new IntDependency(intTaggedWord4, this.wildTW, intDependency.leftHeaded, intDependency.distance));
        }
        double d8 = d7 > 0.0d ? d6 / d7 : 0.0d;
        if (d8 > 0.0d) {
            System.out.println(intDependency + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + d8);
        }
        double d9 = ((count + (17.7d * d8)) + (35.4d * d2)) / ((count3 + 17.7d) + 35.4d);
        System.out.println(intDependency);
        System.out.println(count + " + 17.7 * " + d8 + " + 35.4 * " + d2);
        System.out.println("--------------------------------  = " + d9);
        System.out.println(count3 + " + 17.7 + 35.4");
        System.out.println();
        double d10 = ((this.interp * d9) + ((1.0d - this.interp) * d4 * d5)) * d;
        if (this.op.testOptions.prunePunc && pruneTW(intTaggedWord3)) {
            return 1.0d;
        }
        if (Double.isNaN(d10)) {
            d10 = 0.0d;
        }
        if (d10 < 1.0E-40d) {
            d10 = 0.0d;
        }
        return d10;
    }

    private double probSimilarWordAvg(IntDependency intDependency) {
        double probTB = probTB(intDependency);
        this.statsCounter.incrementCount("total");
        List<Triple<Integer, String, Double>> list = this.simArgMap.get(new Pair(Integer.valueOf(intDependency.arg.word), stringBasicCategory(intDependency.arg.tag)));
        List<Triple<Integer, String, Double>> list2 = this.simHeadMap.get(new Pair(Integer.valueOf(intDependency.head.word), stringBasicCategory(intDependency.head.tag)));
        if (list2 == null && list == null) {
            return probTB;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        if (list2 == null) {
            this.statsCounter.incrementCount("aSim");
            for (Triple<Integer, String, Double> triple : list) {
                double exp = Math.exp((-50.0d) * triple.third.doubleValue());
                int size = this.tagIndex.size();
                for (int i = 0; i < size; i++) {
                    if (stringBasicCategory(i).equals(triple.second)) {
                        IntDependency intDependency2 = new IntDependency(intDependency.head, new IntTaggedWord(triple.first.intValue(), i), intDependency.leftHeaded, intDependency.distance);
                        double exp2 = Math.exp(this.lex.score(r0, 0, this.wordIndex.get(r0.word), null));
                        if (exp2 != 0.0d) {
                            d += (probTB(intDependency2) * exp) / exp2;
                            d2 += exp;
                        }
                    }
                }
            }
        } else if (list == null) {
            this.statsCounter.incrementCount("hSim");
            for (Triple<Integer, String, Double> triple2 : list2) {
                double exp3 = Math.exp((-50.0d) * triple2.third.doubleValue());
                int size2 = this.tagIndex.size();
                for (int i2 = 0; i2 < size2; i2++) {
                    if (stringBasicCategory(i2).equals(triple2.second)) {
                        d += probTB(new IntDependency(new IntTaggedWord(triple2.first.intValue(), i2), intDependency.arg, intDependency.leftHeaded, intDependency.distance)) * exp3;
                        d2 += exp3;
                    }
                }
            }
        } else {
            this.statsCounter.incrementCount("hSim");
            this.statsCounter.incrementCount("aSim");
            this.statsCounter.incrementCount("aSim&hSim");
            for (Triple<Integer, String, Double> triple3 : list) {
                int size3 = this.tagIndex.size();
                for (int i3 = 0; i3 < size3; i3++) {
                    if (stringBasicCategory(i3).equals(triple3.second)) {
                        IntTaggedWord intTaggedWord = new IntTaggedWord(triple3.first.intValue(), i3);
                        double exp4 = Math.exp(this.lex.score(intTaggedWord, 0, this.wordIndex.get(intTaggedWord.word), null));
                        if (exp4 != 0.0d) {
                            for (Triple<Integer, String, Double> triple4 : list2) {
                                for (int i4 = 0; i4 < size3; i4++) {
                                    if (stringBasicCategory(i4).equals(triple4.second)) {
                                        IntDependency intDependency3 = new IntDependency(new IntTaggedWord(triple4.first.intValue(), i3), intTaggedWord, intDependency.leftHeaded, intDependency.distance);
                                        double exp5 = Math.exp((-50.0d) * triple4.third.doubleValue()) * Math.exp((-50.0d) * triple3.third.doubleValue());
                                        d += (probTB(intDependency3) * exp5) / exp4;
                                        d2 += exp5;
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        double count = this.argCounter.getCount(new IntDependency(intDependency.head, this.wildTW, intDependency.leftHeaded, intDependency.distance));
        double exp6 = list == null ? d / d2 : (Math.exp(this.lex.score(intDependency.arg, 0, this.wordIndex.get(intDependency.arg.word), null)) * d) / d2;
        if (exp6 == 0.0d) {
            this.statsCounter.incrementCount("simProbZero");
        }
        if (probTB == 0.0d) {
            this.statsCounter.incrementCount("regProbZero");
        }
        double d3 = ((count * probTB) + (simSmooth * exp6)) / (count + simSmooth);
        if (d3 == 0.0d) {
            this.statsCounter.incrementCount("smoothProbZero");
        }
        return d3;
    }

    private String stringBasicCategory(int i) {
        return this.tlp.basicCategory(this.tagIndex.get(i));
    }
}
