diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index 04c08f87..10427d1a 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -24,4 +24,7 @@ include("decomposition.jl") # matrix copy include("copy.jl") +# solver +include("solve.jl") + end diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index 66944356..e8a952a7 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -192,3 +192,108 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T} return B end + + +function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T} + # TODO +end + +function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T} + M,N = size(B) + dev = current_device() + queue = global_queue(dev) + cmdbuf = MTLCommandBuffer(queue) + enqueue!(cmdbuf) + + X = MtlMatrix{T}(undef, size(B)) + + mps_a = MPSMatrix(A) + mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector + mps_x = MPSMatrix(X) + + solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, false, M, N, 1.0) # TODO: likely N, M is the correct order + encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x) + commit!(cmdbuf) + + return X +end + +function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T} + M,N = size(B) + dev = current_device() + queue = global_queue(dev) + cmdbuf = MTLCommandBuffer(queue) + enqueue!(cmdbuf) + + Bh = reshape(B, ) + X = MtlMatrix{T}(undef, size(B)) + + mps_a = MPSMatrix(A) + mps_b = MPSMatrix(Bh) # TODO reshape to matrix if B is a vector + mps_x = MPSMatrix(X) + + solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0) + encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x) + commit!(cmdbuf) + + return X +end + +function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T} + M,N = size(B) + dev = current_device() + queue = global_queue(dev) + cmdbuf = MTLCommandBuffer(queue) + enqueue!(cmdbuf) + + X = MtlMatrix{T}(undef, size(B)) + + mps_a = MPSMatrix(A) + mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector + mps_x = MPSMatrix(X) + + solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0) + encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x) + commit!(cmdbuf) + + return X +end + +function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T} + M,N = size(B) + dev = current_device() + queue = global_queue(dev) + cmdbuf = MTLCommandBuffer(queue) + enqueue!(cmdbuf) + + X = MtlMatrix{T}(undef, size(B)) + + mps_a = MPSMatrix(A) + mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector + mps_x = MPSMatrix(X) + + solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0) + encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x) + commit!(cmdbuf) + + return X +end + +# function (\)(A::AbstractMatrix, B::AbstractVecOrMat) +# require_one_based_indexing(A, B) +# m, n = size(A) +# if m == n +# if istril(A) +# if istriu(A) +# return Diagonal(A) \ B +# else +# return LowerTriangular(A) \ B +# end +# end +# if istriu(A) +# return UpperTriangular(A) \ B +# end +# return lu(A) \ B +# end +# return qr(A, ColumnNorm()) \ B +# end \ No newline at end of file diff --git a/lib/mps/solve.jl b/lib/mps/solve.jl new file mode 100644 index 00000000..c5978e26 --- /dev/null +++ b/lib/mps/solve.jl @@ -0,0 +1,72 @@ +export MPSMatrixSolveLU + +@objcwrapper immutable=false MPSMatrixSolveLU <: MPSMatrixBinaryKernel + +function MPSMatrixSolveLU(device, transpose, order, numberOfRightHandSides) + kernel = @objc [MPSMatrixSolveLU alloc]::id{MPSMatrixSolveLU} + obj = MPSMatrixSolveLU(kernel) + finalizer(release, obj) + @objc [obj::id{MPSMatrixSolveLU} initWithDevice:device::id{MTLDevice} + transpose:transpose::Bool + order:order::NSUInteger + numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveLU} + return obj +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveLU, sourceMatrix, rightHandSideMatrix, pivotIndices, solutionMatrix) + @objc [kernel::id{MPSMatrixSolveLU} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + sourceMatrix:sourceMatrix::id{MPSMatrix} + rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix} + pivotIndices:pivotIndices::id{MPSMatrix} + solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing +end + + +export MPSMatrixSolveTriangular + +@objcwrapper immutable=false MPSMatrixSolveTriangular <: MPSMatrixBinaryKernel + +function MPSMatrixSolveTriangular(device, right, upper, transpose, unit, order, numberOfRightHandSides, alpha) + kernel = @objc [MPSMatrixSolveTriangular alloc]::id{MPSMatrixSolveTriangular} + obj = MPSMatrixSolveTriangular(kernel) + finalizer(release, obj) + @objc [obj::id{MPSMatrixSolveTriangular} initWithDevice:device::id{MTLDevice} + right:right::Bool + upper:upper::Bool + transpose:transpose::Bool + unit:unit::Bool + order:order::NSUInteger + numberOfRightHandSides:numberOfRightHandSides::NSUInteger + alpha:alpha::Cdouble]::id{MPSMatrixSolveTriangular} + return obj +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveTriangular, sourceMatrix, rightHandSideMatrix, solutionMatrix) + @objc [kernel::id{MPSMatrixSolveTriangular} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + sourceMatrix:sourceMatrix::id{MPSMatrix} + rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix} + solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing +end + + +export MPSMatrixSolveCholesky + +@objcwrapper immutable=false MPSMatrixSolveCholesky <: MPSMatrixBinaryKernel + +function MPSMatrixSolveCholesky(device, upper, order, numberOfRightHandSides) + kernel = @objc [MPSMatrixSolveCholesky alloc]::id{MPSMatrixSolveCholesky} + obj = MPSMatrixSolveCholesky(kernel) + finalizer(release, obj) + @objc [obj::id{MPSMatrixSolveCholesky} initWithDevice:device::id{MTLDevice} + upper:upper::Bool + order:order::NSUInteger + numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveCholesky} + return obj +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveCholesky, sourceMatrix, rightHandSideMatrix, solutionMatrix) + @objc [kernel::id{MPSMatrixSolveCholesky} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + sourceMatrix:sourceMatrix::id{MPSMatrix} + rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix} + solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing +end