Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for MPSNDArray #499

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ include("images.jl")
include("matrix.jl")
include("vector.jl")
include("matrixrandom.jl")
include("ndarray.jl")
include("decomposition.jl")
include("copy.jl")

Expand Down
2 changes: 1 addition & 1 deletion lib/mps/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ end
@objcwrapper MPSKernel <: NSObject

@objcproperties MPSKernel begin
@autoproperty options::MPSKernelOptions setter=setOptions
@autoproperty device::id{MTLDevice}
@autoproperty label::id{NSString} setter=setLabel
@autoproperty options::MPSKernelOptions setter=setOptions
end

@autoreleasepool function Base.copy(kernel::K) where {K <: MPSKernel}
Expand Down
315 changes: 315 additions & 0 deletions lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
#
# ndarray descriptor
#

export MPSNDArrayDescriptor

@objcwrapper immutable=false MPSNDArrayDescriptor <: NSObject

@objcproperties MPSNDArrayDescriptor begin
@autoproperty dataType::MPSDataType setter=setDataType
@autoproperty numberOfDimensions::NSUInteger setter=setNumberOfDimensions

# Both are officially available starting macOS 15, but they work in macOS 13/14
@autoproperty preferPackedRows::Bool setter=setPreferPackedRows # macOS 15+
@autoproperty getShape::id{NSArray} # macOS 15+
end

function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes::Ptr)
desc = @objc [MPSNDArrayDescriptor descriptorWithDataType:dataType::MPSDataType
dimensionCount:dimensionCount::NSUInteger
dimensionSizes:dimensionSizes::Ptr{NSUInteger}]::id{MPSNDArrayDescriptor}
obj = MPSNDArrayDescriptor(desc)
return obj
end

function MPSNDArrayDescriptor(dataType::DataType, shape::DenseVector{T}) where {T<:Union{Int,UInt}}
revshape = collect(reverse(shape))
obj = GC.@preserve revshape begin
shapeptr = pointer(revshape)
MPSNDArrayDescriptor(dataType, length(revshape), shapeptr)
end
return obj
end
MPSNDArrayDescriptor(dataType::DataType, shape::Tuple) = MPSNDArrayDescriptor(dataType, collect(shape))

MPSNDArrayDescriptor(dataType::DataType, dimensionSizes...) = @inline MPSNDArrayDescriptor(dataType, collect(dimensionSizes))

lengthOfDimension(desc::MPSNDArrayDescriptor, dim) = @objc [desc::id{MPSNDArrayDescriptor} lengthOfDimension:dim::UInt]::UInt

function transposeDimensionwithDimension(desc::MPSNDArrayDescriptor, dim1, dim2)
@objc [desc::id{MPSNDArrayDescriptor} transposeDimension:dim1::UInt
withDimension:dim2::UInt]::Cvoid
end

#
# ndarray object
#

export MPSNDArray

@objcwrapper immutable=false MPSNDArray <: NSObject

@static if Metal.macos_version() >= v"15"
@objcproperties MPSNDArray begin
@autoproperty dataType::MPSDataType
@autoproperty dataTypeSize::Csize_t
@autoproperty device::id{MTLDevice}
@autoproperty label::id{NSString} setter=setLabel
@autoproperty numberOfDimensions::NSUInteger
@autoproperty parent::id{MPSNDArray}

#Instance methods that act like properties
@autoproperty descriptor::id{MPSNDArrayDescriptor}
@autoproperty resourceSize::NSUInteger
@autoproperty userBuffer::id{MTLBuffer}
end
else
@objcproperties MPSNDArray begin
@autoproperty dataType::MPSDataType
@autoproperty dataTypeSize::Csize_t
@autoproperty device::id{MTLDevice}
@autoproperty label::id{NSString} setter=setLabel
@autoproperty numberOfDimensions::NSUInteger
@autoproperty parent::id{MPSNDArray}
end
end

@objcwrapper immutable=false MPSTemporaryNDArray <: MPSNDArray

@objcproperties MPSTemporaryNDArray begin
@autoproperty readCount::NSUInteger setter=setReadCount
end

function MPSTemporaryNDArray(cmdbuf::MTLCommandBuffer, descriptor::MPSNDArrayDescriptor)
@objc [MPSNDTemporaryNDArray temporaryNDArrayWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
descriptor:descriptor::id{MPSNDArrayDescriptor}]::id{MPSTemporaryNDArray}
return obj
end

"""
MPSNDArray([device::MTLDevice], arr::MtlArray)

Metal ndarray representation used in Performance Shaders.

May not contain more than 16 dimensions.
"""
function MPSNDArray(device::MTLDevice, desc::MPSNDArrayDescriptor)
arrayaddr = @objc [MPSNDArray alloc]::id{MPSNDArray}
obj = MPSNDArray(arrayaddr)
finalizer(release, obj)
@objc [obj::MPSNDArray initWithDevice:device::id{MTLDevice}
descriptor:desc::id{MPSNDArrayDescriptor}]::id{MPSNDArray}
return obj
end

function MPSNDArray(device::MTLDevice, scalar)
arrayaddr = @objc [MPSNDArray alloc]::id{MPSNDArray}
obj = MPSNDArray(arrayaddr)
finalizer(release, obj)
@objc [obj::MPSNDArray initWithDevice:device::id{MTLDevice}
scalar:scalar::Float64]::id{MPSNDArray}
return obj
end

@static if Metal.macos_version() >= v"15"
function MPSNDArray(buffer::MTLBuffer, offset::UInt, descriptor::MPSNDArrayDescriptor)
arrayaddr = @objc [MPSNDArray alloc]::id{MPSNDArray}
obj = MPSNDArray(arrayaddr)
finalizer(release, obj)
@objc [obj::MPSNDArray initWithBuffer:buffer::id{MTLBuffer}
offset:offset::NSUInteger
descriptor:descriptor::id{MPSNDArrayDescriptor}]::id{MPSNDArray}
return obj
end
else
function MPSNDArray(buffer::MTLBuffer, offset::UInt, descriptor::MPSNDArrayDescriptor)
@assert false "Creating an MPSNDArray that shares data with user-provided MTLBuffer is only supported in macOS v15+"
end
end

function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
arrsize = size(arr)
@assert arrsize[end]*sizeof(T) % 16 == 0 "Final dimension of arr must have a byte size divisible by 16"
desc = MPSNDArrayDescriptor(T, arrsize)
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
end

function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode)
ndims = Int(ndarr.numberOfDimensions)
arrsize = [lengthOfDimension(ndarr,i) for i in 0:ndims-1]
T = convert(DataType, ndarr.dataType)
arr = MtlArray{T,ndims,storage}(undef, reverse(arrsize)...)
dev = device(arr)

cmdBuf = MTLCommandBuffer(global_queue(dev))

exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))

commit!(cmdBuf)
wait_completed(cmdBuf)

return arr
end

# rowStrides in Bytes
exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffer, destinationDataType, offset, rowStrides) =
GC.@preserve rowStrides @objc [ndarr::MPSNDArray exportDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
toBuffer:toBuffer::id{MTLBuffer}
destinationDataType:destinationDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing

# rowStrides in Bytes
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset, rowStrides) =
GC.@preserve rowStrides @objc [ndarr::MPSNDArray importDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
fromBuffer:fromBuffer::id{MTLBuffer}
sourceDataType:sourceDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing

# TODO
# exportDataWithCommandBuffer(toImages, offset)
# importDataWithCommandBuffer(fromImages, offset)

# 0-indexed
lengthOfDimension(ndarr::MPSNDArray, dimensionIndex) =
@objc [ndarr::MPSNDArray lengthOfDimension:dimensionIndex::NSUInteger]::UInt

# TODO
# readBytes(strideBytes)
# writeBytes(strideBytes)

synchronizeOnCommandBuffer(ndarr::MPSNDArray, q::MTLCommandBuffer) =
@objc [ndarr::MPSNDArray synchronizeOnCommandBuffer:q::id{MTLCommandBuffer}]::Nothing


export MPSNDArrayMultiaryBase

@objcwrapper immutable=false MPSNDArrayMultiaryBase <: MPSKernel

export MPSNDArrayMultiaryKernel

@objcwrapper immutable=false MPSNDArrayMultiaryKernel <: MPSNDArrayMultiaryBase

function MPSNDArrayMultiaryKernel(device, sourceCount)
kernel = @objc [MPSNDArrayMultiaryKernel alloc]::id{MPSNDArrayMultiaryKernel}
obj = MPSNDArrayMultiaryKernel(kernel)
finalizer(release, obj)
@objc [obj::id{MPSNDArrayMultiaryKernel} initWithDevice:device::id{MTLDevice}
sourceCount:sourceCount::NSUInteger]::id{MPSNDArrayMultiaryKernel}
return obj
end

function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArrays) where {K<:MPSNDArrayMultiaryKernel}
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
sourceArrays:sourceArrays::id{NSArray}]::id{MPSNDArray}
end
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArrays, destinationArray) where {K<:MPSNDArrayMultiaryKernel}
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
sourceArrays:sourceArrays::id{NSArray}
destinationArray:destinationArray::id{MPSNDArray}]::Nothing
end
# TODO: MPSState is not implemented yet, so these don't work
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArrays, resultState, destinationArray) where {K<:MPSNDArrayMultiaryKernel}
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
# sourceArrays:sourceArrays::id{NSArray}
# resultState:resultState::id{MPSState}
# destinationArray:destinationArray::id{MPSNDArray}]::Nothing
# end
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArrays, resultState, outputStateIsTemporary::Bool) where {K<:MPSNDArrayMultiaryKernel}
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
# sourceArrays:sourceArrays::id{NSArray}
# resultState:resultState::id{MPSState}
# outputStateIsTemporary:outputStateIsTemporary::Bool]::MPSNDArray
# end
christiangnrd marked this conversation as resolved.
Show resolved Hide resolved

export MPSNDArrayUnaryKernel

@objcwrapper immutable=false MPSNDArrayUnaryKernel <: MPSNDArrayMultiaryBase

function MPSNDArrayUnaryKernel(device)
kernel = @objc [MPSNDArrayUnaryKernel alloc]::id{MPSNDArrayUnaryKernel}
obj = MPSNDArrayUnaryKernel(kernel)
finalizer(release, obj)
@objc [obj::id{MPSNDArrayUnaryKernel} initWithDevice:device::id{MTLDevice}]::id{MPSNDArrayUnaryKernel}
return obj
end

function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArray) where {K<:MPSNDArrayUnaryKernel}
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
sourceArray:sourceArray::id{MPSNDArray}]::id{MPSNDArray}
end
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArray, destinationArray) where {K<:MPSNDArrayUnaryKernel}
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
sourceArray:sourceArray::id{MPSNDArray}
destinationArray:destinationArray::id{MPSNDArray}]::Nothing
end
# TODO: MPSState is not implemented yet, so these don't work
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArray, resultState, destinationArray) where {K<:MPSNDArrayUnaryKernel}
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
# sourceArray:sourceArray::id{MPSNDArray}
# resultState:resultState::id{MPSState}
# destinationArray:destinationArray::id{MPSNDArray}]::Nothing
# end
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArray, resultState, outputStateIsTemporary::Bool) where {K<:MPSNDArrayUnaryKernel}
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
# sourceArray:sourceArrays::id{MPSNDArray}
# resultState:resultState::id{MPSState}
# outputStateIsTemporary:outputStateIsTemporary::Bool]::MPSNDArray
# end

export MPSNDArrayBinaryKernel

@objcwrapper immutable=false MPSNDArrayBinaryKernel <: MPSNDArrayMultiaryBase

function MPSNDArrayBinaryKernel(device)
kernel = @objc [MPSNDArrayBinaryKernel alloc]::id{MPSNDArrayBinaryKernel}
obj = MPSNDArrayBinaryKernel(kernel)
finalizer(release, obj)
@objc [obj::id{MPSNDArrayBinaryKernel} initWithDevice:device::id{MTLDevice}]::id{MPSNDArrayBinaryKernel}
return obj
end

function encode!(cmdbuf::MTLCommandBuffer, kernel::K, primarySourceArray, secondarySourceArray) where {K<:MPSNDArrayBinaryKernel}
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
secondarySourceArray:secondarySourceArray::id{MPSNDArray}
primarySourceArray:primarySourceArray::id{MPSNDArray}]::id{MPSNDArray}
end
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, primarySourceArray, secondarySourceArray, destinationArray) where {K<:MPSNDArrayBinaryKernel}
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
primarySourceArray:primarySourceArray::id{MPSNDArray}
secondarySourceArray:secondarySourceArray::id{MPSNDArray}
destinationArray:destinationArray::id{MPSNDArray}]::Nothing
end
# TODO: MPSState is not implemented yet, so these don't work
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, primarySourceArray, secondarySourceArray, resultState, destinationArray) where {K<:MPSNDArrayBinaryKernel}
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
# primarySourceArray:primarySourceArray::id{MPSNDArray}
# secondarySourceArray:secondarySourceArray::id{MPSNDArray}
# resultState:resultState::id{MPSState}
# destinationArray:destinationArray::id{MPSNDArray}]::Nothing
# end
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, primarySourceArray, secondarySourceArray, resultState, outputStateIsTemporary::Bool) where {K<:MPSNDArrayBinaryKernel}
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
# primarySourceArray:primarySourceArrays::id{MPSNDArray}
# secondarySourceArray:secondarySourceArray::id{MPSNDArray}
# resultState:resultState::id{MPSState}
# outputStateIsTemporary:outputStateIsTemporary::Bool]::MPSNDArray
# end

@objcwrapper immutable=false MPSNDArrayMatrixMultiplication <: MPSNDArrayMultiaryKernel

@objcproperties MPSNDArrayMatrixMultiplication begin
@autoproperty alpha::Float64 setter=setAlpha
@autoproperty beta::Float64 setter=setBeta
end

function MPSNDArrayMatrixMultiplication(device, sourceCount)
kernel = @objc [MPSNDArrayMatrixMultiplication alloc]::id{MPSNDArrayMatrixMultiplication}
obj = MPSNDArrayMatrixMultiplication(kernel)
finalizer(release, obj)
@objc [obj::id{MPSNDArrayMatrixMultiplication} initWithDevice:device::id{MTLDevice}
sourceCount:sourceCount::NSUInteger]::id{MPSNDArrayMatrixMultiplication}
return obj
end
Loading
Loading