diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 350f6ac5e8..fc8d734cc9 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -13,8 +13,8 @@ - `gt4py.next` supports structured and unstructured grid. """ +from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__ # noqa: I001 from . import eve, storage -from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__ __all__ = [ diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8da5cdfd22..f00e394c9a 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -630,6 +630,16 @@ def is_noninstantiable(cls: Type[_T]) -> bool: return "__noninstantiable__" in cls.__dict__ +def singledispatcher( + default: Callable[P, T], + implementations: dict[type, Callable[[Any], Any]], +) -> Callable[P, T]: + result = functools.singledispatch(default) + for cls, func in implementations.items(): + result.register(cls)(func) + return result + + def content_hash( *args: Any, hash_algorithm: str | xtyping.HashlibAlgorithm | None = None, @@ -669,6 +679,7 @@ def content_hash( def custom_pickler( reducer: Callable[[Any], tuple | types.NotImplementedType], + *, name: str | None = None, ) -> type[pickle.Pickler]: """ @@ -685,7 +696,9 @@ def custom_pickler( def custom_pickler_from_reducers( custom_reducers: dict[type, Callable[[Any], tuple | types.NotImplementedType]], + *, name: str | None = None, + default_reducer: Callable[[Any], tuple | types.NotImplementedType] = lambda _: NotImplemented, ) -> type[pickle.Pickler]: """ Create a pickler with the provided reducers registered in reducer override. @@ -696,13 +709,8 @@ def custom_pickler_from_reducers( the `dispatch_table` dict, to allow easy pickle customization of entire class hierarchies. """ - reducer = functools.singledispatch( - cast(Callable[[Any], tuple | types.NotImplementedType], lambda _: NotImplemented) - ) - for cls, func in custom_reducers.items(): - reducer.register(cls)(func) - return custom_pickler(reducer, name=name) + return custom_pickler(singledispatcher(default_reducer, custom_reducers), name=name) ddiff = deepdiff.diff.DeepDiff diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 6123da97e0..80bc7e263a 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -12,7 +12,7 @@ from typing import Generic from gt4py._core import definitions as core_defs -from gt4py.next import custom_layout_allocators as next_allocators +from gt4py.next import custom_layout_allocators as next_allocators, utils from gt4py.next.ffront import ( foast_to_gtir, foast_to_past, @@ -145,7 +145,7 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: # `transforms` -> `frontend_transforms` # `executor` -> `backend_transforms` @dataclasses.dataclass(frozen=True) -class Backend(Generic[core_defs.DeviceTypeT]): +class Backend(utils.CachedFingerprintedDataclass, Generic[core_defs.DeviceTypeT]): name: str executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index 03f84a50c1..4558209911 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -90,6 +90,18 @@ def env_flag_to_int(name: str, default: int) -> int: ] +def _get_build_cache_version_id() -> str: + from gt4py import __version__ + + return __version__ + + +#: Version ID for the build cache. It should only be overridden by advanced users +#: testing toolchain changes that are expected to break compatibility with previously cached builds. +BUILD_CACHE_VERSION_ID: str = ( + os.environ.get("BUILD_CACHE_VERSION_ID") or _get_build_cache_version_id() +) + #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. # FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key. diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e9aff84a15..fe8d06b450 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -106,7 +106,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: class NdArrayField( common.MutableField[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry, - utils.MetadataBasedPickling, + utils.MetadataBasedPicklingMixin, ): """ Shared field implementation for NumPy-like fields. diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index c749fcec01..bc59ce8069 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -83,7 +83,9 @@ def embedded_program_call_context( @dataclasses.dataclass(frozen=True) -class _CompilableGTEntryPointMixin(Generic[ffront_stages.DSLDefinitionT]): +class _CompilableGTEntryPointMixin( + utils.CachedFingerprintedDataclass, Generic[ffront_stages.DSLDefinitionT] +): """ Mixing used by program and program-like objects. @@ -862,23 +864,3 @@ def scan_operator_inner(definition: Callable) -> FieldOperator: ) return scan_operator_inner if definition is None else scan_operator_inner(definition) - - -@ffront_stages.add_content_to_fingerprint.register -def add_fieldop_to_fingerprint(obj: FieldOperator, hasher: xtyping.HashlibAlgorithm) -> None: - ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) - ffront_stages.add_content_to_fingerprint(obj.backend, hasher) - - -@ffront_stages.add_content_to_fingerprint.register -def add_foast_fieldop_to_fingerprint( - obj: FieldOperatorFromFoast, hasher: xtyping.HashlibAlgorithm -) -> None: - ffront_stages.add_content_to_fingerprint(obj.foast_stage, hasher) - ffront_stages.add_content_to_fingerprint(obj.backend, hasher) - - -@ffront_stages.add_content_to_fingerprint.register -def add_program_to_fingerprint(obj: Program, hasher: xtyping.HashlibAlgorithm) -> None: - ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) - ffront_stages.add_content_to_fingerprint(obj.backend, hasher) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3825072cb7..633e1b2cfe 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -47,7 +47,7 @@ def foast_to_gtir_factory( """Wrap `foast_to_gtir` into a chainable and, optionally, cached workflow step.""" wf = foast_to_gtir if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(step=wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 05b080b70b..5507011ad2 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -190,5 +190,5 @@ def operator_to_program_factory( foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory() ) if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ced0ff3905..89fdb5a043 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -99,7 +99,7 @@ def func_to_foast_factory( """Wrap `func_to_foast` in a chainable and optionally cached workflow step.""" wf = workflow.make_step(func_to_foast) if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(step=wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index ebc0de31b3..d026774d77 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -80,7 +80,7 @@ def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSLProgramDef """ wf = workflow.make_step(func_to_past) if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py index 6d9fb9123b..3008ff7e99 100644 --- a/src/gt4py/next/ffront/past_passes/linters.py +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -51,7 +51,7 @@ def linter_factory( ) -> workflow.Workflow[PASTProgramDef, PASTProgramDef]: wf = lint_misnamed_functions.chain(lint_undefined_symbols) if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(step=wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index ce794fd9dc..b014ae3007 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -41,7 +41,7 @@ def transform_program_args_factory( ) -> workflow.Workflow[ffront_stages.ConcretePASTProgramDef, ffront_stages.ConcretePASTProgramDef]: wf = transform_program_args if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 6b19e7cc1f..4292f80cf3 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -152,7 +152,7 @@ def past_to_gtir_factory( ) -> workflow.Workflow[ConcretePASTProgramDef, definitions.CompilableProgramDef]: wf = workflow.make_step(past_to_gtir) if cached: - wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter) return wf diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index da05c68669..d0a6700ace 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -22,24 +22,51 @@ from __future__ import annotations -import collections.abc import dataclasses import functools -import hashlib import types import typing from typing import Any, Optional, TypeVar -import xxhash - -from gt4py.eve import extended_typing as xtyping -from gt4py.next import common +from gt4py.eve import utils as eve_utils +from gt4py.next import common, utils from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils from gt4py.next.otf import arguments, toolchain +# Create a custom pickler for the BaseStage `fingerprinter` that handles +# `types.FunctionType` by using its source code and closure variables. +# This should be enough for the use case of GT4Py DSL definitions, +# which are expected to be pure functions without complicated closures. +fingerprint_reducer = eve_utils.singledispatcher( + utils.fingerprint_reducer, + { + types.FunctionType: lambda f: ( + tuple, + (), + ( + source_utils.make_source_definition_from_function(f), + source_utils.get_closure_vars_from_function(f), + ), + ) + }, +) + +fingerprinter = functools.partial( + eve_utils.content_hash, + pickler=eve_utils.custom_pickler(fingerprint_reducer, name="FFrontFingerprintPickler"), +) + + +@dataclasses.dataclass(frozen=True) +class BaseStage(utils.CachedFingerprintedDataclass): + """Base class for optimized Fingerprinted implementations in frozen dataclasses.""" + + fingerprinter = staticmethod(fingerprinter) + + @dataclasses.dataclass(frozen=True) -class DSLFieldOperatorDef: +class DSLFieldOperatorDef(BaseStage): definition: types.FunctionType node_class: type[foast.OperatorNode] = foast.FieldOperator attributes: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -53,7 +80,7 @@ class DSLFieldOperatorDef: @dataclasses.dataclass(frozen=True) -class FOASTOperatorDef: +class FOASTOperatorDef(BaseStage): foast_node: foast.OperatorNode closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None @@ -67,7 +94,7 @@ class FOASTOperatorDef: @dataclasses.dataclass(frozen=True) -class DSLProgramDef: +class DSLProgramDef(BaseStage): definition: types.FunctionType grid_type: Optional[common.GridType] = None debug: bool = False @@ -79,7 +106,7 @@ class DSLProgramDef: @dataclasses.dataclass(frozen=True) -class PASTProgramDef: +class PASTProgramDef(BaseStage): past_node: past.Program closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None @@ -92,88 +119,3 @@ class PASTProgramDef: DSLDefinition = DSLFieldOperatorDef | DSLProgramDef DSLDefinitionT = TypeVar("DSLDefinitionT", DSLFieldOperatorDef, DSLProgramDef) - - -def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str: - hasher: xtyping.HashlibAlgorithm - if not algorithm: - hasher = xxhash.xxh64() # type: ignore[assignment] # fixing this requires https://github.com/ifduyue/python-xxhash/issues/104 - elif isinstance(algorithm, str): - hasher = hashlib.new(algorithm) - else: - hasher = algorithm - - add_content_to_fingerprint(obj, hasher) - return hasher.hexdigest() - - -@functools.singledispatch -def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: - hasher.update(str(obj).encode()) - - -for t in (str, int): - add_content_to_fingerprint.register(t, add_content_to_fingerprint.registry[object]) - - -@add_content_to_fingerprint.register(DSLFieldOperatorDef) -@add_content_to_fingerprint.register(FOASTOperatorDef) -@add_content_to_fingerprint.register(DSLProgramDef) -@add_content_to_fingerprint.register(PASTProgramDef) -@add_content_to_fingerprint.register(toolchain.ConcreteArtifact) -@add_content_to_fingerprint.register(arguments.CompileTimeArgs) -def add_stage_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: - add_content_to_fingerprint(obj.__class__, hasher) - for field in dataclasses.fields(obj): - add_content_to_fingerprint(getattr(obj, field.name), hasher) - - -def add_jit_args_id_to_fingerprint( - obj: arguments.JITArgs, hasher: xtyping.HashlibAlgorithm -) -> None: - add_content_to_fingerprint(str(id(obj)), hasher) - - -@add_content_to_fingerprint.register -def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgorithm) -> None: - sourcedef = source_utils.SourceDefinition.from_function(obj) - for item in sourcedef: - add_content_to_fingerprint(item, hasher) - - closure_vars = source_utils.get_closure_vars_from_function(obj) - for item in sorted(closure_vars.items(), key=lambda x: x[0]): - add_content_to_fingerprint(item, hasher) - - -@add_content_to_fingerprint.register -def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - # just a small helper to additionally allow sorting types (by just using their name) - def key_function(key: Any) -> Any: - if isinstance(key, type): - return key.__module__, key.__qualname__ - return key - - for key in sorted(obj.keys(), key=key_function): - add_content_to_fingerprint(key, hasher) - add_content_to_fingerprint(obj[key], hasher) - - -@add_content_to_fingerprint.register -def add_type_to_fingerprint(obj: type, hasher: xtyping.HashlibAlgorithm) -> None: - hasher.update(obj.__name__.encode()) - - -@add_content_to_fingerprint.register -def add_sequence_to_fingerprint( - obj: collections.abc.Iterable, hasher: xtyping.HashlibAlgorithm -) -> None: - for item in obj: - add_content_to_fingerprint(item, hasher) - - -@add_content_to_fingerprint.register -def add_foast_located_node_to_fingerprint( - obj: foast.LocatedNode, hasher: xtyping.HashlibAlgorithm -) -> None: - add_content_to_fingerprint(obj.location, hasher) - add_content_to_fingerprint(str(obj), hasher) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 9d8953ce5a..6cdeddb63d 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -140,7 +140,7 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: @dataclasses.dataclass(frozen=True) -class CompileTimeArgs: +class CompileTimeArgs(utils.CachedFingerprintedDataclass): """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" args: tuple[ts.TypeSpec, ...] @@ -175,6 +175,19 @@ def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: def empty(cls) -> Self: return cls(tuple(), {}, {}, None, {}) + @staticmethod + def fingerprinter(instance: utils.Fingerprinted) -> str: + assert isinstance(instance, CompileTimeArgs) + return utils.fingerprint( + ( + instance.args, + sorted(instance.kwargs.items()), + sorted(instance.offset_provider.items()), + instance.column_axis, + instance.argument_descriptor_contexts, + ) + ) + # This is not really accurate, just an approximation NeedsValueExtraction: TypeAlias = MaybeNestedInTuple[named_collections.CustomNamedCollection] diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 3748d95192..9dc62b4682 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -15,7 +15,7 @@ import factory from gt4py._core import locking -from gt4py.next import config +from gt4py.next import config, utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer @@ -55,6 +55,7 @@ class Compiler( stages.ExecutableProgram, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], + utils.CachedFingerprintedDataclass, ): """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index b6816b1cc3..ec61a4c2fe 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -12,54 +12,17 @@ from collections.abc import Callable from typing import Generic, Optional, Protocol, TypeAlias, TypeVar -from gt4py.eve import utils -from gt4py.next import common -from gt4py.next.iterator import ir as itir -from gt4py.next.otf import code_specs, definitions +from gt4py.next import utils +from gt4py.next.otf import code_specs from gt4py.next.otf.binding import interface -def compilation_hash(program_def: definitions.CompilableProgramDef) -> int: - """Given closure compute a hash uniquely determining if we need to recompile.""" - offset_provider = program_def.args.offset_provider - return hash( - ( - program_def.data, - # As the frontend types contain lists they are not hashable. As a workaround we just - # use content_hash here. - utils.content_hash(tuple(arg for arg in program_def.args.args)), - common.hash_offset_provider_items_by_id(offset_provider) if offset_provider else None, - program_def.args.column_axis, - ) - ) - - -def fingerprint_compilable_program(program_def: definitions.CompilableProgramDef) -> str: - """ - Generates a unique hash string for a stencil source program representing - the program, sorted offset_provider, and column_axis. - """ - program: itir.Program = program_def.data - offset_provider: common.OffsetProvider = program_def.args.offset_provider - column_axis: Optional[common.Dimension] = program_def.args.column_axis - - program_hash = utils.content_hash( - ( - program.fingerprint(), - sorted(offset_provider.items(), key=lambda el: el[0]), - column_axis, - ) - ) - - return program_hash - - CodeSpecT = TypeVar("CodeSpecT", bound=code_specs.SourceCodeSpec) TargetCodeSpecT = TypeVar("TargetCodeSpecT", bound=code_specs.SourceCodeSpec) @dataclasses.dataclass(frozen=True) -class ProgramSource(Generic[CodeSpecT]): +class ProgramSource(utils.CachedFingerprintedDataclass, Generic[CodeSpecT]): """ Standalone source code translated from an IR along with information relevant for OTF compilation. @@ -76,7 +39,7 @@ class ProgramSource(Generic[CodeSpecT]): @dataclasses.dataclass(frozen=True) -class BindingSource(Generic[CodeSpecT, TargetCodeSpecT]): +class BindingSource(utils.CachedFingerprintedDataclass, Generic[CodeSpecT, TargetCodeSpecT]): """ Companion source code for translated program source code. @@ -92,7 +55,7 @@ class BindingSource(Generic[CodeSpecT, TargetCodeSpecT]): # TODO(ricoh): reconsider name in view of future backends producing standalone compilable ProgramSource code @dataclasses.dataclass(frozen=True) -class CompilableProject(Generic[CodeSpecT, TargetCodeSpecT]): +class CompilableProject(utils.CachedFingerprintedDataclass, Generic[CodeSpecT, TargetCodeSpecT]): """ Encapsulate all the source code required for OTF compilation. diff --git a/src/gt4py/next/otf/toolchain.py b/src/gt4py/next/otf/toolchain.py index 0c816759ff..1c24a99fe8 100644 --- a/src/gt4py/next/otf/toolchain.py +++ b/src/gt4py/next/otf/toolchain.py @@ -12,6 +12,7 @@ import typing from typing import Generic +from gt4py.next import utils from gt4py.next.otf import workflow @@ -22,7 +23,7 @@ @dataclasses.dataclass -class ConcreteArtifact(Generic[DefT, ArgsT]): +class ConcreteArtifact(utils.CachedFingerprintedDataclass, Generic[DefT, ArgsT]): data: DefT args: ArgsT @@ -32,6 +33,7 @@ class DataOnlyAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, workflow.Workflow[ConcreteArtifact[S, ArgsT], ConcreteArtifact[T, ArgsT]], + utils.CachedFingerprintedDataclass, Generic[ArgsT, S, T], ): step: workflow.Workflow[S, T] @@ -45,6 +47,7 @@ class ArgsOnlyAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, workflow.Workflow[ConcreteArtifact[DefT, S], ConcreteArtifact[DefT, T]], + utils.CachedFingerprintedDataclass, Generic[DefT, S, T], ): step: workflow.Workflow[S, T] @@ -58,6 +61,7 @@ class StripArgsAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, workflow.Workflow[ConcreteArtifact[S, ArgsT], T], + utils.CachedFingerprintedDataclass, Generic[ArgsT, S, T], ): step: workflow.Workflow[S, T] diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 08abadd7e5..0d9f5ef561 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -17,6 +17,7 @@ from typing_extensions import Self from gt4py.eve.extended_typing import OpaqueMutableMapping +from gt4py.next import utils StartT = TypeVar("StartT") @@ -92,7 +93,9 @@ def chain( @dataclasses.dataclass(frozen=True) class NamedStepSequence( - ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT] + ChainableWorkflowMixin[StartT, EndT], + ReplaceEnabledWorkflowMixin[StartT, EndT], + utils.CachedFingerprintedDataclass, ): """ Workflow with linear succession of named steps. @@ -159,7 +162,9 @@ def step_order(self) -> list[str]: @dataclasses.dataclass(frozen=True) class MultiWorkflow( - ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT] + ChainableWorkflowMixin[StartT, EndT], + ReplaceEnabledWorkflowMixin[StartT, EndT], + utils.CachedFingerprintedDataclass, ): """A flexible workflow, where the sequence of steps depends on the input type.""" @@ -175,7 +180,10 @@ def step_order(self, inp: StartT) -> list[str]: @dataclasses.dataclass(frozen=True) -class StepSequence(ChainableWorkflowMixin[StartT, EndT]): +class StepSequence( + ChainableWorkflowMixin[StartT, EndT], + utils.CachedFingerprintedDataclass, +): """ Composable workflow of single input callables. @@ -227,6 +235,7 @@ def start(cls, first_step: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[Sta class CachedStep( ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT], + utils.CachedFingerprintedDataclass, Generic[StartT, EndT, HashT], ): """ @@ -254,22 +263,31 @@ class CachedStep( """ step: Workflow[StartT, EndT] - hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] - cache: OpaqueMutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict) + key_function: Callable[[StartT], HashT] = dataclasses.field( + default=utils.fingerprint, metadata=utils.gt4py_metadata(pickle=False) + ) # type: ignore[assignment] + cache: OpaqueMutableMapping[str, EndT] = dataclasses.field( + repr=False, default_factory=dict, metadata=utils.gt4py_metadata(pickle=False) + ) def __call__(self, inp: StartT) -> EndT: """Run the step only if the input is not cached, else return from cache.""" - hash_ = self.hash_function(inp) + hash_ = self.cache_key(inp) try: result = self.cache[hash_] except KeyError: result = self.cache[hash_] = self.step(inp) return result + def cache_key(self, inp: StartT) -> str: + return utils.fingerprint((self, self.key_function(inp))) + @dataclasses.dataclass(frozen=True) class SkippableStep( - ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT] + ChainableWorkflowMixin[StartT, EndT], + ReplaceEnabledWorkflowMixin[StartT, EndT], + utils.CachedFingerprintedDataclass, ): step: Workflow[StartT, EndT] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 2135af7fbb..b669219c3d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -17,7 +17,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import codegen -from gt4py.next import common +from gt4py.next import common, utils from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import pass_manager @@ -46,6 +46,7 @@ class GTFNTranslationStep( definitions.CompilableProgramDef, stages.ProgramSource[code_specs.HeaderAndSourceCodeSpec], ], + utils.CachedFingerprintedDataclass, ): code_spec: Optional[code_specs.HeaderAndSourceCodeSpec] = None # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index de6778a750..31fb7120a5 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -15,8 +15,8 @@ import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import definitions as core_defs -from gt4py.next import backend, common, config -from gt4py.next.otf import stages, workflow +from gt4py.next import backend, common, config, utils +from gt4py.next.otf import workflow from gt4py.next.program_processors.runners.dace.workflow.factory import DaCeWorkflowFactory @@ -44,12 +44,12 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + lambda o: workflow.CachedStep(o.otf_workflow, key_function=o.key_function) ), name_cached="_cached", ) device_type = core_defs.DeviceType.CPU - hash_function = stages.compilation_hash + key_function = utils.fingerprint otf_workflow = factory.SubFactory( DaCeWorkflowFactory, device_type=factory.SelfAttribute("..device_type"), diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index e1747b7ac3..58580e7921 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -18,7 +18,7 @@ import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import common, config +from gt4py.next import common, config, utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon @@ -125,6 +125,7 @@ class DaCeCompiler( CompiledDaceProgram, ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], + utils.CachedFingerprintedDataclass, ): """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 62febd0965..bc2f98aea6 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -14,8 +14,8 @@ import factory from gt4py._core import definitions as core_defs, filecache -from gt4py.next import config -from gt4py.next.otf import recipes, stages, workflow +from gt4py.next import config, utils +from gt4py.next.otf import recipes, workflow from gt4py.next.program_processors.runners.dace.workflow import ( bindings as bindings_step, decoration as decoration_step, @@ -46,7 +46,7 @@ class Params: translation=factory.LazyAttribute( lambda o: workflow.CachedStep( o.bare_translation, - hash_function=stages.fingerprint_compilable_program, + key_function=utils.versioned_fingerprint, cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "translation_cache")), ) ), diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index ad8e8ea04b..e945095694 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -15,7 +15,7 @@ import factory from gt4py._core import definitions as core_defs -from gt4py.next import common, config +from gt4py.next import common, config, utils from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import code_specs, definitions, stages, workflow @@ -355,6 +355,7 @@ class DaCeTranslator( stages.ProgramSource[code_specs.SDFGCodeSpec], ], definitions.TranslationStep[code_specs.SDFGCodeSpec], + utils.CachedFingerprintedDataclass, ): device_type: core_defs.DeviceType auto_optimize: bool diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c1743dea6a..da25510854 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -15,7 +15,7 @@ import gt4py._core.definitions as core_defs import gt4py.next.custom_layout_allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, common, config, field_utils +from gt4py.next import backend, common, config, field_utils, utils from gt4py.next.embedded import nd_array_field from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow @@ -123,7 +123,7 @@ class Params: translation=factory.LazyAttribute( lambda o: workflow.CachedStep( o.bare_translation, - hash_function=stages.fingerprint_compilable_program, + key_function=utils.versioned_fingerprint, cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), ) ), @@ -165,12 +165,12 @@ class Params: ) cached = factory.Trait( executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + lambda o: workflow.CachedStep(o.otf_workflow, key_function=o.key_function) ), name_cached="_cached", ) device_type = core_defs.DeviceType.CPU - hash_function = stages.compilation_hash + key_function = utils.fingerprint otf_workflow = factory.SubFactory( GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 942bdb8459..62887335a5 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -8,6 +8,7 @@ from __future__ import annotations +import abc import copyreg import dataclasses import functools @@ -15,12 +16,14 @@ import itertools import types from typing import ( + TYPE_CHECKING, Any, Callable, ClassVar, Final, Optional, ParamSpec, + Protocol, Sequence, TypeAlias, TypeGuard, @@ -30,6 +33,8 @@ ) from gt4py.eve import datamodels, utils as eve_utils +from gt4py.eve.extended_typing import Self +from gt4py.next import config GT4PY_CLASS_METADATA_NS: Final[str] = "GT4PY_META" @@ -111,7 +116,7 @@ def __getstate__(self: object) -> _StandardPickleState: return __getstate__ -class MetadataBasedPickling: +class MetadataBasedPicklingMixin: """ Mixin for adding metadata-based pickling to dataclass-like objects. @@ -138,6 +143,115 @@ def __getstate__(self) -> _StandardPickleState: # implementation. +class Fingerprinted(Protocol): + """ + Protocol for objects that can be fingerprinted with a custom function. + + The fingerprint should be a stable hash string representing the state of the object. + """ + + __slots__ = () + + @property + def fingerprint(self) -> str: + """Get the fingerprint of the object.""" + ... + + @property + def fingerprinter(self) -> Callable[[Fingerprinted], str]: + """Get the fingerprinting function for the object.""" + ... + + +class FingerprintedABC(abc.ABC): + """ + ABC of objects implementing the fingerprinting protocol. + + It provides a custom subclass hook to recognize classes implementing the + protocol without inheriting from the ABC, without the performance problems + of using `isinstance` checks on runtime-checkable protocols directly. + """ + + __slots__ = () + + @classmethod + def __subclasshook__(cls, subclass: type) -> bool | types.NotImplementedType: + if ( + cls is FingerprintedABC + and hasattr(subclass, "fingerprint") + and hasattr(subclass, "fingerprinter") + ): + return True + + return NotImplemented + + @staticmethod + @abc.abstractmethod + def fingerprinter(instance: FingerprintedABC) -> str: ... # to be implemented by subclasses + + @property + def fingerprint(self: Self) -> str: + return self.fingerprinter(self) + + +fingerprint_reducer = eve_utils.singledispatcher( + lambda _: NotImplemented, + {FingerprintedABC: lambda obj: (type(obj), (), (obj.fingerprint,))}, +) +fingerprint_pickler = eve_utils.custom_pickler(fingerprint_reducer, name="FingerprintPickler") + + +fingerprint: Callable[[Any], str] = functools.partial( + eve_utils.content_hash, pickler=fingerprint_pickler +) +""" +Default fingerprinting function for GT4Py objects. + +It uses `eve_utils.content_hash` as fingerprinting function. If the object +is an instance of a class implementing the `FingerprintedProtocol`, it will +instead use the `fingerprint` property of the object as its content hash. +""" + +versioned_fingerprint: Callable[[Any], str] = functools.partial( + eve_utils.content_hash, config.BUILD_CACHE_VERSION_ID, pickler=fingerprint_pickler +) + + +class FingerprintedMixin: + """General mixin to add support for the Fingerprinted protocol to any class.""" + + __slots__ = () + + @staticmethod + def fingerprinter(instance: Fingerprinted) -> str: + try: + return fingerprint(instance.__reduce__()) + except AttributeError: + return "" + + @property + def fingerprint(self) -> str: + return self.fingerprinter(self) + + +class CachedFingerprintedMixin(FingerprintedMixin): + """Mixin to add an optimized implementation of the Fingerprinted protocol to frozen classes.""" + + @(functools.cached_property if not TYPE_CHECKING else property) + def fingerprint(self) -> str: + return self.fingerprinter(self) + + +assert issubclass(FingerprintedMixin, FingerprintedABC) +assert issubclass(CachedFingerprintedMixin, FingerprintedABC) +if TYPE_CHECKING: + _CFM: type[Fingerprinted] = CachedFingerprintedMixin + + +class CachedFingerprintedDataclass(CachedFingerprintedMixin, MetadataBasedPicklingMixin): + __slots__ = () + + class RecursionGuard: """ Context manager to guard against inifinite recursion. diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index 2a5a712c6d..9f1d90d91b 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -93,6 +93,29 @@ class MyBaseClass(abc.ABC): ) +def test_singledispatcher(): + from gt4py.eve.utils import singledispatcher + + class Base: + pass + + class Derived(Base): + pass + + dispatcher = singledispatcher( + lambda _: "default", + { + Base: lambda _: "base", + Derived: lambda _: "derived", + }, + ) + + assert dispatcher(1) == "default" + assert dispatcher(Base()) == "base" + assert dispatcher(Derived()) == "derived" + assert dispatcher.registry.keys() == {object, Base, Derived} + + class ModelClass(eve.datamodels.DataModel): data: Any diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py index 4ca940db83..164c3e5703 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_stages.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py @@ -6,134 +6,42 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import pytest - -from gt4py import next as gtx from gt4py.next.ffront import stages -from gt4py.next.otf import arguments, toolchain - - -@pytest.fixture -def idim(): - yield gtx.Dimension("I") - - -@pytest.fixture -def jdim(): - yield gtx.Dimension("J") - - -@pytest.fixture -def fieldop(idim): - @gtx.field_operator - def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]: - return a - - yield copy - - -@pytest.fixture -def samecode_fieldop(idim): - @gtx.field_operator - def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]: - return a - - yield copy - - -@pytest.fixture -def different_fieldop(jdim): - @gtx.field_operator - def copy(a: gtx.Field[[jdim], gtx.int32]) -> gtx.Field[[jdim], gtx.int32]: - return a - - yield copy - - -@pytest.fixture -def program(fieldop, idim): - copy = fieldop - - @gtx.program - def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]): - copy(a, out=out) - - yield copy_program - - -@pytest.fixture -def samecode_program(samecode_fieldop, idim): - copy = samecode_fieldop - - @gtx.program - def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]): - copy(a, out=out) - - yield copy_program - -@pytest.fixture -def different_program(different_fieldop, jdim): - copy = different_fieldop - @gtx.program - def copy_program(a: gtx.Field[[jdim], gtx.int32], out: gtx.Field[[jdim], gtx.int32]): - copy(a, out=out) +def _make_field_operator_definition(offset: int): + def copy(a): + return a + offset - yield copy_program + return copy -def test_fingerprint_stage_field_op_def(fieldop, samecode_fieldop, different_fieldop): - assert stages.fingerprint_stage(samecode_fieldop.definition_stage) != stages.fingerprint_stage( - fieldop.definition_stage - ) - assert stages.fingerprint_stage(different_fieldop.definition_stage) != stages.fingerprint_stage( - fieldop.definition_stage - ) +def _make_program_definition(offset: int): + def copy_program(a, out): + return a + out + offset + return copy_program -def test_fingerprint_stage_foast_op_def(fieldop, samecode_fieldop, different_fieldop): - foast = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( - toolchain.ConcreteArtifact(fieldop.definition_stage, arguments.CompileTimeArgs.empty()) - ).data - samecode = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( - toolchain.ConcreteArtifact( - samecode_fieldop.definition_stage, arguments.CompileTimeArgs.empty() - ) - ).data - different = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( - toolchain.ConcreteArtifact( - different_fieldop.definition_stage, arguments.CompileTimeArgs.empty() - ) - ).data - assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast) - assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast) +def test_fingerprinter_hashes_functions_by_source_and_closure(): + first = _make_field_operator_definition(1) + same = _make_field_operator_definition(1) + different = _make_field_operator_definition(2) + assert stages.fingerprinter(first) == stages.fingerprinter(same) + assert stages.fingerprinter(first) != stages.fingerprinter(different) -def test_fingerprint_stage_program_def(program, samecode_program, different_program): - assert stages.fingerprint_stage(samecode_program.definition_stage) != stages.fingerprint_stage( - program.definition_stage - ) - assert stages.fingerprint_stage(different_program.definition_stage) != stages.fingerprint_stage( - program.definition_stage - ) +def test_definition_stages_use_the_custom_fingerprinter(): + first_fieldop = stages.DSLFieldOperatorDef(definition=_make_field_operator_definition(1)) + same_fieldop = stages.DSLFieldOperatorDef(definition=_make_field_operator_definition(1)) + different_fieldop = stages.DSLFieldOperatorDef(definition=_make_field_operator_definition(2)) -def test_fingerprint_stage_past_def(program, samecode_program, different_program): - past = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( - toolchain.ConcreteArtifact(program.definition_stage, arguments.CompileTimeArgs.empty()) - ) - samecode = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( - toolchain.ConcreteArtifact( - samecode_program.definition_stage, arguments.CompileTimeArgs.empty() - ) - ) - different = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( - toolchain.ConcreteArtifact( - different_program.definition_stage, arguments.CompileTimeArgs.empty() - ) - ) + first_program = stages.DSLProgramDef(definition=_make_program_definition(1)) + same_program = stages.DSLProgramDef(definition=_make_program_definition(1)) + different_program = stages.DSLProgramDef(definition=_make_program_definition(2)) - assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(past) - assert stages.fingerprint_stage(different) != stages.fingerprint_stage(past) + assert first_fieldop.fingerprint == same_fieldop.fingerprint + assert first_fieldop.fingerprint != different_fieldop.fingerprint + assert first_program.fingerprint == same_program.fingerprint + assert first_program.fingerprint != different_program.fingerprint diff --git a/tests/next_tests/unit_tests/otf_tests/test_workflow.py b/tests/next_tests/unit_tests/otf_tests/test_workflow.py index d4717d34ef..853bf7a33c 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_workflow.py +++ b/tests/next_tests/unit_tests/otf_tests/test_workflow.py @@ -71,7 +71,7 @@ def test_cached_with_hashing(): def hashing(inp: list[int]) -> int: return hash(sum(inp)) - wf = workflow.CachedStep(step=lambda inp: [*inp, 1], hash_function=hashing) + wf = workflow.CachedStep(step=lambda inp: [*inp, 1], key_function=hashing) assert wf([1, 2, 3]) == [1, 2, 3, 1] assert wf([3, 2, 1]) == [1, 2, 3, 1] diff --git a/tests/next_tests/unit_tests/test_utils.py b/tests/next_tests/unit_tests/test_utils.py index 51ed0cf8db..ef34325aa0 100644 --- a/tests/next_tests/unit_tests/test_utils.py +++ b/tests/next_tests/unit_tests/test_utils.py @@ -17,35 +17,35 @@ # Module-level classes so pickle can resolve them by qualified name. @dataclasses.dataclass -class _DataclassModel(utils.MetadataBasedPickling): +class _DataclassModel(utils.MetadataBasedPicklingMixin): value: int transient: str = dataclasses.field(default="skip", metadata=utils.gt4py_metadata(pickle=False)) @dataclasses.dataclass(slots=True) -class _SlottedDataclassModel(utils.MetadataBasedPickling): +class _SlottedDataclassModel(utils.MetadataBasedPicklingMixin): value: int transient: str = dataclasses.field(default="skip", metadata=utils.gt4py_metadata(pickle=False)) @datamodels.datamodel(slots=False) -class _DatamodelModel(utils.MetadataBasedPickling): +class _DatamodelModel(utils.MetadataBasedPicklingMixin): value: int transient: str = datamodels.field(default="skip", metadata=utils.gt4py_metadata(pickle=False)) @dataclasses.dataclass -class _EmptyDataclassModel(utils.MetadataBasedPickling): +class _EmptyDataclassModel(utils.MetadataBasedPicklingMixin): pass @dataclasses.dataclass(slots=True) -class _EmptySlottedDataclassModel(utils.MetadataBasedPickling): +class _EmptySlottedDataclassModel(utils.MetadataBasedPicklingMixin): pass @datamodels.datamodel(slots=False) -class _EmptyDatamodelModel(utils.MetadataBasedPickling): +class _EmptyDatamodelModel(utils.MetadataBasedPicklingMixin): pass @@ -91,6 +91,60 @@ def test_pickle_roundtrip(self, instance, expected_fields): assert getattr(restored, field_name) == expected_value +class _DuckFingerprinted: + @property + def fingerprint(self): + return "duck" + + @property + def fingerprinter(self): + return lambda _: "duck" + + +def test_fingerprintedabc_subclasshook_recognizes_duck_types(): + assert issubclass(_DuckFingerprinted, utils.FingerprintedABC) + + +class _TestFingerprinted(utils.FingerprintedABC): + def __init__(self, value: str, noise: str): + self.value = value + self.noise = noise + + @staticmethod + def fingerprinter(instance: "_TestFingerprinted") -> str: + return instance.value + + +def test_fingerprint_uses_fingerprinted_state_only(): + a = _TestFingerprinted("id", noise="left") + b = _TestFingerprinted("id", noise="right") + c = _TestFingerprinted("other", noise="left") + + assert utils.fingerprint(a) == utils.fingerprint(b) + assert utils.fingerprint(a) != utils.fingerprint(c) + + +class _CachedTestFingerprinted(utils.CachedFingerprintedMixin): + calls = 0 + + def __init__(self, value: str): + self.value = value + + @staticmethod + def fingerprinter(instance: "_CachedTestFingerprinted") -> str: + _CachedTestFingerprinted.calls += 1 + return instance.value + + +def test_cached_fingerprinted_mixin_computes_fingerprint_once_per_instance(): + _CachedTestFingerprinted.calls = 0 + instance = _CachedTestFingerprinted("stable") + + assert instance.fingerprint == "stable" + assert instance.fingerprint == "stable" + assert _CachedTestFingerprinted.calls == 1 + + def test_tree_map_default(): @utils.tree_map def testee(x):