From 305f37bdbbf6d7b5d51241fb7d74704fc6a5be5e Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 27 May 2026 22:11:21 -0700 Subject: [PATCH 1/5] Implement fma for C-like targets. --- loopy/target/c/__init__.py | 34 +++++++++++++++++++++++++++++++++- loopy/target/opencl.py | 27 ++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index 42a0b5730..3c0baffc7 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -756,6 +756,38 @@ def with_types(self, 0: NumpyType(dtype), -1: NumpyType(np.int32)})), clbl_inf_ctx) + elif name == "fma": + + if not all(-1 <= id <= 2 for id in arg_num_to_dtype): + raise LoopyError("fma takes exactly three arguments.") + + if not all(i in arg_num_to_dtype for i in range(3)): + return ( + self.copy(arg_id_to_dtype=constantdict(arg_num_to_dtype)), + clbl_inf_ctx) + + dtype = np.result_type(*[ + arg_num_to_dtype[i].numpy_dtype for i in range(3)]) + real_dtype = np.empty(0, dtype=dtype).real.dtype + + if dtype.kind == "c": + raise LoopyTypeError("fma does not support complex numbers.") + + if real_dtype == np.float64: + pass # fma + elif real_dtype == np.float32: + name = name + "f" # fmaf + elif (hasattr(np, "float128") and real_dtype == np.float128): + name = name + "l" # fmal + else: + raise LoopyTypeError(f"fma does not support type {dtype}.") + + dtype = NumpyType(dtype) + return ( + self.copy(name_in_target=name, + arg_id_to_dtype=constantdict( + {-1: dtype, 0: dtype, 1: dtype, 2: dtype})), + clbl_inf_ctx) # does not satisfy any of the conditions needed for specialization. # hence just returning a copy of the callable. @@ -866,7 +898,7 @@ def get_c_callables(): "sinh", "pow", "atan2", "tanh", "exp", "log", "log10", "sqrt", "ceil", "floor", "max", "min", "fmax", "fmin", "fabs", "tan", "erf", "erfc", "isnan", "real", "imag", - "conj"] + "conj", "fma"] return {id_: CMathCallable(id_) for id_ in cmath_ids} diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index ed05c7628..511978500 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -404,6 +404,31 @@ def with_types(self, })), clbl_inf_ctx) + elif name == "fma": + if not all(-1 <= id <= 2 for id in arg_num_to_dtype): + raise LoopyError("fma takes exactly three arguments.") + + if not all(i in arg_num_to_dtype for i in range(3)): + return ( + self.copy(arg_id_to_dtype=constantdict(arg_num_to_dtype)), + clbl_inf_ctx) + + dtype = np.result_type(*[ + arg_num_to_dtype[i].numpy_dtype for i in range(3)]) + + if dtype.kind == "c": + raise LoopyTypeError("fma does not support complex numbers.") + if dtype.kind not in "f": + raise LoopyTypeError("fma requires floating-point arguments, " + f"got '{dtype}'.") + + dtype = NumpyType(dtype) + return ( + self.copy(name_in_target="fma", + arg_id_to_dtype=constantdict( + {-1: dtype, 0: dtype, 1: dtype, 2: dtype})), + clbl_inf_ctx) + elif name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS: num_args = _CL_SIMPLE_MULTI_ARG_FUNCTIONS[name] for id in arg_num_to_dtype: @@ -480,7 +505,7 @@ def get_opencl_callables(): "acos", "acosh", "asin", "asinh", "atan", "atanh", "atan2", "pow", "exp", "log", "log10", "sqrt", "ceil", "floor", "max", "min", "fmax", "fmin", - "fabs", "erf", "erfc"} + "fabs", "erf", "erfc", "fma"} | set(_CL_SIMPLE_MULTI_ARG_FUNCTIONS) | set(VECTOR_LITERAL_FUNCS)) From b0667004520d423513eb5dab2f0278b1d7c66822 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 27 May 2026 22:22:11 -0700 Subject: [PATCH 2/5] Test FMA support in loopy. --- test/test_target.py | 54 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/test/test_target.py b/test/test_target.py index c47d20fd9..5e50c8307 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -891,6 +891,60 @@ def test_argmax_ctarget_floating_point(): assert out_dict["max_ind"][0] == 2 +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("target_t", [lp.PyOpenCLTarget, lp.ExecutableCTarget]) +def test_fma_correctness(ctx_factory: cl.CtxFactory, dtype, target_t): + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + target = target_t() + + n = 1729 + rng = np.random.default_rng(seed=42) + + a = rng.random(n, dtype=dtype) + b = rng.random(n, dtype=dtype) + c = rng.random(n, dtype=dtype) + + knl = lp.make_kernel( + "{[i]: 0<=i 1: From 9afa8f55a078579207fe2b3defa4443e4199878f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 28 May 2026 10:41:09 -0500 Subject: [PATCH 3/5] Fix deprecated system() use --- contrib/c-integer-semantics.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/contrib/c-integer-semantics.py b/contrib/c-integer-semantics.py index 1c676f747..ae59ac48f 100644 --- a/contrib/c-integer-semantics.py +++ b/contrib/c-integer-semantics.py @@ -1,5 +1,5 @@ import ctypes -from os import system +import subprocess C_SRC = """ @@ -79,7 +79,8 @@ def main(): with open("int-experiments.c", "w") as outf: outf.write(C_SRC) - system("gcc -Wall -shared int-experiments.c -o int-experiments.so") + subprocess.run(["gcc", "-Wall", "-shared", "int-experiments.c", + "-o", "int-experiments.so"], check=True) int_exp = ctypes.CDLL("int-experiments.so") for func in [ From a57caea699ee2962cfb91b2968c023bbbd71c2de Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 28 May 2026 11:07:07 -0500 Subject: [PATCH 4/5] Fix some redundant imports (new ruff lint) --- loopy/kernel/creation.py | 2 -- loopy/target/pyopencl.py | 1 - 2 files changed, 3 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 9de077d7b..be7da3bce 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -66,7 +66,6 @@ InstructionBase, MultiAssignmentBase, ) -from loopy.options import Options from loopy.symbolic import ( IdentityMapper, Reduction, @@ -74,7 +73,6 @@ SubstitutionRuleExpander, WalkMapper, ) -from loopy.target import TargetBase from loopy.tools import Optional, intern_frozenset_of_ids from loopy.translation_unit import TranslationUnit, for_each_kernel from loopy.types import NumpyType diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index 9d88394c9..148b9982b 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -32,7 +32,6 @@ from constantdict import constantdict from typing_extensions import override -import genpy import pymbolic.primitives as p from cgen import ( Block, From 6e40595d756a1e97bf30370c543e1d1e46a028c7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 28 May 2026 10:41:14 -0500 Subject: [PATCH 5/5] Update baseline --- .basedpyright/baseline.json | 40 ++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index c631e45d5..6e2c50a7a 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -51923,6 +51923,30 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 25, + "endColumn": 54, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 25, + "endColumn": 60, + "lineCount": 1 + } + }, + { + "code": "reportAttributeAccessIssue", + "range": { + "startColumn": 50, + "endColumn": 54, + "lineCount": 1 + } + }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -92327,22 +92351,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": {