Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gt4py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
- `gt4py.next` supports structured and unstructured grid.
"""

from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__ # noqa: I001
from . import eve, storage
from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__


__all__ = [
Expand Down
20 changes: 14 additions & 6 deletions src/gt4py/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,16 @@ def is_noninstantiable(cls: Type[_T]) -> bool:
return "__noninstantiable__" in cls.__dict__


def singledispatcher(
default: Callable[P, T],
implementations: dict[type, Callable[[Any], Any]],
) -> Callable[P, T]:
result = functools.singledispatch(default)
for cls, func in implementations.items():
result.register(cls)(func)
return result


def content_hash(
*args: Any,
hash_algorithm: str | xtyping.HashlibAlgorithm | None = None,
Expand Down Expand Up @@ -669,6 +679,7 @@ def content_hash(

def custom_pickler(
reducer: Callable[[Any], tuple | types.NotImplementedType],
*,
name: str | None = None,
) -> type[pickle.Pickler]:
"""
Expand All @@ -685,7 +696,9 @@ def custom_pickler(

def custom_pickler_from_reducers(
custom_reducers: dict[type, Callable[[Any], tuple | types.NotImplementedType]],
*,
name: str | None = None,
default_reducer: Callable[[Any], tuple | types.NotImplementedType] = lambda _: NotImplemented,
) -> type[pickle.Pickler]:
"""
Create a pickler with the provided reducers registered in reducer override.
Expand All @@ -696,13 +709,8 @@ def custom_pickler_from_reducers(
the `dispatch_table` dict, to allow easy pickle customization of entire class
hierarchies.
"""
reducer = functools.singledispatch(
cast(Callable[[Any], tuple | types.NotImplementedType], lambda _: NotImplemented)
)
for cls, func in custom_reducers.items():
reducer.register(cls)(func)

return custom_pickler(reducer, name=name)
return custom_pickler(singledispatcher(default_reducer, custom_reducers), name=name)


ddiff = deepdiff.diff.DeepDiff
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Generic

from gt4py._core import definitions as core_defs
from gt4py.next import custom_layout_allocators as next_allocators
from gt4py.next import custom_layout_allocators as next_allocators, utils
from gt4py.next.ffront import (
foast_to_gtir,
foast_to_past,
Expand Down Expand Up @@ -145,7 +145,7 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]:
# `transforms` -> `frontend_transforms`
# `executor` -> `backend_transforms`
@dataclasses.dataclass(frozen=True)
class Backend(Generic[core_defs.DeviceTypeT]):
class Backend(utils.CachedFingerprintedDataclass, Generic[core_defs.DeviceTypeT]):
name: str
executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram]
allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
Expand Down
12 changes: 12 additions & 0 deletions src/gt4py/next/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ def env_flag_to_int(name: str, default: int) -> int:
]


def _get_build_cache_version_id() -> str:
from gt4py import __version__

return __version__


#: Version ID for the build cache. It should only be overridden by advanced users
#: testing toolchain changes that are expected to break compatibility with previously cached builds.
BUILD_CACHE_VERSION_ID: str = (
os.environ.get("BUILD_CACHE_VERSION_ID") or _get_build_cache_version_id()
)

#: Build type to be used when CMake is used to compile generated code.
#: Might have no effect when CMake is not used as part of the toolchain.
# FIXME[#2447](egparedes): compile-time setting, should be included in the build cache key.
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
class NdArrayField(
common.MutableField[common.DimsT, core_defs.ScalarT],
common.FieldBuiltinFuncRegistry,
utils.MetadataBasedPickling,
utils.MetadataBasedPicklingMixin,
):
"""
Shared field implementation for NumPy-like fields.
Expand Down
24 changes: 3 additions & 21 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def embedded_program_call_context(


@dataclasses.dataclass(frozen=True)
class _CompilableGTEntryPointMixin(Generic[ffront_stages.DSLDefinitionT]):
class _CompilableGTEntryPointMixin(
utils.CachedFingerprintedDataclass, Generic[ffront_stages.DSLDefinitionT]
):
"""
Mixing used by program and program-like objects.
Expand Down Expand Up @@ -862,23 +864,3 @@ def scan_operator_inner(definition: Callable) -> FieldOperator:
)

return scan_operator_inner if definition is None else scan_operator_inner(definition)


@ffront_stages.add_content_to_fingerprint.register
def add_fieldop_to_fingerprint(obj: FieldOperator, hasher: xtyping.HashlibAlgorithm) -> None:
ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher)
ffront_stages.add_content_to_fingerprint(obj.backend, hasher)


@ffront_stages.add_content_to_fingerprint.register
def add_foast_fieldop_to_fingerprint(
obj: FieldOperatorFromFoast, hasher: xtyping.HashlibAlgorithm
) -> None:
ffront_stages.add_content_to_fingerprint(obj.foast_stage, hasher)
ffront_stages.add_content_to_fingerprint(obj.backend, hasher)


@ffront_stages.add_content_to_fingerprint.register
def add_program_to_fingerprint(obj: Program, hasher: xtyping.HashlibAlgorithm) -> None:
ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher)
ffront_stages.add_content_to_fingerprint(obj.backend, hasher)
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def foast_to_gtir_factory(
"""Wrap `foast_to_gtir` into a chainable and, optionally, cached workflow step."""
wf = foast_to_gtir
if cached:
wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage)
wf = workflow.CachedStep(step=wf, key_function=ffront_stages.fingerprinter)
return wf


Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/foast_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,5 @@ def operator_to_program_factory(
foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory()
)
if cached:
wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage)
wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter)
return wf
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def func_to_foast_factory(
"""Wrap `func_to_foast` in a chainable and optionally cached workflow step."""
wf = workflow.make_step(func_to_foast)
if cached:
wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage)
wf = workflow.CachedStep(step=wf, key_function=ffront_stages.fingerprinter)
return wf


Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/func_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSLProgramDef
"""
wf = workflow.make_step(func_to_past)
if cached:
wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage)
wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter)
return wf


Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/past_passes/linters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def linter_factory(
) -> workflow.Workflow[PASTProgramDef, PASTProgramDef]:
wf = lint_misnamed_functions.chain(lint_undefined_symbols)
if cached:
wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage)
wf = workflow.CachedStep(step=wf, key_function=ffront_stages.fingerprinter)
return wf


Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/past_process_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def transform_program_args_factory(
) -> workflow.Workflow[ffront_stages.ConcretePASTProgramDef, ffront_stages.ConcretePASTProgramDef]:
wf = transform_program_args
if cached:
wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage)
wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter)
return wf


Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def past_to_gtir_factory(
) -> workflow.Workflow[ConcretePASTProgramDef, definitions.CompilableProgramDef]:
wf = workflow.make_step(past_to_gtir)
if cached:
wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage)
wf = workflow.CachedStep(wf, key_function=ffront_stages.fingerprinter)
return wf


Expand Down
132 changes: 37 additions & 95 deletions src/gt4py/next/ffront/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,51 @@

from __future__ import annotations

import collections.abc
import dataclasses
import functools
import hashlib
import types
import typing
from typing import Any, Optional, TypeVar

import xxhash

from gt4py.eve import extended_typing as xtyping
from gt4py.next import common
from gt4py.eve import utils as eve_utils
from gt4py.next import common, utils
from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils
from gt4py.next.otf import arguments, toolchain


# Create a custom pickler for the BaseStage `fingerprinter` that handles
# `types.FunctionType` by using its source code and closure variables.
# This should be enough for the use case of GT4Py DSL definitions,
# which are expected to be pure functions without complicated closures.
fingerprint_reducer = eve_utils.singledispatcher(
utils.fingerprint_reducer,
{
types.FunctionType: lambda f: (
tuple,
(),
(
source_utils.make_source_definition_from_function(f),
source_utils.get_closure_vars_from_function(f),
),
)
},
)

fingerprinter = functools.partial(
eve_utils.content_hash,
pickler=eve_utils.custom_pickler(fingerprint_reducer, name="FFrontFingerprintPickler"),
)


@dataclasses.dataclass(frozen=True)
class BaseStage(utils.CachedFingerprintedDataclass):
"""Base class for optimized Fingerprinted implementations in frozen dataclasses."""

fingerprinter = staticmethod(fingerprinter)


@dataclasses.dataclass(frozen=True)
class DSLFieldOperatorDef:
class DSLFieldOperatorDef(BaseStage):
definition: types.FunctionType
node_class: type[foast.OperatorNode] = foast.FieldOperator
attributes: dict[str, Any] = dataclasses.field(default_factory=dict)
Expand All @@ -53,7 +80,7 @@ class DSLFieldOperatorDef:


@dataclasses.dataclass(frozen=True)
class FOASTOperatorDef:
class FOASTOperatorDef(BaseStage):
foast_node: foast.OperatorNode
closure_vars: dict[str, Any]
grid_type: Optional[common.GridType] = None
Expand All @@ -67,7 +94,7 @@ class FOASTOperatorDef:


@dataclasses.dataclass(frozen=True)
class DSLProgramDef:
class DSLProgramDef(BaseStage):
definition: types.FunctionType
grid_type: Optional[common.GridType] = None
debug: bool = False
Expand All @@ -79,7 +106,7 @@ class DSLProgramDef:


@dataclasses.dataclass(frozen=True)
class PASTProgramDef:
class PASTProgramDef(BaseStage):
past_node: past.Program
closure_vars: dict[str, Any]
grid_type: Optional[common.GridType] = None
Expand All @@ -92,88 +119,3 @@ class PASTProgramDef:

DSLDefinition = DSLFieldOperatorDef | DSLProgramDef
DSLDefinitionT = TypeVar("DSLDefinitionT", DSLFieldOperatorDef, DSLProgramDef)


def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str:
hasher: xtyping.HashlibAlgorithm
if not algorithm:
hasher = xxhash.xxh64() # type: ignore[assignment] # fixing this requires https://github.com/ifduyue/python-xxhash/issues/104
elif isinstance(algorithm, str):
hasher = hashlib.new(algorithm)
else:
hasher = algorithm

add_content_to_fingerprint(obj, hasher)
return hasher.hexdigest()


@functools.singledispatch
def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None:
hasher.update(str(obj).encode())


for t in (str, int):
add_content_to_fingerprint.register(t, add_content_to_fingerprint.registry[object])


@add_content_to_fingerprint.register(DSLFieldOperatorDef)
@add_content_to_fingerprint.register(FOASTOperatorDef)
@add_content_to_fingerprint.register(DSLProgramDef)
@add_content_to_fingerprint.register(PASTProgramDef)
@add_content_to_fingerprint.register(toolchain.ConcreteArtifact)
@add_content_to_fingerprint.register(arguments.CompileTimeArgs)
def add_stage_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None:
add_content_to_fingerprint(obj.__class__, hasher)
for field in dataclasses.fields(obj):
add_content_to_fingerprint(getattr(obj, field.name), hasher)


def add_jit_args_id_to_fingerprint(
obj: arguments.JITArgs, hasher: xtyping.HashlibAlgorithm
) -> None:
add_content_to_fingerprint(str(id(obj)), hasher)


@add_content_to_fingerprint.register
def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgorithm) -> None:
sourcedef = source_utils.SourceDefinition.from_function(obj)
for item in sourcedef:
add_content_to_fingerprint(item, hasher)

closure_vars = source_utils.get_closure_vars_from_function(obj)
for item in sorted(closure_vars.items(), key=lambda x: x[0]):
add_content_to_fingerprint(item, hasher)


@add_content_to_fingerprint.register
def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None:
# just a small helper to additionally allow sorting types (by just using their name)
def key_function(key: Any) -> Any:
if isinstance(key, type):
return key.__module__, key.__qualname__
return key

for key in sorted(obj.keys(), key=key_function):
add_content_to_fingerprint(key, hasher)
add_content_to_fingerprint(obj[key], hasher)


@add_content_to_fingerprint.register
def add_type_to_fingerprint(obj: type, hasher: xtyping.HashlibAlgorithm) -> None:
hasher.update(obj.__name__.encode())


@add_content_to_fingerprint.register
def add_sequence_to_fingerprint(
obj: collections.abc.Iterable, hasher: xtyping.HashlibAlgorithm
) -> None:
for item in obj:
add_content_to_fingerprint(item, hasher)


@add_content_to_fingerprint.register
def add_foast_located_node_to_fingerprint(
obj: foast.LocatedNode, hasher: xtyping.HashlibAlgorithm
) -> None:
add_content_to_fingerprint(obj.location, hasher)
add_content_to_fingerprint(str(obj), hasher)
Loading
Loading