/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.classification.linearsvc;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.classification.linearsvc.LinearSVCModel;
import org.apache.flink.ml.classification.linearsvc.LinearSVCModelData;
import org.apache.flink.ml.classification.linearsvc.LinearSVCParams;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.common.lossfunc.HingeLoss;
import org.apache.flink.ml.common.optimizer.SGD;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

public class LinearSVC
implements Estimator<LinearSVC, LinearSVCModel>,
LinearSVCParams<LinearSVC> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public LinearSVC() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public LinearSVCModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator trainData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)dataPoint -> {
            double weight = this.getWeightCol() == null ? 1.0 : ((Number)dataPoint.getField(this.getWeightCol())).doubleValue();
            double label = ((Number)dataPoint.getField(this.getLabelCol())).doubleValue();
            Preconditions.checkState((Double.compare(0.0, label) == 0 || Double.compare(1.0, label) == 0 ? 1 : 0) != 0, (String)"LinearSVC only supports binary classification. But detected label: %s.", (Object[])new Object[]{label});
            DenseVector features = ((Vector)dataPoint.getField(this.getFeaturesCol())).toDense();
            return new LabeledPointWithWeight(features, label, weight);
        });
        SingleOutputStreamOperator initModelData = DataStreamUtils.reduce(trainData.map((MapFunction & Serializable)x -> x.getFeatures().size()), (ReduceFunction & Serializable)(t0, t1) -> {
            Preconditions.checkState((boolean)t0.equals(t1), (Object)"The training data should all have same dimensions.");
            return t0;
        }).map(DenseVector::new);
        SGD optimizer = new SGD(this.getMaxIter(), this.getLearningRate(), this.getGlobalBatchSize(), this.getTol(), this.getReg(), this.getElasticNet());
        DataStream<DenseVector> rawModelData = optimizer.optimize((DataStream<DenseVector>)initModelData, (DataStream<LabeledPointWithWeight>)trainData, HingeLoss.INSTANCE);
        SingleOutputStreamOperator modelData = rawModelData.map(LinearSVCModelData::new);
        LinearSVCModel model = new LinearSVCModel().setModelData(tEnv.fromDataStream((DataStream)modelData));
        ParamUtils.updateExistingParams(model, this.paramMap);
        return model;
    }

    @Override
    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata(this, path);
    }

    public static LinearSVC load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (LinearSVC)ReadWriteUtils.loadStageParam(path);
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}

