/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.common.datastream;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.common.datastream.AllReduceImpl;
import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
import org.apache.flink.ml.common.datastream.sort.CoGroupOperator;
import org.apache.flink.ml.common.window.CountTumblingWindows;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
import org.apache.flink.ml.common.window.GlobalWindows;
import org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows;
import org.apache.flink.ml.common.window.Windows;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.streaming.api.datastream.AllWindowedStream;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.InternalTimer;
import org.apache.flink.streaming.api.operators.InternalTimerService;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.TimestampedCollector;
import org.apache.flink.streaming.api.operators.Triggerable;
import org.apache.flink.streaming.api.windowing.assigners.EventTimeSessionWindows;
import org.apache.flink.streaming.api.windowing.assigners.ProcessingTimeSessionWindows;
import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows;
import org.apache.flink.streaming.api.windowing.assigners.TumblingProcessingTimeWindows;
import org.apache.flink.streaming.api.windowing.assigners.WindowAssigner;
import org.apache.flink.streaming.api.windowing.time.Time;
import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.TableException;
import org.apache.flink.util.Collector;

@Internal
public class DataStreamUtils {
    public static DataStream<double[]> allReduceSum(DataStream<double[]> input) {
        return AllReduceImpl.allReduceSum(input);
    }

    public static <IN, OUT> DataStream<OUT> mapPartition(DataStream<IN> input, MapPartitionFunction<IN, OUT> func) {
        TypeInformation outType = TypeExtractor.getMapPartitionReturnTypes(func, (TypeInformation)input.getType(), null, (boolean)true);
        return DataStreamUtils.mapPartition(input, func, outType);
    }

    public static <IN, OUT> DataStream<OUT> mapPartition(DataStream<IN> input, MapPartitionFunction<IN, OUT> func, TypeInformation<OUT> outType) {
        func = (MapPartitionFunction)input.getExecutionEnvironment().clean(func);
        return input.transform("mapPartition", outType, new MapPartitionOperator(func)).setParallelism(input.getParallelism());
    }

    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {
        return DataStreamUtils.reduce(input, func, input.getType());
    }

    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func, TypeInformation<T> outType) {
        func = (ReduceFunction)input.getExecutionEnvironment().clean(func);
        SingleOutputStreamOperator partialReducedStream = input.transform("reduce", outType, new ReduceOperator(func)).setParallelism(input.getParallelism());
        if (partialReducedStream.getParallelism() == 1) {
            return partialReducedStream;
        }
        return partialReducedStream.transform("reduce", outType, new ReduceOperator(func)).setParallelism(1);
    }

    public static <T, K> DataStream<T> reduce(KeyedStream<T, K> input, ReduceFunction<T> func) {
        return DataStreamUtils.reduce(input, func, input.getType());
    }

    public static <T, K> DataStream<T> reduce(KeyedStream<T, K> input, ReduceFunction<T> func, TypeInformation<T> outType) {
        func = (ReduceFunction)input.getExecutionEnvironment().clean(func);
        return input.transform("Keyed Reduce", outType, new KeyedReduceOperator(func, outType.createSerializer(input.getExecutionConfig()))).setParallelism(input.getParallelism());
    }

    public static <IN, ACC, OUT> DataStream<OUT> aggregate(DataStream<IN> input, AggregateFunction<IN, ACC, OUT> func) {
        TypeInformation accType = TypeExtractor.getAggregateFunctionAccumulatorType(func, (TypeInformation)input.getType(), null, (boolean)true);
        TypeInformation outType = TypeExtractor.getAggregateFunctionReturnType(func, (TypeInformation)input.getType(), null, (boolean)true);
        return DataStreamUtils.aggregate(input, func, accType, outType);
    }

    public static <IN, ACC, OUT> DataStream<OUT> aggregate(DataStream<IN> input, AggregateFunction<IN, ACC, OUT> func, TypeInformation<ACC> accType, TypeInformation<OUT> outType) {
        func = (AggregateFunction)input.getExecutionEnvironment().clean(func);
        SingleOutputStreamOperator partialAggregatedStream = input.transform("partialAggregate", accType, new PartialAggregateOperator(func, accType));
        SingleOutputStreamOperator aggregatedStream = partialAggregatedStream.transform("aggregate", outType, new AggregateOperator(func, accType));
        aggregatedStream.getTransformation().setParallelism(1);
        return aggregatedStream;
    }

    public static <T> DataStream<T> sample(DataStream<T> input, int numSamples, long randomSeed) {
        int inputParallelism = input.getParallelism();
        int firstRoundNumSamples = Math.min(numSamples / inputParallelism + inputParallelism, numSamples);
        return input.rebalance().transform("firstRoundSampling", input.getType(), new SamplingOperator(firstRoundNumSamples, randomSeed)).setParallelism(inputParallelism).transform("secondRoundSampling", input.getType(), new SamplingOperator(numSamples, randomSeed)).setParallelism(1).map((MapFunction & Serializable)x -> x, input.getType()).setParallelism(inputParallelism);
    }

    public static <T> void setManagedMemoryWeight(DataStream<T> dataStream, long memoryBytes) {
        if (memoryBytes > 0L) {
            int weightInMebibyte = Math.max(1, (int)(memoryBytes >> 20));
            Optional previousWeight = dataStream.getTransformation().declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase.OPERATOR, weightInMebibyte);
            if (previousWeight.isPresent()) {
                throw new TableException("Managed memory weight has been set, this should not happen.");
            }
        }
    }

    public static <IN, OUT, W extends Window> SingleOutputStreamOperator<OUT> windowAllAndProcess(DataStream<IN> input, Windows windows, ProcessAllWindowFunction<IN, OUT, W> function) {
        function = (ProcessAllWindowFunction)input.getExecutionEnvironment().clean(function);
        AllWindowedStream<IN, W> allWindowedStream = DataStreamUtils.getAllWindowedStream(input, windows);
        return allWindowedStream.process(function);
    }

    public static <IN, OUT, W extends Window> SingleOutputStreamOperator<OUT> windowAllAndProcess(DataStream<IN> input, Windows windows, ProcessAllWindowFunction<IN, OUT, W> function, TypeInformation<OUT> outType) {
        function = (ProcessAllWindowFunction)input.getExecutionEnvironment().clean(function);
        AllWindowedStream<IN, W> allWindowedStream = DataStreamUtils.getAllWindowedStream(input, windows);
        return allWindowedStream.process(function, outType);
    }

    public static <IN1, IN2, KEY extends Serializable, OUT> DataStream<OUT> coGroup(DataStream<IN1> input1, DataStream<IN2> input2, KeySelector<IN1, KEY> keySelector1, KeySelector<IN2, KEY> keySelector2, TypeInformation<OUT> outTypeInformation, CoGroupFunction<IN1, IN2, OUT> func) {
        func = (CoGroupFunction)input1.getExecutionEnvironment().clean(func);
        SingleOutputStreamOperator result = input1.connect(input2).keyBy(keySelector1, keySelector2).transform("CoGroupOperator", outTypeInformation, new CoGroupOperator(func)).setParallelism(Math.max(input1.getParallelism(), input2.getParallelism()));
        DataStreamUtils.setManagedMemoryWeight(result, 100L);
        return result;
    }

    private static <IN, W extends Window> AllWindowedStream<IN, W> getAllWindowedStream(DataStream<IN> input, Windows windows) {
        if (windows instanceof CountTumblingWindows) {
            long countWindowSize = ((CountTumblingWindows)windows).getSize();
            return input.countWindowAll(countWindowSize);
        }
        return input.windowAll(DataStreamUtils.getDataStreamTimeWindowAssigner(windows));
    }

    private static WindowAssigner<Object, TimeWindow> getDataStreamTimeWindowAssigner(Windows windows) {
        if (windows instanceof GlobalWindows) {
            return EndOfStreamWindows.get();
        }
        if (windows instanceof EventTimeTumblingWindows) {
            return TumblingEventTimeWindows.of((Time)DataStreamUtils.getStreamWindowTime(((EventTimeTumblingWindows)windows).getSize()));
        }
        if (windows instanceof ProcessingTimeTumblingWindows) {
            return TumblingProcessingTimeWindows.of((Time)DataStreamUtils.getStreamWindowTime(((ProcessingTimeTumblingWindows)windows).getSize()));
        }
        if (windows instanceof org.apache.flink.ml.common.window.EventTimeSessionWindows) {
            return EventTimeSessionWindows.withGap((Time)DataStreamUtils.getStreamWindowTime(((org.apache.flink.ml.common.window.EventTimeSessionWindows)windows).getGap()));
        }
        if (windows instanceof org.apache.flink.ml.common.window.ProcessingTimeSessionWindows) {
            return ProcessingTimeSessionWindows.withGap((Time)DataStreamUtils.getStreamWindowTime(((org.apache.flink.ml.common.window.ProcessingTimeSessionWindows)windows).getGap()));
        }
        throw new UnsupportedOperationException(String.format("Unsupported Windows subclass: %s", windows.getClass().getName()));
    }

    private static Time getStreamWindowTime(org.apache.flink.api.common.time.Time time) {
        return Time.of((long)time.getSize(), (TimeUnit)time.getUnit());
    }

    public static <T> DataStream<T[]> generateBatchData(DataStream<T> inputData, int downStreamParallelism, int batchSize) {
        return inputData.countWindowAll((long)batchSize).apply(new GlobalBatchCreator()).flatMap(new GlobalBatchSplitter(downStreamParallelism)).partitionCustom((Partitioner & Serializable)(chunkId, numPartitions) -> chunkId, (KeySelector & Serializable)x -> (Integer)x.f0).map((MapFunction)new MapFunction<Tuple2<Integer, T[]>, T[]>(){

            public T[] map(Tuple2<Integer, T[]> integerTuple2) throws Exception {
                return (Object[])integerTuple2.f1;
            }
        });
    }

    private static class SamplingOperator<T>
    extends AbstractStreamOperator<T>
    implements OneInputStreamOperator<T, T>,
    BoundedOneInput {
        private final int numSamples;
        private final Random random;
        private ListState<T> samplesState;
        private List<T> samples;
        private ListState<Integer> countState;
        private int count;

        SamplingOperator(int numSamples, long randomSeed) {
            this.numSamples = numSamples;
            this.random = new Random(randomSeed);
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            ListStateDescriptor samplesDescriptor = new ListStateDescriptor("samplesState", this.getOperatorConfig().getTypeSerializerIn(0, ((Object)((Object)this)).getClass().getClassLoader()));
            this.samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
            this.samples = new ArrayList<T>(this.numSamples);
            ((Iterable)this.samplesState.get()).forEach(this.samples::add);
            ListStateDescriptor countDescriptor = new ListStateDescriptor("countState", (TypeSerializer)IntSerializer.INSTANCE);
            this.countState = context.getOperatorStateStore().getListState(countDescriptor);
            Iterator countIterator = ((Iterable)this.countState.get()).iterator();
            this.count = countIterator.hasNext() ? (Integer)countIterator.next() : 0;
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.samplesState.update(this.samples);
            this.countState.update(Collections.singletonList(this.count));
        }

        public void processElement(StreamRecord<T> streamRecord) throws Exception {
            Object value = streamRecord.getValue();
            ++this.count;
            if (this.samples.size() < this.numSamples) {
                this.samples.add(value);
            } else {
                int index = this.random.nextInt(this.count);
                if (index < this.numSamples) {
                    this.samples.set(index, value);
                }
            }
        }

        public void endInput() throws Exception {
            for (T sample : this.samples) {
                this.output.collect((Object)new StreamRecord(sample));
            }
        }
    }

    private static class GlobalBatchSplitter<T>
    implements FlatMapFunction<T[], Tuple2<Integer, T[]>> {
        private final int downStreamParallelism;

        public GlobalBatchSplitter(int downStreamParallelism) {
            this.downStreamParallelism = downStreamParallelism;
        }

        public void flatMap(T[] values, Collector<Tuple2<Integer, T[]>> collector) {
            int i;
            int div = values.length / this.downStreamParallelism;
            int mod = values.length % this.downStreamParallelism;
            int offset = 0;
            int size = div + 1;
            for (i = 0; i < mod; ++i) {
                collector.collect((Object)Tuple2.of((Object)i, Arrays.copyOfRange(values, offset, offset + size)));
                offset += size;
            }
            size = div;
            while (i < this.downStreamParallelism) {
                collector.collect((Object)Tuple2.of((Object)i, Arrays.copyOfRange(values, offset, offset + size)));
                offset += size;
                ++i;
            }
        }
    }

    private static class GlobalBatchCreator<T>
    implements AllWindowFunction<T, T[], GlobalWindow> {
        private GlobalBatchCreator() {
        }

        public void apply(GlobalWindow timeWindow, Iterable<T> iterable, Collector<T[]> collector) {
            List points = IteratorUtils.toList(iterable.iterator());
            collector.collect((Object)points.toArray(new Object[0]));
        }
    }

    private static class AggregateOperator<IN, ACC, OUT>
    extends AbstractUdfStreamOperator<OUT, AggregateFunction<IN, ACC, OUT>>
    implements OneInputStreamOperator<ACC, OUT>,
    BoundedOneInput {
        private final TypeInformation<ACC> accType;
        private ACC acc;
        private ListState<ACC> accState;

        public AggregateOperator(AggregateFunction<IN, ACC, OUT> userFunction, TypeInformation<ACC> accType) {
            super(userFunction);
            this.accType = accType;
        }

        public void endInput() {
            this.output.collect((Object)new StreamRecord(((AggregateFunction)this.userFunction).getResult(this.acc)));
        }

        public void processElement(StreamRecord<ACC> streamRecord) throws Exception {
            this.acc = this.acc == null ? streamRecord.getValue() : ((AggregateFunction)this.userFunction).merge(streamRecord.getValue(), this.acc);
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.accState = context.getOperatorStateStore().getListState(new ListStateDescriptor("accState", this.accType));
            this.acc = OperatorStateUtils.getUniqueElement(this.accState, "accState").orElse(null);
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.accState.clear();
            if (this.acc != null) {
                this.accState.add(this.acc);
            }
        }
    }

    private static class PartialAggregateOperator<IN, ACC, OUT>
    extends AbstractUdfStreamOperator<ACC, AggregateFunction<IN, ACC, OUT>>
    implements OneInputStreamOperator<IN, ACC>,
    BoundedOneInput {
        private final TypeInformation<ACC> accType;
        private ACC acc;
        private ListState<ACC> accState;

        public PartialAggregateOperator(AggregateFunction<IN, ACC, OUT> userFunction, TypeInformation<ACC> accType) {
            super(userFunction);
            this.accType = accType;
        }

        public void endInput() {
            this.output.collect((Object)new StreamRecord(this.acc));
        }

        public void processElement(StreamRecord<IN> streamRecord) throws Exception {
            this.acc = ((AggregateFunction)this.userFunction).add(streamRecord.getValue(), this.acc);
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.accState = context.getOperatorStateStore().getListState(new ListStateDescriptor("accState", this.accType));
            this.acc = OperatorStateUtils.getUniqueElement(this.accState, "accState").orElse(((AggregateFunction)this.userFunction).createAccumulator());
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.accState.clear();
            this.accState.add(this.acc);
        }
    }

    private static class KeyedReduceOperator<IN, KEY>
    extends AbstractUdfStreamOperator<IN, ReduceFunction<IN>>
    implements OneInputStreamOperator<IN, IN>,
    Triggerable<KEY, VoidNamespace> {
        private static final long serialVersionUID = 1L;
        private static final String STATE_NAME = "_op_state";
        private transient ValueState<IN> values;
        private final TypeSerializer<IN> serializer;
        private InternalTimerService<VoidNamespace> timerService;

        public KeyedReduceOperator(ReduceFunction<IN> reducer, TypeSerializer<IN> serializer) {
            super(reducer);
            this.serializer = serializer;
        }

        public void open() throws Exception {
            super.open();
            ValueStateDescriptor stateId = new ValueStateDescriptor(STATE_NAME, this.serializer);
            this.values = (ValueState)this.getPartitionedState((StateDescriptor)stateId);
            this.timerService = this.getInternalTimerService("end-key-timers", (TypeSerializer)new VoidNamespaceSerializer(), this);
        }

        public void processElement(StreamRecord<IN> element) throws Exception {
            Object value = element.getValue();
            Object currentValue = this.values.value();
            if (currentValue == null) {
                this.timerService.registerEventTimeTimer((Object)VoidNamespace.INSTANCE, Long.MAX_VALUE);
            } else {
                value = ((ReduceFunction)this.userFunction).reduce(currentValue, value);
            }
            this.values.update(value);
        }

        public void onEventTime(InternalTimer<KEY, VoidNamespace> timer) throws Exception {
            Object currentValue = this.values.value();
            if (currentValue != null) {
                this.output.collect((Object)new StreamRecord(currentValue, Long.MAX_VALUE));
            }
        }

        public void onProcessingTime(InternalTimer<KEY, VoidNamespace> timer) throws Exception {
        }
    }

    private static class ReduceOperator<T>
    extends AbstractUdfStreamOperator<T, ReduceFunction<T>>
    implements OneInputStreamOperator<T, T>,
    BoundedOneInput {
        private T result;
        private ListState<T> state;

        public ReduceOperator(ReduceFunction<T> userFunction) {
            super(userFunction);
        }

        public void endInput() {
            if (this.result != null) {
                this.output.collect((Object)new StreamRecord(this.result));
            }
        }

        public void processElement(StreamRecord<T> streamRecord) throws Exception {
            this.result = this.result == null ? streamRecord.getValue() : ((ReduceFunction)this.userFunction).reduce(streamRecord.getValue(), this.result);
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.state = context.getOperatorStateStore().getListState(new ListStateDescriptor("state", this.getOperatorConfig().getTypeSerializerIn(0, ((Object)((Object)this)).getClass().getClassLoader())));
            this.result = OperatorStateUtils.getUniqueElement(this.state, "state").orElse(null);
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.state.clear();
            if (this.result != null) {
                this.state.add(this.result);
            }
        }
    }

    private static class MapPartitionOperator<IN, OUT>
    extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
    implements OneInputStreamOperator<IN, OUT>,
    BoundedOneInput {
        private ListStateWithCache<IN> valuesState;

        public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
            super(mapPartitionFunc);
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.valuesState = new ListStateWithCache(this.getOperatorConfig().getTypeSerializerIn(0, ((Object)((Object)this)).getClass().getClassLoader()), this.getContainingTask(), this.getRuntimeContext(), context, this.config.getOperatorID());
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.valuesState.snapshotState(context);
        }

        public void processElement(StreamRecord<IN> input) throws Exception {
            this.valuesState.add(input.getValue());
        }

        public void endInput() throws Exception {
            ((MapPartitionFunction)this.userFunction).mapPartition((Iterable)this.valuesState.get(), (Collector)new TimestampedCollector(this.output));
            this.valuesState.clear();
        }
    }
}

