Skip to content

fix(jit): infer DynDim on locals so dynamic shapes flow into deps#1536

Merged
lyfne123 merged 2 commits into
hw-native-sys:mainfrom
Hzfengsy:worktree-issue-1524
May 26, 2026
Merged

fix(jit): infer DynDim on locals so dynamic shapes flow into deps#1536
lyfne123 merged 2 commits into
hw-native-sys:mainfrom
Hzfengsy:worktree-issue-1524

Conversation

@Hzfengsy
Copy link
Copy Markdown
Member

Summary

  • Refactor JIT dynamic-shape model so dynamic dims live as first-class DynDim entries inside TensorMeta.shape; legacy parallel tables (dynamic_dims, dynvar_bindings, dynvar_literals) become derived accessors on SpecializeContext. Four old propagation helpers collapse into one leaf-first cascade (_compute_per_func_dyndim_maps).
  • Teach _extract_local_tensor_metas about var = pl.tensor.dim(P, k) aliases and direct DynVar references; the shape-element resolver returns int | DynDim so locals like pl.create_tensor([tokens, HIDDEN], ...) inherit the parent's dynamic dim.
  • Have _BodyTransformer.visit_Name substitute runtime DynVar references with pl.tensor.dim(P, k) at the first anchor site — keeps annotation-only DynVars out of runtime expressions so ConvertToSSA no longer reports "Variable 'M' used outside its defining scope".

Why

Closes #1524. The issue's repro pattern (tokens = pl.tensor.dim(hidden_states, 0); current = pl.create_tensor([tokens, HIDDEN], ...)) and the "alternative workaround" form (pl.create_tensor([M, HIDDEN], ...)) both failed before this change — the first with missing inferred tensor metadata, the second with the SSA "outside defining scope" error.

Test plan

  • tests/ut/jit/ — 177 passed (173 prior + 4 new in TestDynamicLocalTensorMetadata):
    • test_dim_alias_propagates_dyndim_to_localpl.tensor.dim alias stamps parent's DynDim onto the local.
    • test_dim_alias_static_dim_stays_int — aliasing a static parent dim yields an int.
    • test_dynvar_in_create_tensor_substituted — direct DynVar in shape is rewritten to pl.tensor.dim(P, k).
    • test_issue_1524_repro_compiles — the issue's verbatim repro.
  • ruff check + ruff format + pyright — all green via pre-commit hooks.
  • CI: ut suite, lint, build.

@pl.jit failed to compile patterns like

    tokens = pl.tensor.dim(hidden_states, 0)
    current = pl.create_tensor([tokens, HIDDEN], dtype=pl.BF16)
    current = layer(current, next_buf)

because _extract_local_tensor_metas could not resolve the local tensor's
shape, and the same DynVar reused directly in the shape leaked past SSA
conversion as "Variable 'M' used outside its defining scope".

Refactor the JIT dynamic-shape model so dynamic dims live as first-class
DynDim entries inside TensorMeta.shape. The legacy dynamic_dims set and
dynvar_bindings / dynvar_literals dicts are derived from the metas via
new SpecializeContext accessors; the four old propagation helpers
(_compute_per_func_dynamic_dims, _build_dynvar_bindings,
_backfill_dynvar_bindings, _merge_annotation_dynvars) are removed.
_compute_per_func_dyndim_maps does a single leaf-first cascade so the
entry's cache key stays DynDim-aware even when only a dep declares the
dim dynamic.

_extract_local_tensor_metas now prescans `var = pl.tensor.dim(P, k)`
aliases, builds an inverse anchor index of seeded DynDim references, and
its shape-element resolver returns int | DynDim. _BodyTransformer
substitutes runtime DynVar references with pl.tensor.dim(P, k) at the
first anchor site so the generated source never leaks an annotation-only
DynVar into a runtime expression. Closure variables are now consulted
alongside __globals__ for shape-element resolution.

Add four regression tests under TestDynamicLocalTensorMetadata covering
the dim-alias propagation, the static parent-dim path, the
DynVar-in-create_tensor substitution, and the issue's verbatim repro.

Fixes hw-native-sys#1524
Copilot AI review requested due to automatic review settings May 26, 2026 09:44
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 26, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 0e183784-1ef3-4932-93ad-9169f5a55ea2

📥 Commits

Reviewing files that changed from the base of the PR and between 09bee27 and f90ddc3.

📒 Files selected for processing (3)
  • python/pypto/jit/decorator.py
  • python/pypto/jit/specializer.py
  • tests/ut/jit/test_decorator.py

📝 Walkthrough

Walkthrough

This PR refactors JIT dynamic-dimension handling to embed DynDim objects directly in TensorMeta.shape instead of using separate set-based representations. The change cascades through dep-graph computation, local tensor metadata inference, and the compilation pipeline, enabling inference of dynamic local tensors created from pl.tensor.dim() aliases (fixing issue #1524).

Changes

Dynamic-Dimension Refactor: DynDim-embedded TensorMeta

Layer / File(s) Summary
Core data model: TensorMeta, DynDim, and ShapeDim
python/pypto/jit/specializer.py
Introduces DynDim with name,literal,static_bound; `ShapeDim = int
SpecializeContext and annotation builder refactor
python/pypto/jit/specializer.py
Removes stored dynamic_dims, adds dep_names/py_globals, provides derived dynamic_dims, dynvar_for(), dynvar_literals(), and simplifies _build_tensor_annotation; updates Specializer/specialize() signatures and build_specialize_context behavior.
Body transformer: shape handling and dynvar runtime substitution
python/pypto/jit/specializer.py
Builds _dynvar_anchors, emits pl.tensor.dim(...) anchors for dynamic dims, branches on dynamic vs static targets when unpacking a.shape, and rewrites bare dynvar names to anchored expressions at runtime.
Dynamic-dimension discovery: AST parsing and per-function DynDim mapping
python/pypto/jit/decorator.py
Scans AST for param.bind_dynamic(...), builds per-function DynDim maps, unions annotation-embedded dynvars and literals/anchors, and computes effective per-function DynDim maps via leaf-first dep-graph propagation.
Local tensor metadata extraction with DynDim support
python/pypto/jit/decorator.py
_extract_tensor_meta accepts per-dim DynDim bindings and stamps static_bound; _extract_local_tensor_metas resolves shape elements to ints or DynDim, indexes dim aliases, and propagates DynDim through slicing.
Dep-call metadata resolution and dep-graph updates
python/pypto/jit/decorator.py
Change _resolve_dep_call_metadata to accept dep_dyn_map, overlay DynDim entries onto callee tensor metas (pinning static_bound), and update dep-graph caching/comments to remove per-func dynamic-dim set caching.
Cache key and compilation pipeline updates
python/pypto/jit/decorator.py, python/pypto/jit/specializer.py
Cache keys derive dynamic_dims from each TensorMeta.dynamic_dim_indices() and use static_shape() for tensor shapes; _compile/_compile_to_program and context building pass per-func DynDim maps and instantiate Specializer without external dynvar binding tables.
Return type inference and module-level dynamic declaration emission
python/pypto/jit/specializer.py
_infer_return_type now uses TensorMeta/DynDim directly; module-level pl.dynamic(...) declarations are emitted from SpecializeContext.dynvar_literals() deduplicated across contexts.
Test updates: API alignment and regression tests for issue #1524
tests/ut/jit/test_decorator.py, tests/ut/jit/test_specializer.py
Update test helpers to remove dynamic_dims args and rely on tensor_meta-embedded DynDim; rewrite dynvar expectations to construct DynDim objects; adapt dep-graph unpacking and per_func_dyn call sites; add TestDynamicLocalTensorMetadata regression suite.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • lyfne123
  • zhangqi-chen

Poem

A rabbit hops through shapes both wide and slim, 🐇
DynDim tucked inside each tensor's limb,
From pl.tensor.dim the anchors softly sing,
Local creates now keep the dynvar ring,
Compilation hums — the inline flows begin.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely describes the main change: refactoring to enable DynDim inference on local tensors so dynamic shapes flow into inline function dependencies.
Description check ✅ Passed The description is well-related to the changeset, clearly explaining the refactor of dynamic-shape model, specific enhancements to tensor metadata inference, and the DynVar substitution mechanism.
Linked Issues check ✅ Passed The PR directly addresses all coding requirements from issue #1524: propagating DynDim into local tensor metadata, ensuring runtime expressions use pl.tensor.dim anchors instead of annotation-only DynVar names, and enabling the decode-style pattern.
Out of Scope Changes check ✅ Passed All changes are directly scoped to addressing issue #1524: refactoring JIT dynamic-shape representation, enhancing local tensor metadata inference, and fixing SSA validation errors related to dynamic dimension handling.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the JIT specialization and dynamic dimension tracking mechanism by embedding dynamic dimension metadata (DynDim) directly into TensorMeta shapes, replacing the parallel dynamic dimension and dynvar binding tables. This change resolves issue #1524, allowing local tensors created via pl.create_tensor to correctly inherit dynamic dimensions and propagate them through dependency calls. Feedback on the changes suggests: (1) enhancing _scan_dim_aliases to support chained assignments by processing all assignment targets, (2) adding a bounds check when indexing self._meta[param_name].shape in visit_Assign to prevent potential IndexError crashes, and (3) using copy.deepcopy when substituting DynVar runtime references in visit_Name to avoid mutable AST node sharing across multiple locations.

Comment thread python/pypto/jit/decorator.py Outdated
Comment thread python/pypto/jit/specializer.py
Comment thread python/pypto/jit/specializer.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors JIT dynamic-shape handling so dynamic dimensions are represented as first-class DynDim entries inside TensorMeta.shape, enabling dynamic shape metadata to propagate through locals (e.g., pl.tensor.dim(...) aliases) and into dependent @pl.jit.inline functions without SSA scope violations.

Changes:

  • Introduces DynDim and updates TensorMeta/SpecializeContext to derive legacy dynamic-dim views from embedded shape entries; removes legacy dynvar binding tables from specialization APIs.
  • Updates decorator-side metadata inference and propagation (_compute_per_func_dyndim_maps, _extract_local_tensor_metas) to carry DynDim through locals, deps, and cache keys.
  • Adds regression tests for issue #1524, covering pl.tensor.dim aliases and direct DynVar usage in pl.create_tensor shapes.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
python/pypto/jit/specializer.py Adds DynDim/ShapeDim, embeds dynamic dims in TensorMeta.shape, simplifies specialization APIs, and updates AST rewriting to handle dynvar declarations/uses.
python/pypto/jit/decorator.py Builds per-function DynDim maps and enhances local tensor meta inference so dynamic shapes flow through locals and across dep graphs.
tests/ut/jit/test_specializer.py Updates unit tests to construct dynamic shapes via DynDim-embedded TensorMeta and to match new specializer signatures.
tests/ut/jit/test_decorator.py Adds new regression tests for dynamic local tensor metadata propagation and DynVar substitution to avoid SSA scope errors.

Comment thread python/pypto/jit/specializer.py
Comment thread python/pypto/jit/specializer.py Outdated
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
python/pypto/jit/specializer.py (1)

470-472: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't treat removed pl.dynamic(...) locals as real runtime bindings.

_used_names is preseeded from every Store target, and visit_Name() uses that to suppress DynVar anchoring. After visit_Assign() deletes rows = pl.dynamic("M"), rows is still considered “assigned”, so later runtime uses like pl.create_tensor([rows, HIDDEN], ...) will not rewrite to pl.tensor.dim(...) and can still leak the annotation-only DynVar into IR.

Also applies to: 581-591, 713-721

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/pypto/jit/specializer.py` around lines 470 - 472, The pre-seeding of
self._used_names with Store targets is treating removed annotation-only locals
(e.g., variables created via pl.dynamic(...) like rows) as real runtime
bindings, preventing DynVar anchoring in visit_Name; update the logic so that
when visit_Assign removes a dynamic-only binding (pl.dynamic(...)) you also
remove that target name from self._used_names (or track a separate set of
annotation-only names) so visit_Name will allow anchoring; locate and adjust the
handling in visit_Assign and the initialization of self._used_names (and
analogous code at the other sites around lines 581-591 and 713-721) to ensure
annotation-only Dynamic variables do not remain marked as "used".
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@python/pypto/jit/decorator.py`:
- Around line 371-390: The code stores DynDim entries under local temporary
names (e.g., "tmp") so subsequent hops using
_build_param_mapping(dep._param_names(), ...) only see real parameter names and
lose provenance; to fix, when mapping dep_param -> caller_arg inside the loop
that iterates dep_map.items(), detect if caller_arg is a local temporary (not
one of the caller function's parameter names) and resolve it back to the
originating caller parameter(s) before inserting into caller_map: use the
existing param_mapping, call_args_cache, and callers_by_dep_id to walk the
provenance (or invert param_mapping) and translate local aliases to original
parameter names, then set DynDim entries under those source param names in
caller_map (keep using dep._func, dep._param_names(), callers_by_dep_id,
call_args_cache, _build_param_mapping and out to locate and update the
mappings).
- Around line 464-481: In _func_name_lookup(), the current merge uses
out.setdefault(...) so module globals beat captured closure vars; change the
merge to assign/overwrite the key with the closure binding (e.g., out[fv_name] =
cell.cell_contents) while preserving the existing try/except that skips unbound
closure cells, and update the zip invocation to be explicit about strictness
(zip(co_freevars, closure, strict=True)) or add a targeted Ruff suppression for
B905 so the free-var-to-closure pairing is correct and lint warnings are
addressed.

In `@python/pypto/jit/specializer.py`:
- Around line 143-157: dynvar_literals currently silently overwrites
name→literal pairs across metas, but Specializer.specialize expects the first
declaration to win across contexts; change dynvar_literals to detect conflicting
mappings and fail fast: iterate self.tensor_meta.values() and for each DynDim
(in meta.shape) if d.name already in out and out[d.name] != d.literal raise a
clear ValueError (or a custom exception) explaining the conflicting dynvar name
and the two literals and include the originating meta/context identifier; keep
the function returning the mapping when no conflicts are found. This ensures
Specializer.specialize and dynvar_literals have consistent behaviour and
surfaces illegal reuse of the same DynVar name with different pl.dynamic()
literals.
- Around line 543-554: The code constructs raw ast.Name nodes for dynamic
dimensions (in the tuple-building block and in _shape_dim_node) which bypasses
visit_Name rewriting; replace those raw ast.Name(id=...) creations with anchored
expressions fetched from self._dynvar_anchors (e.g. use
self._dynvar_anchors[d.name]) or ensure the newly created nodes are passed
through self.visit() so visit_Name can emit the anchored form; update both the
tuple-construction loop that appends elts and the _shape_dim_node function to
return the anchored AST expression (or a visited node) whenever isinstance(d,
DynDim) rather than a bare ast.Name.

---

Outside diff comments:
In `@python/pypto/jit/specializer.py`:
- Around line 470-472: The pre-seeding of self._used_names with Store targets is
treating removed annotation-only locals (e.g., variables created via
pl.dynamic(...) like rows) as real runtime bindings, preventing DynVar anchoring
in visit_Name; update the logic so that when visit_Assign removes a dynamic-only
binding (pl.dynamic(...)) you also remove that target name from self._used_names
(or track a separate set of annotation-only names) so visit_Name will allow
anchoring; locate and adjust the handling in visit_Assign and the initialization
of self._used_names (and analogous code at the other sites around lines 581-591
and 713-721) to ensure annotation-only Dynamic variables do not remain marked as
"used".
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 6c375569-2ac6-4018-91c1-d7ed392d1da5

📥 Commits

Reviewing files that changed from the base of the PR and between e7e41f5 and 09bee27.

📒 Files selected for processing (4)
  • python/pypto/jit/decorator.py
  • python/pypto/jit/specializer.py
  • tests/ut/jit/test_decorator.py
  • tests/ut/jit/test_specializer.py

Comment thread python/pypto/jit/decorator.py
Comment thread python/pypto/jit/decorator.py
Comment thread python/pypto/jit/decorator.py
Comment thread python/pypto/jit/specializer.py
Comment thread python/pypto/jit/specializer.py Outdated
Address review feedback:

- specializer.py: _shape_tuple_node, _shape_dim_node, and the M, N = a.shape
  Case 2 emission now produce pl.tensor.dim(<param>, <dim_idx>) instead of
  ast.Name(<DynVar>). The previous code bypassed visit_Name (those nodes
  are returned directly by visit_Attribute / visit_Subscript / visit_Assign
  and not revisited), so DynVar references could still leak into runtime
  expressions and re-trigger the ConvertToSSA "Variable 'M' used outside
  its defining scope" error. _dyn_dim_expr centralizes the emission.
- specializer.py: visit_Name's DynVar substitution now rebuilds the anchor
  AST via _dyn_dim_expr (returns a fresh subtree per call) instead of
  reusing nested ast.Call / ast.Attribute children — avoids mutable AST
  node sharing across substitution sites.
- specializer.py: SpecializeContext.dynvar_literals docstring rewritten —
  the previous note about pl.dynamic() being singleton-cached was wrong
  (it isn't). The real constraint is on user source not rebinding a Python
  name to two different literals.
- decorator.py: _func_name_lookup now lets closure free vars override
  module globals, matching Python's own name-resolution order. The
  previous setdefault inversion silently returned the wrong constant when
  a nested function shadowed a module global. Adds strict=True to zip()
  (B905 — surfaced by CodeRabbit's ruff hook).
- decorator.py: _scan_dim_aliases is now reassignment-safe: any name
  rebound to a non-pl.tensor.dim value drops its alias entry, so patterns
  like tokens = pl.tensor.dim(x, 0); tokens = tokens - 1 don't stamp a
  stale DynDim onto downstream pl.create_tensor shapes.

Add two regression tests under TestDynamicLocalTensorMetadata:
- test_shape_attribute_emits_anchor_not_dynvar
- test_dim_alias_rebind_is_safe
@lyfne123 lyfne123 merged commit a73c056 into hw-native-sys:main May 26, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[Bug] @pl.jit cannot pass dynamic pl.create_tensor intermediates to inline functions

3 participants