diff --git a/lib/mps/ndarray.jl b/lib/mps/ndarray.jl index 8012202e..4384ff1e 100644 --- a/lib/mps/ndarray.jl +++ b/lib/mps/ndarray.jl @@ -24,10 +24,9 @@ function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes end function MPSNDArrayDescriptor(dataType::DataType, shape::DenseVector{T}) where {T<:Union{Int,UInt}} - revshape = collect(reverse(shape)) - obj = GC.@preserve revshape begin - shapeptr = pointer(revshape) - MPSNDArrayDescriptor(dataType, length(revshape), shapeptr) + obj = GC.@preserve shape begin + shapeptr = pointer(shape) + MPSNDArrayDescriptor(dataType, length(shape), shapeptr) end return obj end @@ -135,7 +134,7 @@ end function MPSNDArray(arr::MtlArray{T,N}) where {T,N} arrsize = size(arr) - @assert arrsize[end]*sizeof(T) % 16 == 0 "Final dimension of arr must have a byte size divisible by 16" + @assert arrsize[1]*sizeof(T) % 16 == 0 "First dimension of arr must have a byte size divisible by 16" desc = MPSNDArrayDescriptor(T, arrsize) return MPSNDArray(arr.data[], UInt(arr.offset), desc) end diff --git a/test/mps/ndarray.jl b/test/mps/ndarray.jl index 5a2d2d0d..edd8af94 100644 --- a/test/mps/ndarray.jl +++ b/test/mps/ndarray.jl @@ -7,7 +7,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension T = Float32 DT = convert(MPSDataType, T) - desc1 = MPSNDArrayDescriptor(T, 5,4,3,2,1) + desc1 = MPSNDArrayDescriptor(T,1,2,3,4,5) @test desc1 isa MPSNDArrayDescriptor @test desc1.dataType == DT @test desc1.preferPackedRows == false @@ -19,7 +19,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension @test lengthOfDimension(desc1,4) == 4 @test lengthOfDimension(desc1,3) == 5 - desc2 = MPSNDArrayDescriptor(T, (4,3,2,1)) + desc2 = MPSNDArrayDescriptor(T, (1,2,3,4)) @test desc2 isa MPSNDArrayDescriptor @test desc2.dataType == DT @test desc2.numberOfDimensions == 4 @@ -51,6 +51,7 @@ using .MPS: MPSNDArray @test ndarr1.label == "Test1" @test ndarr1.numberOfDimensions == 5 @test ndarr1.parent === nothing + @test size(ndarr1) == (5,4,3,2,1) ndarr2 = MPSNDArray(dev, 4) @test ndarr2 isa MPSNDArray @@ -63,9 +64,9 @@ using .MPS: MPSNDArray @test ndarr2.parent === nothing arr3 = MtlArray(ones(Float16, 2,3,4)) - @test_throws "Final dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3) + @test_throws "First dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3) - arr4 = MtlArray(ones(Float16, 2,3,8)) + arr4 = MtlArray(ones(Float16, 8,3,2)) @static if Metal.macos_version() >= v"15" @test ndarr1.descriptor isa MPSNDArrayDescriptor