Skip to content

Commit

Permalink
Move MPSKernels into a dedicated file
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Mar 28, 2023
1 parent 9d500f6 commit 5e52e1a
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 36 deletions.
3 changes: 3 additions & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import GPUArrays

is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)

# MPS kernel base clases
include("kernel.jl")

# high-level wrappers
include("matrix.jl")

Expand Down
12 changes: 0 additions & 12 deletions lib/mps/decomposition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@
end


export MPSMatrixUnaryKernel

@objcwrapper immutable=false MPSMatrixUnaryKernel <: MPSKernel

@objcproperties MPSMatrixUnaryKernel begin
@autoproperty sourceMatrixOrigin::id{MTLOrigin} setter=setSourceMatrixOrigin
@autoproperty resultMatrixOrigin::id{MTLOrigin} setter=setResultMatrixOrigin
@autoproperty batchStart::NSUInteger setter=setBatchStart
@autoproperty batchSize::NSUInteger setter=setBatchSize
end


export MPSMatrixDecompositionLU

@objcwrapper immutable=false MPSMatrixDecompositionLU <: MPSMatrixUnaryKernel
Expand Down
42 changes: 42 additions & 0 deletions lib/mps/kernel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# kernels
#

@cenum MPSKernelOptions::NSUInteger begin
MPSKernelOptionsNone = 0
MPSKernelOptionsSkipAPIValidation = 1 << 0
MPSKernelOptionsAllowReducedPrecision = 1 << 1
MPSKernelOptionsDisableInternalTiling = 1 << 2
MPSKernelOptionsInsertDebugGroups = 1 << 3
MPSKernelOptionsVerbose = 1 << 4
end


@objcwrapper MPSKernel <: NSObject

@objcproperties MPSKernel begin
@autoproperty device::id{MTLDevice}
@autoproperty label::id{NSString} setter=setLabel
@autoproperty options::MPSKernelOptions setter=setOptions
end


@objcwrapper immutable=false MPSMatrixUnaryKernel <: MPSKernel

@objcproperties MPSMatrixUnaryKernel begin
@autoproperty sourceMatrixOrigin::id{MTLOrigin} setter=setSourceMatrixOrigin
@autoproperty resultMatrixOrigin::id{MTLOrigin} setter=setResultMatrixOrigin
@autoproperty batchStart::NSUInteger setter=setBatchStart
@autoproperty batchSize::NSUInteger setter=setBatchSize
end


@objcwrapper immutable=false MPSMatrixBinaryKernel <: MPSKernel

@objcproperties MPSMatrixUnaryKernel begin
@autoproperty primarySourceMatrixOrigin::id{MTLOrigin} setter=setPrimarySourceMatrixOrigin
@autoproperty secondarySourceMatrixOrigin::id{MTLOrigin} setter=setSecondarySourceMatrixOrigin
@autoproperty resultMatrixOrigin::id{MTLOrigin} setter=setResultMatrixOrigin
@autoproperty batchStart::NSUInteger setter=setBatchStart
@autoproperty batchSize::NSUInteger setter=setBatchSize
end
24 changes: 0 additions & 24 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@ end
## bitwise operations lose type information, so allow conversions
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)

@cenum MPSKernelOptions::NSUInteger begin
MPSKernelOptionsNone = 0
MPSKernelOptionsSkipAPIValidation = 1 << 0
MPSKernelOptionsAllowReducedPrecision = 1 << 1
MPSKernelOptionsDisableInternalTiling = 1 << 2
MPSKernelOptionsInsertDebugGroups = 1 << 3
MPSKernelOptionsVerbose = 1 << 4
end


#
# matrix descriptor
#
Expand Down Expand Up @@ -87,20 +77,6 @@ function MPSMatrix(arr::MtlMatrix{T}) where T
return obj
end


#
# kernels
#

@objcwrapper MPSKernel <: NSObject

@objcproperties MPSKernel begin
@autoproperty device::id{MTLDevice}
@autoproperty label::id{NSString} setter=setLabel
@autoproperty options::MPSKernelOptions setter=setOptions
end


#
# matrix multiplication
#
Expand Down

0 comments on commit 5e52e1a

Please sign in to comment.