Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions src/winml/modelkit/commands/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,8 @@ def _perf_modules(
ep_options: dict[str, str] | None = None,
precision: str = "auto",
allow_unsupported_nodes: bool = False,
rebuild: bool = False,
ignore_cache: bool = False,
) -> None:
"""Run per-module build and benchmark for matching submodules.

Expand Down Expand Up @@ -863,13 +865,21 @@ def _perf_modules(
precision: Precision mode passed through to the build stage.
allow_unsupported_nodes: If True, warn instead of failing the build when
the analyzer reports unsupported nodes that persist.
rebuild: If True, overwrite cached per-module artifacts and re-run the
build (mirrors the single-model ``--rebuild``).
ignore_cache: If True, build each module in a throwaway temp dir and
always rebuild, discarding artifacts afterward (mirrors the
single-model ``--ignore-cache``).
"""
import contextlib
import difflib
import json as json_mod
import tempfile

from ..build import build_hf_model
from ..cache import get_cache_dir, get_cache_key, get_model_dir
from ..config import SubmoduleClassNotFoundError, generate_hf_build_config
from ..loader.task import get_task_abbrev
from ..sysinfo import resolve_device, resolve_eps
from .build import _instantiate_parent_model

Expand Down Expand Up @@ -932,6 +942,18 @@ def _perf_modules(
parent_loader_cfg, _, _, _resolution = resolve_loader_config(model_id=hf_model, task=task)
parent_model = _instantiate_parent_model(model_type, task=parent_loader_cfg.task)

# Cache control mirrors auto.py / the single-model path:
# --ignore-cache -> build each module in a throwaway temp dir, always
# rebuild, discard afterward
# --rebuild -> reuse the persistent model dir but overwrite artifacts
# Each module's cache_key folds in loader.module_path (and its I/O shapes),
# so sibling instances of the same class get distinct keys and coexist in
# the shared model dir without colliding.
use_cache = not ignore_cache
force_rebuild = rebuild or ignore_cache
task_abbrev = get_task_abbrev(parent_loader_cfg.task) if parent_loader_cfg.task else "module"
cache_model_dir = get_model_dir(hf_model, cache_dir=get_cache_dir()) if use_cache else None

all_results: list[dict[str, Any]] = []
for i, cfg in enumerate(module_configs):
module_path = cfg.loader.module_path
Expand Down Expand Up @@ -960,12 +982,26 @@ def _perf_modules(
if no_compile:
cfg.compile = None

with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
# Compute the cache key AFTER the quant/compile mutations above so it
# reflects what is actually built.
cache_key = get_cache_key(task_abbrev, cfg.generate_cache_key())

# Persistent model dir (reused across runs) when caching, else a
# throwaway temp dir that is removed when the with-block exits.
build_dir_ctx: Any = (
contextlib.nullcontext(cache_model_dir)
if use_cache
else tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
)
with build_dir_ctx as build_dir_raw:
build_dir = Path(build_dir_raw)
try:
build_result = build_hf_model(
config=cfg,
output_dir=Path(tmpdir),
output_dir=build_dir,
pytorch_model=submodule,
rebuild=force_rebuild,
cache_key=cache_key,
ep=ep,
device=resolved_device,
allow_unsupported_nodes=allow_unsupported_nodes,
Expand Down Expand Up @@ -1681,6 +1717,8 @@ def perf(
ep_options=ep_provider_options,
precision=precision.lower(),
allow_unsupported_nodes=allow_unsupported_nodes,
rebuild=rebuild,
ignore_cache=ignore_cache,
)
return

Expand Down
110 changes: 110 additions & 0 deletions tests/unit/commands/test_perf_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,113 @@ def test_no_quantize_clears_only_quant(self, tmp_path: Path) -> None:
cfg = self._run(tmp_path, ["--no-quantize", "--compile"])
assert cfg.quant is None
assert cfg.compile is not None


class TestPerfModuleCache:
"""--rebuild / --ignore-cache control the per-module build cache the same
way they do for the single-model path (mirrors auto.py).

Regression guard: per-module builds previously always used a throwaway
temp dir and never passed rebuild/cache_key, so artifacts were rebuilt
every run and the cache flags were silently ignored.
"""

@staticmethod
def _run_build_kwargs(tmp_path: Path, extra_args: list[str]) -> dict:
"""Invoke ``perf --module`` and return the build_hf_model call kwargs.

get_cache_dir is pinned to a known directory so the resolved
persistent build dir is deterministic. The benchmark is short-circuited
via a failing ``session.perf()`` — build_hf_model is already called by
then, so its kwargs are captured.
"""
cache_root = tmp_path / "cache"

fake_cfg = MagicMock()
fake_cfg.loader.model_type = "bert"
fake_cfg.loader.module_path = "encoder.layer.0"
fake_cfg.generate_cache_key.return_value = "deadbeefdeadbeef"

fake_build_result = MagicMock()
fake_build_result.final_onnx_path = tmp_path / "model.onnx"

fake_session = MagicMock()
fake_session.perf.side_effect = RuntimeError("test-skip-benchmark")

fake_loader_cfg = MagicMock()
fake_loader_cfg.task = "fill-mask"

with (
patch(
"winml.modelkit.sysinfo.resolve_device",
return_value=("cpu", ["cpu"]),
),
patch(
"winml.modelkit.config.generate_hf_build_config",
return_value=[fake_cfg],
),
patch(
"winml.modelkit.loader.resolve_loader_config",
return_value=(fake_loader_cfg, MagicMock(), MagicMock(), MagicMock()),
),
patch(
"winml.modelkit.commands.build._instantiate_parent_model",
return_value=MagicMock(),
),
patch(
"winml.modelkit.build.build_hf_model",
return_value=fake_build_result,
) as mock_build,
patch(
"winml.modelkit.session.WinMLSession",
return_value=fake_session,
),
# Pin the cache root so the resolved persistent build dir is
# deterministic. Patch the source attribute — _perf_modules binds
# the name via a function-local `from ..cache import get_cache_dir`.
patch(
"winml.modelkit.cache.get_cache_dir",
return_value=cache_root,
),
):
runner = CliRunner()
result = runner.invoke(
main,
[
"perf",
"-m",
"fake/model",
"--module",
"BertLayer",
"--iterations",
"1",
"--warmup",
"0",
"-o",
str(tmp_path / "out.json"),
*extra_args,
],
)
assert result.exit_code == 0, result.output
return dict(mock_build.call_args.kwargs)

def test_default_uses_persistent_cache_no_rebuild(self, tmp_path: Path) -> None:
kwargs = self._run_build_kwargs(tmp_path, [])
# Builds into the model's persistent cache dir (under the pinned root),
# not a temp dir, and does not force a rebuild.
assert kwargs["rebuild"] is False
assert (tmp_path / "cache") in kwargs["output_dir"].parents
# cache_key disambiguates instances within the shared model dir.
assert kwargs["cache_key"]

def test_rebuild_forces_rebuild_in_cache_dir(self, tmp_path: Path) -> None:
kwargs = self._run_build_kwargs(tmp_path, ["--rebuild"])
# Reuses the persistent cache dir but overwrites artifacts.
assert kwargs["rebuild"] is True
assert (tmp_path / "cache") in kwargs["output_dir"].parents

def test_ignore_cache_uses_temp_dir_and_rebuilds(self, tmp_path: Path) -> None:
kwargs = self._run_build_kwargs(tmp_path, ["--ignore-cache"])
# Throwaway temp dir (outside the pinned cache root) + forced rebuild.
assert kwargs["rebuild"] is True
assert (tmp_path / "cache") not in kwargs["output_dir"].parents
Loading