Skip to content

Commit 7e1dc44

Browse files
committed
Refactor @init alias system to use compile-time substitution
- Replace runtime alias execution with parse-time substitution via subs() - Enable time-dependent aliases like phi = 2π * t - Support accumulated aliases like a = t; s = a - Handle literal grid specifications after alias expansion - Add strict mode for unrecognized statements - Add comprehensive tests for new functionality - Align @init alias mechanism with @def implementation All 157 tests pass successfully.
1 parent 9304906 commit 7e1dc44

2 files changed

Lines changed: 200 additions & 25 deletions

File tree

src/initial_guess.jl

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ or a time grid, based on whether `arg` matches `time_name(ocp)`.
108108
109109
- `pref::Symbol`: backend module prefix (e.g. `:CTModels`).
110110
- `ocp`: symbolic OCP variable passed from the macro.
111-
- `arg::Symbol`: argument symbol used in the specification (e.g. `:t`, `:s`, `:T`).
111+
- `arg`: argument used in the specification (e.g. `:t`, `:s`, or a literal array after alias expansion).
112112
- `rhs`: right-hand side expression.
113113
- `arg_in_rhs::Bool`: whether `arg` appears in `rhs` (computed at parse-time via `has`).
114114
@@ -119,14 +119,25 @@ or a time grid, based on whether `arg` matches `time_name(ocp)`.
119119
120120
# Notes
121121
122-
When `arg_in_rhs` is `true`, the specification is definitely a time-dependent
123-
function, so we validate that `arg == Symbol(time_name(ocp))` and throw an
124-
error if not. When `arg_in_rhs` is `false`, we generate a runtime conditional
122+
When `arg` is not a Symbol (e.g., a literal array after alias expansion like
123+
`[0.0, 0.5, 1.0]`), it is always treated as a time grid specification.
124+
125+
When `arg` is a Symbol and `arg_in_rhs` is `true`, the specification is definitely
126+
a time-dependent function, so we validate that `arg == Symbol(time_name(ocp))` and
127+
throw an error if not. When `arg_in_rhs` is `false`, we generate a runtime conditional
125128
that checks whether `arg` matches the time name to decide between a constant
126129
function or a time grid.
127130
"""
128131
function __gen_temporal_value(pref, ocp, arg, rhs, arg_in_rhs)
129132
val_sym = __symgen(:init_val)
133+
134+
# Early return: if arg is not a Symbol (e.g., literal array after alias expansion),
135+
# it's always a grid specification
136+
if !(arg isa Symbol)
137+
code = :($val_sym = ($arg, $rhs))
138+
return val_sym, code
139+
end
140+
130141
arg_quoted = QuoteNode(arg)
131142

132143
if arg_in_rhs
@@ -232,7 +243,15 @@ function __log_spec(key, spec)
232243
rhs
233244
end
234245
rhs_str = sprint(Base.show_unquoted, rhs_clean)
235-
return string(key, " = ", arg, " -> ", rhs_str)
246+
247+
# If arg is not a Symbol (e.g., literal array after alias expansion),
248+
# format as grid specification
249+
if arg isa Symbol
250+
return string(key, " = ", arg, " -> ", rhs_str)
251+
else
252+
arg_str = sprint(Base.show_unquoted, arg)
253+
return string(key, " = (", arg_str, ", ", rhs_str, ")")
254+
end
236255
else
237256
return string(key, " = ???")
238257
end
@@ -245,17 +264,21 @@ Internal helper that parses the body of an `@init` block.
245264
246265
The function walks through the expression `ex` and splits it into
247266
248-
- *alias statements*, which are left as ordinary Julia assignments and
249-
executed verbatim inside the generated block;
267+
- *alias statements* of the form `lhs = rhs`, which are stored in a dictionary
268+
and substituted into subsequent statements at parse-time using `subs`;
250269
- *initialisation specifications* of the form `lhs := rhs` or
251270
`lhs(arg) := rhs`, which are converted into structured specification
252-
tuples.
271+
tuples after alias expansion.
253272
254273
For expressions of the form `lhs(arg) := rhs`, this function uses `has(rhs, arg)`
255274
to determine whether `arg` appears in the right-hand side. This information
256275
is stored in the specification tuple and used later to generate appropriate
257276
runtime code that distinguishes time-dependent functions from time grids.
258277
278+
Alias substitution happens before each statement is matched, enabling
279+
time-dependent aliases like `phi = 2π * t` and accumulated aliases like
280+
`a = t; s = a`.
281+
259282
# Arguments
260283
261284
- `ex::Any`: expression or block coming from the body of `@init`.
@@ -264,16 +287,14 @@ runtime code that distinguishes time-dependent functions from time grids.
264287
265288
# Returns
266289
267-
- `alias_stmts::Vector{Expr}`: ordinary statements to execute before
268-
building the initial guess.
269290
- `keys::Vector{Symbol}`: names of the components being initialised
270291
(e.g. `:q`, `:v`, `:u`, `:tf`).
271292
- `specs::Vector{Tuple}`: specification tuples, either `(:constant, rhs)`
272293
for constant values or `(:temporal, arg, rhs, arg_in_rhs)` for temporal
273294
specifications where `arg_in_rhs` indicates whether `arg` appears in `rhs`.
274295
"""
275296
function _collect_init_specs(ex, lnum::Int, line_str::String)
276-
alias_stmts = Expr[] # statements of the form a = ... or other Julia statements
297+
aliases = OrderedCollections.OrderedDict{Union{Symbol,Expr}, Any}()
277298
keys = Symbol[] # keys of the NamedTuple (q, v, x, u, tf, ...)
278299
specs = Tuple[] # specification tuples
279300

@@ -286,20 +307,27 @@ function _collect_init_specs(ex, lnum::Int, line_str::String)
286307
for st in stmts
287308
st isa LineNumberNode && continue
288309

310+
# Substitute all known aliases before matching
311+
for a in Base.keys(aliases)
312+
st = subs(st, a, aliases[a])
313+
end
314+
289315
@match st begin
290-
# Alias / ordinary Julia assignments left as-is
316+
# Alias: store for future substitution
291317
:($lhs = $rhs) => begin
292-
push!(alias_stmts, st)
318+
lhs isa Symbol || error("Unsupported alias left-hand side in @init: $lhs (only Symbol allowed)")
319+
aliases[lhs] = rhs
293320
end
294321

295322
# Forms q(arg) := rhs
296-
# Use has(rhs, arg) to determine if arg appears in rhs
323+
# After alias expansion, arg may be a Symbol or an Expr (literal grid)
297324
:($lhs($arg) := $rhs) => begin
298325
lhs isa Symbol || error("Unsupported left-hand side in @init: $lhs")
299-
arg isa Symbol || error("Unsupported argument in @init: $arg must be a symbol")
300326

301327
# Check if arg appears in rhs using has() from utils.jl
302-
arg_in_rhs = has(rhs, arg)
328+
# Note: if arg is not a Symbol (e.g., after alias expansion to a literal array),
329+
# has() will return false, which is correct for grid specifications
330+
arg_in_rhs = (arg isa Symbol) && has(rhs, arg)
303331

304332
push!(keys, lhs)
305333
push!(specs, (:temporal, arg, rhs, arg_in_rhs))
@@ -312,14 +340,14 @@ function _collect_init_specs(ex, lnum::Int, line_str::String)
312340
push!(specs, (:constant, rhs))
313341
end
314342

315-
# Fallback: any other line is treated as an ordinary Julia statement
343+
# Fallback: strict mode - reject unrecognized statements
316344
_ => begin
317-
push!(alias_stmts, st)
345+
error("Unrecognized statement in @init block: $st. Only alias assignments (a = expr) and specifications (lhs := rhs or lhs(arg) := rhs) are allowed.")
318346
end
319347
end
320348
end
321349

322-
return alias_stmts, keys, specs
350+
return keys, specs
323351
end
324352

325353
"""
@@ -352,24 +380,20 @@ macro level.
352380
initial guess when executed.
353381
"""
354382
function init_fun(ocp, e, lnum::Int, line_str::String)
355-
alias_stmts, keys, specs = _collect_init_specs(e, lnum, line_str)
383+
keys, specs = _collect_init_specs(e, lnum, line_str)
356384
pref = init_prefix()
357385

358386
# If there is no init specification, delegate to build_initial_guess/validate_initial_guess
359387
if isempty(keys)
360-
body_stmts = Any[]
361-
append!(body_stmts, alias_stmts)
362388
build_call = :($pref.build_initial_guess($ocp, ()))
363389
validate_call = :($pref.validate_initial_guess($ocp, $build_call))
364-
push!(body_stmts, validate_call)
365-
code_expr = Expr(:block, body_stmts...)
390+
code_expr = validate_call
366391
log_str = "()"
367392
return log_str, code_expr
368393
end
369394

370395
# Generate runtime code for each specification
371396
body_stmts = Any[]
372-
append!(body_stmts, alias_stmts)
373397

374398
val_syms = Symbol[]
375399
for spec in specs

test/test_initial_guess.jl

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,4 +692,155 @@ function test_initial_guess() # debug
692692
@test occursin("\"s\"", err_msg)
693693
@test occursin(":s", err_msg)
694694
end
695+
696+
@testset "time-dependent alias (phi = 2pi * t)" begin
697+
ocp_circle = @def begin
698+
t [0, 1], time
699+
x R², state
700+
u R², control
701+
x(0) == [0, 0]
702+
x(1) == [1, 1]
703+
(t) == u(t)
704+
(u(t)' * u(t)) min
705+
end
706+
707+
ig = @init ocp_circle begin
708+
phi = 2π * t # Alias depending on time variable
709+
u(t) := [cos(phi), sin(phi)]
710+
end
711+
712+
@test ig isa CTModels.AbstractInitialGuess
713+
CTModels.validate_initial_guess(ocp_circle, ig)
714+
715+
ufun = CTModels.control(ig)
716+
u0 = ufun(0.0)
717+
u1 = ufun(0.5)
718+
719+
@test u0[1] cos(0.0)
720+
@test u0[2] sin(0.0)
721+
@test u1[1] cos(π)
722+
@test u1[2] sin(π) atol=1e-10
723+
end
724+
725+
@testset "time variable substitution (s = t)" begin
726+
ocp_circle = @def begin
727+
t [0, 1], time
728+
x R², state
729+
u R², control
730+
x(0) == [0, 0]
731+
x(1) == [1, 1]
732+
(t) == u(t)
733+
(u(t)' * u(t)) min
734+
end
735+
736+
ig = @init ocp_circle begin
737+
s = t # Alias for time variable
738+
u(s) := [cos(s), sin(s)]
739+
end
740+
741+
@test ig isa CTModels.AbstractInitialGuess
742+
CTModels.validate_initial_guess(ocp_circle, ig)
743+
744+
ufun = CTModels.control(ig)
745+
u0 = ufun(0.0)
746+
u1 = ufun(0.5)
747+
748+
@test u0[1] cos(0.0)
749+
@test u0[2] sin(0.0)
750+
@test u1[1] cos(0.5)
751+
@test u1[2] sin(0.5)
752+
end
753+
754+
@testset "grid aliases (T, X, U as local variables)" begin
755+
ig = @init ocp_fixed begin
756+
T = [0.0, 0.5, 1.0]
757+
X = [[-1.0, 0.0], [0.0, 0.5], [0.0, 0.0]]
758+
U = [0.0, 0.0, 1.0]
759+
x(T) := X
760+
u(T) := U
761+
end
762+
763+
@test ig isa CTModels.AbstractInitialGuess
764+
CTModels.validate_initial_guess(ocp_fixed, ig)
765+
766+
xfun = CTModels.state(ig)
767+
ufun = CTModels.control(ig)
768+
769+
x0 = xfun(0.0)
770+
x1 = xfun(1.0)
771+
u0 = ufun(0.0)
772+
u1 = ufun(1.0)
773+
774+
@test x0[1] -1.0
775+
@test x0[2] 0.0
776+
@test x1[1] 0.0
777+
@test x1[2] 0.0
778+
@test u0 0.0
779+
@test u1 1.0
780+
end
781+
782+
@testset "accumulated aliases (a = t, s = a)" begin
783+
ocp_circle = @def begin
784+
t [0, 1], time
785+
x R², state
786+
u R², control
787+
x(0) == [0, 0]
788+
x(1) == [1, 1]
789+
(t) == u(t)
790+
(u(t)' * u(t)) min
791+
end
792+
793+
ig = @init ocp_circle begin
794+
a = t
795+
s = a
796+
u(s) := [cos(s), sin(s)]
797+
end
798+
799+
@test ig isa CTModels.AbstractInitialGuess
800+
CTModels.validate_initial_guess(ocp_circle, ig)
801+
802+
ufun = CTModels.control(ig)
803+
u0 = ufun(0.0)
804+
u1 = ufun(0.5)
805+
806+
@test u0[1] cos(0.0)
807+
@test u0[2] sin(0.0)
808+
@test u1[1] cos(0.5)
809+
@test u1[2] sin(0.5)
810+
end
811+
812+
@testset "grid aliases with literal arrays" begin
813+
ig = @init ocp_fixed begin
814+
X = [[-1.0, 0.0], [0.0, 0.5], [0.0, 0.0]]
815+
U = [0.0, 0.0, 1.0]
816+
x([0.0, 0.5, 1.0]) := X
817+
u([0.0, 0.5, 1.0]) := U
818+
end
819+
820+
@test ig isa CTModels.AbstractInitialGuess
821+
CTModels.validate_initial_guess(ocp_fixed, ig)
822+
823+
xfun = CTModels.state(ig)
824+
ufun = CTModels.control(ig)
825+
826+
x0 = xfun(0.0)
827+
x1 = xfun(1.0)
828+
u0 = ufun(0.0)
829+
u1 = ufun(1.0)
830+
831+
@test x0[1] -1.0
832+
@test x0[2] 0.0
833+
@test x1[1] 0.0
834+
@test x1[2] 0.0
835+
@test u0 0.0
836+
@test u1 1.0
837+
end
838+
839+
@testset "strict mode: unrecognized statement error" begin
840+
@test_throws CTBase.ParsingError Base.redirect_stdout(Base.devnull) do
841+
@init ocp_fixed begin
842+
println("This should fail")
843+
end
844+
end
845+
end
695846
end

0 commit comments

Comments
 (0)