diff --git a/dagrt/builtins_python.py b/dagrt/builtins_python.py index ab8eca8..7904f5f 100644 --- a/dagrt/builtins_python.py +++ b/dagrt/builtins_python.py @@ -73,6 +73,16 @@ def builtin_array(n): return np.empty(n, dtype=np.float64) +def builtin_array_utype(n, x): + import numpy as np + if n != np.floor(n): + raise ValueError("array() argument n is not an integer") + n = int(n) + + #return np.empty((n, x.size), dtype=x.dtype) + return np.zeros((n, x.size), dtype=x.dtype) + + def builtin_matmul(a, b, a_cols, b_cols): import numpy as np if a_cols != np.floor(a_cols): @@ -148,6 +158,7 @@ def builtin_print(arg): "dot_product": builtin_dot_product, "elementwise_abs": builtin_elementwise_abs, "array": builtin_array, + "array_utype": builtin_array_utype, "matmul": builtin_matmul, "transpose": builtin_transpose, "linear_solve": builtin_linear_solve, diff --git a/dagrt/codegen/expressions.py b/dagrt/codegen/expressions.py index 8a8b8d2..fa5e0c3 100644 --- a/dagrt/codegen/expressions.py +++ b/dagrt/codegen/expressions.py @@ -123,6 +123,24 @@ def map_logical_and(self, expr, enclosing_prec): " .and. ", expr.children, PREC_LOGICAL_AND), enclosing_prec, PREC_LOGICAL_AND) + _comparison_to_fortran = { + "==": ".eq.", + "!=": ".ne.", + "<": ".lt.", + ">": ".gt.", + "<=": ".le.", + ">=": ".ge.", + } + + def map_comparison(self, expr, enclosing_prec, *args, **kwargs): + from pymbolic.mapper.stringifier import PREC_COMPARISON + return self.parenthesize_if_needed( + self.format("%s %s %s", + self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), + self._comparison_to_fortran[expr.operator], + self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)), + enclosing_prec, PREC_COMPARISON) + # }}} diff --git a/dagrt/codegen/fortran.py b/dagrt/codegen/fortran.py index 63a5c32..9b7937e 100644 --- a/dagrt/codegen/fortran.py +++ b/dagrt/codegen/fortran.py @@ -29,7 +29,7 @@ from dagrt.codegen.expressions import FortranExpressionMapper from dagrt.codegen.codegen_base import StructuredCodeGenerator from dagrt.utils import is_state_variable -from dagrt.data import UserType +from dagrt.data import UserType, UserTypeArray from pytools.py_codegen import ( # It's the same code. So sue me. PythonCodeGenerator as FortranEmitterBase) @@ -308,9 +308,18 @@ def find_sym_kind(self, expr): def transform(self, expr): raise NotImplementedError + def transform_utype_array(self, expr, identifier): + raise NotImplementedError + def map_variable(self, expr): if isinstance(self.find_sym_kind(expr), UserType): return self.transform(expr) + # We also need to catch UserTypeArrays here, and operate + # on their elements. + elif isinstance(self.find_sym_kind(expr), UserTypeArray): + identifier = self.find_sym_kind(expr).identifier + expr = self.transform_utype_array(expr, identifier) + return self.transform(expr) else: return expr @@ -327,6 +336,24 @@ def transform(self, expr): return expr.attr(self.component) +class UserTypeArrayAppender(UserTypeReferenceTransformer): + def __init__(self, code_generator, component): + super().__init__(code_generator) + self.component = component + + def map_variable(self, expr): + if isinstance(self.find_sym_kind(expr), UserTypeArray): + return self.transform(expr) + else: + return expr + + def transform(self, expr): + return expr.attr(self.component) + + map_lookup = map_variable + map_subscript = map_variable + + class ArraySubscriptAppender(UserTypeReferenceTransformer): def __init__(self, code_generator, subscript): super().__init__(code_generator) @@ -335,6 +362,9 @@ def __init__(self, code_generator, subscript): def transform(self, expr): return expr[self.subscript] + def transform_utype_array(self, expr, identifier): + return expr.attr(identifier) + # }}} @@ -987,6 +1017,10 @@ def __init__(self, module_name, self.current_function = None self.used = False + self.in_loop = 0 + self.deinit_in_loop = False + self.deinit_emitter = FortranEmitter() + # }}} # {{{ utilities @@ -1091,7 +1125,8 @@ def process_ast(ast, print_ast=False): forced_kinds=[ (fd.name, loop_var, Integer()) for fd in fdescrs - for loop_var in LoopVariableFinder()(fd.ast)]) + for loop_var in LoopVariableFinder()(fd.ast)], + user_type_map=self.user_type_map) from dagrt.codegen.analysis import ( collect_ode_component_names_from_dag, @@ -1278,6 +1313,30 @@ def begin_emit(self, dag): # }}} + # {{{ for usertype arrays, should they be required + + # Emit a derived type for each UserTypeArray that we have. + from dagrt.data import collect_user_type_arrays + usertype_arrays = collect_user_type_arrays(self.sym_kind_table) + for sym_id in usertype_arrays: + with FortranTypeEmitter( + self.emitter, + "dagrt_{}_array".format(sym_id), + self) as emit: + + self.emit_variable_decl( + self.name_manager.name_global(sym_id), + sym_kind=UserType(sym_id), is_argument=False, + refcount_name=self.name_manager.name_refcount( + sym_id, qualified_with_state=False)) + # Store the length of the UserTypeArray in the type itself. + from dagrt.data import Integer + self.emit_variable_decl( + "array_size", sym_kind=Integer(), + is_argument=False) + + # }}} + # {{{ state type with FortranTypeEmitter( @@ -1365,7 +1424,6 @@ def begin_emit(self, dag): # If the refcount is 1, then nobody else is referring to # the memory, and we might as well repurpose/overwrite it, # so there's nothing more to do in that case. - with FortranIfEmitter( self.emitter, "refcount.ne.1", self) as emit_if: @@ -1451,6 +1509,19 @@ def get_fortran_type_for_user_type(self, type_identifier, is_argument=False): return ftype + def get_fortran_type_for_user_type_array(self, type_identifier, + is_argument=False): + # The type is a structure type, which has two members: + # a UserType and the (integer) refcount. + ftype = StructureType("dagrt_{}_array".format(type_identifier), ( + (type_identifier, + self.get_fortran_type_for_user_type(type_identifier)), + (self.name_manager.name_refcount(type_identifier, + qualified_with_state=False), + PointerType(BuiltinType("integer"))))) + + return ftype + def emit_variable_decl(self, fortran_name, sym_kind, is_argument=False, other_specifiers=(), emit=None, refcount_name=None): @@ -1461,7 +1532,7 @@ def emit_variable_decl(self, fortran_name, sym_kind, type_specifiers = other_specifiers - from dagrt.data import UserType + from dagrt.data import UserType, UserTypeArray if isinstance(sym_kind, Boolean): type_name = "logical" @@ -1495,6 +1566,14 @@ def emit_variable_decl(self, fortran_name, sym_kind, "integer, pointer :: {refcount_name}".format( refcount_name=refcount_name)) + elif isinstance(sym_kind, UserTypeArray): + ftype = self.get_fortran_type_for_user_type_array(sym_kind.identifier, + is_argument=is_argument) + + # This is a 1D array of UserTypes. + type_name = "type(dagrt_{}_array)".format(sym_kind.identifier) + type_specifiers += ("allocatable, dimension(:)",) + else: raise ValueError("unknown variable kind: %s" % type(sym_kind).__name__) @@ -1518,20 +1597,52 @@ def emit_variable_deinit(self, name, sym_kind): fortran_name = self.name_manager[name] refcnt_name = self.name_manager.name_refcount(name) - from dagrt.data import UserType - if not isinstance(sym_kind, UserType): + if isinstance(sym_kind, UserType): + self.emit( + "call {var_deinit_name}({args})" + .format( + var_deinit_name=self.get_var_deinit_name( + sym_kind.identifier), + args=", ".join( + self.extra_arguments + + (fortran_name, refcnt_name)) + )) + elif isinstance(sym_kind, UserTypeArray): + # For UserTypeArrays we need to loop through and deinit as well. + with FortranIfEmitter( + self.emitter, "allocated({})".format( + fortran_name), self): + ident = self.name_manager.make_unique_fortran_name("uarray_i") + self.declaration_emitter("integer %s" % ident) + em = FortranDoEmitter( + self.emitter, ident, + "0, {}".format( + fortran_name + "(0)%array_size-1"), + self) + em.__enter__() + uarray_entry = fortran_name + \ + "(int({}))".format(ident) + uarray_entry += "%{}".format(sym_kind.identifier) + refcnt_name = fortran_name \ + + "(int({}))".format(ident) + refcnt_name += "%{}".format( + self.name_manager.name_refcount( + sym_kind.identifier, + qualified_with_state=False)) + self.emit( + "call {var_deinit_name}({args})" + .format( + var_deinit_name=self.get_var_deinit_name( + sym_kind.identifier), + args=", ".join( + self.extra_arguments + + (uarray_entry, refcnt_name)) + )) + self.emitter.__exit__(None, None, None) + self.emit("deallocate({})".format(fortran_name)) + else: return - self.emit( - "call {var_deinit_name}({args})" - .format( - var_deinit_name=self.get_var_deinit_name( - sym_kind.identifier), - args=", ".join( - self.extra_arguments - + (fortran_name, refcnt_name)) - )) - def emit_refcounted_allocation(self, sym, sym_kind): fortran_name = self.name_manager[sym] @@ -1569,26 +1680,131 @@ def emit_allocation_check(self, sym, sym_kind): )) def emit_user_type_move(self, assignee_sym, assignee_fortran_name, - sym_kind, expr): - self.emit_variable_deinit(assignee_sym, sym_kind) + sym_kind, expression, tgt_refcnt, refcnt, deinit=True): + if deinit: + self.emit_variable_deinit(assignee_sym, sym_kind) self.emit_traceable( "{name} => {expr}" .format( name=assignee_fortran_name, - expr=self.name_manager[expr.name])) + expr=expression)) self.emit_traceable( "{tgt_refcnt} => {refcnt}" .format( - tgt_refcnt=self.name_manager.name_refcount(assignee_sym), - refcnt=self.name_manager.name_refcount(expr.name))) + tgt_refcnt=tgt_refcnt, + refcnt=refcnt)) self.emit_traceable( "{tgt_refcnt} = {tgt_refcnt} + 1" .format( - tgt_refcnt=self.name_manager.name_refcount(assignee_sym))) + tgt_refcnt=tgt_refcnt)) self.emit("") - def emit_assign_expr_inner(self, + def emit_user_type_array_assignment(self, assignee_sym, assignee_fortran_name, + sym_kind, expr, subscript_str): + if (subscript_str or isinstance(expr, (Subscript, Lookup))): + transformer = UserTypeArrayAppender( + self, sym_kind.identifier) + expression = self.expr(transformer(expr)) + else: + if isinstance(expr, Variable): + expr_kind = self.sym_kind_table.get(self.current_function, + expr.name) + expression = self.expr(expr) + if isinstance(expr_kind, UserTypeArray): + # We need to loop through the UserTypeArray here, + # and attach the structure entrance to both the LHS and + # the RHS... + # For temps, we also need to check if the assignee is + # allocated, and if not, allocate it. + with FortranIfEmitter( + self.emitter, ".not.allocated({})".format( + assignee_fortran_name), self): + self.emit("allocate({}(0:{}-1))".format( + assignee_fortran_name, + self.expr(expr) + "(0)%array_size")) + alloc_check_name = self.get_alloc_check_name( + sym_kind.identifier) + ident = self.name_manager.make_unique_fortran_name( + "uarray_i") + self.declaration_emitter("integer %s" % ident) + em = FortranDoEmitter( + self.emitter, ident, + "0, {}".format( + self.expr(expr) + "(0)%array_size-1"), + self) + em.__enter__() + uarray_entry = assignee_fortran_name + \ + "(int({}))".format(ident) + self.emit( + "{out}%array_size = {existing}".format( + out=uarray_entry, + existing=self.expr(expr) + + "(0)%array_size")) + uarray_entry += "%{}".format(sym_kind.identifier) + refcnt_name = assignee_fortran_name \ + + "(int({}))".format(ident) + refcnt_name += "%{}".format( + self.name_manager.name_refcount( + sym_kind.identifier, + qualified_with_state=False)) + self.emit("allocate({})".format(refcnt_name)) + self.emit( + "call {alloc_check_name}({args})" + .format( + alloc_check_name=alloc_check_name, + args=", ".join( + self.extra_arguments + + (uarray_entry, refcnt_name)) + )) + self.emitter.__exit__(None, None, None) + ident = self.name_manager.make_unique_fortran_name( + "uarray_i") + expression = self.expr(expr) + \ + "({})%".format(ident) + expr_kind.identifier + subscript_str += "({})%".format(ident) \ + + expr_kind.identifier + tgt_refcnt_name = assignee_fortran_name \ + + "(int({}))".format(ident) + tgt_refcnt_name += "%{}".format( + self.name_manager.name_refcount( + sym_kind.identifier, + qualified_with_state=False)) + refcnt_name = self.expr(expr) \ + + "(int({}))".format(ident) + refcnt_name += "%{}".format( + self.name_manager.name_refcount( + sym_kind.identifier, + qualified_with_state=False)) + self.declaration_emitter("integer %s" % ident) + em = FortranDoEmitter( + self.emitter, ident, + "0, {}".format( + self.expr(expr) + "(0)%array_size-1"), + self) + em.__enter__() + # Per-element user type moves. + assignee_loop_name = assignee_fortran_name \ + + subscript_str + self.emit_user_type_move(assignee_sym, + assignee_loop_name, + sym_kind, expression, + tgt_refcnt_name, refcnt_name, + deinit=False) + self.emitter.__exit__(None, None, None) + self.emit("") + return + + self.emit( + "{name}{subscript_str} = {expr}" + .format( + name=assignee_fortran_name, + subscript_str=subscript_str, + expr=expression)) + + self.emit("") + + def emit_assign_expr_inner(self, assignee_sym, assignee_fortran_name, assignee_subscript, expr, sym_kind): if assignee_subscript: subscript_str = "(%s)" % ( @@ -1598,11 +1814,18 @@ def emit_assign_expr_inner(self, else: subscript_str = "" + # Special treatment for usertype arrays needed here, + # since usertype arrays are actually intermediate structures. + if isinstance(sym_kind, UserTypeArray): + if assignee_subscript: + subscript_str += "%" + sym_kind.identifier + + expression = self.expr(expr) + if isinstance(expr, (Call, CallWithKwargs)): # These are supposed to have been transformed to AssignFunctionCall. raise RuntimeError("bare Call/CallWithKwargs encountered in " "Fortran code generator") - else: self.emit_trace("{assignee_fortran_name}{subscript_str} = {expr}..." .format( @@ -1610,19 +1833,24 @@ def emit_assign_expr_inner(self, subscript_str=subscript_str, expr=str(expr)[:50])) - from dagrt.data import UserType - if not isinstance(sym_kind, UserType): - self.emit( - "{name}{subscript_str} = {expr}" - .format( - name=assignee_fortran_name, - subscript_str=subscript_str, - expr=self.expr(expr))) - else: + if isinstance(sym_kind, UserType): ftype = self.get_fortran_type_for_user_type(sym_kind.identifier) AssignmentEmitter(self)( ftype, assignee_fortran_name, {}, expr, is_rhs_target=True) + self.emit("") + return + elif isinstance(sym_kind, UserTypeArray): + self.emit_user_type_array_assignment(assignee_sym, + assignee_fortran_name, + sym_kind, expr, subscript_str) + return + self.emit( + "{name}{subscript_str} = {expr}" + .format( + name=assignee_fortran_name, + subscript_str=subscript_str, + expr=expression)) self.emit("") @@ -2089,10 +2317,18 @@ def emit_for_begin(self, loop_var_name, lbound, ubound): self.expr(lbound), self.expr(ubound-1)), code_generator=self) + self.in_loop += 1 em.__enter__() def emit_for_end(self, loop_var_name): self.emitter.__exit__(None, None, None) + self.in_loop -= 1 + # if we are out of all loops, emitter-juggle to + # perform any deinits we have hanging around + if self.deinit_in_loop: + if self.in_loop == 0: + self.emitter.incorporate(self.deinit_emitter) + self.deinit_in_loop = False def emit_assign_expr(self, assignee_sym, assignee_subscript, expr): from dagrt.data import UserType, Array @@ -2102,22 +2338,47 @@ def emit_assign_expr(self, assignee_sym, assignee_subscript, expr): sym_kind = self.sym_kind_table.get( self.current_function, assignee_sym) - if assignee_subscript and not isinstance(sym_kind, Array): - raise TypeError("only arrays support subscripted assignment") + if (assignee_subscript and not isinstance(sym_kind, Array) + and not isinstance(sym_kind, UserTypeArray)): + raise TypeError("only arrays and UserTypeArrays support" + " subscripted assignment") return if not isinstance(sym_kind, UserType): self.emit_assign_expr_inner( - assignee_fortran_name, assignee_subscript, expr, sym_kind) + assignee_sym, assignee_fortran_name, assignee_subscript, + expr, sym_kind) return if assignee_subscript: raise ValueError("User types do not support subscripting") + # Incorporate possibility for assigning UserTypeArray elements to + # UserTypes. + tgt_refcnt = self.name_manager.name_refcount(assignee_sym) if isinstance(expr, Variable): + expression = self.name_manager[expr.name] + refcnt = self.name_manager.name_refcount(expr.name) self.emit_user_type_move( - assignee_sym, assignee_fortran_name, sym_kind, expr) + assignee_sym, assignee_fortran_name, sym_kind, + expression, tgt_refcnt, refcnt) return + elif isinstance(expr, (Subscript, Lookup)): + expr_kind = self.sym_kind_table.get(self.current_function, + expr.aggregate.name) + if isinstance(expr_kind, UserTypeArray): + transformer = UserTypeArrayAppender( + self, sym_kind.identifier) + expression = self.expr(transformer(expr)) + refcnt = self.expr(expr) \ + + "%{}".format( + self.name_manager.name_refcount( + sym_kind.identifier, + qualified_with_state=False)) + self.emit_user_type_move( + assignee_sym, assignee_fortran_name, sym_kind, + expression, tgt_refcnt, refcnt) + return from pymbolic import var from pymbolic.mapper.dependency import DependencyMapper @@ -2129,7 +2390,8 @@ def emit_assign_expr(self, assignee_sym, assignee_subscript, expr): self.emit_allocation_check(assignee_sym, sym_kind) self.emit_assign_expr_inner( - assignee_fortran_name, assignee_subscript, expr, sym_kind) + assignee_sym, assignee_fortran_name, assignee_subscript, + expr, sym_kind) def lower_inst(self, inst): """Emit the code for an statement.""" @@ -2225,6 +2487,37 @@ def emit_inst_AssignFunctionCall(self, inst): + assignee_fortran_names ))) + # If we just built a UserTypeArray, we need to loop and call the + # appropriate allocation check on the elements. + if "array_utype" in fortran_func_name: + alloc_check_name = self.get_alloc_check_name(sym_kind.identifier) + ident = self.name_manager.make_unique_fortran_name("uarray_i") + self.declaration_emitter("integer %s" % ident) + em = FortranDoEmitter( + self.emitter, ident, + "0, {}".format( + assignee_fortran_names[0] + "(0)%array_size-1"), + self) + em.__enter__() + uarray_entry = assignee_fortran_names[0] + "(int({}))".format(ident) + uarray_entry += "%{}".format(sym_kind.identifier) + refcnt_name = assignee_fortran_names[0] + "(int({}))".format(ident) + refcnt_name += "%{}".format(self.name_manager.name_refcount( + sym_kind.identifier, qualified_with_state=False)) + # Ensure allocations will be performed by setting refcounts to + # non-unity. + self.emit("allocate({})".format(refcnt_name)) + self.emit("{} = 2".format(refcnt_name)) + self.emit( + "call {alloc_check_name}({args})" + .format( + alloc_check_name=alloc_check_name, + args=", ".join( + self.extra_arguments + + (uarray_entry, refcnt_name)) + )) + self.emitter.__exit__(None, None, None) + self.emit_deinit_for_last_usage_of_vars(inst) # }}} @@ -2287,7 +2580,18 @@ def emit_deinit_for_last_usage_of_vars(self, inst): last_used_stmt_id = self.last_used_stmt_table[ variable, self.current_function] if inst.id == last_used_stmt_id and not is_state_variable(variable): - self.emit_variable_deinit(variable, var_kind) + # Check if we are in a loop or not. + if self.in_loop == 0: + self.emit_variable_deinit(variable, var_kind) + else: + # If we are in a loop, we need to pop the deinit outside + # of it. + self.deinit_emitter = FortranEmitter() + self.emitters.append(self.deinit_emitter) + self.emit_variable_deinit(variable, var_kind) + self.emit("") + self.deinit_in_loop = True + self.emitters.pop() def emit_inst_Raise(self, inst): # FIXME: Reenable emitting full error message @@ -2490,6 +2794,27 @@ def codegen_builtin_isnan(results, function, args, arg_kinds, """) +builtin_array_utype = CallCode(""" + <% + i = declare_new("integer", "i") + %> + if (int(${n}).ne.${n}) then + write(dagrt_stderr,*) 'argument to array_utype() is not an integer' + stop + endif + + if (allocated(${result})) then + deallocate(${result}) + endif + + allocate(${result}(0:int(${n})-1)) + + do ${i} = 0, int(${n})-1 + ${result}(int(${i}))%array_size = int(${n}) + end do + """) + + UTIL_MACROS = """ <%def name="write_matrix(mat_array, rows_var)" > <% @@ -2545,11 +2870,12 @@ def kind_to_fortran(kind): if kind.is_real_valued: return "real (kind=%s)" % real_scalar_kind else: - return "compelx (kind=%s)" % complex_scalar_kind + return "complex (kind=%s)" % complex_scalar_kind %> """ + builtin_matmul = CallCode(UTIL_MACROS + """ <% a_rows = declare_new("integer", "a_rows") diff --git a/dagrt/data.py b/dagrt/data.py index 3df73fe..6d817bb 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -46,6 +46,7 @@ .. autoclass:: Scalar .. autoclass:: Array .. autoclass:: UserType +.. autoclass:: UserTypeArray Symbol kind inference ^^^^^^^^^^^^^^^^^^^^^ @@ -59,6 +60,7 @@ .. autofunction:: infer_kinds .. autofunction:: collect_user_types +.. autofunction:: collect_user_type_arrays """ @@ -133,6 +135,22 @@ def __getinitargs__(self): return (self.is_real_valued,) +class UserTypeArray(SymbolKind): + """A variable-sized one-dimensional array + of user types. + + .. attribute:: identifier + + A unique identifier for this type. + """ + + def __init__(self, identifier): + super().__init__(identifier=identifier) + + def __getinitargs__(self): + return (self.identifier,) + + class UserType(SymbolKind): """Represents user state belonging to a normed vector space. @@ -429,13 +447,16 @@ def map_max(self, expr): def map_subscript(self, expr): agg_kind = self.rec(expr.aggregate) - if self.check and not isinstance(agg_kind, Array): - raise ValueError( - "only arrays can be subscripted, not '%s' " - "which is a '%s'" - % (expr.aggregate, type(agg_kind).__name__)) - - return Scalar(is_real_valued=agg_kind.is_real_valued) + if isinstance(agg_kind, Array): + return Scalar(is_real_valued=agg_kind.is_real_valued) + elif isinstance(agg_kind, UserTypeArray): + return UserType(agg_kind.identifier) + else: + if self.check: + raise ValueError( + "only arrays or UserTypeArrays can be " + "subscripted, not '%s' which is a '%s'" + % (expr.aggregate, type(agg_kind).__name__)) # }}} @@ -449,7 +470,7 @@ class SymbolKindFinder: def __init__(self, function_registry): self.function_registry = function_registry - def __call__(self, names, phases, forced_kinds=None): + def __call__(self, names, phases, forced_kinds=None, user_type_map=None): """Infer the kinds of all the symbols in a program. :arg names: a list of phase names @@ -472,6 +493,15 @@ def __call__(self, names, phases, forced_kinds=None): for phase_name, ident, kind in forced_kinds: result.set(phase_name, ident, kind=kind) + # If a UserType map is given, set the global symbol + # kind table accordingly. + # FIXME: are UserType identifiers guaranteed to + # match component_ids? + if user_type_map is not None: + for name in user_type_map: + result.set(names[0], "{}".format(name), + UserType(identifier=name)) + def make_kim(phase_name, check): return KindInferenceMapper( result.global_table, @@ -655,4 +685,28 @@ def collect_user_types(skt): # }}} +# {{{ collect user types + + +def collect_user_type_arrays(skt): + """Collect all of the of :class:`UserTypeArray` identifiers in a table. + + :arg skt: a :class:`SymbolKindTable` + :returns: a set of strings + """ + result = set() + + for kind in skt.global_table.values(): + if isinstance(kind, UserTypeArray): + result.add(kind.identifier) + + for tbl in skt.per_phase_table.values(): + for kind in tbl.values(): + if isinstance(kind, UserTypeArray): + result.add(kind.identifier) + + return result + +# }}} + # vim: foldmethod=marker diff --git a/dagrt/function_registry.py b/dagrt/function_registry.py index ef1c857..4c3ab6c 100644 --- a/dagrt/function_registry.py +++ b/dagrt/function_registry.py @@ -26,7 +26,7 @@ from pytools import RecordWithoutPickling from dagrt.data import ( - UserType, Integer, Boolean, Scalar, Array, UnableToInferKind) + UserType, Integer, Boolean, Scalar, Array, UserTypeArray, UnableToInferKind) NoneType = type(None) @@ -65,6 +65,7 @@ .. autoclass:: Len .. autoclass:: IsNaN .. autoclass:: Array_ +.. autoclass:: ArrayUType_ .. autoclass:: MatMul .. autoclass:: Transpose .. autoclass:: LinearSolve @@ -407,6 +408,27 @@ def get_result_kinds(self, arg_kinds, check): return (Array(is_real_valued=True),) +class ArrayUType_(Function): # noqa + """``array_utype(n, x)`` returns an empty array with *n* entries in it, + in which each entry is a user type. *n* must be an integer. + """ + + result_names = ("result",) + identifier = "array_utype" + arg_names = ("n", "x") + default_dict = {} + + def get_result_kinds(self, arg_kinds, check): + n_kind, x_kind = self.resolve_args(arg_kinds) + + if check and not isinstance(n_kind, Scalar): + raise TypeError("argument 'n' of 'array_utype' is not a scalar") + if check and not isinstance(x_kind, UserType): + raise TypeError("argument 'x' of 'array_utype' is not a user type") + + return (UserTypeArray(identifier=x_kind.identifier),) + + class MatMul(Function): """``matmul(a, b, a_cols, b_cols)`` returns a 1D array containing the matrix resulting from multiplying the arrays *a* and *b* (both interpreted @@ -571,6 +593,7 @@ def _make_bfr(): (Len(), "{numpy}.size({args})"), (IsNaN(), "{numpy}.isnan({args})"), (Array_(), "self._builtin_array({args})"), + (ArrayUType_(), "self._builtin_array_utype({args})"), (MatMul(), "self._builtin_matmul({args})"), (Transpose(), "self._builtin_transpose({args})"), (LinearSolve(), "self._builtin_linear_solve({args})"), @@ -597,6 +620,8 @@ def _make_bfr(): f.codegen_builtin_isnan) bfr = bfr.register_codegen(Array_.identifier, "fortran", f.builtin_array) + bfr = bfr.register_codegen(ArrayUType_.identifier, "fortran", + f.builtin_array_utype) bfr = bfr.register_codegen(MatMul.identifier, "fortran", f.builtin_matmul) bfr = bfr.register_codegen(Transpose.identifier, "fortran", diff --git a/test/test_codegen_fortran.py b/test/test_codegen_fortran.py index a83b89d..4d0e5fd 100755 --- a/test/test_codegen_fortran.py +++ b/test/test_codegen_fortran.py @@ -206,6 +206,43 @@ def test_elementwise_abs(): fortran_libraries=["lapack", "blas"]) +def test_usertype_arrays(): + with CodeBuilder(name="primary") as cb: + cb("y_array", "array_utype(5, ytype)") + cb("y_array[i]", "ytype", + loops=(("i", 0, 5),)) + cb("y_array[i]", "i*f(0, y_array[i])", + loops=(("i", 0, 5),)) + + code = create_DAGCode_with_steady_phase(cb.statements) + + rhs_function = "f" + + from dagrt.function_registry import ( + base_function_registry, register_ode_rhs) + freg = register_ode_rhs(base_function_registry, "ytype", + identifier=rhs_function, + input_names=("y",)) + freg = freg.register_codegen(rhs_function, "fortran", + f.CallCode(""" + ${result} = -2*${y} + """)) + + codegen = f.CodeGenerator( + "utype_arrays", + function_registry=freg, + user_type_map={"ytype": f.ArrayType((100,), f.BuiltinType("real*8"))}, + timing_function="second") + + code_str = codegen(code) + + run_fortran([ + ("utype_arrays.f90", code_str), + ("test_utype_arrays.f90", read_file("test_utype_arrays.f90")), + ], + fortran_libraries=["lapack", "blas"]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/test/test_utype_arrays.f90 b/test/test_utype_arrays.f90 new file mode 100644 index 0000000..5eb10e5 --- /dev/null +++ b/test/test_utype_arrays.f90 @@ -0,0 +1,31 @@ +program test_utype_arrays + + use utype_arrays, only: dagrt_state_type, & + timestep_initialize => initialize, & + timestep_run => run, & + timestep_shutdown => shutdown + + implicit none + + type(dagrt_state_type), target :: dagrt_state + type(dagrt_state_type), pointer :: dagrt_state_ptr + + real*8, dimension(100) :: y0 + + integer i + + ! start code ---------------------------------------------------------------- + + dagrt_state_ptr => dagrt_state + + + do i = 1, 100 + y0(i) = i + end do + + call timestep_initialize(dagrt_state=dagrt_state_ptr, state_ytype=y0) + call timestep_run(dagrt_state=dagrt_state_ptr) + call timestep_shutdown(dagrt_state=dagrt_state_ptr) + +end program +