/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.classification.logisticregression;

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionParams;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

public class OnlineLogisticRegression
implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
OnlineLogisticRegressionParams<OnlineLogisticRegression> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private Table initModelDataTable;

    public OnlineLogisticRegression() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public OnlineLogisticRegressionModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        DataStream<LogisticRegressionModelData> modelDataStream = LogisticRegressionModelDataUtil.getModelDataStream(this.initModelDataTable);
        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
        TypeInformation pointTypeInfo = this.getWeightCol() == null ? Types.ROW((TypeInformation[])new TypeInformation[]{inputTypeInfo.getTypeAt(this.getFeaturesCol()), inputTypeInfo.getTypeAt(this.getLabelCol())}) : Types.ROW((TypeInformation[])new TypeInformation[]{inputTypeInfo.getTypeAt(this.getFeaturesCol()), inputTypeInfo.getTypeAt(this.getLabelCol()), inputTypeInfo.getTypeAt(this.getWeightCol())});
        SingleOutputStreamOperator points = tEnv.toDataStream(inputs[0]).map((MapFunction)new FeaturesLabelExtractor(this.getFeaturesCol(), this.getLabelCol(), this.getWeightCol()), pointTypeInfo);
        SingleOutputStreamOperator initModelData = modelDataStream.map((MapFunction & Serializable)value -> value.coefficient);
        initModelData.getTransformation().setParallelism(1);
        FtrlIterationBody body = new FtrlIterationBody(this.getGlobalBatchSize(), this.getAlpha(), this.getBeta(), this.getReg(), this.getElasticNet());
        DataStream onlineModelData = Iterations.iterateUnboundedStreams(DataStreamList.of(new DataStream[]{initModelData}), DataStreamList.of(new DataStream[]{points}), body).get(0);
        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
        OnlineLogisticRegressionModel model = new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
        ParamUtils.updateExistingParams(model, this.paramMap);
        return model;
    }

    @Override
    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata(this, path);
        ReadWriteUtils.saveModelData(LogisticRegressionModelDataUtil.getModelDataStream(this.initModelDataTable), path, new LogisticRegressionModelDataUtil.ModelDataEncoder());
    }

    public static OnlineLogisticRegression load(StreamTableEnvironment tEnv, String path) throws IOException {
        OnlineLogisticRegression onlineLogisticRegression = (OnlineLogisticRegression)ReadWriteUtils.loadStageParam(path);
        Table modelDataTable = ReadWriteUtils.loadModelData(tEnv, path, new LogisticRegressionModelDataUtil.ModelDataDecoder());
        onlineLogisticRegression.setInitialModelData(modelDataTable);
        return onlineLogisticRegression;
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public OnlineLogisticRegression setInitialModelData(Table initModelDataTable) {
        this.initModelDataTable = initModelDataTable;
        return this;
    }

    private static class CalculateLocalGradient
    extends AbstractStreamOperator<DenseVector[]>
    implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
        private ListState<DenseVector> modelDataState;
        private ListState<Row[]> localBatchDataState;
        private double[] gradient;
        private double[] weightSum;

        private CalculateLocalGradient() {
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.modelDataState = context.getOperatorStateStore().getListState(new ListStateDescriptor("modelData", DenseVector.class));
            ObjectArrayTypeInfo type = ObjectArrayTypeInfo.getInfoFor((TypeInformation)TypeInformation.of(Row.class));
            this.localBatchDataState = context.getOperatorStateStore().getListState(new ListStateDescriptor("localBatch", (TypeInformation)type));
        }

        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
            this.localBatchDataState.add((Object)((Row[])pointsRecord.getValue()));
            this.calculateGradient();
        }

        private void calculateGradient() throws Exception {
            if (!((Iterable)this.modelDataState.get()).iterator().hasNext() || !((Iterable)this.localBatchDataState.get()).iterator().hasNext()) {
                return;
            }
            DenseVector modelData = OperatorStateUtils.getUniqueElement(this.modelDataState, "modelData").get();
            this.modelDataState.clear();
            List pointsList = IteratorUtils.toList(((Iterable)this.localBatchDataState.get()).iterator());
            Row[] points = (Row[])pointsList.remove(0);
            this.localBatchDataState.update(pointsList);
            for (Row point : points) {
                int i;
                double weight;
                Vector vec = (Vector)point.getFieldAs(0);
                double label = (Double)point.getFieldAs(1);
                double d = weight = point.getArity() == 2 ? 1.0 : (Double)point.getFieldAs(2);
                if (this.gradient == null) {
                    this.gradient = new double[vec.size()];
                    this.weightSum = new double[this.gradient.length];
                }
                double p = BLAS.dot((Vector)modelData, vec);
                p = 1.0 / (1.0 + Math.exp(-p));
                if (vec instanceof DenseVector) {
                    DenseVector dvec = (DenseVector)vec;
                    i = 0;
                    while (i < modelData.size()) {
                        int n = i;
                        this.gradient[n] = this.gradient[n] + (p - label) * dvec.values[i];
                        int n2 = i++;
                        this.weightSum[n2] = this.weightSum[n2] + 1.0;
                    }
                    continue;
                }
                SparseVector svec = (SparseVector)vec;
                for (i = 0; i < svec.indices.length; ++i) {
                    int idx;
                    int n = idx = svec.indices[i];
                    this.gradient[n] = this.gradient[n] + (p - label) * svec.values[i];
                    int n3 = idx;
                    this.weightSum[n3] = this.weightSum[n3] + weight;
                }
            }
            if (points.length > 0) {
                this.output.collect((Object)new StreamRecord((Object)new DenseVector[]{new DenseVector(this.gradient), new DenseVector(this.weightSum), this.getRuntimeContext().getIndexOfThisSubtask() == 0 ? modelData : null}));
            }
            Arrays.fill(this.gradient, 0.0);
            Arrays.fill(this.weightSum, 0.0);
        }

        public void processElement2(StreamRecord<DenseVector> modelDataRecord) throws Exception {
            this.modelDataState.add((Object)((DenseVector)modelDataRecord.getValue()));
            this.calculateGradient();
        }
    }

    private static class UpdateModel
    extends AbstractStreamOperator<DenseVector>
    implements OneInputStreamOperator<DenseVector[], DenseVector> {
        private ListState<double[]> nParamState;
        private ListState<double[]> zParamState;
        private final double alpha;
        private final double beta;
        private final double l1;
        private final double l2;
        private double[] nParam;
        private double[] zParam;

        public UpdateModel(double alpha, double beta, double l1, double l2) {
            this.alpha = alpha;
            this.beta = beta;
            this.l1 = l1;
            this.l2 = l2;
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.nParamState = context.getOperatorStateStore().getListState(new ListStateDescriptor("nParamState", double[].class));
            this.zParamState = context.getOperatorStateStore().getListState(new ListStateDescriptor("zParamState", double[].class));
        }

        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
            int i;
            DenseVector[] gradientInfo = (DenseVector[])streamRecord.getValue();
            double[] coefficient = gradientInfo[2].values;
            double[] g = gradientInfo[0].values;
            for (i = 0; i < g.length; ++i) {
                if (gradientInfo[1].values[i] == 0.0) continue;
                g[i] = g[i] / gradientInfo[1].values[i];
            }
            if (this.zParam == null) {
                this.zParam = new double[g.length];
                this.nParam = new double[g.length];
                this.nParamState.add((Object)this.nParam);
                this.zParamState.add((Object)this.zParam);
            }
            for (i = 0; i < this.zParam.length; ++i) {
                double sigma = (Math.sqrt(this.nParam[i] + g[i] * g[i]) - Math.sqrt(this.nParam[i])) / this.alpha;
                int n = i;
                this.zParam[n] = this.zParam[n] + (g[i] - sigma * coefficient[i]);
                int n2 = i;
                this.nParam[n2] = this.nParam[n2] + g[i] * g[i];
                coefficient[i] = Math.abs(this.zParam[i]) <= this.l1 ? 0.0 : ((double)(this.zParam[i] < 0.0 ? -1 : 1) * this.l1 - this.zParam[i]) / ((this.beta + Math.sqrt(this.nParam[i])) / this.alpha + this.l2);
            }
            this.output.collect((Object)new StreamRecord((Object)new DenseVector(coefficient)));
        }
    }

    private static class CreateLrModelData
    implements MapFunction<DenseVector, LogisticRegressionModelData>,
    CheckpointedFunction {
        private Long modelVersion = 1L;
        private transient ListState<Long> modelVersionState;

        private CreateLrModelData() {
        }

        public LogisticRegressionModelData map(DenseVector denseVector) throws Exception {
            Long l = this.modelVersion;
            Long l2 = this.modelVersion = Long.valueOf(this.modelVersion + 1L);
            return new LogisticRegressionModelData(denseVector, l);
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            this.modelVersionState.update(Collections.singletonList(this.modelVersion));
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.modelVersionState = context.getOperatorStateStore().getListState(new ListStateDescriptor("modelVersionState", Long.class));
        }
    }

    private static class FtrlIterationBody
    implements IterationBody {
        private final int batchSize;
        private final double alpha;
        private final double beta;
        private final double l1;
        private final double l2;

        public FtrlIterationBody(int batchSize, double alpha, double beta, double reg, double elasticNet) {
            this.batchSize = batchSize;
            this.alpha = alpha;
            this.beta = beta;
            this.l1 = elasticNet * reg;
            this.l2 = (1.0 - elasticNet) * reg;
        }

        @Override
        public IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams) {
            DataStream modelData = variableStreams.get(0);
            DataStream points = dataStreams.get(0);
            int parallelism = points.getParallelism();
            Preconditions.checkState((parallelism <= this.batchSize ? 1 : 0) != 0, (Object)"There are more subtasks in the training process than the number of elements in each batch. Some subtasks might be idling forever.");
            SingleOutputStreamOperator newGradient = DataStreamUtils.generateBatchData(points, parallelism, this.batchSize).connect(modelData.broadcast()).transform("LocalGradientCalculator", TypeInformation.of(DenseVector[].class), (TwoInputStreamOperator)new CalculateLocalGradient()).setParallelism(parallelism).countWindowAll((long)parallelism).reduce((ReduceFunction & Serializable)(gradientInfo, newGradientInfo) -> {
                BLAS.axpy(1.0, gradientInfo[0], newGradientInfo[0]);
                BLAS.axpy(1.0, gradientInfo[1], newGradientInfo[1]);
                if (newGradientInfo[2] == null) {
                    newGradientInfo[2] = gradientInfo[2];
                }
                return newGradientInfo;
            });
            SingleOutputStreamOperator feedbackModelData = newGradient.transform("ModelDataUpdater", TypeInformation.of(DenseVector.class), (OneInputStreamOperator)new UpdateModel(this.alpha, this.beta, this.l1, this.l2)).setParallelism(1);
            SingleOutputStreamOperator outputModelData = feedbackModelData.map((MapFunction)new CreateLrModelData()).setParallelism(1);
            return new IterationBodyResult(DataStreamList.of(new DataStream[]{feedbackModelData}), DataStreamList.of(new DataStream[]{outputModelData}));
        }
    }

    private static class FeaturesLabelExtractor
    implements MapFunction<Row, Row> {
        private final String featuresCol;
        private final String labelCol;
        private final String weightCol;

        private FeaturesLabelExtractor(String featuresCol, String labelCol, String weightCol) {
            this.featuresCol = featuresCol;
            this.labelCol = labelCol;
            this.weightCol = weightCol;
        }

        public Row map(Row row) throws Exception {
            if (this.weightCol == null) {
                return Row.of((Object[])new Object[]{row.getField(this.featuresCol), row.getField(this.labelCol)});
            }
            return Row.of((Object[])new Object[]{row.getField(this.featuresCol), row.getField(this.labelCol), row.getField(this.weightCol)});
        }
    }
}

