From 3580a966004c2570455f7f4edecb874db0b4d5dd Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 25 Jul 2023 12:48:16 -0700 Subject: [PATCH 1/4] Add overlooked overload information on torchlib functions [ghstack-poisoned] --- .../function_libs/torch_lib/ops/core.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b46d999b31..68a7f83f83 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -81,7 +81,7 @@ def aten_acosh(self: TFloat) -> TFloat: return op.Acosh(self) -@torch_op("aten::add") +@torch_op(("aten::add", "aten::add.Tensor")) def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" # TODO(microsoft/onnxruntime#15977): Improve fp16 precision @@ -1235,7 +1235,7 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: return op.SplitToSequence(self, list_split, axis=dim) -@torch_op("aten::clamp", trace_only=True) +@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True) def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal: """clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor""" clamped = self @@ -2184,7 +2184,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType raise NotImplementedError() -@torch_op("aten::div") +@torch_op(("aten::div", "aten::div.Tensor")) def aten_div(self: TFloat, other: TFloat) -> TFloat: """div.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -2353,7 +2353,7 @@ def aten_empty_strided( return op.Expand(zero, size) -@torch_op("aten::eq") +@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar")) def aten_eq(self: TTensor, other: TTensor) -> BOOL: """eq.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -2563,7 +2563,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType raise NotImplementedError() -@torch_op("aten::fill") +@torch_op(("aten::fill", "aten::fill.Tensor")) def aten_fill(self: TTensor, value: TTensor) -> TTensor: """fill.Tensor(Tensor self, Tensor value) -> Tensor""" @@ -3595,7 +3595,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::le") +@torch_op(("aten::le", "aten::le.Tensor")) def aten_le(self: TReal, other: TReal) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3884,7 +3884,7 @@ def aten_lstm_mps_backward( raise NotImplementedError() -@torch_op("aten::lt") +@torch_op(("aten::lt", "aten::lt.Scalar")) def aten_lt(self: TReal, other: TReal) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4883,7 +4883,7 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType: raise NotImplementedError() -@torch_op("aten::ne") +@torch_op(("aten::ne", "aten::ne.Scalar")) def aten_ne(self: TReal, other: TReal) -> BOOL: """ne.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -5756,7 +5756,7 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: return op.Reciprocal(op.Sqrt(self)) -@torch_op("aten::rsub") +@torch_op(("aten::rsub", "aten::rsub.Scalar")) def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" # FIXME(titaiwang): get rid of this when we have type_promotion @@ -5855,7 +5855,7 @@ def aten_segment_reduce( raise NotImplementedError() -@torch_op("aten::select") +@torch_op(("aten::select", "aten::select.int")) def aten_select(self: TTensor, dim: int, index: int) -> TTensor: """select(Tensor self, int dim, int index) -> Tensor""" @@ -5935,7 +5935,7 @@ def aten_sinh(self: TFloat) -> TFloat: return op.Sinh(self) -@torch_op("aten::slice", trace_only=True) +@torch_op(("aten::slice", "aten::slice.Tensor"), trace_only=True) def aten_slice( self: TTensor, dim: int = 0, @@ -6081,7 +6081,7 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::split") +@torch_op(("aten::split", "aten::split.Tensor")) def aten_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor: """split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]""" @@ -6309,7 +6309,7 @@ def aten_stft( return result -@torch_op("aten::sub") +@torch_op(("aten::sub", "aten::sub.Tensor")) def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" alpha = op.CastLike(alpha, other) @@ -6634,7 +6634,7 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType: raise NotImplementedError() -@torch_op("aten::transpose", trace_only=True) +@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True) def aten_transpose(self, dim0: int, dim1: int): """transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)""" @@ -6729,7 +6729,7 @@ def aten_type_as(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::unbind") +@torch_op(("aten::unbind", "aten::unbind.int")) def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" @@ -7082,7 +7082,7 @@ def aten_vstack(tensors: Sequence[TTensor]) -> TTensor: return op.ConcatFromSequence(tensors, axis=0) -@torch_op("aten::where") +@torch_op(("aten::where", "aten::where.self")) def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor: """where.self(Tensor condition, Tensor self, Tensor other) -> Tensor""" From 8e3e9b2903313f2360e25c465003b38ea76b2989 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 25 Jul 2023 12:48:19 -0700 Subject: [PATCH 2/4] Tool to modify torchlib overload names via libcst [ghstack-poisoned] --- .../tools/torch_lib/modify_overload_names.py | 179 ++++++++++++++++++ .../function_libs/torch_lib/registration.py | 2 +- 2 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 onnxscript/function_libs/tools/torch_lib/modify_overload_names.py diff --git a/onnxscript/function_libs/tools/torch_lib/modify_overload_names.py b/onnxscript/function_libs/tools/torch_lib/modify_overload_names.py new file mode 100644 index 0000000000..5bc2f82200 --- /dev/null +++ b/onnxscript/function_libs/tools/torch_lib/modify_overload_names.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import enum +import os +import pathlib +from typing import Dict, List, Set, Tuple + +import libcst as cst +from libcst import matchers +from libcst._nodes.statement import FunctionDef + +from onnxscript.function_libs.torch_lib import registration + + +class _StatusEnum(enum.Enum): + SUCCESS = enum.auto() + """Success.""" + FAILURE_OVERLOAD_EXIST = enum.auto() + """Failure: overload name already exists.""" + FAILURE_OVERLOAD_INVALID = enum.auto() + """Failure: overload name is invalid.""" + FAILURE_OP_NOT_FOUND = enum.auto() + """Failure: op not found.""" + FAILURE_OP_MULTIPLE_IMPL = enum.auto() + """Failure: op has multiple implementations. Cannot decide which to add new overload name to.""" + + +def _cst_arg_to_overload_names(arg: cst.Arg) -> Tuple[str, ...]: + if matchers.matches(arg, matchers.Arg(value=matchers.SimpleString())): + overload_names = (cst.ensure_type(arg.value, cst.SimpleString).value,) + else: + overload_names = tuple( + cst.ensure_type(element.value, cst.SimpleString).value + for element in cst.ensure_type(arg.value, cst.Tuple).elements + ) + overload_names = tuple(name.replace('"', "") for name in overload_names) + return overload_names + + +def _overload_names_to_namespace_op(overload_names: Tuple[str, ...]) -> str: + match = registration._QUALIFIED_OPERATOR_NAME_REGEX.fullmatch(overload_names[0]) + assert match is not None + namespace = match.group("namespace") + name = match.group("name") + return f"{namespace}::{name}" + + +class _TorchlibOpOverloadCollector(cst.CSTVisitor): + def __init__(self): + self._op_overloads: Dict[str, List[Tuple[str, List[str]]]] = {} + self._stack: List[str] = [] + + def visit_FunctionDef(self, node: FunctionDef) -> bool | None: + self._stack.append(node.name.value) + + def leave_FunctionDef(self, node: FunctionDef) -> None: + self._stack.pop() + + def visit_Call(self, node: cst.Call) -> None: + if not matchers.matches(node.func, matchers.Name("torch_op")): + return + + function_name = self._stack[-1] + overload_names = _cst_arg_to_overload_names(node.args[0]) + namespace_op_name = _overload_names_to_namespace_op(overload_names) + + self._op_overloads.setdefault(namespace_op_name, []) + self._op_overloads[namespace_op_name].append((function_name, list(overload_names))) + + +class _TorchlibOpOverloadAdder(cst.CSTTransformer): + def __init__( + self, + overload_names: Dict[str, List[Tuple[str, List[str]]]], + new_overload_names: Set[str], + ): + self._overload_names = overload_names + self._results: Dict[str, _StatusEnum] = {} + + for new_overload_name in new_overload_names: + match = registration._QUALIFIED_OPERATOR_NAME_REGEX.fullmatch(new_overload_name) + if not match: + self._results[new_overload_name] = _StatusEnum.FAILURE_OVERLOAD_INVALID + continue + overload = match.group("overload") or "" + if overload == "default": + overload = "" + dot_overload = f".{overload}" if overload else "" + op_name = match.group("name") + namespace = match.group("namespace") + namespace_op_name = f"{namespace}::{op_name}" + qualified_name = f"{namespace_op_name}{dot_overload}" + + if namespace_op_name not in self._overload_names: + self._results[new_overload_name] = _StatusEnum.FAILURE_OP_NOT_FOUND + continue + + if len(self._overload_names[namespace_op_name]) > 1: + self._results[new_overload_name] = _StatusEnum.FAILURE_OP_MULTIPLE_IMPL + continue + + if qualified_name in self._overload_names[namespace_op_name][0][1]: + self._results[new_overload_name] = _StatusEnum.FAILURE_OVERLOAD_EXIST + continue + + self._overload_names[namespace_op_name][0][1].append(qualified_name) + self._results[new_overload_name] = _StatusEnum.SUCCESS + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: + if not matchers.matches(original_node.func, matchers.Name("torch_op")): + return original_node + + original_overload_names = _cst_arg_to_overload_names(original_node.args[0]) + namespace_op_name = _overload_names_to_namespace_op(original_overload_names) + overload_names = self._overload_names[namespace_op_name][0][1] + if len(overload_names) == 1: + return original_node + return updated_node.with_changes( + args=[ + cst.Arg( + value=cst.Tuple( + elements=[ + cst.Element(cst.SimpleString(value=f'"{name}"')) + for name in overload_names + ] + ) + ), + *original_node.args[1:], + ], + ) + + +def add_overload_names( + module_path: pathlib.Path, overload_names: Set[str] +) -> Dict[str, _StatusEnum]: + """NOTE: This function assumes""" + source_tree = cst.parse_module(module_path.read_text()) + op_overload_collector = _TorchlibOpOverloadCollector() + source_tree.visit(op_overload_collector) + transformer = _TorchlibOpOverloadAdder(op_overload_collector._op_overloads, overload_names) + modified_tree = source_tree.visit(transformer) + module_path.write_text(modified_tree.code) + return transformer._results + + +def main(): + new_overload_names = { + "aten::add.Tensor", + "aten::clamp.Tensor", + "aten::div.Tensor", + "aten::eq.Scalar", + "aten::eq.Tensor", + "aten::fill.Tensor", + "aten::ge.Scalaraten::ge.Tensoraten::gt.Scalar", + "aten::le.Tensor", + "aten::lt.Scalar", + "aten::mul.Tensor", + "aten::ne.Scalar", + "aten::roll.default", + "aten::rsub.Scalar", + "aten::select.int", + "aten::slice.Tensor", + "aten::split.Tensor", + "aten::sub.Tensor", + "aten::transpose.int", + "aten::unbind.int", + "aten::where.self", + } + file_paths = [ + pathlib.Path(os.path.join(root, file)) + for root, dirs, files in os.walk("onnxscript/function_libs/torch_lib/ops") + for file in files + ] + for file_path in file_paths: + print(add_overload_names(file_path, new_overload_names)) + + +if __name__ == "__main__": + main() diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 8fef3fd382..93aad59f5e 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -10,7 +10,7 @@ # Regex that will match "::[.]" _QUALIFIED_OPERATOR_NAME_REGEX = re.compile( - r"^(?P[a-zA-Z0-9_]+)::(?P[a-zA-Z0-9_]+)(?P\.[a-zA-Z0-9._]+)?$" + r"^(?P\w+)::(?P\w+)(?:\.(?P\w+))?$" ) From b88388a41e286970c23ba75b455e780df35b1e13 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 25 Jul 2023 12:52:39 -0700 Subject: [PATCH 3/4] Update base for Update on "Tool to modify torchlib overload names via libcst" [ghstack-poisoned] --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 68a7f83f83..46a7467f14 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2748,7 +2748,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::ge") +@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar")) def aten_ge(self: TReal, other: TReal) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -2905,7 +2905,7 @@ def aten_gru_cell( raise NotImplementedError() -@torch_op("aten::gt") +@torch_op(("aten::gt", "aten::gt.Scalar")) def aten_gt(self: TReal, other: TReal) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4462,7 +4462,7 @@ def aten_msort(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::mul") +@torch_op(("aten::mul", "aten::mul.Tensor")) def aten_mul(self: TReal, other: TReal) -> TReal: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" # FIXME(titaiwang): get rid of this when we have type_promotion @@ -4470,7 +4470,7 @@ def aten_mul(self: TReal, other: TReal) -> TReal: return op.Mul(self, other) -@torch_op("aten::mul") +@torch_op(("aten::mul", "aten::mul.Tensor")) def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: """ONNX Mul doesn't support Boolean, so use And as an equivalent operator.""" From 3521a1a40aa21bb93c8757263833c7667bf99598 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 25 Jul 2023 14:41:20 -0700 Subject: [PATCH 4/4] Update base for Update on "Tool to modify torchlib overload names via libcst" [ghstack-poisoned] --- onnxscript/function_libs/torch_lib/ops/core.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 46a7467f14..6aaddd65f7 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2299,7 +2299,7 @@ def aten_embedding_sparse_backward( raise NotImplementedError() -@torch_op("aten::empty") +@torch_op(("aten::empty", "aten::empty.memory_format")) def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var] # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor @@ -3957,7 +3957,7 @@ def aten_margin_ranking_loss( raise NotImplementedError() -@torch_op("aten::masked_fill") +@torch_op(("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor")) def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: """masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor""" # NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types. @@ -4883,7 +4883,7 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType: raise NotImplementedError() -@torch_op(("aten::ne", "aten::ne.Scalar")) +@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor")) def aten_ne(self: TReal, other: TReal) -> BOOL: """ne.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -5223,7 +5223,7 @@ def aten_positive(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::pow") +@torch_op(("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar")) def aten_pow(self: TReal, exponent: TTensor) -> TReal: """pow(Tensor self, Tensor exponent) -> Tensor""" @@ -5785,7 +5785,7 @@ def aten_scatter_add( return op.ScatterElements(self, index, src, axis=dim, reduction="add") -@torch_op("aten::scatter_reduce", trace_only=True) +@torch_op(("aten::scatter_reduce", "aten::scatter_reduce.two"), trace_only=True) def aten_scatter_reduce( self: TReal, dim: int, # we have to use int here because ScatterElements() will use this attribute @@ -6324,7 +6324,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1.0) -> Te raise NotImplementedError() -@torch_op("aten::sum", trace_only=True) +@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True) def aten_sum_dim_IntList( self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1 ) -> TReal: