diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index ca36cfe1..9228357e 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1,71 +1,5 @@ { "files": { - "./.run-pylint.py": [ - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 28, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 28, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 33, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 35, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 45, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 20, - "endColumn": 24, - "lineCount": 1 - } - } - ], "./examples/advection/surface.py": [ { "code": "reportUnannotatedClassAttribute", @@ -995,6 +929,30 @@ "lineCount": 1 } }, + { + "code": "reportUnannotatedClassAttribute", + "range": { + "startColumn": 4, + "endColumn": 14, + "lineCount": 1 + } + }, + { + "code": "reportImplicitOverride", + "range": { + "startColumn": 8, + "endColumn": 16, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 8, + "endColumn": 25, + "lineCount": 1 + } + }, { "code": "reportPossiblyUnboundVariable", "range": { @@ -1219,6 +1177,30 @@ } ], "./grudge/geometry/metrics.py": [ + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 8, + "endColumn": 31, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 42, + "endColumn": 65, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 48, + "endColumn": 64, + "lineCount": 1 + } + }, { "code": "reportArgumentType", "range": { diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 3a4f0a8f..c4981063 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -12,6 +12,7 @@ dependencies: - pyopencl - python=3 - gmsh +- jax # test scripts use ompi-specific arguments - openmpi diff --git a/doc/conf.py b/doc/conf.py index 1597d1c6..cca6d818 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -18,6 +18,7 @@ intersphinx_mapping = { "arraycontext": ("https://documen.tician.de/arraycontext/", None), "loopy": ("https://documen.tician.de/loopy/", None), + "jax": ("https://docs.jax.dev/en/latest/", None), "meshmode": ("https://documen.tician.de/meshmode/", None), "modepy": ("https://documen.tician.de/modepy/", None), "mpi4py": ("https://mpi4py.readthedocs.io/en/stable", None), @@ -33,7 +34,6 @@ os.environ["PYOPENCL_TEST"] = "port:cpu" nitpick_ignore_regex = [ - ["py:class", r"np\.ndarray"], ["py:data|py:class", r"arraycontext.*ContainerTc"], ] diff --git a/grudge/array_context.py b/grudge/array_context.py index 9f1074bc..583d2bb9 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -5,6 +5,7 @@ .. autoclass:: MPIPyOpenCLArrayContext .. autoclass:: MPINumpyArrayContext .. class:: MPIPytatoArrayContext +.. autoclass:: MPIPytatoJAXArrayContext .. autofunction:: get_reasonable_array_context_class """ @@ -75,11 +76,12 @@ _HAVE_FUSION_ACTX = False -from arraycontext import ArrayContext, NumpyArrayContext +from arraycontext import ArrayContext, NumpyArrayContext, PytatoJAXArrayContext from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller from arraycontext.pytest import ( _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, + _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory, ) @@ -443,6 +445,27 @@ def clone(self) -> Self: # }}} +# {{{ distributed + lazy jax + +class MPIPytatoJAXArrayContext(PytatoJAXArrayContext, MPIBasedArrayContext): # pyright: ignore[reportUnsafeMultipleInheritance] + """An array context for using distributed computation with :mod:`jax.numpy` + lazy evaluation. + + .. autofunction:: __init__ + """ + + def __init__(self, mpi_communicator: MPI.Intracomm) -> None: + super().__init__() + + self.mpi_communicator: MPI.Intracomm = mpi_communicator + + @override + def clone(self) -> Self: + return type(self)(self.mpi_communicator) + +# }}} + + # {{{ distributed + pytato array context subclasses class MPIBasePytatoPyOpenCLArrayContext( @@ -542,12 +565,23 @@ def __call__(self): return self.actx_class() +class PytestPytatoJAXArrayContextFactory(_PytestPytatoJaxArrayContextFactory): + actx_class = PytatoJAXArrayContext + + def __call__(self): + import jax + jax.config.update("jax_enable_x64", True) + return self.actx_class() + + register_pytest_array_context_factory("grudge.pyopencl", PytestPyOpenCLArrayContextFactory) register_pytest_array_context_factory("grudge.pytato-pyopencl", PytestPytatoPyOpenCLArrayContextFactory) register_pytest_array_context_factory("grudge.numpy", PytestNumpyArrayContextFactory) +register_pytest_array_context_factory("grudge.lazy-jax", + PytestPytatoJAXArrayContextFactory) # }}} diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py index 04aa11b9..15e73e6b 100644 --- a/grudge/geometry/metrics.py +++ b/grudge/geometry/metrics.py @@ -597,22 +597,24 @@ def _signed_face_ones( dd_base.untrace(), dd_base ) assert isinstance(all_faces_conn, DirectDiscretizationConnection) - signed_ones = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE)).zeros( - actx, dtype=dcoll.real_dtype - ) + 1 - signed_face_ones_numpy = actx.to_numpy(signed_ones) + discr = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE)) + + new_group_arrays = [] + + for dgrp, grp in zip(discr.groups, all_faces_conn.groups, strict=True): + sign = np.ones((dgrp.nelements, dgrp.nunit_dofs), + dtype=discr.real_dtype) - for igrp, grp in enumerate(all_faces_conn.groups): for batch in grp.batches: assert batch.to_element_face is not None i = actx.to_numpy(actx.thaw(batch.to_element_indices)) - grp_field = signed_face_ones_numpy[igrp].reshape(-1) - grp_field[i] = ( # pyright: ignore[reportIndexIssue] - (2.0 * (batch.to_element_face % 2) - 1.0) * grp_field[i] - ) + sign[i, :] = 2.0 * (batch.to_element_face % 2) - 1.0 + + new_group_arrays.append(sign) - return actx.from_numpy(signed_face_ones_numpy) + from meshmode.dof_array import DOFArray + return actx.from_numpy(DOFArray(actx, tuple(new_group_arrays))) def parametrization_derivative( diff --git a/test/test_dt_utils.py b/test/test_dt_utils.py index 59e3d4db..f5974521 100644 --- a/test/test_dt_utils.py +++ b/test/test_dt_utils.py @@ -33,6 +33,7 @@ from grudge.array_context import ( PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, + PytestPytatoJAXArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) @@ -40,7 +41,8 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, - PytestNumpyArrayContextFactory]) + PytestNumpyArrayContextFactory, + PytestPytatoJAXArrayContextFactory]) import logging diff --git a/test/test_metrics.py b/test/test_metrics.py index 21a7934f..e7cbde6d 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -38,6 +38,7 @@ from grudge.array_context import ( PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, + PytestPytatoJAXArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) from grudge.discretization import make_discretization_collection @@ -47,7 +48,8 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, - PytestNumpyArrayContextFactory]) + PytestNumpyArrayContextFactory, + PytestPytatoJAXArrayContextFactory]) # {{{ inverse metric