package edu.stanford.nlp.coref.neural;

import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import java.io.Serializable;
import java.util.List;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/coref/neural/NeuralCorefModel.class */
public class NeuralCorefModel implements Serializable {
    private static final long serialVersionUID = 2139427931784505653L;
    private final SimpleMatrix antecedentMatrix;
    private final SimpleMatrix anaphorMatrix;
    private final SimpleMatrix pairFeaturesMatrix;
    private final SimpleMatrix pairwiseFirstLayerBias;
    private final List<SimpleMatrix> anaphoricityModel;
    private final List<SimpleMatrix> pairwiseModel;
    private final Embedding wordEmbeddings;

    public NeuralCorefModel(SimpleMatrix simpleMatrix, SimpleMatrix simpleMatrix2, SimpleMatrix simpleMatrix3, SimpleMatrix simpleMatrix4, List<SimpleMatrix> list, List<SimpleMatrix> list2, Embedding embedding) {
        this.antecedentMatrix = simpleMatrix;
        this.anaphorMatrix = simpleMatrix2;
        this.pairFeaturesMatrix = simpleMatrix3;
        this.pairwiseFirstLayerBias = simpleMatrix4;
        this.anaphoricityModel = list;
        this.pairwiseModel = list2;
        this.wordEmbeddings = embedding;
    }

    public double getAnaphoricityScore(SimpleMatrix simpleMatrix, SimpleMatrix simpleMatrix2) {
        return score(NeuralUtils.concatenate(simpleMatrix, simpleMatrix2), this.anaphoricityModel);
    }

    public double getPairwiseScore(SimpleMatrix simpleMatrix, SimpleMatrix simpleMatrix2, SimpleMatrix simpleMatrix3) {
        return score(NeuralUtils.elementwiseApplyReLU(simpleMatrix.plus(simpleMatrix2).plus(this.pairFeaturesMatrix.mult(simpleMatrix3)).plus(this.pairwiseFirstLayerBias)), this.pairwiseModel);
    }

    private static double score(SimpleMatrix simpleMatrix, List<SimpleMatrix> list) {
        for (int i = 0; i < list.size(); i += 2) {
            simpleMatrix = list.get(i).mult(simpleMatrix).plus(list.get(i + 1));
            if (list.get(i).numRows() > 1) {
                simpleMatrix = NeuralUtils.elementwiseApplyReLU(simpleMatrix);
            }
        }
        return simpleMatrix.elementSum();
    }

    public SimpleMatrix getAnaphorEmbedding(SimpleMatrix simpleMatrix) {
        return this.anaphorMatrix.mult(simpleMatrix);
    }

    public SimpleMatrix getAntecedentEmbedding(SimpleMatrix simpleMatrix) {
        return this.antecedentMatrix.mult(simpleMatrix);
    }

    public Embedding getWordEmbeddings() {
        return this.wordEmbeddings;
    }
}
