diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 7ba618042..c631e45d5 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -21553,22 +21553,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 40, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 40, - "endColumn": 45, - "lineCount": 1 - } - }, { "code": "reportArgumentType", "range": { @@ -21585,14 +21569,6 @@ "lineCount": 3 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 38, - "endColumn": 43, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -30781,14 +30757,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 29, - "endColumn": 61, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -30917,14 +30885,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 27, - "endColumn": 59, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -79651,14 +79611,6 @@ } ], "./loopy/transform/precompute.py": [ - { - "code": "reportPrivateLocalImportUsage", - "range": { - "startColumn": 4, - "endColumn": 14, - "lineCount": 1 - } - }, { "code": "reportMissingTypeArgument", "range": { @@ -80692,962 +80644,346 @@ } }, { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 58, - "endColumn": 74, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 58, - "endColumn": 74, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 25, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportMissingTypeArgument", - "range": { - "startColumn": 29, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 23, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 45, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 57, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 57, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 68, - "endColumn": 74, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 68, - "endColumn": 74, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 12, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 12, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 32, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 12, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 12, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 32, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 12, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 12, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 12, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 12, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 28, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 28, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 59, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 45, - "endColumn": 59, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 12, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 12, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 25, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 29, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 29, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 35, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 35, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 40, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 40, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 51, - "endColumn": 61, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 51, - "endColumn": 61, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 47, - "endColumn": 61, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 20, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 26, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 31, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 42, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 20, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 42, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 53, - "endColumn": 75, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 53, - "endColumn": 75, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 35, - "endColumn": 52, - "lineCount": 2 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 20, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 68, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 20, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 26, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 31, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 42, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 19, - "endColumn": 28, + "startColumn": 58, + "endColumn": 74, "lineCount": 1 } }, { "code": "reportUnknownArgumentType", "range": { - "startColumn": 16, - "endColumn": 39, + "startColumn": 58, + "endColumn": 74, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportArgumentType", "range": { - "startColumn": 16, - "endColumn": 41, + "startColumn": 25, + "endColumn": 29, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownParameterType", "range": { "startColumn": 16, - "endColumn": 40, + "endColumn": 34, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportMissingParameterType", "range": { "startColumn": 16, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 31, - "endColumn": 41, + "endColumn": 34, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportUnknownParameterType", "range": { - "startColumn": 40, - "endColumn": 63, + "startColumn": 16, + "endColumn": 30, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportMissingParameterType", "range": { - "startColumn": 12, - "endColumn": 33, + "startColumn": 16, + "endColumn": 30, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportUnknownParameterType", "range": { - "startColumn": 29, - "endColumn": 48, + "startColumn": 16, + "endColumn": 36, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportMissingParameterType", "range": { - "startColumn": 50, - "endColumn": 64, + "startColumn": 16, + "endColumn": 36, "lineCount": 1 } }, { - "code": "reportUnannotatedClassAttribute", + "code": "reportUnknownParameterType", "range": { - "startColumn": 13, - "endColumn": 31, + "startColumn": 16, + "endColumn": 39, "lineCount": 1 } }, { - "code": "reportUninitializedInstanceVariable", + "code": "reportMissingParameterType", "range": { - "startColumn": 13, - "endColumn": 31, + "startColumn": 16, + "endColumn": 39, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownParameterType", "range": { - "startColumn": 39, - "endColumn": 56, + "startColumn": 16, + "endColumn": 38, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportMissingParameterType", "range": { - "startColumn": 39, - "endColumn": 56, + "startColumn": 16, + "endColumn": 38, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 53, - "endColumn": 70, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 53, - "endColumn": 70, + "startColumn": 8, + "endColumn": 31, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 53, - "endColumn": 75, + "startColumn": 8, + "endColumn": 27, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 53, - "endColumn": 78, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 53, - "endColumn": 78, + "startColumn": 8, + "endColumn": 33, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 12, - "endColumn": 40, + "startColumn": 8, + "endColumn": 36, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 12, - "endColumn": 44, + "startColumn": 8, + "endColumn": 35, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 48, - "endColumn": 65, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 48, - "endColumn": 65, + "startColumn": 8, + "endColumn": 36, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportArgumentType", "range": { - "startColumn": 48, - "endColumn": 70, - "lineCount": 1 + "startColumn": 35, + "endColumn": 52, + "lineCount": 2 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 48, - "endColumn": 73, + "startColumn": 20, + "endColumn": 45, "lineCount": 1 } }, { "code": "reportUnknownArgumentType", "range": { - "startColumn": 48, - "endColumn": 73, + "startColumn": 20, + "endColumn": 45, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 8, - "endColumn": 18, + "startColumn": 15, + "endColumn": 34, "lineCount": 1 } }, { - "code": "reportIncompatibleMethodOverride", + "code": "reportUnknownMemberType", "range": { - "startColumn": 8, - "endColumn": 18, + "startColumn": 15, + "endColumn": 68, "lineCount": 1 } }, { - "code": "reportImplicitOverride", + "code": "reportUnknownMemberType", "range": { - "startColumn": 8, - "endColumn": 18, + "startColumn": 14, + "endColumn": 33, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 25, - "endColumn": 31, + "startColumn": 16, + "endColumn": 41, "lineCount": 1 } }, { - "code": "reportMissingParameterType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 25, - "endColumn": 31, + "startColumn": 16, + "endColumn": 41, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 20, - "endColumn": 39, + "startColumn": 16, + "endColumn": 40, "lineCount": 1 } }, { "code": "reportUnknownArgumentType", "range": { - "startColumn": 61, - "endColumn": 67, + "startColumn": 16, + "endColumn": 40, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 16, - "endColumn": 32, + "startColumn": 31, + "endColumn": 59, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 19, - "endColumn": 52, + "startColumn": 31, + "endColumn": 41, "lineCount": 1 } }, { - "code": "reportUnknownLambdaType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 27, - "endColumn": 31, + "startColumn": 40, + "endColumn": 63, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 38, - "endColumn": 42, + "startColumn": 12, + "endColumn": 33, "lineCount": 1 } }, { "code": "reportUnknownArgumentType", "range": { - "startColumn": 44, - "endColumn": 50, + "startColumn": 50, + "endColumn": 64, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportUninitializedInstanceVariable", "range": { - "startColumn": 52, - "endColumn": 56, + "startColumn": 13, + "endColumn": 31, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 23, - "endColumn": 32, + "startColumn": 12, + "endColumn": 40, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 28, - "endColumn": 43, + "startColumn": 12, + "endColumn": 44, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportIncompatibleMethodOverride", "range": { - "startColumn": 16, - "endColumn": 40, + "startColumn": 8, + "endColumn": 18, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportImplicitOverride", "range": { - "startColumn": 41, - "endColumn": 48, + "startColumn": 8, + "endColumn": 18, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 27, - "endColumn": 42, + "startColumn": 16, + "endColumn": 32, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 31, - "endColumn": 48, + "startColumn": 23, + "endColumn": 32, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 34, - "endColumn": 61, + "startColumn": 16, + "endColumn": 40, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 34, - "endColumn": 63, + "startColumn": 30, + "endColumn": 57, "lineCount": 1 } }, @@ -81667,14 +81003,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 32, - "endColumn": 47, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -81691,14 +81019,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 26, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -82043,22 +81363,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 12, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 28, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index 7fabaface..95fb1bfab 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -29,6 +29,7 @@ import numpy as np from constantdict import constantdict +from typing_extensions import final, override import islpy as isl from pymbolic import ArithmeticExpression, var @@ -45,6 +46,7 @@ ) from loopy.symbolic import ( CombineMapper, + ExpansionState, RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext, @@ -60,7 +62,7 @@ from loopy.translation_unit import CallablesTable, TranslationUnit from loopy.types import LoopyType, ToLoopyTypeConvertible, to_loopy_type from loopy.typing import ( - Expression, + InsnId, auto, integer_expr_or_err, integer_or_err, @@ -69,12 +71,13 @@ if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Collection, Sequence, Set as AbstractSet + from pymbolic.typing import Expression from pytools.tag import Tag from loopy.kernel import LoopKernel - from loopy.match import ToStackMatchConvertible + from loopy.match import StackMatch, ToStackMatchConvertible # {{{ contains_subst_rule_invocation @@ -234,13 +237,21 @@ def map_subst_rule(self, name, tag, arguments, expn_state): # {{{ replace rule invocation -class RuleInvocationReplacer(RuleAwareIdentityMapper): - def __init__(self, rule_mapping_context, subst_name, subst_tag, within, - access_descriptors, array_base_map, - storage_axis_names, storage_axis_sources, - non1_storage_axis_names, - temporary_name, compute_insn_id, compute_dep_id, - compute_read_variables): +@final +class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]): + def __init__(self, + rule_mapping_context: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Tag | None, + within: StackMatch, + access_descriptors, + array_base_map, + storage_axis_names: Collection[str], + storage_axis_sources, + non1_storage_axis_names, + temporary_name: str, + compute_dep_ids: AbstractSet[InsnId], + compute_read_variables): super().__init__(rule_mapping_context) self.subst_name = subst_name @@ -255,22 +266,26 @@ def __init__(self, rule_mapping_context, subst_name, subst_tag, within, self.non1_storage_axis_names = non1_storage_axis_names self.temporary_name = temporary_name - self.compute_insn_id = compute_insn_id - self.compute_dep_id = compute_dep_id + self.compute_dep_ids = compute_dep_ids self.compute_read_variables = compute_read_variables self.compute_insn_depends_on = set() - def map_subst_rule(self, name, tag, arguments, expn_state): + @override + def map_subst_rule(self, + name: str, + tags: AbstractSet[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState): if not ( name == self.subst_name and self.within( expn_state.kernel, expn_state.instruction, expn_state.stack) - and (self.subst_tag is None or self.subst_tag == tag)): + and (self.subst_tag is None or self.subst_tag == tags)): return super().map_subst_rule( - name, tag, arguments, expn_state) + name, tags, arguments, expn_state) # {{{ check if in footprint @@ -285,7 +300,7 @@ def map_subst_rule(self, name, tag, arguments, expn_state): if not self.array_base_map.is_access_descriptor_in_footprint(accdesc): return super().map_subst_rule( - name, tag, arguments, expn_state) + name, tags, arguments, expn_state) # }}} @@ -336,10 +351,9 @@ def map_subst_rule(self, name, tag, arguments, expn_state): return new_outer_expr - def map_kernel(self, kernel): + def map_kernel(self, kernel: LoopKernel): new_insns = [] - excluded_insn_ids = {self.compute_insn_id, self.compute_dep_id} # precomputed_in_insns: set of insn ids in which the subst rule was # precomputed. precomputed_in_insns = set() @@ -359,20 +373,19 @@ def map_kernel(self, kernel): if self.replaced_something: insn = insn.copy( depends_on=( - insn.depends_on - | frozenset([self.compute_dep_id]))) + insn.depends_on | self.compute_dep_ids)) precomputed_in_insns.add(insn.id) for dep in insn.depends_on: - if dep in excluded_insn_ids: + if dep in self.compute_dep_ids: continue dep_insn = kernel.id_to_insn[dep] if (frozenset(dep_insn.assignee_var_names()) & self.compute_read_variables): self.compute_insn_depends_on.update( - insn.depends_on - excluded_insn_ids) + insn.depends_on - self.compute_dep_ids) new_insns.append(insn) @@ -967,7 +980,7 @@ def add_assumptions(d): expression=compute_expression, # within_inames determined below ) - compute_dep_id = compute_insn_id + compute_dep_ids = {compute_insn_id} added_compute_insns: list[InstructionBase] = [compute_insn] if temporary_address_space == AddressSpace.GLOBAL: @@ -979,7 +992,7 @@ def add_assumptions(d): depends_on=frozenset([compute_insn_id]), synchronization_kind="global", mem_kind="global") - compute_dep_id = barrier_insn_id + compute_dep_ids.add(barrier_insn_id) added_compute_insns.append(barrier_insn) @@ -995,7 +1008,7 @@ def add_assumptions(d): access_descriptors, abm, storage_axis_names, storage_axis_sources, non1_storage_axis_names, - temporary_name, compute_insn_id, compute_dep_id, + temporary_name, frozenset(compute_dep_ids), compute_read_variables=get_dependencies(expander(compute_expression))) kernel = invr.map_kernel(kernel)