9898"""
9999$(TYPEDSIGNATURES)
100100
101+ Generate runtime code for a temporal specification `lhs(arg) := rhs`.
102+
103+ This function produces the Julia expression that will be evaluated at runtime
104+ to determine whether the specification represents a time-dependent function
105+ or a time grid, based on whether `arg` matches `time_name(ocp)`.
106+
107+ # Arguments
108+
109+ - `pref::Symbol`: backend module prefix (e.g. `:CTModels`).
110+ - `ocp`: symbolic OCP variable passed from the macro.
111+ - `arg::Symbol`: argument symbol used in the specification (e.g. `:t`, `:s`, `:T`).
112+ - `rhs`: right-hand side expression.
113+ - `arg_in_rhs::Bool`: whether `arg` appears in `rhs` (computed at parse-time via `has`).
114+
115+ # Returns
116+
117+ - `val_sym::Symbol`: generated symbol to store the computed value.
118+ - `code::Expr`: expression block to insert in the generated code.
119+
120+ # Notes
121+
122+ When `arg_in_rhs` is `true`, the specification is definitely a time-dependent
123+ function, so we validate that `arg == Symbol(time_name(ocp))` and throw an
124+ error if not. When `arg_in_rhs` is `false`, we generate a runtime conditional
125+ that checks whether `arg` matches the time name to decide between a constant
126+ function or a time grid.
127+ """
128+ function __gen_temporal_value (pref, ocp, arg, rhs, arg_in_rhs)
129+ val_sym = __symgen (:init_val )
130+ arg_quoted = QuoteNode (arg)
131+
132+ if arg_in_rhs
133+ # arg appears in rhs → must be a time-dependent function
134+ # Validate at runtime that arg matches time_name(ocp)
135+ code = quote
136+ let _expected = Symbol ($ pref. time_name ($ ocp))
137+ if $ arg_quoted != _expected
138+ error (
139+ " Incorrect time variable in @init: " *
140+ " used :" * string ($ arg_quoted) * " but time_name(ocp) is " *
141+ " \" " * $ pref. time_name ($ ocp) * " \" " *
142+ " (expected :" * string (_expected) * " ). " *
143+ " Please use :" * string (_expected) * " instead of :" * string ($ arg_quoted) * " " *
144+ " in your @init block."
145+ )
146+ end
147+ end
148+ $ val_sym = $ arg -> $ rhs
149+ end
150+ else
151+ # arg does NOT appear in rhs → ambiguous
152+ # Runtime check: if arg matches time_name → constant function, else → grid
153+ code = quote
154+ $ val_sym = if Symbol ($ pref. time_name ($ ocp)) == $ arg_quoted
155+ $ arg -> $ rhs # constant time function
156+ else
157+ ($ arg, $ rhs) # time grid
158+ end
159+ end
160+ end
161+
162+ return val_sym, code
163+ end
164+
165+ """
166+ $(TYPEDSIGNATURES)
167+
168+ Generate runtime code for a single initialisation specification.
169+
170+ This function dispatches based on the specification kind (`:constant` or
171+ `:temporal`) and delegates to the appropriate code generator.
172+
173+ # Arguments
174+
175+ - `pref::Symbol`: backend module prefix.
176+ - `ocp`: symbolic OCP variable.
177+ - `spec::Tuple`: specification tuple, either `(:constant, rhs)` or
178+ `(:temporal, arg, rhs, arg_in_rhs)`.
179+
180+ # Returns
181+
182+ - `val_sym::Symbol`: generated symbol to store the value.
183+ - `code::Expr`: expression to insert in the generated code.
184+ """
185+ function __gen_spec_value (pref, ocp, spec)
186+ kind = spec[1 ]
187+ if kind == :constant
188+ rhs = spec[2 ]
189+ val_sym = __symgen (:init_val )
190+ code = :($ val_sym = $ rhs)
191+ return val_sym, code
192+ elseif kind == :temporal
193+ arg, rhs, arg_in_rhs = spec[2 ], spec[3 ], spec[4 ]
194+ return __gen_temporal_value (pref, ocp, arg, rhs, arg_in_rhs)
195+ else
196+ error (" Unknown spec kind: $kind " )
197+ end
198+ end
199+
200+ """
201+ $(TYPEDSIGNATURES)
202+
203+ Format a single initialisation specification for logging.
204+
205+ This function produces a human-readable string representation of a
206+ specification, used when `log = true` is passed to `@init`.
207+
208+ # Arguments
209+
210+ - `key::Symbol`: component name (e.g. `:u`, `:x`).
211+ - `spec::Tuple`: specification tuple.
212+
213+ # Returns
214+
215+ - `String`: formatted string like `"u = t -> sin(t)"` or `"x = 1.0"`.
216+ """
217+ function __log_spec (key, spec)
218+ kind = spec[1 ]
219+ if kind == :constant
220+ rhs = spec[2 ]
221+ rhs_str = if rhs isa Expr
222+ sprint (Base. show_unquoted, Base. remove_linenums! (deepcopy (rhs)))
223+ else
224+ sprint (show, rhs)
225+ end
226+ return string (key, " = " , rhs_str)
227+ elseif kind == :temporal
228+ arg, rhs = spec[2 ], spec[3 ]
229+ rhs_clean = if rhs isa Expr
230+ Base. remove_linenums! (deepcopy (rhs))
231+ else
232+ rhs
233+ end
234+ rhs_str = sprint (Base. show_unquoted, rhs_clean)
235+ return string (key, " = " , arg, " -> " , rhs_str)
236+ else
237+ return string (key, " = ???" )
238+ end
239+ end
240+
241+ """
242+ $(TYPEDSIGNATURES)
243+
101244Internal helper that parses the body of an `@init` block.
102245
103246The function walks through the expression `ex` and splits it into
104247
105248- *alias statements*, which are left as ordinary Julia assignments and
106249 executed verbatim inside the generated block;
107250- *initialisation specifications* of the form `lhs := rhs` or
108- `lhs(t) := rhs` / `lhs(T) := rhs`, which are converted into keys and
109- values used to build a `NamedTuple`.
251+ `lhs(arg) := rhs`, which are converted into structured specification
252+ tuples.
253+
254+ For expressions of the form `lhs(arg) := rhs`, this function uses `has(rhs, arg)`
255+ to determine whether `arg` appears in the right-hand side. This information
256+ is stored in the specification tuple and used later to generate appropriate
257+ runtime code that distinguishes time-dependent functions from time grids.
110258
111259# Arguments
112260
113261- `ex::Any`: expression or block coming from the body of `@init`.
262+ - `lnum::Int`: line number for error reporting.
263+ - `line_str::String`: line string for error reporting.
114264
115265# Returns
116266
117267- `alias_stmts::Vector{Expr}`: ordinary statements to execute before
118268 building the initial guess.
119269- `keys::Vector{Symbol}`: names of the components being initialised
120270 (e.g. `:q`, `:v`, `:u`, `:tf`).
121- - `vals::Vector{Any}`: expressions representing the corresponding
122- values, functions or `(T, data)` pairs.
271+ - `specs::Vector{Tuple}`: specification tuples, either `(:constant, rhs)`
272+ for constant values or `(:temporal, arg, rhs, arg_in_rhs)` for temporal
273+ specifications where `arg_in_rhs` indicates whether `arg` appears in `rhs`.
123274"""
124- function _collect_init_specs (ex)
275+ function _collect_init_specs (ex, lnum :: Int , line_str :: String )
125276 alias_stmts = Expr[] # statements of the form a = ... or other Julia statements
126277 keys = Symbol[] # keys of the NamedTuple (q, v, x, u, tf, ...)
127- vals = Any [] # expressions for the associated values
278+ specs = Tuple [] # specification tuples
128279
129280 stmts = if ex isa Expr && ex. head == :block
130281 ex. args
@@ -141,25 +292,24 @@ function _collect_init_specs(ex)
141292 push! (alias_stmts, st)
142293 end
143294
144- # Forms q(t) := rhs (time-dependent function) or q(T) := rhs (time grid)
295+ # Forms q(arg) := rhs
296+ # Use has(rhs, arg) to determine if arg appears in rhs
145297 :($ lhs ($ arg) := $ rhs) => begin
146298 lhs isa Symbol || error (" Unsupported left-hand side in @init: $lhs " )
147- if arg == :t
148- # q(t) := rhs → time-dependent function
149- push! (keys, lhs)
150- push! (vals, :($ arg -> $ rhs))
151- else
152- # q(T) := rhs → (T, rhs) for build_initial_guess
153- push! (keys, lhs)
154- push! (vals, :(($ arg, $ rhs)))
155- end
299+ arg isa Symbol || error (" Unsupported argument in @init: $arg must be a symbol" )
300+
301+ # Check if arg appears in rhs using has() from utils.jl
302+ arg_in_rhs = has (rhs, arg)
303+
304+ push! (keys, lhs)
305+ push! (specs, (:temporal , arg, rhs, arg_in_rhs))
156306 end
157307
158308 # Constant / variable form: lhs := rhs
159309 :($ lhs := $ rhs) => begin
160310 lhs isa Symbol || error (" Unsupported left-hand side in @init: $lhs " )
161311 push! (keys, lhs)
162- push! (vals, rhs)
312+ push! (specs, ( :constant , rhs) )
163313 end
164314
165315 # Fallback: any other line is treated as an ordinary Julia statement
@@ -169,7 +319,7 @@ function _collect_init_specs(ex)
169319 end
170320 end
171321
172- return alias_stmts, keys, vals
322+ return alias_stmts, keys, specs
173323end
174324
175325"""
@@ -191,6 +341,8 @@ macro level.
191341
192342- `ocp`: symbolic optimal control problem built with `@def`.
193343- `e`: expression corresponding to the body of the `@init` block.
344+ - `lnum::Int`: line number for error reporting.
345+ - `line_str::String`: line string for error reporting.
194346
195347# Returns
196348
@@ -199,8 +351,8 @@ macro level.
199351- `code_expr::Expr`: block of Julia code that builds and validates the
200352 initial guess when executed.
201353"""
202- function init_fun (ocp, e)
203- alias_stmts, keys, vals = _collect_init_specs (e)
354+ function init_fun (ocp, e, lnum :: Int , line_str :: String )
355+ alias_stmts, keys, specs = _collect_init_specs (e, lnum, line_str )
204356 pref = init_prefix ()
205357
206358 # If there is no init specification, delegate to build_initial_guess/validate_initial_guess
@@ -215,47 +367,30 @@ function init_fun(ocp, e)
215367 return log_str, code_expr
216368 end
217369
218- # Build the NamedTuple type and its values for execution
370+ # Generate runtime code for each specification
371+ body_stmts = Any[]
372+ append! (body_stmts, alias_stmts)
373+
374+ val_syms = Symbol[]
375+ for spec in specs
376+ val_sym, code = __gen_spec_value (pref, ocp, spec)
377+ push! (val_syms, val_sym)
378+ push! (body_stmts, code)
379+ end
380+
381+ # Build the NamedTuple with the generated value symbols
219382 key_nodes = [QuoteNode (k) for k in keys]
220383 keys_tuple = Expr (:tuple , key_nodes... )
221- vals_tuple = Expr (:tuple , vals ... )
384+ vals_tuple = Expr (:tuple , val_syms ... )
222385 nt_expr = :(NamedTuple {$keys_tuple} ($ vals_tuple))
223386
224- body_stmts = Any[]
225- append! (body_stmts, alias_stmts)
226387 build_call = :($ pref. build_initial_guess ($ ocp, $ nt_expr))
227388 validate_call = :($ pref. validate_initial_guess ($ ocp, $ build_call))
228389 push! (body_stmts, validate_call)
229390 code_expr = Expr (:block , body_stmts... )
230391
231- # Build a pretty NamedTuple-like string for logging, of the form (q = ..., v = ..., ...)
232- pairs_str = String[]
233- for (k, v) in zip (keys, vals)
234- vc = v
235- if vc isa Expr
236- # Remove LineNumberNode noise and print without leading :( ... ) wrapper
237- vc_clean = Base. remove_linenums! (deepcopy (vc))
238- if vc_clean. head == :-> && length (vc_clean. args) == 2
239- arg_expr, body_expr = vc_clean. args
240- # Simplify body: strip trivial `begin ... end` with a single non-LineNumberNode expression
241- body_clean = body_expr
242- if body_clean isa Expr && body_clean. head == :block
243- filtered = [x for x in body_clean. args if ! (x isa LineNumberNode)]
244- if length (filtered) == 1
245- body_clean = filtered[1 ]
246- end
247- end
248- lhs_str = sprint (Base. show_unquoted, arg_expr)
249- rhs_body_str = sprint (Base. show_unquoted, body_clean)
250- rhs_str = string (lhs_str, " -> " , rhs_body_str)
251- else
252- rhs_str = sprint (Base. show_unquoted, vc_clean)
253- end
254- else
255- rhs_str = sprint (show, vc)
256- end
257- push! (pairs_str, string (k, " = " , rhs_str))
258- end
392+ # Build log string using __log_spec helper
393+ pairs_str = [__log_spec (k, s) for (k, s) in zip (keys, specs)]
259394 log_str = if length (pairs_str) == 1
260395 string (" (" , pairs_str[1 ], " ,)" )
261396 else
@@ -349,22 +484,25 @@ macro init(ocp, e, rest...)
349484 if opt isa Expr && opt. head == :(= ) && opt. args[1 ] == :log
350485 log_expr = opt. args[2 ]
351486 else
352- error (
353- " Unsupported trailing argument in @init. Use `log = true` or `log = false`."
487+ throw_expr = CTParser. __throw (
488+ " Unsupported trailing argument in @init. Use `log = true` or `log = false`." ,
489+ lnum, line_str
354490 )
491+ return esc (throw_expr)
355492 end
356493 elseif length (rest) > 1
357- error (
494+ throw_expr = CTParser . __throw (
358495 " Too many trailing arguments in @init. Only a single `log = ...` keyword is supported." ,
496+ lnum, line_str
359497 )
498+ return esc (throw_expr)
360499 end
361500
362501 log_str, code = try
363- init_fun (ocp, e)
502+ init_fun (ocp, e, lnum, line_str )
364503 catch err
365- # Treat unsupported DSL syntax as a static parsing error with proper line info.
366- if err isa ErrorException &&
367- occursin (" Unsupported left-hand side in @init" , err. msg)
504+ # Catch any ErrorException from parsing and convert to __throw
505+ if err isa ErrorException
368506 throw_expr = CTParser. __throw (err. msg, lnum, line_str)
369507 return esc (throw_expr)
370508 else
0 commit comments