Fix: Gradients don't propagate through array of structs#1207
Fix: Gradients don't propagate through array of structs#1207Adityakk9031 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughSpecial-cases assignments to struct fields inside array elements in the code generator, emitting reverse-pass adjoint accumulation to propagate gradients into the RHS for differentiable fields and suppressing the nondifferentiability warning. Adds tests validating gradient propagation through arrays of struct-backed fields and registers them in test suites. Changes
Sequence Diagram(s)sequenceDiagram
participant Kernel as Kernel (user)
participant Codegen as Codegen
participant Runtime as Runtime/Autodiff
participant Adjoint as Adjoint Accumulator
Kernel->>Codegen: emit forward assignment (y[i].field = x[i])
Codegen->>Runtime: generate forward code and mark struct-array-field case
Kernel->>Runtime: execute forward (compute loss)
Runtime->>Runtime: start backward pass
Runtime->>Codegen: request backward code for assignment
Codegen->>Adjoint: emit adjoint accumulation for struct-field-in-array -> propagate gradient into RHS
Adjoint->>Runtime: updated RHS adjoint values
Runtime->>Kernel: backward complete (gradients available)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Greptile OverviewGreptile SummaryThis PR fixes a critical autodiff bug where gradients failed to propagate backward through array-of-structs. The fix adds manual reverse-mode adjoint code generation in Key changes:
The implementation correctly identifies struct array field assignments and generates the necessary Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant Kernel as pack_kernel
participant Store as store builtin
participant ReverseAdjoint as Reverse Mode (NEW)
participant Tape as Tape.backward()
User->>Kernel: wp.launch(pack_kernel, x, y)
Kernel->>Store: y[i].a = x[i]
Note over Store: Forward pass: stores value
Note over Store: builtin is_differentiable=False
User->>Tape: tape.backward(loss)
Tape->>ReverseAdjoint: Execute adjoint code
Note over ReverseAdjoint: NEW: Manual gradient propagation<br/>x.adj += y[i].a.adj
ReverseAdjoint-->>User: Gradients flow to x.grad
Note over User,ReverseAdjoint: Before fix: gradient chain broke at store<br/>After fix: adj.add_reverse() closes the chain
|
warp/_src/codegen.py
Outdated
| if is_struct_array_field and adj.is_differentiable_value_type(strip_reference(rhs.type)): | ||
| adj.add_reverse(f"{rhs.emit_adj()} += {attr.emit_adj()};") |
There was a problem hiding this comment.
Add a test case for this fix (e.g., the minimal example from the PR description) to prevent regression
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Additional Comments (1)
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@warp/_src/codegen.py`:
- Around line 3296-3312: The check for struct types in the is_struct_array_field
predicate is using an invalid symbol (isinstance(aggregate_type,
warp._src.types.struct)) so it always fails; update the predicate to use the
proper type check helper type_is_struct(aggregate_type) from warp._src.types
(i.e., replace the isinstance call with type_is_struct(aggregate_type)) so
is_struct_array_field correctly detects struct array fields and the
adj.add_reverse call inside the block (which emits the adjoint propagation for
rhs via attr.emit_adj()) will run when appropriate.
warp/_src/codegen.py
Outdated
| if is_struct_array_field and adj.is_differentiable_value_type(strip_reference(rhs.type)): | ||
| adj.add_reverse(f"{rhs.emit_adj()} += {attr.emit_adj()};") |
There was a problem hiding this comment.
Check that a test case was added for this fix (e.g., the minimal example from PR description) to prevent regression
warp/tests/test_grad_struct_array.py
Outdated
| add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=devices, check_outputs=False) | ||
| add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices) | ||
| add_function_test(TestGrad, "test_name_clash", test_name_clash, devices=devices) | ||
| add_function_test(TestGrad, "test_struct_attribute_gradient", test_struct_attribute_gradient, devices=devices) | ||
| add_function_test(TestGrad, "test_copy", test_copy, devices=devices) | ||
| add_function_test(TestGrad, "test_aliasing", test_aliasing, devices=devices) | ||
| add_function_test(TestGrad, "test_gradient_internal", test_gradient_internal, devices=devices) | ||
| add_function_test(TestGrad, "test_gradient_external", test_gradient_external, devices=devices) | ||
| add_function_test(TestGrad, "test_gradient_precedence", test_gradient_precedence, devices=devices) | ||
| add_function_test(TestGrad, "test_gradient_slice_2d", test_gradient_slice_2d, devices=devices) | ||
| add_function_test(TestGrad, "test_gradient_slice_3d_1d", test_gradient_slice_3d_1d, devices=devices) | ||
| add_function_test(TestGrad, "test_gradient_slice_3d_2d", test_gradient_slice_3d_2d, devices=devices) |
There was a problem hiding this comment.
These test functions (test_scalar_grad, test_for_loop_grad, etc.) are not defined in this file or imported from anywhere. They're defined in test_grad.py, not in unittest_utils. This will cause NameError when the test module is loaded.
Remove these undefined test registrations and keep only test_struct_array_gradient_propagation:
| add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices) | |
| add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices) | |
| add_function_test(TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=devices, check_outputs=False) | |
| add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices) | |
| add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices) | |
| add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices) | |
| add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices) | |
| add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices) | |
| add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices) | |
| add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices) | |
| add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices) | |
| add_function_test(TestGrad, "test_name_clash", test_name_clash, devices=devices) | |
| add_function_test(TestGrad, "test_struct_attribute_gradient", test_struct_attribute_gradient, devices=devices) | |
| add_function_test(TestGrad, "test_copy", test_copy, devices=devices) | |
| add_function_test(TestGrad, "test_aliasing", test_aliasing, devices=devices) | |
| add_function_test(TestGrad, "test_gradient_internal", test_gradient_internal, devices=devices) | |
| add_function_test(TestGrad, "test_gradient_external", test_gradient_external, devices=devices) | |
| add_function_test(TestGrad, "test_gradient_precedence", test_gradient_precedence, devices=devices) | |
| add_function_test(TestGrad, "test_gradient_slice_2d", test_gradient_slice_2d, devices=devices) | |
| add_function_test(TestGrad, "test_gradient_slice_3d_1d", test_gradient_slice_3d_1d, devices=devices) | |
| add_function_test(TestGrad, "test_gradient_slice_3d_2d", test_gradient_slice_3d_2d, devices=devices) | |
| add_function_test(TestGrad, "test_struct_array_gradient_propagation", test_struct_array_gradient_propagation, devices=devices) |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@warp/tests/test_grad_struct_array.py`:
- Around line 43-60: The test function test_struct_array_gradient_propagation
has an unused parameter named test which triggers Ruff ARG001; rename that
parameter to _test (or _) in the function signature to mark it intentionally
unused and update any internal references if present, ensuring the function name
and behavior (including use of wp.ScopedDevice, x/y/loss setup, tape.backward,
and assertions) remain unchanged.
warp/tests/test_grad_struct_array.py
Outdated
| # limitations under the License. | ||
|
|
||
| import unittest | ||
| from typing import Any |
There was a problem hiding this comment.
Remove unused import
| from typing import Any |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@warp/tests/test_grad_struct_array.py`:
- Around line 62-69: Add the missing test class TestGradStructArray to the
default unittest suite by importing TestGradStructArray from its test module and
appending it to the test_classes list inside the default_suite() function;
specifically, update default_suite() (the function that builds test_classes) to
include the TestGradStructArray symbol so the
"test_struct_array_gradient_propagation" tests are executed in CI.
|
@tabo please check this |
|
@shi-eric please check this |
warp/tests/test_grad_struct_array.py
Outdated
| @wp.kernel | ||
| def loss_from_struct_array_kernel(y: wp.array(dtype=ScalarStruct), loss: wp.array(dtype=wp.float32)): | ||
| i = wp.tid() | ||
| loss[i] = y[i].a |
There was a problem hiding this comment.
Should use wp.atomic_add() here and launch N > 1 threads.
warp/tests/test_grad_struct_array.py
Outdated
|
|
||
| tape = wp.Tape() | ||
| with tape: | ||
| wp.launch(kernel=pack_struct_array_kernel, dim=1, inputs=[x], outputs=[y]) |
There was a problem hiding this comment.
Try with dim > 1 here, to test more realistic case
| # Test for issue #1174: Gradients not propagating through array of structs | ||
| @wp.struct | ||
| class ScalarStruct: | ||
| a: wp.float32 |
There was a problem hiding this comment.
Can you add more dtypes here to test a wider range of use cases? Eg vec3, mat22.
I don't think this change supports arrays of structs that contain array fields. Could you test this case and add more guardrails in codegen.py? If there aren't any compiler errors, we should probably issue a warning if adj.used_by_backward_kernel is true and the user is attempting to write to an array field in a differentiable kernel.
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Description
This PR fixes a long-standing bug where gradients fail to propagate backward through array-of-structs in automatic differentiation. The gradient correctly accumulates into the struct array's adjoint but doesn't flow back to the source array.
Problem
When using autodiff with array-of-structs, assignments like y[i].a = x[i] where y is an array of structs would not propagate gradients back to x, even though gradients correctly accumulated into y.grad.
Minimal Example
@WP.struct
class Scalar:
a: wp.float32
@wp.kernel
def pack_kernel(x: wp.array(dtype=wp.float32), y: wp.array(dtype=Scalar)):
i = wp.tid()
y[i].a = x[i] # Gradient chain broke here
Before fix: x.grad = [0.] ❌
After fix: x.grad = [1.] ✅
Root Cause
In
warp/_src/codegen.py
, the
emit_Assign
function handles struct field assignments by calling the
store
builtin, but it failed to generate the reverse-mode adjoint code needed to propagate gradients from the RHS variable back through the struct field reference.
Changes
Modified Files
warp/_src/codegen.py
Added detection for struct array field assignments using type_is_struct()
Added missing adj.add_reverse() call to generate gradient propagation code
Updated warning logic to exclude this case (now differentiable)
warp/tests/test_grad_struct_array.py
(New File)
Added regression test
TestGradStructArray
warp/tests/unittest_suites.py
Registered
TestGradStructArray
to the default test suite for CI execution
Code Logic
Check if we're assigning to a struct field in an array element
is_struct_array_field = (
is_reference(aggregate.type) and
type_is_struct(aggregate_type)
)
if is_reference(attr.type):
adj.add_builtin_call("store", [attr, rhs])
Testing
Added a new test file
warp/tests/test_grad_struct_array.py
with the following test case:
def test_struct_array_gradient_propagation(test, device):
# ... setup x, y, loss ...
Summary by CodeRabbit
Bug Fixes
Tests