major update, thread safety, improved efficiency

This commit is contained in:
Jutho Haegeman 2021-06-11 15:50:25 +02:00
parent f06635b64b
commit b1303b9b79
10 changed files with 693 additions and 350 deletions

View file

@ -1,15 +1,15 @@
using Primes: isprime
import Base.divgcd
const primetable =
[2,3,5]
const factortable =
[UInt8[], UInt8[1], UInt8[0,1], UInt8[2], UInt8[0,0,1]]
const factorialtable =
[UInt32[], UInt32[], UInt32[1], UInt32[1,1], UInt32[3,1], UInt32[3,1,1]]
const bigprimetable =
[[big(2)], [big(3)], [big(5)]]
const bigone = Ref{BigInt}(big(1))
using Base.GMP.MPZ
const primetable = GrowingList([2, 3]; sizehint = 256)
const factortable = GrowingList([UInt8[], UInt8[1], UInt8[0,1]]; sizehint = 1024)
const factorialtable = GrowingList([UInt32[], UInt32[1], UInt32[1,1]]; sizehint = 1024)
const bigprimetable = GrowingList([GrowingList([big(2)]; sizehint = 512),
GrowingList([big(3)]; sizehint = 256)];
sizehint = 256)
const bigone = big(1)
# Make a prime iterator
struct PrimeIterator
@ -22,51 +22,71 @@ Base.eltype(::PrimeIterator) = Int
# Get the `n`th prime; store all primes up to the `n`th if not yet available
function prime(n::Int)
p = last(primetable)
while length(primetable) < n
k = min(length(primetable), length(bigprimetable))
p = primetable[k]
while k < n
p = p + 2
while !isprime(p)
p += 2
end
push!(primetable, p)
push!(bigprimetable, [big(p)])
k += 1
# these lines do not get but set new elements; provided no other task did so earlier
get!(primetable, k, p)
get!(bigprimetable, k, GrowingList([big(p)]; sizehint = 256))
k = min(length(primetable), length(bigprimetable))
# other threads might have inserted additional entries,
# make sure they are finished with both primetable and bigprimetable
end
@inbounds return primetable[n]
return primetable[n]
end
Base.iterate(::PrimeIterator, n = 1) = prime(n), n+1
# get primes and their powers as `BigInt`, also cache all results
function bigprime(n::Integer, e::Integer=1)
e == 0 && return bigone[]
e == 0 && return bigone
p = prime(n) # triggers computation of prime(n) if necessary
@inbounds l = length(bigprimetable[n])
powerlist = bigprimetable[n]
l = length(powerlist)
@inbounds while l < e
# compute next prime power as approximate square of existing results
k = (l+1)>>1
push!(bigprimetable[n], bigprimetable[n][k]*bigprimetable[n][l+1-k])
l += 1
k = l>>1
get!(powerlist, l, powerlist[k]*powerlist[l-k])
l = length(powerlist) # other threads might have inserted more powers
end
@inbounds return bigprimetable[n][e]
@inbounds return powerlist[e]
end
# A custom `Integer` subtype to store an integer as its prime factorization
struct PrimeFactorization{U<:Unsigned} <: Integer
# mutable to allow in place update of sign
mutable struct PrimeFactorization{U<:Unsigned} <: Integer
powers::Vector{U}
sign::Int8
PrimeFactorization{U}(powers::Vector, sign = one(Int8)) where {U<:Unsigned} =
new{U}(convert(Vector{U}, powers), sign)
end
# convenience constructor: normalizes powers to have last entry nonzero
PrimeFactorization(powers::Vector{U}, sign = one(Int8)) where {U<:Unsigned} =
PrimeFactorization{U}(_normalize_powers!(powers), sign)
function _normalize_powers!(v::Vector{<:Integer})
i = findlast(!iszero, v)
l = ifelse(i === nothing, 0, i)
l < length(v) && resize!(v, l)
return v
end
PrimeFactorization(powers::Vector{U}) where {U<:Unsigned} =
PrimeFactorization{U}(powers, one(Int8))
# define our own factor function, returning an instance of PrimeFactorization
function primefactor(n::Integer)
iszero(n) && return PrimeFactorization(UInt8[], zero(Int8))
iszero(n) && return PrimeFactorization{UInt8}(UInt8[], zero(Int8))
sn = n < 0 ? -one(Int8) : one(Int8)
n = abs(n)
m = length(factortable)
while m < abs(n)
m += 1
powers = UInt8[] # should be sufficient for all integers up to 2^255
powers = UInt8[]
# should be sufficient for all integers up to 2^255
a = m
for p in primes()
f = 0
@ -79,16 +99,18 @@ function primefactor(n::Integer)
push!(powers, f)
a == 1 && break
end
push!(factortable, powers)
get!(factortable, m, powers)
m = length(factortable) # other threads may have inserted other entries
end
@inbounds return PrimeFactorization(copy(factortable[n]), sn)
@inbounds return PrimeFactorization{UInt8}(factortable[n], sn)
end
function primefactorial(n::Integer)
n < 0 && throw(DomainError(n))
m = length(factorialtable)-1
n < 0 && throw(DomainError(n,"primefactorial only works for non-negative numbers"))
n <= 1 && return PrimeFactorization{UInt32}(UInt32[], one(Int8))
m = length(factorialtable)
@inbounds while m < n
prevfactorial = factorialtable[m+1]
prevfactorial = factorialtable[m]
m += 1
f = primefactor(m).powers
powers = copy(prevfactorial)
@ -98,36 +120,48 @@ function primefactorial(n::Integer)
for k = 1:length(f)
powers[k] += f[k]
end
push!(factorialtable, powers)
get!(factorialtable, m, powers)
m = length(factorialtable) # other threads may have inserted other entries
end
@inbounds return PrimeFactorization(copy(factorialtable[n+1]))
@inbounds return PrimeFactorization{UInt32}(factorialtable[n])
end
# Methods for PrimeFactorization:
Base.copy(a::PrimeFactorization) = PrimeFactorization(copy(a.powers), a.sign)
function Base.copy!(c::PrimeFactorization, a::PrimeFactorization)
c.sign = a.sign
copy!(c.powers, a.powers)
return c
end
Base.one(::Type{PrimeFactorization{U}}) where {U<:Unsigned} =
PrimeFactorization(Vector{U}())
PrimeFactorization{U}(Vector{U}(), one(Int8))
Base.zero(::Type{PrimeFactorization{U}}) where {U<:Unsigned} =
PrimeFactorization(Vector{U}(), zero(Int8))
PrimeFactorization{U}(Vector{U}(), zero(Int8))
one!(c::PrimeFactorization) = (c.sign = one(Int8); empty!(c.powers); return c)
zero!(c::PrimeFactorization) = (c.sign = zero(Int8); empty!(c.powers); return c)
Base.promote_rule(P::Type{<:PrimeFactorization},::Type{<:Integer}) = P
Base.promote_rule(P::Type{<:PrimeFactorization},::Type{BigInt}) = BigInt
Base.promote_rule(::Type{<:PrimeFactorization},::Type{BigInt}) = BigInt
Base.promote_rule(::Type{PrimeFactorization{U1}},
::Type{PrimeFactorization{U2}}) where {U1<:Unsigned, U2<:Unsigned} = PrimeFactorization{promote_type(U1, U2)}
Base.convert(P::Type{<:PrimeFactorization}, n::Integer) = convert(P, primefactor(n))
function Base.convert(::Type{BigInt}, a::PrimeFactorization)
A = one(BigInt)
function _convert!(x::BigInt, a::PrimeFactorization)
MPZ.set!(x, bigone)
for (n, e) in enumerate(a.powers)
if !iszero(e)
MPZ.mul!(A, bigprime(n, e))
MPZ.mul!(x, bigprime(n, e))
end
end
return a.sign < 0 ? MPZ.neg!(A) : A
return a.sign < 0 ? MPZ.neg!(x) : x
end
Base.convert(::Type{BigInt}, a::PrimeFactorization) = _convert!(one(BigInt), a)
Base.convert(::Type{PrimeFactorization{U}}, a::PrimeFactorization{U}) where {U<:Unsigned} =
a
Base.convert(::Type{PrimeFactorization{U}}, a::PrimeFactorization) where {U<:Unsigned} =
PrimeFactorization(convert(Vector{U}, a.powers), a.sign)
PrimeFactorization{U}(convert(Vector{U}, a.powers), a.sign)
Base.:(==)(a::PrimeFactorization, b::PrimeFactorization) =
a.powers == b.powers && a.sign == b.sign
@ -138,7 +172,8 @@ function Base.:<(a::PrimeFactorization, b::PrimeFactorization)
return <(-b, -a)
else
ag, bg = divgcd(a, b)
if length(ag.powers) <= length(bg.powers) &&
ag == bg && return false
if length(ag.powers) <= length(bg.powers)
all(k->ag.powers[k]<bg.powers[k], 1:length(ag.powers))
return true
else
@ -151,35 +186,107 @@ end
# Addition and subtraction will require conversion to BigInt
Base.sign(a::PrimeFactorization) = a.sign
Base.:-(a::PrimeFactorization) = PrimeFactorization(a.powers, -a.sign)
function Base.:*(a::PrimeFactorization{T}, b::PrimeFactorization{T}) where {T}
if a.sign == 0
return a
elseif b.sign ==0
return b
neg!(a::PrimeFactorization) = (a.sign = -a.sign; return a)
function mul!(c::PrimeFactorization, a::PrimeFactorization, b::PrimeFactorization)
if a.sign == 0 || b.sign == 0
zero!(c)
else
return PrimeFactorization(_vadd!(copy(a.powers), b.powers), a.sign*b.sign)
c.sign = a.sign * b.sign
la = length(a.powers)
lb = length(b.powers)
lc = max(la, lb)
lc === length(c.powers) || resize!(c.powers, lc)
@inbounds for k = 1:min(la,lb)
c.powers[k] = +(a.powers[k], b.powers[k])
end
if c !== a
@inbounds for k = lb+1:la
c.powers[k] = a.powers[k]
end
end
@inbounds for k = la+1:lb
c.powers[k] = b.powers[k]
end
end
return c
end
function Base.gcd(a::PrimeFactorization{T}, b::PrimeFactorization{T}) where {T}
if a.sign == 0
return b
elseif b.sign ==0
return a
# unlike div, this one errors if the a is not divisible by b
function divexact!(c::PrimeFactorization, a::PrimeFactorization, b::PrimeFactorization)
if iszero(a.sign)
zero!(c)
elseif iszero(b.sign)
throw(DivideError())
else
return PrimeFactorization(_vmin!(copy(a.powers), b.powers))
c.sign = a.sign * b.sign
la = length(a.powers)
lb = length(b.powers)
if lb > la
throw(DivideError())
end
lc = la
if lb == lc
while lc > 0 && a.powers[lc] == b.powers[lc]
lc -= 1
end
end
lc == length(c.powers) || resize!(c.powers, lc)
@inbounds for k = 1:min(lb, lc)
if b.powers[k] > a.powers[k]
throw(DivideError())
end
c.powers[k] = a.powers[k] - b.powers[k]
end
if c !== a
@inbounds for k = lb+1:lc
c.powers[k] = a.powers[k]
end
end
end
return c
end
function Base.lcm(a::PrimeFactorization{T}, b::PrimeFactorization{T}) where {T}
function gcd!(c::PrimeFactorization, a::PrimeFactorization, b::PrimeFactorization)
if a.sign == 0
return a
copy!(c.powers, b.powers)
elseif b.sign ==0
return b
copy!(c.powers, a.powers)
else
return PrimeFactorization(_vmax!(copy(a.powers), b.powers))
c.sign = one(Int8)
la = length(a.powers)
lb = length(b.powers)
lc = min(la, lb)
lc === length(c.powers) || resize!(c.powers, lc)
@inbounds for k = 1:lc
c.powers[k] = min(a.powers[k], b.powers[k])
end
end
c.sign = one(Int8)
return c
end
function lcm!(c::PrimeFactorization, a::PrimeFactorization, b::PrimeFactorization)
if a.sign == 0 || b.sign == 0
return zero!(c)
else
c.sign = one(Int8)
la = length(a.powers)
lb = length(b.powers)
lc = max(la, lb)
lc === length(c.powers) || resize!(c.powers, lc)
@inbounds for k = 1:min(la,lb)
c.powers[k] = max(a.powers[k], b.powers[k])
end
if c !== a
@inbounds for k = lb+1:la
c.powers[k] = a.powers[k]
end
end
@inbounds for k = la+1:lb
c.powers[k] = b.powers[k]
end
end
c.sign = one(Int8)
return c
end
Base.divgcd(a::PrimeFactorization, b::PrimeFactorization) = divgcd!(copy(a), copy(b))
function divgcd!(a::PrimeFactorization, b::PrimeFactorization)
af, bf = a.powers, b.powers
for k = 1:min(length(af), length(bf))
@ -187,26 +294,50 @@ function divgcd!(a::PrimeFactorization, b::PrimeFactorization)
af[k] -= gk
bf[k] -= gk
end
while length(af) > 0 && iszero(last(af))
pop!(af)
end
while length(bf) > 0 && iszero(last(bf))
pop!(bf)
end
_normalize_powers!(a.powers)
_normalize_powers!(b.powers)
return a, b
end
mul!(a::PrimeFactorization, b::PrimeFactorization) = mul!(a, a, b)
divexact!(a::PrimeFactorization, b::PrimeFactorization) = divexact!(a, a, b)
gcd!(a::PrimeFactorization, b::PrimeFactorization) = gcd!(a, a, b)
lcm!(a::PrimeFactorization, b::PrimeFactorization) = lcm!(a, a, b)
Base.:-(a::PrimeFactorization) = neg!(copy(a))
function Base.:*(a::PrimeFactorization, b::PrimeFactorization)
P = promote_type(typeof(a), typeof(b))
if length(a.powers) >= length(b.powers)
return typeof(a) == P ? mul!(copy(a), b) : mul!(convert(P, a), b)
else
return typeof(b) == P ? mul!(copy(b), a) : mul!(convert(P, b), a)
end
end
function Base.lcm(a::PrimeFactorization, b::PrimeFactorization)
P = promote_type(typeof(a), typeof(b))
if length(a.powers) >= length(b.powers)
return typeof(a) == P ? lcm!(copy(a), b) : lcm!(convert(P, a), b)
else
return typeof(b) == P ? lcm!(copy(b), a) : lcm!(convert(P, b), a)
end
end
function Base.gcd(a::PrimeFactorization, b::PrimeFactorization)
P = promote_type(typeof(a), typeof(b))
if length(a.powers) <= length(b.powers)
return typeof(a) == P ? lcm!(copy(a), b) : lcm!(convert(P, a), b)
else
return typeof(b) == P ? lcm!(copy(b), a) : lcm!(convert(P, b), a)
end
end
Base.divgcd(a::PrimeFactorization, b::PrimeFactorization) = divgcd!(copy(a), copy(b))
# no promotion necessary, should be smaller than a
divexact(a::PrimeFactorization, b::PrimeFactorization) = divexact!(copy(a), b)
# split `a::PrimeFactorization` into a square `s` and a remainder `r`, such that
# `a = s^2 * r` and the powers in the prime factorization of `r` are zero or one
function splitsquare(a::PrimeFactorization)
r = PrimeFactorization(map(p->convert(UInt8, isodd(p)), a.powers), a.sign)
while length(r.powers) > 0 && iszero(last(r.powers))
pop!(r.powers)
end
s = PrimeFactorization(map(p->(p>>1), a.powers))
while length(s.powers) > 0 && iszero(last(s.powers))
pop!(s.powers)
end
return s, r
end
@ -215,13 +346,13 @@ end
function commondenominator!(nums::Vector{P}, dens::Vector{P}) where {P<:PrimeFactorization}
isempty(nums) && return one(P)
# accumulate lcm of denominator
den = PrimeFactorization(copy(dens[1].powers))
den = copy(dens[1])
for i = 2:length(dens)
_vmax!(den.powers, dens[i].powers)
lcm!(den, dens[i])
end
# rescale numerators
for i = 1:length(nums)
_vsub!(_vadd!(nums[i].powers, den.powers), dens[i].powers)
divexact!(mul!(nums[i], den), dens[i])
end
return den
end
@ -229,68 +360,17 @@ end
# auxiliary function to compute sums of a list of PrimeFactorizations as quickly as possible
function sumlist!(list::Vector{<:PrimeFactorization}, ind = 1:length(list))
# first compute gcd to take out common factors
g = PrimeFactorization(copy(list[ind[1]].powers))
g = copy(list[ind[1]])
for k in ind
_vmin!(g.powers, list[k].powers)
gcd!(g, list[k])
end
for k in ind
_vsub!(list[k].powers, g.powers)
divexact!(list[k], g)
end
L = length(ind)
if L > 32
l = L >> 1
s = sumlist!(list, first(ind).+(0:l-1)) + sumlist!(list, first(ind).+(l:L-1))
else
# do sum
s = big(0)
for k in ind
MPZ.add!(s, convert(BigInt, list[k]))
end
s = big(0)
i = big(1)
for p in list
MPZ.add!(s, _convert!(i, p))
end
return MPZ.mul!(s, convert(BigInt, g))
end
# Mutating vector methods that also grow and shrink as required
function _vmin!(af::Vector{U}, bf::Vector{U}) where {U<:Unsigned}
while length(af) > length(bf)
pop!(af)
end
@inbounds for k = 1:length(af)
af[k] = min(af[k], bf[k])
end
while length(af) > 0 && iszero(last(af))
pop!(af)
end
return af
end
function _vmax!(af::Vector{U}, bf::Vector{U}) where {U<:Unsigned}
while length(bf) > length(af)
push!(af, zero(U))
end
@inbounds for k = 1:length(bf)
af[k] = max(af[k], bf[k])
end
return af
end
function _vadd!(af::Vector{U}, bf::Vector{U}) where {U<:Unsigned}
while length(bf) > length(af)
push!(af, zero(U))
end
@inbounds for k = 1:length(bf)
af[k] = +(af[k], bf[k])
end
return af
end
function _vsub!(af::Vector{U}, bf::Vector{U}) where {U<:Unsigned}
if length(bf) > length(af)
throw(OverflowError())
end
@inbounds for k = 1:length(bf)
bf[k] > af[k] && throw(OverflowError())
af[k] -= bf[k]
end
while length(af) > 0 && iszero(last(af))
pop!(af)
end
return af
return MPZ.mul!(s, _convert!(i, g))
end