Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.1.1"
version = "1.2.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -17,6 +17,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -33,6 +34,7 @@ MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
MLDataDevicesMLUtilsExt = "MLUtils"
MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
MLDataDevicesOpenCLExt = ["GPUArrays", "OpenCL"]
MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
MLDataDevicesReverseDiffExt = "ReverseDiff"
MLDataDevicesSparseArraysExt = "SparseArrays"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Currently we provide support for the following backends:
2. `AMDGPU.jl` for AMD ROCM GPUs.
3. `Metal.jl` for Apple Metal GPUs. **(Experimental)**
4. `oneAPI.jl` for Intel GPUs. **(Experimental)**
5. `OpenCL.jl` for openCL devices. **(Extremely Experimental)**

## Updating to v1.0

Expand Down
36 changes: 36 additions & 0 deletions ext/MLDataDevicesOpenCLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module MLDataDevicesOpenCLExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, OpenCLDevice, reset_gpu_device!
using GPUArrays: GPUArrays
using OpenCL: OpenCL, CLArray

__init__() = reset_gpu_device!()

MLDataDevices.loaded(::Union{OpenCLDevice, Type{<:OpenCLDevice}}) = true
# TODO: Check if OpenCL can provide a `functional` function.
MLDataDevices.functional(::Union{OpenCLDevice, Type{<:OpenCLDevice}}) = true

# Default RNG
MLDataDevices.default_device_rng(::OpenCLDevice) = GPUArrays.default_rng(CLArray)

# Query Device from Array
Internal.get_device(::CLArray) = OpenCLDevice()

Internal.get_device_type(::CLArray) = OpenCLDevice

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{OpenCLDevice}, ::AbstractArray)
# TODO: Implement this
@warn "Support for `unsafe_free!` for OpenCL is not implemented yet. This is a no-op." maxlog=1
return
end

# Device Transfer
Adapt.adapt_storage(::OpenCLDevice, x::AbstractArray) = CLArray(x)

# TODO: Eventually we want to do robust device management, since it is possible users
# change the device after creating the OpenCLDevice and that might cuase unwanted

Check warning on line 33 in ext/MLDataDevicesOpenCLExt.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"cuase" should be "cause".
# behavior.

end
4 changes: 3 additions & 1 deletion src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random

abstract type AbstractDevice <: Function end
abstract type AbstractCPUDevice <: AbstractDevice end
abstract type AbstractGPUDevice <: AbstractDevice end

include("public.jl")
Expand All @@ -16,7 +17,8 @@ export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export default_device_rng
export gpu_device, cpu_device

export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
export CPUDevice
export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice
export get_device, get_device_type

export DeviceIterator
Expand Down
14 changes: 8 additions & 6 deletions src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ using Preferences: load_preference
using Random: AbstractRNG

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES,
loaded, functional
MetalDevice, oneAPIDevice, OpenCLDevice, supported_gpu_backends,
GPU_DEVICES, loaded, functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
for dev in (CPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice)
msg = "`device_id` is not applicable for `$dev`."
@eval begin
with_device(::Type{$dev}, ::Nothing) = $dev()
Expand All @@ -19,7 +19,7 @@ for dev in (CPUDevice, MetalDevice, oneAPIDevice)
end
end

for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :OpenCL)
tpkg = name === :CPU ? "" : string(name)
ldev = Symbol(name, :Device)
@eval begin
Expand All @@ -28,7 +28,8 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
end
end

for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice)
for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing},
MetalDevice, oneAPIDevice, OpenCLDevice)
@eval get_device_id(::$(T)) = nothing
end

Expand Down Expand Up @@ -93,7 +94,8 @@ function get_gpu_device(; force_gpu_usage::Bool)
a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support.
b. `AMDGPU.jl` for AMD GPU ROCM Support.
c. `Metal.jl` for Apple Metal GPU Support. (Experimental)
d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1
d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)
e. `OpenCL.jl` for OpenCL Support. (Extremely Experimental)""" maxlog=1
return CPUDevice
end

Expand Down
11 changes: 6 additions & 5 deletions src/public.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct CPUDevice <: AbstractDevice end
struct CPUDevice <: AbstractCPUDevice end
@kwdef struct CUDADevice{D} <: AbstractGPUDevice
device::D = nothing
end
Expand All @@ -7,6 +7,7 @@ end
end
struct MetalDevice <: AbstractGPUDevice end
struct oneAPIDevice <: AbstractGPUDevice end
struct OpenCLDevice <: AbstractGPUDevice end

"""
functional(x::AbstractDevice) -> Bool
Expand Down Expand Up @@ -36,7 +37,7 @@ loaded(x) = false
loaded(::Union{CPUDevice, Type{<:CPUDevice}}) = true

# Order is important here
const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice)

const GPU_DEVICE = Ref{Union{Nothing, AbstractDevice}}(nothing)

Expand Down Expand Up @@ -292,7 +293,7 @@ end
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
# For all other types we rely on fmap which means we lose type stability.
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :OpenCL)
ldev = Symbol(dev, :Device)
@eval begin
function (D::$(ldev))(x::AbstractArray{T}) where {T}
Expand All @@ -318,7 +319,7 @@ end
Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x)
Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng

for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice)
for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, OpenCLDevice)
@eval begin
function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG)
return default_device_rng(to)
Expand All @@ -330,6 +331,6 @@ end
Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x
# Prevent Ambiguity
for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
CUDADevice{Nothing}, MetalDevice, oneAPIDevice, OpenCLDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end
Loading