diff --git a/dagrt/codegen/fortran.py b/dagrt/codegen/fortran.py index 9361a59..3327855 100644 --- a/dagrt/codegen/fortran.py +++ b/dagrt/codegen/fortran.py @@ -2407,6 +2407,42 @@ def codegen_builtin_len(results, function, args, arg_kinds, code_generator.emit("") +class AbsComputer(TypeVisitorWithResult): + def visit_BuiltinType(self, fortran_type, fortran_expr_str, index_expr_map): + self.code_generator.emit( + "{result} = abs({result})" + .format( + result=self.result_expr)) + + +def codegen_builtin_elementwise_abs(results, function, args, arg_kinds, + code_generator): + result, = results + + from dagrt.data import Scalar, Array, UserType + x_kind = arg_kinds[0] + if isinstance(x_kind, Scalar): + if x_kind.is_real_valued: + ftype = BuiltinType("real*8") + else: + ftype = BuiltinType("complex*16") + elif isinstance(x_kind, UserType): + ftype = code_generator.user_type_map[x_kind.identifier] + elif isinstance(x_kind, Array): + code_generator.emit("{result} = abs({arg})".format( + result=result, + arg=args[0])) + return + else: + raise TypeError("unsupported kind for elementwise_abs argument: %s" % x_kind) + + code_generator.emit(f"{result} = 0") + code_generator.emit("") + + AbsComputer(code_generator, result)(ftype, args[0], {}) + code_generator.emit("") + + class IsNaNComputer(TypeVisitorWithResult): def visit_BuiltinType(self, fortran_type, fortran_expr_str, index_expr_map): self.code_generator.emit( diff --git a/dagrt/data.py b/dagrt/data.py index 298a10c..3df73fe 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -371,7 +371,12 @@ def map_generic_call(self, function_id, arg_dict, single_return_only=True): except UnableToInferKind: arg_kinds[key] = None - z = func.get_result_kinds(arg_kinds, self.check) + try: + z = func.get_result_kinds(arg_kinds, self.check) + except Exception: + raise UnableToInferKind( + "function '%s' needs more info about arguments" + % function_id) if single_return_only: if len(z) != 1: diff --git a/dagrt/function_registry.py b/dagrt/function_registry.py index 409fd90..ef1c857 100644 --- a/dagrt/function_registry.py +++ b/dagrt/function_registry.py @@ -60,6 +60,7 @@ .. autoclass:: Norm1 .. autoclass:: Norm2 .. autoclass:: NormInf +.. autoclass:: ElementwiseAbs .. autoclass:: DotProduct .. autoclass:: Len .. autoclass:: IsNaN @@ -300,6 +301,32 @@ class NormInf(_NormBase): identifier = "norm_inf" +class ElementwiseAbs(Function): + """``elementwise_abs(x)`` takes the elementwise absolute value of *x*. + *x* is a user type, array, or scalar. + """ + + result_names = ("result",) + identifier = "elementwise_abs" + arg_names = ("x") + default_dict = {} + + def get_result_kinds(self, arg_kinds, check): + x_kind, = self.resolve_args(arg_kinds) + + if check and not isinstance(x_kind, (Scalar, Array, UserType)): + raise TypeError("argument 'x' of 'elementwise_abs' is not a user type") + + if isinstance(x_kind, UserType): + return (UserType(identifier=x_kind.identifier),) + elif isinstance(x_kind, Array): + return (Array(is_real_valued=True),) + elif isinstance(x_kind, Scalar): + return (Scalar(is_real_valued=True),) + else: + raise TypeError("argument 'x' of 'elementwise_abs' undetermined") + + class DotProduct(Function): """``dot_product(x, y)`` return the dot product of *x* and *y*. The complex conjugate of *x* is taken first, if applicable. @@ -539,6 +566,7 @@ def _make_bfr(): (Norm1(), "self._builtin_norm_1({args})"), (Norm2(), "self._builtin_norm_2({args})"), (NormInf(), "self._builtin_norm_inf({args})"), + (ElementwiseAbs(), "{numpy}.abs({args})"), (DotProduct(), "{numpy}.vdot({args})"), (Len(), "{numpy}.size({args})"), (IsNaN(), "{numpy}.isnan({args})"), @@ -561,6 +589,8 @@ def _make_bfr(): bfr = bfr.register_codegen(Norm2.identifier, "fortran", f.codegen_builtin_norm_2) + bfr = bfr.register_codegen(ElementwiseAbs.identifier, "fortran", + f.codegen_builtin_elementwise_abs) bfr = bfr.register_codegen(Len.identifier, "fortran", f.codegen_builtin_len) bfr = bfr.register_codegen(IsNaN.identifier, "fortran", diff --git a/test/test_codegen_fortran.py b/test/test_codegen_fortran.py index 420efd3..9d3fd3d 100755 --- a/test/test_codegen_fortran.py +++ b/test/test_codegen_fortran.py @@ -150,6 +150,49 @@ def test_self_dep_in_loop(): fortran_libraries=["lapack", "blas"]) +def test_elementwise_abs(): + with CodeBuilder(name="primary") as cb: + cb("y", "f(0, ytype)") + cb("ytype", "y") + # Test new builtin on a usertype. + cb("z", "elementwise_abs(ytype)") + cb("i", "array(20)") + cb("i[j]", "-j", + loops=(("j", 0, 20),)) + # Test new builtin on an array type. + cb("k", "elementwise_abs(i)") + # Test new builtin on a scalar. + cb("l", "elementwise_abs(-20)") + + code = create_DAGCode_with_steady_phase(cb.statements) + + rhs_function = "f" + + from dagrt.function_registry import ( + base_function_registry, register_ode_rhs) + freg = register_ode_rhs(base_function_registry, "ytype", + identifier=rhs_function, + input_names=("y",)) + freg = freg.register_codegen(rhs_function, "fortran", + f.CallCode(""" + ${result} = -2*${y} + """)) + + codegen = f.CodeGenerator( + "element_abs_test", + function_registry=freg, + user_type_map={"ytype": f.ArrayType((100,), f.BuiltinType("real*8"))}, + timing_function="second") + + code_str = codegen(code) + + run_fortran([ + ("element_abs.f90", code_str), + ("test_element_abs.f90", read_file("test_element_abs.f90")), + ], + fortran_libraries=["lapack", "blas"]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/test/test_element_abs.f90 b/test/test_element_abs.f90 new file mode 100644 index 0000000..a5129a4 --- /dev/null +++ b/test/test_element_abs.f90 @@ -0,0 +1,31 @@ +program test_element_abs + + use element_abs_test, only: dagrt_state_type, & + timestep_initialize => initialize, & + timestep_run => run, & + timestep_shutdown => shutdown + + implicit none + + type(dagrt_state_type), target :: dagrt_state + type(dagrt_state_type), pointer :: dagrt_state_ptr + + real*8, dimension(100) :: y0 + + integer i + + ! start code ---------------------------------------------------------------- + + dagrt_state_ptr => dagrt_state + + + do i = 1, 100 + y0 = i + end do + + call timestep_initialize(dagrt_state=dagrt_state_ptr, state_ytype=y0) + call timestep_run(dagrt_state=dagrt_state_ptr) + call timestep_shutdown(dagrt_state=dagrt_state_ptr) + +end program +