diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp index 6ea7f57ad3c22..fb9f14ddbc0d1 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp @@ -110,7 +110,11 @@ SerializeToCubinPass::serializeISA(const std::string &isa) { CUdevice device; RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0)); CUcontext context; - RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device)); + // Use the primary context. + RETURN_ON_CUDA_ERROR(cuDevicePrimaryCtxRetain(&context, device)); + // Push the primary context so that the next CUDA operations + // actually use it. + RETURN_ON_CUDA_ERROR(cuCtxPushCurrent(context)); CUlinkState linkState; CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER, @@ -146,7 +150,10 @@ SerializeToCubinPass::serializeISA(const std::string &isa) { // This will also destroy the cubin data. RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState)); - RETURN_ON_CUDA_ERROR(cuCtxDestroy(context)); + // Pop and release the primary context. + CUcontext poppedContext; + RETURN_ON_CUDA_ERROR(cuCtxPopCurrent(&poppedContext)); + RETURN_ON_CUDA_ERROR(cuDevicePrimaryCtxRelease(device)); return result; }