/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.coref.fastneural;

import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.coref.neural.CategoricalFeatureExtractor;
import edu.stanford.nlp.coref.neural.EmbeddingExtractor;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.stats.Counter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class FastNeuralCorefModel
implements Serializable {
    private static final long serialVersionUID = 8663264823377059140L;
    private final EmbeddingExtractor embeddingExtractor;
    private final Map<String, Integer> pairFeatureIds;
    private final Map<String, Integer> mentionFeatureIds;
    private SimpleMatrix anaphorKernel;
    private SimpleMatrix anaphorBias;
    private SimpleMatrix antecedentKernel;
    private SimpleMatrix antecedentBias;
    private SimpleMatrix pairFeaturesKernel;
    private SimpleMatrix pairFeaturesBias;
    private SimpleMatrix NARepresentation;
    private List<SimpleMatrix> networkLayers;

    public FastNeuralCorefModel(EmbeddingExtractor embeddingExtractor, Map<String, Integer> pairFeatureIds, Map<String, Integer> mentionFeatureIds, List<SimpleMatrix> weights) {
        this.embeddingExtractor = embeddingExtractor;
        this.pairFeatureIds = pairFeatureIds;
        this.mentionFeatureIds = mentionFeatureIds;
        this.anaphorKernel = weights.get(0);
        this.anaphorBias = weights.get(1);
        this.antecedentKernel = weights.get(2);
        this.antecedentBias = weights.get(3);
        this.pairFeaturesKernel = weights.get(4);
        this.pairFeaturesBias = weights.get(5);
        this.NARepresentation = weights.get(6);
        this.networkLayers = new ArrayList<SimpleMatrix>(weights.subList(7, weights.size()));
    }

    public EmbeddingExtractor getEmbeddingExtractor() {
        return this.embeddingExtractor;
    }

    public Map<String, Integer> getPairFeatureIds() {
        return Collections.unmodifiableMap(this.pairFeatureIds);
    }

    public Map<String, Integer> getMentionFeatureIds() {
        return Collections.unmodifiableMap(this.mentionFeatureIds);
    }

    public List<SimpleMatrix> getAllWeights() {
        ArrayList<SimpleMatrix> weights = new ArrayList<SimpleMatrix>();
        weights.add(this.anaphorKernel);
        weights.add(this.anaphorBias);
        weights.add(this.antecedentKernel);
        weights.add(this.anaphorBias);
        weights.add(this.pairFeaturesKernel);
        weights.add(this.pairFeaturesBias);
        weights.add(this.NARepresentation);
        weights.addAll(this.networkLayers);
        return Collections.unmodifiableList(weights);
    }

    public double score(Mention antecedent, Mention anaphor, Counter<String> antecedentFeatures, Counter<String> anaphorFeatures, Counter<String> pairFeatures, Map<Integer, SimpleMatrix> antecedentCache, Map<Integer, SimpleMatrix> anaphorCache) {
        SimpleMatrix anaphorVector;
        SimpleMatrix antecedentVector = this.NARepresentation;
        if (antecedent != null && (antecedentVector = antecedentCache.get(antecedent.mentionID)) == null) {
            antecedentVector = (SimpleMatrix)((SimpleMatrix)this.antecedentKernel.mult((SimpleBase)NeuralUtils.concatenate(this.embeddingExtractor.getMentionEmbeddingsForFast(antecedent), this.makeFeatureVector(antecedentFeatures, this.mentionFeatureIds)))).plus((SimpleBase)this.antecedentBias);
            antecedentCache.put(antecedent.mentionID, antecedentVector);
        }
        if ((anaphorVector = anaphorCache.get(anaphor.mentionID)) == null) {
            anaphorVector = (SimpleMatrix)((SimpleMatrix)this.anaphorKernel.mult((SimpleBase)NeuralUtils.concatenate(this.embeddingExtractor.getMentionEmbeddingsForFast(anaphor), this.makeFeatureVector(anaphorFeatures, this.mentionFeatureIds)))).plus((SimpleBase)this.anaphorBias);
            anaphorCache.put(anaphor.mentionID, anaphorVector);
        }
        SimpleMatrix pairFeaturesVector = (SimpleMatrix)((SimpleMatrix)this.pairFeaturesKernel.mult((SimpleBase)(pairFeatures == null ? new SimpleMatrix(this.pairFeatureIds.size() + 23, 1) : this.addDistanceFeatures(this.makeFeatureVector(pairFeatures, this.pairFeatureIds), antecedent, anaphor)))).plus((SimpleBase)this.pairFeaturesBias);
        SimpleMatrix pairVector = (SimpleMatrix)((SimpleMatrix)antecedentVector.concatRows(new SimpleBase[]{anaphorVector})).concatRows(new SimpleBase[]{pairFeaturesVector});
        pairVector = NeuralUtils.elementwiseApplyReLU(pairVector);
        for (int i = 0; i < this.networkLayers.size(); i += 2) {
            pairVector = (SimpleMatrix)((SimpleMatrix)this.networkLayers.get(i).mult((SimpleBase)pairVector)).plus((SimpleBase)this.networkLayers.get(i + 1));
            if (this.networkLayers.get(i).numRows() <= 1) continue;
            pairVector = NeuralUtils.elementwiseApplyReLU(pairVector);
        }
        return pairVector.elementSum();
    }

    private SimpleMatrix makeFeatureVector(Counter<String> features, Map<String, Integer> featureIds) {
        SimpleMatrix featureVector = new SimpleMatrix(featureIds.size(), 1);
        for (Map.Entry<String, Double> feature : features.entrySet()) {
            if (!featureIds.containsKey(feature.getKey())) continue;
            featureVector.set(featureIds.get(feature.getKey()).intValue(), feature.getValue().doubleValue());
        }
        return featureVector;
    }

    private SimpleMatrix addDistanceFeatures(SimpleMatrix featureVector, Mention antecedent, Mention anaphor) {
        return NeuralUtils.concatenate(featureVector, CategoricalFeatureExtractor.encodeDistance(anaphor.sentNum - antecedent.sentNum), CategoricalFeatureExtractor.encodeDistance(anaphor.mentionNum - antecedent.mentionNum - 1), new SimpleMatrix((double[][])new double[][]{{antecedent.sentNum == anaphor.sentNum && antecedent.endIndex > anaphor.startIndex ? 1.0 : 0.0}}));
    }

    public static FastNeuralCorefModel loadFromTextFiles(String path) {
        List<SimpleMatrix> weights = NeuralUtils.loadTextMatrices(path + "weights.txt");
        weights.set(weights.size() - 2, (SimpleMatrix)weights.get(weights.size() - 2).transpose());
        Embedding embeddings = new Embedding(path + "embeddings.txt");
        EmbeddingExtractor extractor = new EmbeddingExtractor(false, null, embeddings, "<missing>");
        Map<String, Integer> pairFeatureIds = FastNeuralCorefModel.loadMapFromTextFile(path + "pair_features.txt");
        Map<String, Integer> mentionFeatureIds = FastNeuralCorefModel.loadMapFromTextFile(path + "mention_features.txt");
        return new FastNeuralCorefModel(extractor, pairFeatureIds, mentionFeatureIds, weights);
    }

    public static Map<String, Integer> loadMapFromTextFile(String filename) {
        HashMap<String, Integer> dict = new HashMap<String, Integer>();
        for (String line : IOUtils.readLines(filename, "utf-8")) {
            String[] lineSplit = line.split("\\s+");
            assert (lineSplit.length == 2);
            dict.put(lineSplit[0], Integer.parseInt(lineSplit[1]));
        }
        return dict;
    }
}

