/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibCombine {
    protected static final Log LOG = LogFactory.getLog((String)CLALibCombine.class.getName());

    public static MatrixBlock combine(Map<MatrixIndexes, MatrixBlock> m, int k) {
        MatrixIndexes lookup = new MatrixIndexes(1L, 1L);
        MatrixBlock b = m.get(lookup);
        int blen = Math.max(b.getNumColumns(), b.getNumRows());
        long rows = 0L;
        while ((b = m.get(lookup)) != null) {
            rows += (long)b.getNumRows();
            lookup.setIndexes(lookup.getRowIndex() + 1L, 1L);
        }
        lookup.setIndexes(1L, 1L);
        long cols = 0L;
        while ((b = m.get(lookup)) != null) {
            cols += (long)b.getNumColumns();
            lookup.setIndexes(1L, lookup.getColumnIndex() + 1L);
        }
        return CLALibCombine.combine(m, lookup, (int)rows, (int)cols, blen, k);
    }

    public static MatrixBlock combine(Map<MatrixIndexes, MatrixBlock> m, int rlen, int clen, int blen, int k) {
        MatrixIndexes lookup = new MatrixIndexes();
        return CLALibCombine.combine(m, lookup, rlen, clen, blen, k);
    }

    private static MatrixBlock combine(Map<MatrixIndexes, MatrixBlock> m, MatrixIndexes lookup, int rlen, int clen, int blen, int k) {
        if (rlen < blen) {
            return CLALibCombine.CombiningColumnGroups(m, lookup, rlen, clen, blen, k);
        }
        AColGroup.CompressionType[] colTypes = new AColGroup.CompressionType[clen];
        int bc = 0;
        while (bc * blen < clen) {
            lookup.setIndexes(1L, bc + 1);
            MatrixBlock b = m.get(lookup);
            if (!(b instanceof CompressedMatrixBlock)) {
                LOG.warn((Object)"Found uncompressed matrix in Map of matrices, this is not supported in combine therefore falling back to decompression");
                return CLALibCombine.combineViaDecompression(m, rlen, clen, blen, k);
            }
            CompressedMatrixBlock cmb = (CompressedMatrixBlock)b;
            if (cmb.isOverlapping()) {
                LOG.warn((Object)"Not supporting overlapping combine yet falling back to decompression");
                return CLALibCombine.combineViaDecompression(m, rlen, clen, blen, k);
            }
            List<AColGroup> gs = cmb.getColGroups();
            int off = bc * blen;
            for (AColGroup g : gs) {
                try {
                    IIterate cols = g.getColIndices().iterator();
                    AColGroup.CompressionType t = g.getCompType();
                    while (cols.hasNext()) {
                        colTypes[cols.next() + off] = t;
                    }
                }
                catch (Exception e) {
                    throw new DMLCompressionException("Failed combining: " + g.toString());
                }
            }
            ++bc;
        }
        int br = 1;
        while (br * blen < rlen) {
            int bc2 = 0;
            while (bc2 * blen < clen) {
                lookup.setIndexes(br + 1, bc2 + 1);
                MatrixBlock b = m.get(lookup);
                if (!(b instanceof CompressedMatrixBlock)) {
                    LOG.warn((Object)"Found uncompressed matrix in Map of matrices, this is not supported in combine therefore falling back to decompression");
                    return CLALibCombine.combineViaDecompression(m, rlen, clen, blen, k);
                }
                CompressedMatrixBlock cmb = (CompressedMatrixBlock)b;
                if (cmb.isOverlapping()) {
                    LOG.warn((Object)"Not supporting overlapping combine yet falling back to decompression");
                    return CLALibCombine.combineViaDecompression(m, rlen, clen, blen, k);
                }
                List<AColGroup> gs = cmb.getColGroups();
                int off = bc2 * blen;
                for (AColGroup g : gs) {
                    IIterate cols = g.getColIndices().iterator();
                    AColGroup.CompressionType t = g.getCompType();
                    while (cols.hasNext()) {
                        int c = cols.next();
                        if (colTypes[c + off] == t) continue;
                        LOG.warn((Object)"Not supported different types of column groups to combine.Falling back to decompression of all blocks");
                        return CLALibCombine.combineViaDecompression(m, rlen, clen, blen, k);
                    }
                }
                ++bc2;
            }
            ++br;
        }
        return CLALibCombine.CombiningColumnGroups(m, lookup, rlen, clen, blen, k);
    }

    private static MatrixBlock combineViaDecompression(Map<MatrixIndexes, MatrixBlock> m, int rlen, int clen, int blen, int k) {
        MatrixBlock out = new MatrixBlock(rlen, clen, false);
        out.allocateDenseBlock();
        for (Map.Entry<MatrixIndexes, MatrixBlock> e : m.entrySet()) {
            MatrixIndexes ix = e.getKey();
            MatrixBlock block = e.getValue();
            if (block == null) continue;
            int row_offset = (int)(ix.getRowIndex() - 1L) * blen;
            int col_offset = (int)(ix.getColumnIndex() - 1L) * blen;
            block.putInto(out, row_offset, col_offset, false);
        }
        out.setNonZeros(-1L);
        out.examSparsity(true);
        return out;
    }

    private static MatrixBlock CombiningColumnGroups(Map<MatrixIndexes, MatrixBlock> m, MatrixIndexes lookup, int rlen, int clen, int blen, int k) {
        AColGroup[][] finalCols = new AColGroup[clen][];
        int blocksInColumn = (rlen - 1) / blen + 1;
        int br = 0;
        while (br * blen < rlen) {
            int bc = 0;
            while (bc * blen < clen) {
                lookup.setIndexes(br + 1, bc + 1);
                CompressedMatrixBlock cmb = (CompressedMatrixBlock)m.get(lookup);
                for (AColGroup g : cmb.getColGroups()) {
                    AColGroup gc = bc > 0 ? g.shiftColIndices(bc * blen) : g;
                    int c = gc.getColIndices().get(0);
                    if (br == 0) {
                        finalCols[c] = new AColGroup[blocksInColumn];
                    }
                    finalCols[c][br] = gc;
                }
                ++bc;
            }
            ++br;
        }
        ExecutorService pool = CommonThreadPool.get(Math.max(Math.min(clen / 500, k), 1));
        try {
            List finalGroups = pool.submit(() -> ((Stream)Arrays.stream(finalCols).filter(x -> x != null).parallel()).map(x -> CLALibCombine.combineN(x)).collect(Collectors.toList())).get();
            pool.shutdown();
            if (finalGroups.contains(null)) {
                LOG.warn((Object)"Combining via decompression. There was a column group that did not append ");
                return CLALibCombine.combineViaDecompression(m, rlen, clen, blen, k);
            }
            return new CompressedMatrixBlock(rlen, clen, -1L, false, finalGroups);
        }
        catch (InterruptedException | ExecutionException e) {
            pool.shutdown();
            throw new DMLRuntimeException("Failed to combine column groups", e);
        }
    }

    private static AColGroup combineN(AColGroup[] groups) {
        return AColGroup.appendN(groups);
    }
}

