Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -51923,6 +51923,30 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 25,
"endColumn": 54,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 25,
"endColumn": 60,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 50,
"endColumn": 54,
"lineCount": 1
}
},
{
"code": "reportIncompatibleMethodOverride",
"range": {
Expand Down Expand Up @@ -92327,22 +92351,6 @@
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 30,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 16,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down
5 changes: 3 additions & 2 deletions contrib/c-integer-semantics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ctypes
from os import system
import subprocess


C_SRC = """
Expand Down Expand Up @@ -79,7 +79,8 @@ def main():
with open("int-experiments.c", "w") as outf:
outf.write(C_SRC)

system("gcc -Wall -shared int-experiments.c -o int-experiments.so")
subprocess.run(["gcc", "-Wall", "-shared", "int-experiments.c",
"-o", "int-experiments.so"], check=True)

int_exp = ctypes.CDLL("int-experiments.so")
for func in [
Expand Down
2 changes: 0 additions & 2 deletions loopy/kernel/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,13 @@
InstructionBase,
MultiAssignmentBase,
)
from loopy.options import Options
from loopy.symbolic import (
IdentityMapper,
Reduction,
SubArrayRef,
SubstitutionRuleExpander,
WalkMapper,
)
from loopy.target import TargetBase
from loopy.tools import Optional, intern_frozenset_of_ids
from loopy.translation_unit import TranslationUnit, for_each_kernel
from loopy.types import NumpyType
Expand Down
34 changes: 33 additions & 1 deletion loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,38 @@ def with_types(self,
0: NumpyType(dtype),
-1: NumpyType(np.int32)})),
clbl_inf_ctx)
elif name == "fma":

if not all(-1 <= id <= 2 for id in arg_num_to_dtype):
raise LoopyError("fma takes exactly three arguments.")

if not all(i in arg_num_to_dtype for i in range(3)):
return (
self.copy(arg_id_to_dtype=constantdict(arg_num_to_dtype)),
clbl_inf_ctx)

dtype = np.result_type(*[
arg_num_to_dtype[i].numpy_dtype for i in range(3)])
real_dtype = np.empty(0, dtype=dtype).real.dtype

if dtype.kind == "c":
raise LoopyTypeError("fma does not support complex numbers.")

if real_dtype == np.float64:
pass # fma
elif real_dtype == np.float32:
name = name + "f" # fmaf
elif (hasattr(np, "float128") and real_dtype == np.float128):
name = name + "l" # fmal
else:
raise LoopyTypeError(f"fma does not support type {dtype}.")

dtype = NumpyType(dtype)
return (
self.copy(name_in_target=name,
arg_id_to_dtype=constantdict(
{-1: dtype, 0: dtype, 1: dtype, 2: dtype})),
clbl_inf_ctx)

# does not satisfy any of the conditions needed for specialization.
# hence just returning a copy of the callable.
Expand Down Expand Up @@ -866,7 +898,7 @@ def get_c_callables():
"sinh", "pow", "atan2", "tanh", "exp", "log", "log10",
"sqrt", "ceil", "floor", "max", "min", "fmax", "fmin",
"fabs", "tan", "erf", "erfc", "isnan", "real", "imag",
"conj"]
"conj", "fma"]

return {id_: CMathCallable(id_) for id_ in cmath_ids}

Expand Down
27 changes: 26 additions & 1 deletion loopy/target/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,31 @@ def with_types(self,
})),
clbl_inf_ctx)

elif name == "fma":
if not all(-1 <= id <= 2 for id in arg_num_to_dtype):
raise LoopyError("fma takes exactly three arguments.")

if not all(i in arg_num_to_dtype for i in range(3)):
return (
self.copy(arg_id_to_dtype=constantdict(arg_num_to_dtype)),
clbl_inf_ctx)

dtype = np.result_type(*[
arg_num_to_dtype[i].numpy_dtype for i in range(3)])

if dtype.kind == "c":
raise LoopyTypeError("fma does not support complex numbers.")
if dtype.kind not in "f":
raise LoopyTypeError("fma requires floating-point arguments, "
f"got '{dtype}'.")

dtype = NumpyType(dtype)
return (
self.copy(name_in_target="fma",
arg_id_to_dtype=constantdict(
{-1: dtype, 0: dtype, 1: dtype, 2: dtype})),
clbl_inf_ctx)

elif name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS:
num_args = _CL_SIMPLE_MULTI_ARG_FUNCTIONS[name]
for id in arg_num_to_dtype:
Expand Down Expand Up @@ -480,7 +505,7 @@ def get_opencl_callables():
"acos", "acosh", "asin", "asinh", "atan", "atanh", "atan2",
"pow", "exp", "log", "log10", "sqrt", "ceil", "floor",
"max", "min", "fmax", "fmin",
"fabs", "erf", "erfc"}
"fabs", "erf", "erfc", "fma"}
| set(_CL_SIMPLE_MULTI_ARG_FUNCTIONS)
| set(VECTOR_LITERAL_FUNCS))

Expand Down
1 change: 0 additions & 1 deletion loopy/target/pyopencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from constantdict import constantdict
from typing_extensions import override

import genpy
import pymbolic.primitives as p
from cgen import (
Block,
Expand Down
54 changes: 54 additions & 0 deletions test/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,60 @@ def test_argmax_ctarget_floating_point():
assert out_dict["max_ind"][0] == 2


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("target_t", [lp.PyOpenCLTarget, lp.ExecutableCTarget])
def test_fma_correctness(ctx_factory: cl.CtxFactory, dtype, target_t):
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)
target = target_t()

n = 1729
rng = np.random.default_rng(seed=42)

a = rng.random(n, dtype=dtype)
b = rng.random(n, dtype=dtype)
c = rng.random(n, dtype=dtype)

knl = lp.make_kernel(
"{[i]: 0<=i<n}",
"out[i] = fma(a[i], b[i], c[i])",
[lp.GlobalArg("a,b,c,out", dtype, shape="n"),
lp.ValueArg("n", np.int32)],
target=target,
)

if isinstance(target, lp.PyOpenCLTarget):
_, (out,) = knl(cq, a=a, b=b, c=c, n=n)
else:
assert isinstance(target, lp.ExecutableCTarget)
_, (out,) = knl(a=a, b=b, c=c, n=n)

np.testing.assert_allclose(out, a*b + c,
rtol=1e-5 if dtype == np.float32 else 1e-12)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_fma_codegen(dtype):
targets_and_expected = [
(lp.CTarget(), "fmaf(" if dtype == np.float32 else "fma("),
(lp.OpenCLTarget(), "fma("),
(lp.CudaTarget(), "fmaf(" if dtype == np.float32 else "fma("),
]

for target, expected_substr in targets_and_expected:
knl = lp.make_kernel(
"{[i]: 0<=i<n}",
"out[i] = fma(a[i], b[i], c[i])",
[lp.GlobalArg("a,b,c,out", dtype, shape="n"),
lp.ValueArg("n", np.int32)],
target=target,
)
code = lp.generate_code_v2(knl).device_code()
assert expected_substr in code, (
f"{target}, {dtype}: expected '{expected_substr}' in code"
)


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down
Loading