/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.math.ArrayMath;
import java.io.Serializable;

public class LogPrior
implements Serializable {
    private static final long serialVersionUID = 7826853908892790965L;
    private double[] means = null;
    private LogPrior otherPrior = null;
    public final LogPriorType type;
    private double[] sigmaSqM = null;
    private double[] sigmaQuM = null;
    private double sigmaSq;
    private double sigmaQu;
    private double epsilon;

    public static LogPriorType getType(String name) {
        if (name.equalsIgnoreCase("null")) {
            return LogPriorType.NULL;
        }
        if (name.equalsIgnoreCase("quadratic")) {
            return LogPriorType.QUADRATIC;
        }
        if (name.equalsIgnoreCase("huber")) {
            return LogPriorType.HUBER;
        }
        if (name.equalsIgnoreCase("quartic")) {
            return LogPriorType.QUARTIC;
        }
        if (name.equalsIgnoreCase("cosh")) {
            return LogPriorType.COSH;
        }
        throw new RuntimeException("Unknown LogPriorType: " + name);
    }

    public static LogPrior getAdaptationPrior(double[] means, LogPrior otherPrior) {
        LogPrior lp = new LogPrior(LogPriorType.ADAPT);
        lp.means = means;
        lp.otherPrior = otherPrior;
        return lp;
    }

    public LogPriorType getType() {
        return this.type;
    }

    public LogPrior() {
        this(LogPriorType.QUADRATIC);
    }

    public LogPrior(int intPrior) {
        this(intPrior, 1.0, 0.1);
    }

    public LogPrior(LogPriorType type) {
        this(type, 1.0, 0.1);
    }

    private static LogPriorType intToType(int intPrior) {
        LogPriorType[] values;
        for (LogPriorType val : values = LogPriorType.values()) {
            if (val.ordinal() != intPrior) continue;
            return val;
        }
        throw new IllegalArgumentException(intPrior + " is not a legal LogPrior.");
    }

    public LogPrior(int intPrior, double sigma, double epsilon) {
        this(LogPrior.intToType(intPrior), sigma, epsilon);
    }

    public LogPrior(LogPriorType type, double sigma, double epsilon) {
        this.type = type;
        if (type != LogPriorType.ADAPT) {
            this.setSigma(sigma);
            this.setEpsilon(epsilon);
        }
    }

    public LogPrior(double[] C) {
        this.type = LogPriorType.MULTIPLE_QUADRATIC;
        double[] sigmaSqM = new double[C.length];
        for (int i = 0; i < C.length; ++i) {
            sigmaSqM[i] = 1.0 / C[i];
        }
        this.sigmaSqM = sigmaSqM;
        this.setSigmaSquaredM(sigmaSqM);
    }

    public double getSigma() {
        if (this.type == LogPriorType.ADAPT) {
            return this.otherPrior.getSigma();
        }
        return Math.sqrt(this.sigmaSq);
    }

    public double getSigmaSquared() {
        if (this.type == LogPriorType.ADAPT) {
            return this.otherPrior.getSigmaSquared();
        }
        return this.sigmaSq;
    }

    public double[] getSigmaSquaredM() {
        if (this.type == LogPriorType.MULTIPLE_QUADRATIC) {
            return this.sigmaSqM;
        }
        throw new RuntimeException("LogPrior.getSigmaSquaredM is undefined for any prior but MULTIPLE_QUADRATIC" + this);
    }

    public double getEpsilon() {
        if (this.type == LogPriorType.ADAPT) {
            return this.otherPrior.getEpsilon();
        }
        return this.epsilon;
    }

    public void setSigma(double sigma) {
        if (this.type == LogPriorType.ADAPT) {
            this.otherPrior.setSigma(sigma);
        } else {
            this.sigmaSq = sigma * sigma;
            this.sigmaQu = this.sigmaSq * this.sigmaSq;
        }
    }

    public void setSigmaSquared(double sigmaSq) {
        if (this.type == LogPriorType.ADAPT) {
            this.otherPrior.setSigmaSquared(sigmaSq);
        } else {
            this.sigmaSq = sigmaSq;
            this.sigmaQu = sigmaSq * sigmaSq;
        }
    }

    public void setSigmaSquaredM(double[] sigmaSq) {
        double[] sigmaQuM;
        if (this.type == LogPriorType.ADAPT) {
            this.otherPrior.setSigmaSquaredM(sigmaSq);
        }
        if (this.type == LogPriorType.MULTIPLE_QUADRATIC) {
            this.sigmaSqM = (double[])sigmaSq.clone();
            sigmaQuM = new double[sigmaSq.length];
            for (int i = 0; i < sigmaSq.length; ++i) {
                sigmaQuM[i] = this.sigmaSqM[i] * this.sigmaSqM[i];
            }
        } else {
            throw new RuntimeException("LogPrior.getSigmaSquaredM is undefined for any prior but MULTIPLE_QUADRATIC" + this);
        }
        this.sigmaQuM = sigmaQuM;
    }

    public void setEpsilon(double epsilon) {
        if (this.type == LogPriorType.ADAPT) {
            this.otherPrior.setEpsilon(epsilon);
        } else {
            this.epsilon = epsilon;
        }
    }

    public double computeStochastic(double[] x, double[] grad, double fractionOfData) {
        if (this.type == LogPriorType.ADAPT) {
            double[] newX = ArrayMath.pairwiseSubtract(x, this.means);
            return this.otherPrior.computeStochastic(newX, grad, fractionOfData);
        }
        if (this.type == LogPriorType.MULTIPLE_QUADRATIC) {
            double[] sigmaSquaredOld = this.getSigmaSquaredM();
            double[] sigmaSquaredTemp = (double[])sigmaSquaredOld.clone();
            int i = 0;
            while (i < x.length) {
                int n = i++;
                sigmaSquaredTemp[n] = sigmaSquaredTemp[n] / fractionOfData;
            }
            this.setSigmaSquaredM(sigmaSquaredTemp);
            double val = this.compute(x, grad);
            this.setSigmaSquaredM(sigmaSquaredOld);
            return val;
        }
        double sigmaSquaredOld = this.getSigmaSquared();
        this.setSigmaSquared(sigmaSquaredOld / fractionOfData);
        double val = this.compute(x, grad);
        this.setSigmaSquared(sigmaSquaredOld);
        return val;
    }

    public double compute(double[] x, double[] grad) {
        double val = 0.0;
        switch (this.type) {
            case NULL: {
                return val;
            }
            case QUADRATIC: {
                for (int i = 0; i < x.length; ++i) {
                    val += x[i] * x[i] / 2.0 / this.sigmaSq;
                    int n = i;
                    grad[n] = grad[n] + x[i] / this.sigmaSq;
                }
                return val;
            }
            case HUBER: {
                for (int i = 0; i < x.length; ++i) {
                    if (x[i] < -this.epsilon) {
                        val += (-x[i] - this.epsilon / 2.0) / this.sigmaSq;
                        int n = i;
                        grad[n] = grad[n] + -1.0 / this.sigmaSq;
                        continue;
                    }
                    if (x[i] < this.epsilon) {
                        val += x[i] * x[i] / 2.0 / this.epsilon / this.sigmaSq;
                        int n = i;
                        grad[n] = grad[n] + x[i] / this.epsilon / this.sigmaSq;
                        continue;
                    }
                    val += (x[i] - this.epsilon / 2.0) / this.sigmaSq;
                    int n = i;
                    grad[n] = grad[n] + 1.0 / this.sigmaSq;
                }
                return val;
            }
            case QUARTIC: {
                for (int i = 0; i < x.length; ++i) {
                    val += x[i] * x[i] * (x[i] * x[i]) / 2.0 / this.sigmaQu;
                    int n = i;
                    grad[n] = grad[n] + x[i] / this.sigmaQu;
                }
                return val;
            }
            case ADAPT: {
                double[] newX = ArrayMath.pairwiseSubtract(x, this.means);
                return val += this.otherPrior.compute(newX, grad);
            }
            case COSH: {
                double d;
                double norm = ArrayMath.norm_1(x) / this.sigmaSq;
                if (norm > 30.0) {
                    val = norm - Math.log(2.0);
                    d = 1.0 / this.sigmaSq;
                } else {
                    val = Math.log(Math.cosh(norm));
                    d = (2.0 * (1.0 / (Math.exp(-2.0 * norm) + 1.0)) - 1.0) / this.sigmaSq;
                }
                for (int i = 0; i < x.length; ++i) {
                    int n = i;
                    grad[n] = grad[n] + Math.signum(x[i]) * d;
                }
                return val;
            }
            case MULTIPLE_QUADRATIC: {
                for (int i = 0; i < x.length; ++i) {
                    val += x[i] * x[i] / 2.0 / this.sigmaSqM[i];
                    int n = i;
                    grad[n] = grad[n] + x[i] / this.sigmaSqM[i];
                }
                return val;
            }
        }
        throw new RuntimeException("LogPrior.valueAt is undefined for prior of type " + this);
    }

    public static enum LogPriorType {
        NULL,
        QUADRATIC,
        HUBER,
        QUARTIC,
        COSH,
        ADAPT,
        MULTIPLE_QUADRATIC;

    }
}

