From afa08f230034534f81d10c761aa5c939fe6e6b5b Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Sat, 24 Jan 2026 10:35:10 +0100 Subject: [PATCH 1/8] work on symbolics v7 --- Project.toml | 4 +- src/DifferentialEquation.jl | 2 +- src/HarmonicVariable.jl | 2 +- src/QuestBase.jl | 4 -- src/Symbolics/Symbolics_utils.jl | 100 ++++++++++++++++++++----------- src/Symbolics/drop_powers.jl | 9 ++- src/Symbolics/exponentials.jl | 36 +++++++---- src/Symbolics/fourier.jl | 68 +++++++++------------ test/symbolics.jl | 22 +++---- 9 files changed, 138 insertions(+), 109 deletions(-) diff --git a/Project.toml b/Project.toml index 2f1a691..3749d35 100644 --- a/Project.toml +++ b/Project.toml @@ -12,8 +12,8 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] DocStringExtensions = "0.9.4" -SymbolicUtils = "3.25" -Symbolics = "6.34" +SymbolicUtils = "3.25, 4" +Symbolics = "6.34, 7" julia = "1.10" Random = "1.10" LinearAlgebra = "1.10" diff --git a/src/DifferentialEquation.jl b/src/DifferentialEquation.jl index fcda18c..cb97dab 100644 --- a/src/DifferentialEquation.jl +++ b/src/DifferentialEquation.jl @@ -118,7 +118,7 @@ $(TYPEDSIGNATURES) Return the independent dependent variables of `diff_eom`. """ function get_independent_variables(diff_eom::DifferentialEquation) - return Num.(flatten(unique([x.val.arguments for x in keys(diff_eom.equations)]))) + return Num.(flatten(unique([arguments(x.val) for x in keys(diff_eom.equations)]))) end """ diff --git a/src/HarmonicVariable.jl b/src/HarmonicVariable.jl index 0023311..c2f91b8 100644 --- a/src/HarmonicVariable.jl +++ b/src/HarmonicVariable.jl @@ -43,7 +43,7 @@ function _show_ansatz(var::HarmonicVariable) if isempty(var.type) return string(var.symbol) end - t = var.natural_variable.val.arguments + t = arguments(var.natural_variable.val) t = length(t) == 1 ? string(t[1]) : error("more than 1 independent variable") ω = string(var.ω) terms = Dict("u" => "*cos(" * ω * t * ")", "v" => "*sin(" * ω * t * ")", "a" => "") diff --git a/src/QuestBase.jl b/src/QuestBase.jl index 5159792..0c87448 100644 --- a/src/QuestBase.jl +++ b/src/QuestBase.jl @@ -7,7 +7,6 @@ using LinearAlgebra: LinearAlgebra using SymbolicUtils: SymbolicUtils, Postwalk, - Sym, BasicSymbolic, isterm, ispow, @@ -15,8 +14,6 @@ using SymbolicUtils: isdiv, ismul, add_with_div, - frac_maketerm, - @compactified, issym using Symbolics: @@ -27,7 +24,6 @@ using Symbolics: get_variables, Equation, Differential, - @variables, arguments, substitute, term, diff --git a/src/Symbolics/Symbolics_utils.jl b/src/Symbolics/Symbolics_utils.jl index aee9122..a7a9f86 100644 --- a/src/Symbolics/Symbolics_utils.jl +++ b/src/Symbolics/Symbolics_utils.jl @@ -10,33 +10,42 @@ end expand_all(x::Complex{Num}) = expand_all(x.re) + im * expand_all(x.im) function expand_fraction(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Add => _apply_termwise(expand_fraction, x) - Mul => _apply_termwise(expand_fraction, x) - Div => sum([arg / x.den for arg in arguments(x.num)]) - _ => x + if isadd(x) + return _apply_termwise(expand_fraction, x) + elseif ismul(x) + return _apply_termwise(expand_fraction, x) + elseif isdiv(x) + return sum([arg / x.den for arg in arguments(x.num)]) + else + return x end end expand_fraction(x::Num) = Num(expand_fraction(x.val)) "Apply a function f on every member of a sum or a product" function _apply_termwise(f, x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Add => sum([f(arg) for arg in arguments(x)]) - Mul => prod([f(arg) for arg in arguments(x)]) - Div => _apply_termwise(f, x.num) / _apply_termwise(f, x.den) - _ => f(x) + if isadd(x) + return sum([f(arg) for arg in arguments(x)]) + elseif ismul(x) + return prod([f(arg) for arg in arguments(x)]) + elseif isdiv(x) + return _apply_termwise(f, x.num) / _apply_termwise(f, x.den) + else + return f(x) end end simplify_complex(x::Complex) = isequal(x.im, 0) ? x.re : x.re + im * x.im simplify_complex(x) = x function simplify_complex(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Add => _apply_termwise(simplify_complex, x) - Mul => _apply_termwise(simplify_complex, x) - Div => _apply_termwise(simplify_complex, x) - _ => x + if isadd(x) + return _apply_termwise(simplify_complex, x) + elseif ismul(x) + return _apply_termwise(simplify_complex, x) + elseif isdiv(x) + return _apply_termwise(simplify_complex, x) + else + return x end end @@ -68,7 +77,10 @@ function substitute_all(x::Complex{Num}, rules::Collections) return substitute_all(x.re, rules) + im * substitute_all(x.im, rules) end -get_independent(x::Num, t::Num)::Num = wrap(get_independent(x.val, t)) +function get_independent(x::Num, t::Num) + result = get_independent(x.val, t) + return result isa Complex{Num} ? result : wrap(result) +end function get_independent(x::Complex{Num}, t::Num) return get_independent(x.re, t) + im * get_independent(x.im, t) end @@ -76,14 +88,21 @@ get_independent(v::Vector{Num}, t::Num) = [get_independent(el, t) for el in v] get_independent(x, t::Num) = x function get_independent(x::BasicSymbolic, t::Num) - @compactified x::BasicSymbolic begin - Add => sum([get_independent(arg, t) for arg in arguments(x)]) - Mul => prod([get_independent(arg, t) for arg in arguments(x)]) - Div => !is_function(x.den, t) ? get_independent(x.num, t) / x.den : 0 - Pow => !is_function(x.base, t) && !is_function(x.exp, t) ? x : 0 - Term => !is_function(x, t) ? x : 0 - Sym => !is_function(x, t) ? x : 0 - _ => x + if isadd(x) + return sum([get_independent(arg, t) for arg in arguments(x)]) + elseif ismul(x) + return prod([get_independent(arg, t) for arg in arguments(x)]) + elseif isdiv(x) + return !is_function(x.den, t) ? get_independent(x.num, t) / x.den : 0 + elseif ispow(x) + base, exponent = arguments(x) + return !is_function(base, t) && !is_function(exponent, t) ? x : 0 + elseif isterm(x) + return !is_function(x, t) ? x : 0 + elseif issym(x) + return !is_function(x, t) ? x : 0 + else + return x end end @@ -94,11 +113,14 @@ function get_all_terms(x::Equation) return unique(cat(get_all_terms(Num(x.lhs)), get_all_terms(Num(x.rhs)); dims=1)) end function _get_all_terms(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Add => vcat([_get_all_terms(term) for term in SymbolicUtils.arguments(x)]...) - Mul => SymbolicUtils.arguments(x) - Div => [_get_all_terms(x.num)..., _get_all_terms(x.den)...] - _ => [x] + if isadd(x) + return vcat([_get_all_terms(term) for term in SymbolicUtils.sorted_arguments(x)]...) + elseif ismul(x) + return SymbolicUtils.sorted_arguments(x) + elseif isdiv(x) + return [_get_all_terms(x.num)..., _get_all_terms(x.den)...] + else + return [x] end end _get_all_terms(x) = x @@ -112,8 +134,8 @@ function is_harmonic(x::Num, t::Num)::Bool if !prod(trigs) return false else - powers = [max_power(first(term.val.arguments), t) for term in t_terms[trigs]] - return all(isone, powers) + powers = [max_power(first(arguments(term.val)), t) for term in t_terms[trigs]] + return all(isequal(1), powers) end end @@ -126,10 +148,18 @@ is_function(f, var) = any(isequal.(get_variables(f), var)) """ Counts the number of derivatives of a symbolic variable. """ -function count_derivatives(x::Symbolics.BasicSymbolic) - (Symbolics.isterm(x) || Symbolics.issym(x)) || +function count_derivatives(x::BasicSymbolic) + if Symbolics.is_derivative(x) + arg = first(arguments(x)) + (issym(arg) || + Symbolics.is_derivative(arg) || + (isterm(arg) && issym(operation(arg)))) || + error("The input is not a single term or symbol") + D = operation(x) + return D.order + count_derivatives(arg) + end + (issym(x) || (isterm(x) && issym(operation(x)))) || error("The input is not a single term or symbol") - bool = Symbolics.is_derivative(x) - return bool ? 1 + count_derivatives(first(arguments(x))) : 0 + return 0 end count_derivatives(x::Num) = count_derivatives(Symbolics.unwrap(x)) diff --git a/src/Symbolics/drop_powers.jl b/src/Symbolics/drop_powers.jl index 3c88cfa..ce42a0a 100644 --- a/src/Symbolics/drop_powers.jl +++ b/src/Symbolics/drop_powers.jl @@ -45,7 +45,11 @@ drop_powers(x, vars, deg::Int) = drop_powers(wrap(x), vars, deg) function max_power(x::Num, y::Num) terms = get_all_terms(x) powers = power_of.(terms, y) - return maximum(powers) + literal_powers = Int[ + Int(SymbolicUtils.unwrap_const(p)) for p in powers if SymbolicUtils.is_literal_number(p) + ] + isempty(literal_powers) && return 0 + return maximum(literal_powers) end max_power(x::Vector{Num}, y::Num) = maximum(max_power.(x, y)) @@ -60,7 +64,8 @@ end function power_of(x::BasicSymbolic, y::BasicSymbolic) if ispow(x) && issym(y) - return isequal(x.base, y) ? x.exp : 0 + base, exponent = arguments(x) + return isequal(base, y) ? exponent : 0 elseif issym(x) && issym(y) return isequal(x, y) ? 1 : 0 else diff --git a/src/Symbolics/exponentials.jl b/src/Symbolics/exponentials.jl index 74423ec..71203e0 100644 --- a/src/Symbolics/exponentials.jl +++ b/src/Symbolics/exponentials.jl @@ -2,14 +2,22 @@ expand_exp_power(expr::Num) = expand_exp_power(expr.val) simplify_exp_products(x::Num) = simplify_exp_products(x.val) "Returns true if expr is an exponential" -isexp(expr) = isterm(expr) && expr.f == exp +isexp(expr) = isterm(expr) && operation(expr) === exp "Expand powers of exponential such that exp(x)^n => exp(x*n) " function expand_exp_power(expr::BasicSymbolic) - @compactified expr::BasicSymbolic begin - Add => sum([expand_exp_power(arg) for arg in arguments(expr)]) - Mul => prod([expand_exp_power(arg) for arg in arguments(expr)]) - _ => ispow(expr) && isexp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr + if isadd(expr) + return sum([expand_exp_power(arg) for arg in arguments(expr)]) + elseif ismul(expr) + return prod([expand_exp_power(arg) for arg in arguments(expr)]) + else + if ispow(expr) + base, exponent = arguments(expr) + if isexp(base) + return exp(arguments(base)[1] * exponent) + end + end + return expr end end expand_exp_power(expr) = expr @@ -17,11 +25,14 @@ expand_exp_power(expr) = expr "Simplify products of exponentials such that exp(a)*exp(b) => exp(a+b) This is included in SymbolicUtils as of 17.0 but the method here avoid other simplify calls" function simplify_exp_products(expr::BasicSymbolic) - @compactified expr::BasicSymbolic begin - Add => _apply_termwise(simplify_exp_products, expr) - Div => _apply_termwise(simplify_exp_products, expr) - Mul => simplify_exp_products_mul(expr) - _ => expr + if isadd(expr) + return _apply_termwise(simplify_exp_products, expr) + elseif isdiv(expr) + return _apply_termwise(simplify_exp_products, expr) + elseif ismul(expr) + return simplify_exp_products_mul(expr) + else + return expr end end function simplify_exp_products(x::Complex{Num}) @@ -33,9 +44,8 @@ function simplify_exp_products_mul(expr) rest = isempty(rest_ind) ? 1 : prod(arguments(expr)[rest_ind]) total = isempty(ind) ? 0 : sum(getindex.(arguments.(arguments(expr)[ind]), 1)) if SymbolicUtils.is_literal_number(total) - (total == 0 && return rest) - else - return rest * exp(total) + return iszero(SymbolicUtils.unwrap_const(total)) ? rest : rest * exp(total) end + return rest * exp(total) end simplify_exp_products(x) = x diff --git a/src/Symbolics/fourier.jl b/src/Symbolics/fourier.jl index 02d3e59..99751c4 100644 --- a/src/Symbolics/fourier.jl +++ b/src/Symbolics/fourier.jl @@ -36,9 +36,11 @@ end is_trig(f::Num) = is_trig(f.val) is_trig(f) = false function is_trig(f::BasicSymbolic) - f = ispow(f) ? f.base : f - isterm(f) && SymbolicUtils.operation(f) ∈ [cos, sin] && return true - return false + if ispow(f) + base, _ = arguments(f) + f = base + end + return isterm(f) && SymbolicUtils.operation(f) ∈ [cos, sin] end """ @@ -66,7 +68,7 @@ function fourier_cos_term(x, ω, t) end "Simplify fraction a/b + c/d = (ad + bc)/bd" -add_div(x) = wrap(Postwalk(add_with_div; maketerm=frac_maketerm)(unwrap(x))) +add_div(x) = wrap(Postwalk(add_with_div)(unwrap(x))) """ fourier_sin_term(x, ω, t) @@ -120,35 +122,9 @@ using Euler's formula: ``\\exp(ix) = \\cos(x) + i*\\sin(x)``. Returns the converted expression as a `Num` type. """ function trig_to_exp(x::Num) - all_terms = get_all_terms(x) - trigs = filter(z -> is_trig(z), all_terms) - - rules = [] - for trig in trigs - is_pow = ispow(trig.val) # trig is either a trig or a power of a trig - power = is_pow ? trig.val.exp : 1 - arg = is_pow ? arguments(trig.val.base)[1] : arguments(trig.val)[1] - type = is_pow ? operation(trig.val.base) : operation(trig.val) - - if type == cos - term = Complex{Num}((exp(im * arg) + exp(-im * arg))^power * (1//2)^power, 0) - elseif type == sin - term = - (1 * im^power) * - Complex{Num}(((exp(-im * arg) - exp(im * arg)))^power * (1//2)^power, 0) - end - # avoid Complex{Num} where possible as this causes bugs - # instead, the Nums store SymbolicUtils Complex types - term = Num(Symbolics.expand(term.re.val + im * term.im.val)) - append!(rules, [trig => term]) - end - - result = Symbolics.substitute(x, Dict(rules)) - return convert_to_Num(result) + return Num(trig_to_exp(x.val)) end trig_to_exp(x::Complex{Num}) = trig_to_exp(x.re) + im * trig_to_exp(x.im) -convert_to_Num(x::Complex{Num})::Num = Num(first(x.re.val.arguments)) -convert_to_Num(x::Num)::Num = x """ trig_to_exp(x::BasicSymbolic) @@ -163,9 +139,16 @@ function trig_to_exp(x::BasicSymbolic) rules = [] for trig in trigs is_pow = ispow(trig) # trig is either a trig or a power of a trig - power = is_pow ? trig.exp : 1 - arg = is_pow ? arguments(trig.base)[1] : arguments(trig)[1] - type = is_pow ? operation(trig.base) : operation(trig) + if is_pow + base, exponent = arguments(trig) + power = exponent + arg = arguments(base)[1] + type = operation(base) + else + power = 1 + arg = arguments(trig)[1] + type = operation(trig) + end if type == cos term = (exp(im * arg) + exp(-im * arg))^power * (1 // 2)^power @@ -198,8 +181,8 @@ trigonometric arguments for consistent simplification. function exp_to_trig(x::BasicSymbolic) if isadd(x) || isdiv(x) || ismul(x) return _apply_termwise(exp_to_trig, x) - elseif isterm(x) && x.f == exp - arg = first(x.arguments) + elseif isterm(x) && operation(x) === exp + arg = first(arguments(x)) trigarg = Symbolics.expand(-im * arg) # the argument of the to-be trig function trigarg = simplify_complex(trigarg) @@ -218,11 +201,16 @@ function exp_to_trig(x::BasicSymbolic) cos(trigarg) + im * sin(trigarg) end end - return if ismul(trigarg) && trigarg.coeff < 0 - cos(-trigarg) - im * sin(-trigarg) - else - cos(trigarg) + im * sin(trigarg) + if ismul(trigarg) + coeff = trigarg.coeff + if SymbolicUtils.is_literal_number(coeff) + coeff_val = SymbolicUtils.unwrap_const(coeff) + if coeff_val isa Real && coeff_val < 0 + return cos(-trigarg) - im * sin(-trigarg) + end + end end + return cos(trigarg) + im * sin(trigarg) else return x end diff --git a/test/symbolics.jl b/test/symbolics.jl index 56c1aa9..b833245 100644 --- a/test/symbolics.jl +++ b/test/symbolics.jl @@ -69,12 +69,12 @@ end trigs = [cos(f * t), sin(f * t)] for (i, trig) in pairs(trigs) z = trig_to_exp(trig) - @eqtest expand(exp_to_trig(z)) == trig + @eqtest simplify(expand(exp_to_trig(z))) == trig end trigs′ = [cos_euler(f * t), sin_euler(f * t)] for (i, trig) in pairs(trigs′) z = trig_to_exp(trig) - @eqtest expand(exp_to_trig(z)) == trigs[i] + @eqtest simplify(expand(exp_to_trig(z))) == trigs[i] end end @@ -176,24 +176,24 @@ end using QuestBase: simplify_complex @variables a, b, c for z in Complex{Num}[a, a * b, a / b] - @test simplify_complex(z).val isa BasicSymbolic{Real} + @test simplify_complex(z).val isa BasicSymbolic end z = Complex{Num}(1 + 0 * im) - @test simplify_complex(z).val isa Int64 + @test SymbolicUtils.is_literal_number(simplify_complex(z).val) end @testset "get_all_terms" begin using QuestBase: get_all_terms @variables a, b, c - @eqtest get_all_terms(a + b + c) == [a, b, c] - @eqtest get_all_terms(a * b * c) == [a, b, c] - @eqtest get_all_terms(a / b) == [a, b] - @eqtest get_all_terms(a^2 + b^2 + c^2) == [b^2, a^2, c^2] - @eqtest get_all_terms(a^2 / b^2) == [a^2, b^2] - @eqtest get_all_terms(2 * b^2) == [2, b^2] - @eqtest get_all_terms(2 * b^2 ~ a) == [2, b^2, a] + @eqtest sort(get_all_terms(a + b + c); by=string) == sort([a, b, c]; by=string) + @eqtest sort(get_all_terms(a * b * c); by=string) == sort([a, b, c]; by=string) + @eqtest sort(get_all_terms(a / b); by=string) == sort([a, b]; by=string) + @eqtest sort(get_all_terms(a^2 + b^2 + c^2); by=string) == sort([a^2, b^2, c^2]; by=string) + @eqtest sort(get_all_terms(a^2 / b^2); by=string) == sort([a^2, b^2]; by=string) + @eqtest sort(get_all_terms(2 * b^2); by=string) == sort([2, b^2]; by=string) + @eqtest sort(get_all_terms(2 * b^2 ~ a); by=string) == sort([2, b^2, a]; by=string) end @testset "get_independent" begin From 5115ec18ded746f2b7e5435870c474e8852bd554 Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Sat, 24 Jan 2026 20:26:26 +0100 Subject: [PATCH 2/8] feat: update to symbolics V7 and SymbolicUtils v4 --- .github/workflows/Tests.yml | 1 - Project.toml | 2 +- src/DifferentialEquation.jl | 17 ++- src/HarmonicEquation.jl | 22 ++- src/Symbolics/Symbolics_utils.jl | 18 ++- src/Symbolics/drop_powers.jl | 19 ++- src/Symbolics/exponentials.jl | 5 +- src/Symbolics/fourier.jl | 232 +++++++++++++++++++++++++++++-- src/Variables.jl | 4 +- src/utils.jl | 61 +++++++- test/DifferentialEquations.jl | 4 +- test/symbolics.jl | 73 +++++++--- 12 files changed, 403 insertions(+), 55 deletions(-) diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index bea1792..2ca70fc 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -40,7 +40,6 @@ jobs: fail-fast: false matrix: version: - - 'pre' - 'lts' - '1' os: diff --git a/Project.toml b/Project.toml index 3749d35..c60bd29 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,7 @@ Test = "1.10" OrderedCollections = "1.8" Aqua = "0.8.11" ExplicitImports = "1.11" -JET = "0.9.18, 0.10.0" +JET = "0.9.18, 0.10.0, 0.11" CheckConcreteStructs = "0.1.0" [extras] diff --git a/src/DifferentialEquation.jl b/src/DifferentialEquation.jl index cb97dab..ea4f2e3 100644 --- a/src/DifferentialEquation.jl +++ b/src/DifferentialEquation.jl @@ -163,7 +163,10 @@ corresponding to second-order differential equations. function is_rearranged_standard(eom::DifferentialEquation, degree=2) tvar = get_independent_variables(eom)[1] D = Differential(tvar)^degree - return isequal(getfield.(values(eom.equations), :lhs), D.(get_variables(eom))) + lhs = getfield.(values(eom.equations), :lhs) + rhs = D.(get_variables(eom)) + diffs = Symbolics.simplify.(lhs .- rhs) + return all(is_literal_zero, diffs) end """ @@ -195,7 +198,17 @@ function rearrange!(eom::DifferentialEquation, new_lhs::Vector{Num}) return nothing end function get_variables_nums(vars::Vector{Num}) - unique(flatten([Num.(get_variables(x)) for x in vars])) + # Symbolics v7: `get_variables(Differential(t, n)(x(t)))` returns the derivative term + # itself, so we must explicitly strip derivatives to recover the dependent variable. + out = Num[] + for expr in vars + sym = Symbolics.unwrap(expr) + while Symbolics.is_derivative(sym) + sym = first(Symbolics.arguments(sym)) + end + push!(out, Num(sym)) + end + return out end # TODO: remove this function or at least better names """ diff --git a/src/HarmonicEquation.jl b/src/HarmonicEquation.jl index 4889146..c014a0c 100644 --- a/src/HarmonicEquation.jl +++ b/src/HarmonicEquation.jl @@ -62,11 +62,23 @@ end "Get the parameters (not time nor variables) of a HarmonicEquation" function _parameters(eom::HarmonicEquation) - all_symbols = flatten([ - cat(get_variables(eq.lhs), get_variables(eq.rhs); dims=1) for eq in eom.equations - ]) - # subtract the set of independent variables (i.e., time) from all free symbols - return setdiff(all_symbols, get_variables(eom), get_independent_variables(eom)) + symbols = Num[] + for eq in eom.equations + vars = union(Symbolics.get_variables(eq.lhs), Symbolics.get_variables(eq.rhs)) + vars = sort!(collect(vars); by=string) + for sym in vars + if Symbolics.is_derivative(sym) + sym = first(Symbolics.arguments(sym)) + end + push!(symbols, Num(sym)) + end + end + vars = Set(get_variables(eom)) + indep = Set(get_independent_variables(eom)) + params = filter(s -> !(s in vars || s in indep), symbols) + params = unique(params) + sort!(params; by=string) + return params end """ diff --git a/src/Symbolics/Symbolics_utils.jl b/src/Symbolics/Symbolics_utils.jl index a7a9f86..32ce732 100644 --- a/src/Symbolics/Symbolics_utils.jl +++ b/src/Symbolics/Symbolics_utils.jl @@ -2,6 +2,13 @@ expand_all(x::Num) = Num(expand_all(x.val)) _apply_termwise(f, x::Num) = wrap(_apply_termwise(f, unwrap(x))) +# Symbolics v7 note: `unwrap` now always returns `BasicSymbolic`; use `Symbolics.value` +# when you want the underlying literal number (via `Const`). +_literal_number(x::Num) = (v = Symbolics.value(x); v isa Number ? v : nothing) +_literal_number(x::BasicSymbolic) = + (v = SymbolicUtils.unwrap_const(x); v isa Number ? v : nothing) +is_literal_zero(x) = (v = _literal_number(x); v !== nothing && iszero(v)) + "Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)" function expand_all(x) result = Postwalk(expand_exp_power)(SymbolicUtils.expand(x)) @@ -15,7 +22,14 @@ function expand_fraction(x::BasicSymbolic) elseif ismul(x) return _apply_termwise(expand_fraction, x) elseif isdiv(x) - return sum([arg / x.den for arg in arguments(x.num)]) + # Only distribute division over addition in the numerator. + # In SymbolicUtils v4, `arguments(x.num)` for a multiplication returns factors, + # so splitting unconditionally would produce incorrect results (e.g. d*(a+b)/c). + num, den = x.num, x.den + if isadd(num) + return sum([expand_fraction(arg) / den for arg in arguments(num)]) + end + return expand_fraction(num) / expand_fraction(den) else return x end @@ -131,7 +145,7 @@ function is_harmonic(x::Num, t::Num)::Bool isempty(t_terms) && return true trigs = is_trig.(t_terms) - if !prod(trigs) + if !all(trigs) return false else powers = [max_power(first(arguments(term.val)), t) for term in t_terms[trigs]] diff --git a/src/Symbolics/drop_powers.jl b/src/Symbolics/drop_powers.jl index ce42a0a..a6c560d 100644 --- a/src/Symbolics/drop_powers.jl +++ b/src/Symbolics/drop_powers.jl @@ -17,10 +17,14 @@ function drop_powers(expr::Num, vars::Vector{Num}, deg::Int) Symbolics.@variables ϵ subs_expr = deepcopy(expr) rules = Dict([var => ϵ * var for var in unique(vars)]) - subs_expr = Symbolics.expand(substitute_all(subs_expr, rules)) + subs_expr = Symbolics.expand(substitute_all(subs_expr, rules; include_derivatives=false)) max_deg = max_power(subs_expr, ϵ) removal = Dict([ϵ^d => Num(0) for d in deg:max_deg]) - res = substitute_all(substitute_all(subs_expr, removal), Dict(ϵ => Num(1))) + res = substitute_all( + substitute_all(subs_expr, removal; include_derivatives=false), + Dict(ϵ => Num(1)); + include_derivatives=false, + ) return Symbolics.expand(res) end @@ -30,7 +34,7 @@ end # calls the above for various types of the first argument function drop_powers(eq::Equation, var::Vector{Num}, deg::Int) - return drop_powers(eq.lhs, var, deg) .~ drop_powers(eq.lhs, var, deg) + return drop_powers(eq.lhs, var, deg) .~ drop_powers(eq.rhs, var, deg) end function drop_powers(eqs::Vector{Equation}, var::Vector{Num}, deg::Int) return [ @@ -45,9 +49,12 @@ drop_powers(x, vars, deg::Int) = drop_powers(wrap(x), vars, deg) function max_power(x::Num, y::Num) terms = get_all_terms(x) powers = power_of.(terms, y) - literal_powers = Int[ - Int(SymbolicUtils.unwrap_const(p)) for p in powers if SymbolicUtils.is_literal_number(p) - ] + literal_powers = Int[] + for p in powers + pv = SymbolicUtils.unwrap_const(p) + pv isa Number || continue + push!(literal_powers, Int(pv)) + end isempty(literal_powers) && return 0 return maximum(literal_powers) end diff --git a/src/Symbolics/exponentials.jl b/src/Symbolics/exponentials.jl index 71203e0..e5265f2 100644 --- a/src/Symbolics/exponentials.jl +++ b/src/Symbolics/exponentials.jl @@ -43,8 +43,9 @@ function simplify_exp_products_mul(expr) rest_ind = setdiff(1:length(arguments(expr)), ind) rest = isempty(rest_ind) ? 1 : prod(arguments(expr)[rest_ind]) total = isempty(ind) ? 0 : sum(getindex.(arguments.(arguments(expr)[ind]), 1)) - if SymbolicUtils.is_literal_number(total) - return iszero(SymbolicUtils.unwrap_const(total)) ? rest : rest * exp(total) + total_val = SymbolicUtils.unwrap_const(total) + if total_val isa Number + return iszero(total_val) ? rest : rest * exp(total) end return rest * exp(total) end diff --git a/src/Symbolics/fourier.jl b/src/Symbolics/fourier.jl index 99751c4..83eabed 100644 --- a/src/Symbolics/fourier.jl +++ b/src/Symbolics/fourier.jl @@ -28,10 +28,73 @@ function trig_reduce(x) x = expand_all(x) # expand products of exponentials x = simplify_exp_products(x) # simplify products of exps x = exp_to_trig(x) + x = _trig_expand_products(x) x = Num(simplify_complex(expand(x))) return x # simplify_fractions(x)# (a*c^2 + b*c)/c^2 = (a*c + b)/c end +function _is_sin_cos(ex::BasicSymbolic) + return isterm(ex) && (operation(ex) === sin || operation(ex) === cos) +end + +function _trig_mul_to_sum(a::BasicSymbolic, b::BasicSymbolic) + op1, op2 = operation(a), operation(b) + x = first(arguments(a)) + y = first(arguments(b)) + if op1 === cos && op2 === cos + return (cos(x - y) + cos(x + y)) / 2 + elseif op1 === sin && op2 === sin + return (cos(x - y) - cos(x + y)) / 2 + elseif op1 === sin && op2 === cos + return (sin(x + y) + sin(x - y)) / 2 + elseif op1 === cos && op2 === sin + return (sin(x + y) - sin(x - y)) / 2 + end + return nothing +end + +function _trig_expand_products(x::BasicSymbolic) + # Expand trig products/powers into sums so `get_independent` can isolate constants. + y = Postwalk(ex -> begin + if ispow(ex) + base, exponent = arguments(ex) + exp_val = SymbolicUtils.unwrap_const(exponent) + if exp_val isa Integer && exp_val == 2 && _is_sin_cos(base) + arg = first(arguments(base)) + if operation(base) === cos + return (1 + cos(2 * arg)) / 2 + else + return (1 - cos(2 * arg)) / 2 + end + end + elseif ismul(ex) + # In SymbolicUtils v4, `arguments(ismul(...))` includes the numeric coefficient + # even though `ex.coeff` also stores it. Avoid double-counting it. + factors = BasicSymbolic[ + f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) + ] + trig_idx = findall(_is_sin_cos, factors) + if length(trig_idx) >= 2 + i, j = trig_idx[1], trig_idx[2] + repl = _trig_mul_to_sum(factors[i], factors[j]) + if repl !== nothing + others = BasicSymbolic[] + for (k, f) in pairs(factors) + (k == i || k == j) && continue + push!(others, f) + end + coeff = ex.coeff + return coeff * prod(others; init=1) * repl + end + end + end + return ex + end)(x) + return SymbolicUtils.expand(y) +end +_trig_expand_products(x::Num) = wrap(_trig_expand_products(unwrap(x))) +_trig_expand_products(x) = x + "Return true if `f` is a sin or cos." is_trig(f::Num) = is_trig(f.val) is_trig(f) = false @@ -103,14 +166,138 @@ function _fourier_term(x::Equation, ω, t, f) return Equation(_fourier_term(x.lhs, ω, t, f), _fourier_term(x.rhs, ω, t, f)) end +_real_if_complex(v) = v isa Complex{Num} ? v.re : v + +_canonicalize(ft) = _real_if_complex(Symbolics.simplify(Symbolics.expand(ft))) + +function _cleanup_fourier_term(ft) + ft = _real_if_complex(ft) + ft = _strip_real_imag(ft) + ft = _canonicalize(ft) + ft = _simplify_trig_zero(Num(ft)) + ft = simplify_complex(Symbolics.expand(ft)) + ft = _real_if_complex(ft) + ft = _strip_real_imag(ft) + ft = _normalize_trig_signs(unwrap(ft)) + ft = _strip_zero_imag_literals(wrap(ft)) + ft = _canonicalize(ft) + ft = _real_if_complex(ft) + return ft +end + "Return the coefficient of f(ωt) in `x` where `f` is a cos or sin." function _fourier_term(x, ω, t, f) term = x * f(ω * t) term = trig_reduce(term) indep = get_independent(term, t) - ft = Num(simplify_complex(Symbolics.expand(indep))) + ft = simplify_complex(Symbolics.expand(indep)) ft = !isequal(ω, 0) ? 2 * ft : ft # extra factor in case ω = 0 ! - return Symbolics.expand(ft) + ft = _cleanup_fourier_term(ft) + return ft isa Num ? ft : Num(ft) +end + +function _strip_real_imag(x::Complex{Num}) + return _strip_real_imag(x.re) + im * _strip_real_imag(x.im) +end + +function _strip_zero_imag_literals(x::BasicSymbolic) + return Postwalk(ex -> begin + v = SymbolicUtils.unwrap_const(ex) + if v isa Complex && iszero(imag(v)) + return real(v) + end + return ex + end)(x) +end + +function _strip_zero_imag_literals(x::Num) + return wrap(_strip_zero_imag_literals(unwrap(x))) +end + +function _strip_zero_imag_literals(x::Complex{Num}) + return _strip_zero_imag_literals(x.re) + im * _strip_zero_imag_literals(x.im) +end + +function _strip_real_imag(x::BasicSymbolic) + function _real_of(ex::BasicSymbolic) + if isadd(ex) + return sum(_real_of.(arguments(ex))) + elseif ismul(ex) + coeff_val = SymbolicUtils.unwrap_const(ex.coeff) + if coeff_val isa Number + r = real(coeff_val) + rest = prod( + (f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number)); + init=1, + ) + return iszero(r) ? 0 : r * rest + end + return ex + elseif isdiv(ex) + return _real_of(ex.num) / ex.den + else + v = SymbolicUtils.unwrap_const(ex) + return v isa Number ? real(v) : ex + end + end + + function _imag_of(ex::BasicSymbolic) + if isadd(ex) + return sum(_imag_of.(arguments(ex))) + elseif ismul(ex) + coeff_val = SymbolicUtils.unwrap_const(ex.coeff) + if coeff_val isa Number + i = imag(coeff_val) + rest = prod( + (f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number)); + init=1, + ) + return iszero(i) ? 0 : i * rest + end + return 0 + elseif isdiv(ex) + return _imag_of(ex.num) / ex.den + else + v = SymbolicUtils.unwrap_const(ex) + return v isa Number ? imag(v) : 0 + end + end + + return Postwalk(ex -> begin + if isterm(ex) + op = operation(ex) + if op === real || nameof(op) == :real + return _real_of(first(arguments(ex))) + elseif op === imag || nameof(op) == :imag + return _imag_of(first(arguments(ex))) + end + end + return ex + end)(x) +end + +function _strip_real_imag(x::Num) + return wrap(_strip_real_imag(unwrap(x))) +end + +function _simplify_trig_zero(x::BasicSymbolic) + return Postwalk(ex -> begin + if isterm(ex) + op = operation(ex) + if op === sin || op === cos + arg = first(arguments(ex)) + arg_val = SymbolicUtils.unwrap_const(arg) + if arg_val isa Number && iszero(arg_val) + return op === sin ? 0 : 1 + end + end + end + return ex + end)(x) +end + +function _simplify_trig_zero(x::Num) + return wrap(_simplify_trig_zero(unwrap(x))) end """ @@ -180,7 +367,7 @@ trigonometric arguments for consistent simplification. """ function exp_to_trig(x::BasicSymbolic) if isadd(x) || isdiv(x) || ismul(x) - return _apply_termwise(exp_to_trig, x) + return Symbolics.simplify(_normalize_trig_signs(_apply_termwise(exp_to_trig, x))) elseif isterm(x) && operation(x) === exp arg = first(arguments(x)) trigarg = Symbolics.expand(-im * arg) # the argument of the to-be trig function @@ -203,19 +390,46 @@ function exp_to_trig(x::BasicSymbolic) end if ismul(trigarg) coeff = trigarg.coeff - if SymbolicUtils.is_literal_number(coeff) - coeff_val = SymbolicUtils.unwrap_const(coeff) - if coeff_val isa Real && coeff_val < 0 - return cos(-trigarg) - im * sin(-trigarg) - end + coeff_val = SymbolicUtils.unwrap_const(coeff) + if coeff_val isa Real && coeff_val < 0 + return cos(-trigarg) - im * sin(-trigarg) end end - return cos(trigarg) + im * sin(trigarg) + return _normalize_trig_signs(cos(trigarg) + im * sin(trigarg)) else return x end end +function _normalize_trig_signs(x::BasicSymbolic) + if isadd(x) || ismul(x) + args = SymbolicUtils.sorted_arguments(x) + return (isadd(x) ? sum : prod)(_normalize_trig_signs.(args)) + elseif isdiv(x) + return _normalize_trig_signs(x.num) / _normalize_trig_signs(x.den) + elseif isterm(x) + op = operation(x) + if op === real || op === imag || nameof(op) == :real || nameof(op) == :imag + arg = first(arguments(x)) + return op(_normalize_trig_signs(arg)) + elseif op === sin || op === cos + arg = first(arguments(x)) + if SymbolicUtils.isnegative(arg) + new_arg = -arg + return op === sin ? -sin(new_arg) : cos(new_arg) + elseif ismul(arg) + coeff_val = SymbolicUtils.unwrap_const(arg.coeff) + # Handle coefficients that are complex with zero imaginary part, e.g. (-2 + 0im)θ. + if coeff_val isa Number && isreal(coeff_val) && real(coeff_val) < 0 + new_arg = -arg + return op === sin ? -sin(new_arg) : cos(new_arg) + end + end + end + end + return x +end + exp_to_trig(x) = x exp_to_trig(x::Num) = exp_to_trig(x.val) exp_to_trig(x::Complex{Num}) = exp_to_trig(x.re) + im * exp_to_trig(x.im) diff --git a/src/Variables.jl b/src/Variables.jl index 4b04e2f..099857f 100644 --- a/src/Variables.jl +++ b/src/Variables.jl @@ -24,8 +24,8 @@ end "Return the name of a variable (excluding independent variables)" function var_name(x::Num)::String - var = Symbolics._toexpr(x) - var = var isa Expr ? String(var.args[1]) : String(var) + var = string(x) + var = replace(var, r"\(.*\)$" => "") return String(replace(var, r"\\mathtt\{([^}]*)\}" => s"\1")) # ^ remove "\\mathtt{}" from the variable name coming from Symbolics # since Symbolics v6.14.1 (Symbolics#1305) diff --git a/src/utils.jl b/src/utils.jl index fe063ea..5ece8c0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,13 +12,68 @@ function show_fields(object) end end +_is_symbolic_like(x) = x isa Num || x isa BasicSymbolic +_is_symbolic_like(x::Complex) = _is_symbolic_like(real(x)) || _is_symbolic_like(imag(x)) + +function _eqtest_symbolic_scalar(a, b) + diff = Symbolics.simplify(Symbolics.expand(a - b)) + return isequal(diff, 0) +end + +function _eqtest_symbolic_scalar(a::Complex, b::Complex) + return _eqtest_equal(real(a), real(b)) && _eqtest_equal(imag(a), imag(b)) +end + +function _eqtest_symbolic_scalar(a::Complex, b) + return _eqtest_equal(real(a), b) && _eqtest_equal(imag(a), 0) +end + +function _eqtest_symbolic_scalar(a, b::Complex) + return _eqtest_equal(a, real(b)) && _eqtest_equal(0, imag(b)) +end + +function _eqtest_equal(a::AbstractArray, b::AbstractArray) + size(a) == size(b) || return false + for (aa, bb) in zip(a, b) + _eqtest_equal(aa, bb) || return false + end + return true +end + +function _eqtest_equal(a::Tuple, b::Tuple) + length(a) == length(b) || return false + for (aa, bb) in zip(a, b) + _eqtest_equal(aa, bb) || return false + end + return true +end + +function _eqtest_equal(a::Equation, b::Equation) + return _eqtest_equal(a.lhs, b.lhs) && _eqtest_equal(a.rhs, b.rhs) +end + +function _eqtest_equal(a, b) + isequal(a, b) && return true + if _is_symbolic_like(a) || _is_symbolic_like(b) + try + return _eqtest_symbolic_scalar(a, b) + catch + return false + end + end + return false +end + +_eqtest_notequal(a, b) = !_eqtest_equal(a, b) + macro eqtest(expr) @assert expr.head == :call && expr.args[1] in [:(==), :(!=)] + qb = QuoteNode(QuestBase) return esc( if expr.args[1] == :(==) - :(@test isequal($(expr.args[2]), $(expr.args[3]))) + :(@test $qb._eqtest_equal($(expr.args[2]), $(expr.args[3]))) else - :(@test !isequal($(expr.args[2]), $(expr.args[3]))) + :(@test $qb._eqtest_notequal($(expr.args[2]), $(expr.args[3]))) end, ) end @@ -35,6 +90,6 @@ macro eqsym(expr) end is_identity(A::Matrix{Num}) = (@eqsym A == Matrix{Num}(LinearAlgebra.I, size(A)...)) -hasnan(x::Matrix{Num}) = any(my_isnan, unwrap.(x)) +hasnan(x::Matrix{Num}) = any(my_isnan, Symbolics.value.(x)) my_isnan(x) = isnan(x) my_isnan(x::BasicSymbolic) = false diff --git a/test/DifferentialEquations.jl b/test/DifferentialEquations.jl index 79b85bb..9f1427e 100644 --- a/test/DifferentialEquations.jl +++ b/test/DifferentialEquations.jl @@ -41,7 +41,9 @@ using QuestBase: expr = d(x, t, 2) + ω0^2 * x diff_eq3 = DifferentialEquation(expr, x) @test length(diff_eq3.equations) == 1 - @test diff_eq3.equations[x].rhs == 0 + rhs_simplified = Symbolics.simplify(diff_eq3.equations[x].rhs) + rhs_val = Symbolics.value(rhs_simplified) + @test rhs_val isa Number && iszero(rhs_val) # Test empty constructor diff = DifferentialEquation() diff --git a/test/symbolics.jl b/test/symbolics.jl index b833245..9920f1d 100644 --- a/test/symbolics.jl +++ b/test/symbolics.jl @@ -1,8 +1,8 @@ using Test using Symbolics -using SymbolicUtils: Fixpoint, Prewalk, PassThrough, BasicSymbolic +using SymbolicUtils: BasicSymbolic -using QuestBase: @eqtest +using QuestBase: @eqtest, trig_reduce @testset "exp(x)^n => exp(x*n)" begin using QuestBase: expand_all, expand_exp_power @@ -27,7 +27,7 @@ end @testset "euler" begin @variables a b - @eqtest cos(a) + im * sin(a) == exp(im * a) + @test isequal(trig_reduce(cos(a) + im * sin(a) - exp(im * a)), 0) @eqtest exp(a) * cos(b) + im * sin(b) * exp(a) == exp(a + im * b) end @@ -52,7 +52,8 @@ end # eq = drop_powers(a^2 + a ~ b, [a, b], 2) # broken @eqtest [eq.lhs, eq.rhs] == [a, a] eq = drop_powers(a^2 + a + b ~ a, a, 2) - @test string(eq.rhs) == "a" broken = true + @eqtest eq.lhs == a + b + @eqtest eq.rhs == a @eqtest drop_powers([a^2 + a + b, b], a, 2) == [a + b, b] @eqtest drop_powers([a^2 + a + b, b], [a, b], 2) == [a + b, b] @@ -65,16 +66,17 @@ end cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2 sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im) - # automatic conversion between trig and exp form + # Conversion between trig and exp form. + # We validate by substituting numeric values (robust across Symbolics canonicalization). trigs = [cos(f * t), sin(f * t)] - for (i, trig) in pairs(trigs) + samples = ((1.3, 0.7), (2.0, 0.1), (-0.4, 1.1)) + for trig in trigs z = trig_to_exp(trig) - @eqtest simplify(expand(exp_to_trig(z))) == trig - end - trigs′ = [cos_euler(f * t), sin_euler(f * t)] - for (i, trig) in pairs(trigs′) - z = trig_to_exp(trig) - @eqtest simplify(expand(exp_to_trig(z))) == trigs[i] + back = exp_to_trig(z) + for (fv, tv) in samples + d = Symbolics.substitute(back - trig, Dict(f => fv, t => tv)) + @test Symbolics.value(d) == 0 + end end end @@ -99,7 +101,6 @@ end @testset "harmonic" begin using QuestBase: is_harmonic - @variables a, b, c, t, x(t), f, y(t) @test is_harmonic(cos(f * t), t) @@ -117,7 +118,7 @@ end using QuestBase: fourier_cos_term, fourier_sin_term using QuestBase.Symbolics: expand - @variables f t θ a b + @variables f t θ a b c @eqtest fourier_cos_term(cos(f * t)^2, f, t) == 0 @eqtest fourier_sin_term(sin(f * t)^2, f, t) == 0 @@ -147,10 +148,11 @@ end # try something harder! term = (a + b * cos(f * t + θ)^2)^3 * sin(f * t) - @eqtest fourier_sin_term(term, f, t) == expand( - a^3 + a^2 * b * 3//2 + 9//8 * a * b^2 + 5//16 * b^3 - - 3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * cos(2 * θ), - ) + @eqtest fourier_sin_term(term, f, t) == + expand( + a^3 + a^2 * b * 3//2 + 9//8 * a * b^2 + 5//16 * b^3 - + 3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * cos(2 * θ), + ) @eqtest fourier_cos_term(term, f, t) == expand(-3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * sin(2 * θ)) @@ -160,6 +162,33 @@ end @eqtest fourier_cos_term(cos(f * t)^3 + 1, 0, t) == 1 @eqtest fourier_cos_term(cos(f * t)^2 + 1, 0, t) == 3//2 @eqtest fourier_cos_term((cos(f * t)^2 + cos(f * t))^3, 0, t) == 23//16 + + # more complex but closed-form cases + term = (a + b * cos(f * t))^2 + @eqtest fourier_cos_term(term, f, t) == 2 * a * b + @eqtest fourier_sin_term(term, f, t) == 0 + @eqtest fourier_cos_term(term, 2 * f, t) == b^2 / 2 + @eqtest fourier_sin_term(term, 2 * f, t) == 0 + @eqtest fourier_cos_term(term, 0, t) == a^2 + b^2 / 2 + + term = (a + b * sin(f * t))^2 + @eqtest fourier_cos_term(term, f, t) == 0 + @eqtest fourier_sin_term(term, f, t) == 2 * a * b + @eqtest fourier_cos_term(term, 2 * f, t) == -b^2 / 2 + @eqtest fourier_sin_term(term, 2 * f, t) == 0 + @eqtest fourier_cos_term(term, 0, t) == a^2 + b^2 / 2 + + term = (a + b * cos(f * t + θ)) * (a + b * cos(f * t - θ)) + @eqtest fourier_cos_term(term, f, t) == 2 * a * b * cos(θ) + @eqtest fourier_sin_term(term, f, t) == 0 + @eqtest fourier_cos_term(term, 2 * f, t) == b^2 / 2 + @eqtest fourier_sin_term(term, 2 * f, t) == 0 + @eqtest fourier_cos_term(term, 0, t) == a^2 + b^2 / 2 * cos(2 * θ) + + term = (a + b * cos(f * t))^3 + @eqtest fourier_cos_term(term, f, t) == 3 * a^2 * b + 3//4 * b^3 + @eqtest fourier_sin_term(term, f, t) == 0 + @eqtest fourier_cos_term(term, 0, t) == a^3 + 3//2 * a * b^2 end @testset "_apply_termwise" begin @@ -180,7 +209,8 @@ end end z = Complex{Num}(1 + 0 * im) - @test SymbolicUtils.is_literal_number(simplify_complex(z).val) + z_val = SymbolicUtils.unwrap_const(simplify_complex(z).val) + @test z_val isa Number && z_val == 1 end @testset "get_all_terms" begin @@ -218,8 +248,9 @@ end @variables a, b, c, d @eqtest expand_fraction((a + b) / c) == a / c + b / c - @test string.(expand_fraction(d * (a + b) / c)) == "d*a /c + d*b / c + d / c" broken = - true + lhs = Symbolics.expand(expand_fraction(d * (a + b) / c)) + rhs = d * a / c + d * b / c + @eqtest lhs == rhs end @testset "count_derivatives" begin using QuestBase: count_derivatives, d From 6c2a2cbaec71f6f0071fa770e5752991f9d284a0 Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Sat, 24 Jan 2026 20:26:43 +0100 Subject: [PATCH 3/8] format --- src/Symbolics/Symbolics_utils.jl | 20 +++++---- src/Symbolics/drop_powers.jl | 4 +- src/Symbolics/fourier.jl | 76 ++++++++++++++++++-------------- test/DifferentialEquations.jl | 6 +-- test/symbolics.jl | 12 ++--- 5 files changed, 66 insertions(+), 52 deletions(-) diff --git a/src/Symbolics/Symbolics_utils.jl b/src/Symbolics/Symbolics_utils.jl index 32ce732..a49133d 100644 --- a/src/Symbolics/Symbolics_utils.jl +++ b/src/Symbolics/Symbolics_utils.jl @@ -4,10 +4,11 @@ _apply_termwise(f, x::Num) = wrap(_apply_termwise(f, unwrap(x))) # Symbolics v7 note: `unwrap` now always returns `BasicSymbolic`; use `Symbolics.value` # when you want the underlying literal number (via `Const`). -_literal_number(x::Num) = (v = Symbolics.value(x); v isa Number ? v : nothing) -_literal_number(x::BasicSymbolic) = - (v = SymbolicUtils.unwrap_const(x); v isa Number ? v : nothing) -is_literal_zero(x) = (v = _literal_number(x); v !== nothing && iszero(v)) +_literal_number(x::Num) = (v=Symbolics.value(x); v isa Number ? v : nothing) +function _literal_number(x::BasicSymbolic) + (v=SymbolicUtils.unwrap_const(x); v isa Number ? v : nothing) +end +is_literal_zero(x) = (v=_literal_number(x); v !== nothing && iszero(v)) "Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)" function expand_all(x) @@ -145,7 +146,7 @@ function is_harmonic(x::Num, t::Num)::Bool isempty(t_terms) && return true trigs = is_trig.(t_terms) - if !all(trigs) + if !all(trigs) return false else powers = [max_power(first(arguments(term.val)), t) for term in t_terms[trigs]] @@ -165,10 +166,11 @@ Counts the number of derivatives of a symbolic variable. function count_derivatives(x::BasicSymbolic) if Symbolics.is_derivative(x) arg = first(arguments(x)) - (issym(arg) || - Symbolics.is_derivative(arg) || - (isterm(arg) && issym(operation(arg)))) || - error("The input is not a single term or symbol") + ( + issym(arg) || + Symbolics.is_derivative(arg) || + (isterm(arg) && issym(operation(arg))) + ) || error("The input is not a single term or symbol") D = operation(x) return D.order + count_derivatives(arg) end diff --git a/src/Symbolics/drop_powers.jl b/src/Symbolics/drop_powers.jl index a6c560d..ad8409e 100644 --- a/src/Symbolics/drop_powers.jl +++ b/src/Symbolics/drop_powers.jl @@ -17,7 +17,9 @@ function drop_powers(expr::Num, vars::Vector{Num}, deg::Int) Symbolics.@variables ϵ subs_expr = deepcopy(expr) rules = Dict([var => ϵ * var for var in unique(vars)]) - subs_expr = Symbolics.expand(substitute_all(subs_expr, rules; include_derivatives=false)) + subs_expr = Symbolics.expand( + substitute_all(subs_expr, rules; include_derivatives=false) + ) max_deg = max_power(subs_expr, ϵ) removal = Dict([ϵ^d => Num(0) for d in deg:max_deg]) res = substitute_all( diff --git a/src/Symbolics/fourier.jl b/src/Symbolics/fourier.jl index 83eabed..bf32d9b 100644 --- a/src/Symbolics/fourier.jl +++ b/src/Symbolics/fourier.jl @@ -55,41 +55,45 @@ end function _trig_expand_products(x::BasicSymbolic) # Expand trig products/powers into sums so `get_independent` can isolate constants. - y = Postwalk(ex -> begin - if ispow(ex) - base, exponent = arguments(ex) - exp_val = SymbolicUtils.unwrap_const(exponent) - if exp_val isa Integer && exp_val == 2 && _is_sin_cos(base) - arg = first(arguments(base)) - if operation(base) === cos - return (1 + cos(2 * arg)) / 2 - else - return (1 - cos(2 * arg)) / 2 + y = Postwalk( + ex -> begin + if ispow(ex) + base, exponent = arguments(ex) + exp_val = SymbolicUtils.unwrap_const(exponent) + if exp_val isa Integer && exp_val == 2 && _is_sin_cos(base) + arg = first(arguments(base)) + if operation(base) === cos + return (1 + cos(2 * arg)) / 2 + else + return (1 - cos(2 * arg)) / 2 + end end - end - elseif ismul(ex) - # In SymbolicUtils v4, `arguments(ismul(...))` includes the numeric coefficient - # even though `ex.coeff` also stores it. Avoid double-counting it. - factors = BasicSymbolic[ - f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) - ] - trig_idx = findall(_is_sin_cos, factors) - if length(trig_idx) >= 2 - i, j = trig_idx[1], trig_idx[2] - repl = _trig_mul_to_sum(factors[i], factors[j]) - if repl !== nothing - others = BasicSymbolic[] - for (k, f) in pairs(factors) - (k == i || k == j) && continue - push!(others, f) + elseif ismul(ex) + # In SymbolicUtils v4, `arguments(ismul(...))` includes the numeric coefficient + # even though `ex.coeff` also stores it. Avoid double-counting it. + factors = BasicSymbolic[ + f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) + ] + trig_idx = findall(_is_sin_cos, factors) + if length(trig_idx) >= 2 + i, j = trig_idx[1], trig_idx[2] + repl = _trig_mul_to_sum(factors[i], factors[j]) + if repl !== nothing + others = BasicSymbolic[] + for (k, f) in pairs(factors) + (k == i || k == j) && continue + push!(others, f) + end + coeff = ex.coeff + return coeff * prod(others; init=1) * repl end - coeff = ex.coeff - return coeff * prod(others; init=1) * repl end end - end - return ex - end)(x) + return ex + end, + )( + x + ) return SymbolicUtils.expand(y) end _trig_expand_products(x::Num) = wrap(_trig_expand_products(unwrap(x))) @@ -227,7 +231,10 @@ function _strip_real_imag(x::BasicSymbolic) if coeff_val isa Number r = real(coeff_val) rest = prod( - (f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number)); + ( + f for + f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) + ); init=1, ) return iszero(r) ? 0 : r * rest @@ -249,7 +256,10 @@ function _strip_real_imag(x::BasicSymbolic) if coeff_val isa Number i = imag(coeff_val) rest = prod( - (f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number)); + ( + f for + f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) + ); init=1, ) return iszero(i) ? 0 : i * rest diff --git a/test/DifferentialEquations.jl b/test/DifferentialEquations.jl index 9f1427e..55559b1 100644 --- a/test/DifferentialEquations.jl +++ b/test/DifferentialEquations.jl @@ -41,9 +41,9 @@ using QuestBase: expr = d(x, t, 2) + ω0^2 * x diff_eq3 = DifferentialEquation(expr, x) @test length(diff_eq3.equations) == 1 - rhs_simplified = Symbolics.simplify(diff_eq3.equations[x].rhs) - rhs_val = Symbolics.value(rhs_simplified) - @test rhs_val isa Number && iszero(rhs_val) + rhs_simplified = Symbolics.simplify(diff_eq3.equations[x].rhs) + rhs_val = Symbolics.value(rhs_simplified) + @test rhs_val isa Number && iszero(rhs_val) # Test empty constructor diff = DifferentialEquation() diff --git a/test/symbolics.jl b/test/symbolics.jl index 9920f1d..2d957e1 100644 --- a/test/symbolics.jl +++ b/test/symbolics.jl @@ -148,11 +148,10 @@ end # try something harder! term = (a + b * cos(f * t + θ)^2)^3 * sin(f * t) - @eqtest fourier_sin_term(term, f, t) == - expand( - a^3 + a^2 * b * 3//2 + 9//8 * a * b^2 + 5//16 * b^3 - - 3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * cos(2 * θ), - ) + @eqtest fourier_sin_term(term, f, t) == expand( + a^3 + a^2 * b * 3//2 + 9//8 * a * b^2 + 5//16 * b^3 - + 3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * cos(2 * θ), + ) @eqtest fourier_cos_term(term, f, t) == expand(-3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * sin(2 * θ)) @@ -220,7 +219,8 @@ end @eqtest sort(get_all_terms(a + b + c); by=string) == sort([a, b, c]; by=string) @eqtest sort(get_all_terms(a * b * c); by=string) == sort([a, b, c]; by=string) @eqtest sort(get_all_terms(a / b); by=string) == sort([a, b]; by=string) - @eqtest sort(get_all_terms(a^2 + b^2 + c^2); by=string) == sort([a^2, b^2, c^2]; by=string) + @eqtest sort(get_all_terms(a^2 + b^2 + c^2); by=string) == + sort([a^2, b^2, c^2]; by=string) @eqtest sort(get_all_terms(a^2 / b^2); by=string) == sort([a^2, b^2]; by=string) @eqtest sort(get_all_terms(2 * b^2); by=string) == sort([2, b^2]; by=string) @eqtest sort(get_all_terms(2 * b^2 ~ a); by=string) == sort([2, b^2, a]; by=string) From c63df32bc53e9d3b271321baa01ca401ec967def Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Sat, 24 Jan 2026 20:45:44 +0100 Subject: [PATCH 4/8] rule based trigo rules --- src/Symbolics/fourier.jl | 83 ++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 54 deletions(-) diff --git a/src/Symbolics/fourier.jl b/src/Symbolics/fourier.jl index bf32d9b..e29e5cb 100644 --- a/src/Symbolics/fourier.jl +++ b/src/Symbolics/fourier.jl @@ -37,63 +37,38 @@ function _is_sin_cos(ex::BasicSymbolic) return isterm(ex) && (operation(ex) === sin || operation(ex) === cos) end -function _trig_mul_to_sum(a::BasicSymbolic, b::BasicSymbolic) - op1, op2 = operation(a), operation(b) - x = first(arguments(a)) - y = first(arguments(b)) - if op1 === cos && op2 === cos - return (cos(x - y) + cos(x + y)) / 2 - elseif op1 === sin && op2 === sin - return (cos(x - y) - cos(x + y)) / 2 - elseif op1 === sin && op2 === cos - return (sin(x + y) + sin(x - y)) / 2 - elseif op1 === cos && op2 === sin - return (sin(x + y) - sin(x - y)) / 2 - end - return nothing -end +const _rw_trig_mul_to_sum = SymbolicUtils.Rewriters.Chain([ + SymbolicUtils.@rule(cos(~x) * cos(~y) => (cos(~x - ~y) + cos(~x + ~y)) / 2), + SymbolicUtils.@rule(sin(~x) * sin(~y) => (cos(~x - ~y) - cos(~x + ~y)) / 2), + SymbolicUtils.@rule(sin(~x) * cos(~y) => (sin(~x + ~y) + sin(~x - ~y)) / 2), + SymbolicUtils.@rule(cos(~x) * sin(~y) => (sin(~x + ~y) - sin(~x - ~y)) / 2), +]) + +const _rw_trig_expand = SymbolicUtils.Rewriters.Fixpoint( + SymbolicUtils.Rewriters.Postwalk( + SymbolicUtils.Rewriters.Chain([ + SymbolicUtils.@rule((cos(~x))^2 => (1 + cos(2 * ~x)) / 2), + SymbolicUtils.@rule((sin(~x))^2 => (1 - cos(2 * ~x)) / 2), + _rw_trig_mul_to_sum, + ]), + ), +) function _trig_expand_products(x::BasicSymbolic) # Expand trig products/powers into sums so `get_independent` can isolate constants. - y = Postwalk( - ex -> begin - if ispow(ex) - base, exponent = arguments(ex) - exp_val = SymbolicUtils.unwrap_const(exponent) - if exp_val isa Integer && exp_val == 2 && _is_sin_cos(base) - arg = first(arguments(base)) - if operation(base) === cos - return (1 + cos(2 * arg)) / 2 - else - return (1 - cos(2 * arg)) / 2 - end - end - elseif ismul(ex) - # In SymbolicUtils v4, `arguments(ismul(...))` includes the numeric coefficient - # even though `ex.coeff` also stores it. Avoid double-counting it. - factors = BasicSymbolic[ - f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) - ] - trig_idx = findall(_is_sin_cos, factors) - if length(trig_idx) >= 2 - i, j = trig_idx[1], trig_idx[2] - repl = _trig_mul_to_sum(factors[i], factors[j]) - if repl !== nothing - others = BasicSymbolic[] - for (k, f) in pairs(factors) - (k == i || k == j) && continue - push!(others, f) - end - coeff = ex.coeff - return coeff * prod(others; init=1) * repl - end - end - end - return ex - end, - )( - x - ) + y = Postwalk(ex -> begin + if ismul(ex) + # In SymbolicUtils v4, `arguments(ismul(...))` includes the numeric coefficient + # even though `ex.coeff` also stores it. Avoid double-counting it. + coeff = ex.coeff + factors = BasicSymbolic[ + f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) + ] + rest = isempty(factors) ? 1 : prod(factors; init=1) + return coeff * _rw_trig_expand(rest) + end + return _rw_trig_expand(ex) + end)(x) return SymbolicUtils.expand(y) end _trig_expand_products(x::Num) = wrap(_trig_expand_products(unwrap(x))) From cc061f8ce685a149f4bf729089fb20aadc43bdb9 Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Sat, 24 Jan 2026 20:56:49 +0100 Subject: [PATCH 5/8] fix lts JET fail --- src/QuestBase.jl | 2 +- src/Symbolics/fourier.jl | 42 ++++++++++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/QuestBase.jl b/src/QuestBase.jl index 0c87448..057c796 100644 --- a/src/QuestBase.jl +++ b/src/QuestBase.jl @@ -8,6 +8,7 @@ using SymbolicUtils: SymbolicUtils, Postwalk, BasicSymbolic, + unwrap, isterm, ispow, isadd, @@ -19,7 +20,6 @@ using SymbolicUtils: using Symbolics: Symbolics, Num, - unwrap, wrap, get_variables, Equation, diff --git a/src/Symbolics/fourier.jl b/src/Symbolics/fourier.jl index e29e5cb..b4d6f08 100644 --- a/src/Symbolics/fourier.jl +++ b/src/Symbolics/fourier.jl @@ -56,19 +56,23 @@ const _rw_trig_expand = SymbolicUtils.Rewriters.Fixpoint( function _trig_expand_products(x::BasicSymbolic) # Expand trig products/powers into sums so `get_independent` can isolate constants. - y = Postwalk(ex -> begin - if ismul(ex) - # In SymbolicUtils v4, `arguments(ismul(...))` includes the numeric coefficient - # even though `ex.coeff` also stores it. Avoid double-counting it. - coeff = ex.coeff - factors = BasicSymbolic[ - f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) - ] - rest = isempty(factors) ? 1 : prod(factors; init=1) - return coeff * _rw_trig_expand(rest) - end - return _rw_trig_expand(ex) - end)(x) + y = Postwalk( + ex -> begin + if ismul(ex) + # In SymbolicUtils v4, `arguments(ismul(...))` includes the numeric coefficient + # even though `ex.coeff` also stores it. Avoid double-counting it. + coeff = ex.coeff + factors = BasicSymbolic[ + f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) + ] + rest = isempty(factors) ? 1 : prod(factors; init=1) + return coeff * _rw_trig_expand(rest) + end + return _rw_trig_expand(ex) + end, + )( + x + ) return SymbolicUtils.expand(y) end _trig_expand_products(x::Num) = wrap(_trig_expand_products(unwrap(x))) @@ -401,13 +405,21 @@ function _normalize_trig_signs(x::BasicSymbolic) arg = first(arguments(x)) if SymbolicUtils.isnegative(arg) new_arg = -arg - return op === sin ? -sin(new_arg) : cos(new_arg) + return if op === sin + -SymbolicUtils.term(sin, new_arg) + else + SymbolicUtils.term(cos, new_arg) + end elseif ismul(arg) coeff_val = SymbolicUtils.unwrap_const(arg.coeff) # Handle coefficients that are complex with zero imaginary part, e.g. (-2 + 0im)θ. if coeff_val isa Number && isreal(coeff_val) && real(coeff_val) < 0 new_arg = -arg - return op === sin ? -sin(new_arg) : cos(new_arg) + return if op === sin + -SymbolicUtils.term(sin, new_arg) + else + SymbolicUtils.term(cos, new_arg) + end end end end From c26f6dd6e439dac8f07920351d0b1d54fa9aad27 Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Sat, 24 Jan 2026 21:14:39 +0100 Subject: [PATCH 6/8] update version to 0.4.0 and compatibility for SymbolicUtils and Symbolics --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index c60bd29..3264bf2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "QuestBase" uuid = "7e80f742-43d6-403d-a9ea-981410111d43" authors = ["Orjan Ameye ", "Jan Kosata ", "Javier del Pino "] -version = "0.3.4" +version = "0.4.0" [deps] DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -12,8 +12,8 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] DocStringExtensions = "0.9.4" -SymbolicUtils = "3.25, 4" -Symbolics = "6.34, 7" +SymbolicUtils = "4.13.1" +Symbolics = "7.8.0" julia = "1.10" Random = "1.10" LinearAlgebra = "1.10" From 2d3e5bd1ed10e2427b863d0cd85e27f74217652c Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Sat, 24 Jan 2026 21:30:40 +0100 Subject: [PATCH 7/8] try to reduce new lines --- src/DifferentialEquation.jl | 12 +------- src/HarmonicEquation.jl | 5 +-- src/Symbolics/fourier.jl | 60 ++++++++++++++---------------------- src/utils.jl | 53 ++++++++++++++------------------ test/symbolics.jl | 61 +++++++++++++++++++------------------ 5 files changed, 80 insertions(+), 111 deletions(-) diff --git a/src/DifferentialEquation.jl b/src/DifferentialEquation.jl index ea4f2e3..4bd57f1 100644 --- a/src/DifferentialEquation.jl +++ b/src/DifferentialEquation.jl @@ -198,17 +198,7 @@ function rearrange!(eom::DifferentialEquation, new_lhs::Vector{Num}) return nothing end function get_variables_nums(vars::Vector{Num}) - # Symbolics v7: `get_variables(Differential(t, n)(x(t)))` returns the derivative term - # itself, so we must explicitly strip derivatives to recover the dependent variable. - out = Num[] - for expr in vars - sym = Symbolics.unwrap(expr) - while Symbolics.is_derivative(sym) - sym = first(Symbolics.arguments(sym)) - end - push!(out, Num(sym)) - end - return out + return strip_derivative.(vars) end # TODO: remove this function or at least better names """ diff --git a/src/HarmonicEquation.jl b/src/HarmonicEquation.jl index c014a0c..ca8043a 100644 --- a/src/HarmonicEquation.jl +++ b/src/HarmonicEquation.jl @@ -67,10 +67,7 @@ function _parameters(eom::HarmonicEquation) vars = union(Symbolics.get_variables(eq.lhs), Symbolics.get_variables(eq.rhs)) vars = sort!(collect(vars); by=string) for sym in vars - if Symbolics.is_derivative(sym) - sym = first(Symbolics.arguments(sym)) - end - push!(symbols, Num(sym)) + push!(symbols, strip_derivative(Num(sym))) end end vars = Set(get_variables(eom)) diff --git a/src/Symbolics/fourier.jl b/src/Symbolics/fourier.jl index b4d6f08..f049900 100644 --- a/src/Symbolics/fourier.jl +++ b/src/Symbolics/fourier.jl @@ -13,14 +13,6 @@ This function performs the following steps: Returns the simplified expression as a `Num` type. """ - -""" - is_trig(f::Num) - -Check if the given expression `f` is a trigonometric function (sine or cosine). - -Returns `true` if `f` is either `sin` or `cos`, `false` otherwise. -""" function trig_reduce(x) x = add_div(x) # a/b + c/d = (ad + bc)/bd x = expand(x) # open all brackets @@ -33,10 +25,6 @@ function trig_reduce(x) return x # simplify_fractions(x)# (a*c^2 + b*c)/c^2 = (a*c + b)/c end -function _is_sin_cos(ex::BasicSymbolic) - return isterm(ex) && (operation(ex) === sin || operation(ex) === cos) -end - const _rw_trig_mul_to_sum = SymbolicUtils.Rewriters.Chain([ SymbolicUtils.@rule(cos(~x) * cos(~y) => (cos(~x - ~y) + cos(~x + ~y)) / 2), SymbolicUtils.@rule(sin(~x) * sin(~y) => (cos(~x - ~y) - cos(~x + ~y)) / 2), @@ -78,7 +66,13 @@ end _trig_expand_products(x::Num) = wrap(_trig_expand_products(unwrap(x))) _trig_expand_products(x) = x -"Return true if `f` is a sin or cos." +""" + is_trig(f::Num) + +Check if the given expression `f` is a trigonometric function (sine or cosine). + +Returns `true` if `f` is either `sin` or `cos`, `false` otherwise. +""" is_trig(f::Num) = is_trig(f.val) is_trig(f) = false function is_trig(f::BasicSymbolic) @@ -100,6 +94,9 @@ Used in Fourier analysis to find the cosine components of a periodic function. - `ω`: The angular frequency - `t`: The time variable """ +function fourier_cos_term(x, ω, t) + return _fourier_term(x, ω, t, cos) +end """ add_div(x) @@ -109,11 +106,7 @@ Transforms expressions of the form a/b + c/d into (ad + bc)/bd. Returns the simplified fraction as a `Num` type. """ -function fourier_cos_term(x, ω, t) - return _fourier_term(x, ω, t, cos) -end - -"Simplify fraction a/b + c/d = (ad + bc)/bd" +# Simplify fraction a/b + c/d = (ad + bc)/bd add_div(x) = wrap(Postwalk(add_with_div)(unwrap(x))) """ @@ -183,23 +176,18 @@ function _strip_real_imag(x::Complex{Num}) return _strip_real_imag(x.re) + im * _strip_real_imag(x.im) end -function _strip_zero_imag_literals(x::BasicSymbolic) - return Postwalk(ex -> begin - v = SymbolicUtils.unwrap_const(ex) - if v isa Complex && iszero(imag(v)) - return real(v) - end - return ex - end)(x) -end - -function _strip_zero_imag_literals(x::Num) - return wrap(_strip_zero_imag_literals(unwrap(x))) -end +_postwalk(f, x::BasicSymbolic) = Postwalk(f)(x) +_postwalk(f, x::Num) = wrap(_postwalk(f, unwrap(x))) +_postwalk(f, x::Complex{Num}) = _postwalk(f, x.re) + im * _postwalk(f, x.im) +_postwalk(f, x) = x -function _strip_zero_imag_literals(x::Complex{Num}) - return _strip_zero_imag_literals(x.re) + im * _strip_zero_imag_literals(x.im) -end +_strip_zero_imag_literals(x) = _postwalk( + ex -> begin + v = SymbolicUtils.unwrap_const(ex) + return (v isa Complex && iszero(imag(v))) ? real(v) : ex + end, + x, +) function _strip_real_imag(x::BasicSymbolic) function _real_of(ex::BasicSymbolic) @@ -285,9 +273,7 @@ function _simplify_trig_zero(x::BasicSymbolic) end)(x) end -function _simplify_trig_zero(x::Num) - return wrap(_simplify_trig_zero(unwrap(x))) -end +_simplify_trig_zero(x::Num) = wrap(_simplify_trig_zero(unwrap(x))) """ trig_to_exp(x::Num) diff --git a/src/utils.jl b/src/utils.jl index 5ece8c0..a88af06 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,6 +4,17 @@ end flatten(a) = collect(Iterators.flatten(a)) +"""Strip Symbolics derivative wrappers to recover the base dependent variable.""" +strip_derivative(x::Num) = wrap(strip_derivative(unwrap(x))) +function strip_derivative(x::BasicSymbolic) + y = x + while Symbolics.is_derivative(y) + y = first(arguments(y)) + end + return y +end +strip_derivative(x) = x + "Show fields of an object." function show_fields(object) for field in fieldnames(typeof(object)) # display every field @@ -13,55 +24,37 @@ function show_fields(object) end _is_symbolic_like(x) = x isa Num || x isa BasicSymbolic -_is_symbolic_like(x::Complex) = _is_symbolic_like(real(x)) || _is_symbolic_like(imag(x)) function _eqtest_symbolic_scalar(a, b) diff = Symbolics.simplify(Symbolics.expand(a - b)) return isequal(diff, 0) end -function _eqtest_symbolic_scalar(a::Complex, b::Complex) - return _eqtest_equal(real(a), real(b)) && _eqtest_equal(imag(a), imag(b)) -end - -function _eqtest_symbolic_scalar(a::Complex, b) - return _eqtest_equal(real(a), b) && _eqtest_equal(imag(a), 0) -end - -function _eqtest_symbolic_scalar(a, b::Complex) - return _eqtest_equal(a, real(b)) && _eqtest_equal(0, imag(b)) -end +_eqtest_equal(a::Complex, b::Complex) = + _eqtest_equal(real(a), real(b)) && _eqtest_equal(imag(a), imag(b)) +_eqtest_equal(a::Complex, b) = _eqtest_equal(real(a), b) && _eqtest_equal(imag(a), 0) +_eqtest_equal(a, b::Complex) = _eqtest_equal(a, real(b)) && _eqtest_equal(0, imag(b)) function _eqtest_equal(a::AbstractArray, b::AbstractArray) size(a) == size(b) || return false - for (aa, bb) in zip(a, b) - _eqtest_equal(aa, bb) || return false - end - return true + return all(i -> _eqtest_equal(a[i], b[i]), eachindex(a, b)) end function _eqtest_equal(a::Tuple, b::Tuple) length(a) == length(b) || return false - for (aa, bb) in zip(a, b) - _eqtest_equal(aa, bb) || return false - end - return true + return all(i -> _eqtest_equal(a[i], b[i]), eachindex(a)) end -function _eqtest_equal(a::Equation, b::Equation) - return _eqtest_equal(a.lhs, b.lhs) && _eqtest_equal(a.rhs, b.rhs) -end +_eqtest_equal(a::Equation, b::Equation) = _eqtest_equal(a.lhs, b.lhs) && _eqtest_equal(a.rhs, b.rhs) function _eqtest_equal(a, b) isequal(a, b) && return true - if _is_symbolic_like(a) || _is_symbolic_like(b) - try - return _eqtest_symbolic_scalar(a, b) - catch - return false - end + (_is_symbolic_like(a) || _is_symbolic_like(b)) || return false + return try + _eqtest_symbolic_scalar(a, b) + catch + false end - return false end _eqtest_notequal(a, b) = !_eqtest_equal(a, b) diff --git a/test/symbolics.jl b/test/symbolics.jl index 2d957e1..14783b5 100644 --- a/test/symbolics.jl +++ b/test/symbolics.jl @@ -12,7 +12,6 @@ using QuestBase: @eqtest, trig_reduce @eqtest simplify(exp(a)^3) == exp(3 * a) @eqtest simplify(exp(a)^n) == exp(n * a) @eqtest expand_all(exp(a)^3) == exp(3 * a) - @eqtest expand_all(exp(a)^3) == exp(3 * a) @eqtest expand_all(im * exp(a)^5) == im * exp(5 * a) end @@ -63,9 +62,6 @@ end using QuestBase: expand_all, trig_to_exp, exp_to_trig @testset "Num" begin @variables f t - cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2 - sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im) - # Conversion between trig and exp form. # We validate by substituting numeric values (robust across Symbolics canonicalization). trigs = [cos(f * t), sin(f * t)] @@ -162,32 +158,39 @@ end @eqtest fourier_cos_term(cos(f * t)^2 + 1, 0, t) == 3//2 @eqtest fourier_cos_term((cos(f * t)^2 + cos(f * t))^3, 0, t) == 23//16 + function _check_fourier(term, specs) + for (ω, cos_expected, sin_expected) in specs + @eqtest fourier_cos_term(term, ω, t) == cos_expected + @eqtest fourier_sin_term(term, ω, t) == sin_expected + end + return nothing + end + # more complex but closed-form cases - term = (a + b * cos(f * t))^2 - @eqtest fourier_cos_term(term, f, t) == 2 * a * b - @eqtest fourier_sin_term(term, f, t) == 0 - @eqtest fourier_cos_term(term, 2 * f, t) == b^2 / 2 - @eqtest fourier_sin_term(term, 2 * f, t) == 0 - @eqtest fourier_cos_term(term, 0, t) == a^2 + b^2 / 2 - - term = (a + b * sin(f * t))^2 - @eqtest fourier_cos_term(term, f, t) == 0 - @eqtest fourier_sin_term(term, f, t) == 2 * a * b - @eqtest fourier_cos_term(term, 2 * f, t) == -b^2 / 2 - @eqtest fourier_sin_term(term, 2 * f, t) == 0 - @eqtest fourier_cos_term(term, 0, t) == a^2 + b^2 / 2 - - term = (a + b * cos(f * t + θ)) * (a + b * cos(f * t - θ)) - @eqtest fourier_cos_term(term, f, t) == 2 * a * b * cos(θ) - @eqtest fourier_sin_term(term, f, t) == 0 - @eqtest fourier_cos_term(term, 2 * f, t) == b^2 / 2 - @eqtest fourier_sin_term(term, 2 * f, t) == 0 - @eqtest fourier_cos_term(term, 0, t) == a^2 + b^2 / 2 * cos(2 * θ) - - term = (a + b * cos(f * t))^3 - @eqtest fourier_cos_term(term, f, t) == 3 * a^2 * b + 3//4 * b^3 - @eqtest fourier_sin_term(term, f, t) == 0 - @eqtest fourier_cos_term(term, 0, t) == a^3 + 3//2 * a * b^2 + for (term, specs) in ( + ( + (a + b * cos(f * t))^2, + ((f, 2 * a * b, 0), (2 * f, b^2 / 2, 0), (0, a^2 + b^2 / 2, 0)), + ), + ( + (a + b * sin(f * t))^2, + ((f, 0, 2 * a * b), (2 * f, -b^2 / 2, 0), (0, a^2 + b^2 / 2, 0)), + ), + ( + (a + b * cos(f * t + θ)) * (a + b * cos(f * t - θ)), + ( + (f, 2 * a * b * cos(θ), 0), + (2 * f, b^2 / 2, 0), + (0, a^2 + b^2 / 2 * cos(2 * θ), 0), + ), + ), + ( + (a + b * cos(f * t))^3, + ((f, 3 * a^2 * b + 3//4 * b^3, 0), (0, a^3 + 3//2 * a * b^2, 0)), + ), + ) + _check_fourier(term, specs) + end end @testset "_apply_termwise" begin From c63fce0c051e358450772c0d6ec1b561b23bbcb8 Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Sat, 24 Jan 2026 22:30:47 +0100 Subject: [PATCH 8/8] format and more test --- src/HarmonicEquation.jl | 1 + src/Symbolics/fourier.jl | 9 ++++---- src/utils.jl | 7 ++++-- test/HarmonicEquation.jl | 19 ++++++++++++++++- test/runtests.jl | 4 ++++ test/symbolics.jl | 27 +++++++++++++++++++++++ test/utils.jl | 46 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 105 insertions(+), 8 deletions(-) create mode 100644 test/utils.jl diff --git a/src/HarmonicEquation.jl b/src/HarmonicEquation.jl index ca8043a..b805093 100644 --- a/src/HarmonicEquation.jl +++ b/src/HarmonicEquation.jl @@ -116,6 +116,7 @@ Base.show(eom::HarmonicEquation) = show_fields(eom) function substitute_all(eom::HarmonicEquation, rules::Union{Dict,Pair})::HarmonicEquation new_eom = deepcopy(eom) new_eom.equations = expand_derivatives.(substitute_all(eom.equations, rules)) + new_eom.variables = substitute_all(eom.variables, rules) return new_eom end diff --git a/src/Symbolics/fourier.jl b/src/Symbolics/fourier.jl index f049900..cd2a1f2 100644 --- a/src/Symbolics/fourier.jl +++ b/src/Symbolics/fourier.jl @@ -181,13 +181,12 @@ _postwalk(f, x::Num) = wrap(_postwalk(f, unwrap(x))) _postwalk(f, x::Complex{Num}) = _postwalk(f, x.re) + im * _postwalk(f, x.im) _postwalk(f, x) = x -_strip_zero_imag_literals(x) = _postwalk( - ex -> begin +function _strip_zero_imag_literals(x) + _postwalk(ex -> begin v = SymbolicUtils.unwrap_const(ex) return (v isa Complex && iszero(imag(v))) ? real(v) : ex - end, - x, -) + end, x) +end function _strip_real_imag(x::BasicSymbolic) function _real_of(ex::BasicSymbolic) diff --git a/src/utils.jl b/src/utils.jl index a88af06..9bbc5d1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -30,8 +30,9 @@ function _eqtest_symbolic_scalar(a, b) return isequal(diff, 0) end -_eqtest_equal(a::Complex, b::Complex) = +function _eqtest_equal(a::Complex, b::Complex) _eqtest_equal(real(a), real(b)) && _eqtest_equal(imag(a), imag(b)) +end _eqtest_equal(a::Complex, b) = _eqtest_equal(real(a), b) && _eqtest_equal(imag(a), 0) _eqtest_equal(a, b::Complex) = _eqtest_equal(a, real(b)) && _eqtest_equal(0, imag(b)) @@ -45,7 +46,9 @@ function _eqtest_equal(a::Tuple, b::Tuple) return all(i -> _eqtest_equal(a[i], b[i]), eachindex(a)) end -_eqtest_equal(a::Equation, b::Equation) = _eqtest_equal(a.lhs, b.lhs) && _eqtest_equal(a.rhs, b.rhs) +function _eqtest_equal(a::Equation, b::Equation) + _eqtest_equal(a.lhs, b.lhs) && _eqtest_equal(a.rhs, b.rhs) +end function _eqtest_equal(a, b) isequal(a, b) && return true diff --git a/test/HarmonicEquation.jl b/test/HarmonicEquation.jl index 9bff61f..4cfb7e0 100644 --- a/test/HarmonicEquation.jl +++ b/test/HarmonicEquation.jl @@ -22,7 +22,7 @@ using QuestBase: # Setup common test variables @variables t, T -@variables x(t) y(t) u(T) v(T) +@variables x(t) y(t) u(T) v(T) w(T) D = Differential(T) # Create simple test equation @@ -78,6 +78,8 @@ end rules = Dict(u => a) subbed = substitute_all(heq, rules) @test !isequal(subbed.equations, heq.equations) + @test !isequal(subbed.variables, heq.variables) + @eqtest subbed.variables[1].symbol == a end @testset "Utility functions" begin @@ -96,3 +98,18 @@ end list = unique(filter(x -> !(x isa Real), Symbolics.unwrap.(reduce(vcat, list)))) @test all(map(x -> !hasproperty(x, :arguments), list)) end + +@testset "is_rearranged MF path" begin + # No derivative terms on either side => treated as arranged by construction (MF_bool). + eq_alg1 = u ~ u + v + eq_alg2 = v ~ u + heq_alg = HarmonicEquation([eq_alg1, eq_alg2], [hv1, hv2], nat_eq) + @test is_rearranged(heq_alg) +end + +@testset "Hopf variable filtered from ansatz" begin + hv_hopf = HarmonicVariable(w, "test", "Hopf", Num(1.0), x) + heq_hopf = HarmonicEquation([eq1, eq2], [hv1, hv2, hv_hopf], nat_eq) + ans = QuestBase._show_ansatz(heq_hopf) + @test !occursin(string(hv_hopf.symbol), ans) +end diff --git a/test/runtests.jl b/test/runtests.jl index 284cf5f..8ec9a27 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,10 @@ end include("symbolics.jl") end +@testset "Utils" begin + include("utils.jl") +end + @testset "DifferentialEquations" begin include("DifferentialEquations.jl") end diff --git a/test/symbolics.jl b/test/symbolics.jl index 14783b5..4f24a96 100644 --- a/test/symbolics.jl +++ b/test/symbolics.jl @@ -56,6 +56,13 @@ end @eqtest drop_powers([a^2 + a + b, b], a, 2) == [a + b, b] @eqtest drop_powers([a^2 + a + b, b], [a, b], 2) == [a + b, b] + + @testset "Vector{Equation}" begin + eqs = [a^2 + a ~ a, b^3 + b ~ 0] + out = drop_powers(eqs, [a], 2) + @eqtest [out[1].lhs, out[1].rhs] == [a, a] + @eqtest out[2] == eqs[2] + end end @testset "trig_to_exp and trig_to_exp" begin @@ -277,4 +284,24 @@ end @eqtest substitute_all(a * b * c * d * e * f * g * h, rules) == b^2 * d^2 * f^2 * h^2 @eqtest substitute_all([a, c, e], rules) == [b, d, f] @eqtest substitute_all(a + b * im, rules) == b + b * im + + @testset "include_derivatives rewrites Differential(var)" begin + @variables t T x(t) + D = Differential(t) + expr = D(x) + + # Symbolics substitution currently rewrites arguments but does not rewrite the + # derivative operator itself (i.e., the `operation` remains `Differential(t)`). + expected_arg = Symbolics.unwrap(Symbolics.substitute(x, Dict(t => T))) + + out = substitute_all(expr, Dict(t => T); include_derivatives=true) + out_bs = Symbolics.unwrap(out) + @test Symbolics.operation(out_bs) == Differential(t) + @test isequal(first(Symbolics.arguments(out_bs)), expected_arg) + + out_no = substitute_all(expr, Dict(t => T); include_derivatives=false) + out_no_bs = Symbolics.unwrap(out_no) + @test Symbolics.operation(out_no_bs) == Differential(t) + @test isequal(first(Symbolics.arguments(out_no_bs)), expected_arg) + end end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..7b2e5b7 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,46 @@ +using Test +using Symbolics +using LinearAlgebra + +using QuestBase: strip_derivative, dummy_symbolic_Jacobian, is_identity, hasnan, d, @eqtest + +@testset "strip_derivative" begin + @variables t x(t) y(t) + + @eqtest strip_derivative(x) == x + @eqtest strip_derivative(d(x, t)) == x + @eqtest strip_derivative(d(d(x, t), t)) == x + + # should leave non-derivative expressions untouched + @eqtest strip_derivative(x + y) == x + y +end + +@testset "eqtest_equal containers and complex" begin + @variables a b c + + @eqtest [a, b] == [a, b] + @test !QuestBase._eqtest_equal([a], [a, b]) + + @eqtest (a, b, c) == (a, b, c) + @test !QuestBase._eqtest_equal((a, b), (a, c)) + + z = a + im * b + @eqtest z == z + @eqtest (a + 0im) == a + @eqtest a == (a + 0im) +end + +@testset "identity and NaN helpers" begin + @variables a + + I2 = Matrix{Num}(LinearAlgebra.I, 2, 2) + @test is_identity(I2) + + A = Num[1 0; 0 (1 + a)] + @test !is_identity(A) + + J = dummy_symbolic_Jacobian(3) + @test hasnan(J) + @test !hasnan(I2) + @test !is_identity(J) +end