/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
import org.apache.sysds.runtime.compress.CompressionStatistics;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.cocode.CoCoderFactory;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupValue;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.cost.ACostEstimate;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder;
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.compress.cost.MemoryCostEstimator;
import org.apache.sysds.runtime.compress.estim.AComEst;
import org.apache.sysds.runtime.compress.estim.ComEstFactory;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.DMLCompressionStatistics;

public class CompressedMatrixBlockFactory {
    private static final Log LOG = LogFactory.getLog((String)CompressedMatrixBlockFactory.class.getName());
    private final Timing time = new Timing(true);
    private final CompressionStatistics _stats = new CompressionStatistics();
    private final int k;
    private final CompressionSettings compSettings;
    private final ACostEstimate costEstimator;
    private double lastPhase;
    private MatrixBlock mb;
    private CompressedMatrixBlock res;
    private int phase = 0;
    private AComEst informationExtractor;
    private CompressedSizeInfo compressionGroups;

    private CompressedMatrixBlockFactory(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings, ACostEstimate costEstimator) {
        this(mb, k, compSettings.create(), costEstimator);
    }

    private CompressedMatrixBlockFactory(MatrixBlock mb, int k, CompressionSettings compSettings, ACostEstimate costEstimator) {
        this.mb = mb;
        this.k = k;
        this.compSettings = compSettings;
        this.costEstimator = costEstimator;
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb) {
        return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder(), (WTreeRoot)null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, WTreeRoot root) {
        return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder(), root);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, CostEstimatorBuilder csb) {
        return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder(), csb);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, InstructionTypeCounter ins) {
        if (ins == null) {
            return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder());
        }
        return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder(), new CostEstimatorBuilder(ins));
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, CompressionSettingsBuilder customSettings) {
        return CompressedMatrixBlockFactory.compress(mb, 1, customSettings, (WTreeRoot)null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k) {
        return CompressedMatrixBlockFactory.compress(mb, k, new CompressionSettingsBuilder(), (WTreeRoot)null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, WTreeRoot root) {
        return CompressedMatrixBlockFactory.compress(mb, k, new CompressionSettingsBuilder(), root);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CostEstimatorBuilder csb) {
        return CompressedMatrixBlockFactory.compress(mb, k, new CompressionSettingsBuilder(), csb);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, InstructionTypeCounter ins) {
        if (ins == null) {
            return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder());
        }
        return CompressedMatrixBlockFactory.compress(mb, k, new CompressionSettingsBuilder(), new CostEstimatorBuilder(ins));
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, ACostEstimate costEstimator) {
        return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder(), costEstimator);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, ACostEstimate costEstimator) {
        return CompressedMatrixBlockFactory.compress(mb, k, new CompressionSettingsBuilder(), costEstimator);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings) {
        return CompressedMatrixBlockFactory.compress(mb, k, compSettings, (WTreeRoot)null);
    }

    public static void compressAsync(ExecutionContext ec, String varName) {
        CompressedMatrixBlockFactory.compressAsync(ec, varName, null);
    }

    public static void compressAsync(ExecutionContext ec, String varName, InstructionTypeCounter ins) {
        CompletableFuture.runAsync(() -> {
            CacheableData<?> data = ec.getCacheableData(varName);
            if (data instanceof MatrixObject) {
                MatrixObject mo = (MatrixObject)data;
                MatrixBlock mb = (MatrixBlock)mo.acquireReadAndRelease();
                MatrixBlock mbc = (MatrixBlock)CompressedMatrixBlockFactory.compress((MatrixBlock)mo.acquireReadAndRelease(), ins).getLeft();
                if (mbc instanceof CompressedMatrixBlock) {
                    ExecutionContext.createCacheableData(mb);
                    mo.acquireModify(mbc);
                    mo.release();
                }
            }
        });
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings, WTreeRoot root) {
        ACostEstimate ice;
        CompressionSettings cs = compSettings.create();
        if (root == null) {
            ice = CostEstimatorFactory.create(cs, null, mb.getNumRows(), mb.getNumColumns(), mb.getSparsity());
        } else {
            CostEstimatorBuilder csb = new CostEstimatorBuilder(root);
            ice = CostEstimatorFactory.create(cs, csb, mb.getNumRows(), mb.getNumColumns(), mb.getSparsity());
        }
        CompressedMatrixBlockFactory cmbf = new CompressedMatrixBlockFactory(mb, k, cs, ice);
        return cmbf.compressMatrix();
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings, CostEstimatorBuilder csb) {
        CompressionSettings cs = compSettings.create();
        ACostEstimate ice = CostEstimatorFactory.create(cs, csb, mb.getNumRows(), mb.getNumColumns(), mb.getSparsity());
        CompressedMatrixBlockFactory cmbf = new CompressedMatrixBlockFactory(mb, k, cs, ice);
        return cmbf.compressMatrix();
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings, ACostEstimate costEstimator) {
        CompressedMatrixBlockFactory cmbf = new CompressedMatrixBlockFactory(mb, k, compSettings, costEstimator);
        return cmbf.compressMatrix();
    }

    public static CompressedMatrixBlock genUncompressedCompressedMatrixBlock(MatrixBlock mb) {
        CompressedMatrixBlock ret = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns());
        AColGroup cg = ColGroupUncompressed.create(mb);
        ret.allocateColGroup(cg);
        ret.setNonZeros(mb.getNonZeros());
        return ret;
    }

    public static CompressedMatrixBlock createConstant(int numRows, int numCols, double value) {
        CompressedMatrixBlock block = new CompressedMatrixBlock(numRows, numCols);
        AColGroup cg = ColGroupConst.create(numCols, value);
        block.allocateColGroup(cg);
        block.recomputeNonZeros();
        if (block.getNumRows() == 0 || block.getNumColumns() == 0) {
            throw new DMLCompressionException("Invalid size of allocated constant compressed matrix block");
        }
        return block;
    }

    private Pair<MatrixBlock, CompressionStatistics> compressMatrix() {
        if (this.mb instanceof CompressedMatrixBlock) {
            return this.returnSelf();
        }
        this._stats.denseSize = MatrixBlock.estimateSizeInMemory(this.mb.getNumRows(), this.mb.getNumColumns(), 1.0);
        this._stats.originalSize = this.mb.getInMemorySize();
        this._stats.originalCost = this.costEstimator.getCost(this.mb);
        if (this.mb.isEmpty()) {
            return this.createEmpty();
        }
        this.res = new CompressedMatrixBlock(this.mb);
        this.classifyPhase();
        if (this.compressionGroups == null) {
            return this.abortCompression();
        }
        this.compressionGroups.clearMaps();
        this.informationExtractor.clearNNZ();
        this.transposePhase();
        this.compressPhase();
        this.finalizePhase();
        if (this.res == null) {
            return this.abortCompression();
        }
        return new ImmutablePair((Object)this.res, (Object)this._stats);
    }

    private void classifyPhase() {
        this.informationExtractor = ComEstFactory.createEstimator(this.mb, this.compSettings, this.k);
        this.compressionGroups = this.informationExtractor.computeCompressedSizeInfos(this.k);
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)"Logging all individual columns estimated cost:");
            for (CompressedSizeInfoColGroup g : this.compressionGroups.getInfo()) {
                LOG.trace((Object)String.format("Cost: %8.0f Size: %16d %15s", this.costEstimator.getCost(g), g.getMinSize(), g.getColumns()));
            }
        }
        this._stats.estimatedSizeCols = this.compressionGroups.memoryEstimate();
        this._stats.estimatedCostCols = this.costEstimator.getCost(this.compressionGroups);
        this.logPhase();
        int nCols = this.compSettings.transposed ? this.mb.getNumRows() : this.mb.getNumColumns();
        double scale = this.costEstimator instanceof ComputationCostEstimator ? (double)nCols / 2.0 : 1.0;
        double threshold = this._stats.estimatedCostCols / scale;
        if (threshold < this._stats.originalCost) {
            if (nCols > 1) {
                this.coCodePhase();
            } else {
                this.logPhase();
            }
        } else {
            this.compressionGroups = null;
            if (LOG.isInfoEnabled()) {
                LOG.info((Object)"Aborting before co-code, because the compression looks bad");
                LOG.info((Object)("Threshold was set to : " + threshold + " but it was above original " + this._stats.originalCost));
                LOG.info((Object)("Original size       : " + this._stats.originalSize));
                LOG.info((Object)("single col size     : " + this._stats.estimatedSizeCols));
                if (!(this.costEstimator instanceof MemoryCostEstimator)) {
                    LOG.info((Object)("original cost      : " + this._stats.originalCost));
                    LOG.info((Object)("single col cost    : " + this._stats.estimatedCostCols));
                }
            }
        }
    }

    private void coCodePhase() {
        this.compressionGroups = CoCoderFactory.findCoCodesByPartitioning(this.informationExtractor, this.compressionGroups, this.k, this.costEstimator, this.compSettings);
        this._stats.estimatedSizeCoCoded = this.compressionGroups.memoryEstimate();
        this._stats.estimatedCostCoCoded = this.costEstimator.getCost(this.compressionGroups);
        this.logPhase();
        if (this._stats.estimatedCostCoCoded > this._stats.originalCost) {
            this.compressionGroups = null;
            if (LOG.isInfoEnabled()) {
                LOG.info((Object)"Aborting after co-code, because the compression looks bad");
                LOG.info((Object)("co-code size      : " + this._stats.estimatedSizeCoCoded));
                LOG.info((Object)("original size     : " + this._stats.originalSize));
                if (!(this.costEstimator instanceof MemoryCostEstimator)) {
                    LOG.info((Object)("original cost    : " + this._stats.originalCost));
                    LOG.info((Object)("single col cost  : " + this._stats.estimatedCostCols));
                    LOG.info((Object)("co-code cost     : " + this._stats.estimatedCostCoCoded));
                }
            }
        }
    }

    private void transposePhase() {
        boolean haveMemory;
        boolean bl = haveMemory = Runtime.getRuntime().freeMemory() - this.mb.estimateSizeInMemory() * 2L > 0L;
        if (!this.compSettings.transposed && haveMemory) {
            this.transposeHeuristics();
            if (this.compSettings.transposed) {
                boolean sparse = this.mb.isInSparseFormat();
                this.mb = LibMatrixReorg.transpose(this.mb, new MatrixBlock(this.mb.getNumColumns(), this.mb.getNumRows(), sparse), this.k, true);
                this.mb.evalSparseFormatInMemory();
            }
        }
        this.logPhase();
    }

    private void transposeHeuristics() {
        switch (this.compSettings.transposeInput) {
            case "true": {
                this.compSettings.transposed = true;
                break;
            }
            case "false": {
                this.compSettings.transposed = false;
                break;
            }
            default: {
                if (this.mb.isInSparseFormat()) {
                    boolean haveManyColumns = this.mb.getNumColumns() > 10000;
                    boolean isNnzLowAndVerySparse = this.mb.getNonZeros() < 1000L && this.mb.getSparsity() < 0.4;
                    boolean isAboveRowNumbers = this.mb.getNumRows() > 500000 && this.mb.getSparsity() < 0.4;
                    boolean isAboveThreadToColumnRatio = this.compressionGroups.getNumberColGroups() > this.mb.getNumColumns() / 30;
                    this.compSettings.transposed = haveManyColumns || isNnzLowAndVerySparse || isAboveRowNumbers && isAboveThreadToColumnRatio;
                    break;
                }
                this.compSettings.transposed = false;
            }
        }
    }

    private void compressPhase() {
        List<AColGroup> c = ColGroupFactory.compressColGroups(this.mb, this.compressionGroups, this.compSettings, this.costEstimator, this.k);
        this.res.allocateColGroupList(c);
        this._stats.compressedInitialSize = this.res.getInMemorySize();
        this.logPhase();
    }

    private void finalizePhase() {
        long oldNNZ;
        this.res.cleanupBlock(true, true);
        this._stats.compressedSize = this.res.getInMemorySize();
        this._stats.compressedCost = this.costEstimator.getCost(this.res.getColGroups(), this.res.getNumRows());
        double ratio = this._stats.getRatio();
        double denseRatio = this._stats.getDenseRatio();
        if (ratio < 1.0 && denseRatio < 100.0) {
            LOG.info((Object)("--dense size:        " + this._stats.denseSize));
            LOG.info((Object)("--original size:     " + this._stats.originalSize));
            LOG.info((Object)("--compressed size:   " + this._stats.compressedSize));
            LOG.info((Object)("--compression ratio: " + ratio));
            LOG.info((Object)"Abort block compression because compression ratio is less than 1.");
            this.res = null;
            this.setNextTimePhase(this.time.stop());
            DMLCompressionStatistics.addCompressionTime(this.getLastTimePhase(), this.phase);
            return;
        }
        this._stats.setColGroupsCounts(this.res.getColGroups());
        if (this.compSettings.isInSparkInstruction) {
            this.res.clearSoftReferenceToDecompressed();
        }
        if ((oldNNZ = this.mb.getNonZeros()) <= 0L) {
            this.res.recomputeNonZeros();
        } else {
            this.res.setNonZeros(oldNNZ);
        }
        this.logPhase();
    }

    private Pair<MatrixBlock, CompressionStatistics> abortCompression() {
        LOG.warn((Object)("Compression aborted at phase: " + this.phase));
        if (this.compSettings.transposed) {
            LibMatrixReorg.transposeInPlace(this.mb, this.k);
        }
        return new ImmutablePair((Object)this.mb, (Object)this._stats);
    }

    private void logPhase() {
        this.setNextTimePhase(this.time.stop());
        DMLCompressionStatistics.addCompressionTime(this.getLastTimePhase(), this.phase);
        if (LOG.isDebugEnabled()) {
            if (this.compSettings.isInSparkInstruction) {
                if (this.phase == 4) {
                    LOG.debug((Object)this._stats);
                }
            } else {
                switch (this.phase) {
                    case 0: {
                        LOG.debug((Object)("--Seed used for comp : " + this.compSettings.seed));
                        LOG.debug((Object)("--compression phase " + this.phase + " Classify  : " + this.getLastTimePhase()));
                        LOG.debug((Object)("--Individual Columns Estimated Compression: " + this._stats.estimatedSizeCols));
                        break;
                    }
                    case 1: {
                        LOG.debug((Object)("--compression phase " + this.phase + " Grouping  : " + this.getLastTimePhase()));
                        LOG.debug((Object)("Grouping using: " + this.compSettings.columnPartitioner));
                        LOG.debug((Object)("Cost Calculated using: " + this.costEstimator));
                        LOG.debug((Object)("--Cocoded Columns estimated Compression:" + this._stats.estimatedSizeCoCoded));
                        if (this.compressionGroups.getInfo().size() < 1000) {
                            LOG.debug((Object)("--Cocoded Columns estimated nr distinct:" + this.compressionGroups.getEstimatedDistinct()));
                            LOG.debug((Object)("--Cocoded Columns nr columns           :" + this.compressionGroups.getNrColumnsString()));
                            break;
                        }
                        LOG.debug((Object)("--CoCoded produce many columns but the first says:\n" + this.compressionGroups.getInfo().get(0)));
                        break;
                    }
                    case 2: {
                        LOG.debug((Object)("--compression phase " + this.phase + " Transpose : " + this.getLastTimePhase()));
                        LOG.debug((Object)("Did transpose: " + this.compSettings.transposed));
                        break;
                    }
                    case 3: {
                        LOG.debug((Object)("--compression phase " + this.phase + " Compress  : " + this.getLastTimePhase()));
                        LOG.debug((Object)("--compressed initial actual size:" + this._stats.compressedInitialSize));
                        break;
                    }
                    case 4: {
                        LOG.debug((Object)("--num col groups:    " + this.res.getColGroups().size()));
                        LOG.debug((Object)("--compression phase  " + this.phase + " Cleanup   : " + this.getLastTimePhase()));
                        LOG.debug((Object)("--col groups types   " + this._stats.getGroupsTypesString()));
                        LOG.debug((Object)("--col groups sizes   " + this._stats.getGroupsSizesString()));
                        LOG.debug((Object)String.format("--dense size:        %16d", this._stats.denseSize));
                        LOG.debug((Object)String.format("--original size:     %16d", this._stats.originalSize));
                        LOG.debug((Object)String.format("--compressed size:   %16d", this._stats.compressedSize));
                        LOG.debug((Object)String.format("--compression ratio: %4.3f", this._stats.getRatio()));
                        LOG.debug((Object)String.format("--Dense       ratio: %4.3f", this._stats.getDenseRatio()));
                        if (!(this.costEstimator instanceof MemoryCostEstimator)) {
                            LOG.debug((Object)String.format("--original cost:     %5.2E", this._stats.originalCost));
                            LOG.debug((Object)String.format("--single col cost:   %5.2E", this._stats.estimatedCostCols));
                            LOG.debug((Object)String.format("--cocode cost:       %5.2E", this._stats.estimatedCostCoCoded));
                            LOG.debug((Object)String.format("--actual cost:       %5.2E", this._stats.compressedCost));
                            LOG.debug((Object)String.format("--relative cost:     %1.4f", this._stats.compressedCost / this._stats.originalCost));
                        }
                        if (this.compressionGroups != null && this.compressionGroups.getInfo().size() < 1000) {
                            int[] lengths = new int[this.res.getColGroups().size()];
                            int i = 0;
                            for (AColGroup colGroup : this.res.getColGroups()) {
                                lengths[i++] = colGroup.getNumValues();
                            }
                            LOG.debug((Object)("--compressed colGroup dictionary sizes: " + Arrays.toString(lengths)));
                            LOG.debug((Object)("--compressed colGroup nr columns      : " + CompressedMatrixBlockFactory.constructNrColumnString(this.res.getColGroups())));
                        }
                        if (!LOG.isTraceEnabled()) break;
                        for (AColGroup colGroup : this.res.getColGroups()) {
                            if (colGroup.estimateInMemorySize() < 1000L) {
                                LOG.trace((Object)colGroup);
                                continue;
                            }
                            LOG.trace((Object)("--colGroups type       : " + colGroup.getClass().getSimpleName() + " size: " + colGroup.estimateInMemorySize() + (String)(colGroup instanceof AColGroupValue ? "  numValues :" + ((AColGroupValue)colGroup).getNumValues() : "") + "  colIndexes : " + colGroup.getColIndices()));
                        }
                        break;
                    }
                }
            }
        }
        ++this.phase;
    }

    private void setNextTimePhase(double time) {
        this.lastPhase = time;
    }

    private double getLastTimePhase() {
        return this.lastPhase;
    }

    private Pair<MatrixBlock, CompressionStatistics> createEmpty() {
        LOG.info((Object)"Empty input to compress, returning a compressed Matrix block with empty column group");
        this.res = new CompressedMatrixBlock(this.mb.getNumRows(), this.mb.getNumColumns());
        ColGroupEmpty cg = ColGroupEmpty.create(this.mb.getNumColumns());
        this.res.allocateColGroup(cg);
        this.res.setNonZeros(0L);
        this._stats.compressedSize = this.res.getInMemorySize();
        this._stats.compressedCost = this.costEstimator.getCost(this.res.getColGroups(), this.res.getNumRows());
        this._stats.setColGroupsCounts(this.res.getColGroups());
        this.phase = 4;
        this.logPhase();
        return new ImmutablePair((Object)this.res, (Object)this._stats);
    }

    private Pair<MatrixBlock, CompressionStatistics> returnSelf() {
        LOG.info((Object)"MatrixBlock already compressed or is Empty");
        return new ImmutablePair((Object)this.mb, null);
    }

    private static String constructNrColumnString(List<AColGroup> cg) {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        sb.append(cg.get(0).getNumCols());
        for (int id = 1; id < cg.size(); ++id) {
            sb.append(", " + cg.get(id).getNumCols());
        }
        sb.append("]");
        return sb.toString();
    }
}

