Skip to content

Commit 85a2ffb

Browse files
authored
Narrow Zero arithmetic methods to reduce invalidations (#348)
The methods +(x::Any, ::Zero), *(::Zero, ::Any), etc. used untyped arguments, causing method invalidations by superseding the fundamental +(x, y) and *(x, y) fallbacks in Base. Narrow these to Number, AbstractArray, and AbstractMutable since these cover the types that participate in MutableArithmetics rewrites. Downstream packages with custom types can define their own +(::MyType, ::Zero) methods (as MultivariatePolynomials already does).
1 parent b7d82d6 commit 85a2ffb

2 files changed

Lines changed: 28 additions & 7 deletions

File tree

src/dispatch.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@
1313

1414
abstract type AbstractMutable end
1515

16+
# Zero arithmetic methods for AbstractMutable types.
17+
# The main Zero arithmetic is defined in rewrite.jl with Number/AbstractArray;
18+
# these methods extend it to AbstractMutable.
19+
Base.:*(z::Zero, ::AbstractMutable) = z
20+
Base.:*(::AbstractMutable, z::Zero) = z
21+
Base.:+(::Zero, x::AbstractMutable) = copy_if_mutable(x)
22+
Base.:+(x::AbstractMutable, ::Zero) = copy_if_mutable(x)
23+
Base.:-(::Zero, x::AbstractMutable) = operate(-, x)
24+
Base.:-(x::AbstractMutable, ::Zero) = copy_if_mutable(x)
25+
1626
function Base.sum(
1727
a::AbstractArray{T};
1828
dims = :,

src/rewrite.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,31 @@ broadcast!!(::Union{typeof(add_mul),typeof(+)}, ::Zero, x) = copy_if_mutable(x)
5858
broadcast!!(::typeof(add_mul), ::Zero, x, y) = x * y
5959

6060
# Needed in `@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)`
61-
Base.:*(z::Zero, ::Any) = z
62-
Base.:*(::Any, z::Zero) = z
61+
# These methods are narrowed to `Number` and `AbstractArray` to avoid invalidating
62+
# the very broad `+(x, y)`, `*(x, y)` fallbacks in Base, which causes thousands of
63+
# method invalidations across the ecosystem. Downstream packages that define custom
64+
# types participating in MutableArithmetics rewrites should define their own
65+
# `+(::MyType, ::Zero)` etc. methods.
66+
Base.:*(z::Zero, ::Number) = z
67+
Base.:*(::Number, z::Zero) = z
68+
Base.:*(z::Zero, ::AbstractArray) = z
69+
Base.:*(::AbstractArray, z::Zero) = z
6370
Base.:*(z::Zero, ::Zero) = z
64-
Base.:+(::Zero, x::Any) = x
65-
Base.:+(x::Any, ::Zero) = x
71+
Base.:+(::Zero, x::Number) = x
72+
Base.:+(x::Number, ::Zero) = x
73+
Base.:+(::Zero, x::AbstractArray) = x
74+
Base.:+(x::AbstractArray, ::Zero) = x
6675
Base.:+(z::Zero, ::Zero) = z
67-
Base.:-(::Zero, x::Any) = -x
68-
Base.:-(x::Any, ::Zero) = x
76+
Base.:-(::Zero, x::Number) = -x
77+
Base.:-(x::Number, ::Zero) = x
78+
Base.:-(::Zero, x::AbstractArray) = -x
79+
Base.:-(x::AbstractArray, ::Zero) = x
6980
Base.:-(z::Zero, ::Zero) = z
7081
Base.:-(z::Zero) = z
7182
Base.:+(z::Zero) = z
7283
Base.:*(z::Zero) = z
7384

74-
function Base.:/(z::Zero, x::Any)
85+
function Base.:/(z::Zero, x::Number)
7586
if iszero(x)
7687
throw(DivideError())
7788
else

0 commit comments

Comments
 (0)