package edu.stanford.nlp.optimization;

import de.metanome.algorithm_integration.ColumnCondition;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:edu/stanford/nlp/optimization/StochasticMinimizer.class */
public abstract class StochasticMinimizer<T extends Function> implements Minimizer<T>, HasEvaluators {
    protected double[] x;
    protected double[] newX;
    protected double[] grad;
    protected double[] newGrad;
    protected double[] v;
    protected int numBatches;
    protected int k;
    private Evaluator[] evaluators;
    private static Redwood.RedwoodChannels log = Redwood.channels(StochasticMinimizer.class);
    protected static final NumberFormat nf = new DecimalFormat("0.000E0");
    public boolean outputIterationsToFile = false;
    public int outputFrequency = 1000;
    public double gain = 0.1d;
    protected int bSize = 15;
    protected boolean quiet = false;
    protected List<double[]> gradList = null;
    protected int memory = 10;
    protected int numPasses = -1;
    protected Random gen = new Random(1);
    protected PrintWriter file = null;
    protected PrintWriter infoFile = null;
    protected long maxTime = Long.MAX_VALUE;
    private int evaluateIters = 0;

    /* loaded from: input_file:edu/stanford/nlp/optimization/StochasticMinimizer$PropertySetter.class */
    public interface PropertySetter<T1> {
        void set(T1 t1);
    }

    /* loaded from: input_file:edu/stanford/nlp/optimization/StochasticMinimizer$setGain.class */
    private class setGain implements PropertySetter<Double> {
        StochasticMinimizer<T> parent;

        public setGain(StochasticMinimizer<T> stochasticMinimizer) {
            this.parent = null;
            this.parent = stochasticMinimizer;
        }

        @Override // edu.stanford.nlp.optimization.StochasticMinimizer.PropertySetter
        public void set(Double d) {
            StochasticMinimizer.this.gain = d.doubleValue();
        }
    }

    public void shutUp() {
        this.quiet = true;
    }

    protected abstract String getName();

    protected abstract void takeStep(AbstractStochasticCachingDiffFunction abstractStochasticCachingDiffFunction);

    @Override // edu.stanford.nlp.optimization.HasEvaluators
    public void setEvaluators(int i, Evaluator[] evaluatorArr) {
        this.evaluateIters = i;
        this.evaluators = evaluatorArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double gainSchedule(int i, double d) {
        return d / (d + i);
    }

    protected static double[] smooth(List<double[]> list) {
        double[] dArr = new double[list.get(0).length];
        Iterator<double[]> it2 = list.iterator();
        while (it2.hasNext()) {
            ArrayMath.pairwiseAddInPlace(dArr, it2.next());
        }
        ArrayMath.multiplyInPlace(dArr, 1.0d / list.size());
        return dArr;
    }

    private void initFiles() {
        if (this.outputIterationsToFile) {
            String str = getName() + ".output";
            String str2 = getName() + ".info";
            try {
                this.file = new PrintWriter((OutputStream) new FileOutputStream(str), true);
                this.infoFile = new PrintWriter((OutputStream) new FileOutputStream(str2), true);
            } catch (IOException e) {
                log.info("Caught IOException outputting data to file: " + e.getMessage());
                System.exit(1);
            }
        }
    }

    public abstract Pair<Integer, Double> tune(Function function, double[] dArr, long j);

    public double tuneDouble(Function function, double[] dArr, long j, PropertySetter<Double> propertySetter, double d, double d2) {
        return tuneDouble(function, dArr, j, propertySetter, d, d2, 0.001d * Math.abs(d2 - d));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v16, types: [java.lang.Double, T1] */
    /* JADX WARN: Type inference failed for: r1v27, types: [java.lang.Double, T2] */
    /* JADX WARN: Type inference failed for: r1v82, types: [java.lang.Double, T2] */
    /* JADX WARN: Type inference failed for: r1v84, types: [java.lang.Object, T1] */
    public double tuneDouble(Function function, double[] dArr, long j, PropertySetter<Double> propertySetter, double d, double d2, double d3) {
        double[] dArr2 = new double[dArr.length];
        this.maxTime = j;
        if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
            throw new UnsupportedOperationException();
        }
        AbstractStochasticCachingDiffFunction abstractStochasticCachingDiffFunction = (AbstractStochasticCachingDiffFunction) function;
        ArrayList<Pair> arrayList = new ArrayList();
        Pair pair = new Pair(Double.valueOf(d), Double.valueOf(Double.POSITIVE_INFINITY));
        Pair pair2 = new Pair(Double.valueOf(d), Double.valueOf(Double.POSITIVE_INFINITY));
        Pair pair3 = new Pair(Double.valueOf(d2), Double.valueOf(Double.POSITIVE_INFINITY));
        Pair pair4 = new Pair();
        Pair pair5 = new Pair();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(Double.valueOf(d));
        arrayList2.add(Double.valueOf(d2));
        boolean z = true;
        this.numPasses = 10000;
        do {
            System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
            if (arrayList2.size() != 0) {
                pair4.first = arrayList2.remove(0);
            } else {
                pair4.first = Double.valueOf(0.5d * (((Double) pair2.first()).doubleValue() + ((Double) pair3.first()).doubleValue()));
            }
            propertySetter.set(pair4.first());
            log.info("");
            log.info("About to test with batch size:  " + this.bSize + "  gain: " + this.gain + " and  " + propertySetter.toString() + " set to  " + pair4.first());
            dArr2 = minimize(function, 1.0E-100d, dArr2);
            if (Double.isNaN(dArr2[0])) {
                pair4.second = Double.valueOf(Double.POSITIVE_INFINITY);
            } else {
                pair4.second = Double.valueOf(abstractStochasticCachingDiffFunction.valueAt(dArr2));
            }
            if (((Double) pair4.second()).doubleValue() < ((Double) pair.second()).doubleValue()) {
                copyPair(pair, pair5);
                copyPair(pair4, pair);
                if (((Double) pair5.first()).doubleValue() > ((Double) pair.first()).doubleValue()) {
                    copyPair(pair5, pair3);
                } else {
                    copyPair(pair5, pair2);
                }
                arrayList2.add(Double.valueOf(0.5d * (((Double) pair4.first()).doubleValue() + ((Double) pair3.first()).doubleValue())));
            } else if (((Double) pair4.first()).doubleValue() < ((Double) pair.first()).doubleValue()) {
                copyPair(pair4, pair2);
            } else if (((Double) pair4.first()).doubleValue() > ((Double) pair.first()).doubleValue()) {
                copyPair(pair4, pair3);
            }
            if (Math.abs(((Double) pair2.first()).doubleValue() - ((Double) pair3.first()).doubleValue()) < d3) {
                z = false;
            }
            arrayList.add(new Pair(pair4.first(), pair4.second()));
            log.info("");
            log.info("Final value is: " + nf.format(pair4.second()));
            log.info("Optimal so far using " + propertySetter.toString() + " is: " + pair.first());
        } while (z);
        log.info("-------------");
        log.info(" RESULTS          ");
        log.info(propertySetter.getClass().toString());
        log.info("-------------");
        log.info("  val    ,    function after " + j + " ms");
        for (Pair pair6 : arrayList) {
            log.info(pair6.first() + "    ,    " + pair6.second());
        }
        log.info("");
        log.info("");
        return ((Double) pair.first()).doubleValue();
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [java.lang.Object, T1] */
    /* JADX WARN: Type inference failed for: r1v3, types: [java.lang.Object, T2] */
    private static void copyPair(Pair<Double, Double> pair, Pair<Double, Double> pair2) {
        pair2.first = pair.first();
        pair2.second = pair.second();
    }

    public double tuneGain(Function function, double[] dArr, long j, double d, double d2) {
        return tuneDouble(function, dArr, j, new setGain(this), d, d2);
    }

    public int tuneBatch(Function function, double[] dArr, long j, int i) {
        double[] dArr2 = new double[dArr.length];
        int i2 = 0;
        double d = Double.POSITIVE_INFINITY;
        this.maxTime = j;
        double d2 = Double.POSITIVE_INFINITY;
        if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
            throw new UnsupportedOperationException();
        }
        AbstractStochasticCachingDiffFunction abstractStochasticCachingDiffFunction = (AbstractStochasticCachingDiffFunction) function;
        int i3 = i;
        boolean z = true;
        do {
            System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
            log.info("");
            log.info("Testing with batch size:  " + i3);
            this.bSize = i3;
            shutUp();
            minimize(function, 1.0E-5d, dArr2);
            double valueAt = abstractStochasticCachingDiffFunction.valueAt(dArr2);
            if (valueAt < d) {
                d = valueAt;
                i2 = this.bSize;
                i3 *= 2;
                d2 = valueAt;
            } else if (valueAt < d2) {
                i3 *= 2;
                d2 = valueAt;
            } else if (valueAt > d2) {
                z = false;
            }
            log.info("");
            log.info("Final value is: " + nf.format(valueAt));
            log.info("Optimal so far is:  batch size: " + i2);
        } while (z);
        return i2;
    }

    public Pair<Integer, Double> tune(Function function, double[] dArr, long j, List<Integer> list, List<Double> list2) {
        double[] dArr2 = new double[dArr.length];
        int i = 0;
        double d = 0.0d;
        double d2 = Double.POSITIVE_INFINITY;
        double[][] dArr3 = new double[list.size()][list2.size()];
        this.maxTime = j;
        for (int i2 = 0; i2 < list.size(); i2++) {
            for (int i3 = 0; i3 < list2.size(); i3++) {
                System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
                this.bSize = list.get(i2).intValue();
                this.gain = list2.get(i3).doubleValue();
                log.info("");
                log.info("Testing with batch size: " + this.bSize + "    gain:  " + nf.format(this.gain));
                this.quiet = true;
                minimize(function, 1.0E-100d, dArr2);
                dArr3[i2][i3] = function.valueAt(dArr2);
                if (dArr3[i2][i3] < d2) {
                    d2 = dArr3[i2][i3];
                    i = this.bSize;
                    d = this.gain;
                }
                log.info("");
                log.info("Final value is: " + nf.format(dArr3[i2][i3]));
                log.info("Optimal so far is:  batch size: " + i + "   gain:  " + nf.format(d));
            }
        }
        return new Pair<>(Integer.valueOf(i), Double.valueOf(d));
    }

    protected void init(AbstractStochasticCachingDiffFunction abstractStochasticCachingDiffFunction) {
    }

    private void doEvaluation(double[] dArr) {
        if (this.evaluators == null) {
            return;
        }
        for (Evaluator evaluator : this.evaluators) {
            sayln("  Evaluating: " + evaluator.toString());
            evaluator.evaluate(dArr);
        }
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(Function function, double d, double[] dArr) {
        return minimize(function, d, dArr, -1);
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(Function function, double d, double[] dArr, int i) {
        if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
            throw new UnsupportedOperationException();
        }
        AbstractStochasticCachingDiffFunction abstractStochasticCachingDiffFunction = (AbstractStochasticCachingDiffFunction) function;
        abstractStochasticCachingDiffFunction.method = StochasticCalculateMethods.GradientOnly;
        this.x = dArr;
        this.grad = new double[this.x.length];
        this.newX = new double[this.x.length];
        this.gradList = new ArrayList();
        this.numBatches = abstractStochasticCachingDiffFunction.dataDimension() / this.bSize;
        this.outputFrequency = (int) Math.ceil(this.numBatches / this.outputFrequency);
        init(abstractStochasticCachingDiffFunction);
        initFiles();
        if (!(i > 0 || this.numPasses > 0)) {
            throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
        }
        int max = Math.max(i, this.numPasses) * this.numBatches;
        sayln("       Batchsize of: " + this.bSize);
        sayln("       Data dimension of: " + abstractStochasticCachingDiffFunction.dataDimension());
        sayln("       Batches per pass through data:  " + this.numBatches);
        sayln("       Max iterations is = " + max);
        if (this.outputIterationsToFile) {
            this.infoFile.println(function.domainDimension() + "; DomainDimension ");
            this.infoFile.println(this.bSize + "; batchSize ");
            this.infoFile.println(max + "; maxIterations");
            this.infoFile.println(this.numBatches + "; numBatches ");
            this.infoFile.println(this.outputFrequency + "; outputFrequency");
        }
        Timing timing = new Timing();
        Timing timing2 = new Timing();
        timing.start();
        timing2.start();
        this.k = 0;
        while (true) {
            if (this.k >= max) {
                break;
            }
            try {
                if (this.k > 0 && this.evaluateIters > 0 && this.k % this.evaluateIters == 0) {
                    doEvaluation(this.x);
                }
                say("Iter: " + this.k + " pass " + (this.k / this.numBatches) + " batch " + (this.k % this.numBatches));
                if (this.k <= 0 || this.gradList.size() < this.memory) {
                    this.newGrad = new double[this.grad.length];
                } else {
                    this.newGrad = this.gradList.remove(0);
                }
                abstractStochasticCachingDiffFunction.hasNewVals = true;
                System.arraycopy(abstractStochasticCachingDiffFunction.derivativeAt(this.x, this.v, this.bSize), 0, this.newGrad, 0, this.newGrad.length);
                ArrayMath.assertFinite(this.newGrad, "newGrad");
                this.gradList.add(this.newGrad);
                this.grad = smooth(this.gradList);
                takeStep(abstractStochasticCachingDiffFunction);
                ArrayMath.assertFinite(this.newX, "newX");
                if (this.outputIterationsToFile && this.k % this.outputFrequency == 0 && this.k != 0) {
                    double valueAt = abstractStochasticCachingDiffFunction.valueAt(this.x);
                    say(" TrueValue{ " + valueAt + " } ");
                    this.file.println(this.k + " , " + valueAt + " , " + timing.report());
                }
                if (this.k >= max) {
                    sayln("Stochastic Optimization complete.  Stopped after max iterations");
                    this.x = this.newX;
                    break;
                }
                if (timing.report() >= this.maxTime) {
                    sayln("Stochastic Optimization complete.  Stopped after max time");
                    this.x = this.newX;
                    break;
                }
                System.arraycopy(this.newX, 0, this.x, 0, this.x.length);
                say(ColumnCondition.OPEN_BRACKET + (timing.report() / 1000.0d) + " s ");
                say("{" + (timing2.restart() / 1000.0d) + " s}] ");
                say(" " + abstractStochasticCachingDiffFunction.lastValue());
                if (this.quiet) {
                    log.info(".");
                } else {
                    sayln("");
                }
                this.k++;
            } catch (ArrayMath.InvalidElementException e) {
                log.info(e.toString());
                for (int i2 = 0; i2 < this.x.length; i2++) {
                    this.x[i2] = Double.NaN;
                }
            }
        }
        if (this.evaluateIters > 0) {
            doEvaluation(this.x);
        }
        if (this.outputIterationsToFile) {
            this.infoFile.println(this.k + "; Iterations");
            this.infoFile.println((timing.report() / 1000.0d) + "; Completion Time");
            this.infoFile.println(abstractStochasticCachingDiffFunction.valueAt(this.x) + "; Finalvalue");
            this.infoFile.close();
            this.file.close();
            log.info("Output Files Closed");
        }
        say("Completed in: " + (timing.report() / 1000.0d) + " s");
        return this.x;
    }

    protected void sayln(String str) {
        if (this.quiet) {
            return;
        }
        log.info(str);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void say(String str) {
        if (this.quiet) {
            return;
        }
        log.info(str);
    }
}
