-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_registry.py
More file actions
79 lines (62 loc) · 2.54 KB
/
test_registry.py
File metadata and controls
79 lines (62 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import pytest
from contextflow.core.registry import PluginRegistry
from contextflow.mode import mode_registry, MinimalMode, FullMode
from contextflow.compression import compressor_registry, StandardCompressor, DistillationCompressor
from contextflow.provider import provider_registry, MockProvider, OpenAIProvider
from contextflow.sources import source_registry, FileSource
from contextflow.ranking import scorer_registry, TimeDecayScorer
class DummyBase:
pass
def test_generic_registry():
registry = PluginRegistry("test", DummyBase)
@registry.register("my_plugin")
class MyPlugin(DummyBase):
pass
assert "my_plugin" in registry.list_plugins()
assert registry.get_class("my_plugin") is MyPlugin
instance = registry.get("my_plugin")
assert isinstance(instance, MyPlugin)
def test_registry_type_checking():
registry = PluginRegistry("test", DummyBase)
with pytest.raises(TypeError):
@registry.register("invalid")
class InvalidPlugin:
pass
def test_mode_registry():
plugins = mode_registry.list_plugins()
assert "minimal" in plugins
assert "full" in plugins
assert issubclass(mode_registry.get_class("minimal"), MinimalMode)
def test_compressor_registry():
plugins = compressor_registry.list_plugins()
assert "standard" in plugins
assert "distillation" in plugins
assert issubclass(compressor_registry.get_class("standard"), StandardCompressor)
def test_provider_registry():
plugins = provider_registry.list_plugins()
assert "mock" in plugins
assert "openai" in plugins
def test_source_registry():
plugins = source_registry.list_plugins()
assert "file" in plugins
def test_scorer_registry():
plugins = scorer_registry.list_plugins()
assert "time_decay" in plugins
def test_auto_discovery(monkeypatch):
import importlib.metadata
from contextflow.core.interfaces import ContextMode
class MockMode(ContextMode):
def select(self, messages): return messages
class MockEntryPoint:
def __init__(self, name):
self.name = name
def load(self):
return MockMode
def mock_entry_points(group=None):
if group == "contextflow.plugins.mode":
return [MockEntryPoint("mock_external")]
return []
monkeypatch.setattr(importlib.metadata, "entry_points", mock_entry_points)
mode_registry.discover()
assert "mock_external" in mode_registry.list_plugins()
assert mode_registry.get_class("mock_external") is MockMode