diff --git a/src/winml/modelkit/commands/perf.py b/src/winml/modelkit/commands/perf.py index 0f83c871c..e0d5f039d 100644 --- a/src/winml/modelkit/commands/perf.py +++ b/src/winml/modelkit/commands/perf.py @@ -831,6 +831,8 @@ def _perf_modules( ep_options: dict[str, str] | None = None, precision: str = "auto", allow_unsupported_nodes: bool = False, + rebuild: bool = False, + ignore_cache: bool = False, ) -> None: """Run per-module build and benchmark for matching submodules. @@ -863,13 +865,21 @@ def _perf_modules( precision: Precision mode passed through to the build stage. allow_unsupported_nodes: If True, warn instead of failing the build when the analyzer reports unsupported nodes that persist. + rebuild: If True, overwrite cached per-module artifacts and re-run the + build (mirrors the single-model ``--rebuild``). + ignore_cache: If True, build each module in a throwaway temp dir and + always rebuild, discarding artifacts afterward (mirrors the + single-model ``--ignore-cache``). """ + import contextlib import difflib import json as json_mod import tempfile from ..build import build_hf_model + from ..cache import get_cache_dir, get_cache_key, get_model_dir from ..config import SubmoduleClassNotFoundError, generate_hf_build_config + from ..loader.task import get_task_abbrev from ..sysinfo import resolve_device, resolve_eps from .build import _instantiate_parent_model @@ -932,6 +942,18 @@ def _perf_modules( parent_loader_cfg, _, _, _resolution = resolve_loader_config(model_id=hf_model, task=task) parent_model = _instantiate_parent_model(model_type, task=parent_loader_cfg.task) + # Cache control mirrors auto.py / the single-model path: + # --ignore-cache -> build each module in a throwaway temp dir, always + # rebuild, discard afterward + # --rebuild -> reuse the persistent model dir but overwrite artifacts + # Each module's cache_key folds in loader.module_path (and its I/O shapes), + # so sibling instances of the same class get distinct keys and coexist in + # the shared model dir without colliding. + use_cache = not ignore_cache + force_rebuild = rebuild or ignore_cache + task_abbrev = get_task_abbrev(parent_loader_cfg.task) if parent_loader_cfg.task else "module" + cache_model_dir = get_model_dir(hf_model, cache_dir=get_cache_dir()) if use_cache else None + all_results: list[dict[str, Any]] = [] for i, cfg in enumerate(module_configs): module_path = cfg.loader.module_path @@ -960,12 +982,26 @@ def _perf_modules( if no_compile: cfg.compile = None - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: + # Compute the cache key AFTER the quant/compile mutations above so it + # reflects what is actually built. + cache_key = get_cache_key(task_abbrev, cfg.generate_cache_key()) + + # Persistent model dir (reused across runs) when caching, else a + # throwaway temp dir that is removed when the with-block exits. + build_dir_ctx: Any = ( + contextlib.nullcontext(cache_model_dir) + if use_cache + else tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + ) + with build_dir_ctx as build_dir_raw: + build_dir = Path(build_dir_raw) try: build_result = build_hf_model( config=cfg, - output_dir=Path(tmpdir), + output_dir=build_dir, pytorch_model=submodule, + rebuild=force_rebuild, + cache_key=cache_key, ep=ep, device=resolved_device, allow_unsupported_nodes=allow_unsupported_nodes, @@ -1681,6 +1717,8 @@ def perf( ep_options=ep_provider_options, precision=precision.lower(), allow_unsupported_nodes=allow_unsupported_nodes, + rebuild=rebuild, + ignore_cache=ignore_cache, ) return diff --git a/tests/unit/commands/test_perf_module.py b/tests/unit/commands/test_perf_module.py index 4240897a6..09b5cf319 100644 --- a/tests/unit/commands/test_perf_module.py +++ b/tests/unit/commands/test_perf_module.py @@ -543,3 +543,113 @@ def test_no_quantize_clears_only_quant(self, tmp_path: Path) -> None: cfg = self._run(tmp_path, ["--no-quantize", "--compile"]) assert cfg.quant is None assert cfg.compile is not None + + +class TestPerfModuleCache: + """--rebuild / --ignore-cache control the per-module build cache the same + way they do for the single-model path (mirrors auto.py). + + Regression guard: per-module builds previously always used a throwaway + temp dir and never passed rebuild/cache_key, so artifacts were rebuilt + every run and the cache flags were silently ignored. + """ + + @staticmethod + def _run_build_kwargs(tmp_path: Path, extra_args: list[str]) -> dict: + """Invoke ``perf --module`` and return the build_hf_model call kwargs. + + get_cache_dir is pinned to a known directory so the resolved + persistent build dir is deterministic. The benchmark is short-circuited + via a failing ``session.perf()`` — build_hf_model is already called by + then, so its kwargs are captured. + """ + cache_root = tmp_path / "cache" + + fake_cfg = MagicMock() + fake_cfg.loader.model_type = "bert" + fake_cfg.loader.module_path = "encoder.layer.0" + fake_cfg.generate_cache_key.return_value = "deadbeefdeadbeef" + + fake_build_result = MagicMock() + fake_build_result.final_onnx_path = tmp_path / "model.onnx" + + fake_session = MagicMock() + fake_session.perf.side_effect = RuntimeError("test-skip-benchmark") + + fake_loader_cfg = MagicMock() + fake_loader_cfg.task = "fill-mask" + + with ( + patch( + "winml.modelkit.sysinfo.resolve_device", + return_value=("cpu", ["cpu"]), + ), + patch( + "winml.modelkit.config.generate_hf_build_config", + return_value=[fake_cfg], + ), + patch( + "winml.modelkit.loader.resolve_loader_config", + return_value=(fake_loader_cfg, MagicMock(), MagicMock(), MagicMock()), + ), + patch( + "winml.modelkit.commands.build._instantiate_parent_model", + return_value=MagicMock(), + ), + patch( + "winml.modelkit.build.build_hf_model", + return_value=fake_build_result, + ) as mock_build, + patch( + "winml.modelkit.session.WinMLSession", + return_value=fake_session, + ), + # Pin the cache root so the resolved persistent build dir is + # deterministic. Patch the source attribute — _perf_modules binds + # the name via a function-local `from ..cache import get_cache_dir`. + patch( + "winml.modelkit.cache.get_cache_dir", + return_value=cache_root, + ), + ): + runner = CliRunner() + result = runner.invoke( + main, + [ + "perf", + "-m", + "fake/model", + "--module", + "BertLayer", + "--iterations", + "1", + "--warmup", + "0", + "-o", + str(tmp_path / "out.json"), + *extra_args, + ], + ) + assert result.exit_code == 0, result.output + return dict(mock_build.call_args.kwargs) + + def test_default_uses_persistent_cache_no_rebuild(self, tmp_path: Path) -> None: + kwargs = self._run_build_kwargs(tmp_path, []) + # Builds into the model's persistent cache dir (under the pinned root), + # not a temp dir, and does not force a rebuild. + assert kwargs["rebuild"] is False + assert (tmp_path / "cache") in kwargs["output_dir"].parents + # cache_key disambiguates instances within the shared model dir. + assert kwargs["cache_key"] + + def test_rebuild_forces_rebuild_in_cache_dir(self, tmp_path: Path) -> None: + kwargs = self._run_build_kwargs(tmp_path, ["--rebuild"]) + # Reuses the persistent cache dir but overwrites artifacts. + assert kwargs["rebuild"] is True + assert (tmp_path / "cache") in kwargs["output_dir"].parents + + def test_ignore_cache_uses_temp_dir_and_rebuilds(self, tmp_path: Path) -> None: + kwargs = self._run_build_kwargs(tmp_path, ["--ignore-cache"]) + # Throwaway temp dir (outside the pinned cache root) + forced rebuild. + assert kwargs["rebuild"] is True + assert (tmp_path / "cache") not in kwargs["output_dir"].parents