fix(jit): infer DynDim on locals so dynamic shapes flow into deps#1536
Conversation
@pl.jit failed to compile patterns like
tokens = pl.tensor.dim(hidden_states, 0)
current = pl.create_tensor([tokens, HIDDEN], dtype=pl.BF16)
current = layer(current, next_buf)
because _extract_local_tensor_metas could not resolve the local tensor's
shape, and the same DynVar reused directly in the shape leaked past SSA
conversion as "Variable 'M' used outside its defining scope".
Refactor the JIT dynamic-shape model so dynamic dims live as first-class
DynDim entries inside TensorMeta.shape. The legacy dynamic_dims set and
dynvar_bindings / dynvar_literals dicts are derived from the metas via
new SpecializeContext accessors; the four old propagation helpers
(_compute_per_func_dynamic_dims, _build_dynvar_bindings,
_backfill_dynvar_bindings, _merge_annotation_dynvars) are removed.
_compute_per_func_dyndim_maps does a single leaf-first cascade so the
entry's cache key stays DynDim-aware even when only a dep declares the
dim dynamic.
_extract_local_tensor_metas now prescans `var = pl.tensor.dim(P, k)`
aliases, builds an inverse anchor index of seeded DynDim references, and
its shape-element resolver returns int | DynDim. _BodyTransformer
substitutes runtime DynVar references with pl.tensor.dim(P, k) at the
first anchor site so the generated source never leaks an annotation-only
DynVar into a runtime expression. Closure variables are now consulted
alongside __globals__ for shape-element resolution.
Add four regression tests under TestDynamicLocalTensorMetadata covering
the dim-alias propagation, the static parent-dim path, the
DynVar-in-create_tensor substitution, and the issue's verbatim repro.
Fixes hw-native-sys#1524
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
📝 WalkthroughWalkthroughThis PR refactors JIT dynamic-dimension handling to embed DynDim objects directly in TensorMeta.shape instead of using separate set-based representations. The change cascades through dep-graph computation, local tensor metadata inference, and the compilation pipeline, enabling inference of dynamic local tensors created from pl.tensor.dim() aliases (fixing issue ChangesDynamic-Dimension Refactor: DynDim-embedded TensorMeta
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the JIT specialization and dynamic dimension tracking mechanism by embedding dynamic dimension metadata (DynDim) directly into TensorMeta shapes, replacing the parallel dynamic dimension and dynvar binding tables. This change resolves issue #1524, allowing local tensors created via pl.create_tensor to correctly inherit dynamic dimensions and propagate them through dependency calls. Feedback on the changes suggests: (1) enhancing _scan_dim_aliases to support chained assignments by processing all assignment targets, (2) adding a bounds check when indexing self._meta[param_name].shape in visit_Assign to prevent potential IndexError crashes, and (3) using copy.deepcopy when substituting DynVar runtime references in visit_Name to avoid mutable AST node sharing across multiple locations.
There was a problem hiding this comment.
Pull request overview
Refactors JIT dynamic-shape handling so dynamic dimensions are represented as first-class DynDim entries inside TensorMeta.shape, enabling dynamic shape metadata to propagate through locals (e.g., pl.tensor.dim(...) aliases) and into dependent @pl.jit.inline functions without SSA scope violations.
Changes:
- Introduces
DynDimand updatesTensorMeta/SpecializeContextto derive legacy dynamic-dim views from embedded shape entries; removes legacy dynvar binding tables from specialization APIs. - Updates decorator-side metadata inference and propagation (
_compute_per_func_dyndim_maps,_extract_local_tensor_metas) to carryDynDimthrough locals, deps, and cache keys. - Adds regression tests for issue #1524, covering
pl.tensor.dimaliases and direct DynVar usage inpl.create_tensorshapes.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
python/pypto/jit/specializer.py |
Adds DynDim/ShapeDim, embeds dynamic dims in TensorMeta.shape, simplifies specialization APIs, and updates AST rewriting to handle dynvar declarations/uses. |
python/pypto/jit/decorator.py |
Builds per-function DynDim maps and enhances local tensor meta inference so dynamic shapes flow through locals and across dep graphs. |
tests/ut/jit/test_specializer.py |
Updates unit tests to construct dynamic shapes via DynDim-embedded TensorMeta and to match new specializer signatures. |
tests/ut/jit/test_decorator.py |
Adds new regression tests for dynamic local tensor metadata propagation and DynVar substitution to avoid SSA scope errors. |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
python/pypto/jit/specializer.py (1)
470-472:⚠️ Potential issue | 🟠 Major | ⚡ Quick winDon't treat removed
pl.dynamic(...)locals as real runtime bindings.
_used_namesis preseeded from everyStoretarget, andvisit_Name()uses that to suppress DynVar anchoring. Aftervisit_Assign()deletesrows = pl.dynamic("M"),rowsis still considered “assigned”, so later runtime uses likepl.create_tensor([rows, HIDDEN], ...)will not rewrite topl.tensor.dim(...)and can still leak the annotation-only DynVar into IR.Also applies to: 581-591, 713-721
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/pypto/jit/specializer.py` around lines 470 - 472, The pre-seeding of self._used_names with Store targets is treating removed annotation-only locals (e.g., variables created via pl.dynamic(...) like rows) as real runtime bindings, preventing DynVar anchoring in visit_Name; update the logic so that when visit_Assign removes a dynamic-only binding (pl.dynamic(...)) you also remove that target name from self._used_names (or track a separate set of annotation-only names) so visit_Name will allow anchoring; locate and adjust the handling in visit_Assign and the initialization of self._used_names (and analogous code at the other sites around lines 581-591 and 713-721) to ensure annotation-only Dynamic variables do not remain marked as "used".
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@python/pypto/jit/decorator.py`:
- Around line 371-390: The code stores DynDim entries under local temporary
names (e.g., "tmp") so subsequent hops using
_build_param_mapping(dep._param_names(), ...) only see real parameter names and
lose provenance; to fix, when mapping dep_param -> caller_arg inside the loop
that iterates dep_map.items(), detect if caller_arg is a local temporary (not
one of the caller function's parameter names) and resolve it back to the
originating caller parameter(s) before inserting into caller_map: use the
existing param_mapping, call_args_cache, and callers_by_dep_id to walk the
provenance (or invert param_mapping) and translate local aliases to original
parameter names, then set DynDim entries under those source param names in
caller_map (keep using dep._func, dep._param_names(), callers_by_dep_id,
call_args_cache, _build_param_mapping and out to locate and update the
mappings).
- Around line 464-481: In _func_name_lookup(), the current merge uses
out.setdefault(...) so module globals beat captured closure vars; change the
merge to assign/overwrite the key with the closure binding (e.g., out[fv_name] =
cell.cell_contents) while preserving the existing try/except that skips unbound
closure cells, and update the zip invocation to be explicit about strictness
(zip(co_freevars, closure, strict=True)) or add a targeted Ruff suppression for
B905 so the free-var-to-closure pairing is correct and lint warnings are
addressed.
In `@python/pypto/jit/specializer.py`:
- Around line 143-157: dynvar_literals currently silently overwrites
name→literal pairs across metas, but Specializer.specialize expects the first
declaration to win across contexts; change dynvar_literals to detect conflicting
mappings and fail fast: iterate self.tensor_meta.values() and for each DynDim
(in meta.shape) if d.name already in out and out[d.name] != d.literal raise a
clear ValueError (or a custom exception) explaining the conflicting dynvar name
and the two literals and include the originating meta/context identifier; keep
the function returning the mapping when no conflicts are found. This ensures
Specializer.specialize and dynvar_literals have consistent behaviour and
surfaces illegal reuse of the same DynVar name with different pl.dynamic()
literals.
- Around line 543-554: The code constructs raw ast.Name nodes for dynamic
dimensions (in the tuple-building block and in _shape_dim_node) which bypasses
visit_Name rewriting; replace those raw ast.Name(id=...) creations with anchored
expressions fetched from self._dynvar_anchors (e.g. use
self._dynvar_anchors[d.name]) or ensure the newly created nodes are passed
through self.visit() so visit_Name can emit the anchored form; update both the
tuple-construction loop that appends elts and the _shape_dim_node function to
return the anchored AST expression (or a visited node) whenever isinstance(d,
DynDim) rather than a bare ast.Name.
---
Outside diff comments:
In `@python/pypto/jit/specializer.py`:
- Around line 470-472: The pre-seeding of self._used_names with Store targets is
treating removed annotation-only locals (e.g., variables created via
pl.dynamic(...) like rows) as real runtime bindings, preventing DynVar anchoring
in visit_Name; update the logic so that when visit_Assign removes a dynamic-only
binding (pl.dynamic(...)) you also remove that target name from self._used_names
(or track a separate set of annotation-only names) so visit_Name will allow
anchoring; locate and adjust the handling in visit_Assign and the initialization
of self._used_names (and analogous code at the other sites around lines 581-591
and 713-721) to ensure annotation-only Dynamic variables do not remain marked as
"used".
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 6c375569-2ac6-4018-91c1-d7ed392d1da5
📒 Files selected for processing (4)
python/pypto/jit/decorator.pypython/pypto/jit/specializer.pytests/ut/jit/test_decorator.pytests/ut/jit/test_specializer.py
Address review feedback: - specializer.py: _shape_tuple_node, _shape_dim_node, and the M, N = a.shape Case 2 emission now produce pl.tensor.dim(<param>, <dim_idx>) instead of ast.Name(<DynVar>). The previous code bypassed visit_Name (those nodes are returned directly by visit_Attribute / visit_Subscript / visit_Assign and not revisited), so DynVar references could still leak into runtime expressions and re-trigger the ConvertToSSA "Variable 'M' used outside its defining scope" error. _dyn_dim_expr centralizes the emission. - specializer.py: visit_Name's DynVar substitution now rebuilds the anchor AST via _dyn_dim_expr (returns a fresh subtree per call) instead of reusing nested ast.Call / ast.Attribute children — avoids mutable AST node sharing across substitution sites. - specializer.py: SpecializeContext.dynvar_literals docstring rewritten — the previous note about pl.dynamic() being singleton-cached was wrong (it isn't). The real constraint is on user source not rebinding a Python name to two different literals. - decorator.py: _func_name_lookup now lets closure free vars override module globals, matching Python's own name-resolution order. The previous setdefault inversion silently returned the wrong constant when a nested function shadowed a module global. Adds strict=True to zip() (B905 — surfaced by CodeRabbit's ruff hook). - decorator.py: _scan_dim_aliases is now reassignment-safe: any name rebound to a non-pl.tensor.dim value drops its alias entry, so patterns like tokens = pl.tensor.dim(x, 0); tokens = tokens - 1 don't stamp a stale DynDim onto downstream pl.create_tensor shapes. Add two regression tests under TestDynamicLocalTensorMetadata: - test_shape_attribute_emits_anchor_not_dynvar - test_dim_alias_rebind_is_safe
Summary
DynDimentries insideTensorMeta.shape; legacy parallel tables (dynamic_dims,dynvar_bindings,dynvar_literals) become derived accessors onSpecializeContext. Four old propagation helpers collapse into one leaf-first cascade (_compute_per_func_dyndim_maps)._extract_local_tensor_metasaboutvar = pl.tensor.dim(P, k)aliases and direct DynVar references; the shape-element resolver returnsint | DynDimso locals likepl.create_tensor([tokens, HIDDEN], ...)inherit the parent's dynamic dim._BodyTransformer.visit_Namesubstitute runtime DynVar references withpl.tensor.dim(P, k)at the first anchor site — keeps annotation-only DynVars out of runtime expressions soConvertToSSAno longer reports "Variable 'M' used outside its defining scope".Why
Closes #1524. The issue's repro pattern (
tokens = pl.tensor.dim(hidden_states, 0); current = pl.create_tensor([tokens, HIDDEN], ...)) and the "alternative workaround" form (pl.create_tensor([M, HIDDEN], ...)) both failed before this change — the first withmissing inferred tensor metadata, the second with the SSA "outside defining scope" error.Test plan
tests/ut/jit/— 177 passed (173 prior + 4 new inTestDynamicLocalTensorMetadata):test_dim_alias_propagates_dyndim_to_local—pl.tensor.dimalias stamps parent's DynDim onto the local.test_dim_alias_static_dim_stays_int— aliasing a static parent dim yields anint.test_dynvar_in_create_tensor_substituted— direct DynVar in shape is rewritten topl.tensor.dim(P, k).test_issue_1524_repro_compiles— the issue's verbatim repro.ruff check+ruff format+pyright— all green via pre-commit hooks.