various updates for efficiency; restore recursive sumlist

This commit is contained in:
Jutho Haegeman 2021-06-14 17:00:19 +02:00
parent 231078a159
commit 04098cb2c2
6 changed files with 107 additions and 50 deletions

View file

@ -10,7 +10,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
version: version:
- '1.0' - '1.4'
- '1' - '1'
os: os:
- ubuntu-latest - ubuntu-latest

View file

@ -10,11 +10,11 @@ Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
[compat] [compat]
RationalRoots = "0.1, 0.2" RationalRoots = "0.1 - 1"
HalfIntegers = "1" HalfIntegers = "1"
Primes = "0.4, 0.5" Primes = "0.4 - 1"
LRUCache = "1.3" LRUCache = "1.3"
julia = "1" julia = "1.4"
[extras] [extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

View file

@ -1,4 +1,3 @@
__precompile__(true)
module WignerSymbols module WignerSymbols
export δ, Δ, clebschgordan, wigner3j, wigner6j, racahV, racahW, HalfInteger export δ, Δ, clebschgordan, wigner3j, wigner6j, racahV, racahW, HalfInteger
@ -9,18 +8,23 @@ const RRBig = RationalRoot{BigInt}
import RationalRoots: _convert import RationalRoots: _convert
include("growinglist.jl") include("growinglist.jl")
include("bigint.jl") # additional GMP BigInt functionality not wrapped in Base.GMP.MPZ
include("primefactorization.jl") include("primefactorization.jl")
convert(BigInt, primefactorial(401)) # trigger compilation and generate some fixed data convert(BigInt, primefactorial(401)) # trigger compilation and generate some fixed data
const Key3j = Tuple{UInt,UInt,UInt,Int,Int} const Key3j = Tuple{UInt,UInt,UInt,Int,Int}
const Key6j = NTuple{6,UInt} const Key6j = NTuple{6,UInt}
# const Wigner3j = Dict{Key3j,Tuple{Rational{BigInt},Rational{BigInt}}}()
# const Wigner6j = Dict{Key6j,Tuple{Rational{BigInt},Rational{BigInt}}}()
#
const Wigner3j = LRU{Key3j,Tuple{Rational{BigInt},Rational{BigInt}}}(; maxsize = 10^6) const Wigner3j = LRU{Key3j,Tuple{Rational{BigInt},Rational{BigInt}}}(; maxsize = 10^6)
const Wigner6j = LRU{Key6j,Tuple{Rational{BigInt},Rational{BigInt}}}(; maxsize = 10^6) const Wigner6j = LRU{Key6j,Tuple{Rational{BigInt},Rational{BigInt}}}(; maxsize = 10^6)
function set_buffer3j_size!(; maxsize)
resize!(Wigner3j; maxsize = maxsize)
end
function set_buffer6j_size!(; maxsize)
resize!(Wigner6j; maxsize = maxsize)
end
# check integerness and correctness of (j,m) angular momentum # check integerness and correctness of (j,m) angular momentum
ϵ(j, m) = (abs(m) <= j && ishalfinteger(j) && isinteger(j-m) && isinteger(j+m)) ϵ(j, m) = (abs(m) <= j && ishalfinteger(j) && isinteger(j-m) && isinteger(j+m))
@ -361,4 +365,19 @@ function compute6jseries(β₁, β₂, β₃, α₁, α₂, α₃, α₄)
return Base.unsafe_rational(totalnum, totalden) return Base.unsafe_rational(totalnum, totalden)
end end
function _precompile_()
@assert precompile(prime, (Int,))
@assert precompile(primefactor, (Int,))
@assert precompile(primefactorial, (Int,))
@assert precompile(wigner3j, (Type{Float64}, Int, Int, Int, Int, Int, Int))
@assert precompile(wigner6j, (Type{Float64}, Int, Int, Int, Int, Int, Int))
@assert precompile(wigner3j, (Type{BigFloat}, HalfInt, HalfInt, HalfInt, HalfInt, HalfInt, HalfInt))
@assert precompile(wigner6j, (Type{BigFloat}, HalfInt, HalfInt, HalfInt, HalfInt, HalfInt, HalfInt))
@assert precompile(getindex, (GrowingList{Int}, Int))
@assert precompile(getindex, (GrowingList{BigInt}, Int))
@assert precompile(get!, (GrowingList{Int}, Int, Int))
@assert precompile(get!, (GrowingList{BigInt}, Int, BigInt))
end
_precompile_()
end # module end # module

19
src/bigint.jl Normal file
View file

@ -0,0 +1,19 @@
# additional bigint functionality
using Base.GMP.MPZ
using Base.GMP.MPZ: gmpz, mpz_t
divexact!(x::BigInt, a::BigInt, b::BigInt) =
(ccall((:__gmpz_divexact, :libgmp), Cvoid, (mpz_t, mpz_t, mpz_t), x, a, b); x)
divexact(a::BigInt, b::BigInt) = divexact!(BigInt(), a, b)
divexact!(a::BigInt, b::BigInt) = divexact!(a, a, b)
const TMP_BIG = BigInt(0)
function mul!(x::Rational{BigInt}, a::Rational{BigInt}, b::Rational{BigInt})
MPZ.mul!(x.num, a.num, b.num)
MPZ.mul!(x.den, a.den, b.den)
g = MPZ.gcd!(TMP_BIG, x.num, x.den)
divexact!(x.num, g)
divexact!(x.den, g)
return x
end

View file

@ -66,7 +66,7 @@ The list is grown by adding new segments using a linked list data structure. Thi
""" """
mutable struct GrowingList{T} <: AbstractVector{T} mutable struct GrowingList{T} <: AbstractVector{T}
first::ListSegment{T} first::ListSegment{T}
totallength::Atomic{Int} totallength::Int
growthfactor::Float64 growthfactor::Float64
lock::SpinLock lock::SpinLock
function GrowingList{T}(iter; function GrowingList{T}(iter;
@ -88,7 +88,7 @@ mutable struct GrowingList{T} <: AbstractVector{T}
_unsafe_getindex(first, i, val, ceil(Int, (i-1)*growthfactor)) _unsafe_getindex(first, i, val, ceil(Int, (i-1)*growthfactor))
next = iterate(iter, state) next = iterate(iter, state)
end end
return new{T}(first, Atomic{Int}(i), growthfactor, SpinLock()) return new{T}(first, i, growthfactor, SpinLock())
end end
end end
GrowingList(v::Vector{T}; sizehint = max(16, length(v)), growthfactor = 2.) where {T} = GrowingList(v::Vector{T}; sizehint = max(16, length(v)), growthfactor = 2.) where {T} =
@ -100,7 +100,11 @@ GrowingList{T}(; sizehint = 16, growthfactor = 2.) where {T} =
GrowingList(; sizehint = 16, growthfactor = 2.) = GrowingList(; sizehint = 16, growthfactor = 2.) =
GrowingList{Any}((); sizehint = sizehint, growthfactor = growthfactor) GrowingList{Any}((); sizehint = sizehint, growthfactor = growthfactor)
Base.length(l::GrowingList) = l.totallength[] Base.length(l::GrowingList) = l.totallength
function _raise_length!(l::GrowingList)
l.totallength += 1
end
Base.size(l::GrowingList) = (length(l),) Base.size(l::GrowingList) = (length(l),)
@inline function Base.getindex(l::GrowingList, n::Int) @inline function Base.getindex(l::GrowingList, n::Int)
@ -109,46 +113,51 @@ Base.size(l::GrowingList) = (length(l),)
end end
function Base.get!(l::GrowingList, n::Int, default) function Base.get!(l::GrowingList, n::Int, default)
if n <= l.totallength[] if n <= length(l)
return _unsafe_getindex(l.first, n) return _unsafe_getindex(l.first, n)
else else
lock(l.lock) ll = l.lock
lock(ll)
len = length(l) len = length(l)
nextlen = len + 1
if n <= len # try again, maybe already ok now if n <= len # try again, maybe already ok now
unlock(l.lock) unlock(ll)
return _unsafe_getindex(l.first, n) return _unsafe_getindex(l.first, n)
elseif n == len+1 elseif n == nextlen
_unsafe_get!(l.first, n, default, ceil(Int, (l.growthfactor-1)*len)) _unsafe_get!(l.first, n, default, ceil(Int, (l.growthfactor-1)*len))
Base.Threads.atomic_add!(l.totallength, 1) _raise_length!(l)
unlock(l.lock) unlock(ll)
return default return default
else else
@show Base.Threads.threadid(), l.totallength[], n unlock(ll)
unlock(l.lock) _inserterror(nextlen)
throw(ArgumentError("can only insert new element at next index: $(len+1)"))
end end
end end
end end
@noinline _inserterror(len::Int) =
throw(ArgumentError("can only insert new element at next index: " * string(len)))
function Base.get!(default::Base.Callable, l::GrowingList, n::Int) function Base.get!(default::Base.Callable, l::GrowingList, n::Int)
if n <= l.totallength[] if n <= l.totallength[]
return _unsafe_getindex(l.first, n) return _unsafe_getindex(l.first, n)
else else
v = default() v = default()
lock(l.lock) ll = l.lock
len = l.totallength[] lock(ll)
len = length(l)
nextlen = len + 1
if n <= len # try again, maybe already ok now if n <= len # try again, maybe already ok now
unlock(l.lock) unlock(ll)
return _unsafe_getindex(l.first, n) return _unsafe_getindex(l.first, n)
elseif n == len+1 elseif n == nextlen
_unsafe_get!(l.first, n, v, ceil(Int, (l.growthfactor-1)*len)) _unsafe_get!(l.first, n, v, ceil(Int, (l.growthfactor-1)*len))
Base.Threads.atomic_add!(l.totallength, 1) _raise_length!(l)
unlock(l.lock) unlock(ll)
return v return v
else else
@show Base.Threads.threadid(), l.totallength[], n unlock(ll)
unlock(l.lock) _inserterror(nextlen)
throw(ArgumentError("can only insert new element at next index: $(len+1)"))
end end
end end
end end

View file

@ -1,14 +1,12 @@
using Primes: isprime using Primes: isprime
import Base.divgcd import Base.divgcd
using Base.GMP.MPZ const primetable = GrowingList([2, 3]; sizehint = 1024)
const factortable = GrowingList([UInt8[], UInt8[1], UInt8[0,1]]; sizehint = 4096)
const primetable = GrowingList([2, 3]; sizehint = 256) const factorialtable = GrowingList([UInt32[], UInt32[1], UInt32[1,1]]; sizehint = 4096)
const factortable = GrowingList([UInt8[], UInt8[1], UInt8[0,1]]; sizehint = 1024) const bigprimetable = GrowingList([GrowingList([big(2)]; sizehint = 2048),
const factorialtable = GrowingList([UInt32[], UInt32[1], UInt32[1,1]]; sizehint = 1024) GrowingList([big(3)]; sizehint = 1024)];
const bigprimetable = GrowingList([GrowingList([big(2)]; sizehint = 512), sizehint = 1024)
GrowingList([big(3)]; sizehint = 256)];
sizehint = 256)
const bigone = big(1) const bigone = big(1)
# Make a prime iterator # Make a prime iterator
@ -23,8 +21,8 @@ Base.eltype(::PrimeIterator) = Int
# Get the `n`th prime; store all primes up to the `n`th if not yet available # Get the `n`th prime; store all primes up to the `n`th if not yet available
function prime(n::Int) function prime(n::Int)
k = min(length(primetable), length(bigprimetable)) k = min(length(primetable), length(bigprimetable))
p = primetable[k]
while k < n while k < n
@inbounds p = primetable[k]
p = p + 2 p = p + 2
while !isprime(p) while !isprime(p)
p += 2 p += 2
@ -32,12 +30,14 @@ function prime(n::Int)
k += 1 k += 1
# these lines do not get but set new elements; provided no other task did so earlier # these lines do not get but set new elements; provided no other task did so earlier
get!(primetable, k, p) get!(primetable, k, p)
get!(bigprimetable, k, GrowingList([big(p)]; sizehint = 256)) bp = big(p)
bpf = GrowingList{BigInt}((big(p),); sizehint = 4)
get!(bigprimetable, k, bpf)
k = min(length(primetable), length(bigprimetable)) k = min(length(primetable), length(bigprimetable))
# other threads might have inserted additional entries, # other threads might have inserted additional entries,
# make sure they are finished with both primetable and bigprimetable # make sure they are finished with both primetable and bigprimetable
end end
return primetable[n] @inbounds return primetable[n]
end end
Base.iterate(::PrimeIterator, n = 1) = prime(n), n+1 Base.iterate(::PrimeIterator, n = 1) = prime(n), n+1
@ -46,13 +46,14 @@ Base.iterate(::PrimeIterator, n = 1) = prime(n), n+1
function bigprime(n::Integer, e::Integer=1) function bigprime(n::Integer, e::Integer=1)
e == 0 && return bigone e == 0 && return bigone
p = prime(n) # triggers computation of prime(n) if necessary p = prime(n) # triggers computation of prime(n) if necessary
powerlist = bigprimetable[n] @inbounds powerlist = bigprimetable[n]
l = length(powerlist) l = length(powerlist)
@inbounds while l < e @inbounds while l < e
# compute next prime power as approximate square of existing results # compute next prime power as approximate square of existing results
l += 1 l += 1
k = l>>1 k = l>>1
get!(powerlist, l, powerlist[k]*powerlist[l-k]) newpower = powerlist[k]*powerlist[l-k]
get!(powerlist, l, newpower)
l = length(powerlist) # other threads might have inserted more powers l = length(powerlist) # other threads might have inserted more powers
end end
@inbounds return powerlist[e] @inbounds return powerlist[e]
@ -85,8 +86,7 @@ function primefactor(n::Integer)
m = length(factortable) m = length(factortable)
while m < abs(n) while m < abs(n)
m += 1 m += 1
powers = UInt8[] powers = UInt8[] # should be sufficient for all integers up to 2^255
# should be sufficient for all integers up to 2^255
a = m a = m
for p in primes() for p in primes()
f = 0 f = 0
@ -113,9 +113,12 @@ function primefactorial(n::Integer)
prevfactorial = factorialtable[m] prevfactorial = factorialtable[m]
m += 1 m += 1
f = primefactor(m).powers f = primefactor(m).powers
if length(f) > length(prevfactorial) # can at most be 1 larger
powers = similar(prevfactorial, length(f))
powers[1:end-1] = prevfactorial
powers[end] = 0
else
powers = copy(prevfactorial) powers = copy(prevfactorial)
if length(f) > length(powers) # can at most be 1 larger
push!(powers, 0)
end end
for k = 1:length(f) for k = 1:length(f)
powers[k] += f[k] powers[k] += f[k]
@ -367,10 +370,17 @@ function sumlist!(list::Vector{<:PrimeFactorization}, ind = 1:length(list))
for k in ind for k in ind
divexact!(list[k], g) divexact!(list[k], g)
end end
s = big(0) L = length(ind)
i = big(1) i = big(1)
for p in list if L > 32
MPZ.add!(s, _convert!(i, p)) l = L >> 1
s = sumlist!(list, first(ind).+(0:l-1))
s = MPZ.add!(s, sumlist!(list, first(ind).+(l:L-1)))
else # do sum, add to s
s = big(0)
for k in ind
MPZ.add!(s, _convert!(i, list[k]))
end
end end
return MPZ.mul!(s, _convert!(i, g)) return MPZ.mul!(s, _convert!(i, g))
end end