From bcf725389739b419a126c323b90a4decff283861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 7 May 2026 12:21:55 +0200 Subject: [PATCH 1/4] Refactoring ffront stages with fingerprinted protocol --- src/gt4py/__init__.py | 2 +- src/gt4py/eve/utils.py | 35 +++-- src/gt4py/next/backend.py | 4 +- src/gt4py/next/embedded/nd_array_field.py | 2 +- src/gt4py/next/ffront/foast_to_gtir.py | 2 +- src/gt4py/next/ffront/foast_to_past.py | 2 +- src/gt4py/next/ffront/func_to_foast.py | 2 +- src/gt4py/next/ffront/func_to_past.py | 2 +- src/gt4py/next/ffront/past_passes/linters.py | 2 +- src/gt4py/next/ffront/past_process_args.py | 2 +- src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/ffront/stages.py | 130 ++++++------------- src/gt4py/next/utils.py | 111 +++++++++++++++- tests/next_tests/unit_tests/test_utils.py | 12 +- 14 files changed, 185 insertions(+), 125 deletions(-) diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 350f6ac5e8..0aef41bf2c 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -13,8 +13,8 @@ - `gt4py.next` supports structured and unstructured grid. """ +from . __about__ import __author__, __copyright__, __license__, __version__, __version_info__ # noqa: I001 from . import eve, storage -from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__ __all__ = [ diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8da5cdfd22..c6126566a7 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -23,6 +23,7 @@ import pickle import pprint import re +import sys import types import typing @@ -630,6 +631,16 @@ def is_noninstantiable(cls: Type[_T]) -> bool: return "__noninstantiable__" in cls.__dict__ +def singledispatcher( + default: Callable[P, T], + implementations: dict[type, Callable[[Any], Any]], +) -> Callable[P, T]: + result = functools.singledispatch(default) + for cls, func in implementations.items(): + result.register(cls)(func) + return result + + def content_hash( *args: Any, hash_algorithm: str | xtyping.HashlibAlgorithm | None = None, @@ -669,6 +680,7 @@ def content_hash( def custom_pickler( reducer: Callable[[Any], tuple | types.NotImplementedType], + *, name: str | None = None, ) -> type[pickle.Pickler]: """ @@ -685,7 +697,9 @@ def custom_pickler( def custom_pickler_from_reducers( custom_reducers: dict[type, Callable[[Any], tuple | types.NotImplementedType]], + *, name: str | None = None, + default_reducer: Callable[[Any], tuple | types.NotImplementedType] = lambda _: NotImplemented, ) -> type[pickle.Pickler]: """ Create a pickler with the provided reducers registered in reducer override. @@ -696,13 +710,8 @@ def custom_pickler_from_reducers( the `dispatch_table` dict, to allow easy pickle customization of entire class hierarchies. """ - reducer = functools.singledispatch( - cast(Callable[[Any], tuple | types.NotImplementedType], lambda _: NotImplemented) - ) - for cls, func in custom_reducers.items(): - reducer.register(cls)(func) - return custom_pickler(reducer, name=name) + return custom_pickler(singledispatcher(default_reducer, custom_reducers), name=name) ddiff = deepdiff.diff.DeepDiff @@ -1226,14 +1235,12 @@ def getitem(self, *indices: Union[int, str], default: Any = NOTHING) -> XIterabl >>> list(it.getitem(0)) ['a', 'b', 'c'] - >>> it = xiter( - ... [ - ... dict(name="AA", age=20, country="US"), - ... dict(name="BB", age=30, country="UK"), - ... dict(name="CC", age=40, country="EU"), - ... dict(country="CH"), - ... ] - ... ) + >>> it = xiter([ + ... dict(name="AA", age=20, country="US"), + ... dict(name="BB", age=30, country="UK"), + ... dict(name="CC", age=40, country="EU"), + ... dict(country="CH"), + ... ]) >>> list(it.getitem("name", "age", default=None)) [('AA', 20), ('BB', 30), ('CC', 40), (None, None)] diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 6123da97e0..329d51775f 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -12,7 +12,7 @@ from typing import Generic from gt4py._core import definitions as core_defs -from gt4py.next import custom_layout_allocators as next_allocators +from gt4py.next import custom_layout_allocators as next_allocators, utils from gt4py.next.ffront import ( foast_to_gtir, foast_to_past, @@ -145,7 +145,7 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: # `transforms` -> `frontend_transforms` # `executor` -> `backend_transforms` @dataclasses.dataclass(frozen=True) -class Backend(Generic[core_defs.DeviceTypeT]): +class Backend(utils.FingerprintedDataclass, Generic[core_defs.DeviceTypeT]): name: str executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e9aff84a15..fe8d06b450 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -106,7 +106,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: class NdArrayField( common.MutableField[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry, - utils.MetadataBasedPickling, + utils.MetadataBasedPicklingMixin, ): """ Shared field implementation for NumPy-like fields. diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3825072cb7..c7db410440 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -47,7 +47,7 @@ def foast_to_gtir_factory( """Wrap `foast_to_gtir` into a chainable and, optionally, cached workflow step.""" wf = foast_to_gtir if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 05b080b70b..3d1611e623 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -190,5 +190,5 @@ def operator_to_program_factory( foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory() ) if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ced0ff3905..ccf94829a9 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -99,7 +99,7 @@ def func_to_foast_factory( """Wrap `func_to_foast` in a chainable and optionally cached workflow step.""" wf = workflow.make_step(func_to_foast) if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index ebc0de31b3..cfb2e4f6dd 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -80,7 +80,7 @@ def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSLProgramDef """ wf = workflow.make_step(func_to_past) if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py index 6d9fb9123b..8fc41a5868 100644 --- a/src/gt4py/next/ffront/past_passes/linters.py +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -51,7 +51,7 @@ def linter_factory( ) -> workflow.Workflow[PASTProgramDef, PASTProgramDef]: wf = lint_misnamed_functions.chain(lint_undefined_symbols) if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index ce794fd9dc..c6edaca324 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -41,7 +41,7 @@ def transform_program_args_factory( ) -> workflow.Workflow[ffront_stages.ConcretePASTProgramDef, ffront_stages.ConcretePASTProgramDef]: wf = transform_program_args if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 6b19e7cc1f..af9641db6f 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -152,7 +152,7 @@ def past_to_gtir_factory( ) -> workflow.Workflow[ConcretePASTProgramDef, definitions.CompilableProgramDef]: wf = workflow.make_step(past_to_gtir) if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index da05c68669..fc2dd226d7 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -25,21 +25,50 @@ import collections.abc import dataclasses import functools -import hashlib import types import typing from typing import Any, Optional, TypeVar -import xxhash -from gt4py.eve import extended_typing as xtyping -from gt4py.next import common +from gt4py.eve import extended_typing as xtyping, utils as eve_utils +from gt4py.next import common, utils from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils from gt4py.next.otf import arguments, toolchain +# Create a custom pickler for the BaseStage `fingerprinter` that handles +# `types.FunctionType` by using its source code and closure variables. +# This should be enough for the use case of GT4Py DSL definitions, +# which are expected to be pure functions without complicated closures. +fingerprint_reducer = eve_utils.singledispatcher( + utils.fingerprint_reducer, + { + types.FunctionType: lambda f: ( + f.__class__, + (), + ( + source_utils.make_source_definition_from_function(f), + source_utils.get_closure_vars_from_function(f), + ), + ) + }, +) + +fingerprinter = functools.partial( + eve_utils.content_hash, + pickler=eve_utils.custom_pickler(fingerprint_reducer, name="FFrontFingerprintPickler"), +) + + +@dataclasses.dataclass(frozen=True) +class BaseStage(utils.FingerprintedDataclass): + """Base class for optimized Fingerprinted implementations in frozen dataclasses.""" + + fingerprinter = staticmethod(fingerprinter) + + @dataclasses.dataclass(frozen=True) -class DSLFieldOperatorDef: +class DSLFieldOperatorDef(BaseStage): definition: types.FunctionType node_class: type[foast.OperatorNode] = foast.FieldOperator attributes: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -53,7 +82,7 @@ class DSLFieldOperatorDef: @dataclasses.dataclass(frozen=True) -class FOASTOperatorDef: +class FOASTOperatorDef(BaseStage): foast_node: foast.OperatorNode closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None @@ -67,7 +96,7 @@ class FOASTOperatorDef: @dataclasses.dataclass(frozen=True) -class DSLProgramDef: +class DSLProgramDef(BaseStage): definition: types.FunctionType grid_type: Optional[common.GridType] = None debug: bool = False @@ -79,7 +108,7 @@ class DSLProgramDef: @dataclasses.dataclass(frozen=True) -class PASTProgramDef: +class PASTProgramDef(BaseStage): past_node: past.Program closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None @@ -92,88 +121,3 @@ class PASTProgramDef: DSLDefinition = DSLFieldOperatorDef | DSLProgramDef DSLDefinitionT = TypeVar("DSLDefinitionT", DSLFieldOperatorDef, DSLProgramDef) - - -def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str: - hasher: xtyping.HashlibAlgorithm - if not algorithm: - hasher = xxhash.xxh64() # type: ignore[assignment] # fixing this requires https://github.com/ifduyue/python-xxhash/issues/104 - elif isinstance(algorithm, str): - hasher = hashlib.new(algorithm) - else: - hasher = algorithm - - add_content_to_fingerprint(obj, hasher) - return hasher.hexdigest() - - -@functools.singledispatch -def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: - hasher.update(str(obj).encode()) - - -for t in (str, int): - add_content_to_fingerprint.register(t, add_content_to_fingerprint.registry[object]) - - -@add_content_to_fingerprint.register(DSLFieldOperatorDef) -@add_content_to_fingerprint.register(FOASTOperatorDef) -@add_content_to_fingerprint.register(DSLProgramDef) -@add_content_to_fingerprint.register(PASTProgramDef) -@add_content_to_fingerprint.register(toolchain.ConcreteArtifact) -@add_content_to_fingerprint.register(arguments.CompileTimeArgs) -def add_stage_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: - add_content_to_fingerprint(obj.__class__, hasher) - for field in dataclasses.fields(obj): - add_content_to_fingerprint(getattr(obj, field.name), hasher) - - -def add_jit_args_id_to_fingerprint( - obj: arguments.JITArgs, hasher: xtyping.HashlibAlgorithm -) -> None: - add_content_to_fingerprint(str(id(obj)), hasher) - - -@add_content_to_fingerprint.register -def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgorithm) -> None: - sourcedef = source_utils.SourceDefinition.from_function(obj) - for item in sourcedef: - add_content_to_fingerprint(item, hasher) - - closure_vars = source_utils.get_closure_vars_from_function(obj) - for item in sorted(closure_vars.items(), key=lambda x: x[0]): - add_content_to_fingerprint(item, hasher) - - -@add_content_to_fingerprint.register -def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - # just a small helper to additionally allow sorting types (by just using their name) - def key_function(key: Any) -> Any: - if isinstance(key, type): - return key.__module__, key.__qualname__ - return key - - for key in sorted(obj.keys(), key=key_function): - add_content_to_fingerprint(key, hasher) - add_content_to_fingerprint(obj[key], hasher) - - -@add_content_to_fingerprint.register -def add_type_to_fingerprint(obj: type, hasher: xtyping.HashlibAlgorithm) -> None: - hasher.update(obj.__name__.encode()) - - -@add_content_to_fingerprint.register -def add_sequence_to_fingerprint( - obj: collections.abc.Iterable, hasher: xtyping.HashlibAlgorithm -) -> None: - for item in obj: - add_content_to_fingerprint(item, hasher) - - -@add_content_to_fingerprint.register -def add_foast_located_node_to_fingerprint( - obj: foast.LocatedNode, hasher: xtyping.HashlibAlgorithm -) -> None: - add_content_to_fingerprint(obj.location, hasher) - add_content_to_fingerprint(str(obj), hasher) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 942bdb8459..bd0948d11c 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -8,6 +8,7 @@ from __future__ import annotations +import abc import copyreg import dataclasses import functools @@ -15,12 +16,14 @@ import itertools import types from typing import ( + TYPE_CHECKING, Any, Callable, ClassVar, Final, Optional, ParamSpec, + Protocol, Sequence, TypeAlias, TypeGuard, @@ -30,6 +33,7 @@ ) from gt4py.eve import datamodels, utils as eve_utils +from gt4py.eve.extended_typing import Self GT4PY_CLASS_METADATA_NS: Final[str] = "GT4PY_META" @@ -111,7 +115,7 @@ def __getstate__(self: object) -> _StandardPickleState: return __getstate__ -class MetadataBasedPickling: +class MetadataBasedPicklingMixin: """ Mixin for adding metadata-based pickling to dataclass-like objects. @@ -138,6 +142,111 @@ def __getstate__(self) -> _StandardPickleState: # implementation. +class Fingerprinted(Protocol): + """ + Protocol for objects that can be fingerprinted with a custom function. + + The fingerprint should be a stable hash string representing the state of the object. + """ + + __slots__ = () + + @property + def fingerprint(self) -> str: + """Get the fingerprint of the object.""" + ... + + @property + def fingerprinter(self) -> Callable[[Fingerprinted], str]: + """Get the fingerprinting function for the object.""" + ... + + +class FingerprintedABC(abc.ABC): + """ + ABC of objects implementing the fingerprinting protocol. + + It provides a custom subclass hook to recognize classes implementing the + protocol without inheriting from the ABC, without the performance problems + of using `isinstance` checks on runtime-checkable protocols directly. + """ + + __slots__ = () + + @classmethod + def __subclasshook__(cls, subclass: type) -> bool | types.NotImplementedType: + if ( + cls is FingerprintedABC + and hasattr(subclass, "fingerprint") + and hasattr(subclass, "fingerprinter") + ): + return True + + return NotImplemented + + @staticmethod + @abc.abstractmethod + def fingerprinter(instance: FingerprintedABC) -> str: ... # to be implemented by subclasses + + @property + def fingerprint(self: Self) -> str: + return self.fingerprinter(self) + + +fingerprint_reducer = eve_utils.singledispatcher( + lambda _: NotImplemented, + {FingerprintedABC: lambda obj: (obj.__class__, (), (obj.fingerprint,))}, +) +fingerprint_pickler = eve_utils.custom_pickler(fingerprint_reducer, name="FingerprintPickler") + +fingerprint: Callable[[Any], str] = functools.partial( + eve_utils.content_hash, pickler=fingerprint_pickler +) +""" +Default fingerprinting function for GT4Py objects. + +It uses `eve_utils.content_hash` as fingerprinting function. If the object +is an instance of a class implementing the `FingerprintedProtocol`, it will +instead use the `fingerprint` property of the object as its content hash. +""" + + +class FingerprintedMixin: + """General mixin to add support for the Fingerprinted protocol to any class.""" + + __slots__ = () + + fingerprinter = staticmethod(fingerprint) + + @property + def fingerprint(self) -> str: + return self.fingerprinter(self) + + +assert issubclass(FingerprintedMixin, FingerprintedABC) +if TYPE_CHECKING: + _FM: type[Fingerprinted] = FingerprintedMixin + + +class CachedFingerprintedMixin: + """Mixin to add an optimized implementation of the Fingerprinted protocol to frozen classes.""" + + fingerprinter = staticmethod(fingerprint) + + @(functools.cached_property if not TYPE_CHECKING else property) + def fingerprint(self) -> str: + return self.fingerprinter(self) + + +assert issubclass(CachedFingerprintedMixin, FingerprintedABC) +if TYPE_CHECKING: + _CFM: type[Fingerprinted] = CachedFingerprintedMixin + + +class FingerprintedDataclass(CachedFingerprintedMixin, MetadataBasedPicklingMixin): + __slots__ = () + + class RecursionGuard: """ Context manager to guard against inifinite recursion. diff --git a/tests/next_tests/unit_tests/test_utils.py b/tests/next_tests/unit_tests/test_utils.py index 51ed0cf8db..ffa774d5f2 100644 --- a/tests/next_tests/unit_tests/test_utils.py +++ b/tests/next_tests/unit_tests/test_utils.py @@ -17,35 +17,35 @@ # Module-level classes so pickle can resolve them by qualified name. @dataclasses.dataclass -class _DataclassModel(utils.MetadataBasedPickling): +class _DataclassModel(utils.MetadataBasedPicklingMixin): value: int transient: str = dataclasses.field(default="skip", metadata=utils.gt4py_metadata(pickle=False)) @dataclasses.dataclass(slots=True) -class _SlottedDataclassModel(utils.MetadataBasedPickling): +class _SlottedDataclassModel(utils.MetadataBasedPicklingMixin): value: int transient: str = dataclasses.field(default="skip", metadata=utils.gt4py_metadata(pickle=False)) @datamodels.datamodel(slots=False) -class _DatamodelModel(utils.MetadataBasedPickling): +class _DatamodelModel(utils.MetadataBasedPicklingMixin): value: int transient: str = datamodels.field(default="skip", metadata=utils.gt4py_metadata(pickle=False)) @dataclasses.dataclass -class _EmptyDataclassModel(utils.MetadataBasedPickling): +class _EmptyDataclassModel(utils.MetadataBasedPicklingMixin): pass @dataclasses.dataclass(slots=True) -class _EmptySlottedDataclassModel(utils.MetadataBasedPickling): +class _EmptySlottedDataclassModel(utils.MetadataBasedPicklingMixin): pass @datamodels.datamodel(slots=False) -class _EmptyDatamodelModel(utils.MetadataBasedPickling): +class _EmptyDatamodelModel(utils.MetadataBasedPicklingMixin): pass From 6216d0c8bf390a356c1e108a2a465eeaa7f7ca03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 7 May 2026 17:35:23 +0200 Subject: [PATCH 2/4] Refactoring workflows and programs/field operators --- src/gt4py/__init__.py | 2 +- src/gt4py/eve/utils.py | 1 - src/gt4py/next/ffront/decorator.py | 24 ++------------ src/gt4py/next/ffront/foast_to_gtir.py | 2 +- src/gt4py/next/ffront/foast_to_past.py | 2 +- src/gt4py/next/ffront/func_to_foast.py | 2 +- src/gt4py/next/ffront/func_to_past.py | 2 +- src/gt4py/next/ffront/past_passes/linters.py | 2 +- src/gt4py/next/ffront/past_process_args.py | 2 +- src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/ffront/stages.py | 4 +-- src/gt4py/next/otf/compilation/compiler.py | 3 +- src/gt4py/next/otf/toolchain.py | 4 +++ src/gt4py/next/otf/workflow.py | 32 +++++++++++++++---- .../codegens/gtfn/gtfn_module.py | 3 +- .../runners/dace/workflow/backend.py | 2 +- .../runners/dace/workflow/compilation.py | 3 +- .../runners/dace/workflow/factory.py | 2 +- .../runners/dace/workflow/translation.py | 3 +- .../next/program_processors/runners/gtfn.py | 4 +-- .../unit_tests/otf_tests/test_workflow.py | 2 +- 21 files changed, 54 insertions(+), 49 deletions(-) diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 0aef41bf2c..fc8d734cc9 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -13,7 +13,7 @@ - `gt4py.next` supports structured and unstructured grid. """ -from . __about__ import __author__, __copyright__, __license__, __version__, __version_info__ # noqa: I001 +from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__ # noqa: I001 from . import eve, storage diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index c6126566a7..f5c6eb9169 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -23,7 +23,6 @@ import pickle import pprint import re -import sys import types import typing diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index c749fcec01..db60062611 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -83,7 +83,9 @@ def embedded_program_call_context( @dataclasses.dataclass(frozen=True) -class _CompilableGTEntryPointMixin(Generic[ffront_stages.DSLDefinitionT]): +class _CompilableGTEntryPointMixin( + utils.FingerprintedDataclass, Generic[ffront_stages.DSLDefinitionT] +): """ Mixing used by program and program-like objects. @@ -862,23 +864,3 @@ def scan_operator_inner(definition: Callable) -> FieldOperator: ) return scan_operator_inner if definition is None else scan_operator_inner(definition) - - -@ffront_stages.add_content_to_fingerprint.register -def add_fieldop_to_fingerprint(obj: FieldOperator, hasher: xtyping.HashlibAlgorithm) -> None: - ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) - ffront_stages.add_content_to_fingerprint(obj.backend, hasher) - - -@ffront_stages.add_content_to_fingerprint.register -def add_foast_fieldop_to_fingerprint( - obj: FieldOperatorFromFoast, hasher: xtyping.HashlibAlgorithm -) -> None: - ffront_stages.add_content_to_fingerprint(obj.foast_stage, hasher) - ffront_stages.add_content_to_fingerprint(obj.backend, hasher) - - -@ffront_stages.add_content_to_fingerprint.register -def add_program_to_fingerprint(obj: Program, hasher: xtyping.HashlibAlgorithm) -> None: - ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) - ffront_stages.add_content_to_fingerprint(obj.backend, hasher) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index c7db410440..633e1b2cfe 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -47,7 +47,7 @@ def foast_to_gtir_factory( """Wrap `foast_to_gtir` into a chainable and, optionally, cached workflow step.""" wf = foast_to_gtir if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprinter) + wf = workflow.CachedStep(step=wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 3d1611e623..5507011ad2 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -190,5 +190,5 @@ def operator_to_program_factory( foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory() ) if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprinter) + wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ccf94829a9..89fdb5a043 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -99,7 +99,7 @@ def func_to_foast_factory( """Wrap `func_to_foast` in a chainable and optionally cached workflow step.""" wf = workflow.make_step(func_to_foast) if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprinter) + wf = workflow.CachedStep(step=wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index cfb2e4f6dd..d026774d77 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -80,7 +80,7 @@ def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSLProgramDef """ wf = workflow.make_step(func_to_past) if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprinter) + wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py index 8fc41a5868..3008ff7e99 100644 --- a/src/gt4py/next/ffront/past_passes/linters.py +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -51,7 +51,7 @@ def linter_factory( ) -> workflow.Workflow[PASTProgramDef, PASTProgramDef]: wf = lint_misnamed_functions.chain(lint_undefined_symbols) if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprinter) + wf = workflow.CachedStep(step=wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index c6edaca324..b014ae3007 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -41,7 +41,7 @@ def transform_program_args_factory( ) -> workflow.Workflow[ffront_stages.ConcretePASTProgramDef, ffront_stages.ConcretePASTProgramDef]: wf = transform_program_args if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprinter) + wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index af9641db6f..4292f80cf3 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -152,7 +152,7 @@ def past_to_gtir_factory( ) -> workflow.Workflow[ConcretePASTProgramDef, definitions.CompilableProgramDef]: wf = workflow.make_step(past_to_gtir) if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprinter) + wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index fc2dd226d7..25aec84432 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -22,15 +22,13 @@ from __future__ import annotations -import collections.abc import dataclasses import functools import types import typing from typing import Any, Optional, TypeVar - -from gt4py.eve import extended_typing as xtyping, utils as eve_utils +from gt4py.eve import utils as eve_utils from gt4py.next import common, utils from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils from gt4py.next.otf import arguments, toolchain diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 3748d95192..56b96a6cda 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -15,7 +15,7 @@ import factory from gt4py._core import locking -from gt4py.next import config +from gt4py.next import config, utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer @@ -55,6 +55,7 @@ class Compiler( stages.ExecutableProgram, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], + utils.FingerprintedDataclass, ): """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" diff --git a/src/gt4py/next/otf/toolchain.py b/src/gt4py/next/otf/toolchain.py index 0c816759ff..1c5088bf96 100644 --- a/src/gt4py/next/otf/toolchain.py +++ b/src/gt4py/next/otf/toolchain.py @@ -12,6 +12,7 @@ import typing from typing import Generic +from gt4py.next import utils from gt4py.next.otf import workflow @@ -32,6 +33,7 @@ class DataOnlyAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, workflow.Workflow[ConcreteArtifact[S, ArgsT], ConcreteArtifact[T, ArgsT]], + utils.FingerprintedDataclass, Generic[ArgsT, S, T], ): step: workflow.Workflow[S, T] @@ -45,6 +47,7 @@ class ArgsOnlyAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, workflow.Workflow[ConcreteArtifact[DefT, S], ConcreteArtifact[DefT, T]], + utils.FingerprintedDataclass, Generic[DefT, S, T], ): step: workflow.Workflow[S, T] @@ -58,6 +61,7 @@ class StripArgsAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, workflow.Workflow[ConcreteArtifact[S, ArgsT], T], + utils.FingerprintedDataclass, Generic[ArgsT, S, T], ): step: workflow.Workflow[S, T] diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 08abadd7e5..740e258127 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -17,6 +17,7 @@ from typing_extensions import Self from gt4py.eve.extended_typing import OpaqueMutableMapping +from gt4py.next import utils StartT = TypeVar("StartT") @@ -92,7 +93,9 @@ def chain( @dataclasses.dataclass(frozen=True) class NamedStepSequence( - ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT] + ChainableWorkflowMixin[StartT, EndT], + ReplaceEnabledWorkflowMixin[StartT, EndT], + utils.FingerprintedDataclass, ): """ Workflow with linear succession of named steps. @@ -159,7 +162,9 @@ def step_order(self) -> list[str]: @dataclasses.dataclass(frozen=True) class MultiWorkflow( - ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT] + ChainableWorkflowMixin[StartT, EndT], + ReplaceEnabledWorkflowMixin[StartT, EndT], + utils.FingerprintedDataclass, ): """A flexible workflow, where the sequence of steps depends on the input type.""" @@ -175,7 +180,10 @@ def step_order(self, inp: StartT) -> list[str]: @dataclasses.dataclass(frozen=True) -class StepSequence(ChainableWorkflowMixin[StartT, EndT]): +class StepSequence( + ChainableWorkflowMixin[StartT, EndT], + utils.FingerprintedDataclass, +): """ Composable workflow of single input callables. @@ -227,6 +235,7 @@ def start(cls, first_step: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[Sta class CachedStep( ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT], + utils.FingerprintedDataclass, Generic[StartT, EndT, HashT], ): """ @@ -254,22 +263,31 @@ class CachedStep( """ step: Workflow[StartT, EndT] - hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] - cache: OpaqueMutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict) + key_function: Callable[[StartT], HashT] = dataclasses.field( + default=utils.fingerprint, metadata=utils.gt4py_metadata(pickle=False) + ) # type: ignore[assignment] + cache: OpaqueMutableMapping[HashT, EndT] = dataclasses.field( + repr=False, default_factory=dict, metadata=utils.gt4py_metadata(pickle=False) + ) # type: ignore[assignment] def __call__(self, inp: StartT) -> EndT: """Run the step only if the input is not cached, else return from cache.""" - hash_ = self.hash_function(inp) + hash_ = self.cache_key(inp) try: result = self.cache[hash_] except KeyError: result = self.cache[hash_] = self.step(inp) return result + def cache_key(self, inp: StartT) -> str: + return utils.fingerprint((self, self.key_function(inp))) + @dataclasses.dataclass(frozen=True) class SkippableStep( - ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT] + ChainableWorkflowMixin[StartT, EndT], + ReplaceEnabledWorkflowMixin[StartT, EndT], + utils.FingerprintedDataclass, ): step: Workflow[StartT, EndT] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 2135af7fbb..d6112b91a6 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -17,7 +17,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import codegen -from gt4py.next import common +from gt4py.next import common, utils from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import pass_manager @@ -46,6 +46,7 @@ class GTFNTranslationStep( definitions.CompilableProgramDef, stages.ProgramSource[code_specs.HeaderAndSourceCodeSpec], ], + utils.FingerprintedDataclass, ): code_spec: Optional[code_specs.HeaderAndSourceCodeSpec] = None # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index de6778a750..1bdd38b6c4 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -44,7 +44,7 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + lambda o: workflow.CachedStep(o.otf_workflow, key_function=o.hash_function) ), name_cached="_cached", ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index e1747b7ac3..6491516ba7 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -18,7 +18,7 @@ import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import common, config +from gt4py.next import common, config, utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon @@ -125,6 +125,7 @@ class DaCeCompiler( CompiledDaceProgram, ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], + utils.FingerprintedDataclass, ): """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 62febd0965..bbe3d43c14 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -46,7 +46,7 @@ class Params: translation=factory.LazyAttribute( lambda o: workflow.CachedStep( o.bare_translation, - hash_function=stages.fingerprint_compilable_program, + key_function=stages.fingerprint_compilable_program, cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "translation_cache")), ) ), diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index ad8e8ea04b..35a7f56417 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -15,7 +15,7 @@ import factory from gt4py._core import definitions as core_defs -from gt4py.next import common, config +from gt4py.next import common, config, utils from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import code_specs, definitions, stages, workflow @@ -355,6 +355,7 @@ class DaCeTranslator( stages.ProgramSource[code_specs.SDFGCodeSpec], ], definitions.TranslationStep[code_specs.SDFGCodeSpec], + utils.FingerprintedDataclass, ): device_type: core_defs.DeviceType auto_optimize: bool diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c1743dea6a..9d97791f45 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -123,7 +123,7 @@ class Params: translation=factory.LazyAttribute( lambda o: workflow.CachedStep( o.bare_translation, - hash_function=stages.fingerprint_compilable_program, + key_function=stages.fingerprint_compilable_program, cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), ) ), @@ -165,7 +165,7 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + lambda o: workflow.CachedStep(o.otf_workflow, key_function=o.hash_function) ), name_cached="_cached", ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_workflow.py b/tests/next_tests/unit_tests/otf_tests/test_workflow.py index d4717d34ef..853bf7a33c 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_workflow.py +++ b/tests/next_tests/unit_tests/otf_tests/test_workflow.py @@ -71,7 +71,7 @@ def test_cached_with_hashing(): def hashing(inp: list[int]) -> int: return hash(sum(inp)) - wf = workflow.CachedStep(step=lambda inp: [*inp, 1], hash_function=hashing) + wf = workflow.CachedStep(step=lambda inp: [*inp, 1], key_function=hashing) assert wf([1, 2, 3]) == [1, 2, 3, 1] assert wf([3, 2, 1]) == [1, 2, 3, 1] From e9cd6e2441e6313f4d78c6f3b514873e9a31cbe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 8 May 2026 18:51:54 +0200 Subject: [PATCH 3/4] Refactoring executors and other fixes --- src/gt4py/eve/utils.py | 14 +- src/gt4py/next/backend.py | 2 +- src/gt4py/next/ffront/decorator.py | 2 +- src/gt4py/next/ffront/stages.py | 4 +- src/gt4py/next/otf/arguments.py | 15 +- src/gt4py/next/otf/compilation/compiler.py | 2 +- src/gt4py/next/otf/stages.py | 47 +----- src/gt4py/next/otf/toolchain.py | 8 +- src/gt4py/next/otf/workflow.py | 14 +- .../codegens/gtfn/gtfn_module.py | 2 +- .../runners/dace/workflow/backend.py | 8 +- .../runners/dace/workflow/compilation.py | 2 +- .../runners/dace/workflow/factory.py | 6 +- .../runners/dace/workflow/translation.py | 2 +- .../next/program_processors/runners/gtfn.py | 8 +- src/gt4py/next/utils.py | 21 ++- tests/eve_tests/unit_tests/test_utils.py | 23 +++ .../unit_tests/ffront_tests/test_stages.py | 142 +++--------------- tests/next_tests/unit_tests/test_utils.py | 54 +++++++ 19 files changed, 169 insertions(+), 207 deletions(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index f5c6eb9169..f00e394c9a 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -1234,12 +1234,14 @@ def getitem(self, *indices: Union[int, str], default: Any = NOTHING) -> XIterabl >>> list(it.getitem(0)) ['a', 'b', 'c'] - >>> it = xiter([ - ... dict(name="AA", age=20, country="US"), - ... dict(name="BB", age=30, country="UK"), - ... dict(name="CC", age=40, country="EU"), - ... dict(country="CH"), - ... ]) + >>> it = xiter( + ... [ + ... dict(name="AA", age=20, country="US"), + ... dict(name="BB", age=30, country="UK"), + ... dict(name="CC", age=40, country="EU"), + ... dict(country="CH"), + ... ] + ... ) >>> list(it.getitem("name", "age", default=None)) [('AA', 20), ('BB', 30), ('CC', 40), (None, None)] diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 329d51775f..80bc7e263a 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -145,7 +145,7 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: # `transforms` -> `frontend_transforms` # `executor` -> `backend_transforms` @dataclasses.dataclass(frozen=True) -class Backend(utils.FingerprintedDataclass, Generic[core_defs.DeviceTypeT]): +class Backend(utils.CachedFingerprintedDataclass, Generic[core_defs.DeviceTypeT]): name: str executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index db60062611..bc59ce8069 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -84,7 +84,7 @@ def embedded_program_call_context( @dataclasses.dataclass(frozen=True) class _CompilableGTEntryPointMixin( - utils.FingerprintedDataclass, Generic[ffront_stages.DSLDefinitionT] + utils.CachedFingerprintedDataclass, Generic[ffront_stages.DSLDefinitionT] ): """ Mixing used by program and program-like objects. diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 25aec84432..d0a6700ace 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -42,7 +42,7 @@ utils.fingerprint_reducer, { types.FunctionType: lambda f: ( - f.__class__, + tuple, (), ( source_utils.make_source_definition_from_function(f), @@ -59,7 +59,7 @@ @dataclasses.dataclass(frozen=True) -class BaseStage(utils.FingerprintedDataclass): +class BaseStage(utils.CachedFingerprintedDataclass): """Base class for optimized Fingerprinted implementations in frozen dataclasses.""" fingerprinter = staticmethod(fingerprinter) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 9d8953ce5a..6cdeddb63d 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -140,7 +140,7 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: @dataclasses.dataclass(frozen=True) -class CompileTimeArgs: +class CompileTimeArgs(utils.CachedFingerprintedDataclass): """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" args: tuple[ts.TypeSpec, ...] @@ -175,6 +175,19 @@ def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: def empty(cls) -> Self: return cls(tuple(), {}, {}, None, {}) + @staticmethod + def fingerprinter(instance: utils.Fingerprinted) -> str: + assert isinstance(instance, CompileTimeArgs) + return utils.fingerprint( + ( + instance.args, + sorted(instance.kwargs.items()), + sorted(instance.offset_provider.items()), + instance.column_axis, + instance.argument_descriptor_contexts, + ) + ) + # This is not really accurate, just an approximation NeedsValueExtraction: TypeAlias = MaybeNestedInTuple[named_collections.CustomNamedCollection] diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 56b96a6cda..9dc62b4682 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -55,7 +55,7 @@ class Compiler( stages.ExecutableProgram, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, ): """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index b6816b1cc3..ec61a4c2fe 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -12,54 +12,17 @@ from collections.abc import Callable from typing import Generic, Optional, Protocol, TypeAlias, TypeVar -from gt4py.eve import utils -from gt4py.next import common -from gt4py.next.iterator import ir as itir -from gt4py.next.otf import code_specs, definitions +from gt4py.next import utils +from gt4py.next.otf import code_specs from gt4py.next.otf.binding import interface -def compilation_hash(program_def: definitions.CompilableProgramDef) -> int: - """Given closure compute a hash uniquely determining if we need to recompile.""" - offset_provider = program_def.args.offset_provider - return hash( - ( - program_def.data, - # As the frontend types contain lists they are not hashable. As a workaround we just - # use content_hash here. - utils.content_hash(tuple(arg for arg in program_def.args.args)), - common.hash_offset_provider_items_by_id(offset_provider) if offset_provider else None, - program_def.args.column_axis, - ) - ) - - -def fingerprint_compilable_program(program_def: definitions.CompilableProgramDef) -> str: - """ - Generates a unique hash string for a stencil source program representing - the program, sorted offset_provider, and column_axis. - """ - program: itir.Program = program_def.data - offset_provider: common.OffsetProvider = program_def.args.offset_provider - column_axis: Optional[common.Dimension] = program_def.args.column_axis - - program_hash = utils.content_hash( - ( - program.fingerprint(), - sorted(offset_provider.items(), key=lambda el: el[0]), - column_axis, - ) - ) - - return program_hash - - CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) @dataclasses.dataclass(frozen=True) -class ProgramSource(Generic[CodeSpecT]): +class ProgramSource(utils.CachedFingerprintedDataclass, Generic[CodeSpecT]): """ Standalone source code translated from an IR along with information relevant for OTF compilation. @@ -76,7 +39,7 @@ class ProgramSource(Generic[CodeSpecT]): @dataclasses.dataclass(frozen=True) -class BindingSource(Generic[CodeSpecT, TargetCodeSpecT]): +class BindingSource(utils.CachedFingerprintedDataclass, Generic[CodeSpecT, TargetCodeSpecT]): """ Companion source code for translated program source code. @@ -92,7 +55,7 @@ class BindingSource(Generic[CodeSpecT, TargetCodeSpecT]): # TODO(ricoh): reconsider name in view of future backends producing standalone compilable ProgramSource code @dataclasses.dataclass(frozen=True) -class CompilableProject(Generic[CodeSpecT, TargetCodeSpecT]): +class CompilableProject(utils.CachedFingerprintedDataclass, Generic[CodeSpecT, TargetCodeSpecT]): """ Encapsulate all the source code required for OTF compilation. diff --git a/src/gt4py/next/otf/toolchain.py b/src/gt4py/next/otf/toolchain.py index 1c5088bf96..1c24a99fe8 100644 --- a/src/gt4py/next/otf/toolchain.py +++ b/src/gt4py/next/otf/toolchain.py @@ -23,7 +23,7 @@ @dataclasses.dataclass -class ConcreteArtifact(Generic[DefT, ArgsT]): +class ConcreteArtifact(utils.CachedFingerprintedDataclass, Generic[DefT, ArgsT]): data: DefT args: ArgsT @@ -33,7 +33,7 @@ class DataOnlyAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, workflow.Workflow[ConcreteArtifact[S, ArgsT], ConcreteArtifact[T, ArgsT]], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, Generic[ArgsT, S, T], ): step: workflow.Workflow[S, T] @@ -47,7 +47,7 @@ class ArgsOnlyAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, workflow.Workflow[ConcreteArtifact[DefT, S], ConcreteArtifact[DefT, T]], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, Generic[DefT, S, T], ): step: workflow.Workflow[S, T] @@ -61,7 +61,7 @@ class StripArgsAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, workflow.Workflow[ConcreteArtifact[S, ArgsT], T], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, Generic[ArgsT, S, T], ): step: workflow.Workflow[S, T] diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 740e258127..0d9f5ef561 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -95,7 +95,7 @@ def chain( class NamedStepSequence( ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, ): """ Workflow with linear succession of named steps. @@ -164,7 +164,7 @@ def step_order(self) -> list[str]: class MultiWorkflow( ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, ): """A flexible workflow, where the sequence of steps depends on the input type.""" @@ -182,7 +182,7 @@ def step_order(self, inp: StartT) -> list[str]: @dataclasses.dataclass(frozen=True) class StepSequence( ChainableWorkflowMixin[StartT, EndT], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, ): """ Composable workflow of single input callables. @@ -235,7 +235,7 @@ def start(cls, first_step: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[Sta class CachedStep( ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, Generic[StartT, EndT, HashT], ): """ @@ -266,9 +266,9 @@ class CachedStep( key_function: Callable[[StartT], HashT] = dataclasses.field( default=utils.fingerprint, metadata=utils.gt4py_metadata(pickle=False) ) # type: ignore[assignment] - cache: OpaqueMutableMapping[HashT, EndT] = dataclasses.field( + cache: OpaqueMutableMapping[str, EndT] = dataclasses.field( repr=False, default_factory=dict, metadata=utils.gt4py_metadata(pickle=False) - ) # type: ignore[assignment] + ) def __call__(self, inp: StartT) -> EndT: """Run the step only if the input is not cached, else return from cache.""" @@ -287,7 +287,7 @@ def cache_key(self, inp: StartT) -> str: class SkippableStep( ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, ): step: Workflow[StartT, EndT] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index d6112b91a6..b669219c3d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -46,7 +46,7 @@ class GTFNTranslationStep( definitions.CompilableProgramDef, stages.ProgramSource[code_specs.HeaderAndSourceCodeSpec], ], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, ): code_spec: Optional[code_specs.HeaderAndSourceCodeSpec] = None # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index 1bdd38b6c4..31fb7120a5 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -15,8 +15,8 @@ import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import definitions as core_defs -from gt4py.next import backend, common, config -from gt4py.next.otf import stages, workflow +from gt4py.next import backend, common, config, utils +from gt4py.next.otf import workflow from gt4py.next.program_processors.runners.dace.workflow.factory import DaCeWorkflowFactory @@ -44,12 +44,12 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, key_function=o.hash_function) + lambda o: workflow.CachedStep(o.otf_workflow, key_function=o.key_function) ), name_cached="_cached", ) device_type = core_defs.DeviceType.CPU - hash_function = stages.compilation_hash + key_function = utils.fingerprint otf_workflow = factory.SubFactory( DaCeWorkflowFactory, device_type=factory.SelfAttribute("..device_type"), diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 6491516ba7..58580e7921 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -125,7 +125,7 @@ class DaCeCompiler( CompiledDaceProgram, ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, ): """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index bbe3d43c14..8514808ed2 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -14,8 +14,8 @@ import factory from gt4py._core import definitions as core_defs, filecache -from gt4py.next import config -from gt4py.next.otf import recipes, stages, workflow +from gt4py.next import config, utils +from gt4py.next.otf import recipes, workflow from gt4py.next.program_processors.runners.dace.workflow import ( bindings as bindings_step, decoration as decoration_step, @@ -46,7 +46,7 @@ class Params: translation=factory.LazyAttribute( lambda o: workflow.CachedStep( o.bare_translation, - key_function=stages.fingerprint_compilable_program, + key_function=utils.fingerprint, cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "translation_cache")), ) ), diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 35a7f56417..e945095694 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -355,7 +355,7 @@ class DaCeTranslator( stages.ProgramSource[code_specs.SDFGCodeSpec], ], definitions.TranslationStep[code_specs.SDFGCodeSpec], - utils.FingerprintedDataclass, + utils.CachedFingerprintedDataclass, ): device_type: core_defs.DeviceType auto_optimize: bool diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 9d97791f45..31dbc0afdf 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -15,7 +15,7 @@ import gt4py._core.definitions as core_defs import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, common, config, field_utils +from gt4py.next import backend, common, config, field_utils, utils from gt4py.next.embedded import nd_array_field from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow @@ -123,7 +123,7 @@ class Params: translation=factory.LazyAttribute( lambda o: workflow.CachedStep( o.bare_translation, - key_function=stages.fingerprint_compilable_program, + key_function=utils.fingerprint, cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), ) ), @@ -165,12 +165,12 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, key_function=o.hash_function) + lambda o: workflow.CachedStep(o.otf_workflow, key_function=o.key_function) ), name_cached="_cached", ) device_type = core_defs.DeviceType.CPU - hash_function = stages.compilation_hash + key_function = utils.fingerprint otf_workflow = factory.SubFactory( GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index bd0948d11c..24eb4b346b 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -195,7 +195,7 @@ def fingerprint(self: Self) -> str: fingerprint_reducer = eve_utils.singledispatcher( lambda _: NotImplemented, - {FingerprintedABC: lambda obj: (obj.__class__, (), (obj.fingerprint,))}, + {FingerprintedABC: lambda obj: (type(obj), (), (obj.fingerprint,))}, ) fingerprint_pickler = eve_utils.custom_pickler(fingerprint_reducer, name="FingerprintPickler") @@ -216,34 +216,33 @@ class FingerprintedMixin: __slots__ = () - fingerprinter = staticmethod(fingerprint) + @staticmethod + def fingerprinter(instance: Fingerprinted) -> str: + try: + return fingerprint(instance.__reduce__()) + except AttributeError: + return "" @property def fingerprint(self) -> str: return self.fingerprinter(self) -assert issubclass(FingerprintedMixin, FingerprintedABC) -if TYPE_CHECKING: - _FM: type[Fingerprinted] = FingerprintedMixin - - -class CachedFingerprintedMixin: +class CachedFingerprintedMixin(FingerprintedMixin): """Mixin to add an optimized implementation of the Fingerprinted protocol to frozen classes.""" - fingerprinter = staticmethod(fingerprint) - @(functools.cached_property if not TYPE_CHECKING else property) def fingerprint(self) -> str: return self.fingerprinter(self) +assert issubclass(FingerprintedMixin, FingerprintedABC) assert issubclass(CachedFingerprintedMixin, FingerprintedABC) if TYPE_CHECKING: _CFM: type[Fingerprinted] = CachedFingerprintedMixin -class FingerprintedDataclass(CachedFingerprintedMixin, MetadataBasedPicklingMixin): +class CachedFingerprintedDataclass(CachedFingerprintedMixin, MetadataBasedPicklingMixin): __slots__ = () diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index 2a5a712c6d..9f1d90d91b 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -93,6 +93,29 @@ class MyBaseClass(abc.ABC): ) +def test_singledispatcher(): + from gt4py.eve.utils import singledispatcher + + class Base: + pass + + class Derived(Base): + pass + + dispatcher = singledispatcher( + lambda _: "default", + { + Base: lambda _: "base", + Derived: lambda _: "derived", + }, + ) + + assert dispatcher(1) == "default" + assert dispatcher(Base()) == "base" + assert dispatcher(Derived()) == "derived" + assert dispatcher.registry.keys() == {object, Base, Derived} + + class ModelClass(eve.datamodels.DataModel): data: Any diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py index 4ca940db83..164c3e5703 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_stages.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py @@ -6,134 +6,42 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import pytest - -from gt4py import next as gtx from gt4py.next.ffront import stages -from gt4py.next.otf import arguments, toolchain - - -@pytest.fixture -def idim(): - yield gtx.Dimension("I") - - -@pytest.fixture -def jdim(): - yield gtx.Dimension("J") - - -@pytest.fixture -def fieldop(idim): - @gtx.field_operator - def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]: - return a - - yield copy - - -@pytest.fixture -def samecode_fieldop(idim): - @gtx.field_operator - def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]: - return a - - yield copy - - -@pytest.fixture -def different_fieldop(jdim): - @gtx.field_operator - def copy(a: gtx.Field[[jdim], gtx.int32]) -> gtx.Field[[jdim], gtx.int32]: - return a - - yield copy - - -@pytest.fixture -def program(fieldop, idim): - copy = fieldop - - @gtx.program - def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]): - copy(a, out=out) - - yield copy_program - - -@pytest.fixture -def samecode_program(samecode_fieldop, idim): - copy = samecode_fieldop - - @gtx.program - def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]): - copy(a, out=out) - - yield copy_program - -@pytest.fixture -def different_program(different_fieldop, jdim): - copy = different_fieldop - @gtx.program - def copy_program(a: gtx.Field[[jdim], gtx.int32], out: gtx.Field[[jdim], gtx.int32]): - copy(a, out=out) +def _make_field_operator_definition(offset: int): + def copy(a): + return a + offset - yield copy_program + return copy -def test_fingerprint_stage_field_op_def(fieldop, samecode_fieldop, different_fieldop): - assert stages.fingerprint_stage(samecode_fieldop.definition_stage) != stages.fingerprint_stage( - fieldop.definition_stage - ) - assert stages.fingerprint_stage(different_fieldop.definition_stage) != stages.fingerprint_stage( - fieldop.definition_stage - ) +def _make_program_definition(offset: int): + def copy_program(a, out): + return a + out + offset + return copy_program -def test_fingerprint_stage_foast_op_def(fieldop, samecode_fieldop, different_fieldop): - foast = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( - toolchain.ConcreteArtifact(fieldop.definition_stage, arguments.CompileTimeArgs.empty()) - ).data - samecode = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( - toolchain.ConcreteArtifact( - samecode_fieldop.definition_stage, arguments.CompileTimeArgs.empty() - ) - ).data - different = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( - toolchain.ConcreteArtifact( - different_fieldop.definition_stage, arguments.CompileTimeArgs.empty() - ) - ).data - assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast) - assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast) +def test_fingerprinter_hashes_functions_by_source_and_closure(): + first = _make_field_operator_definition(1) + same = _make_field_operator_definition(1) + different = _make_field_operator_definition(2) + assert stages.fingerprinter(first) == stages.fingerprinter(same) + assert stages.fingerprinter(first) != stages.fingerprinter(different) -def test_fingerprint_stage_program_def(program, samecode_program, different_program): - assert stages.fingerprint_stage(samecode_program.definition_stage) != stages.fingerprint_stage( - program.definition_stage - ) - assert stages.fingerprint_stage(different_program.definition_stage) != stages.fingerprint_stage( - program.definition_stage - ) +def test_definition_stages_use_the_custom_fingerprinter(): + first_fieldop = stages.DSLFieldOperatorDef(definition=_make_field_operator_definition(1)) + same_fieldop = stages.DSLFieldOperatorDef(definition=_make_field_operator_definition(1)) + different_fieldop = stages.DSLFieldOperatorDef(definition=_make_field_operator_definition(2)) -def test_fingerprint_stage_past_def(program, samecode_program, different_program): - past = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( - toolchain.ConcreteArtifact(program.definition_stage, arguments.CompileTimeArgs.empty()) - ) - samecode = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( - toolchain.ConcreteArtifact( - samecode_program.definition_stage, arguments.CompileTimeArgs.empty() - ) - ) - different = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( - toolchain.ConcreteArtifact( - different_program.definition_stage, arguments.CompileTimeArgs.empty() - ) - ) + first_program = stages.DSLProgramDef(definition=_make_program_definition(1)) + same_program = stages.DSLProgramDef(definition=_make_program_definition(1)) + different_program = stages.DSLProgramDef(definition=_make_program_definition(2)) - assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(past) - assert stages.fingerprint_stage(different) != stages.fingerprint_stage(past) + assert first_fieldop.fingerprint == same_fieldop.fingerprint + assert first_fieldop.fingerprint != different_fieldop.fingerprint + assert first_program.fingerprint == same_program.fingerprint + assert first_program.fingerprint != different_program.fingerprint diff --git a/tests/next_tests/unit_tests/test_utils.py b/tests/next_tests/unit_tests/test_utils.py index ffa774d5f2..ef34325aa0 100644 --- a/tests/next_tests/unit_tests/test_utils.py +++ b/tests/next_tests/unit_tests/test_utils.py @@ -91,6 +91,60 @@ def test_pickle_roundtrip(self, instance, expected_fields): assert getattr(restored, field_name) == expected_value +class _DuckFingerprinted: + @property + def fingerprint(self): + return "duck" + + @property + def fingerprinter(self): + return lambda _: "duck" + + +def test_fingerprintedabc_subclasshook_recognizes_duck_types(): + assert issubclass(_DuckFingerprinted, utils.FingerprintedABC) + + +class _TestFingerprinted(utils.FingerprintedABC): + def __init__(self, value: str, noise: str): + self.value = value + self.noise = noise + + @staticmethod + def fingerprinter(instance: "_TestFingerprinted") -> str: + return instance.value + + +def test_fingerprint_uses_fingerprinted_state_only(): + a = _TestFingerprinted("id", noise="left") + b = _TestFingerprinted("id", noise="right") + c = _TestFingerprinted("other", noise="left") + + assert utils.fingerprint(a) == utils.fingerprint(b) + assert utils.fingerprint(a) != utils.fingerprint(c) + + +class _CachedTestFingerprinted(utils.CachedFingerprintedMixin): + calls = 0 + + def __init__(self, value: str): + self.value = value + + @staticmethod + def fingerprinter(instance: "_CachedTestFingerprinted") -> str: + _CachedTestFingerprinted.calls += 1 + return instance.value + + +def test_cached_fingerprinted_mixin_computes_fingerprint_once_per_instance(): + _CachedTestFingerprinted.calls = 0 + instance = _CachedTestFingerprinted("stable") + + assert instance.fingerprint == "stable" + assert instance.fingerprint == "stable" + assert _CachedTestFingerprinted.calls == 1 + + def test_tree_map_default(): @utils.tree_map def testee(x): From 51c8eaaf9ca92f6ba1c6337f794fdfbbdb57b20f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Tue, 12 May 2026 09:38:06 +0200 Subject: [PATCH 4/4] Add build cache version id --- src/gt4py/next/config.py | 12 ++++++++++++ .../runners/dace/workflow/factory.py | 2 +- src/gt4py/next/program_processors/runners/gtfn.py | 2 +- src/gt4py/next/utils.py | 6 ++++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index 03f84a50c1..4558209911 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -90,6 +90,18 @@ def env_flag_to_int(name: str, default: int) -> int: ] +def _get_build_cache_version_id() -> str: + from gt4py import __version__ + + return __version__ + + +#: Version ID for the build cache. It should only be overridden by advanced users +#: testing toolchain changes that are expected to break compatibility with previously cached builds. +BUILD_CACHE_VERSION_ID: str = ( + os.environ.get("BUILD_CACHE_VERSION_ID") or _get_build_cache_version_id() +) + #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. # FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 8514808ed2..bc2f98aea6 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -46,7 +46,7 @@ class Params: translation=factory.LazyAttribute( lambda o: workflow.CachedStep( o.bare_translation, - key_function=utils.fingerprint, + key_function=utils.versioned_fingerprint, cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "translation_cache")), ) ), diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 31dbc0afdf..da25510854 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -123,7 +123,7 @@ class Params: translation=factory.LazyAttribute( lambda o: workflow.CachedStep( o.bare_translation, - key_function=utils.fingerprint, + key_function=utils.versioned_fingerprint, cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), ) ), diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 24eb4b346b..62887335a5 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -34,6 +34,7 @@ from gt4py.eve import datamodels, utils as eve_utils from gt4py.eve.extended_typing import Self +from gt4py.next import config GT4PY_CLASS_METADATA_NS: Final[str] = "GT4PY_META" @@ -199,6 +200,7 @@ def fingerprint(self: Self) -> str: ) fingerprint_pickler = eve_utils.custom_pickler(fingerprint_reducer, name="FingerprintPickler") + fingerprint: Callable[[Any], str] = functools.partial( eve_utils.content_hash, pickler=fingerprint_pickler ) @@ -210,6 +212,10 @@ def fingerprint(self: Self) -> str: instead use the `fingerprint` property of the object as its content hash. """ +versioned_fingerprint: Callable[[Any], str] = functools.partial( + eve_utils.content_hash, config.BUILD_CACHE_VERSION_ID, pickler=fingerprint_pickler +) + class FingerprintedMixin: """General mixin to add support for the Fingerprinted protocol to any class."""