Expose n_recycle as inference sampler parameter#259
Expose n_recycle as inference sampler parameter#259Ubiquinone-dot merged 2 commits intoproductionfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR exposes n_recycle as an inference-time sampler/config parameter for RFD3 so users can override the checkpoint default recycling count via CLI, and updates related documentation/configuration. It also includes a repo-wide ruff format pass that primarily reformats assertions and a few expressions without changing behavior.
Changes:
- Add
inference_sampler.n_recycle(nullable) and thread it through the RFD3 inference sampler into the diffusion module’s recycle loop. - Update RFD3 inference config (
rfdiffusion3.yaml) and docs (docs/input.md, README) to surfacen_recycleandnum_timesteps. - Apply
ruff formatacross multiple packages/tests (mostly assert formatting).
Reviewed changes
Copilot reviewed 63 out of 63 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| src/foundry/utils/logging.py | Ruff formatting (assert formatting). |
| src/foundry/utils/ddp.py | Ruff formatting (assert formatting). |
| src/foundry/utils/datasets.py | Ruff formatting (assert formatting). |
| src/foundry/utils/alignment.py | Ruff formatting (assert formatting). |
| src/foundry/trainers/fabric.py | Ruff formatting (assert formatting). |
| src/foundry/metrics/metric.py | Ruff formatting (assert formatting). |
| src/foundry/metrics/losses.py | Ruff formatting (assert formatting). |
| src/foundry/inference_engines/base.py | Ruff formatting + touched assertion messages. |
| src/foundry/callbacks/metrics_logging.py | Ruff formatting (assert formatting). |
| models/rfd3/tests/transforms/test_pipeline_regression.py | Ruff formatting (assert formatting). |
| models/rfd3/tests/test_unindexing.py | Ruff formatting (assert formatting). |
| models/rfd3/tests/test_tokenization.py | Ruff formatting (assert formatting). |
| models/rfd3/tests/test_symmetry.py | Ruff formatting (assert formatting). |
| models/rfd3/tests/test_selections.py | Ruff formatting (assert formatting). |
| models/rfd3/tests/test_metrics.py | Ruff formatting (assert formatting). |
| models/rfd3/tests/test_glycines.py | Ruff formatting (assert formatting). |
| models/rfd3/tests/test_conditioning.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/utils/io.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/utils/inference.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/transforms/virtual_atoms.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/transforms/util_transforms.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/transforms/ppi_transforms.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/transforms/pipelines.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/transforms/conditioning_base.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/trainer/fabric_trainer.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/trainer/dump_validation_structures.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/train.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/model/RFD3_diffusion_module.py | Allow inference-time override of n_recycle when provided. |
| models/rfd3/src/rfd3/model/layers/chunked_pairwise.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/model/layers/block_utils.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/model/inference_sampler.py | Add n_recycle to sampler config and pass into diffusion module calls. |
| models/rfd3/src/rfd3/metrics/losses.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/metrics/hbonds_metrics.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/inference/symmetry/frames.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/inference/symmetry/checks.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/inference/legacy_input_parsing.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/inference/input_parsing.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/inference/datasets.py | Ruff formatting (assert formatting). |
| models/rfd3/src/rfd3/callbacks.py | Ruff formatting (assert formatting). |
| models/rfd3/README.md | Make InputSpecification reference more prominent and link to docs/input.md. |
| models/rfd3/docs/input.md | Document num_timesteps and n_recycle under CLI options. |
| models/rfd3/configs/inference_engine/rfdiffusion3.yaml | Add inference_sampler.n_recycle: null to allow CLI override. |
| models/rf3/tests/test_inference_regression.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/validate.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/utils/predicted_error.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/utils/io.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/train.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/symmetry/resolve.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/model/RF3.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/model/layers/mlff.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/model/layers/attention.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/diffusion_samplers/inference_sampler.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/data/pipelines.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/data/ground_truth_template.py | Ruff formatting (line wrapping). |
| models/rf3/src/rf3/data/extra_xforms.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/callbacks/metrics_logging.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/callbacks/dump_validation_structures.py | Ruff formatting (assert formatting). |
| models/rf3/src/rf3/alignment.py | Ruff formatting (assert formatting). |
| models/mpnn/tests/test_utils.py | Ruff formatting (assert formatting). |
| models/mpnn/tests/test_samplers.py | Ruff formatting (assert formatting). |
| models/mpnn/tests/test_pipeline.py | Ruff formatting (assert formatting / wrapping). |
| models/mpnn/tests/test_integration.py | Ruff formatting (assert formatting). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Recycling | ||
| n_recycle: int | None = None # Override model default n_recycle for inference | ||
|
|
There was a problem hiding this comment.
n_recycle is user-configurable now, but there’s no validation that it’s a positive integer. If a user sets inference_sampler.n_recycle=0 (or a negative / non-int), forward_with_recycle() will run zero iterations and downstream code will error when expected outputs are missing. Add an explicit validation (e.g., in a __post_init__ for SampleDiffusionConfig or in ConditionalDiffusionSampler.__init__) to require n_recycle is None or n_recycle >= 1 (and ideally an int).
| assert name in REGISTERED_CHECKPOINTS, ( | ||
| "Checkpoint provided not and not in registered checkpoints" | ||
| ) |
There was a problem hiding this comment.
The assertion message has broken grammar ("Checkpoint provided not and not in registered checkpoints"), which makes this failure hard to understand. Please reword it to clearly indicate the checkpoint name isn’t registered and (optionally) list valid keys.
The n_recycle parameter was previously hardcoded in the diffusion module config and not overridable at inference time. This exposes it through the inference sampler so users can control recycling iterations via CLI (e.g. inference_sampler.n_recycle=3). Also adds num_timesteps and n_recycle to the "Other CLI Options" docs section, and makes the InputSpecification reference more prominent in the README. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
7291158 to
f49348f
Compare
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Summary
n_recyclethrough the inference sampler config so users can override the model default (2) at inference time via CLI:inference_sampler.n_recycle=3num_timestepsandn_recycleto the "Other CLI Options" section indocs/input.md(they were missing from that list)[!NOTE]to[!IMPORTANT], added direct link todocs/input.md)ruff formatpass (pre-commit hook required it)Test plan
n_recycleoverride) — should behave identically to before (uses model default of 2)inference_sampler.n_recycle=1— verify fewer recycles are used (check logs)inference_sampler.n_recycle=3— verify more recycles are used🤖 Generated with Claude Code