diff --git a/Project.toml b/Project.toml index 2f1a691..3264bf2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "QuestBase" uuid = "7e80f742-43d6-403d-a9ea-981410111d43" authors = ["Orjan Ameye ", "Jan Kosata ", "Javier del Pino "] -version = "0.3.4" +version = "0.4.0" [deps] DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -12,8 +12,8 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] DocStringExtensions = "0.9.4" -SymbolicUtils = "3.25" -Symbolics = "6.34" +SymbolicUtils = "4.13.1" +Symbolics = "7.8.0" julia = "1.10" Random = "1.10" LinearAlgebra = "1.10" @@ -21,7 +21,7 @@ Test = "1.10" OrderedCollections = "1.8" Aqua = "0.8.11" ExplicitImports = "1.11" -JET = "0.9.18, 0.10.0" +JET = "0.9.18, 0.10.0, 0.11" CheckConcreteStructs = "0.1.0" [extras] diff --git a/src/DifferentialEquation.jl b/src/DifferentialEquation.jl index fcda18c..4bd57f1 100644 --- a/src/DifferentialEquation.jl +++ b/src/DifferentialEquation.jl @@ -118,7 +118,7 @@ $(TYPEDSIGNATURES) Return the independent dependent variables of `diff_eom`. """ function get_independent_variables(diff_eom::DifferentialEquation) - return Num.(flatten(unique([x.val.arguments for x in keys(diff_eom.equations)]))) + return Num.(flatten(unique([arguments(x.val) for x in keys(diff_eom.equations)]))) end """ @@ -163,7 +163,10 @@ corresponding to second-order differential equations. function is_rearranged_standard(eom::DifferentialEquation, degree=2) tvar = get_independent_variables(eom)[1] D = Differential(tvar)^degree - return isequal(getfield.(values(eom.equations), :lhs), D.(get_variables(eom))) + lhs = getfield.(values(eom.equations), :lhs) + rhs = D.(get_variables(eom)) + diffs = Symbolics.simplify.(lhs .- rhs) + return all(is_literal_zero, diffs) end """ @@ -195,7 +198,7 @@ function rearrange!(eom::DifferentialEquation, new_lhs::Vector{Num}) return nothing end function get_variables_nums(vars::Vector{Num}) - unique(flatten([Num.(get_variables(x)) for x in vars])) + return strip_derivative.(vars) end # TODO: remove this function or at least better names """ diff --git a/src/HarmonicEquation.jl b/src/HarmonicEquation.jl index 4889146..b805093 100644 --- a/src/HarmonicEquation.jl +++ b/src/HarmonicEquation.jl @@ -62,11 +62,20 @@ end "Get the parameters (not time nor variables) of a HarmonicEquation" function _parameters(eom::HarmonicEquation) - all_symbols = flatten([ - cat(get_variables(eq.lhs), get_variables(eq.rhs); dims=1) for eq in eom.equations - ]) - # subtract the set of independent variables (i.e., time) from all free symbols - return setdiff(all_symbols, get_variables(eom), get_independent_variables(eom)) + symbols = Num[] + for eq in eom.equations + vars = union(Symbolics.get_variables(eq.lhs), Symbolics.get_variables(eq.rhs)) + vars = sort!(collect(vars); by=string) + for sym in vars + push!(symbols, strip_derivative(Num(sym))) + end + end + vars = Set(get_variables(eom)) + indep = Set(get_independent_variables(eom)) + params = filter(s -> !(s in vars || s in indep), symbols) + params = unique(params) + sort!(params; by=string) + return params end """ @@ -107,6 +116,7 @@ Base.show(eom::HarmonicEquation) = show_fields(eom) function substitute_all(eom::HarmonicEquation, rules::Union{Dict,Pair})::HarmonicEquation new_eom = deepcopy(eom) new_eom.equations = expand_derivatives.(substitute_all(eom.equations, rules)) + new_eom.variables = substitute_all(eom.variables, rules) return new_eom end diff --git a/src/HarmonicVariable.jl b/src/HarmonicVariable.jl index 0023311..c2f91b8 100644 --- a/src/HarmonicVariable.jl +++ b/src/HarmonicVariable.jl @@ -43,7 +43,7 @@ function _show_ansatz(var::HarmonicVariable) if isempty(var.type) return string(var.symbol) end - t = var.natural_variable.val.arguments + t = arguments(var.natural_variable.val) t = length(t) == 1 ? string(t[1]) : error("more than 1 independent variable") ω = string(var.ω) terms = Dict("u" => "*cos(" * ω * t * ")", "v" => "*sin(" * ω * t * ")", "a" => "") diff --git a/src/QuestBase.jl b/src/QuestBase.jl index 17d5b49..c8a41a6 100644 --- a/src/QuestBase.jl +++ b/src/QuestBase.jl @@ -7,15 +7,14 @@ using LinearAlgebra: LinearAlgebra using SymbolicUtils: SymbolicUtils, Postwalk, - Sym, BasicSymbolic, + unwrap, isterm, ispow, isadd, isdiv, ismul, add_with_div, - frac_maketerm, issym using SymbolicUtils.Unityper: @compactified @@ -23,7 +22,6 @@ using SymbolicUtils.Unityper: @compactified using Symbolics: Symbolics, Num, - unwrap, wrap, get_variables, Equation, diff --git a/src/Symbolics/Symbolics_utils.jl b/src/Symbolics/Symbolics_utils.jl index aee9122..a49133d 100644 --- a/src/Symbolics/Symbolics_utils.jl +++ b/src/Symbolics/Symbolics_utils.jl @@ -2,6 +2,14 @@ expand_all(x::Num) = Num(expand_all(x.val)) _apply_termwise(f, x::Num) = wrap(_apply_termwise(f, unwrap(x))) +# Symbolics v7 note: `unwrap` now always returns `BasicSymbolic`; use `Symbolics.value` +# when you want the underlying literal number (via `Const`). +_literal_number(x::Num) = (v=Symbolics.value(x); v isa Number ? v : nothing) +function _literal_number(x::BasicSymbolic) + (v=SymbolicUtils.unwrap_const(x); v isa Number ? v : nothing) +end +is_literal_zero(x) = (v=_literal_number(x); v !== nothing && iszero(v)) + "Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)" function expand_all(x) result = Postwalk(expand_exp_power)(SymbolicUtils.expand(x)) @@ -10,33 +18,49 @@ end expand_all(x::Complex{Num}) = expand_all(x.re) + im * expand_all(x.im) function expand_fraction(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Add => _apply_termwise(expand_fraction, x) - Mul => _apply_termwise(expand_fraction, x) - Div => sum([arg / x.den for arg in arguments(x.num)]) - _ => x + if isadd(x) + return _apply_termwise(expand_fraction, x) + elseif ismul(x) + return _apply_termwise(expand_fraction, x) + elseif isdiv(x) + # Only distribute division over addition in the numerator. + # In SymbolicUtils v4, `arguments(x.num)` for a multiplication returns factors, + # so splitting unconditionally would produce incorrect results (e.g. d*(a+b)/c). + num, den = x.num, x.den + if isadd(num) + return sum([expand_fraction(arg) / den for arg in arguments(num)]) + end + return expand_fraction(num) / expand_fraction(den) + else + return x end end expand_fraction(x::Num) = Num(expand_fraction(x.val)) "Apply a function f on every member of a sum or a product" function _apply_termwise(f, x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Add => sum([f(arg) for arg in arguments(x)]) - Mul => prod([f(arg) for arg in arguments(x)]) - Div => _apply_termwise(f, x.num) / _apply_termwise(f, x.den) - _ => f(x) + if isadd(x) + return sum([f(arg) for arg in arguments(x)]) + elseif ismul(x) + return prod([f(arg) for arg in arguments(x)]) + elseif isdiv(x) + return _apply_termwise(f, x.num) / _apply_termwise(f, x.den) + else + return f(x) end end simplify_complex(x::Complex) = isequal(x.im, 0) ? x.re : x.re + im * x.im simplify_complex(x) = x function simplify_complex(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Add => _apply_termwise(simplify_complex, x) - Mul => _apply_termwise(simplify_complex, x) - Div => _apply_termwise(simplify_complex, x) - _ => x + if isadd(x) + return _apply_termwise(simplify_complex, x) + elseif ismul(x) + return _apply_termwise(simplify_complex, x) + elseif isdiv(x) + return _apply_termwise(simplify_complex, x) + else + return x end end @@ -68,7 +92,10 @@ function substitute_all(x::Complex{Num}, rules::Collections) return substitute_all(x.re, rules) + im * substitute_all(x.im, rules) end -get_independent(x::Num, t::Num)::Num = wrap(get_independent(x.val, t)) +function get_independent(x::Num, t::Num) + result = get_independent(x.val, t) + return result isa Complex{Num} ? result : wrap(result) +end function get_independent(x::Complex{Num}, t::Num) return get_independent(x.re, t) + im * get_independent(x.im, t) end @@ -76,14 +103,21 @@ get_independent(v::Vector{Num}, t::Num) = [get_independent(el, t) for el in v] get_independent(x, t::Num) = x function get_independent(x::BasicSymbolic, t::Num) - @compactified x::BasicSymbolic begin - Add => sum([get_independent(arg, t) for arg in arguments(x)]) - Mul => prod([get_independent(arg, t) for arg in arguments(x)]) - Div => !is_function(x.den, t) ? get_independent(x.num, t) / x.den : 0 - Pow => !is_function(x.base, t) && !is_function(x.exp, t) ? x : 0 - Term => !is_function(x, t) ? x : 0 - Sym => !is_function(x, t) ? x : 0 - _ => x + if isadd(x) + return sum([get_independent(arg, t) for arg in arguments(x)]) + elseif ismul(x) + return prod([get_independent(arg, t) for arg in arguments(x)]) + elseif isdiv(x) + return !is_function(x.den, t) ? get_independent(x.num, t) / x.den : 0 + elseif ispow(x) + base, exponent = arguments(x) + return !is_function(base, t) && !is_function(exponent, t) ? x : 0 + elseif isterm(x) + return !is_function(x, t) ? x : 0 + elseif issym(x) + return !is_function(x, t) ? x : 0 + else + return x end end @@ -94,11 +128,14 @@ function get_all_terms(x::Equation) return unique(cat(get_all_terms(Num(x.lhs)), get_all_terms(Num(x.rhs)); dims=1)) end function _get_all_terms(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Add => vcat([_get_all_terms(term) for term in SymbolicUtils.arguments(x)]...) - Mul => SymbolicUtils.arguments(x) - Div => [_get_all_terms(x.num)..., _get_all_terms(x.den)...] - _ => [x] + if isadd(x) + return vcat([_get_all_terms(term) for term in SymbolicUtils.sorted_arguments(x)]...) + elseif ismul(x) + return SymbolicUtils.sorted_arguments(x) + elseif isdiv(x) + return [_get_all_terms(x.num)..., _get_all_terms(x.den)...] + else + return [x] end end _get_all_terms(x) = x @@ -109,11 +146,11 @@ function is_harmonic(x::Num, t::Num)::Bool isempty(t_terms) && return true trigs = is_trig.(t_terms) - if !prod(trigs) + if !all(trigs) return false else - powers = [max_power(first(term.val.arguments), t) for term in t_terms[trigs]] - return all(isone, powers) + powers = [max_power(first(arguments(term.val)), t) for term in t_terms[trigs]] + return all(isequal(1), powers) end end @@ -126,10 +163,19 @@ is_function(f, var) = any(isequal.(get_variables(f), var)) """ Counts the number of derivatives of a symbolic variable. """ -function count_derivatives(x::Symbolics.BasicSymbolic) - (Symbolics.isterm(x) || Symbolics.issym(x)) || +function count_derivatives(x::BasicSymbolic) + if Symbolics.is_derivative(x) + arg = first(arguments(x)) + ( + issym(arg) || + Symbolics.is_derivative(arg) || + (isterm(arg) && issym(operation(arg))) + ) || error("The input is not a single term or symbol") + D = operation(x) + return D.order + count_derivatives(arg) + end + (issym(x) || (isterm(x) && issym(operation(x)))) || error("The input is not a single term or symbol") - bool = Symbolics.is_derivative(x) - return bool ? 1 + count_derivatives(first(arguments(x))) : 0 + return 0 end count_derivatives(x::Num) = count_derivatives(Symbolics.unwrap(x)) diff --git a/src/Symbolics/drop_powers.jl b/src/Symbolics/drop_powers.jl index 3c88cfa..ad8409e 100644 --- a/src/Symbolics/drop_powers.jl +++ b/src/Symbolics/drop_powers.jl @@ -17,10 +17,16 @@ function drop_powers(expr::Num, vars::Vector{Num}, deg::Int) Symbolics.@variables ϵ subs_expr = deepcopy(expr) rules = Dict([var => ϵ * var for var in unique(vars)]) - subs_expr = Symbolics.expand(substitute_all(subs_expr, rules)) + subs_expr = Symbolics.expand( + substitute_all(subs_expr, rules; include_derivatives=false) + ) max_deg = max_power(subs_expr, ϵ) removal = Dict([ϵ^d => Num(0) for d in deg:max_deg]) - res = substitute_all(substitute_all(subs_expr, removal), Dict(ϵ => Num(1))) + res = substitute_all( + substitute_all(subs_expr, removal; include_derivatives=false), + Dict(ϵ => Num(1)); + include_derivatives=false, + ) return Symbolics.expand(res) end @@ -30,7 +36,7 @@ end # calls the above for various types of the first argument function drop_powers(eq::Equation, var::Vector{Num}, deg::Int) - return drop_powers(eq.lhs, var, deg) .~ drop_powers(eq.lhs, var, deg) + return drop_powers(eq.lhs, var, deg) .~ drop_powers(eq.rhs, var, deg) end function drop_powers(eqs::Vector{Equation}, var::Vector{Num}, deg::Int) return [ @@ -45,7 +51,14 @@ drop_powers(x, vars, deg::Int) = drop_powers(wrap(x), vars, deg) function max_power(x::Num, y::Num) terms = get_all_terms(x) powers = power_of.(terms, y) - return maximum(powers) + literal_powers = Int[] + for p in powers + pv = SymbolicUtils.unwrap_const(p) + pv isa Number || continue + push!(literal_powers, Int(pv)) + end + isempty(literal_powers) && return 0 + return maximum(literal_powers) end max_power(x::Vector{Num}, y::Num) = maximum(max_power.(x, y)) @@ -60,7 +73,8 @@ end function power_of(x::BasicSymbolic, y::BasicSymbolic) if ispow(x) && issym(y) - return isequal(x.base, y) ? x.exp : 0 + base, exponent = arguments(x) + return isequal(base, y) ? exponent : 0 elseif issym(x) && issym(y) return isequal(x, y) ? 1 : 0 else diff --git a/src/Symbolics/exponentials.jl b/src/Symbolics/exponentials.jl index 74423ec..e5265f2 100644 --- a/src/Symbolics/exponentials.jl +++ b/src/Symbolics/exponentials.jl @@ -2,14 +2,22 @@ expand_exp_power(expr::Num) = expand_exp_power(expr.val) simplify_exp_products(x::Num) = simplify_exp_products(x.val) "Returns true if expr is an exponential" -isexp(expr) = isterm(expr) && expr.f == exp +isexp(expr) = isterm(expr) && operation(expr) === exp "Expand powers of exponential such that exp(x)^n => exp(x*n) " function expand_exp_power(expr::BasicSymbolic) - @compactified expr::BasicSymbolic begin - Add => sum([expand_exp_power(arg) for arg in arguments(expr)]) - Mul => prod([expand_exp_power(arg) for arg in arguments(expr)]) - _ => ispow(expr) && isexp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr + if isadd(expr) + return sum([expand_exp_power(arg) for arg in arguments(expr)]) + elseif ismul(expr) + return prod([expand_exp_power(arg) for arg in arguments(expr)]) + else + if ispow(expr) + base, exponent = arguments(expr) + if isexp(base) + return exp(arguments(base)[1] * exponent) + end + end + return expr end end expand_exp_power(expr) = expr @@ -17,11 +25,14 @@ expand_exp_power(expr) = expr "Simplify products of exponentials such that exp(a)*exp(b) => exp(a+b) This is included in SymbolicUtils as of 17.0 but the method here avoid other simplify calls" function simplify_exp_products(expr::BasicSymbolic) - @compactified expr::BasicSymbolic begin - Add => _apply_termwise(simplify_exp_products, expr) - Div => _apply_termwise(simplify_exp_products, expr) - Mul => simplify_exp_products_mul(expr) - _ => expr + if isadd(expr) + return _apply_termwise(simplify_exp_products, expr) + elseif isdiv(expr) + return _apply_termwise(simplify_exp_products, expr) + elseif ismul(expr) + return simplify_exp_products_mul(expr) + else + return expr end end function simplify_exp_products(x::Complex{Num}) @@ -32,10 +43,10 @@ function simplify_exp_products_mul(expr) rest_ind = setdiff(1:length(arguments(expr)), ind) rest = isempty(rest_ind) ? 1 : prod(arguments(expr)[rest_ind]) total = isempty(ind) ? 0 : sum(getindex.(arguments.(arguments(expr)[ind]), 1)) - if SymbolicUtils.is_literal_number(total) - (total == 0 && return rest) - else - return rest * exp(total) + total_val = SymbolicUtils.unwrap_const(total) + if total_val isa Number + return iszero(total_val) ? rest : rest * exp(total) end + return rest * exp(total) end simplify_exp_products(x) = x diff --git a/src/Symbolics/fourier.jl b/src/Symbolics/fourier.jl index 02d3e59..cd2a1f2 100644 --- a/src/Symbolics/fourier.jl +++ b/src/Symbolics/fourier.jl @@ -13,14 +13,6 @@ This function performs the following steps: Returns the simplified expression as a `Num` type. """ - -""" - is_trig(f::Num) - -Check if the given expression `f` is a trigonometric function (sine or cosine). - -Returns `true` if `f` is either `sin` or `cos`, `false` otherwise. -""" function trig_reduce(x) x = add_div(x) # a/b + c/d = (ad + bc)/bd x = expand(x) # open all brackets @@ -28,17 +20,67 @@ function trig_reduce(x) x = expand_all(x) # expand products of exponentials x = simplify_exp_products(x) # simplify products of exps x = exp_to_trig(x) + x = _trig_expand_products(x) x = Num(simplify_complex(expand(x))) return x # simplify_fractions(x)# (a*c^2 + b*c)/c^2 = (a*c + b)/c end -"Return true if `f` is a sin or cos." +const _rw_trig_mul_to_sum = SymbolicUtils.Rewriters.Chain([ + SymbolicUtils.@rule(cos(~x) * cos(~y) => (cos(~x - ~y) + cos(~x + ~y)) / 2), + SymbolicUtils.@rule(sin(~x) * sin(~y) => (cos(~x - ~y) - cos(~x + ~y)) / 2), + SymbolicUtils.@rule(sin(~x) * cos(~y) => (sin(~x + ~y) + sin(~x - ~y)) / 2), + SymbolicUtils.@rule(cos(~x) * sin(~y) => (sin(~x + ~y) - sin(~x - ~y)) / 2), +]) + +const _rw_trig_expand = SymbolicUtils.Rewriters.Fixpoint( + SymbolicUtils.Rewriters.Postwalk( + SymbolicUtils.Rewriters.Chain([ + SymbolicUtils.@rule((cos(~x))^2 => (1 + cos(2 * ~x)) / 2), + SymbolicUtils.@rule((sin(~x))^2 => (1 - cos(2 * ~x)) / 2), + _rw_trig_mul_to_sum, + ]), + ), +) + +function _trig_expand_products(x::BasicSymbolic) + # Expand trig products/powers into sums so `get_independent` can isolate constants. + y = Postwalk( + ex -> begin + if ismul(ex) + # In SymbolicUtils v4, `arguments(ismul(...))` includes the numeric coefficient + # even though `ex.coeff` also stores it. Avoid double-counting it. + coeff = ex.coeff + factors = BasicSymbolic[ + f for f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) + ] + rest = isempty(factors) ? 1 : prod(factors; init=1) + return coeff * _rw_trig_expand(rest) + end + return _rw_trig_expand(ex) + end, + )( + x + ) + return SymbolicUtils.expand(y) +end +_trig_expand_products(x::Num) = wrap(_trig_expand_products(unwrap(x))) +_trig_expand_products(x) = x + +""" + is_trig(f::Num) + +Check if the given expression `f` is a trigonometric function (sine or cosine). + +Returns `true` if `f` is either `sin` or `cos`, `false` otherwise. +""" is_trig(f::Num) = is_trig(f.val) is_trig(f) = false function is_trig(f::BasicSymbolic) - f = ispow(f) ? f.base : f - isterm(f) && SymbolicUtils.operation(f) ∈ [cos, sin] && return true - return false + if ispow(f) + base, _ = arguments(f) + f = base + end + return isterm(f) && SymbolicUtils.operation(f) ∈ [cos, sin] end """ @@ -52,6 +94,9 @@ Used in Fourier analysis to find the cosine components of a periodic function. - `ω`: The angular frequency - `t`: The time variable """ +function fourier_cos_term(x, ω, t) + return _fourier_term(x, ω, t, cos) +end """ add_div(x) @@ -61,12 +106,8 @@ Transforms expressions of the form a/b + c/d into (ad + bc)/bd. Returns the simplified fraction as a `Num` type. """ -function fourier_cos_term(x, ω, t) - return _fourier_term(x, ω, t, cos) -end - -"Simplify fraction a/b + c/d = (ad + bc)/bd" -add_div(x) = wrap(Postwalk(add_with_div; maketerm=frac_maketerm)(unwrap(x))) +# Simplify fraction a/b + c/d = (ad + bc)/bd +add_div(x) = wrap(Postwalk(add_with_div)(unwrap(x))) """ fourier_sin_term(x, ω, t) @@ -101,16 +142,138 @@ function _fourier_term(x::Equation, ω, t, f) return Equation(_fourier_term(x.lhs, ω, t, f), _fourier_term(x.rhs, ω, t, f)) end +_real_if_complex(v) = v isa Complex{Num} ? v.re : v + +_canonicalize(ft) = _real_if_complex(Symbolics.simplify(Symbolics.expand(ft))) + +function _cleanup_fourier_term(ft) + ft = _real_if_complex(ft) + ft = _strip_real_imag(ft) + ft = _canonicalize(ft) + ft = _simplify_trig_zero(Num(ft)) + ft = simplify_complex(Symbolics.expand(ft)) + ft = _real_if_complex(ft) + ft = _strip_real_imag(ft) + ft = _normalize_trig_signs(unwrap(ft)) + ft = _strip_zero_imag_literals(wrap(ft)) + ft = _canonicalize(ft) + ft = _real_if_complex(ft) + return ft +end + "Return the coefficient of f(ωt) in `x` where `f` is a cos or sin." function _fourier_term(x, ω, t, f) term = x * f(ω * t) term = trig_reduce(term) indep = get_independent(term, t) - ft = Num(simplify_complex(Symbolics.expand(indep))) + ft = simplify_complex(Symbolics.expand(indep)) ft = !isequal(ω, 0) ? 2 * ft : ft # extra factor in case ω = 0 ! - return Symbolics.expand(ft) + ft = _cleanup_fourier_term(ft) + return ft isa Num ? ft : Num(ft) +end + +function _strip_real_imag(x::Complex{Num}) + return _strip_real_imag(x.re) + im * _strip_real_imag(x.im) +end + +_postwalk(f, x::BasicSymbolic) = Postwalk(f)(x) +_postwalk(f, x::Num) = wrap(_postwalk(f, unwrap(x))) +_postwalk(f, x::Complex{Num}) = _postwalk(f, x.re) + im * _postwalk(f, x.im) +_postwalk(f, x) = x + +function _strip_zero_imag_literals(x) + _postwalk(ex -> begin + v = SymbolicUtils.unwrap_const(ex) + return (v isa Complex && iszero(imag(v))) ? real(v) : ex + end, x) +end + +function _strip_real_imag(x::BasicSymbolic) + function _real_of(ex::BasicSymbolic) + if isadd(ex) + return sum(_real_of.(arguments(ex))) + elseif ismul(ex) + coeff_val = SymbolicUtils.unwrap_const(ex.coeff) + if coeff_val isa Number + r = real(coeff_val) + rest = prod( + ( + f for + f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) + ); + init=1, + ) + return iszero(r) ? 0 : r * rest + end + return ex + elseif isdiv(ex) + return _real_of(ex.num) / ex.den + else + v = SymbolicUtils.unwrap_const(ex) + return v isa Number ? real(v) : ex + end + end + + function _imag_of(ex::BasicSymbolic) + if isadd(ex) + return sum(_imag_of.(arguments(ex))) + elseif ismul(ex) + coeff_val = SymbolicUtils.unwrap_const(ex.coeff) + if coeff_val isa Number + i = imag(coeff_val) + rest = prod( + ( + f for + f in arguments(ex) if !(SymbolicUtils.unwrap_const(f) isa Number) + ); + init=1, + ) + return iszero(i) ? 0 : i * rest + end + return 0 + elseif isdiv(ex) + return _imag_of(ex.num) / ex.den + else + v = SymbolicUtils.unwrap_const(ex) + return v isa Number ? imag(v) : 0 + end + end + + return Postwalk(ex -> begin + if isterm(ex) + op = operation(ex) + if op === real || nameof(op) == :real + return _real_of(first(arguments(ex))) + elseif op === imag || nameof(op) == :imag + return _imag_of(first(arguments(ex))) + end + end + return ex + end)(x) end +function _strip_real_imag(x::Num) + return wrap(_strip_real_imag(unwrap(x))) +end + +function _simplify_trig_zero(x::BasicSymbolic) + return Postwalk(ex -> begin + if isterm(ex) + op = operation(ex) + if op === sin || op === cos + arg = first(arguments(ex)) + arg_val = SymbolicUtils.unwrap_const(arg) + if arg_val isa Number && iszero(arg_val) + return op === sin ? 0 : 1 + end + end + end + return ex + end)(x) +end + +_simplify_trig_zero(x::Num) = wrap(_simplify_trig_zero(unwrap(x))) + """ trig_to_exp(x::Num) @@ -120,35 +283,9 @@ using Euler's formula: ``\\exp(ix) = \\cos(x) + i*\\sin(x)``. Returns the converted expression as a `Num` type. """ function trig_to_exp(x::Num) - all_terms = get_all_terms(x) - trigs = filter(z -> is_trig(z), all_terms) - - rules = [] - for trig in trigs - is_pow = ispow(trig.val) # trig is either a trig or a power of a trig - power = is_pow ? trig.val.exp : 1 - arg = is_pow ? arguments(trig.val.base)[1] : arguments(trig.val)[1] - type = is_pow ? operation(trig.val.base) : operation(trig.val) - - if type == cos - term = Complex{Num}((exp(im * arg) + exp(-im * arg))^power * (1//2)^power, 0) - elseif type == sin - term = - (1 * im^power) * - Complex{Num}(((exp(-im * arg) - exp(im * arg)))^power * (1//2)^power, 0) - end - # avoid Complex{Num} where possible as this causes bugs - # instead, the Nums store SymbolicUtils Complex types - term = Num(Symbolics.expand(term.re.val + im * term.im.val)) - append!(rules, [trig => term]) - end - - result = Symbolics.substitute(x, Dict(rules)) - return convert_to_Num(result) + return Num(trig_to_exp(x.val)) end trig_to_exp(x::Complex{Num}) = trig_to_exp(x.re) + im * trig_to_exp(x.im) -convert_to_Num(x::Complex{Num})::Num = Num(first(x.re.val.arguments)) -convert_to_Num(x::Num)::Num = x """ trig_to_exp(x::BasicSymbolic) @@ -163,9 +300,16 @@ function trig_to_exp(x::BasicSymbolic) rules = [] for trig in trigs is_pow = ispow(trig) # trig is either a trig or a power of a trig - power = is_pow ? trig.exp : 1 - arg = is_pow ? arguments(trig.base)[1] : arguments(trig)[1] - type = is_pow ? operation(trig.base) : operation(trig) + if is_pow + base, exponent = arguments(trig) + power = exponent + arg = arguments(base)[1] + type = operation(base) + else + power = 1 + arg = arguments(trig)[1] + type = operation(trig) + end if type == cos term = (exp(im * arg) + exp(-im * arg))^power * (1 // 2)^power @@ -197,9 +341,9 @@ trigonometric arguments for consistent simplification. """ function exp_to_trig(x::BasicSymbolic) if isadd(x) || isdiv(x) || ismul(x) - return _apply_termwise(exp_to_trig, x) - elseif isterm(x) && x.f == exp - arg = first(x.arguments) + return Symbolics.simplify(_normalize_trig_signs(_apply_termwise(exp_to_trig, x))) + elseif isterm(x) && operation(x) === exp + arg = first(arguments(x)) trigarg = Symbolics.expand(-im * arg) # the argument of the to-be trig function trigarg = simplify_complex(trigarg) @@ -218,16 +362,56 @@ function exp_to_trig(x::BasicSymbolic) cos(trigarg) + im * sin(trigarg) end end - return if ismul(trigarg) && trigarg.coeff < 0 - cos(-trigarg) - im * sin(-trigarg) - else - cos(trigarg) + im * sin(trigarg) + if ismul(trigarg) + coeff = trigarg.coeff + coeff_val = SymbolicUtils.unwrap_const(coeff) + if coeff_val isa Real && coeff_val < 0 + return cos(-trigarg) - im * sin(-trigarg) + end end + return _normalize_trig_signs(cos(trigarg) + im * sin(trigarg)) else return x end end +function _normalize_trig_signs(x::BasicSymbolic) + if isadd(x) || ismul(x) + args = SymbolicUtils.sorted_arguments(x) + return (isadd(x) ? sum : prod)(_normalize_trig_signs.(args)) + elseif isdiv(x) + return _normalize_trig_signs(x.num) / _normalize_trig_signs(x.den) + elseif isterm(x) + op = operation(x) + if op === real || op === imag || nameof(op) == :real || nameof(op) == :imag + arg = first(arguments(x)) + return op(_normalize_trig_signs(arg)) + elseif op === sin || op === cos + arg = first(arguments(x)) + if SymbolicUtils.isnegative(arg) + new_arg = -arg + return if op === sin + -SymbolicUtils.term(sin, new_arg) + else + SymbolicUtils.term(cos, new_arg) + end + elseif ismul(arg) + coeff_val = SymbolicUtils.unwrap_const(arg.coeff) + # Handle coefficients that are complex with zero imaginary part, e.g. (-2 + 0im)θ. + if coeff_val isa Number && isreal(coeff_val) && real(coeff_val) < 0 + new_arg = -arg + return if op === sin + -SymbolicUtils.term(sin, new_arg) + else + SymbolicUtils.term(cos, new_arg) + end + end + end + end + end + return x +end + exp_to_trig(x) = x exp_to_trig(x::Num) = exp_to_trig(x.val) exp_to_trig(x::Complex{Num}) = exp_to_trig(x.re) + im * exp_to_trig(x.im) diff --git a/src/Variables.jl b/src/Variables.jl index 4b04e2f..099857f 100644 --- a/src/Variables.jl +++ b/src/Variables.jl @@ -24,8 +24,8 @@ end "Return the name of a variable (excluding independent variables)" function var_name(x::Num)::String - var = Symbolics._toexpr(x) - var = var isa Expr ? String(var.args[1]) : String(var) + var = string(x) + var = replace(var, r"\(.*\)$" => "") return String(replace(var, r"\\mathtt\{([^}]*)\}" => s"\1")) # ^ remove "\\mathtt{}" from the variable name coming from Symbolics # since Symbolics v6.14.1 (Symbolics#1305) diff --git a/src/utils.jl b/src/utils.jl index fe063ea..9bbc5d1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,6 +4,17 @@ end flatten(a) = collect(Iterators.flatten(a)) +"""Strip Symbolics derivative wrappers to recover the base dependent variable.""" +strip_derivative(x::Num) = wrap(strip_derivative(unwrap(x))) +function strip_derivative(x::BasicSymbolic) + y = x + while Symbolics.is_derivative(y) + y = first(arguments(y)) + end + return y +end +strip_derivative(x) = x + "Show fields of an object." function show_fields(object) for field in fieldnames(typeof(object)) # display every field @@ -12,13 +23,53 @@ function show_fields(object) end end +_is_symbolic_like(x) = x isa Num || x isa BasicSymbolic + +function _eqtest_symbolic_scalar(a, b) + diff = Symbolics.simplify(Symbolics.expand(a - b)) + return isequal(diff, 0) +end + +function _eqtest_equal(a::Complex, b::Complex) + _eqtest_equal(real(a), real(b)) && _eqtest_equal(imag(a), imag(b)) +end +_eqtest_equal(a::Complex, b) = _eqtest_equal(real(a), b) && _eqtest_equal(imag(a), 0) +_eqtest_equal(a, b::Complex) = _eqtest_equal(a, real(b)) && _eqtest_equal(0, imag(b)) + +function _eqtest_equal(a::AbstractArray, b::AbstractArray) + size(a) == size(b) || return false + return all(i -> _eqtest_equal(a[i], b[i]), eachindex(a, b)) +end + +function _eqtest_equal(a::Tuple, b::Tuple) + length(a) == length(b) || return false + return all(i -> _eqtest_equal(a[i], b[i]), eachindex(a)) +end + +function _eqtest_equal(a::Equation, b::Equation) + _eqtest_equal(a.lhs, b.lhs) && _eqtest_equal(a.rhs, b.rhs) +end + +function _eqtest_equal(a, b) + isequal(a, b) && return true + (_is_symbolic_like(a) || _is_symbolic_like(b)) || return false + return try + _eqtest_symbolic_scalar(a, b) + catch + false + end +end + +_eqtest_notequal(a, b) = !_eqtest_equal(a, b) + macro eqtest(expr) @assert expr.head == :call && expr.args[1] in [:(==), :(!=)] + qb = QuoteNode(QuestBase) return esc( if expr.args[1] == :(==) - :(@test isequal($(expr.args[2]), $(expr.args[3]))) + :(@test $qb._eqtest_equal($(expr.args[2]), $(expr.args[3]))) else - :(@test !isequal($(expr.args[2]), $(expr.args[3]))) + :(@test $qb._eqtest_notequal($(expr.args[2]), $(expr.args[3]))) end, ) end @@ -35,6 +86,6 @@ macro eqsym(expr) end is_identity(A::Matrix{Num}) = (@eqsym A == Matrix{Num}(LinearAlgebra.I, size(A)...)) -hasnan(x::Matrix{Num}) = any(my_isnan, unwrap.(x)) +hasnan(x::Matrix{Num}) = any(my_isnan, Symbolics.value.(x)) my_isnan(x) = isnan(x) my_isnan(x::BasicSymbolic) = false diff --git a/test/DifferentialEquations.jl b/test/DifferentialEquations.jl index 79b85bb..55559b1 100644 --- a/test/DifferentialEquations.jl +++ b/test/DifferentialEquations.jl @@ -41,7 +41,9 @@ using QuestBase: expr = d(x, t, 2) + ω0^2 * x diff_eq3 = DifferentialEquation(expr, x) @test length(diff_eq3.equations) == 1 - @test diff_eq3.equations[x].rhs == 0 + rhs_simplified = Symbolics.simplify(diff_eq3.equations[x].rhs) + rhs_val = Symbolics.value(rhs_simplified) + @test rhs_val isa Number && iszero(rhs_val) # Test empty constructor diff = DifferentialEquation() diff --git a/test/HarmonicEquation.jl b/test/HarmonicEquation.jl index 9bff61f..4cfb7e0 100644 --- a/test/HarmonicEquation.jl +++ b/test/HarmonicEquation.jl @@ -22,7 +22,7 @@ using QuestBase: # Setup common test variables @variables t, T -@variables x(t) y(t) u(T) v(T) +@variables x(t) y(t) u(T) v(T) w(T) D = Differential(T) # Create simple test equation @@ -78,6 +78,8 @@ end rules = Dict(u => a) subbed = substitute_all(heq, rules) @test !isequal(subbed.equations, heq.equations) + @test !isequal(subbed.variables, heq.variables) + @eqtest subbed.variables[1].symbol == a end @testset "Utility functions" begin @@ -96,3 +98,18 @@ end list = unique(filter(x -> !(x isa Real), Symbolics.unwrap.(reduce(vcat, list)))) @test all(map(x -> !hasproperty(x, :arguments), list)) end + +@testset "is_rearranged MF path" begin + # No derivative terms on either side => treated as arranged by construction (MF_bool). + eq_alg1 = u ~ u + v + eq_alg2 = v ~ u + heq_alg = HarmonicEquation([eq_alg1, eq_alg2], [hv1, hv2], nat_eq) + @test is_rearranged(heq_alg) +end + +@testset "Hopf variable filtered from ansatz" begin + hv_hopf = HarmonicVariable(w, "test", "Hopf", Num(1.0), x) + heq_hopf = HarmonicEquation([eq1, eq2], [hv1, hv2, hv_hopf], nat_eq) + ans = QuestBase._show_ansatz(heq_hopf) + @test !occursin(string(hv_hopf.symbol), ans) +end diff --git a/test/runtests.jl b/test/runtests.jl index 284cf5f..8ec9a27 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,10 @@ end include("symbolics.jl") end +@testset "Utils" begin + include("utils.jl") +end + @testset "DifferentialEquations" begin include("DifferentialEquations.jl") end diff --git a/test/symbolics.jl b/test/symbolics.jl index 56c1aa9..4f24a96 100644 --- a/test/symbolics.jl +++ b/test/symbolics.jl @@ -1,8 +1,8 @@ using Test using Symbolics -using SymbolicUtils: Fixpoint, Prewalk, PassThrough, BasicSymbolic +using SymbolicUtils: BasicSymbolic -using QuestBase: @eqtest +using QuestBase: @eqtest, trig_reduce @testset "exp(x)^n => exp(x*n)" begin using QuestBase: expand_all, expand_exp_power @@ -12,7 +12,6 @@ using QuestBase: @eqtest @eqtest simplify(exp(a)^3) == exp(3 * a) @eqtest simplify(exp(a)^n) == exp(n * a) @eqtest expand_all(exp(a)^3) == exp(3 * a) - @eqtest expand_all(exp(a)^3) == exp(3 * a) @eqtest expand_all(im * exp(a)^5) == im * exp(5 * a) end @@ -27,7 +26,7 @@ end @testset "euler" begin @variables a b - @eqtest cos(a) + im * sin(a) == exp(im * a) + @test isequal(trig_reduce(cos(a) + im * sin(a) - exp(im * a)), 0) @eqtest exp(a) * cos(b) + im * sin(b) * exp(a) == exp(a + im * b) end @@ -52,29 +51,35 @@ end # eq = drop_powers(a^2 + a ~ b, [a, b], 2) # broken @eqtest [eq.lhs, eq.rhs] == [a, a] eq = drop_powers(a^2 + a + b ~ a, a, 2) - @test string(eq.rhs) == "a" broken = true + @eqtest eq.lhs == a + b + @eqtest eq.rhs == a @eqtest drop_powers([a^2 + a + b, b], a, 2) == [a + b, b] @eqtest drop_powers([a^2 + a + b, b], [a, b], 2) == [a + b, b] + + @testset "Vector{Equation}" begin + eqs = [a^2 + a ~ a, b^3 + b ~ 0] + out = drop_powers(eqs, [a], 2) + @eqtest [out[1].lhs, out[1].rhs] == [a, a] + @eqtest out[2] == eqs[2] + end end @testset "trig_to_exp and trig_to_exp" begin using QuestBase: expand_all, trig_to_exp, exp_to_trig @testset "Num" begin @variables f t - cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2 - sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im) - - # automatic conversion between trig and exp form + # Conversion between trig and exp form. + # We validate by substituting numeric values (robust across Symbolics canonicalization). trigs = [cos(f * t), sin(f * t)] - for (i, trig) in pairs(trigs) + samples = ((1.3, 0.7), (2.0, 0.1), (-0.4, 1.1)) + for trig in trigs z = trig_to_exp(trig) - @eqtest expand(exp_to_trig(z)) == trig - end - trigs′ = [cos_euler(f * t), sin_euler(f * t)] - for (i, trig) in pairs(trigs′) - z = trig_to_exp(trig) - @eqtest expand(exp_to_trig(z)) == trigs[i] + back = exp_to_trig(z) + for (fv, tv) in samples + d = Symbolics.substitute(back - trig, Dict(f => fv, t => tv)) + @test Symbolics.value(d) == 0 + end end end @@ -99,7 +104,6 @@ end @testset "harmonic" begin using QuestBase: is_harmonic - @variables a, b, c, t, x(t), f, y(t) @test is_harmonic(cos(f * t), t) @@ -117,7 +121,7 @@ end using QuestBase: fourier_cos_term, fourier_sin_term using QuestBase.Symbolics: expand - @variables f t θ a b + @variables f t θ a b c @eqtest fourier_cos_term(cos(f * t)^2, f, t) == 0 @eqtest fourier_sin_term(sin(f * t)^2, f, t) == 0 @@ -160,6 +164,40 @@ end @eqtest fourier_cos_term(cos(f * t)^3 + 1, 0, t) == 1 @eqtest fourier_cos_term(cos(f * t)^2 + 1, 0, t) == 3//2 @eqtest fourier_cos_term((cos(f * t)^2 + cos(f * t))^3, 0, t) == 23//16 + + function _check_fourier(term, specs) + for (ω, cos_expected, sin_expected) in specs + @eqtest fourier_cos_term(term, ω, t) == cos_expected + @eqtest fourier_sin_term(term, ω, t) == sin_expected + end + return nothing + end + + # more complex but closed-form cases + for (term, specs) in ( + ( + (a + b * cos(f * t))^2, + ((f, 2 * a * b, 0), (2 * f, b^2 / 2, 0), (0, a^2 + b^2 / 2, 0)), + ), + ( + (a + b * sin(f * t))^2, + ((f, 0, 2 * a * b), (2 * f, -b^2 / 2, 0), (0, a^2 + b^2 / 2, 0)), + ), + ( + (a + b * cos(f * t + θ)) * (a + b * cos(f * t - θ)), + ( + (f, 2 * a * b * cos(θ), 0), + (2 * f, b^2 / 2, 0), + (0, a^2 + b^2 / 2 * cos(2 * θ), 0), + ), + ), + ( + (a + b * cos(f * t))^3, + ((f, 3 * a^2 * b + 3//4 * b^3, 0), (0, a^3 + 3//2 * a * b^2, 0)), + ), + ) + _check_fourier(term, specs) + end end @testset "_apply_termwise" begin @@ -176,24 +214,26 @@ end using QuestBase: simplify_complex @variables a, b, c for z in Complex{Num}[a, a * b, a / b] - @test simplify_complex(z).val isa BasicSymbolic{Real} + @test simplify_complex(z).val isa BasicSymbolic end z = Complex{Num}(1 + 0 * im) - @test simplify_complex(z).val isa Int64 + z_val = SymbolicUtils.unwrap_const(simplify_complex(z).val) + @test z_val isa Number && z_val == 1 end @testset "get_all_terms" begin using QuestBase: get_all_terms @variables a, b, c - @eqtest get_all_terms(a + b + c) == [a, b, c] - @eqtest get_all_terms(a * b * c) == [a, b, c] - @eqtest get_all_terms(a / b) == [a, b] - @eqtest get_all_terms(a^2 + b^2 + c^2) == [b^2, a^2, c^2] - @eqtest get_all_terms(a^2 / b^2) == [a^2, b^2] - @eqtest get_all_terms(2 * b^2) == [2, b^2] - @eqtest get_all_terms(2 * b^2 ~ a) == [2, b^2, a] + @eqtest sort(get_all_terms(a + b + c); by=string) == sort([a, b, c]; by=string) + @eqtest sort(get_all_terms(a * b * c); by=string) == sort([a, b, c]; by=string) + @eqtest sort(get_all_terms(a / b); by=string) == sort([a, b]; by=string) + @eqtest sort(get_all_terms(a^2 + b^2 + c^2); by=string) == + sort([a^2, b^2, c^2]; by=string) + @eqtest sort(get_all_terms(a^2 / b^2); by=string) == sort([a^2, b^2]; by=string) + @eqtest sort(get_all_terms(2 * b^2); by=string) == sort([2, b^2]; by=string) + @eqtest sort(get_all_terms(2 * b^2 ~ a); by=string) == sort([2, b^2, a]; by=string) end @testset "get_independent" begin @@ -218,8 +258,9 @@ end @variables a, b, c, d @eqtest expand_fraction((a + b) / c) == a / c + b / c - @test string.(expand_fraction(d * (a + b) / c)) == "d*a /c + d*b / c + d / c" broken = - true + lhs = Symbolics.expand(expand_fraction(d * (a + b) / c)) + rhs = d * a / c + d * b / c + @eqtest lhs == rhs end @testset "count_derivatives" begin using QuestBase: count_derivatives, d @@ -243,4 +284,24 @@ end @eqtest substitute_all(a * b * c * d * e * f * g * h, rules) == b^2 * d^2 * f^2 * h^2 @eqtest substitute_all([a, c, e], rules) == [b, d, f] @eqtest substitute_all(a + b * im, rules) == b + b * im + + @testset "include_derivatives rewrites Differential(var)" begin + @variables t T x(t) + D = Differential(t) + expr = D(x) + + # Symbolics substitution currently rewrites arguments but does not rewrite the + # derivative operator itself (i.e., the `operation` remains `Differential(t)`). + expected_arg = Symbolics.unwrap(Symbolics.substitute(x, Dict(t => T))) + + out = substitute_all(expr, Dict(t => T); include_derivatives=true) + out_bs = Symbolics.unwrap(out) + @test Symbolics.operation(out_bs) == Differential(t) + @test isequal(first(Symbolics.arguments(out_bs)), expected_arg) + + out_no = substitute_all(expr, Dict(t => T); include_derivatives=false) + out_no_bs = Symbolics.unwrap(out_no) + @test Symbolics.operation(out_no_bs) == Differential(t) + @test isequal(first(Symbolics.arguments(out_no_bs)), expected_arg) + end end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..7b2e5b7 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,46 @@ +using Test +using Symbolics +using LinearAlgebra + +using QuestBase: strip_derivative, dummy_symbolic_Jacobian, is_identity, hasnan, d, @eqtest + +@testset "strip_derivative" begin + @variables t x(t) y(t) + + @eqtest strip_derivative(x) == x + @eqtest strip_derivative(d(x, t)) == x + @eqtest strip_derivative(d(d(x, t), t)) == x + + # should leave non-derivative expressions untouched + @eqtest strip_derivative(x + y) == x + y +end + +@testset "eqtest_equal containers and complex" begin + @variables a b c + + @eqtest [a, b] == [a, b] + @test !QuestBase._eqtest_equal([a], [a, b]) + + @eqtest (a, b, c) == (a, b, c) + @test !QuestBase._eqtest_equal((a, b), (a, c)) + + z = a + im * b + @eqtest z == z + @eqtest (a + 0im) == a + @eqtest a == (a + 0im) +end + +@testset "identity and NaN helpers" begin + @variables a + + I2 = Matrix{Num}(LinearAlgebra.I, 2, 2) + @test is_identity(I2) + + A = Num[1 0; 0 (1 + a)] + @test !is_identity(A) + + J = dummy_symbolic_Jacobian(3) + @test hasnan(J) + @test !hasnan(I2) + @test !is_identity(J) +end