Skip to content

Commit 9304906

Browse files
authored
Merge pull request #241 from control-toolbox/init
Init
2 parents 64073fc + f15196b commit 9304906

4 files changed

Lines changed: 357 additions & 61 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ docs/site/
2727
Manifest.toml
2828

2929
# Local reports (analysis, status reports, previews) should not be tracked
30-
reports/
30+
.reports/
31+
.windsurf/
3132

3233
# claude
3334
CLAUDE.local.md

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "CTParser"
22
uuid = "32681960-a1b1-40db-9bff-a1ca817385d1"
3-
version = "0.8.8"
3+
version = "0.8.9-beta"
44
authors = ["Jean-Baptiste Caillau <jean-baptiste.caillau@univ-cotedazur.fr>"]
55

66
[deps]

src/initial_guess.jl

Lines changed: 197 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -98,33 +98,184 @@ end
9898
"""
9999
$(TYPEDSIGNATURES)
100100
101+
Generate runtime code for a temporal specification `lhs(arg) := rhs`.
102+
103+
This function produces the Julia expression that will be evaluated at runtime
104+
to determine whether the specification represents a time-dependent function
105+
or a time grid, based on whether `arg` matches `time_name(ocp)`.
106+
107+
# Arguments
108+
109+
- `pref::Symbol`: backend module prefix (e.g. `:CTModels`).
110+
- `ocp`: symbolic OCP variable passed from the macro.
111+
- `arg::Symbol`: argument symbol used in the specification (e.g. `:t`, `:s`, `:T`).
112+
- `rhs`: right-hand side expression.
113+
- `arg_in_rhs::Bool`: whether `arg` appears in `rhs` (computed at parse-time via `has`).
114+
115+
# Returns
116+
117+
- `val_sym::Symbol`: generated symbol to store the computed value.
118+
- `code::Expr`: expression block to insert in the generated code.
119+
120+
# Notes
121+
122+
When `arg_in_rhs` is `true`, the specification is definitely a time-dependent
123+
function, so we validate that `arg == Symbol(time_name(ocp))` and throw an
124+
error if not. When `arg_in_rhs` is `false`, we generate a runtime conditional
125+
that checks whether `arg` matches the time name to decide between a constant
126+
function or a time grid.
127+
"""
128+
function __gen_temporal_value(pref, ocp, arg, rhs, arg_in_rhs)
129+
val_sym = __symgen(:init_val)
130+
arg_quoted = QuoteNode(arg)
131+
132+
if arg_in_rhs
133+
# arg appears in rhs → must be a time-dependent function
134+
# Validate at runtime that arg matches time_name(ocp)
135+
code = quote
136+
let _expected = Symbol($pref.time_name($ocp))
137+
if $arg_quoted != _expected
138+
error(
139+
"Incorrect time variable in @init: " *
140+
"used :" * string($arg_quoted) * " but time_name(ocp) is " *
141+
"\"" * $pref.time_name($ocp) * "\" " *
142+
"(expected :" * string(_expected) * "). " *
143+
"Please use :" * string(_expected) * " instead of :" * string($arg_quoted) * " " *
144+
"in your @init block."
145+
)
146+
end
147+
end
148+
$val_sym = $arg -> $rhs
149+
end
150+
else
151+
# arg does NOT appear in rhs → ambiguous
152+
# Runtime check: if arg matches time_name → constant function, else → grid
153+
code = quote
154+
$val_sym = if Symbol($pref.time_name($ocp)) == $arg_quoted
155+
$arg -> $rhs # constant time function
156+
else
157+
($arg, $rhs) # time grid
158+
end
159+
end
160+
end
161+
162+
return val_sym, code
163+
end
164+
165+
"""
166+
$(TYPEDSIGNATURES)
167+
168+
Generate runtime code for a single initialisation specification.
169+
170+
This function dispatches based on the specification kind (`:constant` or
171+
`:temporal`) and delegates to the appropriate code generator.
172+
173+
# Arguments
174+
175+
- `pref::Symbol`: backend module prefix.
176+
- `ocp`: symbolic OCP variable.
177+
- `spec::Tuple`: specification tuple, either `(:constant, rhs)` or
178+
`(:temporal, arg, rhs, arg_in_rhs)`.
179+
180+
# Returns
181+
182+
- `val_sym::Symbol`: generated symbol to store the value.
183+
- `code::Expr`: expression to insert in the generated code.
184+
"""
185+
function __gen_spec_value(pref, ocp, spec)
186+
kind = spec[1]
187+
if kind == :constant
188+
rhs = spec[2]
189+
val_sym = __symgen(:init_val)
190+
code = :($val_sym = $rhs)
191+
return val_sym, code
192+
elseif kind == :temporal
193+
arg, rhs, arg_in_rhs = spec[2], spec[3], spec[4]
194+
return __gen_temporal_value(pref, ocp, arg, rhs, arg_in_rhs)
195+
else
196+
error("Unknown spec kind: $kind")
197+
end
198+
end
199+
200+
"""
201+
$(TYPEDSIGNATURES)
202+
203+
Format a single initialisation specification for logging.
204+
205+
This function produces a human-readable string representation of a
206+
specification, used when `log = true` is passed to `@init`.
207+
208+
# Arguments
209+
210+
- `key::Symbol`: component name (e.g. `:u`, `:x`).
211+
- `spec::Tuple`: specification tuple.
212+
213+
# Returns
214+
215+
- `String`: formatted string like `"u = t -> sin(t)"` or `"x = 1.0"`.
216+
"""
217+
function __log_spec(key, spec)
218+
kind = spec[1]
219+
if kind == :constant
220+
rhs = spec[2]
221+
rhs_str = if rhs isa Expr
222+
sprint(Base.show_unquoted, Base.remove_linenums!(deepcopy(rhs)))
223+
else
224+
sprint(show, rhs)
225+
end
226+
return string(key, " = ", rhs_str)
227+
elseif kind == :temporal
228+
arg, rhs = spec[2], spec[3]
229+
rhs_clean = if rhs isa Expr
230+
Base.remove_linenums!(deepcopy(rhs))
231+
else
232+
rhs
233+
end
234+
rhs_str = sprint(Base.show_unquoted, rhs_clean)
235+
return string(key, " = ", arg, " -> ", rhs_str)
236+
else
237+
return string(key, " = ???")
238+
end
239+
end
240+
241+
"""
242+
$(TYPEDSIGNATURES)
243+
101244
Internal helper that parses the body of an `@init` block.
102245
103246
The function walks through the expression `ex` and splits it into
104247
105248
- *alias statements*, which are left as ordinary Julia assignments and
106249
executed verbatim inside the generated block;
107250
- *initialisation specifications* of the form `lhs := rhs` or
108-
`lhs(t) := rhs` / `lhs(T) := rhs`, which are converted into keys and
109-
values used to build a `NamedTuple`.
251+
`lhs(arg) := rhs`, which are converted into structured specification
252+
tuples.
253+
254+
For expressions of the form `lhs(arg) := rhs`, this function uses `has(rhs, arg)`
255+
to determine whether `arg` appears in the right-hand side. This information
256+
is stored in the specification tuple and used later to generate appropriate
257+
runtime code that distinguishes time-dependent functions from time grids.
110258
111259
# Arguments
112260
113261
- `ex::Any`: expression or block coming from the body of `@init`.
262+
- `lnum::Int`: line number for error reporting.
263+
- `line_str::String`: line string for error reporting.
114264
115265
# Returns
116266
117267
- `alias_stmts::Vector{Expr}`: ordinary statements to execute before
118268
building the initial guess.
119269
- `keys::Vector{Symbol}`: names of the components being initialised
120270
(e.g. `:q`, `:v`, `:u`, `:tf`).
121-
- `vals::Vector{Any}`: expressions representing the corresponding
122-
values, functions or `(T, data)` pairs.
271+
- `specs::Vector{Tuple}`: specification tuples, either `(:constant, rhs)`
272+
for constant values or `(:temporal, arg, rhs, arg_in_rhs)` for temporal
273+
specifications where `arg_in_rhs` indicates whether `arg` appears in `rhs`.
123274
"""
124-
function _collect_init_specs(ex)
275+
function _collect_init_specs(ex, lnum::Int, line_str::String)
125276
alias_stmts = Expr[] # statements of the form a = ... or other Julia statements
126277
keys = Symbol[] # keys of the NamedTuple (q, v, x, u, tf, ...)
127-
vals = Any[] # expressions for the associated values
278+
specs = Tuple[] # specification tuples
128279

129280
stmts = if ex isa Expr && ex.head == :block
130281
ex.args
@@ -141,25 +292,24 @@ function _collect_init_specs(ex)
141292
push!(alias_stmts, st)
142293
end
143294

144-
# Forms q(t) := rhs (time-dependent function) or q(T) := rhs (time grid)
295+
# Forms q(arg) := rhs
296+
# Use has(rhs, arg) to determine if arg appears in rhs
145297
:($lhs($arg) := $rhs) => begin
146298
lhs isa Symbol || error("Unsupported left-hand side in @init: $lhs")
147-
if arg == :t
148-
# q(t) := rhs → time-dependent function
149-
push!(keys, lhs)
150-
push!(vals, :($arg -> $rhs))
151-
else
152-
# q(T) := rhs → (T, rhs) for build_initial_guess
153-
push!(keys, lhs)
154-
push!(vals, :(($arg, $rhs)))
155-
end
299+
arg isa Symbol || error("Unsupported argument in @init: $arg must be a symbol")
300+
301+
# Check if arg appears in rhs using has() from utils.jl
302+
arg_in_rhs = has(rhs, arg)
303+
304+
push!(keys, lhs)
305+
push!(specs, (:temporal, arg, rhs, arg_in_rhs))
156306
end
157307

158308
# Constant / variable form: lhs := rhs
159309
:($lhs := $rhs) => begin
160310
lhs isa Symbol || error("Unsupported left-hand side in @init: $lhs")
161311
push!(keys, lhs)
162-
push!(vals, rhs)
312+
push!(specs, (:constant, rhs))
163313
end
164314

165315
# Fallback: any other line is treated as an ordinary Julia statement
@@ -169,7 +319,7 @@ function _collect_init_specs(ex)
169319
end
170320
end
171321

172-
return alias_stmts, keys, vals
322+
return alias_stmts, keys, specs
173323
end
174324

175325
"""
@@ -191,6 +341,8 @@ macro level.
191341
192342
- `ocp`: symbolic optimal control problem built with `@def`.
193343
- `e`: expression corresponding to the body of the `@init` block.
344+
- `lnum::Int`: line number for error reporting.
345+
- `line_str::String`: line string for error reporting.
194346
195347
# Returns
196348
@@ -199,8 +351,8 @@ macro level.
199351
- `code_expr::Expr`: block of Julia code that builds and validates the
200352
initial guess when executed.
201353
"""
202-
function init_fun(ocp, e)
203-
alias_stmts, keys, vals = _collect_init_specs(e)
354+
function init_fun(ocp, e, lnum::Int, line_str::String)
355+
alias_stmts, keys, specs = _collect_init_specs(e, lnum, line_str)
204356
pref = init_prefix()
205357

206358
# If there is no init specification, delegate to build_initial_guess/validate_initial_guess
@@ -215,47 +367,30 @@ function init_fun(ocp, e)
215367
return log_str, code_expr
216368
end
217369

218-
# Build the NamedTuple type and its values for execution
370+
# Generate runtime code for each specification
371+
body_stmts = Any[]
372+
append!(body_stmts, alias_stmts)
373+
374+
val_syms = Symbol[]
375+
for spec in specs
376+
val_sym, code = __gen_spec_value(pref, ocp, spec)
377+
push!(val_syms, val_sym)
378+
push!(body_stmts, code)
379+
end
380+
381+
# Build the NamedTuple with the generated value symbols
219382
key_nodes = [QuoteNode(k) for k in keys]
220383
keys_tuple = Expr(:tuple, key_nodes...)
221-
vals_tuple = Expr(:tuple, vals...)
384+
vals_tuple = Expr(:tuple, val_syms...)
222385
nt_expr = :(NamedTuple{$keys_tuple}($vals_tuple))
223386

224-
body_stmts = Any[]
225-
append!(body_stmts, alias_stmts)
226387
build_call = :($pref.build_initial_guess($ocp, $nt_expr))
227388
validate_call = :($pref.validate_initial_guess($ocp, $build_call))
228389
push!(body_stmts, validate_call)
229390
code_expr = Expr(:block, body_stmts...)
230391

231-
# Build a pretty NamedTuple-like string for logging, of the form (q = ..., v = ..., ...)
232-
pairs_str = String[]
233-
for (k, v) in zip(keys, vals)
234-
vc = v
235-
if vc isa Expr
236-
# Remove LineNumberNode noise and print without leading :( ... ) wrapper
237-
vc_clean = Base.remove_linenums!(deepcopy(vc))
238-
if vc_clean.head == :-> && length(vc_clean.args) == 2
239-
arg_expr, body_expr = vc_clean.args
240-
# Simplify body: strip trivial `begin ... end` with a single non-LineNumberNode expression
241-
body_clean = body_expr
242-
if body_clean isa Expr && body_clean.head == :block
243-
filtered = [x for x in body_clean.args if !(x isa LineNumberNode)]
244-
if length(filtered) == 1
245-
body_clean = filtered[1]
246-
end
247-
end
248-
lhs_str = sprint(Base.show_unquoted, arg_expr)
249-
rhs_body_str = sprint(Base.show_unquoted, body_clean)
250-
rhs_str = string(lhs_str, " -> ", rhs_body_str)
251-
else
252-
rhs_str = sprint(Base.show_unquoted, vc_clean)
253-
end
254-
else
255-
rhs_str = sprint(show, vc)
256-
end
257-
push!(pairs_str, string(k, " = ", rhs_str))
258-
end
392+
# Build log string using __log_spec helper
393+
pairs_str = [__log_spec(k, s) for (k, s) in zip(keys, specs)]
259394
log_str = if length(pairs_str) == 1
260395
string("(", pairs_str[1], ",)")
261396
else
@@ -349,22 +484,25 @@ macro init(ocp, e, rest...)
349484
if opt isa Expr && opt.head == :(=) && opt.args[1] == :log
350485
log_expr = opt.args[2]
351486
else
352-
error(
353-
"Unsupported trailing argument in @init. Use `log = true` or `log = false`."
487+
throw_expr = CTParser.__throw(
488+
"Unsupported trailing argument in @init. Use `log = true` or `log = false`.",
489+
lnum, line_str
354490
)
491+
return esc(throw_expr)
355492
end
356493
elseif length(rest) > 1
357-
error(
494+
throw_expr = CTParser.__throw(
358495
"Too many trailing arguments in @init. Only a single `log = ...` keyword is supported.",
496+
lnum, line_str
359497
)
498+
return esc(throw_expr)
360499
end
361500

362501
log_str, code = try
363-
init_fun(ocp, e)
502+
init_fun(ocp, e, lnum, line_str)
364503
catch err
365-
# Treat unsupported DSL syntax as a static parsing error with proper line info.
366-
if err isa ErrorException &&
367-
occursin("Unsupported left-hand side in @init", err.msg)
504+
# Catch any ErrorException from parsing and convert to __throw
505+
if err isa ErrorException
368506
throw_expr = CTParser.__throw(err.msg, lnum, line_str)
369507
return esc(throw_expr)
370508
else

0 commit comments

Comments
 (0)