Skip to content

Conversation

@groberts-flex
Copy link
Contributor

@groberts-flex groberts-flex commented Nov 19, 2025

@marcorudolphflex @yaugenst-flex

After the interface reconfiguration, I split the PR for these custom autograd hooks into two so that hopefully it's easier to review! This one is for the user_vjp which allows someone to override the internal vjp calculation for a structure geometry or medium. The other hook is done as well, but I'll save it for after this one is done with review!

Based on the other review, I updated the interface to ideally be a little more straightforward to use and less cumbersome. The specification of paths in the user_vjp is not required unless you want it to only apply to a specific path in the structure. It can also be specified as just a single user_vjp value in run_async_custom if you want the same one to apply to all of the simulations (instead of having to manually broadcast it). I think there are other helper functions that could be added in the future that might make things even easier like applying a certain user_vjp for all structures with a specific geometry type, but I'll leave those for a future upgrade.

Greptile Summary

  • Adds user_vjp parameter to autograd run functions enabling custom gradient calculations for specific structures
  • Implements VJP lookup mechanism in backward pass to route computation through user-defined functions when specified
  • Extends DerivativeInfo with updated_epsilon helper for finite difference gradient computations in custom VJPs

Confidence Score: 4/5

  • This PR is safe to merge with minor documentation improvements recommended
  • Score reflects solid implementation with comprehensive test coverage and proper error handling. The core logic correctly routes custom VJP functions through the gradient computation pipeline. Minor documentation issues with inline docstrings and incomplete comments don't impact functionality but should be addressed for maintainability.
  • Pay attention to tidy3d/components/structure.py and tidy3d/web/api/autograd/backward.py for the docstring formatting issues

Important Files Changed

Filename Overview
tidy3d/web/api/autograd/types.py New file introducing UserVJPConfig dataclass for custom gradient computation and SetupRunResult for run preparation
tidy3d/web/api/autograd/autograd.py Adds user_vjp parameter throughout run functions with validation and broadcasting logic for single/batch simulations
tidy3d/web/api/autograd/backward.py Implements user VJP lookup mechanism and updated_epsilon helper function for finite difference gradient computations
tidy3d/components/structure.py Extends _compute_derivatives to accept optional vjp_fns dict for custom gradient paths per geometry/medium field

Sequence Diagram

sequenceDiagram
    participant User
    participant run_custom
    participant _run_primitive
    participant setup_fwd
    participant _run_tidy3d
    participant _run_bwd
    participant postprocess_adj
    participant Structure
    participant UserVJP

    User->>run_custom: "call with simulation and user_vjp"
    run_custom->>_run_primitive: "pass user_vjp to primitive"
    _run_primitive->>setup_fwd: "setup forward simulation"
    setup_fwd-->>_run_primitive: "combined simulation"
    _run_primitive->>_run_tidy3d: "run forward simulation"
    _run_tidy3d-->>_run_primitive: "simulation data"
    
    Note over _run_bwd: Backward pass triggered
    _run_bwd->>postprocess_adj: "compute gradients with user_vjp"
    postprocess_adj->>postprocess_adj: "build user_vjp_lookup dict"
    postprocess_adj->>Structure: "_compute_derivatives with vjp_fns"
    
    alt user VJP exists for path
        Structure->>UserVJP: "call user-defined vjp function"
        UserVJP-->>Structure: "custom gradients"
    else default path
        Structure->>Structure: "call internal gradient method"
        Structure-->>Structure: "standard gradients"
    end
    
    Structure-->>postprocess_adj: "gradient values"
    postprocess_adj-->>_run_bwd: "VJP field map"
    _run_bwd-->>User: "gradients for optimization"
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@github-actions
Copy link
Contributor

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/autograd/derivative_utils.py (100%)
  • tidy3d/components/geometry/primitives.py (100%)
  • tidy3d/components/structure.py (100%)
  • tidy3d/plugins/smatrix/run.py (100%)
  • tidy3d/web/api/autograd/autograd.py (91.9%): Missing lines 476-477,481,489,499,1146
  • tidy3d/web/api/autograd/backward.py (77.8%): Missing lines 263-264,266,276,281,351,388-390,394
  • tidy3d/web/api/autograd/types.py (100%)

Summary

  • Total: 172 lines
  • Missing: 16 lines
  • Coverage: 90%

tidy3d/web/api/autograd/autograd.py

Lines 472-485

  472         if fn_arg is None:
  473             return fn_arg
  474 
  475         if isinstance(fn_arg, base_type):
! 476             expanded = dict.fromkeys(sim_dict.keys(), fn_arg)
! 477             return expanded
  478 
  479         expanded = {}
  480         if not isinstance(fn_arg, type(orig_sim_arg)):
! 481             raise AdjointError(
  482                 f"{fn_arg_name} type ({type(fn_arg)}) should match simulations type ({type(simulations)})"
  483             )
  484 
  485         if isinstance(orig_sim_arg, dict):

Lines 485-493

  485         if isinstance(orig_sim_arg, dict):
  486             check_keys = fn_arg.keys() == sim_dict.keys()
  487 
  488             if not check_keys:
! 489                 raise AdjointError(f"{fn_arg_name} keys do not match simulations keys")
  490 
  491             for key, val in fn_arg.items():
  492                 if isinstance(val, base_type):
  493                     expanded[key] = (val,)

Lines 495-503

  495                     expanded[key] = val
  496 
  497         elif isinstance(orig_sim_arg, (list, tuple)):
  498             if not (len(fn_arg) == len(orig_sim_arg)):
! 499                 raise AdjointError(
  500                     f"{fn_arg_name} is not the same length as simulations ({len(expanded)} vs. {len(simulations)})"
  501                 )
  502 
  503             for idx, key in enumerate(sim_dict.keys()):

Lines 1142-1150

  1142 
  1143                 # Compute VJP contribution
  1144                 task_user_vjp = user_vjp.get(task_name)
  1145                 if isinstance(task_user_vjp, UserVJPConfig):
! 1146                     task_user_vjp = (task_user_vjp,)
  1147 
  1148                 vjp_results[adj_task_name] = postprocess_adj(
  1149                     sim_data_adj=sim_data_adj,
  1150                     sim_data_orig=sim_data_orig,

tidy3d/web/api/autograd/backward.py

Lines 259-270

  259         ) -> ScalarFieldDataArray:
  260             # Return the simulation permittivity for eps_box after replacing the geometry
  261             # for this structure with a new geometry. This is helpful for carrying out finite
  262             # difference permittivity computations
! 263             sim_orig = sim_data_orig.simulation
! 264             sim_orig_grid_spec = td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid)
  265 
! 266             update_sim = sim_orig.updated_copy(
  267                 structures=[
  268                     sim_orig.structures[idx].updated_copy(geometry=replacement_geometry)
  269                     if idx == structure_index
  270                     else sim_orig.structures[idx]

Lines 272-285

  272                 ],
  273                 grid_spec=sim_orig_grid_spec,
  274             )
  275 
! 276             eps_by_f = [
  277                 update_sim.epsilon(box=eps_box, coord_key="centers", freq=f)
  278                 for f in adjoint_frequencies
  279             ]
  280 
! 281             return xr.concat(eps_by_f, dim="f").assign_coords(f=adjoint_frequencies)
  282 
  283         # get chunk size - if None, process all frequencies as one chunk
  284         freq_chunk_size = config.adjoint.solver_freq_chunk_size
  285         n_freqs = len(adjoint_frequencies)

Lines 347-355

  347                 select_adjoint_freqs: typing.Optional[FreqDataArray] = select_adjoint_freqs,
  348                 updated_epsilon_full: typing.Optional[typing.Callable] = updated_epsilon_full,
  349             ) -> ScalarFieldDataArray:
  350                 # Get permittivity function for a subset of frequencies
! 351                 return updated_epsilon_full(replacement_geometry).sel(f=select_adjoint_freqs)
  352 
  353             common_kwargs = {
  354                 "E_der_map": E_der_map_chunk,
  355                 "D_der_map": D_der_map_chunk,

Lines 384-398

  384                 vjp_chunk = structure._compute_derivatives(derivative_info_struct, vjp_fns=vjp_fns)
  385 
  386                 for path, value in vjp_chunk.items():
  387                     if path in vjp_value_map:
! 388                         existing = vjp_value_map[path]
! 389                         if isinstance(existing, (list, tuple)) and isinstance(value, (list, tuple)):
! 390                             vjp_value_map[path] = type(existing)(
  391                                 x + y for x, y in zip(existing, value)
  392                             )
  393                         else:
! 394                             vjp_value_map[path] = existing + value
  395                     else:
  396                         vjp_value_map[path] = value
  397 
  398         # store vjps in output map

Copy link
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @groberts-flex this is very nice to have! As discussed, I think we should change the name to "custom" instead of "user" VJP. Left a couple of other comments but overall looks good!

@groberts-flex groberts-flex force-pushed the groberts-flex/custom_user_vjp_hook_FXC-3730 branch 3 times, most recently from 116571c to 3a444ca Compare December 2, 2025 19:49
@groberts-flex
Copy link
Contributor Author

when you guys get a chance to take another look at this, it would be much appreciated! I rebased the changes I made last week, so should be ready to go if things look good to you all

Copy link
Contributor

@marcorudolphflex marcorudolphflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks much better now. Great to have that feature!
Just a few code styling comments/questions from my side.

CHANGELOG.md Outdated
- Added more RF-specific mode characteristics to `MicrowaveModeData`, including propagation constants (alpha, beta, gamma), phase/group velocities, wave impedance, and automatic mode classification with configurable polarization thresholds in `MicrowaveModeSpec`.
- Introduce `tidy3d.rf` namespace to consolidate all RF classes.
- Added support for custom colormaps in `plot_field`.
- Added `custom_vjp` and new custom run functions that provide hooks into adjoint for custom gradient calculations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to unreleased

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!



__all__ = [
"SetupRunResult",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add CustomVJPConfig or remove __all__ entirely as it defaults?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, will remove all here. don't need either of these public facing


# compute derivatives for chunk
vjp_chunk = structure._compute_derivatives(derivative_info)
common_kwargs = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't use dict here if not necessary - or is it?
Would be better for linting/IDE error detection...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, it's not necessary! I removed it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, it's not necessary! I added them back as explicit keyword args.


@dataclass
class CustomVJPConfig:
structure_index: typing.Union[int, type[GeometryType], type[MediumType]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better naming here if also types are supported?
just structure? structure_index_or_type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, good point! I changed this to structure

aux_data_dict: dict[str, dict[str, typing.Any]],
local_gradient: bool,
max_num_adjoint_per_fwd: int,
custom_vjp: typing.Optional[dict[str, typing.Sequence[CustomVJPConfig]]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict value can also be of type CustomVJPConfig(no Sequence?)? see L1263

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think actually it should only be a Sequence and L1263 is unnecessary because _run_async_primitive only accepts custom_vjp: typing.Optional[dict[str, typing.Sequence[CustomVJPConfig]]]
and so this backward call from the custom vjp should get that same vjp argument passed to it

lazy = True if lazy is None else bool(lazy)

def validate_and_expand(
fn_arg: CustomVJPConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this the same type as for custom_vjp in run_async_custom?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, good catch!

pay_type: typing.Union[PayType, str] = PayType.AUTO,
priority: typing.Optional[int] = None,
lazy: typing.Optional[bool] = None,
custom_vjp: typing.Optional[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make own type for this? CustomVJPSpec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! added a type for this

lazy: typing.Optional[bool] = None,
) -> BatchData:
"""Wrapper for run_async_custom for usage without custom_vjp for public facing API."""
return run_async_custom(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we don't have the run_async_custom code in here? Is it since we do not want to have custom_vjp in the public API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for run and run_custom

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the reasoning was so that the custom version stays out of the public api. eventually we could expose it but for now, we are just keeping the public interface the same

@pytest.mark.parametrize("use_task_names", [True, False])
@pytest.mark.parametrize(
"use_single_custom_vjp, specify_custom_vjp_by_type",
[(True, True), (True, False), (False, False)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't make [False, True] also sense here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think so! trying to remember why I took that config out, but can't remember! adding it back in will make sure the tests still pass

Copy link
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great don't have too much to add!

expanded_custom_vjp.append(updated_vjp_config)

elif isinstance(vjp_config.structure_index, type) and issubclass(
custom_vjp.structure_index, allowed_classes_medium
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering about custom_vjp.structure_index, wouldn't this be an AttributeError?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, great catch. that should mirror the geometry check above and be vjp_config instead of custom_vjp!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead code?

@groberts-flex groberts-flex force-pushed the groberts-flex/custom_user_vjp_hook_FXC-3730 branch from 3a444ca to ca5e293 Compare December 9, 2025 21:51
…p arguments to provide hook into gradient computation for custom vjp calculation.
@groberts-flex groberts-flex force-pushed the groberts-flex/custom_user_vjp_hook_FXC-3730 branch from ca5e293 to fe881ee Compare December 9, 2025 21:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants