Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import sys
import types
import unittest
from importlib.metadata import PackageNotFoundError
from unittest import mock

# Make worker.py importable from the repo root.
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
Expand Down Expand Up @@ -201,5 +203,52 @@ def test_warmup_light_when_not_requested(self):
self.assertNotIn("compress_user_messages", self.calls[-1]["kwargs"])


class KompressRepoTest(unittest.TestCase):
"""The default Kompress weights repo changed in headroom-ai 0.24.0
(kompress-base -> kompress-v2-base); worker resolves it from the installed
version so a single worker.py supports both, including mid-upgrade hosts."""

def _repo_for(self, version_value):
with mock.patch("importlib.metadata.version", return_value=version_value):
return worker._kompress_weights_repo()

def test_legacy_versions_use_kompress_base(self):
for v in ("0.23.0", "0.23.5", "0.1.0"):
self.assertEqual(self._repo_for(v), "chopratejas/kompress-base", v)

def test_v024_plus_uses_v2_base(self):
for v in ("0.24.0", "0.24.3", "0.25.0", "1.0.0"):
self.assertEqual(self._repo_for(v), "chopratejas/kompress-v2-base", v)

def test_unreadable_version_defaults_to_v2(self):
with mock.patch("importlib.metadata.version", side_effect=PackageNotFoundError):
self.assertEqual(worker._kompress_weights_repo(), "chopratejas/kompress-v2-base")

def test_odd_version_defaults_to_v2(self):
# A non-numeric component must not raise; fall back to the current default.
self.assertEqual(self._repo_for("unknown"), "chopratejas/kompress-v2-base")

def test_models_cached_tracks_resolved_repo(self):
# Inject a fake huggingface_hub so this runs without the real dependency.
def fake_hub(cached_ids):
repos = [types.SimpleNamespace(repo_id=i) for i in cached_ids]
mod = types.ModuleType("huggingface_hub")
mod.scan_cache_dir = lambda: types.SimpleNamespace(repos=repos)
return mod

# 0.24 host with both required repos cached -> safe to go offline.
both = {"answerdotai/ModernBERT-base", "chopratejas/kompress-v2-base"}
with mock.patch.dict(sys.modules, {"huggingface_hub": fake_hub(both)}), \
mock.patch("importlib.metadata.version", return_value="0.24.0"):
self.assertTrue(worker._models_cached())

# The upgrade trap: 0.24 installed but only the OLD model cached. We must
# NOT force offline (we'd block the v2 download), so _models_cached=False.
only_old = {"answerdotai/ModernBERT-base", "chopratejas/kompress-base"}
with mock.patch.dict(sys.modules, {"huggingface_hub": fake_hub(only_old)}), \
mock.patch("importlib.metadata.version", return_value="0.24.0"):
self.assertFalse(worker._models_cached())


if __name__ == "__main__":
unittest.main()
36 changes: 32 additions & 4 deletions worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,47 @@
# up as "unauthenticated requests to the HF Hub", adds latency to the first
# request a worker serves, and risks anonymous rate-limiting across a pool. We
# tame it *before* importing headroom (so transformers sees the env at import).
_KOMPRESS_REPOS = ("answerdotai/ModernBERT-base", "chopratejas/kompress-base")
#
# ModernBERT is the Kompress tokenizer/encoder base, unchanged across versions.
_MODERNBERT_REPO = "answerdotai/ModernBERT-base"


def _kompress_weights_repo() -> str:
"""HF repo holding the Kompress weights for the *installed* headroom-ai.

The default Kompress model changed in headroom-ai 0.24.0
(chopratejas/kompress-base -> chopratejas/kompress-v2-base). To support
hosts on either version (including ones mid-upgrade that still have only the
old model cached), we resolve the repo from the installed package version
rather than hardcoding one. We read the version via importlib.metadata and
never by importing headroom — importing it here would pull in transformers
before we've set the offline env, which is exactly what this module avoids.

On an unreadable/odd version we assume the current default (v2): the worst
case is then a stale guess that loses the offline optimization, never one
that forces offline against a model the installed version won't load."""
try:
from importlib.metadata import version

major, minor = (int(p) for p in version("headroom-ai").split(".")[:2])
if (major, minor) < (0, 24):
return "chopratejas/kompress-base"
except Exception: # noqa: BLE001 - missing/odd version -> assume current default
pass
return "chopratejas/kompress-v2-base"


def _models_cached() -> bool:
"""True if the Kompress models are already in the local HF cache, so it's
safe to run transformers offline (no network needed to load them)."""
"""True if the Kompress models the installed headroom will load are already
in the local HF cache, so it's safe to run transformers offline (no network
needed to load them)."""
try:
from huggingface_hub import scan_cache_dir

repos = {r.repo_id for r in scan_cache_dir().repos}
except Exception: # noqa: BLE001 - hub missing/unscannable -> assume not cached
return False
return set(_KOMPRESS_REPOS).issubset(repos)
return {_MODERNBERT_REPO, _kompress_weights_repo()}.issubset(repos)


def _configure_hf_env() -> None:
Expand Down
Loading