Skip to content

Commit

Permalink
Refactor/cleanup and incorporate comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jan 17, 2024
1 parent c49c286 commit 5df2dcd
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 127 deletions.
4 changes: 2 additions & 2 deletions ext/SUNRepresentationsLatexifyExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Latexify: @latexrecipe, LaTeXString
@latexrecipe function f(x::SUNIrrep)
## set parameters
env --> :inline

## convert into latex string
d, numprimes, conjugate = parse_dimname(dimname(x))
str_new = conjugate ? "\\overline{\\textbf{$d}}" : "\\textbf{$d}"
Expand All @@ -22,4 +22,4 @@ end

Base.show(io::IO, ::MIME"text/latex", x::SUNIrrep) = print(io, latexify(x))

end
end
6 changes: 3 additions & 3 deletions src/SUNRepresentations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
SUNIrrep(args::Vararg{Int,N}) where {N} = SUNIrrep{N}(args)
SUNIrrep{N}(args::Vararg{Int}) where {N} = SUNIrrep{N}(args)

SUNIrrep(a::Vector{Int}) = SUNIrrep{length(a)+1}(a)
SUNIrrep(a::Vector{Int}) = SUNIrrep{length(a) + 1}(a)
function SUNIrrep{N}(a::Vector{Int}) where {N}
@assert length(a) == N - 1
return SUNIrrep{N}(reverse(cumsum(reverse(a)))..., 0)
Expand All @@ -54,15 +54,15 @@ function SUNIrrep{N}(name::AbstractString) where {N}
name == generate_dimname(6, 0, false) && return SUNIrrep{N}(2, 0, 0)
name == generate_dimname(6, 0, true) && return SUNIrrep{N}(2, 2, 0)
end

d, numprimes, conjugate = parse_dimname(name)
max_dynkin = max_dynkin_label(SUNIrrep{N})

same_dim_irreps = irreps_by_dim(SUNIrrep{N}, d, max_dynkin)
same_dim_ids = unique!(index.(same_dim_irreps))
length(same_dim_ids) < numprimes + 1 &&
throw(ArgumentError("Either the name $name is not valid for SU{$N} or the irrep has at least one Dynkin label higher than $max_dynkin.\nYou can expand the search space with `SUNRepresentations.max_dynkin_label(SUNIrrep{$N}) = a`."))

id = same_dim_ids[numprimes + 1]
same_id_irreps = filter(x -> index(x) == id, same_dim_irreps)
@assert length(same_id_irreps) <= 2
Expand Down
184 changes: 77 additions & 107 deletions src/caching.jl
Original file line number Diff line number Diff line change
@@ -1,64 +1,47 @@
const CGCKEY{N} = NTuple{3,SUNIrrep{N}}
_string(key::CGCKEY) = "$(key[1].I)$(key[2].I)$(key[3].I)"
const CGCCache{N,T} = LRU{CGCKEY{N},SparseArray{T,4}}

struct CGCCache{N,T}
data::LRU{CGCKEY{N},SparseArray{T,4}} # RAM cached CGC tensors
function CGCCache{N,T}(; maxsize=10^5) where {N,T}
data = LRU{CGCKEY{N},SparseArray{T,4}}(; maxsize)
return new{N,T}(data)
end
end

function Base.show(io::IO, cache::CGCCache{N,T}) where {N,T}
println(io, typeof(cache))
println(io, " ", LRUCache.cache_info(cache.data))
fn = cache_path(N, T)
if isfile(fn)
println(io, " ", filesize(fn), " bytes on disk")
jldopen(fn, "r") do file
return println(io, " $(length(keys(file))) entries in disk cache")
end
else
println(io, " no disk cache")
end
end
# convert sector to string key
_key(s::SUNIrrep) = string(weight(s))

# List of CGC caches for each N and T
const CGC_CACHES = LRU{Any,CGCCache}(; maxsize=10)

const CGC_CACHE_PATH = @get_scratch!("CGC")
cache_path(N, T=Float64) = joinpath(CGC_CACHE_PATH, "$(N)_$(T)")

function Base.get!(cache::CGCCache{N,T}, (s1, s2, s3)::CGCKEY{N})::SparseArray{T,4} where {T,N}
return get!(cache.data, (s1, s2, s3)) do
# if the key is not in the cache, check if it is in a file
cachedir = joinpath(cache_path(N, T), "$(weight(s1))")
isdir(cachedir) || mkpath(cachedir)
fn = "$(weight(s1)) x $(weight(s2))"

# try reading data
if isfile(joinpath(cachedir, fn * ".jld2"))
try
return jldopen(joinpath(cachedir, fn * ".jld2"), "r";
parallel_read=true) do file
@debug "loaded CGC from disk: $s1$s2$s3"
return file[string(weight(s3))]::SparseArray{T,4}
end
catch
end
end
# if failed, create new data
CGCs = Dict(string(weight(s3′)) => _CGC(T, s1, s2, s3′)
for s3′ in s1 s2)
function cgc_cachepath(s1::SUNIrrep{N}, s2::SUNIrrep{N}, T=Float64) where {N}
return joinpath(CGC_CACHE_PATH, string(N), string(T), _key(s1), _key(s2))
end

function tryread(::Type{T}, s1::SUNIrrep{N}, s2::SUNIrrep{N}, s3::SUNIrrep{N}) where {T,N}
fn = cgc_cachepath(s1, s2, T) * ".jld2"
isfile(fn) || return nothing

# write CGCs to disk
mkpidlock(joinpath(cachedir, fn * ".pid")) do
return save(joinpath(cachedir, fn * ".jld2"), CGCs)
try
return jldopen(fn, "r"; parallel_read=true) do file
@debug "loaded CGC from disk: $s1$s2$s3"
return file[_key(s3)]::SparseArray{T,4}
end
catch

Check warning on line 24 in src/caching.jl

View check run for this annotation

Codecov / codecov/patch

src/caching.jl#L24

Added line #L24 was not covered by tests
end

# return CGC
return CGCs[string(weight(s3))]
return nothing

Check warning on line 27 in src/caching.jl

View check run for this annotation

Codecov / codecov/patch

src/caching.jl#L27

Added line #L27 was not covered by tests
end

#= wait at most 1 min before deciding to overwrite. This should avoid deadlocking if a
process started writing but got killed before removing the pidfile. =#
const _PID_STALE_AGE = 60.0

function generate_all_CGCs(::Type{T}, s1::SUNIrrep{N}, s2::SUNIrrep{N}) where {T,N}
@debug "Generating CGCs: $s1$s2"
CGCs = Dict(_key(s3) => _CGC(T, s1, s2, s3) for s3 in s1 s2)
fn = cgc_cachepath(s1, s2, T)
isdir(dirname(fn)) || mkpath(dirname(fn))

mkpidlock(fn * ".pid"; stale_age=_PID_STALE_AGE) do
return save(fn * ".jld2", CGCs)
end

return CGCs
end

"""
Expand All @@ -68,49 +51,31 @@ Populate the CGC cache for ``SU(N)`` with eltype `T` with all CGCs with Dynkin l
``a_max``.
Will not recompute CGCs that are already in the cache, unless ``force=true``.
"""
function precompute_disk_cache(N, a_max::Int=3, T::Type{<:Number}=Float64; force=false)
all_dynkinlabels = CartesianIndices(ntuple(_ -> (a_max + 1), N - 1))
all_irreps = [SUNIrrep(reverse(cumsum(I.I .- 1))..., 0) for I in all_dynkinlabels]

@sync begin
for s1 in all_irreps
cachedir = joinpath(cache_path(N, T), "$(weight(s1))")
isdir(cachedir) || mkpath(cachedir)
for s2 in all_irreps
if force || !isfile(joinpath(cachedir, "$(weight(s1)) x $(weight(s2)).jld2"))
Threads.@spawn _compute_disk_cache($s1, $s2, $T)
end
function precompute_disk_cache(N, a_max::Int=1, T::Type{<:Number}=Float64; force=false)
all_irreps = all_dynkin(SUNIrrep{N}, a_max)
@sync for s1 in all_irreps, s2 in all_irreps
if force || !isfile(cgc_cachepath(s1, s2, T) * ".jld2")
Threads.@spawn begin
generate_all_CGCs(T, s1, s2)
nothing
end
end
end

Check warning on line 63 in src/caching.jl

View check run for this annotation

Codecov / codecov/patch

src/caching.jl#L63

Added line #L63 was not covered by tests

cache_info()
return nothing
end

function _compute_disk_cache(s1::SUNIrrep{N}, s2::SUNIrrep{N}, T) where {N}
@info "Computing CGC: $s1$s2"
cachedir = joinpath(cache_path(N, T), "$(weight(s1))")
fn = "$(weight(s1)) x $(weight(s2))"
CGCs = Dict(string(weight(s3)) => _CGC(T, s1, s2, s3)
for s3 in s1 s2)

# write CGCs to disk
mkpidlock(joinpath(cachedir, fn * ".pid")) do
return save(joinpath(cachedir, fn * ".jld2"), CGCs)
end
end

"""
clear_disk_cache!(N, [T=Float64])
Remove the CGC cache for ``SU(N)`` with eltype `T` from disk.
"""
function clear_disk_cache!(N, T=Float64)
fn = cache_path(N, T)
if isfile(fn)
fldrname = joinpath(CGC_CACHE_PATH, string(N), string(T))
if isdir(fldrname)
@info "Removing disk cache SU($N): $T"
rm(fn)
rm(fldrname; recursive=true)

Check warning on line 78 in src/caching.jl

View check run for this annotation

Codecov / codecov/patch

src/caching.jl#L74-L78

Added lines #L74 - L78 were not covered by tests
end
return nothing

Check warning on line 80 in src/caching.jl

View check run for this annotation

Codecov / codecov/patch

src/caching.jl#L80

Added line #L80 was not covered by tests
end
Expand All @@ -121,46 +86,51 @@ end

_parse_filename(fn) = split(splitext(basename(fn))[1], "_")

Check warning on line 87 in src/caching.jl

View check run for this annotation

Codecov / codecov/patch

src/caching.jl#L87

Added line #L87 was not covered by tests

"""
cache_info()
Print information about the CGC cache.
"""
function cache_info(io::IO=stdout)
function ram_cache_info(io::IO=stdout)
println(io, "CGC RAM cache info:")
println(io, "===================")
for ((N, T), cache) in CGC_CACHES
println(io, "SU($N) - $T")
println(io, "------------------------")
println(io, cache)
println(io, "* ", LRUCache.cache_info(cache))
println(io)
end

Check warning on line 97 in src/caching.jl

View check run for this annotation

Codecov / codecov/patch

src/caching.jl#L93-L97

Added lines #L93 - L97 were not covered by tests
return nothing
end

println(io)
function disk_cache_info(io::IO=stdout)
println(io, "CGC disk cache info:")
println(io, "====================")

cache_dir = CGC_CACHE_PATH
isdir(cache_dir) || return nothing

for fldr in readdir(cache_dir; join=true)
isdir(fldr) || continue
N, T = _parse_filename(fldr)

n_bytes = 0
n_entries = 0

for (root, _, files) in walkdir(fldr)
for f in files
n_bytes += filesize(joinpath(root, f))
n_entries += jldopen(file -> length(keys(file)), joinpath(root, f), "r")
isdir(CGC_CACHE_PATH) || return nothing

for fldr_N in readdir(CGC_CACHE_PATH; join=true)
isdir(fldr_N) || continue
N = last(splitpath(fldr_N))
for fldr_T in readdir(fldr_N; join=true)
isdir(fldr_T) || continue
T = basename(fldr_T)
n_bytes = 0
n_entries = 0
for (root, _, files) in walkdir(fldr_T)
for f in files
n_bytes += filesize(joinpath(root, f))
n_entries += jldopen(file -> length(keys(file)), joinpath(root, f), "r")
end
end
println(io,
"* SU($N) - $T - $(n_entries) entries - $(Base.format_bytes(n_bytes))")
end

println(io, "SU($N) - $T")
println(io, " * ", n_entries, " entries")
println(io, " * ", Base.format_bytes(n_bytes))
println(io)
end
return nothing
end

"""
cache_info([io=stdout])
Print information about the CGC cache.
"""
function cache_info(io::IO=stdout)
ram_cache_info(io)
disk_cache_info(io)
return nothing
end
14 changes: 11 additions & 3 deletions src/clebschgordan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@ end

CGC(s1::I, s2::I, s3::I) where {I<:SUNIrrep} = CGC(Float64, s1, s2, s3)
function CGC(::Type{T}, s1::SUNIrrep{N}, s2::SUNIrrep{N}, s3::SUNIrrep{N}) where {T,N}
cache = get!(() -> CGCCache{N,T}(), CGC_CACHES, (N, T))::CGCCache{N,T}
return get!(cache, (s1, s2, s3))
cache = get!(() -> CGCCache{N,T}(; maxsize=100_000), CGC_CACHES, (N, T))::CGCCache{N,T}
return get!(cache, (s1, s2, s3)) do
# if the key is not in the cache, check if it is in a file
result = tryread(T, s1, s2, s3)
isnothing(result) || return result

# if not, compute it
CGCs = generate_all_CGCs(T, s1, s2)
return CGCs[_key(s3)]
end
end

function _CGC(T::Type{<:Real}, s1::I, s2::I, s3::I) where {I<:SUNIrrep}
CGC = highest_weight_CGC(T, s1, s2, s3)
lower_weight_CGC!(CGC, s1, s2, s3)
@debug "Computed CGC: $(s1.I)$(s2.I)$(s3.I)"
@debug "Computed CGC: $s1$s2$s3"
return CGC
end

Expand Down
19 changes: 8 additions & 11 deletions src/naming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,13 @@ function index(s::SUNIrrep)
return numerator(id)
end

function irreps_by_dim(::Type{SUNIrrep{N}}, d::Int, maxdynkin::Int=3) where {N}
irreps = SUNIrrep{N}[]

all_dynkin = CartesianIndices(ntuple(k -> maxdynkin + 1, N - 1))
for a in all_dynkin
I = SUNIrrep(collect(a.I .- 1))
dim(I) == d && push!(irreps, I)
end
# @show index.(irreps) congruency.(irreps) dynkin_label.(irreps)
function all_dynkin(::Type{SUNIrrep{N}}, maxdynkin::Int=3) where {N}
return (SUNIrrep(collect(I.I .- 1))
for I in CartesianIndices(ntuple(k -> maxdynkin + 1, N - 1)))
end

function irreps_by_dim(::Type{SUNIrrep{N}}, d::Int, maxdynkin::Int=3) where {N}
irreps = [I for I in all_dynkin(SUNIrrep{N}, maxdynkin) if dim(I) == d]
return sort!(irreps; by=x -> (index(x), congruency(x), dynkin_label(x)))
end

Expand All @@ -120,7 +117,7 @@ function find_dimname(s::SUNIrrep{N}) where {N}
else
error("this should never happen")
end

return d, numprimes, conjugate
end

Expand Down Expand Up @@ -150,7 +147,7 @@ function dimname(s::SUNIrrep{N}) where {N}
# for some reason in SU{3}, the 6-dimensional irreps have switched duality
s == SUNIrrep(2, 0, 0) && return generate_dimname(6, 0, false)
s == SUNIrrep(2, 2, 0) && return generate_dimname(6, 0, true)

d, numprimes, conjugate = find_dimname(s)
return generate_dimname(d, numprimes, conjugate)
end
Expand Down
2 changes: 1 addition & 1 deletion test/caching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ println("Caching tests")
println("------------------------------------")

# Tests for caching of Clebsch-Gordan coefficients
import SUNRepresentations: cache_info, precompute_disk_cache, cache_path, clear_disk_cache!
import SUNRepresentations: cache_info, precompute_disk_cache, clear_disk_cache!

# only remove cache if running on CI
if get(ENV, "CI", false) == "true"
Expand Down

0 comments on commit 5df2dcd

Please sign in to comment.