/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.FORUtil;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

public class ColGroupDDCFOR
extends AMorphingMMColGroup {
    private static final long serialVersionUID = -5769772089913918987L;
    protected AMapToData _data;
    protected double[] _reference;

    protected ColGroupDDCFOR(int numRows) {
        super(numRows);
    }

    private ColGroupDDCFOR(int[] colIndexes, int numRows, ADictionary dict, double[] reference, AMapToData data, int[] cachedCounts) {
        super(colIndexes, numRows, dict, cachedCounts);
        if (data.getUnique() != dict.getNumberOfValues(colIndexes.length)) {
            throw new DMLCompressionException("Invalid construction of DDC group " + data.getUnique() + " vs. " + dict.getNumberOfValues(colIndexes.length));
        }
        this._zeros = false;
        this._data = data;
        this._reference = reference;
    }

    protected static AColGroup create(int[] colIndexes, int numRows, ADictionary dict, AMapToData data, int[] cachedCounts, double[] reference) {
        boolean allZero = FORUtil.allZero(reference);
        if (dict == null && allZero) {
            return new ColGroupEmpty(colIndexes);
        }
        if (dict == null) {
            return ColGroupConst.create(colIndexes, reference);
        }
        if (allZero) {
            return ColGroupDDC.create(colIndexes, numRows, dict, data, cachedCounts);
        }
        return new ColGroupDDCFOR(colIndexes, numRows, dict, reference, data, cachedCounts);
    }

    @Override
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.DDCFOR;
    }

    @Override
    public double getIdx(int r, int colIdx) {
        return this._dict.getValue(this._data.getIndex(r) * this._colIndexes.length + colIdx) + this._reference[colIdx];
    }

    @Override
    protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) {
        for (int rix = rl; rix < ru; ++rix) {
            int n = rix;
            c[n] = c[n] + preAgg[this._data.getIndex(rix)];
        }
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) {
        for (int i = rl; i < ru; ++i) {
            c[i] = builtin.execute(c[i], preAgg[this._data.getIndex(i)]);
        }
    }

    @Override
    public int[] getCounts(int[] counts) {
        return this._data.getCounts(counts);
    }

    @Override
    public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        if (this._colIndexes.length == 1) {
            this.leftMultByMatrixNoPreAggSingleCol(matrix, result, rl, ru, cl, cu);
        } else {
            this.lmMatrixNoPreAggMultiCol(matrix, result, rl, ru, cl, cu);
        }
    }

    private void leftMultByMatrixNoPreAggSingleCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        double[] retV = result.getDenseBlockValues();
        int nColM = matrix.getNumColumns();
        int nColRet = result.getNumColumns();
        double[] dictVals = this._dict.getValues();
        if (matrix.isInSparseFormat()) {
            this.lmSparseMatrixNoPreAggSingleCol(matrix.getSparseBlock(), nColM, retV, nColRet, dictVals, rl, ru, cl, cu);
        } else {
            this.lmDenseMatrixNoPreAggSingleCol(matrix.getDenseBlockValues(), nColM, retV, nColRet, dictVals, rl, ru, cl, cu);
        }
    }

    private void lmSparseMatrixNoPreAggSingleCol(SparseBlock sb, int nColM, double[] retV, int nColRet, double[] vals, int rl, int ru, int cl, int cu) {
        int colOut = this._colIndexes[0];
        for (int r = rl; r < ru; ++r) {
            if (sb.isEmpty(r)) continue;
            int apos = sb.pos(r);
            int alen = sb.size(r) + apos;
            int[] aix = sb.indexes(r);
            double[] aval = sb.values(r);
            int offR = r * nColRet;
            for (int i = apos; i < alen; ++i) {
                int n = offR + colOut;
                retV[n] = retV[n] + aval[i] * vals[this._data.getIndex(aix[i])];
            }
        }
    }

    private void lmDenseMatrixNoPreAggSingleCol(double[] mV, int nColM, double[] retV, int nColRet, double[] vals, int rl, int ru, int cl, int cu) {
        int colOut = this._colIndexes[0];
        for (int r = rl; r < ru; ++r) {
            int offL = r * nColM;
            int offR = r * nColRet;
            for (int c = cl; c < cu; ++c) {
                int n = offR + colOut;
                retV[n] = retV[n] + mV[offL + c] * vals[this._data.getIndex(r)];
            }
        }
    }

    private void lmMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        if (matrix.isInSparseFormat()) {
            this.lmSparseMatrixNoPreAggMultiCol(matrix, result, rl, ru, cl, cu);
        } else {
            this.lmDenseMatrixNoPreAggMultiCol(matrix, result, rl, ru, cl, cu);
        }
    }

    private void lmSparseMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        double[] retV = result.getDenseBlockValues();
        int nColRet = result.getNumColumns();
        SparseBlock sb = matrix.getSparseBlock();
        for (int r = rl; r < ru; ++r) {
            if (sb.isEmpty(r)) continue;
            int apos = sb.pos(r);
            int alen = sb.size(r) + apos;
            int[] aix = sb.indexes(r);
            double[] aval = sb.values(r);
            int offR = r * nColRet;
            for (int i = apos; i < alen; ++i) {
                this._dict.multiplyScalar(aval[i], retV, offR, this._data.getIndex(aix[i]), this._colIndexes);
            }
        }
    }

    private void lmDenseMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        double[] retV = result.getDenseBlockValues();
        int nColM = matrix.getNumColumns();
        int nColRet = result.getNumColumns();
        double[] mV = matrix.getDenseBlockValues();
        for (int r = rl; r < ru; ++r) {
            int offL = r * nColM;
            int offR = r * nColRet;
            for (int c = cl; c < cu; ++c) {
                this._dict.multiplyScalar(mV[offL + c], retV, offR, this._data.getIndex(c), this._colIndexes);
            }
        }
    }

    @Override
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.DDCFOR;
    }

    @Override
    public long estimateInMemorySize() {
        long size = super.estimateInMemorySize();
        size += this._data.getInMemorySize();
        return size += (long)(8 * this._colIndexes.length);
    }

    @Override
    public AColGroup scalarOperation(ScalarOperator op) {
        double[] newRef = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; ++i) {
            newRef[i] = op.executeScalar(this._reference[i]);
        }
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            return ColGroupDDCFOR.create(this._colIndexes, this._numRows, this._dict, this._data, this.getCachedCounts(), newRef);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            ADictionary newDict = this._dict.applyScalarOp(op);
            return ColGroupDDCFOR.create(this._colIndexes, this._numRows, newDict, this._data, this.getCachedCounts(), newRef);
        }
        ADictionary newDict = this._dict.applyScalarOpWithReference(op, this._reference, newRef);
        return ColGroupDDCFOR.create(this._colIndexes, this._numRows, newDict, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public AColGroup unaryOperation(UnaryOperator op) {
        double[] newRef = FORUtil.unaryOperator(op, this._reference);
        ADictionary newDict = this._dict.applyUnaryOpWithReference(op, this._reference, newRef);
        return ColGroupDDCFOR.create(this._colIndexes, this._numRows, newDict, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) {
        double[] newRef = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; ++i) {
            newRef[i] = op.fn.execute(v[this._colIndexes[i]], this._reference[i]);
        }
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            return ColGroupDDCFOR.create(this._colIndexes, this._numRows, this._dict, this._data, this.getCachedCounts(), newRef);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            ADictionary newDict = this._dict.binOpLeft(op, v, this._colIndexes);
            return ColGroupDDCFOR.create(this._colIndexes, this._numRows, newDict, this._data, this.getCachedCounts(), newRef);
        }
        ADictionary newDict = this._dict.binOpLeftWithReference(op, v, this._colIndexes, this._reference, newRef);
        return ColGroupDDCFOR.create(this._colIndexes, this._numRows, newDict, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) {
        double[] newRef = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; ++i) {
            newRef[i] = op.fn.execute(this._reference[i], v[this._colIndexes[i]]);
        }
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            return ColGroupDDCFOR.create(this._colIndexes, this._numRows, this._dict, this._data, this.getCachedCounts(), newRef);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            ADictionary newDict = this._dict.binOpRight(op, v, this._colIndexes);
            return ColGroupDDCFOR.create(this._colIndexes, this._numRows, newDict, this._data, this.getCachedCounts(), newRef);
        }
        ADictionary newDict = this._dict.binOpRightWithReference(op, v, this._colIndexes, this._reference, newRef);
        return ColGroupDDCFOR.create(this._colIndexes, this._numRows, newDict, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        this._data.write(out);
        for (double d : this._reference) {
            out.writeDouble(d);
        }
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        super.readFields(in);
        this._data = MapToFactory.readIn(in);
        this._reference = new double[this._colIndexes.length];
        for (int i = 0; i < this._colIndexes.length; ++i) {
            this._reference[i] = in.readDouble();
        }
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = super.getExactSizeOnDisk();
        ret += this._data.getExactSizeOnDisk();
        return ret += (long)(8 * this._colIndexes.length);
    }

    @Override
    public double getCost(ComputationCostEstimator e, int nRows) {
        int nVals = this.getNumValues();
        int nCols = this.getNumCols();
        return e.getCost(nRows, nRows, nCols, nVals, this._dict.getSparsity());
    }

    @Override
    public AColGroup replace(double pattern, double replace) {
        boolean patternInReference = false;
        for (double d : this._reference) {
            if (pattern != d) continue;
            patternInReference = true;
            break;
        }
        if (patternInReference) {
            throw new NotImplementedException("Not Implemented replace where a value in reference should be replaced");
        }
        ADictionary newDict = this._dict.replaceWithReference(pattern, replace, this._reference);
        return ColGroupDDCFOR.create(this._colIndexes, this._numRows, newDict, this._data, this.getCachedCounts(), this._reference);
    }

    @Override
    protected double computeMxx(double c, Builtin builtin) {
        return this._dict.aggregateWithReference(c, builtin, this._reference, false);
    }

    @Override
    protected void computeColMxx(double[] c, Builtin builtin) {
        this._dict.aggregateColsWithReference(c, builtin, this._colIndexes, this._reference, false);
    }

    @Override
    protected void computeSum(double[] c, int nRows) {
        super.computeSum(c, nRows);
        double refSum = FORUtil.refSum(this._reference);
        c[0] = c[0] + refSum * (double)nRows;
    }

    @Override
    public void computeColSums(double[] c, int nRows) {
        super.computeColSums(c, nRows);
        for (int i = 0; i < this._colIndexes.length; ++i) {
            int n = this._colIndexes[i];
            c[n] = c[n] + this._reference[i] * (double)nRows;
        }
    }

    @Override
    protected void computeSumSq(double[] c, int nRows) {
        c[0] = c[0] + this._dict.sumSqWithReference(this.getCounts(), this._reference);
        double refSum = FORUtil.refSumSq(this._reference);
        c[0] = c[0] + refSum * (double)(this._numRows - this._data.size());
    }

    @Override
    protected void computeColSumsSq(double[] c, int nRows) {
        this._dict = this._dict.getMBDict(this._colIndexes.length);
        this._dict.colSumSqWithReference(c, this.getCounts(), this._colIndexes, this._reference);
        for (int i = 0; i < this._colIndexes.length; ++i) {
            int n = this._colIndexes[i];
            c[n] = c[n] + this._reference[i] * this._reference[i] * (double)(this._numRows - this._data.size());
        }
    }

    @Override
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDoubleWithReference(this._reference);
    }

    @Override
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSqWithReference(this._reference);
    }

    @Override
    protected double[] preAggProductRows() {
        throw new NotImplementedException();
    }

    @Override
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRowsWithReference(builtin, this._reference);
    }

    @Override
    protected void computeProduct(double[] c, int nRows) {
        int count = this._numRows - this._data.size();
        this._dict.productWithReference(c, this.getCounts(), this._reference, count);
    }

    @Override
    protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) {
        throw new NotImplementedException("Not Implemented PFOR");
    }

    @Override
    protected void computeColProduct(double[] c, int nRows) {
        throw new NotImplementedException("Not Implemented PFOR");
    }

    @Override
    protected AColGroup sliceSingleColumn(int idx) {
        ColGroupDDCFOR ret = (ColGroupDDCFOR)super.sliceSingleColumn(idx);
        ret._reference = new double[1];
        ret._reference[0] = this._reference[idx];
        return ret;
    }

    @Override
    protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[] outputCols) {
        ColGroupDDCFOR ret = (ColGroupDDCFOR)super.sliceMultiColumns(idStart, idEnd, outputCols);
        int len = idEnd - idStart;
        ret._reference = new double[len];
        int i = 0;
        int ii = idStart;
        while (i < len) {
            ret._reference[i] = this._reference[ii];
            ++i;
            ++ii;
        }
        return ret;
    }

    @Override
    public boolean containsValue(double pattern) {
        if (pattern == 0.0 && this._zeros) {
            return true;
        }
        if (Double.isNaN(pattern) || Double.isInfinite(pattern)) {
            return FORUtil.containsInfOrNan(pattern, this._reference) || this._dict.containsValue(pattern);
        }
        return this._dict.containsValueWithReference(pattern, this._reference);
    }

    @Override
    public long getNumberNonZeros(int nRows) {
        long nnz = 0L;
        int refCount = 0;
        for (int i = 0; i < this._reference.length; ++i) {
            if (this._reference[i] == 0.0) continue;
            ++refCount;
        }
        if (refCount == this._colIndexes.length) {
            return (long)this._colIndexes.length * (long)nRows;
        }
        nnz += this._dict.getNumberNonZerosWithReference(this.getCounts(), this._reference, nRows);
        return Math.min((long)this._colIndexes.length * (long)nRows, nnz += (long)(refCount * nRows));
    }

    @Override
    public AColGroup extractCommon(double[] constV) {
        for (int i = 0; i < this._colIndexes.length; ++i) {
            int n = this._colIndexes[i];
            constV[n] = constV[n] + this._reference[i];
        }
        return ColGroupDDC.create(this._colIndexes, this._numRows, this._dict, this._data, this.getCounts());
    }

    @Override
    public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
        double def = this._reference[0];
        ADictionary d = this._dict.rexpandColsWithReference(max, ignore, cast, def);
        if (d == null) {
            if (def <= 0.0 || def > (double)max) {
                return ColGroupEmpty.create(max);
            }
            double[] retDef = new double[max];
            retDef[(int)def - 1] = 1.0;
            return ColGroupConst.create(retDef);
        }
        int[] outCols = Util.genColsIndices(max);
        if (def <= 0.0) {
            if (ignore) {
                return ColGroupDDC.create(outCols, nRows, d, this._data, this.getCachedCounts());
            }
            throw new DMLRuntimeException("Invalid content of zero in rexpand");
        }
        if (def > (double)max) {
            return ColGroupDDC.create(outCols, nRows, d, this._data, this.getCachedCounts());
        }
        double[] retDef = new double[max];
        retDef[(int)def - 1] = 1.0;
        return ColGroupDDCFOR.create(outCols, nRows, d, this._data, this.getCachedCounts(), retDef);
    }

    @Override
    public CM_COV_Object centralMoment(CMOperator op, int nRows) {
        CM_COV_Object ret = this._dict.centralMomentWithReference(op.fn, this.getCounts(), this._reference[0], nRows);
        int count = this._numRows - this._data.size();
        op.fn.execute(ret, this._reference[0], count);
        return ret;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s ", "Data: "));
        sb.append(this._data);
        sb.append(String.format("\n%15s", "Reference:"));
        sb.append(Arrays.toString(this._reference));
        return sb.toString();
    }
}

