Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement automul #8

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ julia = "1.6"
[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Quantics = "87f76fb3-a40a-40c9-a63c-29fcfe7b7547"

[targets]
test = ["Test", "Random"]
test = ["Test", "Random", "Quantics"]
11 changes: 7 additions & 4 deletions src/ProjMPSs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ module ProjMPSs

import OrderedCollections: OrderedSet, OrderedDict
using EllipsisNotation
import LinearAlgebra
using LinearAlgebra: LinearAlgebra

import ITensors: ITensors, Index, ITensor, dim, inds, qr, commoninds
import ITensorMPS: ITensorMPS, AbstractMPS, MPS, MPO, siteinds
import ITensors: ITensors, Index, ITensor, dim, inds, qr, commoninds, uniqueinds
import ITensorMPS: ITensorMPS, AbstractMPS, MPS, MPO, siteinds, findsites
import ITensors.TagSets: hastag, hastags

import FastMPOContractions as FMPOC


include("util.jl")
include("projector.jl")
include("projmps.jl")
Expand All @@ -18,4 +18,7 @@ include("patching.jl")
include("contract.jl")
include("adaptivemul.jl")

# Only for backward compatibility
include("automul.jl")

end
124 changes: 124 additions & 0 deletions src/automul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
By default, elementwise multiplication will be performed.

This function is kind of deprecated and will be removed in the future.
"""
function automul(
M1::BlockedMPS,
M2::BlockedMPS;
tag_row::String="",
tag_shared::String="",
tag_col::String="",
alg="naive",
maxdim=typemax(Int),
cutoff=1e-25,
kwargs...,
)
all(length.(siteinds(M1)) .== 1) || error("M1 should have only 1 site index per site")
all(length.(siteinds(M2)) .== 1) || error("M2 should have only 1 site index per site")

sites_row = _findallsiteinds_by_tag(M1; tag=tag_row)
sites_shared = _findallsiteinds_by_tag(M1; tag=tag_shared)
sites_col = _findallsiteinds_by_tag(M2; tag=tag_col)
sites_matmul = Set(Iterators.flatten([sites_row, sites_shared, sites_col]))

sites1 = only.(siteinds(M1))
sites1_ewmul = setdiff(only.(siteinds(M1)), sites_matmul)
sites2_ewmul = setdiff(only.(siteinds(M2)), sites_matmul)
sites2_ewmul == sites1_ewmul || error("Invalid sites for elementwise multiplication")

M1 = _makesitediagonal(M1, sites1_ewmul; baseplev=1)
M2 = _makesitediagonal(M2, sites2_ewmul; baseplev=0)

sites_M1_diag = [collect(x) for x in siteinds(M1)]
sites_M2_diag = [collect(x) for x in siteinds(M2)]

M1 = rearrange_siteinds(M1, combinesites(sites_M1_diag, sites_row, sites_shared))

M2 = rearrange_siteinds(M2, combinesites(sites_M2_diag, sites_shared, sites_col))

M = contract(M1, M2; alg=alg, kwargs...)

M = extractdiagonal(M, sites1_ewmul)

ressites = Vector{eltype(siteinds(M1)[1])}[]
for s in siteinds(M)
s_ = unique(ITensors.noprime.(s))
if length(s_) == 1
push!(ressites, s_)
else
if s_[1] ∈ sites1
push!(ressites, [s_[1]])
push!(ressites, [s_[2]])
else
push!(ressites, [s_[2]])
push!(ressites, [s_[1]])
end
end
end
return truncate(rearrange_siteinds(M, ressites); cutoff=cutoff, maxdim=maxdim)
end

function combinesites(
sites::Vector{Vector{Index{IndsT}}},
site1::AbstractVector{Index{IndsT}},
site2::AbstractVector{Index{IndsT}},
) where {IndsT}
length(site1) == length(site2) || error("Length mismatch")
for (s1, s2) in zip(site1, site2)
sites = combinesites(sites, s1, s2)
end
return sites
end

function combinesites(
sites::Vector{Vector{Index{IndsT}}}, site1::Index, site2::Index
) where {IndsT}
sites = deepcopy(sites)
p1 = findfirst(x -> x[1] == site1, sites)
p2 = findfirst(x -> x[1] == site2, sites)
if p1 === nothing || p2 === nothing
error("Site not found")
end
if abs(p1 - p2) != 1
error("Sites are not adjacent")
end
deleteat!(sites, min(p1, p2))
deleteat!(sites, min(p1, p2))
insert!(sites, min(p1, p2), [site1, site2])
return sites
end

function _findallsiteinds_by_tag(M::BlockedMPS; tag=tag)
return findallsiteinds_by_tag(only.(siteinds(M)); tag=tag)
end

# The following code is copied from Quantics.jl

function findallsiteinds_by_tag(
sites::AbstractVector{Index{T}}; tag::String="x", maxnsites::Int=1000
) where {T}
_valid_tag(tag) || error("Invalid tag: $tag")
positions = findallsites_by_tag(sites; tag=tag, maxnsites=maxnsites)
return [sites[p] for p in positions]
end

function findallsites_by_tag(
sites::Vector{Index{T}}; tag::String="x", maxnsites::Int=1000
)::Vector{Int} where {T}
_valid_tag(tag) || error("Invalid tag: $tag")
result = Int[]
for n in 1:maxnsites
tag_ = tag * "=$n"
idx = findall(hastags(tag_), sites)
if length(idx) == 0
break
elseif length(idx) > 1
error("Found more than one site indices with $(tag_)!")
end
push!(result, idx[1])
end
return result
end

_valid_tag(tag::String)::Bool = !occursin("=", tag)
24 changes: 23 additions & 1 deletion src/blockedmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ function Base.values(obj::BlockedMPS)
return values(obj.data)
end


"""
Rearrange the site indices of the BlockedMPS according to the given order.
If nessecary, tensors are fused or split to match the new order.
Expand Down Expand Up @@ -144,3 +143,26 @@ end
function ITensorMPS.MPO(obj::BlockedMPS; cutoff=1e-25, maxdim=typemax(Int))::MPO
return MPO(collect(MPS(obj; cutoff=cutoff, maxdim=maxdim, kwargs...)))
end

"""
Make the BlockedMPS diagonal for a given site index `s` by introducing a dummy index `s'`.
"""
function makesitediagonal(obj::BlockedMPS, site)
return BlockedMPS([
_makesitediagonal(prjmps, site; baseplev=baseplev) for prjmps in values(obj)
])
end

function _makesitediagonal(obj::BlockedMPS, site; baseplev=0)
return BlockedMPS([
_makesitediagonal(prjmps, site; baseplev=baseplev) for prjmps in values(obj)
])
end

"""
Extract diagonal of the BlockedMPS for `s`, `s'`, ... for a given site index `s`,
where `s` must have a prime level of 0.
"""
function extractdiagonal(obj::BlockedMPS, site)
return BlockedMPS([extractdiagonal(prjmps, site) for prjmps in values(obj)])
end
82 changes: 82 additions & 0 deletions src/projmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,85 @@ end
function LinearAlgebra.norm(M::ProjMPS)
return _norm(MPS(M))
end

function _makesitediagonal(
projmps::ProjMPS, sites::AbstractVector{Index{IndsT}}; baseplev=0
) where {IndsT}
M_ = deepcopy(MPO(collect(MPS(projmps))))
for site in sites
target_site::Int = only(findsites(M_, site))
M_[target_site] = _asdiagonal(M_[target_site], site; baseplev=baseplev)
end
return project(M_, projmps.projector)
end

function _makesitediagonal(projmps::ProjMPS, site::Index; baseplev=0)
return _makesitediagonal(projmps, [site]; baseplev=baseplev)
end

function makesitediagonal(projmps::ProjMPS, site::Index)
return _makesitediagonal(projmps, site; baseplev=0)
end

function makesitediagonal(projmps::ProjMPS, sites::AbstractVector{Index})
return _makesitediagonal(projmps, sites; baseplev=0)
end

function makesitediagonal(projmps::ProjMPS, tag::String)
mps_diagonal = makesitediagonal(MPS(projmps), tag)
projmps_diagonal = ProjMPS(mps_diagonal)

target_sites = findallsiteinds_by_tag(
unique(ITensors.noprime.(Iterators.flatten(siteinds(projmps)))); tag=tag
)

newproj = deepcopy(projmps.projector)
for s in target_sites
if isprojectedat(projmps.projector, s)
newproj[ITensors.prime(s)] = newproj[s]
end
end

return project(projmps_diagonal, newproj)
end

# FIXME: may be type unstable
function _find_site_allplevs(tensor::ITensor, site::Index; maxplev=10)
ITensors.plev(site) == 0 || error("Site index must be unprimed.")
return [
ITensors.prime(site, plev) for
plev in 0:maxplev if ITensors.prime(site, plev) ∈ ITensors.inds(tensor)
]
end

function extractdiagonal(
projmps::ProjMPS, sites::AbstractVector{Index{IndsT}}
) where {IndsT}
tensors = collect(projmps.data)
for i in eachindex(tensors)
for site in intersect(sites, ITensors.inds(tensors[i]))
sitewithallplevs = _find_site_allplevs(tensors[i], site)
tensors[i] = if length(sitewithallplevs) > 1
tensors[i] = _extract_diagonal(tensors[i], sitewithallplevs...)
else
tensors[i]
end
end
end

projector = deepcopy(projmps.projector)
for site in sites
if site' in keys(projector.data)
delete!(projector.data, site')
end
end
return ProjMPS(MPS(tensors), projector)
end


function extractdiagonal(projmps::ProjMPS, tag::String)::ProjMPS
targetsites = findallsiteinds_by_tag(
unique(ITensors.noprime.(ProjMPSs._allsites(projmps))); tag=tag
)
return extractdiagonal(projmps, targetsites)
end
42 changes: 40 additions & 2 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ function _asdiagonal(t, site::Index{T}; baseplev=0)::ITensor where {T<:Number}
)
end


function rearrange_siteinds(M::AbstractMPS, sites::Vector{Vector{Index{T}}})::MPS where {T}
sitesold = siteinds(MPO(collect(M)))

Expand Down Expand Up @@ -104,5 +103,44 @@ function rearrange_siteinds(M::AbstractMPS, sites::Vector{Vector{Index{T}}})::MP
tensors[i], t, _ = qr(t, linds)
end
tensors[end] *= t
MPS(tensors)
return MPS(tensors)
end

function _extract_diagonal(t, site::Index{T}, site2::Index{T}) where {T<:Number}
dim(site) == dim(site2) || error("Dimension mismatch")
restinds = uniqueinds(inds(t), site, site2)
newdata = zeros(eltype(t), dim.(restinds)..., dim(site))
olddata = Array(t, restinds..., site, site2)
for i in 1:dim(site)
newdata[.., i] = olddata[.., i, i]
end
return ITensor(newdata, restinds..., site)
end


"""
Makes an MPS/MPO diagonal for a specified a site index `s`.
On return, the data will be deep copied and the target core tensor will be diagonalized with an additional site index `s'`.
"""
function makesitediagonal(M::AbstractMPS, site::Index{T})::MPS where {T}
M_ = deepcopy(MPO(collect(M)))

target_site::Int = only(findsites(M_, site))
M_[target_site] = _asdiagonal(M_[target_site], site)

return MPS(collect(M_))
end

function makesitediagonal(M::AbstractMPS, tag::String)::MPS
M_ = deepcopy(MPO(collect(M)))
sites = siteinds(M_)

target_positions = findallsites_by_tag(siteinds(M_); tag=tag)

for t in eachindex(target_positions)
i, j = target_positions[t]
M_[i] = _asdiagonal(M_[i], sites[i][j])
end

return MPS(collect(M_))
end
Loading
Loading