package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.parser.common.NoSuchParseException;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.metrics.TreeSpanScoring;
import edu.stanford.nlp.trees.DeepTree;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.IntPair;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/parser/dvparser/DVParserCostAndGradient.class */
public class DVParserCostAndGradient extends AbstractCachingDiffFunction {
    private static Redwood.RedwoodChannels log = Redwood.channels(DVParserCostAndGradient.class);
    List<Tree> trainingBatch;
    IdentityHashMap<Tree, List<Tree>> topParses;
    DVModel dvModel;
    Options op;
    static final double TRAIN_LAMBDA = 1.0d;

    /* loaded from: input_file:edu/stanford/nlp/parser/dvparser/DVParserCostAndGradient$ScoringProcessor.class */
    class ScoringProcessor implements ThreadsafeProcessor<Tree, Pair<DeepTree, DeepTree>> {
        ScoringProcessor() {
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public Pair<DeepTree, DeepTree> process(Tree tree) {
            IdentityHashMap<Tree, SimpleMatrix> identityHashMap = new IdentityHashMap<>();
            double score = DVParserCostAndGradient.this.score(tree, identityHashMap);
            return Pair.makePair(new DeepTree(tree, identityHashMap, score), DVParserCostAndGradient.this.getHighestScoringTree(tree, 1.0d));
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        /* renamed from: newInstance */
        public ThreadsafeProcessor<Tree, Pair<DeepTree, DeepTree>> newInstance2() {
            return this;
        }
    }

    public DVParserCostAndGradient(List<Tree> list, IdentityHashMap<Tree, List<Tree>> identityHashMap, DVModel dVModel, Options options) {
        this.trainingBatch = list;
        this.topParses = identityHashMap;
        this.dvModel = dVModel;
        this.op = options;
    }

    private List<String> getContextWords(Tree tree) {
        ArrayList arrayList = null;
        if (this.op.trainOptions.useContextWords) {
            arrayList = Generics.newArrayList();
            Iterator<Label> it2 = tree.yield().iterator();
            while (it2.hasNext()) {
                arrayList.add(it2.next().value());
            }
        }
        return arrayList;
    }

    private SimpleMatrix concatenateContextWords(SimpleMatrix simpleMatrix, IntPair intPair, List<String> list) {
        return NeuralUtils.concatenate(simpleMatrix, intPair.getSource() < 0 ? this.dvModel.getStartWordVector() : this.dvModel.getWordVector(list.get(intPair.getSource())), intPair.getTarget() >= list.size() ? this.dvModel.getEndWordVector() : this.dvModel.getWordVector(list.get(intPair.getTarget())));
    }

    public static void outputSpans(Tree tree) {
        log.info(tree.getSpan() + " ");
        for (Tree tree2 : tree.children()) {
            outputSpans(tree2);
        }
    }

    public double score(Tree tree, IdentityHashMap<Tree, SimpleMatrix> identityHashMap) {
        List<String> contextWords = getContextWords(tree);
        IdentityHashMap<Tree, Double> identityHashMap2 = new IdentityHashMap<>();
        try {
            forwardPropagateTree(tree, contextWords, identityHashMap, identityHashMap2);
            double d = 0.0d;
            Iterator<Tree> it2 = identityHashMap2.keySet().iterator();
            while (it2.hasNext()) {
                d += identityHashMap2.get(it2.next()).doubleValue();
            }
            return d;
        } catch (AssertionError e) {
            log.info("Failed to correctly process tree " + tree);
            throw e;
        }
    }

    private void forwardPropagateTree(Tree tree, List<String> list, IdentityHashMap<Tree, SimpleMatrix> identityHashMap, IdentityHashMap<Tree, Double> identityHashMap2) {
        if (tree.isLeaf()) {
            return;
        }
        if (tree.isPreTerminal()) {
            identityHashMap.put(tree, NeuralUtils.elementwiseApplyTanh(this.dvModel.getWordVector(tree.children()[0].label().value())));
            return;
        }
        for (Tree tree2 : tree.children()) {
            forwardPropagateTree(tree2, list, identityHashMap, identityHashMap2);
        }
        SimpleMatrix concatenateWithBias = tree.children().length == 2 ? NeuralUtils.concatenateWithBias(identityHashMap.get(tree.children()[0]), identityHashMap.get(tree.children()[1])) : NeuralUtils.concatenateWithBias(identityHashMap.get(tree.children()[0]));
        if (this.op.trainOptions.useContextWords) {
            concatenateWithBias = concatenateContextWords(concatenateWithBias, tree.getSpan(), list);
        }
        SimpleMatrix wForNode = this.dvModel.getWForNode(tree);
        if (wForNode == null) {
            String str = "Could not find W for tree " + tree;
            if (this.op.testOptions.verbose) {
                log.info(str);
            }
            throw new NoSuchParseException(str);
        }
        SimpleMatrix elementwiseApplyTanh = NeuralUtils.elementwiseApplyTanh(wForNode.mult(concatenateWithBias));
        identityHashMap.put(tree, elementwiseApplyTanh);
        SimpleMatrix scoreWForNode = this.dvModel.getScoreWForNode(tree);
        if (scoreWForNode != null) {
            identityHashMap2.put(tree, Double.valueOf(scoreWForNode.dot(elementwiseApplyTanh)));
            return;
        }
        String str2 = "Could not find scoreW for tree " + tree;
        if (this.op.testOptions.verbose) {
            log.info(str2);
        }
        throw new NoSuchParseException(str2);
    }

    @Override // edu.stanford.nlp.optimization.Function
    public int domainDimension() {
        return this.dvModel.totalParamSize();
    }

    public List<DeepTree> getAllHighestScoringTreesTest(List<Tree> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<Tree> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList.add(getHighestScoringTree(it2.next(), 0.0d));
        }
        return arrayList;
    }

    public DeepTree getHighestScoringTree(Tree tree, double d) {
        List<Tree> list = this.topParses.get(tree);
        if (list == null || list.size() == 0) {
            throw new AssertionError("Failed to get any hypothesis trees for " + tree);
        }
        double d2 = Double.NEGATIVE_INFINITY;
        Tree tree2 = null;
        IdentityHashMap<Tree, SimpleMatrix> identityHashMap = null;
        for (Tree tree3 : list) {
            IdentityHashMap<Tree, SimpleMatrix> identityHashMap2 = new IdentityHashMap<>();
            double score = score(tree3, identityHashMap2) + (d != 0.0d ? this.op.trainOptions.deltaMargin * d * getMargin(tree, tree3) : 0.0d);
            if (tree2 == null || score > d2) {
                tree2 = tree3;
                d2 = score;
                identityHashMap = identityHashMap2;
            }
        }
        return new DeepTree(tree2, identityHashMap, d2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.stanford.nlp.optimization.AbstractCachingDiffFunction
    public void calculate(double[] dArr) {
        double[] paramsToVector;
        double[] paramsToVector2;
        this.dvModel.vectorToParams(dArr);
        double d = 0.0d;
        double[] dArr2 = new double[dArr.length];
        TwoDimensionalMap<String, String, SimpleMatrix> treeMap = TwoDimensionalMap.treeMap();
        TwoDimensionalMap<String, String, SimpleMatrix> treeMap2 = TwoDimensionalMap.treeMap();
        TwoDimensionalMap<String, String, SimpleMatrix> treeMap3 = TwoDimensionalMap.treeMap();
        TwoDimensionalMap<String, String, SimpleMatrix> treeMap4 = TwoDimensionalMap.treeMap();
        Map<String, SimpleMatrix> treeMap5 = new TreeMap<>();
        Map<String, SimpleMatrix> treeMap6 = new TreeMap<>();
        Map<String, SimpleMatrix> treeMap7 = new TreeMap<>();
        Map<String, SimpleMatrix> treeMap8 = new TreeMap<>();
        Map<String, SimpleMatrix> treeMap9 = new TreeMap<>();
        Map<String, SimpleMatrix> treeMap10 = new TreeMap<>();
        Iterator<TwoDimensionalMap.Entry<String, String, SimpleMatrix>> it2 = this.dvModel.binaryTransform.iterator();
        while (it2.hasNext()) {
            TwoDimensionalMap.Entry<String, String, SimpleMatrix> next = it2.next();
            int numRows = next.getValue().numRows();
            int numCols = next.getValue().numCols();
            treeMap.put(next.getFirstKey(), next.getSecondKey(), new SimpleMatrix(numRows, numCols));
            treeMap2.put(next.getFirstKey(), next.getSecondKey(), new SimpleMatrix(numRows, numCols));
            treeMap3.put(next.getFirstKey(), next.getSecondKey(), new SimpleMatrix(1, numRows));
            treeMap4.put(next.getFirstKey(), next.getSecondKey(), new SimpleMatrix(1, numRows));
        }
        for (Map.Entry<String, SimpleMatrix> entry : this.dvModel.unaryTransform.entrySet()) {
            int numRows2 = entry.getValue().numRows();
            int numCols2 = entry.getValue().numCols();
            treeMap5.put(entry.getKey(), new SimpleMatrix(numRows2, numCols2));
            treeMap6.put(entry.getKey(), new SimpleMatrix(numRows2, numCols2));
            treeMap7.put(entry.getKey(), new SimpleMatrix(1, numRows2));
            treeMap8.put(entry.getKey(), new SimpleMatrix(1, numRows2));
        }
        if (this.op.trainOptions.trainWordVectors) {
            for (Map.Entry<String, SimpleMatrix> entry2 : this.dvModel.wordVectors.entrySet()) {
                int numRows3 = entry2.getValue().numRows();
                int numCols3 = entry2.getValue().numCols();
                treeMap9.put(entry2.getKey(), new SimpleMatrix(numRows3, numCols3));
                treeMap10.put(entry2.getKey(), new SimpleMatrix(numRows3, numCols3));
            }
        }
        Timing timing = new Timing();
        timing.doing("Scoring trees");
        int i = 0;
        MulticoreWrapper multicoreWrapper = new MulticoreWrapper(this.op.trainOptions.trainingThreads, new ScoringProcessor());
        Iterator<Tree> it3 = this.trainingBatch.iterator();
        while (it3.hasNext()) {
            multicoreWrapper.put(it3.next());
        }
        multicoreWrapper.join();
        timing.done();
        while (multicoreWrapper.peek()) {
            Pair pair = (Pair) multicoreWrapper.poll();
            DeepTree deepTree = (DeepTree) pair.first;
            DeepTree deepTree2 = (DeepTree) pair.second;
            StringBuilder sb = new StringBuilder();
            Formatter formatter = new Formatter(sb);
            boolean z = Math.abs(deepTree2.getScore() - deepTree.getScore()) <= 1.0E-5d || deepTree.getScore() > deepTree2.getScore();
            formatter.format("Tree %6d Highest tree: %12.4f Correct tree: %12.4f %s", Integer.valueOf(i), Double.valueOf(deepTree2.getScore()), Double.valueOf(deepTree.getScore()), z ? "done" : "");
            log.info(sb.toString());
            if (!z) {
                d += deepTree2.getScore() - deepTree.getScore();
                List<String> contextWords = getContextWords(deepTree.getTree());
                backpropDerivative(deepTree.getTree(), contextWords, deepTree.getVectors(), treeMap, treeMap5, treeMap3, treeMap7, treeMap9);
                backpropDerivative(deepTree2.getTree(), contextWords, deepTree2.getVectors(), treeMap2, treeMap6, treeMap4, treeMap8, treeMap10);
            }
            i++;
        }
        if (this.op.trainOptions.trainWordVectors) {
            paramsToVector = NeuralUtils.paramsToVector(dArr.length, treeMap.valueIterator(), treeMap5.values().iterator(), treeMap3.valueIterator(), treeMap7.values().iterator(), treeMap9.values().iterator());
            paramsToVector2 = NeuralUtils.paramsToVector(dArr.length, treeMap2.valueIterator(), treeMap6.values().iterator(), treeMap4.valueIterator(), treeMap8.values().iterator(), treeMap10.values().iterator());
        } else {
            paramsToVector = NeuralUtils.paramsToVector(dArr.length, treeMap.valueIterator(), treeMap5.values().iterator(), treeMap3.valueIterator(), treeMap7.values().iterator());
            paramsToVector2 = NeuralUtils.paramsToVector(dArr.length, treeMap2.valueIterator(), treeMap6.values().iterator(), treeMap4.valueIterator(), treeMap8.values().iterator());
        }
        for (int i2 = 0; i2 < paramsToVector.length; i2++) {
            dArr2[i2] = paramsToVector2[i2] - paramsToVector[i2];
        }
        this.value = d;
        this.derivative = dArr2;
        this.value = (1.0d / this.trainingBatch.size()) * this.value;
        ArrayMath.multiplyInPlace(this.derivative, 1.0d / this.trainingBatch.size());
        double[] paramsToVector3 = this.dvModel.paramsToVector();
        double d2 = 0.0d;
        for (double d3 : paramsToVector3) {
            d2 += d3 * d3;
        }
        this.value += this.op.trainOptions.regCost * 0.5d * d2;
        ArrayMath.multiplyInPlace(paramsToVector3, this.op.trainOptions.regCost);
        ArrayMath.pairwiseAddInPlace(this.derivative, paramsToVector3);
    }

    public double getMargin(Tree tree, Tree tree2) {
        return TreeSpanScoring.countSpanErrors(this.op.langpack(), tree, tree2);
    }

    public void backpropDerivative(Tree tree, List<String> list, IdentityHashMap<Tree, SimpleMatrix> identityHashMap, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap, Map<String, SimpleMatrix> map, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap2, Map<String, SimpleMatrix> map2, Map<String, SimpleMatrix> map3) {
        backpropDerivative(tree, list, identityHashMap, twoDimensionalMap, map, twoDimensionalMap2, map2, map3, new SimpleMatrix(this.op.lexOptions.numHid, 1));
    }

    public void backpropDerivative(Tree tree, List<String> list, IdentityHashMap<Tree, SimpleMatrix> identityHashMap, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap, Map<String, SimpleMatrix> map, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap2, Map<String, SimpleMatrix> map2, Map<String, SimpleMatrix> map3, SimpleMatrix simpleMatrix) {
        if (tree.isLeaf()) {
            return;
        }
        if (tree.isPreTerminal()) {
            if (this.op.trainOptions.trainWordVectors) {
                String vocabWord = this.dvModel.getVocabWord(tree.children()[0].label().value());
                map3.put(vocabWord, map3.get(vocabWord).plus(simpleMatrix));
                return;
            }
            return;
        }
        SimpleMatrix simpleMatrix2 = identityHashMap.get(tree);
        SimpleMatrix plus = simpleMatrix.plus(NeuralUtils.elementwiseApplyTanhDerivative(simpleMatrix2).elementMult(this.dvModel.getScoreWForNode(tree).transpose()));
        SimpleMatrix mult = this.dvModel.getWForNode(tree).transpose().mult(plus);
        if (tree.children().length != 2) {
            if (tree.children().length == 1) {
                String basicCategory = this.dvModel.basicCategory(tree.children()[0].label().value());
                map2.put(basicCategory, map2.get(basicCategory).plus(simpleMatrix2.transpose()));
                SimpleMatrix simpleMatrix3 = identityHashMap.get(tree.children()[0]);
                SimpleMatrix concatenateWithBias = NeuralUtils.concatenateWithBias(simpleMatrix3);
                if (this.op.trainOptions.useContextWords) {
                    concatenateWithBias = concatenateContextWords(concatenateWithBias, tree.getSpan(), list);
                }
                map.put(basicCategory, map.get(basicCategory).plus(plus.mult(concatenateWithBias.transpose())));
                backpropDerivative(tree.children()[0], list, identityHashMap, twoDimensionalMap, map, twoDimensionalMap2, map2, map3, NeuralUtils.elementwiseApplyTanhDerivative(simpleMatrix3).elementMult(mult.extractMatrix(0, plus.numRows(), 0, 1)));
                return;
            }
            return;
        }
        String basicCategory2 = this.dvModel.basicCategory(tree.children()[0].label().value());
        String basicCategory3 = this.dvModel.basicCategory(tree.children()[1].label().value());
        twoDimensionalMap2.put(basicCategory2, basicCategory3, twoDimensionalMap2.get(basicCategory2, basicCategory3).plus(simpleMatrix2.transpose()));
        SimpleMatrix simpleMatrix4 = identityHashMap.get(tree.children()[0]);
        SimpleMatrix simpleMatrix5 = identityHashMap.get(tree.children()[1]);
        SimpleMatrix concatenateWithBias2 = NeuralUtils.concatenateWithBias(simpleMatrix4, simpleMatrix5);
        if (this.op.trainOptions.useContextWords) {
            concatenateWithBias2 = concatenateContextWords(concatenateWithBias2, tree.getSpan(), list);
        }
        twoDimensionalMap.put(basicCategory2, basicCategory3, twoDimensionalMap.get(basicCategory2, basicCategory3).plus(plus.mult(concatenateWithBias2.transpose())));
        SimpleMatrix elementwiseApplyTanhDerivative = NeuralUtils.elementwiseApplyTanhDerivative(simpleMatrix4);
        SimpleMatrix elementwiseApplyTanhDerivative2 = NeuralUtils.elementwiseApplyTanhDerivative(simpleMatrix5);
        SimpleMatrix extractMatrix = mult.extractMatrix(0, plus.numRows(), 0, 1);
        SimpleMatrix extractMatrix2 = mult.extractMatrix(plus.numRows(), plus.numRows() * 2, 0, 1);
        backpropDerivative(tree.children()[0], list, identityHashMap, twoDimensionalMap, map, twoDimensionalMap2, map2, map3, elementwiseApplyTanhDerivative.elementMult(extractMatrix));
        backpropDerivative(tree.children()[1], list, identityHashMap, twoDimensionalMap, map, twoDimensionalMap2, map2, map3, elementwiseApplyTanhDerivative2.elementMult(extractMatrix2));
    }
}
