Skip to content

Commit

Permalink
Don't reverse dimensions automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Dec 18, 2024
1 parent 2a6d162 commit 8654f9f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
9 changes: 4 additions & 5 deletions lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes
end

function MPSNDArrayDescriptor(dataType::DataType, shape::DenseVector{T}) where {T<:Union{Int,UInt}}
revshape = collect(reverse(shape))
obj = GC.@preserve revshape begin
shapeptr = pointer(revshape)
MPSNDArrayDescriptor(dataType, length(revshape), shapeptr)
obj = GC.@preserve shape begin
shapeptr = pointer(shape)
MPSNDArrayDescriptor(dataType, length(shape), shapeptr)
end
return obj
end
Expand Down Expand Up @@ -135,7 +134,7 @@ end

function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
arrsize = size(arr)
@assert arrsize[end]*sizeof(T) % 16 == 0 "Final dimension of arr must have a byte size divisible by 16"
@assert arrsize[1]*sizeof(T) % 16 == 0 "First dimension of arr must have a byte size divisible by 16"
desc = MPSNDArrayDescriptor(T, arrsize)
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
end
Expand Down
9 changes: 5 additions & 4 deletions test/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension
T = Float32
DT = convert(MPSDataType, T)

desc1 = MPSNDArrayDescriptor(T, 5,4,3,2,1)
desc1 = MPSNDArrayDescriptor(T,1,2,3,4,5)
@test desc1 isa MPSNDArrayDescriptor
@test desc1.dataType == DT
@test desc1.preferPackedRows == false
Expand All @@ -19,7 +19,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension
@test lengthOfDimension(desc1,4) == 4
@test lengthOfDimension(desc1,3) == 5

desc2 = MPSNDArrayDescriptor(T, (4,3,2,1))
desc2 = MPSNDArrayDescriptor(T, (1,2,3,4))
@test desc2 isa MPSNDArrayDescriptor
@test desc2.dataType == DT
@test desc2.numberOfDimensions == 4
Expand Down Expand Up @@ -51,6 +51,7 @@ using .MPS: MPSNDArray
@test ndarr1.label == "Test1"
@test ndarr1.numberOfDimensions == 5
@test ndarr1.parent === nothing
@test size(ndarr1) == (5,4,3,2,1)

ndarr2 = MPSNDArray(dev, 4)
@test ndarr2 isa MPSNDArray
Expand All @@ -63,9 +64,9 @@ using .MPS: MPSNDArray
@test ndarr2.parent === nothing

arr3 = MtlArray(ones(Float16, 2,3,4))
@test_throws "Final dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3)
@test_throws "First dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3)

arr4 = MtlArray(ones(Float16, 2,3,8))
arr4 = MtlArray(ones(Float16, 8,3,2))

@static if Metal.macos_version() >= v"15"
@test ndarr1.descriptor isa MPSNDArrayDescriptor
Expand Down

0 comments on commit 8654f9f

Please sign in to comment.