Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openfeature/provider/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(self) -> None:
def set_provider(self, domain: str, provider: FeatureProvider) -> None:
if provider is None:
raise GeneralError(error_message="No provider")
if domain is None:
raise GeneralError(error_message="No domain")
providers = self._providers
if domain in providers:
old_provider = providers[domain]
Expand Down
118 changes: 118 additions & 0 deletions tests/provider/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from unittest.mock import Mock

from openfeature.exception import GeneralError
from openfeature.provider._registry import ProviderRegistry
from openfeature.provider.no_op_provider import NoOpProvider


def test_registry_serves_noop_as_default():
registry = ProviderRegistry()

assert isinstance(registry.get_default_provider(), NoOpProvider)
assert isinstance(registry.get_provider("unknown domain"), NoOpProvider)


def test_setting_provider_requires_domain():
registry = ProviderRegistry()

try:
registry.set_provider(None, NoOpProvider()) # type: ignore[reportArgumentType]
raise AssertionError(
"Expected an exception when setting provider with None domain"
)
except GeneralError as e:
assert e.error_message == "No domain"
except Exception as e:
raise AssertionError("Expected GeneralError, got {type(e)}") from e
Comment thread
keelerm84 marked this conversation as resolved.
Outdated


def test_setting_provider_requires_provider():
registry = ProviderRegistry()

try:
registry.set_provider("domain", None) # type: ignore[reportArgumentType]
raise AssertionError("Expected an exception when setting provider to None")
except GeneralError as e:
assert e.error_message == "No provider"
except Exception as e:
raise AssertionError("Expected GeneralError, got {type(e)}") from e
Comment thread
keelerm84 marked this conversation as resolved.
Outdated


def test_can_register_provider_to_multiple_domains():
registry = ProviderRegistry()
provider = NoOpProvider()

registry.set_provider("domain1", provider)
registry.set_provider("domain2", provider)

assert registry.get_provider("domain1") is provider
assert registry.get_provider("domain2") is provider


def test_registering_provider_replaces_previous_provider():
"""Test that registering a provider replaces the previous provider and calls shutdown on the old one."""

registry = ProviderRegistry()
provider1 = Mock()
provider2 = Mock()

registry.set_provider("domain", provider1)
assert registry.get_provider("domain") is provider1

registry.set_provider("domain", provider2)
assert registry.get_provider("domain") is provider2

provider1.shutdown.assert_called_once()
provider2.shutdown.assert_not_called()


def test_registering_provider_for_first_time_initializes_it():
"""Test that registering a provider for the first time calls its initialize method."""

registry = ProviderRegistry()
provider = Mock()

registry.set_provider("domain1", provider)
registry.set_provider("domain2", provider)

provider.initialize.assert_called_once()


def test_setting_default_provider_requires_provider():
registry = ProviderRegistry()

try:
registry.set_default_provider(None) # type: ignore[reportArgumentType]
raise AssertionError(
"Expected an exception when setting default provider to None"
)
except GeneralError as e:
assert e.error_message == "No provider"
except Exception as e:
raise AssertionError("Expected GeneralError, got {type(e)}") from e
Comment thread
keelerm84 marked this conversation as resolved.
Outdated


def test_replacing_default_provider_shuts_down_old_one():
"""Test that replacing the default provider shuts down the old default provider."""

registry = ProviderRegistry()
default_provider1 = Mock()
default_provider2 = Mock()

registry.set_default_provider(default_provider1)
assert registry.get_default_provider() is default_provider1

registry.set_default_provider(default_provider2)
assert registry.get_default_provider() is default_provider2

default_provider1.shutdown.assert_called_once()
default_provider2.shutdown.assert_not_called()


def test_setting_default_provider_initializes_it():
registry = ProviderRegistry()
provider = Mock()

registry.set_default_provider(provider)

provider.initialize.assert_called_once()