Skip to content

Commit

Permalink
ldiv
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Sep 5, 2023
1 parent 69aa51e commit 83a4fe5
Showing 1 changed file with 142 additions and 0 deletions.
142 changes: 142 additions & 0 deletions lib/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,5 +261,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}

commit!(cmdbuf)

wait_completed(cmdbuf)

return B
end


function LinearAlgebra.:(\)(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
C = deepcopy(B)
LinearAlgebra.ldiv!(A, C)
return C
end


function LinearAlgebra.ldiv!(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M, N = size(B, 1), size(B, 2)
dev = current_device()
queue = global_queue(dev)

At = similar(A.factors)
Bt = similar(B, (N, M))
P = reshape((A.ipiv .- UInt32(1)), (1, M))
X = similar(B, (N, M))

transpose!(At, A.factors)
transpose!(Bt, B)

mps_a = MPSMatrix(At)
mps_b = MPSMatrix(Bt)
mps_p = MPSMatrix(P)
mps_x = MPSMatrix(X)

MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveLU(dev, false, M, N)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
end

transpose!(B, X)
return B
end


function LinearAlgebra.ldiv!(A::UpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M, N = size(B, 1), size(B, 2)
dev = current_device()
queue = global_queue(dev)

Ad = MtlMatrix(A')
Br = similar(B, (M, M))
X = similar(Br)

transpose!(Br, B)

mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)

buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

wait_completed(buf)

copy!(B, X)
return B
end


function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M, N = size(B, 1), size(B, 2)
dev = current_device()
queue = global_queue(dev)

Ad = MtlMatrix(A)
Br = reshape(B, (M, N))
X = similar(Br)

mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)


buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

wait_completed(buf)

copy!(Br, X)
return B
end


function LinearAlgebra.ldiv!(A::LowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M, N = size(B, 1), size(B, 2)
dev = current_device()
queue = global_queue(dev)

Ad = MtlMatrix(A)
Br = reshape(B, (M, N))
X = similar(Br)

mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)


buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

wait_completed(buf)

copy!(Br, X)
return B
end


function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M, N = size(B, 1), size(B, 2)
dev = current_device()
queue = global_queue(dev)

Ad = MtlMatrix(A)
Br = reshape(B, (M, N))
X = similar(Br)

mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)


buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

wait_completed(buf)

copy!(Br, X)
return B
end

0 comments on commit 83a4fe5

Please sign in to comment.