/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnActivationDescriptor;
import jcuda.jcudnn.cudnnConvolutionDescriptor;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnPoolingDescriptor;
import jcuda.jcudnn.cudnnRNNDescriptor;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysds.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.data.LibMatrixCuDNNConvolutionAlgorithm;
import org.apache.sysds.runtime.matrix.data.LibMatrixCuDNNInputRowFetcher;
import org.apache.sysds.runtime.matrix.data.LibMatrixCuDNNPoolingDescriptors;
import org.apache.sysds.runtime.matrix.data.LibMatrixCuDNNRnnAlgorithm;
import org.apache.sysds.runtime.matrix.data.LibMatrixCuMatMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;

public class LibMatrixCuDNN
extends LibMatrixCUDA {
    private static final boolean RECOMPUTE_DENSE_NNZ = false;
    protected static int CONVOLUTION_PREFERENCE = 0;
    private static final Log LOG = LogFactory.getLog((String)LibMatrixCuDNN.class.getName());

    protected static cudnnHandle getCudnnHandle(GPUContext gCtx) {
        return gCtx.getCudnnHandle();
    }

    public static void conv2dBiasAdd(GPUContext gCtx, String instName, MatrixObject image, MatrixObject bias, MatrixObject filter, MatrixObject output, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) {
        LibMatrixCuDNN.conv2d(gCtx, instName, image, filter, output, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, intermediateMemoryBudget);
        LibMatrixCuDNN.biasAdd(gCtx, instName, output, bias, output);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static Pointer denseIm2col(GPUContext gCtx, String instName, MatrixObject image, boolean isSparseImage, long N, long C, long H, long W, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q) {
        Pointer im2colPointer = null;
        if (isSparseImage) {
            CSRPointer inPointer = LibMatrixCuDNN.getSparsePointer(gCtx, image, instName);
            if (inPointer.nnz < 0L) {
                throw new DMLRuntimeException("Unknown number of nonzeroes in denseIm2col");
            }
            if (inPointer.nnz <= 0L) return null;
            im2colPointer = gCtx.allocate(instName, C * (long)R * (long)S * N * (long)P * (long)Q * (long)sizeOfDataType, false);
            LibMatrixCuDNN.getCudaKernels(gCtx).launchKernel("sparse_dense_im2col", ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCuDNN.toInt(inPointer.nnz)), inPointer.val, inPointer.rowPtr, inPointer.colInd, im2colPointer, inPointer.nnz, N, C * H * W, H * W, W, R, S, P, Q, P * Q, R * S, N * (long)P * (long)Q, stride_h, stride_w, pad_h, pad_w);
            return im2colPointer;
        } else {
            im2colPointer = gCtx.allocate(instName, C * (long)R * (long)S * N * (long)P * (long)Q * (long)sizeOfDataType, false);
            Pointer imagePointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
            LibMatrixCuDNN.getCudaKernels(gCtx).launchKernel("dense_dense_im2col", ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCuDNN.toInt(N * C * H * W)), imagePointer, im2colPointer, N * C * H * W, C * H * W, H * W, W, R, S, P, Q, P * Q, R * S, N * (long)P * (long)Q, stride_h, stride_w, pad_h, pad_w);
        }
        return im2colPointer;
    }

    public static void conv2d(GPUContext gCtx, String instName, MatrixObject image, MatrixObject filter, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) {
        long CHW = C * H * W;
        long KPQ = K * P * Q;
        long CRS = C * R * S;
        long NCHW = (long)N * CHW;
        long NKPQ = (long)N * KPQ;
        long KCRS = (long)K * CRS;
        long NPQ = N * P * Q;
        boolean isSparseFilter = LibMatrixCuDNN.isInSparseFormat(gCtx, filter);
        long filterNnz = LibMatrixCuDNN.getNnz(gCtx, instName, filter, false);
        if (filterNnz == 0L) {
            return;
        }
        boolean isSparseImage = LibMatrixCuDNN.isInSparseFormat(gCtx, image);
        long imageNnz = LibMatrixCuDNN.getNnz(gCtx, instName, image, false);
        if (imageNnz == 0L) {
            return;
        }
        Pointer dstPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, outputBlock, instName);
        if (NCHW < maxNumElementsOfCuDNNTensor && NKPQ < maxNumElementsOfCuDNNTensor && KCRS < maxNumElementsOfCuDNNTensor) {
            if (isSparseFilter && (double)(OptimizerUtils.estimateSizeExactSparsity(CRS, NPQ, 1.0) + OptimizerUtils.estimateSizeExactSparsity((long)K, NPQ, 1.0)) < Math.min((double)LibMatrixCuDNNConvolutionAlgorithm.MAX_WORKSPACE_LIMIT_BYTES, intermediateMemoryBudget)) {
                Pointer im2colPointer = LibMatrixCuDNN.denseIm2col(gCtx, instName, image, isSparseImage, N, C, H, W, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
                CSRPointer filterPointer = filter.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
                Pointer matmultOutputPointer = gCtx.allocate(instName, NKPQ * (long)sizeOfDataType, false);
                LibMatrixCuMatMult.sparseDenseMatMult(gCtx, instName, matmultOutputPointer, filterPointer, im2colPointer, K, CRS, CRS, NPQ, K, NPQ, false, false);
                gCtx.cudaFreeHelper(instName, im2colPointer, DMLScript.EAGER_CUDA_FREE);
                LibMatrixCuDNN.getCudaKernels(gCtx).launchKernel("reorg_knpq", ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCuDNN.toInt(NKPQ)), matmultOutputPointer, dstPointer, NKPQ, NPQ, KPQ, P * Q);
                gCtx.cudaFreeHelper(instName, matmultOutputPointer, DMLScript.EAGER_CUDA_FREE);
            } else {
                double overhead = isSparseFilter ? (double)OptimizerUtils.estimateSizeExactSparsity((long)K, CRS, 1.0) : 0.0;
                double d = isSparseImage ? (double)OptimizerUtils.estimateSizeExactSparsity((long)N, CHW, 1.0) : 0.0;
                Pointer filterPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, filter, instName);
                long workspaceLimit = (long)(intermediateMemoryBudget - (overhead += d));
                int localN = overhead <= intermediateMemoryBudget ? N : 1;
                try (LibMatrixCuDNNConvolutionAlgorithm algo = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionForwardAlgorithm(gCtx, instName, localN, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, workspaceLimit);){
                    if (localN == N) {
                        Pointer imagePointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
                        LibMatrixCuDNN.cudnnConv2d(gCtx, instName, imagePointer, filterPointer, dstPointer, algo);
                    }
                    try (LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);){
                        for (int n = 0; n < N; ++n) {
                            LibMatrixCuDNN.cudnnConv2d(gCtx, instName, imgFetcher.getNthRow(n), filterPointer, dstPointer.withByteOffset((long)n * KPQ * (long)sizeOfDataType), algo);
                        }
                    }
                }
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, K, CRS, N, KPQ);
        }
    }

    public static void softmax(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, String outputName) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : softmax, GPUContext=" + gCtx));
        }
        cudnnTensorDescriptor tensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(LibMatrixCuDNN.toInt(in1.getNumRows()), LibMatrixCuDNN.toInt(in1.getNumColumns()), 1, 1);
        Pointer srcPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, in1, instName);
        MatrixObject out = ec.getMatrixObject(outputName);
        ec.allocateGPUMatrixObject(outputName, in1.getNumRows(), in1.getNumColumns());
        out.getGPUObject(gCtx).allocateAndFillDense(0.0);
        Pointer dstPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, out, instName);
        JCudnn.cudnnSoftmaxForward((cudnnHandle)gCtx.getCudnnHandle(), (int)1, (int)1, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)tensorDesc, (Pointer)srcPointer, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)tensorDesc, (Pointer)dstPointer);
        JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)tensorDesc);
    }

    private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) {
        cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor((cudnnTensorDescriptor)tensorDescriptor);
        JCudnn.cudnnSetTensor4dDescriptor((cudnnTensorDescriptor)tensorDescriptor, (int)0, (int)LibMatrixCUDA.CUDNN_DATA_TYPE, (int)N, (int)C, (int)H, (int)W);
        return tensorDescriptor;
    }

    private static void throwCuDNNDimensionError(long dim1, long dim2, long dim3, long dim4) {
        throw new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". Given input matrix dimensions: [" + dim1 + "," + dim2 + "]. Output dimension:  [" + dim3 + "," + dim4 + "].");
    }

    private static void throwCuDNNDimensionError(long dim1, long dim2, long dim3, long dim4, long dim5, long dim6) {
        throw new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". Given input matrix dimensions: [" + dim1 + "," + dim2 + "], [" + dim3 + "," + dim4 + "]. Output dimension: [" + dim5 + "," + dim6 + "]");
    }

    private static void cudnnConv2d(GPUContext gCtx, String instName, Pointer image, Pointer filter, Pointer output, LibMatrixCuDNNConvolutionAlgorithm algo) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : conv2d, GPUContext=" + gCtx));
        }
        try {
            int status = JCudnn.cudnnConvolutionForward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)algo.nchwTensorDesc, (Pointer)image, (cudnnFilterDescriptor)algo.filterDesc, (Pointer)filter, (cudnnConvolutionDescriptor)algo.convDesc, (int)algo.algo, (Pointer)algo.workSpace, (long)algo.sizeInBytes, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)algo.nkpqTensorDesc, (Pointer)output);
            if (status != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionForward: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
    }

    public static void conv2dBackwardFilter(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) {
        long CHW = C * H * W;
        long KPQ = K * P * Q;
        long CRS = C * R * S;
        long NCHW = (long)N * CHW;
        long NKPQ = (long)N * KPQ;
        long KCRS = (long)K * CRS;
        boolean isSparseDout = LibMatrixCuDNN.isInSparseFormat(gCtx, dout);
        long doutNnz = LibMatrixCuDNN.getNnz(gCtx, instName, dout, false);
        if (doutNnz == 0L) {
            return;
        }
        boolean isSparseImage = LibMatrixCuDNN.isInSparseFormat(gCtx, image);
        long imageNnz = LibMatrixCuDNN.getNnz(gCtx, instName, image, false);
        if (imageNnz == 0L) {
            return;
        }
        if (NCHW < maxNumElementsOfCuDNNTensor && NKPQ < maxNumElementsOfCuDNNTensor && KCRS < maxNumElementsOfCuDNNTensor) {
            Pointer dwPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, outputBlock, instName);
            double overhead = isSparseImage ? (double)OptimizerUtils.estimateSizeExactSparsity((long)N, CHW, 1.0) : 0.0;
            long workspaceLimit = (long)(intermediateMemoryBudget - (overhead += isSparseDout ? (double)OptimizerUtils.estimateSizeExactSparsity((long)N, KPQ, 1.0) : 0.0));
            int localN = overhead <= intermediateMemoryBudget ? N : 1;
            try (LibMatrixCuDNNConvolutionAlgorithm algo = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardFilterAlgorithm(gCtx, instName, localN, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, workspaceLimit);){
                if (localN == N) {
                    Pointer imagePointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
                    Pointer doutPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dout, instName);
                    LibMatrixCuDNN.cudnnConv2dBackwardFilter(gCtx, instName, imagePointer, doutPointer, dwPointer, algo);
                }
                try (LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
                     LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);){
                    Pointer tempdwPointer = gCtx.allocate(instName, KCRS * (long)sizeOfDataType, false);
                    for (int n = 0; n < N; ++n) {
                        JCuda.cudaMemset((Pointer)tempdwPointer, (int)0, (long)(KCRS * (long)sizeOfDataType));
                        LibMatrixCuDNN.cudnnConv2dBackwardFilter(gCtx, instName, imgFetcher.getNthRow(n), doutFetcher.getNthRow(n), tempdwPointer, algo);
                        LibMatrixCuDNN.getCudaKernels(gCtx).launchKernel("inplace_add", ExecutionConfig.getConfigForSimpleMatrixOperations(K, LibMatrixCuDNN.toInt(CRS)), tempdwPointer, dwPointer, K, LibMatrixCuDNN.toInt(CRS));
                    }
                    gCtx.cudaFreeHelper(instName, tempdwPointer, true);
                }
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, N, KPQ, K, CRS);
        }
    }

    private static void cudnnConv2dBackwardFilter(GPUContext gCtx, String instName, Pointer imagePointer, Pointer doutPointer, Pointer dwPointer, LibMatrixCuDNNConvolutionAlgorithm algo) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : conv2dBackwardFilter, GPUContext=" + gCtx));
        }
        try {
            int status = JCudnn.cudnnConvolutionBackwardFilter((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)algo.nchwTensorDesc, (Pointer)imagePointer, (cudnnTensorDescriptor)algo.nkpqTensorDesc, (Pointer)doutPointer, (cudnnConvolutionDescriptor)algo.convDesc, (int)algo.algo, (Pointer)algo.workSpace, (long)algo.sizeInBytes, (Pointer)LibMatrixCuDNN.zero(), (cudnnFilterDescriptor)algo.filterDesc, (Pointer)dwPointer);
            if (status != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardFilter: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
    }

    public static void conv2dBackwardData(GPUContext gCtx, String instName, MatrixObject filter, MatrixObject dout, MatrixObject output, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) {
        long CHW = C * H * W;
        long KPQ = K * P * Q;
        long CRS = C * R * S;
        long NCHW = (long)N * CHW;
        long NKPQ = (long)N * KPQ;
        long KCRS = (long)K * CRS;
        boolean isSparseFilter = LibMatrixCuDNN.isInSparseFormat(gCtx, filter);
        long filterNnz = LibMatrixCuDNN.getNnz(gCtx, instName, filter, false);
        if (filterNnz == 0L) {
            return;
        }
        boolean isSparseDout = LibMatrixCuDNN.isInSparseFormat(gCtx, dout);
        long doutNnz = LibMatrixCuDNN.getNnz(gCtx, instName, dout, false);
        if (doutNnz == 0L) {
            return;
        }
        if (NCHW < maxNumElementsOfCuDNNTensor && NKPQ < maxNumElementsOfCuDNNTensor && KCRS < maxNumElementsOfCuDNNTensor) {
            double overhead = isSparseFilter ? (double)OptimizerUtils.estimateSizeExactSparsity((long)K, CRS, 1.0) : 0.0;
            double d = isSparseDout ? (double)OptimizerUtils.estimateSizeExactSparsity((long)N, KPQ, 1.0) : 0.0;
            Pointer filterPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, filter, instName);
            Pointer dstPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, output, instName);
            long workspaceLimit = (long)(intermediateMemoryBudget - (overhead += d));
            int localN = overhead <= intermediateMemoryBudget ? N : 1;
            try (LibMatrixCuDNNConvolutionAlgorithm algo = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardDataAlgorithm(gCtx, instName, localN, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, workspaceLimit);){
                if (localN == N) {
                    Pointer doutPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dout, instName);
                    LibMatrixCuDNN.cudnnConv2dBackwardData(gCtx, instName, filterPointer, doutPointer, dstPointer, algo);
                }
                try (LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);){
                    for (int n = 0; n < N; ++n) {
                        LibMatrixCuDNN.cudnnConv2dBackwardData(gCtx, instName, doutFetcher.getNthRow(n), filterPointer, dstPointer.withByteOffset((long)n * CHW * (long)sizeOfDataType), algo);
                    }
                }
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, N, KPQ, K, CRS);
        }
    }

    private static void cudnnConv2dBackwardData(GPUContext gCtx, String instName, Pointer w, Pointer dy, Pointer dx, LibMatrixCuDNNConvolutionAlgorithm algo) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : conv2dBackwardData, GPUContext=" + gCtx));
        }
        try {
            int status = JCudnn.cudnnConvolutionBackwardData((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (Pointer)LibMatrixCuDNN.one(), (cudnnFilterDescriptor)algo.filterDesc, (Pointer)w, (cudnnTensorDescriptor)algo.nkpqTensorDesc, (Pointer)dy, (cudnnConvolutionDescriptor)algo.convDesc, (int)algo.algo, (Pointer)algo.workSpace, (long)algo.sizeInBytes, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)algo.nchwTensorDesc, (Pointer)dx);
            if (status != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardData: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
    }

    public static void pooling(GPUContext gCtx, String instName, MatrixObject image, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, LibMatrixDNN.PoolingType poolingType, double intermediateMemoryBudget) {
        long CHW = C * H * W;
        long CPQ = C * P * Q;
        long NCHW = (long)N * CHW;
        long NCPQ = (long)N * CPQ;
        if (NCHW < maxNumElementsOfCuDNNTensor && NCPQ < maxNumElementsOfCuDNNTensor) {
            long overhead = LibMatrixCuDNN.isInSparseFormat(gCtx, image) ? OptimizerUtils.estimateSizeExactSparsity((long)N, CHW, 1.0) : 0L;
            Pointer y = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, outputBlock, instName);
            if ((double)overhead <= intermediateMemoryBudget) {
                Pointer x = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
                LibMatrixCuDNN.cudnnPoolingHelper(gCtx, instName, x, y, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);
            } else {
                try (LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);){
                    for (int n = 0; n < N; ++n) {
                        LibMatrixCuDNN.cudnnPoolingHelper(gCtx, instName, imgFetcher.getNthRow(n), y.withByteOffset((long)n * CPQ * (long)sizeOfDataType), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);
                    }
                }
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, N, CPQ);
        }
    }

    private static void cudnnPoolingHelper(GPUContext gCtx, String instName, Pointer x, Pointer y, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, LibMatrixDNN.PoolingType poolingType) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : perform pooling, GPUContext=" + gCtx));
        }
        try (LibMatrixCuDNNPoolingDescriptors desc = LibMatrixCuDNNPoolingDescriptors.cudnnPoolingDescriptors(gCtx, instName, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);){
            int status = JCudnn.cudnnPoolingForward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnPoolingDescriptor)desc.poolingDesc, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)desc.xDesc, (Pointer)x, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)desc.yDesc, (Pointer)y);
            if (status != 0) {
                throw new DMLRuntimeException("Could not executed cudnnPoolingForward: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
    }

    public static void poolingBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout, MatrixObject maxpoolOutput, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, LibMatrixDNN.PoolingType poolingType, double intermediateMemoryBudget) {
        boolean isMaxPoolOutputProvided;
        long CHW = C * H * W;
        long CPQ = C * P * Q;
        long NCHW = (long)N * CHW;
        long NCPQ = (long)N * CPQ;
        boolean bl = isMaxPoolOutputProvided = maxpoolOutput != null;
        if (NCHW < maxNumElementsOfCuDNNTensor && NCPQ < maxNumElementsOfCuDNNTensor) {
            long overhead = LibMatrixCuDNN.isInSparseFormat(gCtx, image) ? OptimizerUtils.estimateSizeExactSparsity((long)N, CHW, 1.0) : 0L;
            overhead += LibMatrixCuDNN.isInSparseFormat(gCtx, dout) ? OptimizerUtils.estimateSizeExactSparsity((long)N, CPQ, 1.0) : 0L;
            Pointer dx = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, outputBlock, instName);
            if ((double)overhead <= intermediateMemoryBudget) {
                Pointer x = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
                Pointer dy = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dout, instName);
                Pointer y = isMaxPoolOutputProvided ? LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, maxpoolOutput, instName) : null;
                LibMatrixCuDNN.cudnnPoolingBackwardHelper(gCtx, instName, x, dy, y, dx, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);
            } else {
                LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
                LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);
                LibMatrixCuDNNInputRowFetcher maxPoolOutFetcher = isMaxPoolOutputProvided ? new LibMatrixCuDNNInputRowFetcher(gCtx, instName, maxpoolOutput) : null;
                for (int n = 0; n < N; ++n) {
                    Pointer x = imgFetcher.getNthRow(n);
                    Pointer dy = doutFetcher.getNthRow(n);
                    Pointer y = isMaxPoolOutputProvided ? maxPoolOutFetcher.getNthRow(n) : null;
                    LibMatrixCuDNN.cudnnPoolingBackwardHelper(gCtx, instName, x, dy, y, dx.withByteOffset((long)n * CHW * (long)sizeOfDataType), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);
                }
                imgFetcher.close();
                doutFetcher.close();
                if (isMaxPoolOutputProvided) {
                    maxPoolOutFetcher.close();
                }
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, N, CPQ);
        }
    }

    private static void cudnnPoolingBackwardHelper(GPUContext gCtx, String instName, Pointer x, Pointer dy, Pointer y, Pointer dx, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, LibMatrixDNN.PoolingType poolingType) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : maxpoolingBackward, GPUContext=" + gCtx));
        }
        boolean isMaxPoolOutputProvided = y != null;
        try (LibMatrixCuDNNPoolingDescriptors desc = LibMatrixCuDNNPoolingDescriptors.cudnnPoolingBackwardDescriptors(gCtx, instName, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);){
            int status;
            if (!isMaxPoolOutputProvided) {
                long numBytes = (long)N * (long)C * (long)P * (long)Q * (long)sizeOfDataType;
                y = gCtx.allocate(instName, numBytes, false);
                status = JCudnn.cudnnPoolingForward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnPoolingDescriptor)desc.poolingDesc, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)desc.xDesc, (Pointer)x, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)desc.yDesc, (Pointer)y);
                if (status != 0) {
                    throw new DMLRuntimeException("Could not executed cudnnPoolingForward before cudnnPoolingBackward: " + cudnnStatus.stringFor((int)status));
                }
            }
            if ((status = JCudnn.cudnnPoolingBackward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnPoolingDescriptor)desc.poolingDesc, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)desc.yDesc, (Pointer)y, (cudnnTensorDescriptor)desc.dyDesc, (Pointer)dy, (cudnnTensorDescriptor)desc.xDesc, (Pointer)x, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)desc.dxDesc, (Pointer)dx)) != 0) {
                throw new DMLRuntimeException("Could not executed cudnnPoolingBackward: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
        finally {
            if (!isMaxPoolOutputProvided) {
                gCtx.cudaFreeHelper(instName, y, DMLScript.EAGER_CUDA_FREE);
            }
        }
    }

    private static void cudnnReLU(GPUContext gCtx, String instName, MatrixObject in, Pointer dstData, cudnnTensorDescriptor srcTensorDesc) {
        try {
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("GPU : performCuDNNReLU, GPUContext=" + gCtx));
            }
            cudnnTensorDescriptor dstTensorDesc = srcTensorDesc;
            Pointer srcData = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, in, instName);
            cudnnActivationDescriptor activationDescriptor = new cudnnActivationDescriptor();
            JCudnn.cudnnCreateActivationDescriptor((cudnnActivationDescriptor)activationDescriptor);
            double dummy = -1.0;
            JCudnn.cudnnSetActivationDescriptor((cudnnActivationDescriptor)activationDescriptor, (int)1, (int)1, (double)dummy);
            JCudnn.cudnnActivationForward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnActivationDescriptor)activationDescriptor, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)srcTensorDesc, (Pointer)srcData, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)dstTensorDesc, (Pointer)dstData);
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
    }

    public static void relu(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in, String outputName) {
        if (ec.getGPUContext(0) != gCtx) {
            throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
        }
        long N = in.getNumRows();
        long CHW = in.getNumColumns();
        Pointer dstData = LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, outputName, in.getNumRows(), in.getNumColumns());
        if (N * CHW >= maxNumElementsOfCuDNNTensor) {
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("GPU : relu custom kernel, GPUContext=" + gCtx));
            }
            Pointer srcData = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, in, instName);
            LibMatrixCuDNN.getCudaKernels(gCtx).launchKernel("relu", ExecutionConfig.getConfigForSimpleMatrixOperations(LibMatrixCuDNN.toInt(N), LibMatrixCuDNN.toInt(CHW)), srcData, dstData, LibMatrixCuDNN.toInt(N), LibMatrixCuDNN.toInt(CHW));
        } else {
            cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
            JCudnn.cudnnCreateTensorDescriptor((cudnnTensorDescriptor)tensorDescriptor);
            JCudnn.cudnnSetTensor4dDescriptor((cudnnTensorDescriptor)tensorDescriptor, (int)0, (int)CUDNN_DATA_TYPE, (int)LibMatrixCuDNN.toInt(N), (int)1, (int)1, (int)LibMatrixCuDNN.toInt(CHW));
            LibMatrixCuDNN.cudnnReLU(gCtx, instName, in, dstData, tensorDescriptor);
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)tensorDescriptor);
        }
    }

    static Pointer getDenseInputPointer(ExecutionContext ec, GPUContext gCtx, String instName, String inputName, long numRows, long numCols) throws DMLRuntimeException {
        MatrixObject output = ec.getMatrixInputForGPUInstruction(inputName, instName);
        return LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, output, instName, LibMatrixCuDNN.toInt(numRows), LibMatrixCuDNN.toInt(numCols));
    }

    static Pointer getDenseOutputPointer(ExecutionContext ec, GPUContext gCtx, String instName, String outputName, long numRows, long numCols) throws DMLRuntimeException {
        MatrixObject output = ec.getMatrixObject(outputName);
        LibMatrixCuDNN.getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, numRows, numCols);
        return LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, output, instName, LibMatrixCuDNN.toInt(numRows), LibMatrixCuDNN.toInt(numCols));
    }

    public static void lstm(ExecutionContext ec, GPUContext gCtx, String instName, Pointer X, Pointer wPointer, Pointer out0, Pointer c0, boolean return_sequences, String outputName, String cyName, int N, int M, int D, int T) throws DMLRuntimeException {
        LibMatrixCuDNN.singleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, out0, c0, wPointer, outputName, cyName, "lstm", return_sequences, N, M, D, T);
    }

    private static void singleLayerUnidirectionalRNNForward(ExecutionContext ec, GPUContext gCtx, String instName, Pointer x, Pointer hx, Pointer cx, Pointer wPointer, String outputName, String cyName, String rnnMode, boolean return_sequences, int N, int M, int D, int T) throws DMLRuntimeException {
        boolean hasCarry = rnnMode.equalsIgnoreCase("lstm");
        Pointer cudnnYPointer = gCtx.allocate(instName, (long)N * (long)T * (long)M * (long)sizeOfDataType, false);
        Pointer hyPointer = !return_sequences ? LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, outputName, N, M) : gCtx.allocate(instName, (long)N * (long)M * (long)sizeOfDataType, false);
        Pointer cyPointer = hasCarry ? LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, cyName, N, M) : new Pointer();
        try (LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, rnnMode, N, T, M, D, true, wPointer);){
            JCudnn.cudnnRNNForwardTraining((cudnnHandle)gCtx.getCudnnHandle(), (cudnnRNNDescriptor)algo.rnnDesc, (int)T, (cudnnTensorDescriptor[])algo.xDesc, (Pointer)x, (cudnnTensorDescriptor)algo.hxDesc, (Pointer)hx, (cudnnTensorDescriptor)algo.cxDesc, (Pointer)cx, (cudnnFilterDescriptor)algo.wDesc, (Pointer)wPointer, (cudnnTensorDescriptor[])algo.yDesc, (Pointer)cudnnYPointer, (cudnnTensorDescriptor)algo.hyDesc, (Pointer)hyPointer, (cudnnTensorDescriptor)algo.cyDesc, (Pointer)cyPointer, (Pointer)algo.workSpace, (long)algo.sizeInBytes, (Pointer)algo.reserveSpace, (long)algo.reserveSpaceSizeInBytes);
        }
        if (return_sequences) {
            gCtx.cudaFreeHelper(instName, hyPointer, DMLScript.EAGER_CUDA_FREE);
            Pointer sysdsYPointer = LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, outputName, N, T * M);
            LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_output", ExecutionConfig.getConfigForSimpleVectorOperations(N * T * M), sysdsYPointer, cudnnYPointer, N, T, M, N * T * M);
        }
        gCtx.cudaFreeHelper(instName, cudnnYPointer, DMLScript.EAGER_CUDA_FREE);
    }

    public static void lstmBackward(ExecutionContext ec, GPUContext gCtx, String instName, Pointer x, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, String dxName, String dwName, String dbName, String dhxName, String dcxName, boolean return_sequences, int N, int M, int D, int T) throws DMLRuntimeException {
        Pointer dy = gCtx.allocate(instName, (long)N * (long)T * (long)M * (long)sizeOfDataType, false);
        int size = return_sequences ? N * T * M : N * M;
        LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_backward_gradients", ExecutionConfig.getConfigForSimpleVectorOperations(size), LibMatrixCuDNN.getDenseInputPointer(ec, gCtx, instName, doutName, N, return_sequences ? (long)T * (long)M : (long)M), dy, N, T, M, size, return_sequences ? 1 : 0);
        ec.releaseMatrixInputForGPUInstruction(doutName);
        Pointer yPointer = gCtx.allocate(instName, (long)N * (long)T * (long)M * (long)sizeOfDataType, false);
        try (LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", N, T, M, D, true, wPointer);){
            JCudnn.cudnnRNNForwardTraining((cudnnHandle)gCtx.getCudnnHandle(), (cudnnRNNDescriptor)algo.rnnDesc, (int)T, (cudnnTensorDescriptor[])algo.xDesc, (Pointer)x, (cudnnTensorDescriptor)algo.hxDesc, (Pointer)hx, (cudnnTensorDescriptor)algo.cxDesc, (Pointer)cx, (cudnnFilterDescriptor)algo.wDesc, (Pointer)wPointer, (cudnnTensorDescriptor[])algo.yDesc, (Pointer)yPointer, (cudnnTensorDescriptor)algo.hyDesc, (Pointer)new Pointer(), (cudnnTensorDescriptor)algo.cyDesc, (Pointer)new Pointer(), (Pointer)algo.workSpace, (long)algo.sizeInBytes, (Pointer)algo.reserveSpace, (long)algo.reserveSpaceSizeInBytes);
            Pointer cudnnDx = gCtx.allocate(instName, (long)N * (long)T * (long)D * (long)LibMatrixCUDA.sizeOfDataType, false);
            JCudnn.cudnnRNNBackwardData((cudnnHandle)gCtx.getCudnnHandle(), (cudnnRNNDescriptor)algo.rnnDesc, (int)T, (cudnnTensorDescriptor[])algo.yDesc, (Pointer)yPointer, (cudnnTensorDescriptor[])algo.dyDesc, (Pointer)dy, (cudnnTensorDescriptor)algo.dhyDesc, (Pointer)new Pointer(), (cudnnTensorDescriptor)algo.dcyDesc, (Pointer)LibMatrixCuDNN.getDenseInputPointer(ec, gCtx, instName, dcyName, N, M), (cudnnFilterDescriptor)algo.wDesc, (Pointer)wPointer, (cudnnTensorDescriptor)algo.hxDesc, (Pointer)hx, (cudnnTensorDescriptor)algo.cxDesc, (Pointer)cx, (cudnnTensorDescriptor[])algo.dxDesc, (Pointer)cudnnDx, (cudnnTensorDescriptor)algo.dhxDesc, (Pointer)LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dhxName, N, M), (cudnnTensorDescriptor)algo.dcxDesc, (Pointer)LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dcxName, N, M), (Pointer)algo.workSpace, (long)algo.sizeInBytes, (Pointer)algo.reserveSpace, (long)algo.reserveSpaceSizeInBytes);
            gCtx.cudaFreeHelper(instName, dy, DMLScript.EAGER_CUDA_FREE);
            ec.releaseMatrixInputForGPUInstruction(dcyName);
            ec.releaseMatrixOutputForGPUInstruction(dhxName);
            ec.releaseMatrixOutputForGPUInstruction(dcxName);
            Pointer smlDx = LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dxName, N, T * D);
            LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dinput", ExecutionConfig.getConfigForSimpleVectorOperations(N * T * D), smlDx, cudnnDx, N, D, T * D, N * T * D);
            ec.releaseMatrixOutputForGPUInstruction(dxName);
            gCtx.cudaFreeHelper(instName, cudnnDx, DMLScript.EAGER_CUDA_FREE);
            Pointer cudnnDwPointer = gCtx.allocate(instName, (long)(D + M + 2) * (4L * (long)M) * (long)LibMatrixCUDA.sizeOfDataType, false);
            JCudnn.cudnnRNNBackwardWeights((cudnnHandle)gCtx.getCudnnHandle(), (cudnnRNNDescriptor)algo.rnnDesc, (int)T, (cudnnTensorDescriptor[])algo.xDesc, (Pointer)x, (cudnnTensorDescriptor)algo.hxDesc, (Pointer)hx, (cudnnTensorDescriptor[])algo.yDesc, (Pointer)yPointer, (Pointer)algo.workSpace, (long)algo.sizeInBytes, (cudnnFilterDescriptor)algo.dwDesc, (Pointer)cudnnDwPointer, (Pointer)algo.reserveSpace, (long)algo.reserveSpaceSizeInBytes);
            LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dweight", ExecutionConfig.getConfigForSimpleVectorOperations((D + M + 2) * (4 * M)), LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dwName, D + M, 4 * M), LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, dbName, 1L, 4 * M), cudnnDwPointer, D, M);
            gCtx.cudaFreeHelper(instName, cudnnDwPointer, DMLScript.EAGER_CUDA_FREE);
            ec.releaseMatrixOutputForGPUInstruction(dwName);
            ec.releaseMatrixOutputForGPUInstruction(dbName);
            gCtx.cudaFreeHelper(instName, yPointer, DMLScript.EAGER_CUDA_FREE);
        }
    }

    public static void batchNormalizationForwardTraining(GPUContext gCtx, String instName, MatrixObject image, MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, MatrixObject ret, MatrixObject retRunningMean, MatrixObject retRunningVar, double epsilon, double exponentialAverageFactor, MatrixObject resultSaveMean, MatrixObject resultSaveInvVariance) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : batchNormalizationForwardTraining, GPUContext=" + gCtx));
        }
        int N = LibMatrixCuDNN.toInt(image.getNumRows());
        int C = LibMatrixCuDNN.toInt(scale.getNumRows());
        long CHW = image.getNumColumns();
        LibMatrixCuDNN.validateBatchNormalizationDimensions(scale, bias, runningMean, runningVar, C);
        cudnnTensorDescriptor nCHWDescriptor = LibMatrixCuDNN.allocateNCHWDescriptors(gCtx, N, C, CHW, new MatrixObject[]{image}, new MatrixObject[]{ret});
        cudnnTensorDescriptor scaleTensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(1, C, 1, 1);
        Pointer imagePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
        Pointer retPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, ret, instName);
        Pointer biasPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName);
        Pointer scalePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, scale, instName);
        Pointer runningMeanPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, runningMean, instName);
        Pointer runningVarPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, runningVar, instName);
        Pointer retRunningMeanPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, retRunningMean, instName);
        Pointer retRunningVarPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, retRunningVar, instName);
        JCuda.cudaMemcpy((Pointer)retRunningMeanPtr, (Pointer)runningMeanPtr, (long)(C * sizeOfDataType), (int)3);
        JCuda.cudaMemcpy((Pointer)retRunningVarPtr, (Pointer)runningVarPtr, (long)(C * sizeOfDataType), (int)3);
        Pointer resultSaveMeanPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, resultSaveMean, instName);
        Pointer resultSaveInvVariancePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, resultSaveInvVariance, instName);
        LibMatrixCuDNN.checkStatus(JCudnn.cudnnBatchNormalizationForwardTraining((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (int)1, (Pointer)LibMatrixCuDNN.one(), (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)nCHWDescriptor, (Pointer)imagePtr, (cudnnTensorDescriptor)nCHWDescriptor, (Pointer)retPtr, (cudnnTensorDescriptor)scaleTensorDesc, (Pointer)scalePtr, (Pointer)biasPtr, (double)exponentialAverageFactor, (Pointer)retRunningMeanPtr, (Pointer)retRunningVarPtr, (double)epsilon, (Pointer)resultSaveMeanPtr, (Pointer)resultSaveInvVariancePtr));
    }

    public static void batchNormalizationForwardInference(GPUContext gCtx, String instName, MatrixObject image, MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, MatrixObject ret, double epsilon) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : batchNormalizationForwardInference, GPUContext=" + gCtx));
        }
        int N = LibMatrixCuDNN.toInt(image.getNumRows());
        int C = LibMatrixCuDNN.toInt(scale.getNumRows());
        long CHW = image.getNumColumns();
        LibMatrixCuDNN.validateBatchNormalizationDimensions(scale, bias, runningMean, runningVar, C);
        cudnnTensorDescriptor nCHWDescriptor = LibMatrixCuDNN.allocateNCHWDescriptors(gCtx, N, C, CHW, new MatrixObject[]{image}, new MatrixObject[]{ret});
        cudnnTensorDescriptor scaleTensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(1, C, 1, 1);
        Pointer imagePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
        Pointer retPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, ret, instName);
        Pointer biasPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName);
        Pointer scalePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, scale, instName);
        Pointer runningMeanPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, runningMean, instName);
        Pointer runningVarPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, runningVar, instName);
        LibMatrixCuDNN.checkStatus(JCudnn.cudnnBatchNormalizationForwardInference((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (int)1, (Pointer)LibMatrixCuDNN.one(), (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)nCHWDescriptor, (Pointer)imagePtr, (cudnnTensorDescriptor)nCHWDescriptor, (Pointer)retPtr, (cudnnTensorDescriptor)scaleTensorDesc, (Pointer)scalePtr, (Pointer)biasPtr, (Pointer)runningMeanPtr, (Pointer)runningVarPtr, (double)epsilon));
    }

    public static void batchNormalizationBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout, MatrixObject scale, MatrixObject dX, MatrixObject dScale, MatrixObject dBias, double epsilon, MatrixObject resultSaveMean, MatrixObject resultSaveInvVariance) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : batchNormalizationBackward, GPUContext=" + gCtx));
        }
        int N = LibMatrixCuDNN.toInt(image.getNumRows());
        int C = LibMatrixCuDNN.toInt(scale.getNumRows());
        long CHW = image.getNumColumns();
        cudnnTensorDescriptor nCHWDescriptor = LibMatrixCuDNN.allocateNCHWDescriptors(gCtx, N, C, CHW, new MatrixObject[]{image, dout}, new MatrixObject[]{dX});
        cudnnTensorDescriptor scaleTensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(1, C, 1, 1);
        Pointer imagePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
        Pointer doutPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dout, instName);
        Pointer scalePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, scale, instName);
        Pointer dXPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dX, instName);
        Pointer dScalePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dScale, instName);
        Pointer dBiasPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dBias, instName);
        Pointer resultSaveMeanPtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, resultSaveMean, instName);
        Pointer resultSaveInvVariancePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, resultSaveInvVariance, instName);
        LibMatrixCuDNN.checkStatus(JCudnn.cudnnBatchNormalizationBackward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (int)1, (Pointer)LibMatrixCuDNN.one(), (Pointer)LibMatrixCuDNN.zero(), (Pointer)LibMatrixCuDNN.one(), (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)nCHWDescriptor, (Pointer)imagePtr, (cudnnTensorDescriptor)nCHWDescriptor, (Pointer)doutPtr, (cudnnTensorDescriptor)nCHWDescriptor, (Pointer)dXPtr, (cudnnTensorDescriptor)scaleTensorDesc, (Pointer)scalePtr, (Pointer)dScalePtr, (Pointer)dBiasPtr, (double)epsilon, (Pointer)resultSaveMeanPtr, (Pointer)resultSaveInvVariancePtr));
    }

    private static void validateBatchNormalizationDimensions(MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, int C) throws DMLRuntimeException {
        if (scale.getNumRows() != (long)C || scale.getNumColumns() != 1L) {
            throw new DMLRuntimeException("Incorrect dimensions for scale. Expected a column vector of size " + C + ", but found [" + scale.getNumRows() + ", " + scale.getNumColumns() + "]");
        }
        if (bias.getNumRows() != (long)C || bias.getNumColumns() != 1L) {
            throw new DMLRuntimeException("Incorrect dimensions for bias. Expected a column vector of size " + C + ", but found [" + bias.getNumRows() + ", " + bias.getNumColumns() + "]");
        }
        if (runningMean.getNumRows() != (long)C || runningMean.getNumColumns() != 1L) {
            throw new DMLRuntimeException("Incorrect dimensions for running mean. Expected a column vector of size " + C + ", but found [" + runningMean.getNumRows() + ", " + runningMean.getNumColumns() + "]");
        }
        if (runningVar.getNumRows() != (long)C || runningVar.getNumColumns() != 1L) {
            throw new DMLRuntimeException("Incorrect dimensions for running variance. Expected a column vector of size " + C + ", but found [" + runningVar.getNumRows() + ", " + runningVar.getNumColumns() + "]");
        }
    }

    private static cudnnTensorDescriptor allocateNCHWDescriptors(GPUContext gCtx, int N, int C, long CHW, MatrixObject[] input, MatrixObject[] output) throws DMLRuntimeException {
        int HW;
        cudnnTensorDescriptor ret = null;
        if (CHW > Integer.MAX_VALUE * (long)C) {
            throw new DMLRuntimeException("image size (height*width) should be less than 2147483647");
        }
        int H = -1;
        int W = -1;
        H = HW = (int)(CHW / (long)C);
        W = 1;
        double potentialH = Math.sqrt(HW);
        if (potentialH == (double)((int)potentialH)) {
            W = H = (int)potentialH;
        }
        ret = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor((cudnnTensorDescriptor)ret);
        JCudnn.cudnnSetTensor4dDescriptor((cudnnTensorDescriptor)ret, (int)0, (int)CUDNN_DATA_TYPE, (int)N, (int)C, (int)H, (int)W);
        return ret;
    }

    protected static Pointer getDensePointerForCuDNN(GPUContext gCtx, MatrixObject image, String instName) {
        long numElems = image.getNumRows() * image.getNumColumns();
        if (numElems > maxNumElementsOfCuDNNTensor) {
            throw new DMLRuntimeException("CuDNN restriction: the size of input tensor cannot have greater than 2 giga-elements, but has " + numElems + " (i.e. [" + image.getNumRows() + " X " + image.getNumColumns() + "]). Hint: try reducing the mini-batch size.");
        }
        return LibMatrixCuDNN.getDensePointer(gCtx, image, instName);
    }

    public static Pointer getDensePointerForCuDNN(GPUContext gCtx, MatrixObject image, String instName, int numRows, int numCols) throws DMLRuntimeException {
        long numElems = image.getNumRows() * image.getNumColumns();
        if (image.getNumRows() != (long)numRows || image.getNumColumns() != (long)numCols) {
            throw new DMLRuntimeException("Expected input of size:[" + numRows + ", " + numCols + "], but found [" + image.getNumRows() + ", " + image.getNumColumns() + "].");
        }
        if (numElems > maxNumElementsOfCuDNNTensor) {
            throw new DMLRuntimeException("CuDNN restriction: the size of input tensor cannot have greater than 2 giga-elements, but has " + numElems + " (i.e. [" + image.getNumRows() + " X " + image.getNumColumns() + "]). Hint: try reducing the mini-batch size.");
        }
        Pointer ptr = LibMatrixCuDNN.getDensePointer(gCtx, image, instName);
        long sizeOfPtr = gCtx.getMemoryManager().getSizeAllocatedGPUPointer(ptr);
        if (sizeOfPtr != numElems * (long)sizeOfDataType) {
            throw new DMLRuntimeException("Incorrect pointer: expected size:" + numElems * (long)sizeOfDataType + ", but found " + sizeOfPtr);
        }
        return ptr;
    }

    protected static void checkStatus(int status) {
        if (status != 0) {
            throw new DMLRuntimeException("Error status returned by CuDNN:" + cudnnStatus.stringFor((int)status));
        }
    }
}

