Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 48 additions & 66 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
@@ -1,71 +1,5 @@
{
"files": {
"./.run-pylint.py": [
{
"code": "reportUnknownParameterType",
"range": {
"startColumn": 28,
"endColumn": 38,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
"startColumn": 28,
"endColumn": 38,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 33,
"endColumn": 43,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 35,
"endColumn": 39,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 45,
"endColumn": 49,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 27,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 12,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 20,
"endColumn": 24,
"lineCount": 1
}
}
],
"./examples/advection/surface.py": [
{
"code": "reportUnannotatedClassAttribute",
Expand Down Expand Up @@ -995,6 +929,30 @@
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 4,
"endColumn": 14,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 16,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 8,
"endColumn": 25,
"lineCount": 1
}
},
{
"code": "reportPossiblyUnboundVariable",
"range": {
Expand Down Expand Up @@ -1219,6 +1177,30 @@
}
],
"./grudge/geometry/metrics.py": [
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 8,
"endColumn": 31,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 42,
"endColumn": 65,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 48,
"endColumn": 64,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
Expand Down
1 change: 1 addition & 0 deletions .test-conda-env-py3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- pyopencl
- python=3
- gmsh
- jax

# test scripts use ompi-specific arguments
- openmpi
Expand Down
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
intersphinx_mapping = {
"arraycontext": ("https://documen.tician.de/arraycontext/", None),
"loopy": ("https://documen.tician.de/loopy/", None),
"jax": ("https://docs.jax.dev/en/latest/", None),
"meshmode": ("https://documen.tician.de/meshmode/", None),
"modepy": ("https://documen.tician.de/modepy/", None),
"mpi4py": ("https://mpi4py.readthedocs.io/en/stable", None),
Expand All @@ -33,7 +34,6 @@
os.environ["PYOPENCL_TEST"] = "port:cpu"

nitpick_ignore_regex = [
["py:class", r"np\.ndarray"],
["py:data|py:class", r"arraycontext.*ContainerTc"],
]

Expand Down
36 changes: 35 additions & 1 deletion grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.. autoclass:: MPIPyOpenCLArrayContext
.. autoclass:: MPINumpyArrayContext
.. class:: MPIPytatoArrayContext
.. autoclass:: MPIPytatoJAXArrayContext
.. autofunction:: get_reasonable_array_context_class
"""

Expand Down Expand Up @@ -75,11 +76,12 @@
_HAVE_FUSION_ACTX = False


from arraycontext import ArrayContext, NumpyArrayContext
from arraycontext import ArrayContext, NumpyArrayContext, PytatoJAXArrayContext
from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller
from arraycontext.pytest import (
_PytestNumpyArrayContextFactory,
_PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoJaxArrayContextFactory,
_PytestPytatoPyOpenCLArrayContextFactory,
register_pytest_array_context_factory,
)
Expand Down Expand Up @@ -443,6 +445,27 @@ def clone(self) -> Self:
# }}}


# {{{ distributed + lazy jax

class MPIPytatoJAXArrayContext(PytatoJAXArrayContext, MPIBasedArrayContext): # pyright: ignore[reportUnsafeMultipleInheritance]
"""An array context for using distributed computation with :mod:`jax.numpy`
lazy evaluation.

.. autofunction:: __init__
"""

def __init__(self, mpi_communicator: MPI.Intracomm) -> None:
super().__init__()

self.mpi_communicator: MPI.Intracomm = mpi_communicator

@override
def clone(self) -> Self:
return type(self)(self.mpi_communicator)

# }}}


# {{{ distributed + pytato array context subclasses

class MPIBasePytatoPyOpenCLArrayContext(
Expand Down Expand Up @@ -542,12 +565,23 @@ def __call__(self):
return self.actx_class()


class PytestPytatoJAXArrayContextFactory(_PytestPytatoJaxArrayContextFactory):
actx_class = PytatoJAXArrayContext

def __call__(self):
import jax
jax.config.update("jax_enable_x64", True)
return self.actx_class()


register_pytest_array_context_factory("grudge.pyopencl",
PytestPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("grudge.pytato-pyopencl",
PytestPytatoPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("grudge.numpy",
PytestNumpyArrayContextFactory)
register_pytest_array_context_factory("grudge.lazy-jax",
PytestPytatoJAXArrayContextFactory)

# }}}

Expand Down
22 changes: 12 additions & 10 deletions grudge/geometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,22 +597,24 @@ def _signed_face_ones(
dd_base.untrace(), dd_base
)
assert isinstance(all_faces_conn, DirectDiscretizationConnection)
signed_ones = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE)).zeros(
actx, dtype=dcoll.real_dtype
) + 1

signed_face_ones_numpy = actx.to_numpy(signed_ones)
discr = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE))

new_group_arrays = []

for dgrp, grp in zip(discr.groups, all_faces_conn.groups, strict=True):
sign = np.ones((dgrp.nelements, dgrp.nunit_dofs),
dtype=discr.real_dtype)

for igrp, grp in enumerate(all_faces_conn.groups):
for batch in grp.batches:
assert batch.to_element_face is not None
i = actx.to_numpy(actx.thaw(batch.to_element_indices))
grp_field = signed_face_ones_numpy[igrp].reshape(-1)
grp_field[i] = ( # pyright: ignore[reportIndexIssue]
(2.0 * (batch.to_element_face % 2) - 1.0) * grp_field[i]
)
sign[i, :] = 2.0 * (batch.to_element_face % 2) - 1.0

new_group_arrays.append(sign)

return actx.from_numpy(signed_face_ones_numpy)
from meshmode.dof_array import DOFArray
return actx.from_numpy(DOFArray(actx, tuple(new_group_arrays)))


def parametrization_derivative(
Expand Down
4 changes: 3 additions & 1 deletion test/test_dt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@
from grudge.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoJAXArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)


pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory])
PytestNumpyArrayContextFactory,
PytestPytatoJAXArrayContextFactory])

import logging

Expand Down
4 changes: 3 additions & 1 deletion test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from grudge.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoJAXArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)
from grudge.discretization import make_discretization_collection
Expand All @@ -47,7 +48,8 @@
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory])
PytestNumpyArrayContextFactory,
PytestPytatoJAXArrayContextFactory])


# {{{ inverse metric
Expand Down
Loading