/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.gpu.context;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcudnn.JCudnn;
import jcuda.jcusparse.JCusparse;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.lineage.LineageGPUCacheEviction;
import org.apache.sysds.utils.GPUStatistics;

public class GPUContextPool {
    protected static final Log LOG = LogFactory.getLog((String)GPUContextPool.class.getName());
    public static String AVAILABLE_GPUS;
    private static long INITIAL_GPU_MEMORY_BUDGET;
    static boolean initialized;
    static int deviceCount;
    static cudaDeviceProp[] deviceProperties;
    static List<GPUContext> pool;
    static boolean reserved;

    public static synchronized void initializeGPU() {
        if (initialized) {
            return;
        }
        initialized = true;
        GPUContext.LOG.info((Object)"Initializing CUDA");
        long start = System.nanoTime();
        JCuda.setExceptionsEnabled((boolean)true);
        JCudnn.setExceptionsEnabled((boolean)true);
        JCublas2.setExceptionsEnabled((boolean)true);
        JCusparse.setExceptionsEnabled((boolean)true);
        JCudaDriver.setExceptionsEnabled((boolean)true);
        JCudaDriver.cuInit((int)0);
        int[] deviceCountArray = new int[]{0};
        JCudaDriver.cuDeviceGetCount((int[])deviceCountArray);
        deviceCount = deviceCountArray[0];
        deviceProperties = new cudaDeviceProp[deviceCount];
        try {
            ArrayList<Integer> listOfGPUs = GPUContextPool.parseListString(AVAILABLE_GPUS, deviceCount);
            for (int i : listOfGPUs) {
                cudaDeviceProp properties = new cudaDeviceProp();
                JCuda.cudaGetDeviceProperties((cudaDeviceProp)properties, (int)i);
                GPUContextPool.deviceProperties[i] = properties;
                GPUContext gCtx = new GPUContext(i);
                pool.add(gCtx);
            }
        }
        catch (IllegalArgumentException e) {
            LOG.warn((Object)"Invalid setting for setting systemds.gpu.availableGPUs, defaulting to use ALL GPUs");
            for (int i = 0; i < deviceCount; ++i) {
                cudaDeviceProp properties = new cudaDeviceProp();
                JCuda.cudaGetDeviceProperties((cudaDeviceProp)properties, (int)i);
                GPUContextPool.deviceProperties[i] = properties;
                GPUContext gCtx = new GPUContext(i);
                pool.add(gCtx);
            }
        }
        long minAvailableMemory = Long.MAX_VALUE;
        for (GPUContext gCtx : pool) {
            gCtx.initializeThread();
            minAvailableMemory = Math.min(minAvailableMemory, gCtx.getAvailableMemory());
        }
        INITIAL_GPU_MEMORY_BUDGET = minAvailableMemory;
        GPUContext.LOG.info((Object)("Total number of GPUs on the machine: " + deviceCount));
        GPUContext.LOG.info((Object)("GPUs being used: " + AVAILABLE_GPUS));
        GPUContext.LOG.info((Object)("Initial GPU memory: " + GPUContextPool.initialGPUMemBudget()));
        GPUStatistics.cudaInitTime = System.nanoTime() - start;
        LineageGPUCacheEviction.setGPULineageCacheLimit();
    }

    public static ArrayList<Integer> parseListString(String str, int max) {
        ArrayList<Integer> result = new ArrayList<Integer>();
        if ((str = str.trim()).equalsIgnoreCase("-1")) {
            for (int i = 0; i < max; ++i) {
                result.add(i);
            }
        } else if (str.contains("-")) {
            String[] numbersStr = str.split("-");
            if (numbersStr.length != 2) {
                throw new IllegalArgumentException("Invalid string to parse to a list of numbers : " + str);
            }
            String beginStr = numbersStr[0];
            String endStr = numbersStr[1];
            int begin = Integer.parseInt(beginStr);
            int end = Integer.parseInt(endStr);
            for (int i = begin; i <= end; ++i) {
                result.add(i);
            }
        } else if (str.contains(",")) {
            String[] numbers = str.split(",");
            for (int i = 0; i < numbers.length; ++i) {
                int n = Integer.parseInt(numbers[i].trim());
                result.add(n);
            }
        } else {
            int number = Integer.parseInt(str);
            result.add(number);
        }
        Iterator iterator = result.iterator();
        while (iterator.hasNext()) {
            int n = (Integer)iterator.next();
            if (n >= 0 && n < max) continue;
            throw new IllegalArgumentException("Invalid string (" + str + ") parsed to a list of numbers (" + result + ") which exceeds the maximum range : ");
        }
        return result;
    }

    public static synchronized List<GPUContext> reserveAllGPUContexts() {
        if (reserved) {
            throw new DMLRuntimeException("Trying to re-reserve GPUs");
        }
        GPUContextPool.initializeGPU();
        reserved = true;
        LOG.trace((Object)"GPU : Reserved all GPUs");
        return pool;
    }

    public static synchronized int getAvailableCount() {
        return pool.size();
    }

    static cudaDeviceProp getGPUProperties(int device) {
        if (!initialized) {
            GPUContextPool.initializeGPU();
        }
        return deviceProperties[device];
    }

    public static int getDeviceCount() {
        if (!initialized) {
            GPUContextPool.initializeGPU();
        }
        return deviceCount;
    }

    public static synchronized void freeAllGPUContexts() {
        if (!reserved) {
            throw new DMLRuntimeException("Trying to free unreserved GPUs");
        }
        reserved = false;
        LOG.trace((Object)"GPU : Unreserved all GPUs");
    }

    public static synchronized long initialGPUMemBudget() throws RuntimeException {
        try {
            if (!initialized) {
                GPUContextPool.initializeGPU();
            }
            return INITIAL_GPU_MEMORY_BUDGET;
        }
        catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
    }

    static {
        INITIAL_GPU_MEMORY_BUDGET = -1L;
        initialized = false;
        deviceCount = -1;
        pool = new LinkedList<GPUContext>();
        reserved = false;
    }
}

