From 83a4fe503be32aa9891d9206cbd077a1aed9c9b3 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Tue, 5 Sep 2023 16:02:51 +0200 Subject: [PATCH] ldiv --- lib/mps/linalg.jl | 142 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index cf2c8f360..43e4da978 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -261,5 +261,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T} commit!(cmdbuf) + wait_completed(cmdbuf) + + return B +end + + +function LinearAlgebra.:(\)(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + C = deepcopy(B) + LinearAlgebra.ldiv!(A, C) + return C +end + + +function LinearAlgebra.ldiv!(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + At = similar(A.factors) + Bt = similar(B, (N, M)) + P = reshape((A.ipiv .- UInt32(1)), (1, M)) + X = similar(B, (N, M)) + + transpose!(At, A.factors) + transpose!(Bt, B) + + mps_a = MPSMatrix(At) + mps_b = MPSMatrix(Bt) + mps_p = MPSMatrix(P) + mps_x = MPSMatrix(X) + + MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveLU(dev, false, M, N) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x) + end + + transpose!(B, X) + return B +end + + +function LinearAlgebra.ldiv!(A::UpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + Ad = MtlMatrix(A') + Br = similar(B, (M, M)) + X = similar(Br) + + transpose!(Br, B) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(B, X) + return B +end + + +function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + Ad = MtlMatrix(A) + Br = reshape(B, (M, N)) + X = similar(Br) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(Br, X) + return B +end + + +function LinearAlgebra.ldiv!(A::LowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + Ad = MtlMatrix(A) + Br = reshape(B, (M, N)) + X = similar(Br) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(Br, X) return B end + + +function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = current_device() + queue = global_queue(dev) + + Ad = MtlMatrix(A) + Br = reshape(B, (M, N)) + X = similar(Br) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(Br, X) + return B +end \ No newline at end of file