Skip to content

Commit

Permalink
Fix global linear indexing (fill!) (#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd authored Dec 16, 2024
1 parent 6ecb909 commit 8c119cf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
4 changes: 3 additions & 1 deletion src/MetalKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ end
end

@device_override @inline function KA.__index_Global_Linear(ctx)
return thread_position_in_grid_1d()
I = @inbounds KA.expand(KA.__iterspace(ctx), threadgroup_position_in_grid_1d(), thread_position_in_threadgroup_1d())
# TODO: This is unfortunate, can we get the linear index cheaper
@inbounds LinearIndices(KA.__ndrange(ctx))[I]
end

@device_override @inline function KA.__index_Local_Cartesian(ctx)
Expand Down
14 changes: 4 additions & 10 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,12 @@ end

@testset "fill($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
Float16, Float32]
broken466a = T [Int8,UInt8]
broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation)

b = rand(T)

# Dims in tuple
let A = Metal.fill(b, (10, 10, 10, 1000))
B = fill(b, (10, 10, 10, 1000))
@test Array(A) == B broken=(broken466a && broken466b)
@test Array(A) == B
end

let M = Metal.fill(b, (10, 10))
Expand All @@ -253,7 +250,7 @@ end
#Dims already unpacked
let A = Metal.fill(b, 10, 1000, 1000)
B = fill(b, 10, 1000, 1000)
@test Array(A) == B broken=broken466a
@test Array(A) == B
end

let M = Metal.fill(b, 10, 10)
Expand All @@ -269,15 +266,12 @@ end

@testset "fill!($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
Float16, Float32]
broken466a = T [Int8,UInt8]
broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation)

b = rand(T)

# Dims in tuple
let A = MtlArray{T,3}(undef, (10, 1000, 1000))
fill!(A, b)
@test all(Array(A) .== b) broken=broken466a
@test all(Array(A) .== b)
end

let M = MtlMatrix{T}(undef, (10, 10))
Expand All @@ -293,7 +287,7 @@ end
# Dims already unpacked
let A = MtlArray{T,4}(undef, 10, 10, 10, 1000)
fill!(A, b)
@test all(Array(A) .== b) broken=(broken466a && broken466b)
@test all(Array(A) .== b)
end

let M = MtlMatrix{T}(undef, 10, 10)
Expand Down

0 comments on commit 8c119cf

Please sign in to comment.