/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFCliqueTree;
import edu.stanford.nlp.ie.crf.CRFLabel;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffUpdateFunction;
import edu.stanford.nlp.util.Index;
import java.util.Arrays;

public class CRFLogConditionalObjectiveFunction
extends AbstractStochasticCachingDiffUpdateFunction {
    public static final int NO_PRIOR = 0;
    public static final int QUADRATIC_PRIOR = 1;
    public static final int HUBER_PRIOR = 2;
    public static final int QUARTIC_PRIOR = 3;
    private final int prior;
    private final double sigma;
    private final double epsilon = 0.1;
    private final Index<CRFLabel>[] labelIndices;
    private final Index<String> classIndex;
    private final Index featureIndex;
    private final double[][] Ehat;
    private final int window;
    private final int numClasses;
    private final int[] map;
    private final int[][][][] data;
    private final int[][] labels;
    private final int domainDimension;
    private int[][] weightIndices;
    private final String backgroundSymbol;
    public static boolean VERBOSE = false;

    public static int getPriorType(String priorTypeStr) {
        if (priorTypeStr == null) {
            return 1;
        }
        if ("QUADRATIC".equalsIgnoreCase(priorTypeStr)) {
            return 1;
        }
        if ("HUBER".equalsIgnoreCase(priorTypeStr)) {
            return 2;
        }
        if ("QUARTIC".equalsIgnoreCase(priorTypeStr)) {
            return 3;
        }
        if ("NONE".equalsIgnoreCase(priorTypeStr)) {
            return 0;
        }
        throw new IllegalArgumentException("Unknown prior type: " + priorTypeStr);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, Index featureIndex, int window, Index<String> classIndex, Index[] labelIndices, int[] map, String backgroundSymbol) {
        this(data, labels, featureIndex, window, classIndex, labelIndices, map, 1, backgroundSymbol);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, Index featureIndex, int window, Index<String> classIndex, Index[] labelIndices, int[] map, String backgroundSymbol, double sigma) {
        this(data, labels, featureIndex, window, classIndex, labelIndices, map, 1, backgroundSymbol, sigma);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, Index featureIndex, int window, Index<String> classIndex, Index[] labelIndices, int[] map, int prior, String backgroundSymbol) {
        this(data, labels, featureIndex, window, classIndex, labelIndices, map, prior, backgroundSymbol, 1.0);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, Index featureIndex, int window, Index<String> classIndex, Index[] labelIndices, int[] map, int prior, String backgroundSymbol, double sigma) {
        this.featureIndex = featureIndex;
        this.window = window;
        this.classIndex = classIndex;
        this.numClasses = classIndex.size();
        this.labelIndices = labelIndices;
        this.map = map;
        this.data = data;
        this.labels = labels;
        this.prior = prior;
        this.backgroundSymbol = backgroundSymbol;
        this.sigma = sigma;
        this.Ehat = this.empty2D();
        this.empiricalCounts(data, labels);
        int myDomainDimension = 0;
        for (int dim : map) {
            myDomainDimension += labelIndices[dim].size();
        }
        this.domainDimension = myDomainDimension;
    }

    @Override
    public int domainDimension() {
        return this.domainDimension;
    }

    public static double[][] to2D(double[] weights, Index[] labelIndices, int[] map) {
        double[][] newWeights = new double[map.length][];
        int index = 0;
        for (int i = 0; i < map.length; ++i) {
            newWeights[i] = new double[labelIndices[map[i]].size()];
            System.arraycopy(weights, index, newWeights[i], 0, labelIndices[map[i]].size());
            index += labelIndices[map[i]].size();
        }
        return newWeights;
    }

    public double[][] to2D(double[] weights) {
        return CRFLogConditionalObjectiveFunction.to2D(weights, this.labelIndices, this.map);
    }

    public static double[] to1D(double[][] weights, int domainDimension) {
        double[] newWeights = new double[domainDimension];
        int index = 0;
        for (int i = 0; i < weights.length; ++i) {
            System.arraycopy(weights[i], 0, newWeights, index, weights[i].length);
            index += weights[i].length;
        }
        return newWeights;
    }

    public double[] to1D(double[][] weights) {
        return CRFLogConditionalObjectiveFunction.to1D(weights, this.domainDimension());
    }

    public int[][] getWeightIndices() {
        if (this.weightIndices == null) {
            this.weightIndices = new int[this.map.length][];
            int index = 0;
            for (int i = 0; i < this.map.length; ++i) {
                this.weightIndices[i] = new int[this.labelIndices[this.map[i]].size()];
                for (int j = 0; j < this.labelIndices[this.map[i]].size(); ++j) {
                    this.weightIndices[i][j] = index++;
                }
            }
        }
        return this.weightIndices;
    }

    private double[][] empty2D() {
        double[][] d = new double[this.map.length][];
        for (int i = 0; i < this.map.length; ++i) {
            d[i] = new double[this.labelIndices[this.map[i]].size()];
        }
        return d;
    }

    private void empiricalCounts(int[][][][] data, int[][] labels) {
        for (int m = 0; m < data.length; ++m) {
            int[][][] docData = data[m];
            int[] docLabels = labels[m];
            int[] windowLabels = new int[this.window];
            Arrays.fill(windowLabels, this.classIndex.indexOf(this.backgroundSymbol));
            if (docLabels.length > docData.length) {
                System.arraycopy(docLabels, 0, windowLabels, 0, windowLabels.length);
                int[] newDocLabels = new int[docData.length];
                System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
                docLabels = newDocLabels;
            }
            for (int i = 0; i < docData.length; ++i) {
                System.arraycopy(windowLabels, 1, windowLabels, 0, this.window - 1);
                windowLabels[this.window - 1] = docLabels[i];
                for (int j = 0; j < docData[i].length; ++j) {
                    int[] cliqueLabel = new int[j + 1];
                    System.arraycopy(windowLabels, this.window - 1 - j, cliqueLabel, 0, j + 1);
                    CRFLabel crfLabel = new CRFLabel(cliqueLabel);
                    int labelIndex = this.labelIndices[j].indexOf(crfLabel);
                    for (int k = 0; k < docData[i][j].length; ++k) {
                        double[] dArray = this.Ehat[docData[i][j][k]];
                        int n = labelIndex;
                        dArray[n] = dArray[n] + 1.0;
                    }
                }
            }
        }
    }

    @Override
    public void calculate(double[] x) {
        block17: {
            block18: {
                block16: {
                    double prob = 0.0;
                    double[][] weights = this.to2D(x);
                    double[][] E = this.empty2D();
                    for (int m = 0; m < this.data.length; ++m) {
                        int i;
                        int[][][] docData = this.data[m];
                        int[] docLabels = this.labels[m];
                        if (docLabels.length == 0) continue;
                        CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(weights, docData, this.labelIndices, this.numClasses, this.classIndex, this.backgroundSymbol);
                        int[] given = new int[this.window - 1];
                        Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
                        if (docLabels.length > docData.length) {
                            System.arraycopy(docLabels, 0, given, 0, given.length);
                            int[] newDocLabels = new int[docData.length];
                            System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
                            docLabels = newDocLabels;
                        }
                        for (i = 0; i < docData.length; ++i) {
                            int label = docLabels[i];
                            double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
                            if (VERBOSE) {
                                System.err.println("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p);
                            }
                            prob += p;
                            if (given.length == 0) continue;
                            System.arraycopy(given, 1, given, 0, given.length - 1);
                            given[given.length - 1] = label;
                        }
                        for (i = 0; i < this.data[m].length; ++i) {
                            for (int j = 0; j < this.data[m][i].length; ++j) {
                                Index<CRFLabel> labelIndex = this.labelIndices[j];
                                for (int k = 0; k < labelIndex.size(); ++k) {
                                    int[] label = labelIndex.get(k).getLabel();
                                    double p = cliqueTree.prob(i, label);
                                    for (int n = 0; n < this.data[m][i][j].length; ++n) {
                                        double[] dArray = E[this.data[m][i][j][n]];
                                        int n2 = k;
                                        dArray[n2] = dArray[n2] + p;
                                    }
                                }
                            }
                        }
                    }
                    if (Double.isNaN(prob)) {
                        throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate() - this may well indicate numeric underflow due to overly long documents.");
                    }
                    this.value = -prob;
                    if (VERBOSE) {
                        System.err.println("value is " + this.value);
                    }
                    int index = 0;
                    for (int i = 0; i < E.length; ++i) {
                        for (int j = 0; j < E[i].length; ++j) {
                            this.derivative[index++] = E[i][j] - this.Ehat[i][j];
                            if (!VERBOSE) continue;
                            System.err.println("deriv(" + i + "," + j + ") = " + E[i][j] + " - " + this.Ehat[i][j] + " = " + this.derivative[index - 1]);
                        }
                    }
                    if (this.prior != 1) break block16;
                    double sigmaSq = this.sigma * this.sigma;
                    int i = 0;
                    while (i < x.length) {
                        double k = 1.0;
                        double w = x[i];
                        this.value += k * w * w / 2.0 / sigmaSq;
                        int n = i++;
                        this.derivative[n] = this.derivative[n] + k * w / sigmaSq;
                    }
                    break block17;
                }
                if (this.prior != 2) break block18;
                double sigmaSq = this.sigma * this.sigma;
                for (int i = 0; i < x.length; ++i) {
                    double w = x[i];
                    double wabs = Math.abs(w);
                    if (wabs < 0.1) {
                        this.value += w * w / 2.0 / 0.1 / sigmaSq;
                        int n = i;
                        this.derivative[n] = this.derivative[n] + w / 0.1 / sigmaSq;
                        continue;
                    }
                    this.value += (wabs - 0.05) / sigmaSq;
                    int n = i;
                    this.derivative[n] = this.derivative[n] + (w < 0.0 ? -1.0 : 1.0) / sigmaSq;
                }
                break block17;
            }
            if (this.prior != 3) break block17;
            double sigmaQu = this.sigma * this.sigma * this.sigma * this.sigma;
            int i = 0;
            while (i < x.length) {
                double k = 1.0;
                double w = x[i];
                this.value += k * w * w * w * w / 2.0 / sigmaQu;
                int n = i++;
                this.derivative[n] = this.derivative[n] + k * w / sigmaQu;
            }
        }
    }

    @Override
    public void calculateStochastic(double[] x, double[] v, int[] batch) {
        this.calculateStochasticGradientOnly(x, batch);
    }

    @Override
    public int dataDimension() {
        return this.data.length;
    }

    public void calculateStochasticGradientOnly(double[] x, int[] batch) {
        block16: {
            int i;
            double batchScale;
            block17: {
                block15: {
                    double prob = 0.0;
                    double[][] weights = this.to2D(x);
                    batchScale = (double)batch.length / (double)this.dataDimension();
                    double[][] E = this.empty2D();
                    for (int ind : batch) {
                        int i2;
                        int[][][] docData = this.data[ind];
                        int[] docLabels = this.labels[ind];
                        CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(weights, docData, this.labelIndices, this.numClasses, this.classIndex, this.backgroundSymbol);
                        int[] given = new int[this.window - 1];
                        Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
                        if (docLabels.length > docData.length) {
                            System.arraycopy(docLabels, 0, given, 0, given.length);
                            int[] newDocLabels = new int[docData.length];
                            System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
                            docLabels = newDocLabels;
                        }
                        for (i2 = 0; i2 < docData.length; ++i2) {
                            int label = docLabels[i2];
                            double p = cliqueTree.condLogProbGivenPrevious(i2, label, given);
                            if (VERBOSE) {
                                System.err.println("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p);
                            }
                            prob += p;
                            System.arraycopy(given, 1, given, 0, given.length - 1);
                            given[given.length - 1] = label;
                        }
                        for (i2 = 0; i2 < this.data[ind].length; ++i2) {
                            for (int j = 0; j < this.data[ind][i2].length; ++j) {
                                Index<CRFLabel> labelIndex = this.labelIndices[j];
                                for (int k = 0; k < labelIndex.size(); ++k) {
                                    int[] label = labelIndex.get(k).getLabel();
                                    double p = cliqueTree.prob(i2, label);
                                    for (int n = 0; n < this.data[ind][i2][j].length; ++n) {
                                        double[] dArray = E[this.data[ind][i2][j][n]];
                                        int n2 = k;
                                        dArray[n2] = dArray[n2] + p;
                                    }
                                }
                            }
                        }
                    }
                    if (Double.isNaN(prob)) {
                        throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
                    }
                    this.value = -prob;
                    int index = 0;
                    for (int i3 = 0; i3 < E.length; ++i3) {
                        for (int j = 0; j < E[i3].length; ++j) {
                            this.derivative[index++] = E[i3][j] - batchScale * this.Ehat[i3][j];
                            if (!VERBOSE) continue;
                            System.err.println("deriv(" + i3 + "," + j + ") = " + E[i3][j] + " - " + this.Ehat[i3][j] + " = " + this.derivative[index - 1]);
                        }
                    }
                    if (this.prior != 1) break block15;
                    double sigmaSq = this.sigma * this.sigma;
                    i = 0;
                    while (i < x.length) {
                        double k = 1.0;
                        double w = x[i];
                        this.value += batchScale * k * w * w / 2.0 / sigmaSq;
                        int n = i++;
                        this.derivative[n] = this.derivative[n] + batchScale * k * w / sigmaSq;
                    }
                    break block16;
                }
                if (this.prior != 2) break block17;
                double sigmaSq = this.sigma * this.sigma;
                for (i = 0; i < x.length; ++i) {
                    double w = x[i];
                    double wabs = Math.abs(w);
                    if (wabs < 0.1) {
                        this.value += batchScale * w * w / 2.0 / 0.1 / sigmaSq;
                        int n = i;
                        this.derivative[n] = this.derivative[n] + batchScale * w / 0.1 / sigmaSq;
                        continue;
                    }
                    this.value += batchScale * (wabs - 0.05) / sigmaSq;
                    int n = i;
                    this.derivative[n] = this.derivative[n] + batchScale * (w < 0.0 ? -1.0 : 1.0) / sigmaSq;
                }
                break block16;
            }
            if (this.prior != 3) break block16;
            double sigmaQu = this.sigma * this.sigma * this.sigma * this.sigma;
            i = 0;
            while (i < x.length) {
                double k = 1.0;
                double w = x[i];
                this.value += batchScale * k * w * w * w * w / 2.0 / sigmaQu;
                int n = i++;
                this.derivative[n] = this.derivative[n] + batchScale * k * w / sigmaQu;
            }
        }
    }

    @Override
    public double calculateStochasticUpdate(double[] x, double xscale, int[] batch, double gscale) {
        double prob = 0.0;
        int[][] wis = this.getWeightIndices();
        int[] given = new int[this.window - 1];
        int[][] docCliqueLabels = new int[this.window][];
        for (int j = 0; j < this.window; ++j) {
            docCliqueLabels[j] = new int[j + 1];
        }
        for (int ind : batch) {
            int n;
            int i;
            int[][][] docData = this.data[ind];
            int[] docLabels = this.labels[ind];
            CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(x, xscale, wis, docData, this.labelIndices, this.numClasses, this.classIndex, this.backgroundSymbol);
            Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
            if (docLabels.length > docData.length) {
                System.arraycopy(docLabels, 0, given, 0, given.length);
                int[] newDocLabels = new int[docData.length];
                System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
                docLabels = newDocLabels;
            }
            for (i = 0; i < docData.length; ++i) {
                int label = docLabels[i];
                double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
                if (VERBOSE) {
                    System.err.println("P(" + label + '|' + ArrayMath.toString(given) + ")=" + p);
                }
                prob += p;
                for (int j = 0; j < this.data[ind][i].length; ++j) {
                    if (j > 0) {
                        System.arraycopy(given, this.window - j - 1, docCliqueLabels[j], 0, j);
                    }
                    docCliqueLabels[j][j] = label;
                    CRFLabel crfLabel = new CRFLabel(docCliqueLabels[j]);
                    int correctLabelIndex = this.labelIndices[j].indexOf(crfLabel);
                    for (n = 0; n < this.data[ind][i][j].length; ++n) {
                        int n2 = wis[this.data[ind][i][j][n]][correctLabelIndex];
                        x[n2] = x[n2] + gscale;
                    }
                }
                System.arraycopy(given, 1, given, 0, given.length - 1);
                given[given.length - 1] = label;
            }
            for (i = 0; i < this.data[ind].length; ++i) {
                for (int j = 0; j < this.data[ind][i].length; ++j) {
                    Index<CRFLabel> labelIndex = this.labelIndices[j];
                    for (int k = 0; k < labelIndex.size(); ++k) {
                        int[] label = labelIndex.get(k).getLabel();
                        double p = cliqueTree.prob(i, label);
                        for (n = 0; n < this.data[ind][i][j].length; ++n) {
                            int n3 = wis[docData[i][j][n]][k];
                            x[n3] = x[n3] - p * gscale;
                        }
                    }
                }
            }
        }
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        return this.value;
    }

    @Override
    public double valueAt(double[] x, double xscale, int[] batch) {
        double prob = 0.0;
        int[][] wis = this.getWeightIndices();
        int[] given = new int[this.window - 1];
        int[][] docCliqueLabels = new int[this.window][];
        for (int j = 0; j < this.window; ++j) {
            docCliqueLabels[j] = new int[j + 1];
        }
        for (int ind : batch) {
            int[][][] docData = this.data[ind];
            int[] docLabels = this.labels[ind];
            CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(x, xscale, wis, docData, this.labelIndices, this.numClasses, this.classIndex, this.backgroundSymbol);
            Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
            if (docLabels.length > docData.length) {
                System.arraycopy(docLabels, 0, given, 0, given.length);
                int[] newDocLabels = new int[docData.length];
                System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
                docLabels = newDocLabels;
            }
            for (int i = 0; i < docData.length; ++i) {
                int label = docLabels[i];
                double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
                if (VERBOSE) {
                    System.err.println("P(" + label + '|' + ArrayMath.toString(given) + ")=" + p);
                }
                prob += p;
                System.arraycopy(given, 1, given, 0, given.length - 1);
                given[given.length - 1] = label;
            }
        }
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        return this.value;
    }
}

