From 408ce43e846f274f5be0a3b83c56ff7bdb037ff6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 28 May 2025 17:54:44 -0500 Subject: [PATCH 1/2] add MPIPytatoJAXArrayContext --- grudge/array_context.py | 41 ++++++++++++++++++++++++++++++++++++++++- test/test_dt_utils.py | 4 +++- test/test_metrics.py | 4 +++- 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index 36b4c121..e12a91a5 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -5,6 +5,8 @@ .. autoclass:: MPIPyOpenCLArrayContext .. autoclass:: MPINumpyArrayContext .. class:: MPIPytatoArrayContext +.. autoclass:: MPIEagerJAXArrayContext +.. autoclass:: MPIPytatoJAXArrayContext .. autofunction:: get_reasonable_array_context_class """ @@ -76,13 +78,19 @@ _HAVE_FUSION_ACTX = False -from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext +from arraycontext import ( + ArrayContext, + EagerJAXArrayContext, + NumpyArrayContext, + PytatoJAXArrayContext, +) from arraycontext.container import ArrayContainer from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller from arraycontext.pytest import ( _PytestEagerJaxArrayContextFactory, _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, + _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory, ) @@ -449,6 +457,26 @@ def clone(self) -> Self: # }}} +# {{{ distributed + lazy jax + +class MPIPytatoJAXArrayContext(PytatoJAXArrayContext, MPIBasedArrayContext): + """An array context for using distributed computation with :mod:`jax` + lazy evaluation. + + .. autofunction:: __init__ + """ + + def __init__(self, mpi_communicator) -> None: + super().__init__() + + self.mpi_communicator = mpi_communicator + + def clone(self) -> Self: + return type(self)(self.mpi_communicator) + +# }}} + + # {{{ distributed + pytato array context subclasses class MPIBasePytatoPyOpenCLArrayContext( @@ -551,6 +579,15 @@ def __call__(self): return self.actx_class() +class PytestPytatoJAXArrayContextFactory(_PytestPytatoJaxArrayContextFactory): + actx_class = PytatoJAXArrayContext + + def __call__(self): + import jax + jax.config.update("jax_enable_x64", True) + return self.actx_class() + + register_pytest_array_context_factory("grudge.pyopencl", PytestPyOpenCLArrayContextFactory) register_pytest_array_context_factory("grudge.pytato-pyopencl", @@ -559,6 +596,8 @@ def __call__(self): PytestNumpyArrayContextFactory) register_pytest_array_context_factory("grudge.eager-jax", PytestEagerJAXArrayContextFactory) +register_pytest_array_context_factory("grudge.lazy-jax", + PytestPytatoJAXArrayContextFactory) # }}} diff --git a/test/test_dt_utils.py b/test/test_dt_utils.py index 81b1d09d..a0c11ab0 100644 --- a/test/test_dt_utils.py +++ b/test/test_dt_utils.py @@ -30,6 +30,7 @@ PytestEagerJAXArrayContextFactory, PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, + PytestPytatoJAXArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) @@ -38,7 +39,8 @@ [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, - PytestEagerJAXArrayContextFactory]) + PytestEagerJAXArrayContextFactory, + PytestPytatoJAXArrayContextFactory]) import logging diff --git a/test/test_metrics.py b/test/test_metrics.py index 5b21f15b..ad6e5ded 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -36,6 +36,7 @@ PytestEagerJAXArrayContextFactory, PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, + PytestPytatoJAXArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) from grudge.discretization import make_discretization_collection @@ -46,7 +47,8 @@ [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, - PytestEagerJAXArrayContextFactory]) + PytestEagerJAXArrayContextFactory, + PytestPytatoJAXArrayContextFactory]) # {{{ inverse metric From 8938f8b90ec9e9d85ec13f2ed0642149b6e586d6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 5 Sep 2025 10:35:55 -0500 Subject: [PATCH 2/2] fix merge error --- grudge/array_context.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index b687e3eb..d170502e 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -106,14 +106,8 @@ import pyopencl import pyopencl.array as cl_array - from arraycontext.container import - - - - - - - + from arraycontext.container import ArrayContainer + from pytools.tag import Tag