From ccb934c24c074ad5b98f725b5256b3e2be6d62a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Tue, 24 Feb 2026 17:13:23 +0100 Subject: [PATCH 01/12] Initial implementation --- src/gt4py/eve/utils.py | 39 +++++ src/gt4py/next/__init__.py | 6 + src/gt4py/next/_config.py | 330 +++++++++++++++++++++++++++++++++++++ 3 files changed, 375 insertions(+) create mode 100644 src/gt4py/next/_config.py diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index a0e48ae557..3c909973b6 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -527,6 +527,45 @@ def partial(self, *args: Any, **kwargs: Any) -> fluid_partial: return fluid_partial(self, *args, **kwargs) +if xtyping.TYPE_CHECKING: + + class TypeDispatcher(Generic[_P, _T]): + def __init__(self, func: Callable[_P, _T]) -> None: ... + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ... + + @overload + def register(self, arg_type: type) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... + + @overload + def register(self, arg_type: type, func: Callable[_P, _T]) -> Callable[_P, _T]: ... + + def register( + self, arg_type: type, func: Optional[Callable[_P, _T]] = None + ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]] | Callable[_P, _T]: ... + + +def type_dispatcher(func: Callable[_P, _T]) -> TypeDispatcher[_P, _T]: + """ + Decorator to create a single-dispatch generic function that dispatches of the first argument. + """ + # reuse the singledispatch() dispatching mechanism but change the wrapper + indirect_dispatcher = functools.singledispatch(func) + dispatch = indirect_dispatcher.dispatch + + def wrapper(*args: P.args, **kw: P.kwargs) -> _T: + if not args: + raise TypeError(f"{func.__name__} requires at least 1 positional argument") + + return dispatch(args[0])(*args, **kw) + + functools.update_wrapper(wrapper, func) + for attr in ("dispatch", "register", "registry", "_clear_cache"): + setattr(wrapper, attr, getattr(indirect_dispatcher, attr)) + + return wrapper + + @overload def with_fluid_partial( func: Literal[None] = None, *args: Any, **kwargs: Any diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 4dd5a195b0..3e3d4641cd 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -19,8 +19,13 @@ """ # ruff: noqa: F401 +from __future__ import annotations + from .._core.definitions import CUPY_DEVICE_TYPE, Device, DeviceType, is_scalar_type from . import common, ffront, iterator, program_processors, typing + +# reexport the actual configuration manager instance as a publice attribute +from ._config import config from .common import ( Connectivity, Dimension, @@ -52,6 +57,7 @@ __all__ = [ # submodules "common", + "config", "ffront", "iterator", "program_processors", diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py new file mode 100644 index 0000000000..5036ea3a5c --- /dev/null +++ b/src/gt4py/next/_config.py @@ -0,0 +1,330 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import contextlib +import contextvars +import dataclasses +import enum +import os +import pathlib +import sys +import types +from collections.abc import Callable, Generator +from typing import Any, Final, Generic, Literal, Protocol, TypeVar, cast, final + + +from gt4py.eve import utils +from gt4py.eve.extended_typing import Self + + +@final +class _UNSET_SENTINEL: ... + + +_UNSET: Final = _UNSET_SENTINEL() + +_T = TypeVar("_T") +_T_contra = TypeVar("_T_contra", contravariant=True) + + +@utils.type_dispatcher +def get_value_from_environment_var( + as_type: type[_T], var_name: str, *, default: _T | None = None +) -> _T | None: + """Convert the content of environment variable a typed value.""" + env_value = os.environ.get(var_name, None) + if env_value is None: + return default + try: + return as_type(env_value) + except Exception as e: + raise TypeError( + f"Unsupported conversion of GT4Py environment variable {var_name}: {env_value}) to type '{as_type.__name__}'." + ) from None + + +@get_value_from_environment_var.register(bool) +def _get_value_from_environment_var_as_bool( + as_type: type[bool], var_name: str, *, default: bool | None = None +) -> bool | None: + env_value = os.environ.get(var_name, None) + if env_value is None: + return default + match env_value.upper(): + case "0" | "FALSE" | "OFF": + return False + case "1" | "TRUE" | "ON": + return True + case _: + raise ValueError( + f"Invalid GT4Py environment flag value for {var_name}: use '0 | FALSE | OFF' or '1 | TRUE | ON'." + ) + + +class UpdateScope(str, enum.Enum): + GLOBAL = sys.intern("global") + CONTEXT = sys.intern("context") + + +class OptionUpdateCallback(Protocol[_T_contra]): + def __call__( + self, new_val: _T_contra, old_val: _T_contra | None, scope: UpdateScope + ) -> None: ... + + +ConfigRegistryT = TypeVar("ConfigRegistryT", bound="ConfigManager") + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class OptionDescriptor(Generic[_T, ConfigRegistryT]): + type: type[_T] + default: dataclasses.InitVar[_T | _UNSET_SENTINEL] = _UNSET + default_factory: Callable[[ConfigRegistryT], _T] | None = None + validator: Callable[[Any], Any] | Literal["type_check"] | None = "type_check" + update_callback: OptionUpdateCallback[_T] | None = None + env_prefix: str = "GT4PY_" + name: str = dataclasses.field(init=False) + + def __post_init__(self, default: _T | _UNSET_SENTINEL) -> None: + if self.validator == "type_check": + object.__setattr__(self, "validator", utils.isinstancechecker(self.type)) + assert self.validator is None or callable(self.validator) + + if default is not _UNSET: + if self.default_factory is not None: + raise ValueError( + "Cannot specify both default and default_factory for a config option descriptor." + ) + if self.validator is not None: + self.validator(default) + object.__setattr__(self, "default_factory", lambda _: default) + elif self.default_factory is None: + raise ValueError( + "Must specify either default or default_factory for a config option descriptor." + ) + + def __set_name__(self, owner: type, name: str) -> None: + object.__setattr__(self, "name", name) + + def __get__(self, instance: Any, owner: type | None = None) -> _T | Self: + try: + assert isinstance(instance, ConfigManager) + return instance.get(self.name) + except Exception as e: + if instance is None: + # Accessed on the class, return the descriptor itself (e.g. for help()) + return self + raise AttributeError(f"Error reading config option {self.name!r}") from e + + def __set__(self, instance: Any, value: _T) -> None: + assert isinstance(instance, ConfigManager) + instance.set(self.name, value) + + @property + def env_var_name(self) -> str: + return f"{self.env_prefix}{self.name}".upper() + + +class ConfigManager: + """Central configuration registry with attribute-style access.""" + + def __init__(self) -> None: + self._descriptors: dict[str, OptionDescriptor[Any, Config]] = { + name: attr + for name, attr in type(self).__dict__.items() + if isinstance(attr, OptionDescriptor) + } + self._keys = set(self._descriptors.keys()) + self._validators: dict[str, Callable[[Any], None]] = { + name: desc.validator + for name, desc in self._descriptors.items() + if callable(desc.validator) + } + self._hooks: dict[str, OptionUpdateCallback[Any]] = { + name: desc.update_callback + for name, desc in self._descriptors.items() + if desc.update_callback is not None + } + + # An instance-level ContextVar creates isolated context-local state per manager + # instance. Though discouraged in general (values bind to ContextVar identity + # and Context objects hold strong references to ContextVars, so they won't be + # GC'd even if the instance goes out of scope), in this case we really want + # per-registry isolation and we assume only very few ConfigRegistry instances + # will be ever created. + self._local_context_cvar = contextvars.ContextVar[types.MappingProxyType]( + f"{self.__class__.__name__}_cvar", default=types.MappingProxyType({}) + ) + + self._global_context: dict[str, Any] = {} + for name, desc in self._descriptors.items(): + assert desc.default_factory is not None # Guaranteed by __post_init__ + self._global_context[name] = get_value_from_environment_var( + desc.type, desc.env_var_name, default=desc.default_factory(self) + ) + + def get(self, name: str) -> Any: + if __debug__ and name not in self._keys: + raise AttributeError(f"Unrecognized config option: {name}") + if (val := self._local_context_cvar.get().get(name, _UNSET)) is _UNSET: + return self._global_context[name] + return val + + def set(self, name: str, val: Any) -> None: + if __debug__ and name not in self._keys: + raise AttributeError(f"Unrecognized config option: {name}") + if name in self._local_context_cvar.get(): + raise AttributeError( + f"Cannot set config option {name!r} while it is overridden in a context manager" + ) + old_val = self._global_context[name] + self._global_context[name] = val + if hook := self._hooks.get(name): + hook(val, old_val, UpdateScope.GLOBAL) + + @contextlib.contextmanager + def overrides(self, **overrides: Any) -> Generator[None, None, None]: + if __debug__ and overrides.keys() - self._keys: + raise AttributeError( + f"Unrecognized config options: {set(overrides.keys()) - self._keys}" + ) + for name in overrides.keys() & self._validators.keys(): + self._validators[name](overrides[name]) + old_context = self._local_context_cvar.get() + new_context = old_context | overrides + + token = self._local_context_cvar.set(new_context) + + try: + for name in overrides.keys() & self._hooks.keys(): + self._hooks[name]( + new_context[name], + old_context.get(name, self._global_context[name]), + UpdateScope.CONTEXT, + ) + + yield + + finally: + self._local_context_cvar.reset(token) + for name in overrides.keys() & old_context.keys() & self._hooks.keys(): + self._hooks[name](old_context.get(name), new_context.get(name), UpdateScope.CONTEXT) + + def as_dict(self) -> dict[str, Any]: + """Get the current effective configuration options as a dictionary.""" + # We use self._descriptors to preserve the order of options as defined in the class. + return {name: self.get(name) for name in self._descriptors.keys()} + + def _option_descriptors_(self) -> types.MappingProxyType[str, OptionDescriptor]: + """Get the option descriptors.""" + return types.MappingProxyType(self._descriptors) + + +class Config(ConfigManager): + """ + GT4Py configuration registry. + + This class is used to register configuration options for GT4Py. + """ + + ## -- Debug options -- + #: Master debug flag. It changes defaults for all the other options to be as helpful + #: for debugging as possible. + debug = OptionDescriptor(type=bool, default=False, validator=utils.isinstancechecker(bool)) + + #: Verbose flag for DSL compilation errors. + verbose_exceptions = OptionDescriptor[bool, "Config"]( + type=bool, default_factory=(lambda cfg: cast(bool, cfg.debug)) + ) + + ## -- Instrumentation options -- + #: User-defined level to enable metrics at lower or equal level. + #: Enabling metrics collection will do extra synchronization and will have + #: impact on runtime performance. + collect_metrics_level = OptionDescriptor(type=int, default=0) + + #: Add GPU trace markers (NVTX, ROC-TX) to the generated code, at compile time. + # FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. + add_gpu_trace_markers = OptionDescriptor(type=bool, default=False) + + ## -- Build options -- + class BuildCacheLifetime(enum.Enum): + SESSION = "session" + PERSISTENT = "persistent" + + #: Whether generated code projects should be kept around between runs. + #: - SESSION: generated code projects get destroyed when the interpreter shuts down + #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs + build_cache_lifetime = OptionDescriptor[BuildCacheLifetime, "Config"]( + type=BuildCacheLifetime, + default_factory=( + lambda cfg: cfg.BuildCacheLifetime.PERSISTENT + if cfg.debug + else cfg.BuildCacheLifetime.SESSION + ), + ) + + #: Where generated code projects should be persisted. + #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT + build_cache_dir_root = OptionDescriptor(type=pathlib.Path, default=pathlib.Path.cwd()) + + @property + def build_cache_dir(self) -> pathlib.Path: + assert isinstance(self.build_cache_dir_root, pathlib.Path) + return self.build_cache_dir_root / ".gt4py_cache" + + class CMakeBuildType(enum.Enum): + """ + CMake build types enum. + + Member values have to be valid CMake syntax. + """ + + DEBUG = "Debug" + RELEASE = "Release" + REL_WITH_DEB_INFO = "RelWithDebInfo" + MIN_SIZE_REL = "MinSizeRel" + + #: Build type to be used when CMake is used to compile generated code. + #: Might have no effect when CMake is not used as part of the toolchain. + # FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. + cmake_build_type = OptionDescriptor[CMakeBuildType, "Config"]( + type=CMakeBuildType, + default_factory=( + lambda cfg: cfg.CMakeBuildType.DEBUG if cfg.debug else cfg.CMakeBuildType.RELEASE + ), + ) + + #: Number of threads to use to use for compilation (0 = synchronous compilation). + #: Default: + #: - use os.cpu_count(), TODO(havogt): in Python >= 3.13 use `process_cpu_count()` + #: - if os.cpu_count() is None we are conservative and use 1 job, + #: - if the number is huge (e.g. HPC system) we limit to a smaller number + build_jobs = OptionDescriptor( + type=int, + default_factory=lambda ctx: min(os.cpu_count() or 1, 32), + ) + + ## -- Code-generation options -- + #: Experimental, use at your own risk: assume horizontal dimension has stride 1 + # FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. + unstructured_horizontal_has_unit_stride = OptionDescriptor(type=bool, default=False) + + #: The default for whether to allow jit-compilation for a compiled program. + #: This default can be overriden per program. + enable_jit_default = OptionDescriptor(type=bool, default=True) + + +config = Config() + +# if __name__ == "__main__": +# print(aa) +# self = sys.modules[__name__] +# print(self.aa) From c45069723ec1764b9172a5d0f566e03045fc60b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Tue, 24 Feb 2026 22:56:20 +0100 Subject: [PATCH 02/12] Fixes and cleanups --- src/gt4py/next/_config.py | 142 +++++++++++++++++++++++++++----------- 1 file changed, 103 insertions(+), 39 deletions(-) diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index 5036ea3a5c..8e552fe5b9 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -6,6 +6,20 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +""" +Configuration system for GT4Py. + +Precedence of effective option values (highest to lowest): +1) Active context override (`ConfigManager.overrides`) +2) Global runtime value (`ConfigManager.set`) +3) Environment variable (`OptionDescriptor.env_var_name`) +4) Descriptor default/default_factory + +Notes: +- Context overrides are task-local via `contextvars`. +- `set()` is disallowed while the same option is context-overridden. +""" + from __future__ import annotations import contextlib @@ -16,29 +30,33 @@ import pathlib import sys import types -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Mapping from typing import Any, Final, Generic, Literal, Protocol, TypeVar, cast, final - from gt4py.eve import utils from gt4py.eve.extended_typing import Self @final -class _UNSET_SENTINEL: ... +class Sentinel: + UNSET = enum.auto() -_UNSET: Final = _UNSET_SENTINEL() - _T = TypeVar("_T") _T_contra = TypeVar("_T_contra", contravariant=True) +_EnumT = TypeVar("_EnumT", bound=enum.Enum) @utils.type_dispatcher def get_value_from_environment_var( as_type: type[_T], var_name: str, *, default: _T | None = None ) -> _T | None: - """Convert the content of environment variable a typed value.""" + """ + Create an instance of the provided type from the value of an environment variable. + + The implementation uses a explicit type-dispatcher to allow custom parsing logic + for different types (e.g. bool, enums). + """ env_value = os.environ.get(var_name, None) if env_value is None: return default @@ -46,8 +64,9 @@ def get_value_from_environment_var( return as_type(env_value) except Exception as e: raise TypeError( - f"Unsupported conversion of GT4Py environment variable {var_name}: {env_value}) to type '{as_type.__name__}'." - ) from None + f"Unsupported conversion of GT4Py environment variable {var_name}: " + f"{env_value!r} to type '{as_type.__name__}'." + ) from e @get_value_from_environment_var.register(bool) @@ -57,7 +76,7 @@ def _get_value_from_environment_var_as_bool( env_value = os.environ.get(var_name, None) if env_value is None: return default - match env_value.upper(): + match env_value.strip().upper(): case "0" | "FALSE" | "OFF": return False case "1" | "TRUE" | "ON": @@ -68,36 +87,67 @@ def _get_value_from_environment_var_as_bool( ) +@get_value_from_environment_var.register(enum.Enum) +def _get_value_from_environment_var_as_enum( + as_type: type[_EnumT], var_name: str, *, default: _EnumT | None = None +) -> _EnumT | None: + """Create enum by member name.""" + env_value = os.environ.get(var_name, None) + if env_value is None: + return default + + try: + return as_type[env_value] + except Exception as e: + raise TypeError( + f"Invalid GT4Py enum value for {var_name}: {env_value!r}. Allowed: {[m.name for m in as_type]}." + ) from e + + class UpdateScope(str, enum.Enum): GLOBAL = sys.intern("global") CONTEXT = sys.intern("context") class OptionUpdateCallback(Protocol[_T_contra]): + """Callback invoked after an option changes (in a global or local context scope).""" + def __call__( self, new_val: _T_contra, old_val: _T_contra | None, scope: UpdateScope ) -> None: ... -ConfigRegistryT = TypeVar("ConfigRegistryT", bound="ConfigManager") +ConfigManagerT = TypeVar("ConfigManagerT", bound="ConfigManager") @dataclasses.dataclass(frozen=True, kw_only=True) -class OptionDescriptor(Generic[_T, ConfigRegistryT]): +class OptionDescriptor(Generic[_T, ConfigManagerT]): + """ + Descriptor for a configuration option. + + Instances of this class should be defined as class attributes of a + `ConfigManager` subclass. This class implements the descriptor protocol + to support the bare attribute-style access to the option value on the + manager instance (e.g. `config.debug`), which will be resolved properly + using the precedence rules defined in `ConfigManager.get()`. + """ + type: type[_T] - default: dataclasses.InitVar[_T | _UNSET_SENTINEL] = _UNSET - default_factory: Callable[[ConfigRegistryT], _T] | None = None + default: dataclasses.InitVar[_T | Literal[Sentinel.UNSET]] = Sentinel.UNSET + default_factory: Callable[[ConfigManagerT], _T] | None = None validator: Callable[[Any], Any] | Literal["type_check"] | None = "type_check" update_callback: OptionUpdateCallback[_T] | None = None env_prefix: str = "GT4PY_" name: str = dataclasses.field(init=False) - def __post_init__(self, default: _T | _UNSET_SENTINEL) -> None: + def __post_init__(self, default: _T | Literal[Sentinel.UNSET]) -> None: + # Initialize the validator if self.validator == "type_check": object.__setattr__(self, "validator", utils.isinstancechecker(self.type)) assert self.validator is None or callable(self.validator) - if default is not _UNSET: + # Initialize the default factory based on the provided default/default_factory + if default is not Sentinel.UNSET: if self.default_factory is not None: raise ValueError( "Cannot specify both default and default_factory for a config option descriptor." @@ -111,6 +161,7 @@ def __post_init__(self, default: _T | _UNSET_SENTINEL) -> None: ) def __set_name__(self, owner: type, name: str) -> None: + """Set the name of the option based on the attribute name in the owner class.""" object.__setattr__(self, "name", name) def __get__(self, instance: Any, owner: type | None = None) -> _T | Self: @@ -129,11 +180,18 @@ def __set__(self, instance: Any, value: _T) -> None: @property def env_var_name(self) -> str: + """Construct the name of the environment variable corresponding to this option.""" return f"{self.env_prefix}{self.name}".upper() class ConfigManager: - """Central configuration registry with attribute-style access.""" + """ + Central configuration manager with attribute-style access. + + Config options are defined as class attributes using `OptionDescriptor`. + The manager stores global values for all options and allows temporary + overrides in a context manager scope. + """ def __init__(self) -> None: self._descriptors: dict[str, OptionDescriptor[Any, Config]] = { @@ -159,62 +217,73 @@ def __init__(self) -> None: # GC'd even if the instance goes out of scope), in this case we really want # per-registry isolation and we assume only very few ConfigRegistry instances # will be ever created. - self._local_context_cvar = contextvars.ContextVar[types.MappingProxyType]( + self._local_context_cvar = contextvars.ContextVar[Mapping[str, Any]]( f"{self.__class__.__name__}_cvar", default=types.MappingProxyType({}) ) self._global_context: dict[str, Any] = {} for name, desc in self._descriptors.items(): assert desc.default_factory is not None # Guaranteed by __post_init__ - self._global_context[name] = get_value_from_environment_var( + init_value = get_value_from_environment_var( desc.type, desc.env_var_name, default=desc.default_factory(self) ) + if validator := self._validators.get(name): + validator(init_value) + self._global_context[name] = init_value def get(self, name: str) -> Any: if __debug__ and name not in self._keys: raise AttributeError(f"Unrecognized config option: {name}") - if (val := self._local_context_cvar.get().get(name, _UNSET)) is _UNSET: + if (val := self._local_context_cvar.get().get(name, Sentinel.UNSET)) is Sentinel.UNSET: return self._global_context[name] return val - def set(self, name: str, val: Any) -> None: + def set(self, name: str, value: Any) -> None: if __debug__ and name not in self._keys: raise AttributeError(f"Unrecognized config option: {name}") if name in self._local_context_cvar.get(): raise AttributeError( f"Cannot set config option {name!r} while it is overridden in a context manager" ) + if validator := self._validators.get(name): + validator(value) old_val = self._global_context[name] - self._global_context[name] = val + self._global_context[name] = value if hook := self._hooks.get(name): - hook(val, old_val, UpdateScope.GLOBAL) + hook(value, old_val, UpdateScope.GLOBAL) @contextlib.contextmanager def overrides(self, **overrides: Any) -> Generator[None, None, None]: - if __debug__ and overrides.keys() - self._keys: + if overrides.keys() - self._keys: raise AttributeError( f"Unrecognized config options: {set(overrides.keys()) - self._keys}" ) - for name in overrides.keys() & self._validators.keys(): - self._validators[name](overrides[name]) - old_context = self._local_context_cvar.get() - new_context = old_context | overrides + old_values = {} + changes = {} + for name, new_value in overrides.items(): + old_value = self.get(name) + if new_value != old_value: + old_values[name] = old_value + changes[name] = new_value + + for name in changes.keys() & self._validators.keys(): + self._validators[name](changes[name]) + + old_context = self._local_context_cvar.get() + new_context = types.MappingProxyType(**old_context, **changes) token = self._local_context_cvar.set(new_context) try: - for name in overrides.keys() & self._hooks.keys(): - self._hooks[name]( - new_context[name], - old_context.get(name, self._global_context[name]), - UpdateScope.CONTEXT, - ) + for name in changes.keys() & self._hooks.keys(): + self._hooks[name](new_context[name], old_values[name], UpdateScope.CONTEXT) yield finally: self._local_context_cvar.reset(token) - for name in overrides.keys() & old_context.keys() & self._hooks.keys(): + + for name in changes.keys() & old_context.keys() & self._hooks.keys(): self._hooks[name](old_context.get(name), new_context.get(name), UpdateScope.CONTEXT) def as_dict(self) -> dict[str, Any]: @@ -323,8 +392,3 @@ class CMakeBuildType(enum.Enum): config = Config() - -# if __name__ == "__main__": -# print(aa) -# self = sys.modules[__name__] -# print(self.aa) From 2c7da4c97a024086591d4f508cdc87b9cbcb3dce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 25 Feb 2026 10:11:54 +0100 Subject: [PATCH 03/12] First try to integrate the new system --- src/gt4py/next/_config.py | 10 +++++----- src/gt4py/next/{config.py => _old_config.py} | 0 src/gt4py/next/errors/excepthook.py | 2 +- src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/instrumentation/metrics.py | 6 +++--- src/gt4py/next/iterator/runtime.py | 2 +- src/gt4py/next/otf/binding/nanobind.py | 2 +- src/gt4py/next/otf/compilation/cache.py | 2 +- src/gt4py/next/otf/compiled_program.py | 4 ++-- src/gt4py/next/otf/options.py | 2 +- .../runners/dace/lowering/gtir_to_sdfg.py | 2 +- .../runners/dace/workflow/backend.py | 2 +- .../runners/dace/workflow/compilation.py | 2 +- .../runners/dace/workflow/decoration.py | 2 +- .../runners/dace/workflow/factory.py | 6 +++--- .../runners/dace/workflow/translation.py | 4 ++-- src/gt4py/next/program_processors/runners/gtfn.py | 6 +++--- .../next/program_processors/runners/roundtrip.py | 2 +- tests/next_tests/__init__.py | 2 +- .../feature_tests/ffront_tests/test_decorator.py | 2 +- .../unit_tests/instrumentation_tests/test_metrics.py | 12 ++++++------ .../runners_tests/dace_tests/test_dace_backend.py | 2 +- .../dace_tests/test_dace_translation.py | 2 +- 23 files changed, 39 insertions(+), 39 deletions(-) rename src/gt4py/next/{config.py => _old_config.py} (100%) diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index 8e552fe5b9..5583559847 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -31,7 +31,7 @@ import sys import types from collections.abc import Callable, Generator, Mapping -from typing import Any, Final, Generic, Literal, Protocol, TypeVar, cast, final +from typing import Any, Generic, Literal, Protocol, TypeVar, cast, final from gt4py.eve import utils from gt4py.eve.extended_typing import Self @@ -39,7 +39,7 @@ @final class Sentinel: - UNSET = enum.auto() + UNSET = object() _T = TypeVar("_T") @@ -259,8 +259,8 @@ def overrides(self, **overrides: Any) -> Generator[None, None, None]: f"Unrecognized config options: {set(overrides.keys()) - self._keys}" ) - old_values = {} - changes = {} + old_values: dict[str, Any] = {} + changes: dict[str, Any] = {} for name, new_value in overrides.items(): old_value = self.get(name) if new_value != old_value: @@ -271,7 +271,7 @@ def overrides(self, **overrides: Any) -> Generator[None, None, None]: self._validators[name](changes[name]) old_context = self._local_context_cvar.get() - new_context = types.MappingProxyType(**old_context, **changes) + new_context = types.MappingProxyType({**old_context, **changes}) token = self._local_context_cvar.set(new_context) try: diff --git a/src/gt4py/next/config.py b/src/gt4py/next/_old_config.py similarity index 100% rename from src/gt4py/next/config.py rename to src/gt4py/next/_old_config.py diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 6bd084eebd..dffc2de321 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -44,7 +44,7 @@ def compilation_error_hook( """ # in hard crashes of the interpreter, the `exceptions` module might be partially unloaded if exceptions.DSLError is not None and isinstance(value, exceptions.DSLError): - exc_strs = _format_uncaught_error(value, config.VERBOSE_EXCEPTIONS) + exc_strs = _format_uncaught_error(value, config.verbose_exceptions) print("".join(exc_strs), file=sys.stderr) else: fallback(type_, value, tb) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 6b19e7cc1f..39ac4b32d8 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -141,7 +141,7 @@ def past_to_gtir(inp: ConcretePASTProgramDef) -> definitions.CompilableProgramDe inp.args, args=args, kwargs=kwargs, column_axis=_column_axis(all_closure_vars) ) - if config.DEBUG or inp.data.debug: + if config.debug or inp.data.debug: devtools.debug(itir_program) return definitions.CompilableProgramDef(data=itir_program, args=compile_time_args) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 908f99ed61..fb50cfcc05 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -52,17 +52,17 @@ def is_any_level_enabled() -> bool: """Check if any metrics collection level is enabled.""" - return config.COLLECT_METRICS_LEVEL > DISABLED + return config.collect_metrics_level > DISABLED def is_level_enabled(level: int) -> bool: """Check if a given metrics collection level is enabled.""" - return config.COLLECT_METRICS_LEVEL >= level + return config.collect_metrics_level >= level def get_current_level() -> int: """Retrieve the current metrics collection level (from the configuration module).""" - return config.COLLECT_METRICS_LEVEL + return config.collect_metrics_level @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 18c4f9f897..5aa23e3b2f 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -67,7 +67,7 @@ def itir(self, *args): fencil_definition = trace_fencil_definition(self.definition, args) - if config.DEBUG: + if config.debug: devtools.debug(fencil_definition) return fencil_definition diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 15f4b1866c..1773d222d7 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -209,7 +209,7 @@ def make_argument(name: str, type_: ts.TypeSpec) -> str | BufferSID | Tuple: name=dim.value, static_stride=1 if ( - config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE + config.unstructured_horizontal_has_unit_stride and dim.kind == common.DimensionKind.HORIZONTAL ) else None, diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index b9d06a1e26..a7d2d95894 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -64,7 +64,7 @@ def get_cache_folder( case config.BuildCacheLifetime.SESSION: base_path = _session_cache_dir_path case config.BuildCacheLifetime.PERSISTENT: - base_path = config.BUILD_CACHE_DIR + base_path = config.build_cache_dir case _: raise ValueError("Unsupported caching lifetime.") diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index e46d8219a2..633f5ffdea 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -121,9 +121,9 @@ def compiled_program_call_context( def _init_async_compilation_pool() -> None: global _async_compilation_pool - if _async_compilation_pool is None and config.BUILD_JOBS > 0: + if _async_compilation_pool is None and config.build_jobs > 0: _async_compilation_pool = concurrent.futures.ThreadPoolExecutor( - max_workers=config.BUILD_JOBS + max_workers=config.build_jobs ) diff --git a/src/gt4py/next/otf/options.py b/src/gt4py/next/otf/options.py index 4f77d44586..48672bb36b 100644 --- a/src/gt4py/next/otf/options.py +++ b/src/gt4py/next/otf/options.py @@ -25,7 +25,7 @@ class CompilationOptions: #: to `compile` before calling. # Uses a factory to make changes to the config after module import time take effect. This is # mostly important for testing. Users should not rely on it. - enable_jit: bool = dataclasses.field(default_factory=lambda: config.ENABLE_JIT_DEFAULT) + enable_jit: bool = dataclasses.field(default_factory=lambda: config.enable_jit_default) #: If the user requests static params, they will be used later to initialize CompiledPrograms. #: By default the set of static params is set when compiling for the first time, e.g. on call diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index eb84de9185..cabd4eb657 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -709,7 +709,7 @@ def add_nested_sdfg( if dataname in data_args: # Uninitialized arguments should not be used inside the nested SDFG. if (arg_node := data_args[dataname]) is None: - inner_ctx.sdfg.remove_data(dataname, validate=gtx_config.DEBUG) + inner_ctx.sdfg.remove_data(dataname, validate=gtx_config.debug) else: input_memlets[dataname] = outer_ctx.sdfg.make_array_memlet( arg_node.dc_node.data diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index 32e3ba8a31..16fd10ca6d 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -115,7 +115,7 @@ def make_dace_backend( # Set `unit_strides_kind` based on the gt4py env configuration. optimization_args = optimization_args | { "unit_strides_kind": common.DimensionKind.HORIZONTAL - if config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE + if config.unstructured_horizontal_has_unit_stride else None } diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 7686ae097c..d392474882 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -96,7 +96,7 @@ def fast_call(self) -> None: "Argument vector was not set properly." ) self.sdfg_program.fast_call( - self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.DEBUG + self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=config.debug ) def __call__(self, **kwargs: Any) -> None: diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 103e7af33b..ad96ddd1fa 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -64,7 +64,7 @@ def decorated_program( filter_args=False, ) this_call_args |= { - gtx_wfdcommon.SDFG_ARG_METRIC_LEVEL: config.COLLECT_METRICS_LEVEL, + gtx_wfdcommon.SDFG_ARG_METRIC_LEVEL: config.collect_metrics_level, gtx_wfdcommon.SDFG_ARG_METRIC_COMPUTE_TIME: collect_time_arg, } fun.construct_arguments(**this_call_args) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 62febd0965..ef34918370 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -39,7 +39,7 @@ class Params: auto_optimize: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE + lambda: config.cmake_build_type ) cached_translation = factory.Trait( @@ -47,7 +47,7 @@ class Params: lambda o: workflow.CachedStep( o.bare_translation, hash_function=stages.fingerprint_compilable_program, - cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "translation_cache")), + cache=filecache.FileCache(str(config.build_cache_dir / "translation_cache")), ) ), ) @@ -68,7 +68,7 @@ class Params: compilation = factory.SubFactory( DaCeCompilationStepFactory, bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), + cache_lifetime=factory.LazyFunction(lambda: config.build_cache_lifetime), device_type=factory.SelfAttribute("..device_type"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 407b19d553..a9f8c343f0 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -39,7 +39,7 @@ def find_constant_symbols( """Helper function to find symbols to replace with constant values.""" constant_symbols: dict[str, int] = {} - if config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE: + if config.unstructured_horizontal_has_unit_stride: # Search the stride symbols corresponding to the horizontal dimension for p in ir.params: if isinstance(p.type, ts.FieldType): @@ -268,7 +268,7 @@ def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: dace.Memlet(f"{output}[0]"), ) - if gpu and _has_gpu_schedule(sdfg) and config.ADD_GPU_TRACE_MARKERS: + if gpu and _has_gpu_schedule(sdfg) and config.add_gpu_trace_markers: sdfg.instrument = dace.dtypes.InstrumentationType.GPU_TX_MARKERS for node, _ in sdfg.all_nodes_recursive(): if isinstance( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index be97f9a25a..15b468d5e7 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -113,7 +113,7 @@ class Meta: class Params: device_type: core_defs.DeviceType = core_defs.DeviceType.CPU cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough - lambda: config.CMAKE_BUILD_TYPE + lambda: config.cmake_build_type ) builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) @@ -124,7 +124,7 @@ class Params: lambda o: workflow.CachedStep( o.bare_translation, hash_function=stages.fingerprint_compilable_program, - cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), + cache=filecache.FileCache(str(config.build_cache_dir / "gtfn_cache")), ) ), ) @@ -141,7 +141,7 @@ class Params: ) compilation = factory.SubFactory( compiler.CompilerFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), + cache_lifetime=factory.LazyFunction(lambda: config.build_cache_lifetime), builder_factory=factory.SelfAttribute("..builder_factory"), ) decoration = factory.LazyAttribute( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 16916697a2..aae43f44d3 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -215,7 +215,7 @@ class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, stages.Execu transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` def __call__(self, inp: definitions.CompilableProgramDef) -> stages.ExecutableProgram: - debug = config.DEBUG if self.debug is None else self.debug + debug = config.debug if self.debug is None else self.debug fencil = fencil_generator( inp.data, diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index ba174aa6d1..7449b4dac8 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -16,7 +16,7 @@ __all__ = ["definitions", "get_processor_id"] -if config.BUILD_CACHE_LIFETIME is config.BuildCacheLifetime.PERSISTENT: +if config.build_cache_lifetime is config.BuildCacheLifetime.PERSISTENT: warnings.warn( "You are running GT4Py tests with BUILD_CACHE_LIFETIME set to PERSISTENT!", UserWarning ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index 4578576f02..2db15da661 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -60,7 +60,7 @@ def testee(a: cases.IField, out: cases.IField): testee_op(a, a, out=out) with ( - mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics_level), + mock.patch("gt4py.next.config.collect_metrics_level", metrics_level), mock.patch( "gt4py.next.instrumentation.metrics.sources", collections.defaultdict(metrics.Source) ), diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py index 7aee5580f7..2499a23546 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py @@ -58,7 +58,7 @@ def test_set_current_source_key_different_key_raises(self): class TestSourceKeyContextManager: def test_context_manager_sets_and_resets_key(self): - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.MINIMAL): metrics._source_key_cvar.set( metrics._NO_KEY_SET_MARKER_ ) # Reset context variable before test @@ -78,7 +78,7 @@ def test_context_manager_sets_and_resets_key(self): ) def test_context_manager_with_no_key(self): - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.MINIMAL): metrics._source_key_cvar.set("__BEFORE__MARKER__") # Reset context variable before test with metrics.SourceKeyContextManager(): @@ -92,7 +92,7 @@ def test_context_manager_with_no_key(self): assert metrics._source_key_cvar.get(metrics._NO_KEY_SET_MARKER_) == "__BEFORE__MARKER__" def test_context_manager_nested(self): - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.MINIMAL): metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) key1 = "outer_key" key2 = "inner_key" @@ -121,7 +121,7 @@ class TestCollector( ): ... metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.MINIMAL): outer_key = "outer_key" metrics.set_current_source_key("outer_key") assert metrics.get_current_source_key() == outer_key @@ -140,7 +140,7 @@ class TestCollector( key = "test_disabled" metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.DISABLED): + with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.DISABLED): metrics.set_current_source_key(key) with TestCollector(key=key): @@ -161,7 +161,7 @@ class CustomCollector( key = "test_custom" metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.PERFORMANCE): + with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.PERFORMANCE): with CustomCollector(key=key): pass diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py index cd90208340..dbc4d87490 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py @@ -92,7 +92,7 @@ def mocked_gpu_transformation(*args, **kwargs) -> dace.SDFG: monkeypatch.setattr(gtx_transformations, "gt_auto_optimize", mocked_auto_optimize) monkeypatch.setattr(gtx_transformations, "gt_gpu_transformation", mocked_gpu_transformation) - with mock.patch("gt4py.next.config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE", on_gpu): + with mock.patch("gt4py.next.config.unstructured_horizontal_has_unit_stride", on_gpu): custom_backend = dace_wf_backend.make_dace_backend( gpu=on_gpu, cached=False, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py index 88eaffe345..20651021aa 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py @@ -97,7 +97,7 @@ def test_find_constant_symbols(has_unit_stride, disable_field_origin): ], ) - with mock.patch("gt4py.next.config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE", has_unit_stride): + with mock.patch("gt4py.next.config.unstructured_horizontal_has_unit_stride", has_unit_stride): sdfg = _translate_gtir_to_sdfg( ir=ir, offset_provider=SKIP_VALUE_MESH.offset_provider, From b94e4eab815e15539a99cd033d9cf7d13fbbcdb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 25 Feb 2026 15:59:46 +0100 Subject: [PATCH 04/12] More cleanups and refactorings (in utils mostly) --- src/gt4py/eve/utils.py | 91 +++-- src/gt4py/next/__init__.py | 7 +- src/gt4py/next/_config.py | 344 +++++++++++++----- tests/eve_tests/unit_tests/test_utils.py | 166 +++++++++ .../{test_config.py => test_old_config.py} | 0 5 files changed, 487 insertions(+), 121 deletions(-) rename tests/next_tests/unit_tests/{test_config.py => test_old_config.py} (100%) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 3c909973b6..3674780111 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -527,43 +527,82 @@ def partial(self, *args: Any, **kwargs: Any) -> fluid_partial: return fluid_partial(self, *args, **kwargs) -if xtyping.TYPE_CHECKING: +class TypeMapping(collections.abc.Mapping[type, _T]): + """ + A mapping from types to values supporting complex type-based dispatching. - class TypeDispatcher(Generic[_P, _T]): - def __init__(self, func: Callable[_P, _T]) -> None: ... + The mapping supports registering values for specific types, and retrieving + values based on the type of the key, including support for inheritance + exactly in the same way as `functools.singledispatch()` works. For example, + if a value is registered for a base class, it will be returned for + instances of derived classes unless a more specific type is registered. - def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ... + Examples: + >>> mapping = TypeMapping(lambda type_: f"Default for {type_}") + >>> mapping[int] = "Integer handler" + >>> mapping[int] + 'Integer handler' + >>> mapping[float] + "Default for " + + >>> import collections + >>> mapping[tuple] = "Tuple handler" + >>> mapping[tuple] + 'Tuple handler' + >>> mapping[collections.namedtuple("Point", ["x", "y"])] + 'Tuple handler' + """ - @overload - def register(self, arg_type: type) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... + def __init__(self, fallback_factory: Callable[[type], _T]) -> None: + self._fallback_factory = fallback_factory + self._dispatcher = functools.singledispatch(self._fallback_factory) - @overload - def register(self, arg_type: type, func: Callable[_P, _T]) -> Callable[_P, _T]: ... + def __getitem__(self, type_: type) -> _T: + dispatched = self._dispatcher.dispatch(type_) + return ( + self._fallback_factory(type_) + if dispatched is self._fallback_factory + else cast(_T, dispatched) + ) - def register( - self, arg_type: type, func: Optional[Callable[_P, _T]] = None - ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]] | Callable[_P, _T]: ... + def __setitem__(self, type_: type, value: _T) -> None: + self._dispatcher.register(type_, value) # type: ignore[call-overload] # abusine singledispatch to register any value, not just callables + self.clear_cache() + def __iter__(self) -> Iterator[type]: + return iter(self._dispatcher.registry) -def type_dispatcher(func: Callable[_P, _T]) -> TypeDispatcher[_P, _T]: - """ - Decorator to create a single-dispatch generic function that dispatches of the first argument. - """ - # reuse the singledispatch() dispatching mechanism but change the wrapper - indirect_dispatcher = functools.singledispatch(func) - dispatch = indirect_dispatcher.dispatch + def __len__(self) -> int: + return len(self._dispatcher.registry) + + def __contains__(self, type_: object) -> bool: + """Check if a type is registered in the mapping (including via superclasses).""" + return self._dispatcher.dispatch(type_) is not self._fallback_factory + + @overload + def register(self, type_: type, value: _T) -> _T: ... - def wrapper(*args: P.args, **kw: P.kwargs) -> _T: - if not args: - raise TypeError(f"{func.__name__} requires at least 1 positional argument") + @overload + def register(self, type_: type, value: NothingType = NOTHING) -> Callable[[_T], _T]: ... + + def register(self, type_: type, value: _T | NothingType = NOTHING) -> _T | Callable[[_T], _T]: + """Return a decorator to register a value for the given type.""" + + if value is not NOTHING: + assert not isinstance(value, NothingType) + self[type_] = value + return value + else: - return dispatch(args[0])(*args, **kw) + def _decorator(value: _T) -> _T: + self[type_] = value + return value - functools.update_wrapper(wrapper, func) - for attr in ("dispatch", "register", "registry", "_clear_cache"): - setattr(wrapper, attr, getattr(indirect_dispatcher, attr)) + return _decorator - return wrapper + def clear_cache(self) -> None: + """Clear the singledispatch cache.""" + self._dispatcher._clear_cache() @overload diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 3e3d4641cd..c6ca822737 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -21,11 +21,12 @@ # ruff: noqa: F401 from __future__ import annotations -from .._core.definitions import CUPY_DEVICE_TYPE, Device, DeviceType, is_scalar_type -from . import common, ffront, iterator, program_processors, typing # reexport the actual configuration manager instance as a publice attribute -from ._config import config +from ._config import config # ruff: isort: skip + +from .._core.definitions import CUPY_DEVICE_TYPE, Device, DeviceType, is_scalar_type +from . import common, ffront, iterator, program_processors, typing from .common import ( Connectivity, Dimension, diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index 5583559847..a68d35300a 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -31,15 +31,29 @@ import sys import types from collections.abc import Callable, Generator, Mapping -from typing import Any, Generic, Literal, Protocol, TypeVar, cast, final +from typing import Any, Final, Generic, Literal, Protocol, TypeVar, final, overload from gt4py.eve import utils from gt4py.eve.extended_typing import Self @final -class Sentinel: - UNSET = object() +class _UnsetSentinel: + """Sentinel value for unset configuration options.""" + + __slots__ = () + _instance: _UnsetSentinel | None = None + + def __new__(cls) -> _UnsetSentinel: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __repr__(self) -> str: + return "" + + +UNSET: Final[_UnsetSentinel] = _UnsetSentinel() _T = TypeVar("_T") @@ -47,70 +61,66 @@ class Sentinel: _EnumT = TypeVar("_EnumT", bound=enum.Enum) -@utils.type_dispatcher -def get_value_from_environment_var( - as_type: type[_T], var_name: str, *, default: _T | None = None +def parse_env_var( + var_name: str, parser: Callable[[str], _T], *, default: _T | None = None ) -> _T | None: - """ - Create an instance of the provided type from the value of an environment variable. - - The implementation uses a explicit type-dispatcher to allow custom parsing logic - for different types (e.g. bool, enums). - """ - env_value = os.environ.get(var_name, None) - if env_value is None: + """Get a python value from an environment variable.""" + env_var_value = os.environ.get(var_name, None) + if env_var_value is None: return default + try: - return as_type(env_value) + return parser(env_var_value) except Exception as e: - raise TypeError( - f"Unsupported conversion of GT4Py environment variable {var_name}: " - f"{env_value!r} to type '{as_type.__name__}'." + raise RuntimeError( + f"Parsing '{var_name}' (value: '{env_var_value}') environment variable {var_name} failed!" ) from e -@get_value_from_environment_var.register(bool) -def _get_value_from_environment_var_as_bool( - as_type: type[bool], var_name: str, *, default: bool | None = None -) -> bool | None: - env_value = os.environ.get(var_name, None) - if env_value is None: - return default - match env_value.strip().upper(): +@utils.TypeMapping +def _parse_str(type_: type) -> Callable[[str], Any]: + """Default parser: the type string value as is.""" + match type_: + case enum.Enum() as enum_type: + return lambda value: enum_type[value] # parse enum values from their names + case _: + return lambda x: 1 # type constructor as parser + + +@_parse_str.register(bool) +def _parse_str_as_bool(value: str) -> bool: + match value.strip().upper(): case "0" | "FALSE" | "OFF": return False case "1" | "TRUE" | "ON": return True case _: raise ValueError( - f"Invalid GT4Py environment flag value for {var_name}: use '0 | FALSE | OFF' or '1 | TRUE | ON'." + f"{value} cannot be parsed as a boolean value. Use '0 | FALSE | OFF' or '1 | TRUE | ON'." ) -@get_value_from_environment_var.register(enum.Enum) -def _get_value_from_environment_var_as_enum( - as_type: type[_EnumT], var_name: str, *, default: _EnumT | None = None -) -> _EnumT | None: - """Create enum by member name.""" - env_value = os.environ.get(var_name, None) - if env_value is None: - return default - - try: - return as_type[env_value] - except Exception as e: - raise TypeError( - f"Invalid GT4Py enum value for {var_name}: {env_value!r}. Allowed: {[m.name for m in as_type]}." - ) from e +@_parse_str.register(pathlib.Path) +def _parse_str_as_path(value: str) -> pathlib.Path: + expanded = os.path.expandvars(os.path.expanduser(value)) + return pathlib.Path(expanded) class UpdateScope(str, enum.Enum): + """Scope of a configuration option update.""" + GLOBAL = sys.intern("global") CONTEXT = sys.intern("context") class OptionUpdateCallback(Protocol[_T_contra]): - """Callback invoked after an option changes (in a global or local context scope).""" + """ + Callback invoked after an option changes. + + Callbacks are invoked after both global (via set() or __setattr__) + and context-local (via overrides()) updates. This allows observers + to react to configuration changes. + """ def __call__( self, new_val: _T_contra, old_val: _T_contra | None, scope: UpdateScope @@ -130,24 +140,44 @@ class OptionDescriptor(Generic[_T, ConfigManagerT]): to support the bare attribute-style access to the option value on the manager instance (e.g. `config.debug`), which will be resolved properly using the precedence rules defined in `ConfigManager.get()`. + + Attributes: + type: The Python type of this configuration option. + default: Initial fallback value for this option. Mutually exclusive with default_factory. + default_factory: Callable to compute the default value given a ConfigManager instance. + Mutually exclusive with default. + validator: Callable that validates the option value, or "type_check" for isinstance checking. + Set to None to disable validation. + update_callback: Optional callback invoked after the option is updated (globally or in context). + env_prefix: Prefix for the environment variable name. + name: Name of the option (set automatically via __set_name__). + + Example: + >>> class Config(ConfigManager): + ... debug = OptionDescriptor( + ... type=bool, + ... default=False, + ... update_callback=lambda new, old, scope: print(f"Changed to {new}"), + ... ) """ - type: type[_T] - default: dataclasses.InitVar[_T | Literal[Sentinel.UNSET]] = Sentinel.UNSET + option_type: type[_T] + default: dataclasses.InitVar[_T | _UnsetSentinel] = UNSET default_factory: Callable[[ConfigManagerT], _T] | None = None + parser: Callable[[str], _T] | None = None validator: Callable[[Any], Any] | Literal["type_check"] | None = "type_check" update_callback: OptionUpdateCallback[_T] | None = None env_prefix: str = "GT4PY_" - name: str = dataclasses.field(init=False) + name: str = dataclasses.field(init=False, default="") - def __post_init__(self, default: _T | Literal[Sentinel.UNSET]) -> None: + def __post_init__(self, default: _T | _UnsetSentinel) -> None: # Initialize the validator if self.validator == "type_check": - object.__setattr__(self, "validator", utils.isinstancechecker(self.type)) + object.__setattr__(self, "validator", utils.isinstancechecker(self.option_type)) assert self.validator is None or callable(self.validator) # Initialize the default factory based on the provided default/default_factory - if default is not Sentinel.UNSET: + if not isinstance(default, _UnsetSentinel): if self.default_factory is not None: raise ValueError( "Cannot specify both default and default_factory for a config option descriptor." @@ -164,7 +194,26 @@ def __set_name__(self, owner: type, name: str) -> None: """Set the name of the option based on the attribute name in the owner class.""" object.__setattr__(self, "name", name) - def __get__(self, instance: Any, owner: type | None = None) -> _T | Self: + @overload + def __get__( + self, instance: ConfigManagerT, owner: type[ConfigManagerT] | None = None + ) -> _T: ... + + @overload + def __get__( + self, instance: Literal[None], owner: type[ConfigManagerT] | None = None + ) -> Self: ... + + def __get__( + self, instance: ConfigManagerT | None, owner: type[ConfigManagerT] | None = None + ) -> _T | Self: + """ + Get the configuration option value. + + If accessed on the class (instance is None), returns the descriptor itself. + If accessed on an instance, delegates to ConfigManager.get() to get the + effective current value. + """ try: assert isinstance(instance, ConfigManager) return instance.get(self.name) @@ -175,22 +224,43 @@ def __get__(self, instance: Any, owner: type | None = None) -> _T | Self: raise AttributeError(f"Error reading config option {self.name!r}") from e def __set__(self, instance: Any, value: _T) -> None: + """ + Set the global value of the configuration option. + + This delegates to ConfigManager.set() which handles global updates and validation. + """ assert isinstance(instance, ConfigManager) instance.set(self.name, value) @property def env_var_name(self) -> str: - """Construct the name of the environment variable corresponding to this option.""" + """Construct the name of the environment variable corresponding to this option. + + Returns the environment variable name by combining the prefix and option name in uppercase. + E.g., for env_prefix="GT4PY_" and name="debug", returns "GT4PY_DEBUG". + """ return f"{self.env_prefix}{self.name}".upper() class ConfigManager: - """ - Central configuration manager with attribute-style access. + """Central configuration manager with attribute-style access. Config options are defined as class attributes using `OptionDescriptor`. The manager stores global values for all options and allows temporary overrides in a context manager scope. + + The effective value of an option follows this precedence (highest to lowest): + 1. Active context override via the `overrides()` context manager + 2. Global runtime value set via the `set()` method + 3. Environment variable (if set) + 4. Descriptor default or default_factory result + + Example: + >>> config = ConfigManager() + >>> config.get("some_option") # Apply precedence rules + >>> config.set("some_option", value) # Set global value + >>> with config.overrides(some_option=value): # Temporary override + ... pass """ def __init__(self) -> None: @@ -224,22 +294,48 @@ def __init__(self) -> None: self._global_context: dict[str, Any] = {} for name, desc in self._descriptors.items(): assert desc.default_factory is not None # Guaranteed by __post_init__ - init_value = get_value_from_environment_var( - desc.type, desc.env_var_name, default=desc.default_factory(self) + init_value = parse_env_var( + desc.env_var_name, desc.parser or _parse_str[desc.option_type], default=None ) if validator := self._validators.get(name): validator(init_value) self._global_context[name] = init_value def get(self, name: str) -> Any: - if __debug__ and name not in self._keys: + """Get the effective value of a configuration option. + + Applies precedence rules: context override > global value > environment > default. + + Args: + name: The name of the configuration option. + + Returns: + The effective value of the option. + + Raises: + AttributeError: If the option name is not recognized. + """ + if name not in self._keys: raise AttributeError(f"Unrecognized config option: {name}") - if (val := self._local_context_cvar.get().get(name, Sentinel.UNSET)) is Sentinel.UNSET: - return self._global_context[name] - return val + if (val := self._local_context_cvar.get().get(name, UNSET)) is not UNSET: + return val + return self._global_context[name] def set(self, name: str, value: Any) -> None: - if __debug__ and name not in self._keys: + """Set the global value of a configuration option. + + Validates the value and invokes any registered callbacks. + + Args: + name: The name of the configuration option. + value: The new value for the option. + + Raises: + AttributeError: If the option name is not recognized, or if the option + is currently overridden in a context manager. + Validation error: If the value fails validation. + """ + if name not in self._keys: raise AttributeError(f"Unrecognized config option: {name}") if name in self._local_context_cvar.get(): raise AttributeError( @@ -254,6 +350,27 @@ def set(self, name: str, value: Any) -> None: @contextlib.contextmanager def overrides(self, **overrides: Any) -> Generator[None, None, None]: + """Context manager for temporary configuration overrides. + + Overrides are task-local (isolated per thread/async task) and automatically + reverted when exiting the context manager. Nested contexts are supported. + + Args: + **overrides: Configuration option names and their temporary values. + + Yields: + None + + Raises: + AttributeError: If any override name is not a recognized configuration option. + Validation error: If any override value fails validation. + + Example: + >>> with config.overrides(debug=True, verbose_exceptions=True): + ... # Use config with temporary overrides + ... pass + >>> # Overrides are automatically reverted here + """ if overrides.keys() - self._keys: raise AttributeError( f"Unrecognized config options: {set(overrides.keys()) - self._keys}" @@ -287,52 +404,79 @@ def overrides(self, **overrides: Any) -> Generator[None, None, None]: self._hooks[name](old_context.get(name), new_context.get(name), UpdateScope.CONTEXT) def as_dict(self) -> dict[str, Any]: - """Get the current effective configuration options as a dictionary.""" + """Get the current effective configuration options as a dictionary. + + Returns all configuration options with their effective values, preserving + the order they were defined in the class. + + Returns: + A dictionary mapping option names to their effective values. + """ # We use self._descriptors to preserve the order of options as defined in the class. return {name: self.get(name) for name in self._descriptors.keys()} def _option_descriptors_(self) -> types.MappingProxyType[str, OptionDescriptor]: - """Get the option descriptors.""" + """Get the option descriptors. + + Returns a read-only mapping of option names to their descriptors. + This is useful for introspection and documentation purposes. + + Returns: + A MappingProxyType mapping option names to OptionDescriptor instances. + """ return types.MappingProxyType(self._descriptors) class Config(ConfigManager): - """ - GT4Py configuration registry. + """GT4Py configuration registry. + + This class is used to register and manage all configuration options for GT4Py. + All publicly exposed options should be defined here as OptionDescriptor instances. - This class is used to register configuration options for GT4Py. + Options defined here can be configured via: + - Environment variables (GT4PY_OPTION_NAME format) + - Direct calls to config.set() + - Context manager overrides with config.overrides() """ ## -- Debug options -- #: Master debug flag. It changes defaults for all the other options to be as helpful - #: for debugging as possible. - debug = OptionDescriptor(type=bool, default=False, validator=utils.isinstancechecker(bool)) + #: for debugging as possible. Environment variable: GT4PY_DEBUG + debug = OptionDescriptor( + option_type=bool, default=False, validator=utils.isinstancechecker(bool) + ) - #: Verbose flag for DSL compilation errors. + #: Verbose flag for DSL compilation errors. Defaults to the value of debug. + #: Environment variable: GT4PY_VERBOSE_EXCEPTIONS verbose_exceptions = OptionDescriptor[bool, "Config"]( - type=bool, default_factory=(lambda cfg: cast(bool, cfg.debug)) + option_type=bool, default_factory=(lambda cfg: cfg.debug) ) ## -- Instrumentation options -- #: User-defined level to enable metrics at lower or equal level. #: Enabling metrics collection will do extra synchronization and will have - #: impact on runtime performance. - collect_metrics_level = OptionDescriptor(type=int, default=0) + #: impact on runtime performance. Environment variable: GT4PY_COLLECT_METRICS_LEVEL + collect_metrics_level = OptionDescriptor(option_type=int, default=0) #: Add GPU trace markers (NVTX, ROC-TX) to the generated code, at compile time. - # FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. - add_gpu_trace_markers = OptionDescriptor(type=bool, default=False) + #: Environment variable: GT4PY_ADD_GPU_TRACE_MARKERS + #: FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. + add_gpu_trace_markers = OptionDescriptor(option_type=bool, default=False) ## -- Build options -- class BuildCacheLifetime(enum.Enum): + """Build cache lifetime modes.""" + SESSION = "session" PERSISTENT = "persistent" #: Whether generated code projects should be kept around between runs. #: - SESSION: generated code projects get destroyed when the interpreter shuts down - #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs + #: - PERSISTENT: generated code projects are written to build_cache_dir and persist between runs + #: Defaults to PERSISTENT in debug mode, SESSION otherwise. + #: Environment variable: GT4PY_BUILD_CACHE_LIFETIME build_cache_lifetime = OptionDescriptor[BuildCacheLifetime, "Config"]( - type=BuildCacheLifetime, + option_type=BuildCacheLifetime, default_factory=( lambda cfg: cfg.BuildCacheLifetime.PERSISTENT if cfg.debug @@ -340,9 +484,11 @@ class BuildCacheLifetime(enum.Enum): ), ) - #: Where generated code projects should be persisted. - #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT - build_cache_dir_root = OptionDescriptor(type=pathlib.Path, default=pathlib.Path.cwd()) + #: Where generated code projects should be persisted when BUILD_CACHE_LIFETIME is PERSISTENT. + #: Supports ~ expansion and environment variable substitution ($VAR, ${VAR}). + #: The actual cache directory will be this path with '/.gt4py_cache' appended. + #: Environment variable: GT4PY_BUILD_CACHE_DIR_ROOT + build_cache_dir_root = OptionDescriptor(option_type=pathlib.Path, default=pathlib.Path.cwd()) @property def build_cache_dir(self) -> pathlib.Path: @@ -350,10 +496,15 @@ def build_cache_dir(self) -> pathlib.Path: return self.build_cache_dir_root / ".gt4py_cache" class CMakeBuildType(enum.Enum): - """ - CMake build types enum. + """CMake build types enum. Member values have to be valid CMake syntax. + + Attributes: + DEBUG: Debug build with symbols and no optimization. + RELEASE: Release build with optimization and no symbols. + REL_WITH_DEB_INFO: Release build with optimization and debug symbols. + MIN_SIZE_REL: Release build optimized for minimal size. """ DEBUG = "Debug" @@ -362,33 +513,42 @@ class CMakeBuildType(enum.Enum): MIN_SIZE_REL = "MinSizeRel" #: Build type to be used when CMake is used to compile generated code. + #: Defaults to DEBUG in debug mode, RELEASE otherwise. #: Might have no effect when CMake is not used as part of the toolchain. - # FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. + #: Environment variable: GT4PY_CMAKE_BUILD_TYPE + #: FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. cmake_build_type = OptionDescriptor[CMakeBuildType, "Config"]( - type=CMakeBuildType, + option_type=CMakeBuildType, default_factory=( lambda cfg: cfg.CMakeBuildType.DEBUG if cfg.debug else cfg.CMakeBuildType.RELEASE ), ) - #: Number of threads to use to use for compilation (0 = synchronous compilation). - #: Default: - #: - use os.cpu_count(), TODO(havogt): in Python >= 3.13 use `process_cpu_count()` - #: - if os.cpu_count() is None we are conservative and use 1 job, - #: - if the number is huge (e.g. HPC system) we limit to a smaller number + #: Number of threads to use for compilation (0 = synchronous compilation). + #: Default behavior: + #: - Uses os.cpu_count() if available (TODO: Python >= 3.13 use process_cpu_count()) + #: - Falls back to 1 if os.cpu_count() returns None + #: - Caps the value at 32 to avoid excessive resource usage on HPC systems + #: Environment variable: GT4PY_BUILD_JOBS build_jobs = OptionDescriptor( - type=int, + option_type=int, default_factory=lambda ctx: min(os.cpu_count() or 1, 32), ) ## -- Code-generation options -- #: Experimental, use at your own risk: assume horizontal dimension has stride 1 - # FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. - unstructured_horizontal_has_unit_stride = OptionDescriptor(type=bool, default=False) + #: Environment variable: GT4PY_UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE + #: FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. + unstructured_horizontal_has_unit_stride = OptionDescriptor(option_type=bool, default=False) #: The default for whether to allow jit-compilation for a compiled program. - #: This default can be overriden per program. - enable_jit_default = OptionDescriptor(type=bool, default=True) + #: This default can be overridden per program via their respective APIs. + #: Environment variable: GT4PY_ENABLE_JIT_DEFAULT + enable_jit_default = OptionDescriptor(option_type=bool, default=True) +#: Global singleton instance of the GT4Py configuration manager. +#: Use this to access and modify configuration options: config.debug, config.set(...), etc. config = Config() + +print(config.as_dict()) diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index ae8e938396..191afd3fb2 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -285,6 +285,172 @@ def func(a, b, c): assert fp3() == 6 +class TestTypeMapping: + """Unit tests for TypeMapping class.""" + + def test_basic_getitem(self): + """Test basic type-to-value mapping retrieval.""" + from gt4py.eve.utils import TypeMapping + + def fallback(type_): + return f"default_{type_.__name__}" + + mapping = TypeMapping(fallback) + + # Register some types + mapping[int] = "integer" + mapping[str] = "string" + + assert mapping[int] == "integer" + assert mapping[str] == "string" + + def test_fallback_factory(self): + """Test that fallback factory is used for unregistered types.""" + from gt4py.eve.utils import TypeMapping + + def fallback(type_): + return f"default_{type_.__name__}" + + mapping = TypeMapping(fallback) + mapping[int] = "integer" + + # Unregistered type should use fallback + assert mapping[str] == "default_str" + assert mapping[float] == "default_float" + assert mapping[list] == "default_list" + + def test_setitem_and_register(self): + """Test both __setitem__ and register methods.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: None) + + # Using __setitem__ + mapping[int] = "via_setitem" + assert mapping[int] == "via_setitem" + + # Using register method + result = mapping.register(str, "via_register") + assert result == "via_register" + assert mapping[str] == "via_register" + + # Using register method in a decorator style + result = mapping.register(str)("via_register_decorator") + assert result == "via_register_decorator" + assert mapping[str] == "via_register_decorator" + + def test_callable_value_registration(self): + """Test registering callable objects as values.""" + from gt4py.eve.utils import TypeMapping + + def int_handler(a): + return f"int_handler: {a}" + + def str_handler(a): + return f"str_handler: {a}" + + def any_handler(a): + return f"{type(a).__name__}_handler: {a}" + + mapping = TypeMapping(lambda t: any_handler) + mapping[int] = int_handler + mapping[str] = str_handler + + assert callable(mapping[int]) + assert callable(mapping[str]) + assert mapping[int](1) == "int_handler: 1" + assert mapping[str](2) == "str_handler: 2" + assert mapping[tuple]((3, 4)) == "tuple_handler: (3, 4)" + + def test_multiple_type_registrations(self): + """Test registering and retrieving multiple types.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: f"fallback_{t.__name__}") + + types_values = { + int: "number", + str: "text", + float: "decimal", + list: "sequence", + dict: "mapping", + set: "unique", + tuple: "immutable", + } + + for type_, value in types_values.items(): + mapping[type_] = value + + for type_, expected_value in types_values.items(): + assert mapping[type_] == expected_value + + def test_overwrite_registration(self): + """Test that re-registering a type overwrites the previous value.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: None) + + mapping[int] = "first" + assert mapping[int] == "first" + + mapping[int] = "second" + assert mapping[int] == "second" + + mapping[int] = "third" + assert mapping[int] == "third" + + def test_subclass_dispatch(self): + """Test that singledispatch works with subclasses.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: "default") + + class BaseClass: + pass + + class SubClass(BaseClass): + pass + + mapping[BaseClass] = "base" + + # Subclass should dispatch to BaseClass handler + assert mapping[SubClass] == "base" + assert mapping[BaseClass] == "base" + + def test_complex_fallback_factory(self): + """Test TypeMapping with a complex fallback factory function.""" + from gt4py.eve.utils import TypeMapping + + def complex_fallback(type_): + if hasattr(type_, "__len__"): + return f"sized_{type_.__name__}" + else: + return f"unsized_{type_.__name__}" + + mapping = TypeMapping(complex_fallback) + + assert "sized" in mapping[str] + assert "sized" in mapping[list] + assert "sized" in mapping[dict] + assert "unsized" in mapping[float] + assert "unsized" in mapping[int] + + def test_iteration(self): + """Test iteration over registered types.""" + from gt4py.eve.utils import TypeMapping + + mapping = TypeMapping(lambda t: None) + types_to_register = [int, str, float, list, dict] + + for i, type_ in enumerate(types_to_register): + mapping[type_] = f"value_{i}" + + # Check that all registered types are in iteration + registered_types = list(mapping) + for type_ in types_to_register: + assert type_ in registered_types + + def test_noninstantiable_class(): @eve.utils.noninstantiable class NonInstantiableClass(eve.datamodels.DataModel): diff --git a/tests/next_tests/unit_tests/test_config.py b/tests/next_tests/unit_tests/test_old_config.py similarity index 100% rename from tests/next_tests/unit_tests/test_config.py rename to tests/next_tests/unit_tests/test_old_config.py From 981c2b13fc84ee035a9049193e28edf7c3d7b11f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 25 Feb 2026 16:17:44 +0100 Subject: [PATCH 05/12] Fix typings in _config --- src/gt4py/next/_config.py | 41 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index a68d35300a..bcc0bb1152 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -31,7 +31,7 @@ import sys import types from collections.abc import Callable, Generator, Mapping -from typing import Any, Final, Generic, Literal, Protocol, TypeVar, final, overload +from typing import Any, Final, Generic, Literal, Protocol, TypeVar, cast, final, overload from gt4py.eve import utils from gt4py.eve.extended_typing import Self @@ -82,6 +82,7 @@ def _parse_str(type_: type) -> Callable[[str], Any]: """Default parser: the type string value as is.""" match type_: case enum.Enum() as enum_type: + assert issubclass(enum_type, enum.Enum) return lambda value: enum_type[value] # parse enum values from their names case _: return lambda x: 1 # type constructor as parser @@ -127,11 +128,11 @@ def __call__( ) -> None: ... -ConfigManagerT = TypeVar("ConfigManagerT", bound="ConfigManager") +# ConfigManagerT = TypeVar("ConfigManagerT", bound="ConfigManager") @dataclasses.dataclass(frozen=True, kw_only=True) -class OptionDescriptor(Generic[_T, ConfigManagerT]): +class OptionDescriptor(Generic[_T]): """ Descriptor for a configuration option. @@ -163,7 +164,7 @@ class OptionDescriptor(Generic[_T, ConfigManagerT]): option_type: type[_T] default: dataclasses.InitVar[_T | _UnsetSentinel] = UNSET - default_factory: Callable[[ConfigManagerT], _T] | None = None + default_factory: Callable[[ConfigManager], _T] | None = None parser: Callable[[str], _T] | None = None validator: Callable[[Any], Any] | Literal["type_check"] | None = "type_check" update_callback: OptionUpdateCallback[_T] | None = None @@ -195,18 +196,14 @@ def __set_name__(self, owner: type, name: str) -> None: object.__setattr__(self, "name", name) @overload - def __get__( - self, instance: ConfigManagerT, owner: type[ConfigManagerT] | None = None - ) -> _T: ... + def __get__(self, instance: ConfigManager, owner: type[ConfigManager]) -> _T: ... @overload - def __get__( - self, instance: Literal[None], owner: type[ConfigManagerT] | None = None - ) -> Self: ... + def __get__(self, instance: None, owner: None) -> OptionDescriptor[_T]: ... def __get__( - self, instance: ConfigManagerT | None, owner: type[ConfigManagerT] | None = None - ) -> _T | Self: + self, instance: ConfigManager | None, owner: type[ConfigManager] | None = None + ) -> _T | OptionDescriptor[_T]: """ Get the configuration option value. @@ -264,7 +261,7 @@ class ConfigManager: """ def __init__(self) -> None: - self._descriptors: dict[str, OptionDescriptor[Any, Config]] = { + self._descriptors: dict[str, OptionDescriptor[Any]] = { name: attr for name, attr in type(self).__dict__.items() if isinstance(attr, OptionDescriptor) @@ -448,8 +445,8 @@ class Config(ConfigManager): #: Verbose flag for DSL compilation errors. Defaults to the value of debug. #: Environment variable: GT4PY_VERBOSE_EXCEPTIONS - verbose_exceptions = OptionDescriptor[bool, "Config"]( - option_type=bool, default_factory=(lambda cfg: cfg.debug) + verbose_exceptions = OptionDescriptor[bool]( + option_type=bool, default_factory=(lambda cfg: cast(Config, cfg).debug) ) ## -- Instrumentation options -- @@ -475,12 +472,12 @@ class BuildCacheLifetime(enum.Enum): #: - PERSISTENT: generated code projects are written to build_cache_dir and persist between runs #: Defaults to PERSISTENT in debug mode, SESSION otherwise. #: Environment variable: GT4PY_BUILD_CACHE_LIFETIME - build_cache_lifetime = OptionDescriptor[BuildCacheLifetime, "Config"]( + build_cache_lifetime = OptionDescriptor[BuildCacheLifetime]( option_type=BuildCacheLifetime, default_factory=( - lambda cfg: cfg.BuildCacheLifetime.PERSISTENT - if cfg.debug - else cfg.BuildCacheLifetime.SESSION + lambda cfg: cast(Config, cfg).BuildCacheLifetime.PERSISTENT + if cast(Config, cfg).debug + else cast(Config, cfg).BuildCacheLifetime.SESSION ), ) @@ -517,10 +514,12 @@ class CMakeBuildType(enum.Enum): #: Might have no effect when CMake is not used as part of the toolchain. #: Environment variable: GT4PY_CMAKE_BUILD_TYPE #: FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. - cmake_build_type = OptionDescriptor[CMakeBuildType, "Config"]( + cmake_build_type = OptionDescriptor[CMakeBuildType]( option_type=CMakeBuildType, default_factory=( - lambda cfg: cfg.CMakeBuildType.DEBUG if cfg.debug else cfg.CMakeBuildType.RELEASE + lambda cfg: cast(Config, cfg).CMakeBuildType.DEBUG + if cast(Config, cfg).debug + else cast(Config, cfg).CMakeBuildType.RELEASE ), ) From a331cb60fd8f4ee7db7aa483fe11e9a750bd4cb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 25 Feb 2026 17:23:38 +0100 Subject: [PATCH 06/12] Fixing more integration issues and adding tests --- src/gt4py/next/_config.py | 51 +- src/gt4py/next/_old_config.py | 121 ---- tests/next_tests/__init__.py | 3 +- .../instrumentation_tests/test_metrics.py | 13 +- .../build_systems_tests/test_compiledb.py | 2 +- .../runners_tests/test_gtfn.py | 20 +- tests/next_tests/unit_tests/test_config.py | 525 ++++++++++++++++++ .../next_tests/unit_tests/test_old_config.py | 48 -- 8 files changed, 565 insertions(+), 218 deletions(-) delete mode 100644 src/gt4py/next/_old_config.py create mode 100644 tests/next_tests/unit_tests/test_config.py delete mode 100644 tests/next_tests/unit_tests/test_old_config.py diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index bcc0bb1152..c9e6728704 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -7,17 +7,20 @@ # SPDX-License-Identifier: BSD-3-Clause """ -Configuration system for GT4Py. +GT4Py configuration system. -Precedence of effective option values (highest to lowest): -1) Active context override (`ConfigManager.overrides`) -2) Global runtime value (`ConfigManager.set`) -3) Environment variable (`OptionDescriptor.env_var_name`) -4) Descriptor default/default_factory +This module defines a typed configuration framework based on descriptors: -Notes: -- Context overrides are task-local via `contextvars`. -- `set()` is disallowed while the same option is context-overridden. +- `OptionDescriptor`: declares one option (type, default/default_factory, parser, + validator, environment variable mapping, and optional update callback). +- `ConfigManager`: stores option values, resolves effective values using precedence, + and supports task-local temporary overrides. +- `Config`: concrete registry of GT4Py public options. + +Configuration can be changed globally via attribute assignment or `set()`, and +temporarily via `overrides()`. + +The public singleton instance is exposed as `gt4py.next.config`. """ from __future__ import annotations @@ -34,7 +37,6 @@ from typing import Any, Final, Generic, Literal, Protocol, TypeVar, cast, final, overload from gt4py.eve import utils -from gt4py.eve.extended_typing import Self @final @@ -128,9 +130,6 @@ def __call__( ) -> None: ... -# ConfigManagerT = TypeVar("ConfigManagerT", bound="ConfigManager") - - @dataclasses.dataclass(frozen=True, kw_only=True) class OptionDescriptor(Generic[_T]): """ @@ -231,11 +230,7 @@ def __set__(self, instance: Any, value: _T) -> None: @property def env_var_name(self) -> str: - """Construct the name of the environment variable corresponding to this option. - - Returns the environment variable name by combining the prefix and option name in uppercase. - E.g., for env_prefix="GT4PY_" and name="debug", returns "GT4PY_DEBUG". - """ + """Construct the name of the environment variable corresponding to this option.""" return f"{self.env_prefix}{self.name}".upper() @@ -292,7 +287,9 @@ def __init__(self) -> None: for name, desc in self._descriptors.items(): assert desc.default_factory is not None # Guaranteed by __post_init__ init_value = parse_env_var( - desc.env_var_name, desc.parser or _parse_str[desc.option_type], default=None + desc.env_var_name, + desc.parser or _parse_str[desc.option_type], + default=desc.default_factory(self), ) if validator := self._validators.get(name): validator(init_value) @@ -355,13 +352,6 @@ def overrides(self, **overrides: Any) -> Generator[None, None, None]: Args: **overrides: Configuration option names and their temporary values. - Yields: - None - - Raises: - AttributeError: If any override name is not a recognized configuration option. - Validation error: If any override value fails validation. - Example: >>> with config.overrides(debug=True, verbose_exceptions=True): ... # Use config with temporary overrides @@ -405,9 +395,6 @@ def as_dict(self) -> dict[str, Any]: Returns all configuration options with their effective values, preserving the order they were defined in the class. - - Returns: - A dictionary mapping option names to their effective values. """ # We use self._descriptors to preserve the order of options as defined in the class. return {name: self.get(name) for name in self._descriptors.keys()} @@ -417,15 +404,13 @@ def _option_descriptors_(self) -> types.MappingProxyType[str, OptionDescriptor]: Returns a read-only mapping of option names to their descriptors. This is useful for introspection and documentation purposes. - - Returns: - A MappingProxyType mapping option names to OptionDescriptor instances. """ return types.MappingProxyType(self._descriptors) class Config(ConfigManager): - """GT4Py configuration registry. + """ + GT4Py configuration manager. This class is used to register and manage all configuration options for GT4Py. All publicly exposed options should be defined here as OptionDescriptor instances. diff --git a/src/gt4py/next/_old_config.py b/src/gt4py/next/_old_config.py deleted file mode 100644 index 8b5c870db5..0000000000 --- a/src/gt4py/next/_old_config.py +++ /dev/null @@ -1,121 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import enum -import os -import pathlib -from typing import Final - - -class BuildCacheLifetime(enum.Enum): - SESSION = 1 - PERSISTENT = 2 - - -class CMakeBuildType(enum.Enum): - """ - CMake build types enum. - - Member values have to be valid CMake syntax. - """ - - DEBUG = "Debug" - RELEASE = "Release" - REL_WITH_DEB_INFO = "RelWithDebInfo" - MIN_SIZE_REL = "MinSizeRel" - - -def env_flag_to_bool(name: str, default: bool) -> bool: - """Convert environment variable string variable to a bool value.""" - flag_value = os.environ.get(name, None) - if flag_value is None: - return default - match flag_value.lower(): - case "0" | "false" | "off": - return False - case "1" | "true" | "on": - return True - case _: - raise ValueError( - "Invalid GT4Py environment flag value: use '0 | false | off' or '1 | true | on'." - ) - - -def env_flag_to_int(name: str, default: int) -> int: - """Convert environment variable string variable to an int value.""" - flag_value = os.environ.get(name, None) - if flag_value is None: - return default - try: - return int(flag_value) - except ValueError: - raise ValueError( - f"Invalid GT4Py environment flag value: {flag_value} is not an integer." - ) from None - - -#: Master debug flag -#: Changes defaults for all the other options to be as helpful for debugging as possible. -#: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) - - -#: Verbose flag for DSL compilation errors -VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( - "GT4PY_VERBOSE_EXCEPTIONS", default=True if DEBUG else False -) - - -#: Where generated code projects should be persisted. -#: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT -BUILD_CACHE_DIR: pathlib.Path = ( - pathlib.Path(os.environ.get("GT4PY_BUILD_CACHE_DIR", pathlib.Path.cwd())) / ".gt4py_cache" -) - - -#: Whether generated code projects should be kept around between runs. -#: - SESSION: generated code projects get destroyed when the interpreter shuts down -#: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs -BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ - os.environ.get("GT4PY_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() -] - -#: Build type to be used when CMake is used to compile generated code. -#: Might have no effect when CMake is not used as part of the toolchain. -# FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. -CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ - os.environ.get("GT4PY_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() -] - -#: Experimental, use at your own risk: assume horizontal dimension has stride 1 -# FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. -UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE: bool = env_flag_to_bool( - "GT4PY_UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE", default=False -) - -#: Add GPU trace markers (NVTX, ROC-TX) to the generated code, at compile time. -# FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. -ADD_GPU_TRACE_MARKERS: bool = env_flag_to_bool("GT4PY_ADD_GPU_TRACE_MARKERS", default=False) - -#: Number of threads to use to use for compilation (0 = synchronous compilation). -#: Default: -#: - use os.cpu_count(), TODO(havogt): in Python >= 3.13 use `process_cpu_count()` -#: - if os.cpu_count() is None we are conservative and use 1 job, -#: - if the number is huge (e.g. HPC system) we limit to a smaller number -BUILD_JOBS: int = int(os.environ.get("GT4PY_BUILD_JOBS", min(os.cpu_count() or 1, 32))) - -#: User-defined level to enable metrics at lower or equal level. -#: Enabling metrics collection will do extra synchronization and will have -#: impact on runtime performance. -COLLECT_METRICS_LEVEL: int = env_flag_to_int("GT4PY_COLLECT_METRICS_LEVEL", default=0) - -#: The default for whether to allow jit-compilation for a compiled program. -#: This default can be overriden per program. -ENABLE_JIT_DEFAULT: bool = env_flag_to_bool("GT4PY_ENABLE_JIT_DEFAULT", default=True) diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index 7449b4dac8..7a094466dd 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -18,7 +18,8 @@ if config.build_cache_lifetime is config.BuildCacheLifetime.PERSISTENT: warnings.warn( - "You are running GT4Py tests with BUILD_CACHE_LIFETIME set to PERSISTENT!", UserWarning + "You are running GT4Py tests with 'config.BuildCacheLifetime' set to PERSISTENT!", + UserWarning, ) diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py index 2499a23546..6e384a9004 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py @@ -15,6 +15,7 @@ import numpy as np import pytest +from gt4py.next import config from gt4py.next.instrumentation import metrics from gt4py.next.otf import arguments @@ -58,7 +59,7 @@ def test_set_current_source_key_different_key_raises(self): class TestSourceKeyContextManager: def test_context_manager_sets_and_resets_key(self): - with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.MINIMAL): + with config.overrides(collect_metrics_level=metrics.MINIMAL): metrics._source_key_cvar.set( metrics._NO_KEY_SET_MARKER_ ) # Reset context variable before test @@ -78,7 +79,7 @@ def test_context_manager_sets_and_resets_key(self): ) def test_context_manager_with_no_key(self): - with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.MINIMAL): + with config.overrides(collect_metrics_level=metrics.MINIMAL): metrics._source_key_cvar.set("__BEFORE__MARKER__") # Reset context variable before test with metrics.SourceKeyContextManager(): @@ -92,7 +93,7 @@ def test_context_manager_with_no_key(self): assert metrics._source_key_cvar.get(metrics._NO_KEY_SET_MARKER_) == "__BEFORE__MARKER__" def test_context_manager_nested(self): - with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.MINIMAL): + with config.overrides(collect_metrics_level=metrics.MINIMAL): metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) key1 = "outer_key" key2 = "inner_key" @@ -121,7 +122,7 @@ class TestCollector( ): ... metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.MINIMAL): + with config.overrides(collect_metrics_level=metrics.MINIMAL): outer_key = "outer_key" metrics.set_current_source_key("outer_key") assert metrics.get_current_source_key() == outer_key @@ -140,7 +141,7 @@ class TestCollector( key = "test_disabled" metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.DISABLED): + with config.overrides(collect_metrics_level=metrics.DISABLED): metrics.set_current_source_key(key) with TestCollector(key=key): @@ -161,7 +162,7 @@ class CustomCollector( key = "test_custom" metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with unittest.mock.patch("gt4py.next.config.collect_metrics_level", metrics.PERFORMANCE): + with config.overrides(collect_metrics_level=metrics.PERFORMANCE): with CustomCollector(key=key): pass diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py index 7ff3525cf8..738e2e2972 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py @@ -45,7 +45,7 @@ def test_compiledb_project_is_relocatable(compilable_source_example, clean_examp builder.build() - with tempfile.TemporaryDirectory(dir=config.BUILD_CACHE_DIR) as tmpdir: + with tempfile.TemporaryDirectory(dir=config.build_cache_dir) as tmpdir: # copy the project to a new location relocated_dir = pathlib.Path(tmpdir) / "relocated" shutil.copytree( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 3d82dd8ee5..98cf6d14c0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -51,12 +51,14 @@ def test_backend_factory_trait_cached(): def test_backend_factory_build_cache_config(monkeypatch): - monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.SESSION) - session_version = gtfn.GTFNBackendFactory() - monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.PERSISTENT) - persistent_version = gtfn.GTFNBackendFactory() + with config.overrides(build_cache_lifetime=config.BuildCacheLifetime.SESSION): + session_version = gtfn.GTFNBackendFactory() assert session_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.SESSION + + with config.overrides(build_cache_lifetime=config.BuildCacheLifetime.PERSISTENT): + persistent_version = gtfn.GTFNBackendFactory() + assert ( persistent_version.executor.compilation.cache_lifetime is config.BuildCacheLifetime.PERSISTENT @@ -64,15 +66,17 @@ def test_backend_factory_build_cache_config(monkeypatch): def test_backend_factory_build_type_config(monkeypatch): - monkeypatch.setattr(config, "CMAKE_BUILD_TYPE", config.CMakeBuildType.RELEASE) - release_version = gtfn.GTFNBackendFactory() - monkeypatch.setattr(config, "CMAKE_BUILD_TYPE", config.CMakeBuildType.MIN_SIZE_REL) - min_size_version = gtfn.GTFNBackendFactory() + with config.overrides(cmake_build_type=config.CMakeBuildType.RELEASE): + release_version = gtfn.GTFNBackendFactory() assert ( release_version.executor.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.RELEASE ) + + with config.overrides(cmake_build_type=config.CMakeBuildType.MIN_SIZE_REL): + min_size_version = gtfn.GTFNBackendFactory() + assert ( min_size_version.executor.compilation.builder_factory.cmake_build_type is config.CMakeBuildType.MIN_SIZE_REL diff --git a/tests/next_tests/unit_tests/test_config.py b/tests/next_tests/unit_tests/test_config.py new file mode 100644 index 0000000000..3b64eabb4b --- /dev/null +++ b/tests/next_tests/unit_tests/test_config.py @@ -0,0 +1,525 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import enum +import os +import pathlib +from typing import Any +from unittest import mock + +import pytest + +from gt4py.next._config import Config, ConfigManager, OptionDescriptor, UpdateScope + + +class TestOptionDescriptorBasics: + """Test basic OptionDescriptor functionality.""" + + def test_descriptor_with_default_value(self) -> None: + """Test that descriptor stores and returns default values.""" + + class TestConfig(ConfigManager): + name = OptionDescriptor(option_type=str, default="test") + + cfg = TestConfig() + assert cfg.name == "test" + + def test_descriptor_with_default_factory(self) -> None: + """Test that descriptor uses default_factory to compute defaults.""" + + class TestConfig(ConfigManager): + base = OptionDescriptor(option_type=int, default=10) + derived = OptionDescriptor( + option_type=int, default_factory=lambda cfg: cfg.get("base") * 2 + ) + + cfg = TestConfig() + assert cfg.derived == 20 + + def test_descriptor_attribute_access(self) -> None: + """Test attribute-style access to configuration options.""" + + class TestConfig(ConfigManager): + debug = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + assert cfg.debug is False + + def test_descriptor_get_method(self) -> None: + """Test get() method returns correct values.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=42) + + cfg = TestConfig() + assert cfg.get("value") == 42 + + def test_descriptor_rejects_unrecognized_option(self) -> None: + """Test that get() raises AttributeError for unknown options.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + with pytest.raises(AttributeError, match="Unrecognized config option"): + cfg.get("nonexistent") + + +class TestConfigurationPrecedence: + """Test configuration value precedence rules.""" + + def test_environment_variable_overrides_default(self) -> None: + """Test that environment variables override descriptor defaults.""" + with mock.patch.dict(os.environ, {"GT4PY_VALUE": "999"}): + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=100) + + cfg = TestConfig() + assert cfg.value == 999 + + def test_context_override_takes_precedence(self) -> None: + """Test that context overrides take precedence over global values.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + assert cfg.opt == 10 + + with cfg.overrides(opt=20): + assert cfg.opt == 20 + + assert cfg.opt == 10 + + def test_context_override_precedence_chain(self) -> None: + """Test complete precedence: context > global > environment > default.""" + with mock.patch.dict(os.environ, {"GT4PY_NUM": "50"}): + + class TestConfig(ConfigManager): + num = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + assert cfg.num == 50 # Environment overrides default + + cfg.set("num", 100) + assert cfg.num == 100 # Global overrides environment + + with cfg.overrides(num=200): + assert cfg.num == 200 # Context overrides global + + assert cfg.num == 100 # Back to global after context + + +class TestSetMethod: + """Test ConfigManager.set() method.""" + + def test_set_changes_global_value(self) -> None: + """Test that set() changes the global configuration value.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + cfg.set("value", 20) + assert cfg.value == 20 + + def test_set_persists_across_accesses(self) -> None: + """Test that set values persist across multiple accesses.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=str, default="old") + + cfg = TestConfig() + cfg.set("opt", "new") + assert cfg.opt == "new" + assert cfg.get("opt") == "new" + + def test_set_rejects_unrecognized_option(self) -> None: + """Test that set() raises AttributeError for unknown options.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + with pytest.raises(AttributeError, match="Unrecognized config option"): + cfg.set("nonexistent", True) + + def test_set_blocked_during_context_override(self) -> None: + """Test that set() is blocked while option is overridden in context.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + with cfg.overrides(opt=20): + with pytest.raises(AttributeError, match="overridden in a context manager"): + cfg.set("opt", 30) + + def test_set_via_attribute_assignment(self) -> None: + """Test that setting via attribute assignment works.""" + + class TestConfig(ConfigManager): + debug = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + cfg.debug = True + assert cfg.debug is True + + +class TestValidation: + """Test configuration option validation.""" + + def test_validator_rejects_invalid_values(self) -> None: + """Test that validators reject invalid values.""" + + def positive_int(val: Any) -> None: + if not isinstance(val, int) or val <= 0: + raise ValueError("Must be positive") + + class TestConfig(ConfigManager): + count = OptionDescriptor(option_type=int, default=1, validator=positive_int) + + cfg = TestConfig() + with pytest.raises(ValueError, match="Must be positive"): + cfg.set("count", -5) + + def test_type_check_validator(self) -> None: + """Test that 'type_check' validator validates types.""" + + class TestConfig(ConfigManager): + name = OptionDescriptor(option_type=str, default="test", validator="type_check") + + cfg = TestConfig() + with pytest.raises(TypeError): + cfg.set("name", 123) + + def test_validator_accepts_valid_values(self) -> None: + """Test that validators accept valid values.""" + + def even_int(val: Any) -> None: + if not isinstance(val, int) or val % 2 != 0: + raise ValueError("Must be even") + + class TestConfig(ConfigManager): + num = OptionDescriptor(option_type=int, default=2, validator=even_int) + + cfg = TestConfig() + cfg.set("num", 42) + assert cfg.num == 42 + + def test_validator_applied_during_context_override(self) -> None: + """Test that validators are applied during context overrides.""" + + def positive(val: Any) -> None: + if val <= 0: + raise ValueError("Must be positive") + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=1, validator=positive) + + cfg = TestConfig() + with pytest.raises(ValueError, match="Must be positive"): + with cfg.overrides(value=-1): + pass + + +class TestContextOverrides: + """Test ConfigManager.overrides() context manager.""" + + def test_override_restores_original_value(self) -> None: + """Test that overrides are reverted when exiting context.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + original = cfg.opt + with cfg.overrides(opt=999): + pass + assert cfg.opt == original + + def test_multiple_option_override(self) -> None: + """Test overriding multiple options simultaneously.""" + + class TestConfig(ConfigManager): + opt1 = OptionDescriptor(option_type=int, default=1) + opt2 = OptionDescriptor(option_type=str, default="a") + opt3 = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + with cfg.overrides(opt1=10, opt2="b", opt3=True): + assert cfg.opt1 == 10 + assert cfg.opt2 == "b" + assert cfg.opt3 is True + + def test_nested_context_overrides(self) -> None: + """Test nested context overrides.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=1) + + cfg = TestConfig() + with cfg.overrides(value=10): + assert cfg.value == 10 + with cfg.overrides(value=20): + assert cfg.value == 20 + assert cfg.value == 10 + assert cfg.value == 1 + + def test_override_rejects_unrecognized_options(self) -> None: + """Test that overrides reject unknown option names.""" + + class TestConfig(ConfigManager): + opt = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + with pytest.raises(AttributeError, match="Unrecognized config options"): + with cfg.overrides(nonexistent=True): + pass + + def test_override_no_change_for_same_value(self) -> None: + """Test that overriding with same value doesn't trigger unnecessary changes.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + with cfg.overrides(value=10): + assert cfg.value == 10 + + +class TestUpdateCallbacks: + """Test option update callbacks.""" + + def test_callback_invoked_on_global_set(self) -> None: + """Test that callbacks are invoked when using set().""" + callback_calls: list[tuple[Any, Any, UpdateScope]] = [] + + def track_changes(new_val: Any, old_val: Any, scope: UpdateScope) -> None: + callback_calls.append((new_val, old_val, scope)) + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10, update_callback=track_changes) + + cfg = TestConfig() + cfg.set("value", 20) + + assert len(callback_calls) == 1 + assert callback_calls[0] == (20, 10, UpdateScope.GLOBAL) + + def test_callback_invoked_on_context_override(self) -> None: + """Test that callbacks are invoked during context overrides.""" + callback_calls: list[tuple[Any, Any, UpdateScope]] = [] + + def track_changes(new_val: Any, old_val: Any, scope: UpdateScope) -> None: + callback_calls.append((new_val, old_val, scope)) + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10, update_callback=track_changes) + + cfg = TestConfig() + with cfg.overrides(value=20): + pass + + # Should have one call on enter and one on exit + assert any(call[2] == UpdateScope.CONTEXT for call in callback_calls) + + def test_no_callback_for_no_change(self) -> None: + """Test that callbacks are not invoked when override value equals current value.""" + callback_calls: list[Any] = [] + + def track_changes(new_val: Any, old_val: Any, scope: UpdateScope) -> None: + callback_calls.append("called") + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10, update_callback=track_changes) + + cfg = TestConfig() + with cfg.overrides(value=10): # Same as default + pass + + assert len(callback_calls) == 0 + + +class TestStringValueParsing: + """Test environment variable parsing and configuration.""" + + @pytest.mark.parametrize( + "value,expected", + [ + ("False", False), + ("false", False), + ("0", False), + ("off", False), + ("True", True), + ("true", True), + ("1", True), + ("on", True), + ], + ) + def test_parse_bool(self, value, expected) -> None: + """Test parsing boolean environment variables.""" + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): + + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + assert cfg.testing_opt is expected + + @pytest.mark.parametrize( + "value,expected", + [ + ("42", 42), + ("-5", -5), + ("0", 0), + ], + ) + def test_parse_int(self, value, expected) -> None: + """Test parsing integer environment variables.""" + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): + + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=int, default=0) + + cfg = TestConfig() + assert cfg.testing_opt == expected + + @pytest.mark.parametrize( + "value,expected", + [ + ("/tmp/test", pathlib.Path("/tmp/test")), + ("./relative/path", pathlib.Path("./relative/path")), + ("~/user/path", pathlib.Path(os.environ["HOME"]) / "user" / "path"), + ], + ) + def test_parse_path(self, value, expected) -> None: + """Test parsing pathlib.Path environment variables.""" + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): + + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=pathlib.Path, default=pathlib.Path("/")) + + cfg = TestConfig() + assert cfg.testing_opt == expected + + def test_custom_parser(self) -> None: + """Test custom parser for environment variables.""" + + def parse_list(s: str) -> list[str]: + return s.split(",") + + with mock.patch.dict(os.environ, {"GT4PY_ITEMS": "a,b,c"}): + + class TestConfig(ConfigManager): + items = OptionDescriptor(option_type=list, default=[], parser=parse_list) + + cfg = TestConfig() + assert cfg.items == ["a", "b", "c"] + + def test_invalid_environment_variable_raises_error(self) -> None: + """Test that invalid environment variables raise RuntimeError.""" + with mock.patch.dict(os.environ, {"GT4PY_NUM": "not_a_number"}): + with pytest.raises(RuntimeError, match="Parsing"): + + class TestConfig(ConfigManager): + num = OptionDescriptor(option_type=int, default=0) + + TestConfig() + + +class TestEnumOptions: + """Test configuration options with enum types.""" + + def test_enum_option_with_default(self) -> None: + """Test enum options work with default values.""" + + class Color(enum.Enum): + RED = "red" + BLUE = "blue" + + class TestConfig(ConfigManager): + color = OptionDescriptor(option_type=Color, default=Color.RED) + + cfg = TestConfig() + assert cfg.color == Color.RED + + def test_enum_option_from_environment(self) -> None: + """Test parsing enum options from environment variables.""" + + class Mode(enum.Enum): + DEBUG = "debug" + RELEASE = "release" + + with mock.patch.dict(os.environ, {"GT4PY_MODE": "DEBUG"}): + + class TestConfig(ConfigManager): + mode = OptionDescriptor(option_type=Mode, default=Mode.RELEASE) + + cfg = TestConfig() + assert cfg.mode == Mode.DEBUG + + +class TestAsDict: + """Test ConfigManager.as_dict() method.""" + + def test_as_dict_returns_all_options(self) -> None: + """Test that as_dict() returns all configuration options.""" + + class TestConfig(ConfigManager): + opt1 = OptionDescriptor(option_type=int, default=1) + opt2 = OptionDescriptor(option_type=str, default="test") + + cfg = TestConfig() + config_dict = cfg.as_dict() + assert config_dict["opt1"] == 1 + assert config_dict["opt2"] == "test" + + def test_as_dict_reflects_current_state(self) -> None: + """Test that as_dict() reflects current configuration state.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + cfg.set("value", 20) + assert cfg.as_dict()["value"] == 20 + + def test_as_dict_reflects_context_overrides(self) -> None: + """Test that as_dict() reflects active context overrides.""" + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=10) + + cfg = TestConfig() + with cfg.overrides(value=99): + assert cfg.as_dict()["value"] == 99 + + +class TestRealConfigClass: + """Test the actual Config class.""" + + def test_config_singleton_works(self) -> None: + """Test that the Config singleton is accessible.""" + assert isinstance(Config, type) + cfg = Config() + assert "debug" in cfg._option_descriptors_() + + def test_debug_option_exists(self) -> None: + """Test that debug option exists and has correct type.""" + cfg = Config() + assert isinstance(cfg.debug, bool) + + def test_build_cache_dir_property(self) -> None: + """Test that build_cache_dir property works.""" + cfg = Config() + assert isinstance(cfg.build_cache_dir, pathlib.Path) + assert str(cfg.build_cache_dir).endswith(".gt4py_cache") diff --git a/tests/next_tests/unit_tests/test_old_config.py b/tests/next_tests/unit_tests/test_old_config.py deleted file mode 100644 index a33bd5734a..0000000000 --- a/tests/next_tests/unit_tests/test_old_config.py +++ /dev/null @@ -1,48 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import os - -import pytest - -from gt4py.next import config - - -@pytest.fixture -def env_var(): - """Just in case another test will ever use that environment variable.""" - env_var_name = "GT4PY_TEST_ENV_VAR" - saved = os.environ.get(env_var_name, None) - yield env_var_name - if saved is not None: - os.environ[env_var_name] = saved - else: - _ = os.environ.pop(env_var_name, None) - - -@pytest.mark.parametrize("value", ["False", "false", "0", "off"]) -def test_env_flag_to_bool_false(env_var, value): - os.environ[env_var] = value - assert config.env_flag_to_bool(env_var, default=True) is False - - -@pytest.mark.parametrize("value", ["True", "true", "1", "on"]) -def test_env_flag_to_bool_true(env_var, value): - os.environ[env_var] = value - assert config.env_flag_to_bool(env_var, default=False) is True - - -def test_env_flag_to_bool_invalid(env_var): - os.environ[env_var] = "invalid value" - with pytest.raises(ValueError): - config.env_flag_to_bool(env_var, default=False) - - -def test_env_flag_to_bool_unset(env_var): - _ = os.environ.pop(env_var, None) - assert config.env_flag_to_bool(env_var, default=False) is False From 7c7b8c518228eec7fe318bfb70e3a42d8865d5f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 25 Feb 2026 18:43:08 +0100 Subject: [PATCH 07/12] Fix merge --- src/gt4py/eve/utils.py | 20 +++--- src/gt4py/next/_config.py | 63 +++++++++++++------ src/gt4py/next/config.py | 0 src/gt4py/next/instrumentation/metrics.py | 23 ++++++- .../instrumentation_tests/test_metrics.py | 56 ++++++++--------- tests/next_tests/unit_tests/test_config.py | 2 +- 6 files changed, 102 insertions(+), 62 deletions(-) delete mode 100644 src/gt4py/next/config.py diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 3674780111..feef130982 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -88,7 +88,9 @@ def first(iterable: Iterable[T], *, default: Union[T, NothingType] = NOTHING) -> raise error -def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], bool]: +def isinstancechecker( + type_info: Union[Type, Iterable[Type], types.UnionType], +) -> Callable[[Any], bool]: """Return a callable object that checks if operand is an instance of `type_info`. Examples: @@ -101,18 +103,20 @@ def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], False """ - types: Tuple[Type, ...] = tuple() + all_types: Tuple[Type, ...] = tuple() if isinstance(type_info, type): - types = (type_info,) + all_types = (type_info,) + elif isinstance(type_info, types.UnionType): + all_types = type_info.__args__ elif not isinstance(type_info, tuple) and is_collection(type_info): - types = tuple(type_info) + all_types = tuple(type_info) else: - types = type_info # type:ignore # it is checked at run-time + all_types = type_info # type:ignore # it is checked at run-time - if not isinstance(types, tuple) or not all(isinstance(t, type) for t in types): - raise ValueError(f"Invalid type(s) definition: '{types}'.") + if not isinstance(all_types, tuple) or not all(isinstance(t, type) for t in all_types): + raise ValueError(f"Invalid type(s) definition: '{all_types}'.") - return lambda obj: isinstance(obj, types) + return lambda obj: isinstance(obj, all_types) def attrchecker(*names: str) -> Callable[[Any], bool]: diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index c9e6728704..5aac3f65c4 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -29,6 +29,7 @@ import contextvars import dataclasses import enum +import functools import os import pathlib import sys @@ -60,7 +61,6 @@ def __repr__(self) -> str: _T = TypeVar("_T") _T_contra = TypeVar("_T_contra", contravariant=True) -_EnumT = TypeVar("_EnumT", bound=enum.Enum) def parse_env_var( @@ -109,6 +109,21 @@ def _parse_str_as_path(value: str) -> pathlib.Path: return pathlib.Path(expanded) +@functools.cache +def _type_check_validator(type_: type) -> Callable[[Any], None]: + """Generate a validator function that checks if a value is an instance of the given type.""" + + is_instance_checker = utils.isinstancechecker(type_) + + def validator(value: Any) -> None: + if not is_instance_checker(value): + raise TypeError( + f"Expected value of type '{type_}', got type '{type(value)}' (value: {value})" + ) + + return validator + + class UpdateScope(str, enum.Enum): """Scope of a configuration option update.""" @@ -149,7 +164,8 @@ class OptionDescriptor(Generic[_T]): validator: Callable that validates the option value, or "type_check" for isinstance checking. Set to None to disable validation. update_callback: Optional callback invoked after the option is updated (globally or in context). - env_prefix: Prefix for the environment variable name. + env_var_parser: Optional parser for environment variable values. + env_var_prefix: Prefix for the environment variable name. name: Name of the option (set automatically via __set_name__). Example: @@ -161,19 +177,19 @@ class OptionDescriptor(Generic[_T]): ... ) """ - option_type: type[_T] + option_type: type[_T] | Any default: dataclasses.InitVar[_T | _UnsetSentinel] = UNSET default_factory: Callable[[ConfigManager], _T] | None = None - parser: Callable[[str], _T] | None = None validator: Callable[[Any], Any] | Literal["type_check"] | None = "type_check" update_callback: OptionUpdateCallback[_T] | None = None - env_prefix: str = "GT4PY_" + env_var_parser: Callable[[str], _T] | None = None + env_var_prefix: str = "GT4PY_" name: str = dataclasses.field(init=False, default="") def __post_init__(self, default: _T | _UnsetSentinel) -> None: # Initialize the validator if self.validator == "type_check": - object.__setattr__(self, "validator", utils.isinstancechecker(self.option_type)) + object.__setattr__(self, "validator", _type_check_validator(self.option_type)) assert self.validator is None or callable(self.validator) # Initialize the default factory based on the provided default/default_factory @@ -231,7 +247,7 @@ def __set__(self, instance: Any, value: _T) -> None: @property def env_var_name(self) -> str: """Construct the name of the environment variable corresponding to this option.""" - return f"{self.env_prefix}{self.name}".upper() + return f"{self.env_var_prefix}{self.name}".upper() class ConfigManager: @@ -277,18 +293,19 @@ def __init__(self) -> None: # instance. Though discouraged in general (values bind to ContextVar identity # and Context objects hold strong references to ContextVars, so they won't be # GC'd even if the instance goes out of scope), in this case we really want - # per-registry isolation and we assume only very few ConfigRegistry instances + # per-registry isolation and we assume only very few ConfigManager instances # will be ever created. self._local_context_cvar = contextvars.ContextVar[Mapping[str, Any]]( f"{self.__class__.__name__}_cvar", default=types.MappingProxyType({}) ) + # Option values initialization with environment variable parsing and validation self._global_context: dict[str, Any] = {} for name, desc in self._descriptors.items(): assert desc.default_factory is not None # Guaranteed by __post_init__ init_value = parse_env_var( desc.env_var_name, - desc.parser or _parse_str[desc.option_type], + desc.env_var_parser or _parse_str[desc.option_type], default=desc.default_factory(self), ) if validator := self._validators.get(name): @@ -305,9 +322,6 @@ def get(self, name: str) -> Any: Returns: The effective value of the option. - - Raises: - AttributeError: If the option name is not recognized. """ if name not in self._keys: raise AttributeError(f"Unrecognized config option: {name}") @@ -323,11 +337,6 @@ def set(self, name: str, value: Any) -> None: Args: name: The name of the configuration option. value: The new value for the option. - - Raises: - AttributeError: If the option name is not recognized, or if the option - is currently overridden in a context manager. - Validation error: If the value fails validation. """ if name not in self._keys: raise AttributeError(f"Unrecognized config option: {name}") @@ -408,6 +417,15 @@ def _option_descriptors_(self) -> types.MappingProxyType[str, OptionDescriptor]: return types.MappingProxyType(self._descriptors) +def _parse_dump_metrics_filename(value: str) -> bool | pathlib.Path: + try: + return _parse_str[bool](value) + except Exception: + # If parsing as a bool fails, try parsing as a path. + # This allows users to specify a file path or a boolean value for this option. + return _parse_str[pathlib.Path](value) + + class Config(ConfigManager): """ GT4Py configuration manager. @@ -424,9 +442,7 @@ class Config(ConfigManager): ## -- Debug options -- #: Master debug flag. It changes defaults for all the other options to be as helpful #: for debugging as possible. Environment variable: GT4PY_DEBUG - debug = OptionDescriptor( - option_type=bool, default=False, validator=utils.isinstancechecker(bool) - ) + debug = OptionDescriptor(option_type=bool, default=False) #: Verbose flag for DSL compilation errors. Defaults to the value of debug. #: Environment variable: GT4PY_VERBOSE_EXCEPTIONS @@ -445,6 +461,13 @@ class Config(ConfigManager): #: FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. add_gpu_trace_markers = OptionDescriptor(option_type=bool, default=False) + #: File path to dump collected metrics at exit, if GT4PY_COLLECT_METRICS_LEVEL is enabled. + #: If set to a True value, it defaults to "gt4py_metrics_YYYYMMDD_HHMMSS.json" in + #: the current folder. + dump_metrics_at_exit = OptionDescriptor( + option_type=bool | pathlib.Path, default=False, env_var_parser=_parse_dump_metrics_filename + ) + ## -- Build options -- class BuildCacheLifetime(enum.Enum): """Build cache lifetime modes.""" diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 25492c13dd..9599019672 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -441,17 +441,34 @@ def dump_json( pathlib.Path(filename).write_text(dumps_json(metric_sources)) +def _init_dump_metrics_filename() -> pathlib.Path: + return pathlib.Path(f"gt4py_metrics_{time.strftime('%Y%m%d_%H%M%S')}.json") + + # Handler registration to automatically dump metrics at program exit if # the corresponding configuration flag is set. def _dump_metrics_at_exit() -> None: """Dump collected metrics to a file at program exit if required.""" # It is assumed that 'gt4py.next.config' is still alive at this point - if config.DUMP_METRICS_AT_EXIT and (is_any_level_enabled() or sources): + match config.dump_metrics_at_exit: + case False: + metrics_dump_file = None + case True: + metrics_dump_file = _init_dump_metrics_filename() + case pathlib.Path() as user_path: + metrics_dump_file = user_path + case _: + assert False, ( + f"Invalid type for 'dump_metrics_at_exit' config option: {config.dump_metrics_at_exit}" + f"({type(config.dump_metrics_at_exit)})" + ) + + if metrics_dump_file is not None and (is_any_level_enabled() or sources): try: - pathlib.Path(config.DUMP_METRICS_AT_EXIT).write_text(dumps_json()) + metrics_dump_file.write_text(dumps_json()) print( - f"--- atexit: GT4Py performance metrics saved at {config.DUMP_METRICS_AT_EXIT} ---", + f"--- atexit: GT4Py performance metrics saved at {metrics_dump_file} ---", file=sys.stderr, ) except Exception as e: diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py index a7d94e22ca..c249cf3590 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py @@ -15,7 +15,7 @@ import numpy as np import pytest -from gt4py.next import config +from gt4py.next import config as gt_config from gt4py.next.instrumentation import metrics from gt4py.next.otf import arguments @@ -59,7 +59,7 @@ def test_set_current_source_key_different_key_raises(self): class TestSourceKeyContextManager: def test_context_manager_sets_and_resets_key(self): - with config.overrides(collect_metrics_level=metrics.MINIMAL): + with gt_config.overrides(collect_metrics_level=metrics.MINIMAL): metrics._source_key_cvar.set( metrics._NO_KEY_SET_MARKER_ ) # Reset context variable before test @@ -79,7 +79,7 @@ def test_context_manager_sets_and_resets_key(self): ) def test_context_manager_with_no_key(self): - with config.overrides(collect_metrics_level=metrics.MINIMAL): + with gt_config.overrides(collect_metrics_level=metrics.MINIMAL): metrics._source_key_cvar.set("__BEFORE__MARKER__") # Reset context variable before test with metrics.SourceKeyContextManager(): @@ -93,7 +93,7 @@ def test_context_manager_with_no_key(self): assert metrics._source_key_cvar.get(metrics._NO_KEY_SET_MARKER_) == "__BEFORE__MARKER__" def test_context_manager_nested(self): - with config.overrides(collect_metrics_level=metrics.MINIMAL): + with gt_config.overrides(collect_metrics_level=metrics.MINIMAL): metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) key1 = "outer_key" key2 = "inner_key" @@ -122,7 +122,7 @@ class TestCollector( ): ... metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with config.overrides(collect_metrics_level=metrics.MINIMAL): + with gt_config.overrides(collect_metrics_level=metrics.MINIMAL): outer_key = "outer_key" metrics.set_current_source_key("outer_key") assert metrics.get_current_source_key() == outer_key @@ -141,7 +141,7 @@ class TestCollector( key = "test_disabled" metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with config.overrides(collect_metrics_level=metrics.DISABLED): + with gt_config.overrides(collect_metrics_level=metrics.DISABLED): metrics.set_current_source_key(key) with TestCollector(key=key): @@ -162,7 +162,7 @@ class CustomCollector( key = "test_custom" metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) - with config.overrides(collect_metrics_level=metrics.PERFORMANCE): + with gt_config.overrides(collect_metrics_level=metrics.PERFORMANCE): with CustomCollector(key=key): pass @@ -286,21 +286,17 @@ def sample_source_metrics(sample_source_metadata: dict[str, Any]) -> Mapping[str return { "program1": metrics.Source( metadata={"description": "Test program 1", **sample_source_metadata}, - metrics=metrics.MetricsCollection( - **{ - metrics.COMPUTE_METRIC: metrics.Metric(samples=[1.0, 2.0, 3.0]), - metrics.TOTAL_METRIC: metrics.Metric(samples=[4.0, 5.0, 6.0]), - } - ), + metrics=metrics.MetricsCollection(**{ + metrics.COMPUTE_METRIC: metrics.Metric(samples=[1.0, 2.0, 3.0]), + metrics.TOTAL_METRIC: metrics.Metric(samples=[4.0, 5.0, 6.0]), + }), ), "program2": metrics.Source( metadata={"description": "Test program 2", **sample_source_metadata}, - metrics=metrics.MetricsCollection( - **{ - metrics.COMPUTE_METRIC: metrics.Metric(samples=[10.0, 20.0, 30.0]), - metrics.TOTAL_METRIC: metrics.Metric(samples=[40.0, 50.0, 60.0]), - } - ), + metrics=metrics.MetricsCollection(**{ + metrics.COMPUTE_METRIC: metrics.Metric(samples=[10.0, 20.0, 30.0]), + metrics.TOTAL_METRIC: metrics.Metric(samples=[40.0, 50.0, 60.0]), + }), ), } @@ -430,7 +426,7 @@ def test_dump_json(sample_source_metrics: Mapping[str, metrics.Source], tmp_path class TestDumpMetricsAtExit: - @pytest.mark.parametrize("mode", ["explicit", "auto", None]) + @pytest.mark.parametrize("mode", ["explicit", "auto", False]) def test_dump_metrics_at_exit_enabled( self, sample_source_metrics: Mapping[str, metrics.Source], @@ -438,29 +434,29 @@ def test_dump_metrics_at_exit_enabled( mode: str | None, ): """Test _dump_metrics_at_exit writes to a file when enabled.""" - explicit_output_filename = str(tmp_path / "explicit_metrics.json") - auto_output_filename = str(tmp_path / gt_config._init_dump_metrics_filename()) + explicit_output_filename = tmp_path / "explicit_metrics.json" + auto_output_filename = tmp_path / metrics._init_dump_metrics_filename() if mode == "explicit": output_filename = explicit_output_filename elif mode == "auto": output_filename = auto_output_filename else: - output_filename = None + output_filename = False - with unittest.mock.patch("gt4py.next.config.DUMP_METRICS_AT_EXIT", output_filename): + with gt_config.overrides(dump_metrics_at_exit=output_filename): with unittest.mock.patch( "gt4py.next.instrumentation.metrics.sources", sample_source_metrics ): metrics._dump_metrics_at_exit() - assert (output_filename is None) == (mode is None) + assert (output_filename is False) == (mode is False) if output_filename: - assert pathlib.Path(output_filename).exists() - data = json.loads(pathlib.Path(output_filename).read_text()) + assert output_filename.exists() + data = json.loads(output_filename.read_text()) assert "program1" in data assert "program2" in data - pathlib.Path(output_filename).unlink() # Clean up after test + output_filename.unlink() # Clean up after test else: - assert not pathlib.Path(explicit_output_filename).exists() - assert not pathlib.Path(auto_output_filename).exists() + assert not explicit_output_filename.exists() + assert not auto_output_filename.exists() diff --git a/tests/next_tests/unit_tests/test_config.py b/tests/next_tests/unit_tests/test_config.py index 3b64eabb4b..d6ed668380 100644 --- a/tests/next_tests/unit_tests/test_config.py +++ b/tests/next_tests/unit_tests/test_config.py @@ -420,7 +420,7 @@ def parse_list(s: str) -> list[str]: with mock.patch.dict(os.environ, {"GT4PY_ITEMS": "a,b,c"}): class TestConfig(ConfigManager): - items = OptionDescriptor(option_type=list, default=[], parser=parse_list) + items = OptionDescriptor(option_type=list, default=[], env_var_parser=parse_list) cfg = TestConfig() assert cfg.items == ["a", "b", "c"] From 99e956a49a1755479430271c7c91588dea3207ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 26 Feb 2026 12:24:08 +0100 Subject: [PATCH 08/12] More test fixes --- src/gt4py/next/_config.py | 10 +- .../instrumentation_tests/test_metrics.py | 20 +- .../{test_languages.py => test_code_specs.py} | 0 tests/next_tests/unit_tests/test_config.py | 489 +++++++----------- 4 files changed, 216 insertions(+), 303 deletions(-) rename tests/next_tests/unit_tests/otf_tests/{test_languages.py => test_code_specs.py} (100%) diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index 5aac3f65c4..a23be8fb61 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -82,12 +82,10 @@ def parse_env_var( @utils.TypeMapping def _parse_str(type_: type) -> Callable[[str], Any]: """Default parser: the type string value as is.""" - match type_: - case enum.Enum() as enum_type: - assert issubclass(enum_type, enum.Enum) - return lambda value: enum_type[value] # parse enum values from their names - case _: - return lambda x: 1 # type constructor as parser + if issubclass(type_, enum.Enum): + return lambda value: type_[value] # parse enum values from their names + + return lambda x: type_(x) # type constructor as parser @_parse_str.register(bool) diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py index c249cf3590..0a571e1f24 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py @@ -286,17 +286,21 @@ def sample_source_metrics(sample_source_metadata: dict[str, Any]) -> Mapping[str return { "program1": metrics.Source( metadata={"description": "Test program 1", **sample_source_metadata}, - metrics=metrics.MetricsCollection(**{ - metrics.COMPUTE_METRIC: metrics.Metric(samples=[1.0, 2.0, 3.0]), - metrics.TOTAL_METRIC: metrics.Metric(samples=[4.0, 5.0, 6.0]), - }), + metrics=metrics.MetricsCollection( + **{ + metrics.COMPUTE_METRIC: metrics.Metric(samples=[1.0, 2.0, 3.0]), + metrics.TOTAL_METRIC: metrics.Metric(samples=[4.0, 5.0, 6.0]), + } + ), ), "program2": metrics.Source( metadata={"description": "Test program 2", **sample_source_metadata}, - metrics=metrics.MetricsCollection(**{ - metrics.COMPUTE_METRIC: metrics.Metric(samples=[10.0, 20.0, 30.0]), - metrics.TOTAL_METRIC: metrics.Metric(samples=[40.0, 50.0, 60.0]), - }), + metrics=metrics.MetricsCollection( + **{ + metrics.COMPUTE_METRIC: metrics.Metric(samples=[10.0, 20.0, 30.0]), + metrics.TOTAL_METRIC: metrics.Metric(samples=[40.0, 50.0, 60.0]), + } + ), ), } diff --git a/tests/next_tests/unit_tests/otf_tests/test_languages.py b/tests/next_tests/unit_tests/otf_tests/test_code_specs.py similarity index 100% rename from tests/next_tests/unit_tests/otf_tests/test_languages.py rename to tests/next_tests/unit_tests/otf_tests/test_code_specs.py diff --git a/tests/next_tests/unit_tests/test_config.py b/tests/next_tests/unit_tests/test_config.py index d6ed668380..1be354c6c8 100644 --- a/tests/next_tests/unit_tests/test_config.py +++ b/tests/next_tests/unit_tests/test_config.py @@ -22,6 +22,15 @@ class TestOptionDescriptorBasics: """Test basic OptionDescriptor functionality.""" + def test_descriptor_attribute_access(self) -> None: + """Test attribute-style access to configuration options.""" + + class TestConfig(ConfigManager): + debug = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + assert cfg.debug is False + def test_descriptor_with_default_value(self) -> None: """Test that descriptor stores and returns default values.""" @@ -43,83 +52,111 @@ class TestConfig(ConfigManager): cfg = TestConfig() assert cfg.derived == 20 - def test_descriptor_attribute_access(self) -> None: - """Test attribute-style access to configuration options.""" - - class TestConfig(ConfigManager): - debug = OptionDescriptor(option_type=bool, default=False) - - cfg = TestConfig() - assert cfg.debug is False - def test_descriptor_get_method(self) -> None: - """Test get() method returns correct values.""" - - class TestConfig(ConfigManager): - value = OptionDescriptor(option_type=int, default=42) +class TestStringValueParsing: + """Test environment variable parsing and configuration.""" - cfg = TestConfig() - assert cfg.get("value") == 42 + @pytest.mark.parametrize( + "value,expected", + [ + ("False", False), + ("false", False), + ("0", False), + ("off", False), + ("True", True), + ("true", True), + ("1", True), + ("on", True), + ], + ) + def test_parse_bool(self, value, expected) -> None: + """Test parsing boolean environment variables.""" + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): - def test_descriptor_rejects_unrecognized_option(self) -> None: - """Test that get() raises AttributeError for unknown options.""" + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=bool, default=False) - class TestConfig(ConfigManager): - opt = OptionDescriptor(option_type=bool, default=False) + cfg = TestConfig() + assert cfg.testing_opt is expected - cfg = TestConfig() - with pytest.raises(AttributeError, match="Unrecognized config option"): - cfg.get("nonexistent") + @pytest.mark.parametrize( + "value,expected", + [ + ("42", 42), + ("-5", -5), + ("0", 0), + ], + ) + def test_parse_int(self, value, expected) -> None: + """Test parsing integer environment variables.""" + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=int, default=0) -class TestConfigurationPrecedence: - """Test configuration value precedence rules.""" + cfg = TestConfig() + assert cfg.testing_opt == expected - def test_environment_variable_overrides_default(self) -> None: - """Test that environment variables override descriptor defaults.""" - with mock.patch.dict(os.environ, {"GT4PY_VALUE": "999"}): + @pytest.mark.parametrize( + "value,expected", + [ + ("/tmp/test", pathlib.Path("/tmp/test")), + ("./relative/path", pathlib.Path("./relative/path")), + ("~/user/path", pathlib.Path(os.environ["HOME"]) / "user" / "path"), + ], + ) + def test_parse_path(self, value, expected) -> None: + """Test parsing pathlib.Path environment variables.""" + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): class TestConfig(ConfigManager): - value = OptionDescriptor(option_type=int, default=100) + testing_opt = OptionDescriptor(option_type=pathlib.Path, default=pathlib.Path("/")) cfg = TestConfig() - assert cfg.value == 999 + assert cfg.testing_opt == expected - def test_context_override_takes_precedence(self) -> None: - """Test that context overrides take precedence over global values.""" + def test_parse_enum(self) -> None: + """Test parsing enum options from environment variables.""" - class TestConfig(ConfigManager): - opt = OptionDescriptor(option_type=int, default=10) + class Mode(enum.Enum): + DEBUG = "debug" + RELEASE = "release" - cfg = TestConfig() - assert cfg.opt == 10 + with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": "DEBUG"}): - with cfg.overrides(opt=20): - assert cfg.opt == 20 + class TestConfig(ConfigManager): + testing_opt = OptionDescriptor(option_type=Mode, default=Mode.RELEASE) - assert cfg.opt == 10 + cfg = TestConfig() + assert cfg.testing_opt == Mode.DEBUG - def test_context_override_precedence_chain(self) -> None: - """Test complete precedence: context > global > environment > default.""" - with mock.patch.dict(os.environ, {"GT4PY_NUM": "50"}): + def test_custom_parser(self) -> None: + """Test custom parser for environment variables.""" + + def parse_list(s: str) -> list[str]: + return s.split(",") + + with mock.patch.dict(os.environ, {"GT4PY_ITEMS": "a,b,c"}): class TestConfig(ConfigManager): - num = OptionDescriptor(option_type=int, default=10) + items = OptionDescriptor(option_type=list, default=[], env_var_parser=parse_list) cfg = TestConfig() - assert cfg.num == 50 # Environment overrides default + assert cfg.items == ["a", "b", "c"] - cfg.set("num", 100) - assert cfg.num == 100 # Global overrides environment + def test_invalid_environment_variable_raises_error(self) -> None: + """Test that invalid environment variables raise RuntimeError.""" + with mock.patch.dict(os.environ, {"GT4PY_NUM": "not_a_number"}): + with pytest.raises(RuntimeError, match="Parsing"): - with cfg.overrides(num=200): - assert cfg.num == 200 # Context overrides global + class TestConfig(ConfigManager): + num = OptionDescriptor(option_type=int, default=0) - assert cfg.num == 100 # Back to global after context + TestConfig() -class TestSetMethod: - """Test ConfigManager.set() method.""" +class TestConfigManagerBasics: + """Test ConfigManager basic functionality.""" def test_set_changes_global_value(self) -> None: """Test that set() changes the global configuration value.""" @@ -131,16 +168,25 @@ class TestConfig(ConfigManager): cfg.set("value", 20) assert cfg.value == 20 - def test_set_persists_across_accesses(self) -> None: - """Test that set values persist across multiple accesses.""" + def test_set_via_attribute_assignment(self) -> None: + """Test that setting via attribute assignment works.""" + + class TestConfig(ConfigManager): + debug = OptionDescriptor(option_type=bool, default=False) + + cfg = TestConfig() + cfg.debug = True + assert cfg.debug is True + + def test_get_rejects_unrecognized_option(self) -> None: + """Test that get() raises AttributeError for unknown options.""" class TestConfig(ConfigManager): - opt = OptionDescriptor(option_type=str, default="old") + opt = OptionDescriptor(option_type=bool, default=False) cfg = TestConfig() - cfg.set("opt", "new") - assert cfg.opt == "new" - assert cfg.get("opt") == "new" + with pytest.raises(AttributeError, match="Unrecognized config option"): + cfg.get("nonexistent") def test_set_rejects_unrecognized_option(self) -> None: """Test that set() raises AttributeError for unknown options.""" @@ -163,88 +209,59 @@ class TestConfig(ConfigManager): with pytest.raises(AttributeError, match="overridden in a context manager"): cfg.set("opt", 30) - def test_set_via_attribute_assignment(self) -> None: - """Test that setting via attribute assignment works.""" - - class TestConfig(ConfigManager): - debug = OptionDescriptor(option_type=bool, default=False) - - cfg = TestConfig() - cfg.debug = True - assert cfg.debug is True - - -class TestValidation: - """Test configuration option validation.""" - - def test_validator_rejects_invalid_values(self) -> None: - """Test that validators reject invalid values.""" - - def positive_int(val: Any) -> None: - if not isinstance(val, int) or val <= 0: - raise ValueError("Must be positive") + def test_as_dict_returns_all_options(self) -> None: + """Test that as_dict() returns all configuration options.""" class TestConfig(ConfigManager): - count = OptionDescriptor(option_type=int, default=1, validator=positive_int) + opt1 = OptionDescriptor(option_type=int, default=1) + opt2 = OptionDescriptor(option_type=str, default="test") cfg = TestConfig() - with pytest.raises(ValueError, match="Must be positive"): - cfg.set("count", -5) + config_dict = cfg.as_dict() + assert config_dict["opt1"] == 1 + assert config_dict["opt2"] == "test" - def test_type_check_validator(self) -> None: - """Test that 'type_check' validator validates types.""" + def test_as_dict_reflects_context_overrides(self) -> None: + """Test that as_dict() reflects active context overrides.""" class TestConfig(ConfigManager): - name = OptionDescriptor(option_type=str, default="test", validator="type_check") + value = OptionDescriptor(option_type=int, default=10) cfg = TestConfig() - with pytest.raises(TypeError): - cfg.set("name", 123) - - def test_validator_accepts_valid_values(self) -> None: - """Test that validators accept valid values.""" - - def even_int(val: Any) -> None: - if not isinstance(val, int) or val % 2 != 0: - raise ValueError("Must be even") + with cfg.overrides(value=99): + assert cfg.as_dict()["value"] == 99 - class TestConfig(ConfigManager): - num = OptionDescriptor(option_type=int, default=2, validator=even_int) - cfg = TestConfig() - cfg.set("num", 42) - assert cfg.num == 42 +class TestConfigurationPrecedence: + """Test configuration value precedence rules.""" - def test_validator_applied_during_context_override(self) -> None: - """Test that validators are applied during context overrides.""" + def test_environment_variable_overrides_default(self) -> None: + """Test that environment variables override descriptor defaults.""" + with mock.patch.dict(os.environ, {"GT4PY_VALUE": "999"}): - def positive(val: Any) -> None: - if val <= 0: - raise ValueError("Must be positive") + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=100) - class TestConfig(ConfigManager): - value = OptionDescriptor(option_type=int, default=1, validator=positive) + cfg = TestConfig() + assert cfg.value == 999 - cfg = TestConfig() - with pytest.raises(ValueError, match="Must be positive"): - with cfg.overrides(value=-1): - pass + def test_context_override_precedence_chain(self) -> None: + """Test complete precedence: context > global > environment > default.""" + with mock.patch.dict(os.environ, {"GT4PY_NUM": "50"}): + class TestConfig(ConfigManager): + num = OptionDescriptor(option_type=int, default=10) -class TestContextOverrides: - """Test ConfigManager.overrides() context manager.""" + cfg = TestConfig() + assert cfg.num == 50 # Environment overrides default - def test_override_restores_original_value(self) -> None: - """Test that overrides are reverted when exiting context.""" + cfg.set("num", 100) + assert cfg.num == 100 # Global overrides environment - class TestConfig(ConfigManager): - opt = OptionDescriptor(option_type=int, default=10) + with cfg.overrides(num=200): + assert cfg.num == 200 # Context overrides global - cfg = TestConfig() - original = cfg.opt - with cfg.overrides(opt=999): - pass - assert cfg.opt == original + assert cfg.num == 100 # Back to global after context def test_multiple_option_override(self) -> None: """Test overriding multiple options simultaneously.""" @@ -296,6 +313,63 @@ class TestConfig(ConfigManager): assert cfg.value == 10 +class TestValidation: + """Test configuration option validation.""" + + def test_validator_rejects_invalid_values(self) -> None: + """Test that validators reject invalid values.""" + + def positive_int(val: Any) -> None: + if not isinstance(val, int) or val <= 0: + raise ValueError("Must be positive") + + class TestConfig(ConfigManager): + count = OptionDescriptor(option_type=int, default=1, validator=positive_int) + + cfg = TestConfig() + with pytest.raises(ValueError, match="Must be positive"): + cfg.set("count", -5) + + def test_type_check_validator(self) -> None: + """Test that 'type_check' validator validates types.""" + + class TestConfig(ConfigManager): + name = OptionDescriptor(option_type=str, default="test", validator="type_check") + + cfg = TestConfig() + with pytest.raises(TypeError): + cfg.set("name", 123) + + def test_validator_accepts_valid_values(self) -> None: + """Test that validators accept valid values.""" + + def even_int(val: Any) -> None: + if not isinstance(val, int) or val % 2 != 0: + raise ValueError("Must be even") + + class TestConfig(ConfigManager): + num = OptionDescriptor(option_type=int, default=2, validator=even_int) + + cfg = TestConfig() + cfg.set("num", 42) + assert cfg.num == 42 + + def test_validator_applied_during_context_override(self) -> None: + """Test that validators are applied during context overrides.""" + + def positive(val: Any) -> None: + if val <= 0: + raise ValueError("Must be positive") + + class TestConfig(ConfigManager): + value = OptionDescriptor(option_type=int, default=1, validator=positive) + + cfg = TestConfig() + with pytest.raises(ValueError, match="Must be positive"): + with cfg.overrides(value=-1): + pass + + class TestUpdateCallbacks: """Test option update callbacks.""" @@ -349,177 +423,14 @@ class TestConfig(ConfigManager): assert len(callback_calls) == 0 -class TestStringValueParsing: - """Test environment variable parsing and configuration.""" - - @pytest.mark.parametrize( - "value,expected", - [ - ("False", False), - ("false", False), - ("0", False), - ("off", False), - ("True", True), - ("true", True), - ("1", True), - ("on", True), - ], - ) - def test_parse_bool(self, value, expected) -> None: - """Test parsing boolean environment variables.""" - with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): - - class TestConfig(ConfigManager): - testing_opt = OptionDescriptor(option_type=bool, default=False) - - cfg = TestConfig() - assert cfg.testing_opt is expected - - @pytest.mark.parametrize( - "value,expected", - [ - ("42", 42), - ("-5", -5), - ("0", 0), - ], - ) - def test_parse_int(self, value, expected) -> None: - """Test parsing integer environment variables.""" - with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): - - class TestConfig(ConfigManager): - testing_opt = OptionDescriptor(option_type=int, default=0) - - cfg = TestConfig() - assert cfg.testing_opt == expected - - @pytest.mark.parametrize( - "value,expected", - [ - ("/tmp/test", pathlib.Path("/tmp/test")), - ("./relative/path", pathlib.Path("./relative/path")), - ("~/user/path", pathlib.Path(os.environ["HOME"]) / "user" / "path"), - ], - ) - def test_parse_path(self, value, expected) -> None: - """Test parsing pathlib.Path environment variables.""" - with mock.patch.dict(os.environ, {"GT4PY_TESTING_OPT": value}): - - class TestConfig(ConfigManager): - testing_opt = OptionDescriptor(option_type=pathlib.Path, default=pathlib.Path("/")) - - cfg = TestConfig() - assert cfg.testing_opt == expected - - def test_custom_parser(self) -> None: - """Test custom parser for environment variables.""" - - def parse_list(s: str) -> list[str]: - return s.split(",") - - with mock.patch.dict(os.environ, {"GT4PY_ITEMS": "a,b,c"}): - - class TestConfig(ConfigManager): - items = OptionDescriptor(option_type=list, default=[], env_var_parser=parse_list) - - cfg = TestConfig() - assert cfg.items == ["a", "b", "c"] - - def test_invalid_environment_variable_raises_error(self) -> None: - """Test that invalid environment variables raise RuntimeError.""" - with mock.patch.dict(os.environ, {"GT4PY_NUM": "not_a_number"}): - with pytest.raises(RuntimeError, match="Parsing"): - - class TestConfig(ConfigManager): - num = OptionDescriptor(option_type=int, default=0) - - TestConfig() - - -class TestEnumOptions: - """Test configuration options with enum types.""" - - def test_enum_option_with_default(self) -> None: - """Test enum options work with default values.""" - - class Color(enum.Enum): - RED = "red" - BLUE = "blue" - - class TestConfig(ConfigManager): - color = OptionDescriptor(option_type=Color, default=Color.RED) - - cfg = TestConfig() - assert cfg.color == Color.RED - - def test_enum_option_from_environment(self) -> None: - """Test parsing enum options from environment variables.""" - - class Mode(enum.Enum): - DEBUG = "debug" - RELEASE = "release" - - with mock.patch.dict(os.environ, {"GT4PY_MODE": "DEBUG"}): - - class TestConfig(ConfigManager): - mode = OptionDescriptor(option_type=Mode, default=Mode.RELEASE) - - cfg = TestConfig() - assert cfg.mode == Mode.DEBUG - - -class TestAsDict: - """Test ConfigManager.as_dict() method.""" - - def test_as_dict_returns_all_options(self) -> None: - """Test that as_dict() returns all configuration options.""" - - class TestConfig(ConfigManager): - opt1 = OptionDescriptor(option_type=int, default=1) - opt2 = OptionDescriptor(option_type=str, default="test") - - cfg = TestConfig() - config_dict = cfg.as_dict() - assert config_dict["opt1"] == 1 - assert config_dict["opt2"] == "test" - - def test_as_dict_reflects_current_state(self) -> None: - """Test that as_dict() reflects current configuration state.""" - - class TestConfig(ConfigManager): - value = OptionDescriptor(option_type=int, default=10) - - cfg = TestConfig() - cfg.set("value", 20) - assert cfg.as_dict()["value"] == 20 - - def test_as_dict_reflects_context_overrides(self) -> None: - """Test that as_dict() reflects active context overrides.""" - - class TestConfig(ConfigManager): - value = OptionDescriptor(option_type=int, default=10) - - cfg = TestConfig() - with cfg.overrides(value=99): - assert cfg.as_dict()["value"] == 99 - - -class TestRealConfigClass: - """Test the actual Config class.""" +def test_gt4py_config_class() -> None: + """Test the actual Config class for GT4Py.""" - def test_config_singleton_works(self) -> None: - """Test that the Config singleton is accessible.""" - assert isinstance(Config, type) - cfg = Config() - assert "debug" in cfg._option_descriptors_() + assert isinstance(Config, type) + cfg = Config() - def test_debug_option_exists(self) -> None: - """Test that debug option exists and has correct type.""" - cfg = Config() - assert isinstance(cfg.debug, bool) + assert "debug" in cfg._option_descriptors_() + assert isinstance(cfg.debug, bool) - def test_build_cache_dir_property(self) -> None: - """Test that build_cache_dir property works.""" - cfg = Config() - assert isinstance(cfg.build_cache_dir, pathlib.Path) - assert str(cfg.build_cache_dir).endswith(".gt4py_cache") + assert isinstance(cfg.build_cache_dir, pathlib.Path) + assert str(cfg.build_cache_dir).endswith(".gt4py_cache") From 0257bc5d0feacc75925bbc65286af056a307044b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 26 Feb 2026 14:09:14 +0100 Subject: [PATCH 09/12] More typing fixes --- src/gt4py/eve/utils.py | 30 +++++----- src/gt4py/next/__init__.py | 5 +- src/gt4py/next/_config.py | 59 +++++++++++-------- src/gt4py/next/embedded/nd_array_field.py | 3 +- src/gt4py/next/instrumentation/metrics.py | 9 +-- .../otf/compilation/build_systems/cmake.py | 12 ++-- .../compilation/build_systems/compiledb.py | 19 +++--- src/gt4py/next/otf/compilation/cache.py | 7 ++- src/gt4py/next/otf/compilation/compiler.py | 11 ++-- .../transformations/concat_where_mapper.py | 2 +- .../runners/dace/workflow/common.py | 9 ++- .../runners/dace/workflow/compilation.py | 10 +++- .../runners/dace/workflow/factory.py | 7 ++- .../next/program_processors/runners/gtfn.py | 8 ++- 14 files changed, 116 insertions(+), 75 deletions(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index feef130982..10454b1b44 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -103,20 +103,22 @@ def isinstancechecker( False """ - all_types: Tuple[Type, ...] = tuple() + accepted_types: Tuple[Type, ...] = tuple() if isinstance(type_info, type): - all_types = (type_info,) + accepted_types = (type_info,) elif isinstance(type_info, types.UnionType): - all_types = type_info.__args__ + accepted_types = type_info.__args__ elif not isinstance(type_info, tuple) and is_collection(type_info): - all_types = tuple(type_info) + accepted_types = tuple(type_info) else: - all_types = type_info # type:ignore # it is checked at run-time + accepted_types = type_info # type:ignore # it is checked at run-time - if not isinstance(all_types, tuple) or not all(isinstance(t, type) for t in all_types): - raise ValueError(f"Invalid type(s) definition: '{all_types}'.") + if not isinstance(accepted_types, tuple) or not all( + isinstance(t, type) for t in accepted_types + ): + raise ValueError(f"Invalid type(s) definition: '{accepted_types}'.") - return lambda obj: isinstance(obj, all_types) + return lambda obj: isinstance(obj, accepted_types) def attrchecker(*names: str) -> Callable[[Any], bool]: @@ -535,11 +537,11 @@ class TypeMapping(collections.abc.Mapping[type, _T]): """ A mapping from types to values supporting complex type-based dispatching. - The mapping supports registering values for specific types, and retrieving - values based on the type of the key, including support for inheritance - exactly in the same way as `functools.singledispatch()` works. For example, - if a value is registered for a base class, it will be returned for - instances of derived classes unless a more specific type is registered. + The mapping supports registering values for specific types, and + retrieving values based on the type key, supporting subtyping + relationship exactly in the same way as `functools.singledispatch()` works. + For example, if a value is registered for a base class, it will be returned + for instances of derived classes unless a more specific type is registered. Examples: >>> mapping = TypeMapping(lambda type_: f"Default for {type_}") @@ -605,7 +607,7 @@ def _decorator(value: _T) -> _T: return _decorator def clear_cache(self) -> None: - """Clear the singledispatch cache.""" + """Clear the type dispatching cache.""" self._dispatcher._clear_cache() diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 9fb01c1386..0e9e5afec0 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -22,8 +22,8 @@ from __future__ import annotations -# reexport the actual configuration manager instance as a publice attribute -from ._config import config # ruff: isort: skip +# reexport the actual configuration manager instance as a public attribute +from ._config import Config as config_type, config # ruff: isort: skip from .._core.definitions import CUPY_DEVICE_TYPE, Device, DeviceType, is_scalar_type from . import common, ffront, iterator, program_processors, typing @@ -59,6 +59,7 @@ # submodules "common", "config", + "config_type", "ffront", "iterator", "program_processors", diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index a23be8fb61..4f7fd51212 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -9,18 +9,18 @@ """ GT4Py configuration system. -This module defines a typed configuration framework based on descriptors: +This module defines a typed configuration framework based on these concepts: -- `OptionDescriptor`: declares one option (type, default/default_factory, parser, - validator, environment variable mapping, and optional update callback). +- `OptionDescriptor`: full description of an option (type, default/default_factory, + parser, validator, environment variable mapping, and optional update callback). - `ConfigManager`: stores option values, resolves effective values using precedence, and supports task-local temporary overrides. - `Config`: concrete registry of GT4Py public options. -Configuration can be changed globally via attribute assignment or `set()`, and -temporarily via `overrides()`. +Configuration can be changed globally in a ConfigManaget instance via attribute +assignment or `set()`, and temporarily via `overrides()`. -The public singleton instance is exposed as `gt4py.next.config`. +The global GT4Py ConfigManager instance is exposed as `gt4py.next.config`. """ from __future__ import annotations @@ -29,7 +29,6 @@ import contextvars import dataclasses import enum -import functools import os import pathlib import sys @@ -85,7 +84,7 @@ def _parse_str(type_: type) -> Callable[[str], Any]: if issubclass(type_, enum.Enum): return lambda value: type_[value] # parse enum values from their names - return lambda x: type_(x) # type constructor as parser + return lambda x: type_(x) # type: ignore[call-arg] # use type constructor as parser @_parse_str.register(bool) @@ -107,7 +106,6 @@ def _parse_str_as_path(value: str) -> pathlib.Path: return pathlib.Path(expanded) -@functools.cache def _type_check_validator(type_: type) -> Callable[[Any], None]: """Generate a validator function that checks if a value is an instance of the given type.""" @@ -249,7 +247,8 @@ def env_var_name(self) -> str: class ConfigManager: - """Central configuration manager with attribute-style access. + """ + Central configuration manager with attribute-style access. Config options are defined as class attributes using `OptionDescriptor`. The manager stores global values for all options and allows temporary @@ -281,7 +280,7 @@ def __init__(self) -> None: for name, desc in self._descriptors.items() if callable(desc.validator) } - self._hooks: dict[str, OptionUpdateCallback[Any]] = { + self._callbacks: dict[str, OptionUpdateCallback[Any]] = { name: desc.update_callback for name, desc in self._descriptors.items() if desc.update_callback is not None @@ -311,7 +310,8 @@ def __init__(self) -> None: self._global_context[name] = init_value def get(self, name: str) -> Any: - """Get the effective value of a configuration option. + """ + Get the effective value of a configuration option. Applies precedence rules: context override > global value > environment > default. @@ -328,7 +328,8 @@ def get(self, name: str) -> Any: return self._global_context[name] def set(self, name: str, value: Any) -> None: - """Set the global value of a configuration option. + """ + Set the global value of a configuration option. Validates the value and invokes any registered callbacks. @@ -346,12 +347,13 @@ def set(self, name: str, value: Any) -> None: validator(value) old_val = self._global_context[name] self._global_context[name] = value - if hook := self._hooks.get(name): - hook(value, old_val, UpdateScope.GLOBAL) + if callback := self._callbacks.get(name): + callback(value, old_val, UpdateScope.GLOBAL) @contextlib.contextmanager def overrides(self, **overrides: Any) -> Generator[None, None, None]: - """Context manager for temporary configuration overrides. + """ + Context manager for temporary configuration overrides. Overrides are task-local (isolated per thread/async task) and automatically reverted when exiting the context manager. Nested contexts are supported. @@ -386,19 +388,22 @@ def overrides(self, **overrides: Any) -> Generator[None, None, None]: token = self._local_context_cvar.set(new_context) try: - for name in changes.keys() & self._hooks.keys(): - self._hooks[name](new_context[name], old_values[name], UpdateScope.CONTEXT) + for name in changes.keys() & self._callbacks.keys(): + self._callbacks[name](new_context[name], old_values[name], UpdateScope.CONTEXT) yield finally: self._local_context_cvar.reset(token) - for name in changes.keys() & old_context.keys() & self._hooks.keys(): - self._hooks[name](old_context.get(name), new_context.get(name), UpdateScope.CONTEXT) + for name in changes.keys() & old_context.keys() & self._callbacks.keys(): + self._callbacks[name]( + old_context.get(name), new_context.get(name), UpdateScope.CONTEXT + ) def as_dict(self) -> dict[str, Any]: - """Get the current effective configuration options as a dictionary. + """ + Get the current effective configuration options as a dictionary. Returns all configuration options with their effective values, preserving the order they were defined in the class. @@ -407,7 +412,8 @@ def as_dict(self) -> dict[str, Any]: return {name: self.get(name) for name in self._descriptors.keys()} def _option_descriptors_(self) -> types.MappingProxyType[str, OptionDescriptor]: - """Get the option descriptors. + """ + Get the option descriptors. Returns a read-only mapping of option names to their descriptors. This is useful for introspection and documentation purposes. @@ -463,7 +469,9 @@ class Config(ConfigManager): #: If set to a True value, it defaults to "gt4py_metrics_YYYYMMDD_HHMMSS.json" in #: the current folder. dump_metrics_at_exit = OptionDescriptor( - option_type=bool | pathlib.Path, default=False, env_var_parser=_parse_dump_metrics_filename + option_type=bool | pathlib.Path, + default=False, + env_var_parser=_parse_dump_metrics_filename, # type: ignore[arg-type] # mypy gets confused with the overloaded return type of the parser ) ## -- Build options -- @@ -556,4 +564,7 @@ class CMakeBuildType(enum.Enum): #: Use this to access and modify configuration options: config.debug, config.set(...), etc. config = Config() -print(config.as_dict()) +if config.debug: + print("GT4Py configuration:") + for name, value in sorted(config.as_dict().items()): + print(f" - {name}: {value}") diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index fa99f5fabd..553332e150 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -994,7 +994,8 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR +# TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) def _make_reduction( diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 9599019672..983e8cf3ae 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -29,7 +29,7 @@ import numpy as np from gt4py.eve import extended_typing as xtyping, utils -from gt4py.eve.extended_typing import Any, Final +from gt4py.eve.extended_typing import Any, Final, assert_never from gt4py.next import config from gt4py.next.otf import arguments @@ -458,11 +458,8 @@ def _dump_metrics_at_exit() -> None: metrics_dump_file = _init_dump_metrics_filename() case pathlib.Path() as user_path: metrics_dump_file = user_path - case _: - assert False, ( - f"Invalid type for 'dump_metrics_at_exit' config option: {config.dump_metrics_at_exit}" - f"({type(config.dump_metrics_at_exit)})" - ) + case _ as unreachable: + assert_never(unreachable) if metrics_dump_file is not None and (is_any_level_enabled() or sources): try: diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index 1b79cad6e4..374f2b88e8 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -13,7 +13,7 @@ import pathlib import subprocess import warnings -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar from gt4py._core import definitions as core_defs from gt4py.next import config, errors @@ -22,6 +22,10 @@ from gt4py.next.otf.compilation.build_systems import cmake_lists +if TYPE_CHECKING: + from gt4py.next import config_type + + def get_device_arch() -> str | None: if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.CUDA: # use `cp` from core_defs to avoid trying to re-import cupy @@ -69,13 +73,13 @@ class CMakeFactory( """Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings.""" cmake_generator_name: str = "Ninja" - cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + cmake_build_type: config_type.CMakeBuildType = config.CMakeBuildType.DEBUG cmake_extra_flags: list[str] = dataclasses.field(default_factory=list) def __call__( self, source: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> CMakeProject: if not source.binding_source: raise NotImplementedError( @@ -128,7 +132,7 @@ class CMakeProject(stages.BuildSystemProject[CPPLikeCodeSpecT, code_specs.Python source_files: dict[str, str] program_name: str generator_name: str = "Ninja" - build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + build_type: config_type.CMakeBuildType = config.CMakeBuildType.DEBUG extra_cmake_flags: list[str] = dataclasses.field(default_factory=list) def build(self) -> None: diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 347b0e25e9..4c13307c4a 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -14,7 +14,7 @@ import re import shutil import subprocess -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar from gt4py._core import locking from gt4py.next import config, errors @@ -24,6 +24,9 @@ from gt4py.next.otf.compilation.build_systems import cmake +if TYPE_CHECKING: + from gt4py.next import config_type + CPPLikeCodeSpecT = TypeVar("CPPLikeCodeSpecT", bound=code_specs.CPPLikeCodeSpec) @@ -39,14 +42,14 @@ class CompiledbFactory( and library dependencies. """ - cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + cmake_build_type: config_type.CMakeBuildType = config.CMakeBuildType.DEBUG cmake_extra_flags: list[str] = dataclasses.field(default_factory=list) renew_compiledb: bool = False def __call__( self, source: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> CompiledbProject: if not source.binding_source: raise NotImplementedError( @@ -244,7 +247,7 @@ def _cc_prototype_program_name( def _cc_prototype_program_source( deps: tuple[interface.LibraryDependency, ...], - build_type: config.CMakeBuildType, + build_type: config_type.CMakeBuildType, cmake_flags: list[str], code_spec: code_specs.CPPLikeCodeSpec, ) -> stages.ProgramSource: @@ -260,9 +263,9 @@ def _cc_prototype_program_source( def _cc_get_compiledb( renew_compiledb: bool, prototype_program_source: stages.ProgramSource, - build_type: config.CMakeBuildType, + build_type: config_type.CMakeBuildType, cmake_flags: list[str], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> pathlib.Path: cache_path = cache.get_cache_folder( stages.CompilableProject(prototype_program_source, None), cache_lifetime @@ -293,9 +296,9 @@ def _cc_find_compiledb(path: pathlib.Path) -> Optional[pathlib.Path]: def _cc_create_compiledb( prototype_program_source: stages.ProgramSource, - build_type: config.CMakeBuildType, + build_type: config_type.CMakeBuildType, cmake_flags: list[str], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> pathlib.Path: prototype_project = cmake.CMakeFactory( cmake_generator_name="Ninja", diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index c5bff5aca5..418aef5def 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -11,12 +11,17 @@ import hashlib import pathlib import tempfile +from typing import TYPE_CHECKING from gt4py.next import config from gt4py.next.otf import stages from gt4py.next.otf.binding import interface +if TYPE_CHECKING: + from gt4py.next import config_type + + _session_cache_dir = tempfile.TemporaryDirectory(prefix="gt4py_session_") _session_cache_dir_path = pathlib.Path(_session_cache_dir.name) @@ -50,7 +55,7 @@ def _cache_folder_name(source: stages.ProgramSource) -> str: def get_cache_folder( - compilable_source: stages.CompilableProject, lifetime: config.BuildCacheLifetime + compilable_source: stages.CompilableProject, lifetime: config_type.BuildCacheLifetime ) -> pathlib.Path: """ Construct the path to where the build system project artifact of a compilable source should be cached. diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 3748d95192..5877f8d06f 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -10,16 +10,19 @@ import dataclasses import pathlib -from typing import Protocol, TypeVar +from typing import TYPE_CHECKING, Protocol, TypeVar import factory from gt4py._core import locking -from gt4py.next import config from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer +if TYPE_CHECKING: + from gt4py.next import config_type + + T = TypeVar("T") @@ -40,7 +43,7 @@ class BuildSystemProjectGenerator(Protocol[CodeSpecT, TargetCodeSpecT]): def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT], - cache_lifetime: config.BuildCacheLifetime, + cache_lifetime: config_type.BuildCacheLifetime, ) -> stages.BuildSystemProject[CodeSpecT, TargetCodeSpecT]: ... @@ -58,7 +61,7 @@ class Compiler( ): """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" - cache_lifetime: config.BuildCacheLifetime + cache_lifetime: config_type.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] force_recompile: bool = False diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py b/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py index 8052426f33..c6266d806e 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py @@ -529,7 +529,7 @@ def _process_descending_points_of_state( # nested SDFG and also delete its alias inside it. _cleanup_memlet_path(state, descending_point) descending_point.consumer.remove_in_connector(descending_point.edge.dst_conn) - nsdfg.sdfg.remove_data(descending_point.edge.dst_conn, validate=gtx_config.DEBUG) + nsdfg.sdfg.remove_data(descending_point.edge.dst_conn, validate=gtx_config.debug) return nb_applies diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/common.py b/src/gt4py/next/program_processors/runners/dace/workflow/common.py index 43887dac63..1efc46507f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/common.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/common.py @@ -8,12 +8,15 @@ import contextlib import os -from typing import Any, Final, Generator, Optional +from typing import TYPE_CHECKING, Any, Final, Generator, Optional import dace from gt4py._core import definitions as core_defs -from gt4py.next import config + + +if TYPE_CHECKING: + from gt4py.next import config_type SDFG_ARG_METRIC_LEVEL: Final[str] = "gt_metrics_level" @@ -26,7 +29,7 @@ def set_dace_config( device_type: core_defs.DeviceType, - cmake_build_type: Optional[config.CMakeBuildType] = None, + cmake_build_type: Optional[config_type.CMakeBuildType] = None, ) -> None: """Set the DaCe configuration as required by GT4Py. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 8d802fe4b7..c9b015044f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -12,7 +12,7 @@ import os import warnings from collections.abc import Callable, MutableSequence, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any import dace import factory @@ -24,6 +24,10 @@ from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon +if TYPE_CHECKING: + from gt4py.next import config_type + + class CompiledDaceProgram: sdfg_program: dace.CompiledSDFG @@ -129,9 +133,9 @@ class DaCeCompiler( """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" bind_func_name: str - cache_lifetime: config.BuildCacheLifetime + cache_lifetime: config_type.BuildCacheLifetime device_type: core_defs.DeviceType - cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + cmake_build_type: config_type.CMakeBuildType = config.CMakeBuildType.DEBUG def __call__( self, diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index ef34918370..8760a95878 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Final +from typing import TYPE_CHECKING, Final import factory @@ -28,6 +28,9 @@ ) +if TYPE_CHECKING: + from gt4py.next import config_type + _GT_DACE_BINDING_FUNCTION_NAME: Final[str] = "update_sdfg_args" @@ -38,7 +41,7 @@ class Meta: class Params: auto_optimize: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough + cmake_build_type: config_type.CMakeBuildType = factory.LazyFunction( lambda: config.cmake_build_type ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 072f96e749..37051c2f39 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import functools -from typing import Any +from typing import TYPE_CHECKING, Any import factory import numpy as np @@ -25,6 +25,10 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module +if TYPE_CHECKING: + from gt4py.next import config_type + + def convert_arg(arg: Any) -> Any: # Note: this function is on the hot path and needs to have minimal overhead. if (origin := getattr(arg, "__gt_origin__", None)) is not None: @@ -112,7 +116,7 @@ class Meta: class Params: device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough + cmake_build_type: config_type.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough lambda: config.cmake_build_type ) builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( # type: ignore[assignment] # factory-boy typing not precise enough From 3ba2116caed04dfa1a01c80b1bb3fdbd1ed674f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 26 Feb 2026 14:51:41 +0100 Subject: [PATCH 10/12] Fix all typing and linting issues --- src/gt4py/next/_config.py | 8 ++++---- src/gt4py/next/embedded/nd_array_field.py | 2 +- src/gt4py/next/ffront/decorator.py | 4 ++-- src/gt4py/next/iterator/embedded.py | 8 ++++---- src/gt4py/next/otf/compilation/cache.py | 2 ++ .../program_processors/runners/dace/workflow/common.py | 2 ++ .../program_processors/runners/dace/workflow/factory.py | 2 +- src/gt4py/next/program_processors/runners/gtfn.py | 4 +++- 8 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index 4f7fd51212..3cb957f27c 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -9,15 +9,15 @@ """ GT4Py configuration system. -This module defines a typed configuration framework based on these concepts: +This module defines a configuration system based on these concepts: - `OptionDescriptor`: full description of an option (type, default/default_factory, parser, validator, environment variable mapping, and optional update callback). -- `ConfigManager`: stores option values, resolves effective values using precedence, - and supports task-local temporary overrides. +- `ConfigManager`: stores option values, supports task-local temporary overrides, + and resolves effective values using precedence. - `Config`: concrete registry of GT4Py public options. -Configuration can be changed globally in a ConfigManaget instance via attribute +Configuration can be changed globally in a ConfigManager instance via attribute assignment or `set()`, and temporarily via `overrides()`. The global GT4Py ConfigManager instance is exposed as `gt4py.next.config`. diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 553332e150..2a22c85283 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -995,7 +995,7 @@ def _concat_where( # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] # mypy bug? mypy cannot see experimental.concat_where type here def _make_reduction( diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a7b703564a..d5e9ce6a79 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -20,7 +20,7 @@ import typing import warnings from collections.abc import Callable -from typing import Any, Generic, Optional, Sequence, TypeAlias +from typing import Any, Final, Generic, Optional, Sequence, TypeAlias from gt4py import eve from gt4py._core import definitions as core_defs @@ -54,7 +54,7 @@ DEFAULT_BACKEND: next_backend.Backend | None = None -ProgramCallMetricsCollector = metrics.make_collector( +ProgramCallMetricsCollector: Final[type[metrics.BaseMetricsCollector]] = metrics.make_collector( # type: ignore[has-type] # mypy bug? mypy cannot see metrics.make_collector type here level=metrics.MINIMAL, metric_name=metrics.TOTAL_METRIC ) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b2f742b4f5..72748fb533 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1630,7 +1630,7 @@ def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProvider ) -@runtime.set_at.register(EMBEDDED) +@runtime.set_at.register(EMBEDDED) # type: ignore[has-type] # mypy bug? mypy cannot see runtime.set_at type here def set_at( expr: common.Field, domain_like: xtyping.MaybeNestedInTuple[common.DomainLike], @@ -1640,12 +1640,12 @@ def set_at( operators._tuple_assign_field(target, expr, domain) -@runtime.get_domain_range.register(EMBEDDED) +@runtime.get_domain_range.register(EMBEDDED) # type: ignore[has-type] # mypy bug? mypy cannot see runtime.get_domain_range type here def get_domain_range(field: common.Field, dim: common.Dimension) -> tuple[int, int]: return (field.domain[dim].unit_range.start, field.domain[dim].unit_range.stop) -@runtime.if_stmt.register(EMBEDDED) +@runtime.if_stmt.register(EMBEDDED) # type: ignore[has-type] # mypy bug? mypy cannot see runtime.if_stmt type here def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[[], None]) -> None: """ (Stateful) if statement. @@ -1665,7 +1665,7 @@ def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[ false_branch() -@runtime.temporary.register(EMBEDDED) +@runtime.temporary.register(EMBEDDED) # type: ignore[has-type] # mypy bug? mypy cannot see runtime.temporary type here def temporary(domain: runtime.CartesianDomain | runtime.UnstructuredDomain, dtype): type_ = runtime._dtypebuiltin_to_ts(dtype) new_domain = common.domain(domain) diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index 418aef5def..ebbe4b8601 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -8,6 +8,8 @@ """Caching for compiled backend artifacts.""" +from __future__ import annotations + import hashlib import pathlib import tempfile diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/common.py b/src/gt4py/next/program_processors/runners/dace/workflow/common.py index 1efc46507f..6ec0fe8f14 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/common.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/common.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import contextlib import os from typing import TYPE_CHECKING, Any, Final, Generator, Optional diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 8760a95878..5fbcb24676 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -41,7 +41,7 @@ class Meta: class Params: auto_optimize: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config_type.CMakeBuildType = factory.LazyFunction( + cmake_build_type: config_type.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factoryboy's type stubs seem incomplete lambda: config.cmake_build_type ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 37051c2f39..00bef6d9f8 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import functools from typing import TYPE_CHECKING, Any @@ -141,7 +143,7 @@ class Params: translation = factory.LazyAttribute(lambda o: o.bare_translation) bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] = ( - nanobind.bind_source + nanobind.bind_source # type: ignore[has-type] # mypy bug? mypy cannot see nanobind.bind_source type here ) compilation = factory.SubFactory( compiler.CompilerFactory, From 216d0106f8bfcff84fbedec5991550147e88e14f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 27 Feb 2026 10:03:03 +0100 Subject: [PATCH 11/12] Fix failing doctests --- docs/user/next/advanced/HackTheToolchain.md | 2 +- src/gt4py/next/_config.py | 32 ++++++++++----------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 785cc0b24d..803b4c7dd5 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -64,7 +64,7 @@ class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWor ) -PureCpp2WorkflowFactory(cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG) +PureCpp2WorkflowFactory(cmake_build_type=gtx.config.cmake_build_type.DEBUG) ``` ## Invent new Workflow Types diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index 3cb957f27c..764f401000 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -163,14 +163,6 @@ class OptionDescriptor(Generic[_T]): env_var_parser: Optional parser for environment variable values. env_var_prefix: Prefix for the environment variable name. name: Name of the option (set automatically via __set_name__). - - Example: - >>> class Config(ConfigManager): - ... debug = OptionDescriptor( - ... type=bool, - ... default=False, - ... update_callback=lambda new, old, scope: print(f"Changed to {new}"), - ... ) """ option_type: type[_T] | Any @@ -250,22 +242,28 @@ class ConfigManager: """ Central configuration manager with attribute-style access. - Config options are defined as class attributes using `OptionDescriptor`. - The manager stores global values for all options and allows temporary - overrides in a context manager scope. + Config options are defined as `OptionDescriptor` class attributes in a + concrete subclass of `ConfigManager`. The manager stores global values + for all options and allows temporary overrides in a context manager scope. The effective value of an option follows this precedence (highest to lowest): 1. Active context override via the `overrides()` context manager 2. Global runtime value set via the `set()` method - 3. Environment variable (if set) - 4. Descriptor default or default_factory result + 3. Default value from the environment variable (if set) + 4. Default value from the descriptor (either `default` or `default_factory`) Example: - >>> config = ConfigManager() + >>> class MyConfig(ConfigManager): + ... some_option = OptionDescriptor(option_type=str, default="default_value") + >>> config = MyConfig() + >>> config.get("some_option") # Default value from descriptor + 'default_value' + >>> config.set("some_option", "global_value") # Set global value >>> config.get("some_option") # Apply precedence rules - >>> config.set("some_option", value) # Set global value - >>> with config.overrides(some_option=value): # Temporary override - ... pass + 'global_value' + >>> with config.overrides(some_option="temporary_override"): # Temporary override + ... config.some_option + 'temporary_override' """ def __init__(self) -> None: From 689489af6db7e71aea21b1770a678a58bbbb9be5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 27 Feb 2026 11:53:53 +0100 Subject: [PATCH 12/12] Address copilot review issues --- src/gt4py/next/_config.py | 17 ++--------------- .../runners/dace/workflow/common.py | 1 + tests/next_tests/__init__.py | 2 +- .../dace_tests/test_dace_backend.py | 3 ++- .../dace_tests/test_dace_translation.py | 3 ++- tests/next_tests/unit_tests/test_config.py | 2 +- 6 files changed, 9 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/_config.py b/src/gt4py/next/_config.py index 764f401000..75681c0705 100644 --- a/src/gt4py/next/_config.py +++ b/src/gt4py/next/_config.py @@ -41,9 +41,6 @@ @final class _UnsetSentinel: - """Sentinel value for unset configuration options.""" - - __slots__ = () _instance: _UnsetSentinel | None = None def __new__(cls) -> _UnsetSentinel: @@ -51,9 +48,6 @@ def __new__(cls) -> _UnsetSentinel: cls._instance = super().__new__(cls) return cls._instance - def __repr__(self) -> str: - return "" - UNSET: Final[_UnsetSentinel] = _UnsetSentinel() @@ -394,10 +388,8 @@ def overrides(self, **overrides: Any) -> Generator[None, None, None]: finally: self._local_context_cvar.reset(token) - for name in changes.keys() & old_context.keys() & self._callbacks.keys(): - self._callbacks[name]( - old_context.get(name), new_context.get(name), UpdateScope.CONTEXT - ) + for name in changes.keys() & self._callbacks.keys(): + self._callbacks[name](old_values[name], new_context.get(name), UpdateScope.CONTEXT) def as_dict(self) -> dict[str, Any]: """ @@ -561,8 +553,3 @@ class CMakeBuildType(enum.Enum): #: Global singleton instance of the GT4Py configuration manager. #: Use this to access and modify configuration options: config.debug, config.set(...), etc. config = Config() - -if config.debug: - print("GT4Py configuration:") - for name, value in sorted(config.as_dict().items()): - print(f" - {name}: {value}") diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/common.py b/src/gt4py/next/program_processors/runners/dace/workflow/common.py index 4d0fe2157c..5e193db3a6 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/common.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/common.py @@ -15,6 +15,7 @@ import dace from gt4py._core import definitions as core_defs +from gt4py.next import config if TYPE_CHECKING: diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index 7a094466dd..b461b06153 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -18,7 +18,7 @@ if config.build_cache_lifetime is config.BuildCacheLifetime.PERSISTENT: warnings.warn( - "You are running GT4Py tests with 'config.BuildCacheLifetime' set to PERSISTENT!", + "You are running GT4Py tests with 'config.build_cache_lifetime' set to PERSISTENT!", UserWarning, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py index dbc4d87490..7e4419f03d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_backend.py @@ -16,6 +16,7 @@ from gt4py import next as gtx from gt4py._core import definitions as core_defs +from gt4py.next import config as gt_config from gt4py.next.program_processors.runners.dace.workflow import ( backend as dace_wf_backend, ) @@ -92,7 +93,7 @@ def mocked_gpu_transformation(*args, **kwargs) -> dace.SDFG: monkeypatch.setattr(gtx_transformations, "gt_auto_optimize", mocked_auto_optimize) monkeypatch.setattr(gtx_transformations, "gt_gpu_transformation", mocked_gpu_transformation) - with mock.patch("gt4py.next.config.unstructured_horizontal_has_unit_stride", on_gpu): + with gt_config.overrides(unstructured_horizontal_has_unit_stride=on_gpu): custom_backend = dace_wf_backend.make_dace_backend( gpu=on_gpu, cached=False, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py index 20651021aa..58cc3c8dcb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py @@ -25,6 +25,7 @@ common as dace_wf_common, ) from gt4py.next.type_system import type_specifications as ts +from gt4py.next import config as gt_config from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( V2E, @@ -97,7 +98,7 @@ def test_find_constant_symbols(has_unit_stride, disable_field_origin): ], ) - with mock.patch("gt4py.next.config.unstructured_horizontal_has_unit_stride", has_unit_stride): + with gt_config.overrides(unstructured_horizontal_has_unit_stride=has_unit_stride): sdfg = _translate_gtir_to_sdfg( ir=ir, offset_provider=SKIP_VALUE_MESH.offset_provider, diff --git a/tests/next_tests/unit_tests/test_config.py b/tests/next_tests/unit_tests/test_config.py index 1be354c6c8..ebbb33eb48 100644 --- a/tests/next_tests/unit_tests/test_config.py +++ b/tests/next_tests/unit_tests/test_config.py @@ -102,7 +102,7 @@ class TestConfig(ConfigManager): [ ("/tmp/test", pathlib.Path("/tmp/test")), ("./relative/path", pathlib.Path("./relative/path")), - ("~/user/path", pathlib.Path(os.environ["HOME"]) / "user" / "path"), + ("~/user/path", pathlib.Path(os.path.expanduser("~/user/path"))), ], ) def test_parse_path(self, value, expected) -> None: