-
Notifications
You must be signed in to change notification settings - Fork 66
Add user_vjp hook and custom run functions to allow overriding the internal vjp #3015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 6 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/web/api/autograd/autograd.pyLines 472-485 472 if fn_arg is None:
473 return fn_arg
474
475 if isinstance(fn_arg, base_type):
! 476 expanded = dict.fromkeys(sim_dict.keys(), fn_arg)
! 477 return expanded
478
479 expanded = {}
480 if not isinstance(fn_arg, type(orig_sim_arg)):
! 481 raise AdjointError(
482 f"{fn_arg_name} type ({type(fn_arg)}) should match simulations type ({type(simulations)})"
483 )
484
485 if isinstance(orig_sim_arg, dict):Lines 485-493 485 if isinstance(orig_sim_arg, dict):
486 check_keys = fn_arg.keys() == sim_dict.keys()
487
488 if not check_keys:
! 489 raise AdjointError(f"{fn_arg_name} keys do not match simulations keys")
490
491 for key, val in fn_arg.items():
492 if isinstance(val, base_type):
493 expanded[key] = (val,)Lines 495-503 495 expanded[key] = val
496
497 elif isinstance(orig_sim_arg, (list, tuple)):
498 if not (len(fn_arg) == len(orig_sim_arg)):
! 499 raise AdjointError(
500 f"{fn_arg_name} is not the same length as simulations ({len(expanded)} vs. {len(simulations)})"
501 )
502
503 for idx, key in enumerate(sim_dict.keys()):Lines 1142-1150 1142
1143 # Compute VJP contribution
1144 task_user_vjp = user_vjp.get(task_name)
1145 if isinstance(task_user_vjp, UserVJPConfig):
! 1146 task_user_vjp = (task_user_vjp,)
1147
1148 vjp_results[adj_task_name] = postprocess_adj(
1149 sim_data_adj=sim_data_adj,
1150 sim_data_orig=sim_data_orig,tidy3d/web/api/autograd/backward.pyLines 259-270 259 ) -> ScalarFieldDataArray:
260 # Return the simulation permittivity for eps_box after replacing the geometry
261 # for this structure with a new geometry. This is helpful for carrying out finite
262 # difference permittivity computations
! 263 sim_orig = sim_data_orig.simulation
! 264 sim_orig_grid_spec = td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid)
265
! 266 update_sim = sim_orig.updated_copy(
267 structures=[
268 sim_orig.structures[idx].updated_copy(geometry=replacement_geometry)
269 if idx == structure_index
270 else sim_orig.structures[idx]Lines 272-285 272 ],
273 grid_spec=sim_orig_grid_spec,
274 )
275
! 276 eps_by_f = [
277 update_sim.epsilon(box=eps_box, coord_key="centers", freq=f)
278 for f in adjoint_frequencies
279 ]
280
! 281 return xr.concat(eps_by_f, dim="f").assign_coords(f=adjoint_frequencies)
282
283 # get chunk size - if None, process all frequencies as one chunk
284 freq_chunk_size = config.adjoint.solver_freq_chunk_size
285 n_freqs = len(adjoint_frequencies)Lines 347-355 347 select_adjoint_freqs: typing.Optional[FreqDataArray] = select_adjoint_freqs,
348 updated_epsilon_full: typing.Optional[typing.Callable] = updated_epsilon_full,
349 ) -> ScalarFieldDataArray:
350 # Get permittivity function for a subset of frequencies
! 351 return updated_epsilon_full(replacement_geometry).sel(f=select_adjoint_freqs)
352
353 common_kwargs = {
354 "E_der_map": E_der_map_chunk,
355 "D_der_map": D_der_map_chunk,Lines 384-398 384 vjp_chunk = structure._compute_derivatives(derivative_info_struct, vjp_fns=vjp_fns)
385
386 for path, value in vjp_chunk.items():
387 if path in vjp_value_map:
! 388 existing = vjp_value_map[path]
! 389 if isinstance(existing, (list, tuple)) and isinstance(value, (list, tuple)):
! 390 vjp_value_map[path] = type(existing)(
391 x + y for x, y in zip(existing, value)
392 )
393 else:
! 394 vjp_value_map[path] = existing + value
395 else:
396 vjp_value_map[path] = value
397
398 # store vjps in output map |
yaugenst-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @groberts-flex this is very nice to have! As discussed, I think we should change the name to "custom" instead of "user" VJP. Left a couple of other comments but overall looks good!
116571c to
3a444ca
Compare
|
when you guys get a chance to take another look at this, it would be much appreciated! I rebased the changes I made last week, so should be ready to go if things look good to you all |
marcorudolphflex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much better now. Great to have that feature!
Just a few code styling comments/questions from my side.
CHANGELOG.md
Outdated
| - Added more RF-specific mode characteristics to `MicrowaveModeData`, including propagation constants (alpha, beta, gamma), phase/group velocities, wave impedance, and automatic mode classification with configurable polarization thresholds in `MicrowaveModeSpec`. | ||
| - Introduce `tidy3d.rf` namespace to consolidate all RF classes. | ||
| - Added support for custom colormaps in `plot_field`. | ||
| - Added `custom_vjp` and new custom run functions that provide hooks into adjoint for custom gradient calculations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to unreleased
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
tidy3d/web/api/autograd/types.py
Outdated
|
|
||
|
|
||
| __all__ = [ | ||
| "SetupRunResult", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add CustomVJPConfig or remove __all__ entirely as it defaults?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, will remove all here. don't need either of these public facing
tidy3d/web/api/autograd/backward.py
Outdated
|
|
||
| # compute derivatives for chunk | ||
| vjp_chunk = structure._compute_derivatives(derivative_info) | ||
| common_kwargs = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't use dict here if not necessary - or is it?
Would be better for linting/IDE error detection...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, it's not necessary! I removed it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, it's not necessary! I added them back as explicit keyword args.
tidy3d/web/api/autograd/types.py
Outdated
|
|
||
| @dataclass | ||
| class CustomVJPConfig: | ||
| structure_index: typing.Union[int, type[GeometryType], type[MediumType]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better naming here if also types are supported?
just structure? structure_index_or_type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, good point! I changed this to structure
| aux_data_dict: dict[str, dict[str, typing.Any]], | ||
| local_gradient: bool, | ||
| max_num_adjoint_per_fwd: int, | ||
| custom_vjp: typing.Optional[dict[str, typing.Sequence[CustomVJPConfig]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict value can also be of type CustomVJPConfig(no Sequence?)? see L1263
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think actually it should only be a Sequence and L1263 is unnecessary because _run_async_primitive only accepts custom_vjp: typing.Optional[dict[str, typing.Sequence[CustomVJPConfig]]]
and so this backward call from the custom vjp should get that same vjp argument passed to it
tidy3d/web/api/autograd/autograd.py
Outdated
| lazy = True if lazy is None else bool(lazy) | ||
|
|
||
| def validate_and_expand( | ||
| fn_arg: CustomVJPConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this the same type as for custom_vjp in run_async_custom?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, good catch!
tidy3d/web/api/autograd/autograd.py
Outdated
| pay_type: typing.Union[PayType, str] = PayType.AUTO, | ||
| priority: typing.Optional[int] = None, | ||
| lazy: typing.Optional[bool] = None, | ||
| custom_vjp: typing.Optional[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make own type for this? CustomVJPSpec?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! added a type for this
| lazy: typing.Optional[bool] = None, | ||
| ) -> BatchData: | ||
| """Wrapper for run_async_custom for usage without custom_vjp for public facing API.""" | ||
| return run_async_custom( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we don't have the run_async_custom code in here? Is it since we do not want to have custom_vjp in the public API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for run and run_custom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, the reasoning was so that the custom version stays out of the public api. eventually we could expose it but for now, we are just keeping the public interface the same
| @pytest.mark.parametrize("use_task_names", [True, False]) | ||
| @pytest.mark.parametrize( | ||
| "use_single_custom_vjp, specify_custom_vjp_by_type", | ||
| [(True, True), (True, False), (False, False)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't make [False, True] also sense here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I think so! trying to remember why I took that config out, but can't remember! adding it back in will make sure the tests still pass
yaugenst-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great don't have too much to add!
tidy3d/web/api/autograd/autograd.py
Outdated
| expanded_custom_vjp.append(updated_vjp_config) | ||
|
|
||
| elif isinstance(vjp_config.structure_index, type) and issubclass( | ||
| custom_vjp.structure_index, allowed_classes_medium |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering about custom_vjp.structure_index, wouldn't this be an AttributeError?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, great catch. that should mirror the geometry check above and be vjp_config instead of custom_vjp!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dead code?
3a444ca to
ca5e293
Compare
…p arguments to provide hook into gradient computation for custom vjp calculation.
ca5e293 to
fe881ee
Compare
@marcorudolphflex @yaugenst-flex
After the interface reconfiguration, I split the PR for these custom autograd hooks into two so that hopefully it's easier to review! This one is for the user_vjp which allows someone to override the internal vjp calculation for a structure geometry or medium. The other hook is done as well, but I'll save it for after this one is done with review!
Based on the other review, I updated the interface to ideally be a little more straightforward to use and less cumbersome. The specification of paths in the user_vjp is not required unless you want it to only apply to a specific path in the structure. It can also be specified as just a single user_vjp value in run_async_custom if you want the same one to apply to all of the simulations (instead of having to manually broadcast it). I think there are other helper functions that could be added in the future that might make things even easier like applying a certain user_vjp for all structures with a specific geometry type, but I'll leave those for a future upgrade.
Greptile Summary
user_vjpparameter to autograd run functions enabling custom gradient calculations for specific structuresDerivativeInfowithupdated_epsilonhelper for finite difference gradient computations in custom VJPsConfidence Score: 4/5
tidy3d/components/structure.pyandtidy3d/web/api/autograd/backward.pyfor the docstring formatting issuesImportant Files Changed
UserVJPConfigdataclass for custom gradient computation andSetupRunResultfor run preparationuser_vjpparameter throughout run functions with validation and broadcasting logic for single/batch simulationsupdated_epsilonhelper function for finite difference gradient computations_compute_derivativesto accept optionalvjp_fnsdict for custom gradient paths per geometry/medium fieldSequence Diagram
sequenceDiagram participant User participant run_custom participant _run_primitive participant setup_fwd participant _run_tidy3d participant _run_bwd participant postprocess_adj participant Structure participant UserVJP User->>run_custom: "call with simulation and user_vjp" run_custom->>_run_primitive: "pass user_vjp to primitive" _run_primitive->>setup_fwd: "setup forward simulation" setup_fwd-->>_run_primitive: "combined simulation" _run_primitive->>_run_tidy3d: "run forward simulation" _run_tidy3d-->>_run_primitive: "simulation data" Note over _run_bwd: Backward pass triggered _run_bwd->>postprocess_adj: "compute gradients with user_vjp" postprocess_adj->>postprocess_adj: "build user_vjp_lookup dict" postprocess_adj->>Structure: "_compute_derivatives with vjp_fns" alt user VJP exists for path Structure->>UserVJP: "call user-defined vjp function" UserVJP-->>Structure: "custom gradients" else default path Structure->>Structure: "call internal gradient method" Structure-->>Structure: "standard gradients" end Structure-->>postprocess_adj: "gradient values" postprocess_adj-->>_run_bwd: "VJP field map" _run_bwd-->>User: "gradients for optimization"