Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "QuestBase"
uuid = "7e80f742-43d6-403d-a9ea-981410111d43"
authors = ["Orjan Ameye <orjan.ameye@hotmail.com>", "Jan Kosata <kosataj@phys.ethz.ch>", "Javier del Pino <jdelpino@phys.ethz.ch>"]
version = "0.3.4"
version = "0.4.0"

[deps]
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -12,16 +12,16 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
DocStringExtensions = "0.9.4"
SymbolicUtils = "3.25"
Symbolics = "6.34"
SymbolicUtils = "4.13.1"
Symbolics = "7.8.0"
julia = "1.10"
Random = "1.10"
LinearAlgebra = "1.10"
Test = "1.10"
OrderedCollections = "1.8"
Aqua = "0.8.11"
ExplicitImports = "1.11"
JET = "0.9.18, 0.10.0"
JET = "0.9.18, 0.10.0, 0.11"
CheckConcreteStructs = "0.1.0"

[extras]
Expand Down
9 changes: 6 additions & 3 deletions src/DifferentialEquation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ $(TYPEDSIGNATURES)
Return the independent dependent variables of `diff_eom`.
"""
function get_independent_variables(diff_eom::DifferentialEquation)
return Num.(flatten(unique([x.val.arguments for x in keys(diff_eom.equations)])))
return Num.(flatten(unique([arguments(x.val) for x in keys(diff_eom.equations)])))
end

"""
Expand Down Expand Up @@ -163,7 +163,10 @@ corresponding to second-order differential equations.
function is_rearranged_standard(eom::DifferentialEquation, degree=2)
tvar = get_independent_variables(eom)[1]
D = Differential(tvar)^degree
return isequal(getfield.(values(eom.equations), :lhs), D.(get_variables(eom)))
lhs = getfield.(values(eom.equations), :lhs)
rhs = D.(get_variables(eom))
diffs = Symbolics.simplify.(lhs .- rhs)
return all(is_literal_zero, diffs)
end

"""
Expand Down Expand Up @@ -195,7 +198,7 @@ function rearrange!(eom::DifferentialEquation, new_lhs::Vector{Num})
return nothing
end
function get_variables_nums(vars::Vector{Num})
unique(flatten([Num.(get_variables(x)) for x in vars]))
return strip_derivative.(vars)
end # TODO: remove this function or at least better names

"""
Expand Down
20 changes: 15 additions & 5 deletions src/HarmonicEquation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,20 @@ end

"Get the parameters (not time nor variables) of a HarmonicEquation"
function _parameters(eom::HarmonicEquation)
all_symbols = flatten([
cat(get_variables(eq.lhs), get_variables(eq.rhs); dims=1) for eq in eom.equations
])
# subtract the set of independent variables (i.e., time) from all free symbols
return setdiff(all_symbols, get_variables(eom), get_independent_variables(eom))
symbols = Num[]
for eq in eom.equations
vars = union(Symbolics.get_variables(eq.lhs), Symbolics.get_variables(eq.rhs))
vars = sort!(collect(vars); by=string)
for sym in vars
push!(symbols, strip_derivative(Num(sym)))
end
end
vars = Set(get_variables(eom))
indep = Set(get_independent_variables(eom))
params = filter(s -> !(s in vars || s in indep), symbols)
params = unique(params)
sort!(params; by=string)
return params
end

"""
Expand Down Expand Up @@ -107,6 +116,7 @@ Base.show(eom::HarmonicEquation) = show_fields(eom)
function substitute_all(eom::HarmonicEquation, rules::Union{Dict,Pair})::HarmonicEquation
new_eom = deepcopy(eom)
new_eom.equations = expand_derivatives.(substitute_all(eom.equations, rules))
new_eom.variables = substitute_all(eom.variables, rules)
return new_eom
end

Expand Down
2 changes: 1 addition & 1 deletion src/HarmonicVariable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function _show_ansatz(var::HarmonicVariable)
if isempty(var.type)
return string(var.symbol)
end
t = var.natural_variable.val.arguments
t = arguments(var.natural_variable.val)
t = length(t) == 1 ? string(t[1]) : error("more than 1 independent variable")
ω = string(var.ω)
terms = Dict("u" => "*cos(" * ω * t * ")", "v" => "*sin(" * ω * t * ")", "a" => "")
Expand Down
4 changes: 1 addition & 3 deletions src/QuestBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,21 @@ using LinearAlgebra: LinearAlgebra
using SymbolicUtils:
SymbolicUtils,
Postwalk,
Sym,
BasicSymbolic,
unwrap,
isterm,
ispow,
isadd,
isdiv,
ismul,
add_with_div,
frac_maketerm,
issym

using SymbolicUtils.Unityper: @compactified

using Symbolics:
Symbolics,
Num,
unwrap,
wrap,
get_variables,
Equation,
Expand Down
118 changes: 82 additions & 36 deletions src/Symbolics/Symbolics_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
expand_all(x::Num) = Num(expand_all(x.val))
_apply_termwise(f, x::Num) = wrap(_apply_termwise(f, unwrap(x)))

# Symbolics v7 note: `unwrap` now always returns `BasicSymbolic`; use `Symbolics.value`
# when you want the underlying literal number (via `Const`).
_literal_number(x::Num) = (v=Symbolics.value(x); v isa Number ? v : nothing)
function _literal_number(x::BasicSymbolic)
(v=SymbolicUtils.unwrap_const(x); v isa Number ? v : nothing)
end
is_literal_zero(x) = (v=_literal_number(x); v !== nothing && iszero(v))

"Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)"
function expand_all(x)
result = Postwalk(expand_exp_power)(SymbolicUtils.expand(x))
Expand All @@ -10,33 +18,49 @@ end
expand_all(x::Complex{Num}) = expand_all(x.re) + im * expand_all(x.im)

function expand_fraction(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
Add => _apply_termwise(expand_fraction, x)
Mul => _apply_termwise(expand_fraction, x)
Div => sum([arg / x.den for arg in arguments(x.num)])
_ => x
if isadd(x)
return _apply_termwise(expand_fraction, x)
elseif ismul(x)
return _apply_termwise(expand_fraction, x)
elseif isdiv(x)
# Only distribute division over addition in the numerator.
# In SymbolicUtils v4, `arguments(x.num)` for a multiplication returns factors,
# so splitting unconditionally would produce incorrect results (e.g. d*(a+b)/c).
num, den = x.num, x.den
if isadd(num)
return sum([expand_fraction(arg) / den for arg in arguments(num)])
end
return expand_fraction(num) / expand_fraction(den)
else
return x
end
end
expand_fraction(x::Num) = Num(expand_fraction(x.val))

"Apply a function f on every member of a sum or a product"
function _apply_termwise(f, x::BasicSymbolic)
@compactified x::BasicSymbolic begin
Add => sum([f(arg) for arg in arguments(x)])
Mul => prod([f(arg) for arg in arguments(x)])
Div => _apply_termwise(f, x.num) / _apply_termwise(f, x.den)
_ => f(x)
if isadd(x)
return sum([f(arg) for arg in arguments(x)])
elseif ismul(x)
return prod([f(arg) for arg in arguments(x)])
elseif isdiv(x)
return _apply_termwise(f, x.num) / _apply_termwise(f, x.den)
else
return f(x)
end
end

simplify_complex(x::Complex) = isequal(x.im, 0) ? x.re : x.re + im * x.im
simplify_complex(x) = x
function simplify_complex(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
Add => _apply_termwise(simplify_complex, x)
Mul => _apply_termwise(simplify_complex, x)
Div => _apply_termwise(simplify_complex, x)
_ => x
if isadd(x)
return _apply_termwise(simplify_complex, x)
elseif ismul(x)
return _apply_termwise(simplify_complex, x)
elseif isdiv(x)
return _apply_termwise(simplify_complex, x)
else
return x
end
end

Expand Down Expand Up @@ -68,22 +92,32 @@ function substitute_all(x::Complex{Num}, rules::Collections)
return substitute_all(x.re, rules) + im * substitute_all(x.im, rules)
end

get_independent(x::Num, t::Num)::Num = wrap(get_independent(x.val, t))
function get_independent(x::Num, t::Num)
result = get_independent(x.val, t)
return result isa Complex{Num} ? result : wrap(result)
end
function get_independent(x::Complex{Num}, t::Num)
return get_independent(x.re, t) + im * get_independent(x.im, t)
end
get_independent(v::Vector{Num}, t::Num) = [get_independent(el, t) for el in v]
get_independent(x, t::Num) = x

function get_independent(x::BasicSymbolic, t::Num)
@compactified x::BasicSymbolic begin
Add => sum([get_independent(arg, t) for arg in arguments(x)])
Mul => prod([get_independent(arg, t) for arg in arguments(x)])
Div => !is_function(x.den, t) ? get_independent(x.num, t) / x.den : 0
Pow => !is_function(x.base, t) && !is_function(x.exp, t) ? x : 0
Term => !is_function(x, t) ? x : 0
Sym => !is_function(x, t) ? x : 0
_ => x
if isadd(x)
return sum([get_independent(arg, t) for arg in arguments(x)])
elseif ismul(x)
return prod([get_independent(arg, t) for arg in arguments(x)])
elseif isdiv(x)
return !is_function(x.den, t) ? get_independent(x.num, t) / x.den : 0
elseif ispow(x)
base, exponent = arguments(x)
return !is_function(base, t) && !is_function(exponent, t) ? x : 0
elseif isterm(x)
return !is_function(x, t) ? x : 0
elseif issym(x)
return !is_function(x, t) ? x : 0
else
return x
end
end

Expand All @@ -94,11 +128,14 @@ function get_all_terms(x::Equation)
return unique(cat(get_all_terms(Num(x.lhs)), get_all_terms(Num(x.rhs)); dims=1))
end
function _get_all_terms(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
Add => vcat([_get_all_terms(term) for term in SymbolicUtils.arguments(x)]...)
Mul => SymbolicUtils.arguments(x)
Div => [_get_all_terms(x.num)..., _get_all_terms(x.den)...]
_ => [x]
if isadd(x)
return vcat([_get_all_terms(term) for term in SymbolicUtils.sorted_arguments(x)]...)
elseif ismul(x)
return SymbolicUtils.sorted_arguments(x)
elseif isdiv(x)
return [_get_all_terms(x.num)..., _get_all_terms(x.den)...]
else
return [x]
end
end
_get_all_terms(x) = x
Expand All @@ -109,11 +146,11 @@ function is_harmonic(x::Num, t::Num)::Bool
isempty(t_terms) && return true
trigs = is_trig.(t_terms)

if !prod(trigs)
if !all(trigs)
return false
else
powers = [max_power(first(term.val.arguments), t) for term in t_terms[trigs]]
return all(isone, powers)
powers = [max_power(first(arguments(term.val)), t) for term in t_terms[trigs]]
return all(isequal(1), powers)
end
end

Expand All @@ -126,10 +163,19 @@ is_function(f, var) = any(isequal.(get_variables(f), var))
"""
Counts the number of derivatives of a symbolic variable.
"""
function count_derivatives(x::Symbolics.BasicSymbolic)
(Symbolics.isterm(x) || Symbolics.issym(x)) ||
function count_derivatives(x::BasicSymbolic)
if Symbolics.is_derivative(x)
arg = first(arguments(x))
(
issym(arg) ||
Symbolics.is_derivative(arg) ||
(isterm(arg) && issym(operation(arg)))
) || error("The input is not a single term or symbol")
D = operation(x)
return D.order + count_derivatives(arg)
end
(issym(x) || (isterm(x) && issym(operation(x)))) ||
error("The input is not a single term or symbol")
bool = Symbolics.is_derivative(x)
return bool ? 1 + count_derivatives(first(arguments(x))) : 0
return 0
end
count_derivatives(x::Num) = count_derivatives(Symbolics.unwrap(x))
24 changes: 19 additions & 5 deletions src/Symbolics/drop_powers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ function drop_powers(expr::Num, vars::Vector{Num}, deg::Int)
Symbolics.@variables ϵ
subs_expr = deepcopy(expr)
rules = Dict([var => ϵ * var for var in unique(vars)])
subs_expr = Symbolics.expand(substitute_all(subs_expr, rules))
subs_expr = Symbolics.expand(
substitute_all(subs_expr, rules; include_derivatives=false)
)
max_deg = max_power(subs_expr, ϵ)
removal = Dict([ϵ^d => Num(0) for d in deg:max_deg])
res = substitute_all(substitute_all(subs_expr, removal), Dict(ϵ => Num(1)))
res = substitute_all(
substitute_all(subs_expr, removal; include_derivatives=false),
Dict(ϵ => Num(1));
include_derivatives=false,
)
return Symbolics.expand(res)
end

Expand All @@ -30,7 +36,7 @@ end

# calls the above for various types of the first argument
function drop_powers(eq::Equation, var::Vector{Num}, deg::Int)
return drop_powers(eq.lhs, var, deg) .~ drop_powers(eq.lhs, var, deg)
return drop_powers(eq.lhs, var, deg) .~ drop_powers(eq.rhs, var, deg)
end
function drop_powers(eqs::Vector{Equation}, var::Vector{Num}, deg::Int)
return [
Expand All @@ -45,7 +51,14 @@ drop_powers(x, vars, deg::Int) = drop_powers(wrap(x), vars, deg)
function max_power(x::Num, y::Num)
terms = get_all_terms(x)
powers = power_of.(terms, y)
return maximum(powers)
literal_powers = Int[]
for p in powers
pv = SymbolicUtils.unwrap_const(p)
pv isa Number || continue
push!(literal_powers, Int(pv))
end
isempty(literal_powers) && return 0
return maximum(literal_powers)
end

max_power(x::Vector{Num}, y::Num) = maximum(max_power.(x, y))
Expand All @@ -60,7 +73,8 @@ end

function power_of(x::BasicSymbolic, y::BasicSymbolic)
if ispow(x) && issym(y)
return isequal(x.base, y) ? x.exp : 0
base, exponent = arguments(x)
return isequal(base, y) ? exponent : 0
elseif issym(x) && issym(y)
return isequal(x, y) ? 1 : 0
else
Expand Down
Loading
Loading