Skip to content

Commit

Permalink
Add KernelAbstractions.jl back-end (#131)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <[email protected]>
Co-authored-by: Valentin Churavy <[email protected]>
  • Loading branch information
3 people authored Mar 28, 2023
1 parent a498c30 commit 9d500f6
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 1 deletion.
48 changes: 47 additions & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "7c57dde9c295c98177ad608871c4ba51950909d4"
project_hash = "71a30cde2fde6edfe35f1a8de14ed5e60d0dafb6"

[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
Expand All @@ -17,6 +17,12 @@ version = "1.1.1"
[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[deps.Atomix]]
deps = ["UnsafeAtomics"]
git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be"
uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
version = "0.1.0"

[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

Expand Down Expand Up @@ -75,6 +81,12 @@ git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.4.1"

[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "SnoopPrecompile", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "350a880e80004f4d5d82a17f737d8fcdc56c3462"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
version = "0.9.1"

[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "f044a2796a9e18e0531b9b3072b0019a61f264bc"
Expand Down Expand Up @@ -120,6 +132,12 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.10"

[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand Down Expand Up @@ -195,13 +213,30 @@ version = "0.7.0"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[deps.SnoopPrecompile]]
deps = ["Preferences"]
git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c"
uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c"
version = "1.0.3"

[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[deps.SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
git-tree-sha1 = "2d7d9e1ddadc8407ffd460e24218e37ef52dd9a3"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.5.16"

[[deps.StaticArraysCore]]
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
version = "1.4.0"

[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -229,6 +264,17 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
version = "0.2.1"

[[deps.UnsafeAtomicsLLVM]]
deps = ["LLVM", "UnsafeAtomics"]
git-tree-sha1 = "33af9d2031d0dc09e2be9a0d4beefec4466def8e"
uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
version = "0.1.0"

[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Metal_LLVM_Tools_jll = "0418c028-ff8c-56b8-a53e-0f9676ed36fc"
ObjectiveC = "e86c9b32-1129-44ac-8ea0-90d5bb39ded9"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
Adapt = "3"
Expand All @@ -26,4 +28,5 @@ LLVM = "4.15"
Metal_LLVM_Tools_jll = "~0.3"
ObjectiveC = "0.1"
Reexport = "1.0"
KernelAbstractions = "0.9.1"
julia = "1.8"
5 changes: 5 additions & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ include("gpuarrays.jl")
include("../lib/mps/MPS.jl")
export MPS

# KernelAbstractions
include("MetalKernels.jl")
import .MetalKernels: MetalBackend
export MetalBackend

end # module
191 changes: 191 additions & 0 deletions src/MetalKernels.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
module MetalKernels

import KernelAbstractions
import Metal
import StaticArrays
import GPUCompiler

struct MetalBackend <: KernelAbstractions.GPU
end

export MetalBackend

KernelAbstractions.allocate(::MetalBackend, ::Type{T}, dims::Tuple) where T = Metal.MtlArray{T}(undef, dims)
KernelAbstractions.zeros(::MetalBackend, ::Type{T}, dims::Tuple) where T = Metal.zeros(T, dims)
KernelAbstractions.ones(::MetalBackend, ::Type{T}, dims::Tuple) where T = Metal.ones(T, dims)

# Import through parent
import KernelAbstractions: StaticArrays, Adapt
import .StaticArrays: MArray

KernelAbstractions.get_backend(::Metal.MtlArray) = MetalBackend()
KernelAbstractions.synchronize(::MetalBackend) = Metal.synchronize()
KernelAbstractions.supports_float64(::MetalBackend) = false
KernelAbstractions.supports_atomics(::MetalBackend) = false

Adapt.adapt_storage(::MetalBackend, a::Array) = Adapt.adapt(Metal.MtlArray, a)
Adapt.adapt_storage(::MetalBackend, a::Metal.MtlArray) = a
Adapt.adapt_storage(::KernelAbstractions.CPU, a::Metal.MtlArray) = convert(Array, a)

function KernelAbstractions.copyto!(::MetalBackend, A::Metal.MtlArray{T}, B::Metal.MtlArray{T}) where T
if Metal.device(dest) == Metal.device(src)
GC.@preserve A B unsafe_copyto!(Metal.device(A), pointer(A), pointer(B), length(A); async=true)
return A
else
error("Copy between different devices not implemented")
end
end

function KernelAbstractions.copyto!(::MetalBackend, A::Array{T}, B::Metal.MtlArray{T}) where T
GC.@preserve A B unsafe_copyto!(Metal.device(B), pointer(A), pointer(B), length(A); async=true)
return A
end

function KernelAbstractions.copyto!(::MetalBackend, A::Metal.MtlArray{T}, B::Array{T}) where T
GC.@preserve A B unsafe_copyto!(Metal.device(A), pointer(A), pointer(B), length(A); async=true)
return A
end

import KernelAbstractions: Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config

###
# Kernel launch
###
function launch_config(kernel::Kernel{MetalBackend}, ndrange, workgroupsize)
if ndrange isa Integer
ndrange = (ndrange,)
end
if workgroupsize isa Integer
workgroupsize = (workgroupsize, )
end

# partition checked that the ndrange's agreed
if KernelAbstractions.ndrange(kernel) <: StaticSize
ndrange = nothing
end

iterspace, dynamic = if KernelAbstractions.workgroupsize(kernel) <: DynamicSize &&
workgroupsize === nothing
# use ndrange as preliminary workgroupsize for autotuning
partition(kernel, ndrange, ndrange)
else
partition(kernel, ndrange, workgroupsize)
end

return ndrange, workgroupsize, iterspace, dynamic
end

function threads_to_workgroupsize(threads, ndrange)
total = 1
return map(ndrange) do n
x = min(div(threads, total), n)
total *= x
return x
end
end

function (obj::Kernel{MetalBackend})(args...; ndrange=nothing, workgroupsize=nothing)
ndrange, workgroupsize, iterspace, dynamic = launch_config(obj, ndrange, workgroupsize)
# this might not be the final context, since we may tune the workgroupsize
ctx = mkcontext(obj, ndrange, iterspace)
kernel = Metal.@metal launch=false obj.f(ctx, args...)

is_dynamic =
KernelAbstractions.workgroupsize(obj) <: DynamicSize &&
isnothing(workgroupsize)
if is_dynamic
groupsize = kernel.pipeline.maxTotalThreadsPerThreadgroup
new_workgroupsize = threads_to_workgroupsize(groupsize, ndrange)
iterspace, dynamic = partition(obj, ndrange, new_workgroupsize)
ctx = mkcontext(obj, ndrange, iterspace)
end

nblocks = length(blocks(iterspace))
threads = length(workitems(iterspace))

if nblocks == 0
return nothing
end

# Launch kernel
kernel(ctx, args...; threads=threads, groups=nblocks)
return nothing
end

####################################################################################################

import KernelAbstractions: CompilerMetadata, DynamicCheck, LinearIndices
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds

function mkcontext(kernel::Kernel{MetalBackend}, _ndrange, iterspace)
CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
end
function mkcontext(kernel::Kernel{MetalBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
end

@Metal.device_override @inline function __index_Local_Linear(ctx)
return Metal.thread_position_in_threadgroup_1d()
end

@Metal.device_override @inline function __index_Group_Linear(ctx)
return Metal.threadgroup_position_in_grid_1d()
end

@Metal.device_override @inline function __index_Global_Linear(ctx)
return Metal.thread_position_in_grid_1d()
end

@Metal.device_override @inline function __index_Local_Cartesian(ctx)
@inbounds workitems(__iterspace(ctx))[Metal.thread_position_in_threadgroup_1d()]
end

@Metal.device_override @inline function __index_Group_Cartesian(ctx)
@inbounds blocks(__iterspace(ctx))[Metal.threadgroup_position_in_grid_1d()]
end

@Metal.device_override @inline function __index_Global_Cartesian(ctx)
return @inbounds expand(__iterspace(ctx), Metal.threadgroup_position_in_grid_1d(), Metal.thread_position_in_threadgroup_1d())
end

@Metal.device_override @inline function __validindex(ctx)
if __dynamic_checkbounds(ctx)
I = @inbounds expand(__iterspace(ctx), Metal.threadgroup_position_in_grid_1d(), Metal.thread_position_in_threadgroup_1d())
return I in __ndrange(ctx)
else
return true
end
end

import KernelAbstractions: groupsize, __groupsize, __workitems_iterspace
import KernelAbstractions: SharedMemory, Scratchpad, __synchronize, __size

###
# GPU implementation of shared memory
###
@Metal.device_override @inline function SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
ptr = Metal.emit_threadgroup_memory(T, Val(prod(Dims)))
Metal.MtlDeviceArray(Dims, ptr)
end

###
# GPU implementation of scratch memory
# - private memory for each workitem
###

@Metal.device_override @inline function Scratchpad(ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
StaticArrays.MArray{__size(Dims), T}(undef)
end

@Metal.device_override @inline function __synchronize()
Metal.threadgroup_barrier(Metal.MemoryFlagThreadGroup)
end

@Metal.device_override @inline function __print(args...)
# TODO
end

KernelAbstractions.argconvert(::Kernel{MetalBackend}, arg) = Metal.mtlconvert(arg)

end
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Metal_LLVM_Tools_jll = "0418c028-ff8c-56b8-a53e-0f9676ed36fc"
ObjectiveC = "e86c9b32-1129-44ac-8ea0-90d5bb39ded9"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cmt_jll = "65323cdd-17ec-5719-9643-72016a7f97e3"
10 changes: 10 additions & 0 deletions test/kernelabstractions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import KernelAbstractions
using Metal.MetalKernels

include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))

Testsuite.testsuite(()->MetalBackend(), "Metal", Metal, MtlArray, Metal.MtlDeviceArray; skip_tests=Set([
"Convert", # depends on https://github.com/JuliaGPU/Metal.jl/issues/69
"SpecialFunctions", # no equivalent Metal intrinsics for gamma, erf, etc
"sparse", # not supported yet
]))

0 comments on commit 9d500f6

Please sign in to comment.