diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py index 829ee212cf..183ed3acdb 100644 --- a/onnxscript/__init__.py +++ b/onnxscript/__init__.py @@ -14,7 +14,7 @@ "TracedOnnxFunction", "GraphBuilder", "OpBuilder", - "OpBuilderBase", + "BuilderBase", "TapeBuilder", "build_function", "build_graph", @@ -69,6 +69,7 @@ "opset_ai_onnx_ml4", "opset_ai_onnx_ml5", "DEBUG", + "BuilderFeature", ] import importlib.metadata @@ -135,7 +136,11 @@ from . import ir, nn, optimizer, rewriter, version_converter from ._internal.builder import GraphBuilder, OpBuilder, build_function, build_graph -from ._internal.tape_builder import OpBuilderBase, TapeBuilder +from ._internal.tape_builder import ( + BuilderBase, + BuilderFeature, + TapeBuilder, +) from ._internal.utils import external_tensor from ._internal.values import OnnxFunction, TracedOnnxFunction diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index f867d4332d..c94e207cfa 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -3,21 +3,24 @@ """Graph builder for constructing ONNX IR graphs imperatively. This module provides imperative builders for constructing ONNX IR graphs with automatic -constant promotion, type casting, and shape inference. The GraphBuilder class enables -programmatic construction of graphs with proper scoping, constant management, and node -creation. The OpBuilder class provides dynamic op dispatching via attribute access. +constant promotion, type casting, and shape inference. The GraphBuilder class inherits +from BuilderBase and enables programmatic construction of graphs with proper scoping, +constant management, and node creation. The OpBuilder class provides dynamic op +dispatching via attribute access. """ from __future__ import annotations from typing import Any, Callable, Mapping, Sequence, Union -import onnx import onnx_ir as ir -import onnxscript._internal._inference as inference -import onnxscript.optimizer -from onnxscript._internal import _inliner, param_manipulation +import onnxscript +from onnxscript._internal import _inliner +from onnxscript._internal.tape_builder import ( + BuilderBase, + BuilderFeature, +) # A permissible value for an op input, which can be converted to an ir.Value. VALUE_LIKE = Union[ @@ -418,10 +421,11 @@ def build_function( ) -class GraphBuilder: +class GraphBuilder(BuilderBase): """Imperative builder for constructing ONNX IR graphs with automatic constant promotion, type casting, and shape inference.""" def __init__(self, graph: ir.Graph, *, parent: GraphBuilder | None = None) -> None: + super().__init__(features=BuilderFeature.FULL) self._graph = graph self._parent = parent self._root: GraphBuilder = parent._root if parent is not None else self @@ -474,6 +478,71 @@ def graph(self) -> ir.Graph: def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]: return self._root._functions + # ------------------------------------------------------------------ + # BuilderBase abstract method implementations + # ------------------------------------------------------------------ + + def _add_node(self, node: ir.Node) -> None: + """Append a node to the graph.""" + self.graph.append(node) + + def _add_initializer(self, value: ir.Value) -> None: + """Register an initializer in the root graph.""" + self._root._graph.register_initializer(value) + + def _record_opset(self, domain: str, version: int | None) -> None: + # Graph already tracks opset imports; nothing to do. + pass + + # ------------------------------------------------------------------ + # BuilderBase hook overrides + # ------------------------------------------------------------------ + + def _promote_constant(self, value: Any, dtype: ir.DataType | None) -> ir.Value: + """Cache-based constant promotion. + + Delegates to the root builder so that all constant initializers + live in the root graph (outer-scope initializers are visible to + subgraphs per the ONNX spec). + """ + return self._get_or_create_constant(value, dtype) + + def _generate_node_name(self, op_type: str) -> str: + count = self.graph.num_nodes() + return self._qualify_node_name(f"{op_type}_node_{count}") + + def _adapt_outputs( + self, outputs: int | Sequence[str | ir.Value], op_type: str + ) -> Sequence[ir.Value]: + """Pre-create named output ir.Value objects for the graph.""" + if isinstance(outputs, int): + count = self.graph.num_nodes() + if outputs < 0: + raise ValueError(f"Number of outputs must be non-negative, got {outputs}") + if outputs == 1: + name = f"{op_type}_{count}" if op_type else f"{count}" + return [ir.Value(name=self._qualify_value_name(name))] + else: + names = [ + (f"{op_type}_{count}_{i}" if op_type else f"{count}_{i}") + for i in range(outputs) + ] + return [ir.Value(name=self._qualify_value_name(n)) for n in names] + # Delegate to base class for Sequence[str | ir.Value] + result = super()._adapt_outputs(outputs, op_type) + assert result is not None + return result + + def _annotate_node(self, node: ir.Node) -> None: + """Attach scope metadata to the node.""" + node.metadata_props["namespace"] = self._build_namespace() + node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes()) + node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names()) + + # ------------------------------------------------------------------ + # GraphBuilder-specific public API + # ------------------------------------------------------------------ + def initializer( self, tensor: ir.TensorProtocol, name: str | None = None, *, qualify: bool = True ) -> ir.Value: @@ -594,158 +663,15 @@ def _get_or_create_constant( # TODO(rama): Consider caching for other tensor values. return self.initializer(ir.tensor(value, dtype=dtype)) - def _input_to_ir_value( - self, value: VALUE_LIKE, like_type: ir.Value | None = None - ) -> ir.Value | None: - """Convert a permissible input (for a call to an op) into an ir.Value. - - Permissible values include ir.Value as well as python constants that can be converted - into ONNX constant tensors. For constant values, the like_type is used to determine the - target onnx type. - """ - if isinstance(value, ir.Value): - return value - if value is None: - return value - dtype = ( - like_type.type.dtype - if like_type is not None and like_type.type is not None - else None - ) - needs_dynamic_cast = like_type is not None and dtype is None - ir_value = self._get_or_create_constant(value, dtype) - # If like_type is provided but its type is unknown, insert a dynamic CastLike - # so the constant is cast to match like_type's type at runtime. - # The CastLike node is created in THIS builder's graph (not root), - # so that it lives in the correct scope (subgraph or function body). - if needs_dynamic_cast: - ir_value = self.op.CastLike(ir_value, like_type) - return ir_value - - def _adapt_outputs( - self, outputs: int | Sequence[str | ir.Value], op_type: str = "" - ) -> Sequence[ir.Value]: - if isinstance(outputs, int): - count = self.graph.num_nodes() - if outputs < 0: - raise ValueError(f"Number of outputs must be non-negative, got {outputs}") - if outputs == 1: - name = f"{op_type}_{count}" if op_type else f"{count}" - return [ir.Value(name=self._qualify_value_name(name))] - else: - names = [ - (f"{op_type}_{count}_{i}" if op_type else f"{count}_{i}") - for i in range(outputs) - ] - return [ir.Value(name=self._qualify_value_name(n)) for n in names] - adapted_outputs = [] - for output in outputs: - if isinstance(output, ir.Value): - if output.name: - output.name = self._qualify_value_name(output.name) - adapted_outputs.append(output) - elif isinstance(output, str): - adapted_outputs.append(ir.Value(name=self._qualify_value_name(output))) - else: - raise TypeError("Output type not supported.") - return adapted_outputs - - def _get_schema( - self, op_type: str, domain: str, version: int | None - ) -> onnx.defs.OpSchema | None: - if version is not None: - try: - return onnx.defs.get_schema(op_type, version, domain) - except onnx.defs.SchemaError: - pass - return None - - def _partition_inputs_attributes( - self, - schema: onnx.defs.OpSchema | None, - inputs: Sequence[ir.Value | ir.TensorProtocol | None], - kwargs: dict[str, Any], - ) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]: - if schema is None: - return inputs, kwargs - op_signature = ir.schemas.OpSignature.from_op_schema(schema) - return param_manipulation.separate_input_attributes_from_arguments( - op_signature, - list(inputs), - kwargs, - fill_defaults=False, - allow_extra_args=False, - ) + def add_node(self, node: ir.Node) -> None: + """Append a node to the graph, run constant propagation and shape inference. - def _cast_inputs( - self, - schema: onnx.defs.OpSchema | None, - inputs: Sequence[VALUE_LIKE], - ) -> Sequence[ir.Value | None]: - """Uses schema specification to support a limited form of auto-casting. - - * Scalars are promoted to tensors. - * Further. they are cast to the required type when used in ops with other - tensor inputs that are required to be of same type. - Thus, in "A+1" or "Add(A, 1)", the value 1 will be converted to the same - type as A. + This is a backward-compatible public method used by call_inline and + other code that creates nodes manually. """ - if schema is None: - return [self._input_to_ir_value(i) for i in inputs] - - expected_inputs = schema.inputs - # We make two passes. In the first pass, we identify known type-bindings for - # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}. - # In the second pass, we use these bindings to cast scalar-values to - # tensors of appropriate types. The two passes are needed to handle cases - # like "Add(1, X)" where 1 must be cast to the same type as X. - type_bindings: dict[str, ir.Value] = {} - args_typevars: list[tuple[ir.Value | None, str | None]] = [] - for i, x in enumerate(inputs): - if i < len(expected_inputs): - expected = expected_inputs[i] - elif expected_inputs and ( - expected_inputs[-1].option == onnx.defs.OpSchema.FormalParameterOption.Variadic - ): - expected = expected_inputs[-1] - if not expected.is_homogeneous: - args_typevars.append((x, None)) - continue - else: - raise ValueError( - f"Number of actual parameters {len(inputs)} " - f"exceeds number of formal parameters {len(expected_inputs)}." - ) - typevar = expected.type_str - if ("(" not in typevar) and (typevar not in type_bindings): - # typevar is an identifier, like "T" - if isinstance(x, ir.Value): - type_bindings[typevar] = x - args_typevars.append((x, typevar)) - - def adapt(x, typevar: str | None) -> ir.Value | None: - if x is None: - return None - if typevar is None: - return self._input_to_ir_value(x) - type_like = type_bindings.get(typevar) - return self._input_to_ir_value(x, type_like) - - return [adapt(x, typevar) for x, typevar in args_typevars] - - def _cast_attributes( - self, - schema: onnx.defs.OpSchema | None, - attributes: dict[str, Any], - ) -> dict[str, Any]: - del schema # Not implemented yet - return attributes if attributes is not None else {} - - def add_node(self, node: ir.Node) -> None: - """Append a node to the graph, run constant propagation and shape inference.""" - self.graph.append(node) - onnxscript.optimizer.basic_constant_propagation([node]) - inference.infer_outputs(node) + self._add_node(node) + self._constant_propagation(node) + self._infer_shapes(node) def subgraph( self, @@ -796,46 +722,6 @@ def subgraph( parent=self, ) - def call_op( - self, - op_type: str, - inputs: Sequence[ir.Value | ir.TensorProtocol | None], - kwargs: dict[str, Any], - /, - domain: str = "", - version: int | None = None, - outputs: int | Sequence[str | ir.Value] = 1, - ): - """Create an ONNX node and add it to the graph, returning its output value(s).""" - count = self.graph.num_nodes() - node_name = self._qualify_node_name(f"{op_type}_node_{count}") - - output_values = self._adapt_outputs(outputs, op_type) - - schema = self._get_schema(op_type, domain, version) - inputs, attributes = self._partition_inputs_attributes(schema, inputs, kwargs) - inputs = self._cast_inputs(schema, inputs) - attributes = self._cast_attributes(schema, attributes) - - node = ir.node( - op_type, - inputs, - attributes=attributes or None, - domain=domain, - outputs=output_values, - version=version, - name=node_name, - ) - - # Attach scope metadata to the node - node.metadata_props["namespace"] = self._build_namespace() - node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes()) - node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names()) - - self.add_node(node) - - return node.outputs if len(node.outputs) > 1 else node.outputs[0] - def call( self, function: ir.Function | onnxscript.OnnxFunction, diff --git a/onnxscript/_internal/tape_builder.py b/onnxscript/_internal/tape_builder.py index 384fc317e7..abdc1ef4d8 100644 --- a/onnxscript/_internal/tape_builder.py +++ b/onnxscript/_internal/tape_builder.py @@ -1,11 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Op builder base class and tape-backed implementation. +"""Builder base class and tape-backed implementation. This module defines: -- ``OpBuilderBase``: Abstract base class for building ONNX IR nodes via a - dynamic dispatch interface (``op.Relu(x)``, ``op.op(...)``, ``op.initializer(...)``). +- ``BuilderBase``: Abstract base class for building ONNX IR nodes via a + dynamic dispatch interface (``op.Relu(x)``, ``op.op("Relu", x)``, + ``op.initializer(...)``). Subclasses implement the storage strategy by overriding ``_add_node``, ``_add_initializer``, and ``_record_opset``. @@ -13,27 +14,76 @@ (rewriter, optimizer, version converter) create an instance, pass it to a rule or evaluator, and harvest the accumulated nodes / initializers / opsets after it returns. + +- ``BuilderFeature``: Flag enum controlling optional processing steps + (schema partitioning, input casting, shape inference, etc.). """ from __future__ import annotations import abc -from typing import Any, Mapping, Optional, Sequence +import enum +from typing import Any, Optional, Sequence +import onnx import onnx_ir as ir from onnx_ir import _convenience +from onnxscript._internal import param_manipulation + UsedOpsets = set[tuple[str, Optional[int]]] -class OpBuilderBase(abc.ABC): +def _dtype_suffix(dtype: ir.DataType) -> str: + """Return a short type suffix for naming constants based on ir.DataType.""" + return dtype.short_name() + + +def _constant_name( + value: int | float | bool | str | Sequence, type_suffix: str, num: int = 0 +) -> str: + """Generate a descriptive name for a constant value.""" + if isinstance(value, str): + return f"const_str_{num}" + if isinstance(value, (int, float, bool)): + return f"const_{value}_{type_suffix}" if type_suffix else f"const_{value}" + return f"const_1d_{num}" + + +class BuilderFeature(enum.Flag): + """Features that can be enabled on BuilderBase.""" + + NONE = 0 + SCHEMA_PARTITION = enum.auto() + CAST_INPUTS = enum.auto() + CAST_ATTRIBUTES = enum.auto() + INFER_SHAPES = enum.auto() + CONSTANT_PROPAGATION = enum.auto() + + # Convenience combos + SCHEMA_AWARE = SCHEMA_PARTITION | CAST_INPUTS | CAST_ATTRIBUTES + FULL = SCHEMA_AWARE | INFER_SHAPES | CONSTANT_PROPAGATION + + @property + def any_schema_feature(self) -> bool: + """True if any schema-dependent feature is enabled.""" + return bool( + self + & ( + BuilderFeature.SCHEMA_PARTITION + | BuilderFeature.CAST_INPUTS + | BuilderFeature.CAST_ATTRIBUTES + ) + ) + + +class BuilderBase(abc.ABC): """Abstract base class for building ONNX IR nodes. - Supports three creation operations: + Supports two creation operations: - 1. **Dynamic op dispatch** — ``op.Relu(x)``, ``op.MatMul(a, b, _domain=...)``, etc. - 2. **Explicit op creation** — ``op.op("Conv", inputs, attrs, domain=...)``. - 3. **Initializer creation** — ``op.initializer(tensor, name=...)``. + 1. **Op creation** — ``op.op("Relu", x)`` or ``op.Relu(x)`` (syntactic sugar). + 2. **Initializer creation** — ``op.initializer(tensor, name=...)``. Subclasses must implement the three protected methods that define where created nodes and initializers are stored: @@ -43,6 +93,13 @@ class OpBuilderBase(abc.ABC): - :meth:`_record_opset` """ + def __init__(self, *, features: BuilderFeature = BuilderFeature.NONE) -> None: + self._features = features + + @property + def features(self) -> BuilderFeature: + return self._features + # ------------------------------------------------------------------ # Abstract storage interface (to be implemented by subclasses) # ------------------------------------------------------------------ @@ -62,6 +119,209 @@ def _record_opset(self, domain: str, version: int | None) -> None: """Record that an opset domain/version was referenced.""" raise NotImplementedError + # ------------------------------------------------------------------ + # Overridable hook methods + # ------------------------------------------------------------------ + + def _get_schema( + self, op_type: str, domain: str, version: int | None + ) -> onnx.defs.OpSchema | None: + """Look up the op schema. + + Returns None if version is not provided or schema is not found. + """ + if version is not None: + try: + return onnx.defs.get_schema(op_type, version, domain) + except onnx.defs.SchemaError: + pass + return None + + def _partition_inputs_attributes( + self, + schema: onnx.defs.OpSchema | None, + args: Sequence[Any], + kwargs: dict[str, Any], + ) -> tuple[Sequence[Any], dict[str, Any]]: + """Separate positional args into inputs and attributes using the schema.""" + if schema is None: + return args, kwargs + op_signature = ir.schemas.OpSignature.from_op_schema(schema) + return param_manipulation.separate_input_attributes_from_arguments( + op_signature, + list(args), + kwargs, + fill_defaults=False, + allow_extra_args=False, + ) + + def _cast_inputs( + self, + schema: onnx.defs.OpSchema | None, + inputs: Sequence[Any], + ) -> Sequence[ir.Value | None]: + """Cast/promote inputs (e.g., scalars → tensors) using schema type info. + + Uses schema specification to support a limited form of auto-casting: + * Scalars are promoted to tensors via _input_to_ir_value. + * They are cast to the required type when used in ops with other + tensor inputs that are required to be of same type. + Thus, in "A+1" or "Add(A, 1)", the value 1 will be converted to the same + type as A. + """ + if schema is None: + return [self._input_to_ir_value(i) for i in inputs] + + expected_inputs = schema.inputs + # We make two passes. In the first pass, we identify known type-bindings for + # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}. + # In the second pass, we use these bindings to cast scalar-values to + # tensors of appropriate types. The two passes are needed to handle cases + # like "Add(1, X)" where 1 must be cast to the same type as X. + type_bindings: dict[str, ir.Value] = {} + args_typevars: list[tuple[ir.Value | None, str | None]] = [] + for i, x in enumerate(inputs): + if i < len(expected_inputs): + expected = expected_inputs[i] + elif expected_inputs and ( + expected_inputs[-1].option == onnx.defs.OpSchema.FormalParameterOption.Variadic + ): + expected = expected_inputs[-1] + if not expected.is_homogeneous: + args_typevars.append((x, None)) + continue + else: + raise ValueError( + f"Number of actual parameters {len(inputs)} " + f"exceeds number of formal parameters {len(expected_inputs)}." + ) + typevar = expected.type_str + if ("(" not in typevar) and (typevar not in type_bindings): + # typevar is an identifier, like "T" + if isinstance(x, ir.Value): + type_bindings[typevar] = x + args_typevars.append((x, typevar)) + + def adapt(x, typevar: str | None) -> ir.Value | None: + if x is None: + return None + if typevar is None: + return self._input_to_ir_value(x) + type_like = type_bindings.get(typevar) + return self._input_to_ir_value(x, type_like) + + return [adapt(x, typevar) for x, typevar in args_typevars] + + def _cast_attributes( + self, + schema: onnx.defs.OpSchema | None, + attributes: dict[str, Any], + ) -> dict[str, Any]: + """Cast attributes using schema info. + + Default: pass through unchanged. + """ + del schema # Not implemented yet + return attributes if attributes is not None else {} + + def _input_to_ir_value( + self, value: Any, like_type: ir.Value | None = None + ) -> ir.Value | None: + """Convert a permissible input into an ir.Value. + + Handles ir.Value (pass-through), None (pass-through), and Python + constants/sequences/tensors (promoted to initializers via + ``_promote_constant``). When *like_type* is provided but its dtype + is unknown at graph-construction time, a dynamic ``CastLike`` node + is inserted so the constant matches *like_type* at runtime. + """ + if isinstance(value, ir.Value): + return value + if value is None: + return value + dtype = ( + like_type.type.dtype + if like_type is not None and like_type.type is not None + else None + ) + needs_dynamic_cast = like_type is not None and dtype is None + ir_value = self._promote_constant(value, dtype) + if needs_dynamic_cast: + ir_value = self.call_op("CastLike", [ir_value, like_type], {}) + return ir_value + + def _promote_constant(self, value: Any, dtype: ir.DataType | None) -> ir.Value: + """Convert a Python constant into an ir.Value via a Constant node. + + Creates a ``Constant`` op node whose output carries the tensor value. + This avoids initializer-name collisions when the builder is used + inside the rewriter/optimizer. + + GraphBuilder overrides this with a cache-based initializer strategy. + """ + tensor = ir.tensor(value, dtype=dtype) + return self.call_op("Constant", [], {"value": tensor}) + + def _qualify_value_name(self, name: str) -> str: + """Qualify a value name with scope prefix. + + Default: identity (no qualification). Override in GraphBuilder + to add module scope prefixes. + """ + return name + + def _generate_node_name(self, op_type: str) -> str | None: + """Generate a node name. Default: None (no auto-naming).""" + return None + + def _adapt_outputs( + self, outputs: int | Sequence[str | ir.Value], op_type: str + ) -> Sequence[ir.Value] | None: + """Pre-create output ir.Value objects. + + Default returns ``None`` for int outputs (letting ir.Node create + anonymous outputs), and converts string/ir.Value sequences. + Override in GraphBuilder to always pre-create named outputs. + """ + if isinstance(outputs, int): + return None + adapted_outputs = [] + for output in outputs: + if isinstance(output, ir.Value): + if output.name: + output.name = self._qualify_value_name(output.name) + adapted_outputs.append(output) + elif isinstance(output, str): + adapted_outputs.append(ir.Value(name=self._qualify_value_name(output))) + else: + raise TypeError("Output type not supported.") + return adapted_outputs + + def _constant_propagation(self, node: ir.Node) -> None: + """Run basic constant propagation on a newly created node. + + Called when CONSTANT_PROPAGATION feature is enabled. + """ + # Lazy import to avoid circular dependency at module level. + import onnxscript.optimizer # pylint: disable=import-outside-toplevel + + onnxscript.optimizer.basic_constant_propagation([node]) + + def _infer_shapes(self, node: ir.Node) -> None: + """Run shape/type inference on a newly created node. + + Called when INFER_SHAPES feature is enabled. + """ + from onnxscript._internal import _inference # pylint: disable=import-outside-toplevel + + _inference.infer_outputs(node) + + def _annotate_node(self, node: ir.Node) -> None: # noqa: B027 + """Attach metadata to a node after creation. + + Default: no-op. Override to add scope/namespace annotations. + """ + # ------------------------------------------------------------------ # Public API (concrete) # ------------------------------------------------------------------ @@ -69,85 +329,139 @@ def _record_opset(self, domain: str, version: int | None) -> None: def __getattr__(self, op_type: str) -> Any: """Dynamic op dispatch: ``op.Relu(x)``, ``op.MatMul(a, b)``, etc. + Syntactic sugar for ``op.op(op_type, ...)``. + Returns a callable that creates a node of the given ``op_type`` and records it via the subclass storage implementation. - - Supported keyword arguments on the returned callable: - _domain (str): Op domain (default ``""``). - _version (int | None): Opset version. - _outputs (int | list[str]): Number of outputs or explicit output names. - _name (str | None): Optional node name (must be unique). """ - return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) + return lambda *args, **kwargs: self.op(op_type, *args, **kwargs) - def _make_node( - self, op_type: str, inputs: Sequence[ir.Value | None], kwargs: dict[str, Any] + def op( + self, + op_type: str, + /, + *args: ir.Value | None, + _domain: str = "", + _version: int | None = None, + _outputs: int | Sequence[str] = 1, + _name: str | None = None, + **kwargs: Any, ) -> ir.Value | Sequence[ir.Value]: - """Create one or more output values by building an ``ir.Node``.""" - domain = kwargs.pop("_domain", "") - version = kwargs.pop("_version", None) - outputs = kwargs.pop("_outputs", 1) - name = kwargs.pop("_name", None) - - if isinstance(outputs, Sequence): - num_outputs = len(outputs) - else: - assert isinstance(outputs, int) - num_outputs = outputs - - attrs: Sequence[ir.Attr] = _convenience.convert_attributes(kwargs) if kwargs else () - node = ir.Node( - domain, + """Create an ONNX node. + + This is the single entry point for all node creation. + ``op.Relu(x)`` is equivalent to ``op.op("Relu", x)``. + + Args: + op_type: The operator type (e.g., ``"Relu"``, ``"Conv"``). + *args: Positional arguments — the node's input values. + _domain: Op domain (default ``""``). + _version: Opset version. + _outputs: Number of outputs or list of explicit output names. + _name: Optional node name (must be unique). + **kwargs: Keyword arguments — node attributes. + Values can be Python scalars/lists (auto-converted) or + ``ir.Attr`` instances (passed through). + + Returns: + A single ``ir.Value`` if the node has one output, otherwise + a sequence of ``ir.Value``. + """ + return self.call_op( op_type, - inputs, - attributes=attrs, - num_outputs=num_outputs, - version=version, - name=name, + args, + kwargs, + domain=_domain, + version=_version, + outputs=_outputs, + name=_name, ) - self._add_node(node) - self._record_opset(domain, version) - if num_outputs == 1: - if isinstance(outputs, Sequence): - node.outputs[0].name = outputs[0] - return node.outputs[0] - - if isinstance(outputs, Sequence): - for value, output_name in zip(node.outputs, outputs): - value.name = output_name - return node.outputs - - def op( + def call_op( self, op_type: str, - inputs: Sequence[ir.Value | None], - attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, - *, + args: Sequence[Any], + kwargs: dict[str, Any], + /, domain: str = "", version: int | None = None, + outputs: int | Sequence[str | ir.Value] = 1, name: str | None = None, - ) -> ir.Value: - """Create a single-output node with an explicit op type. + ) -> ir.Value | Sequence[ir.Value]: + """Create an ONNX node and add it to the graph, returning its output value(s). - This is useful when the op type is determined dynamically or when - forwarding attributes from a matched node. + This is the core node-creation method. Both ``BuilderBase.op()`` and + ``OpBuilder.__getattr__`` delegate here. The processing steps are + controlled by :attr:`features` flags and overridable hook methods. """ - attrs: Sequence[ir.Attr] = ( - _convenience.convert_attributes(attributes) if attributes else () - ) - node = ir.Node( - domain, - op_type, - inputs, - attributes=attrs, - num_outputs=1, - version=version, - name=name, - ) + features = self._features + + # 1. Schema lookup (if any schema-dependent feature is enabled) + schema = None + if features.any_schema_feature: + schema = self._get_schema(op_type, domain, version) + + # 2. Partition args into inputs and attributes using schema + if features & BuilderFeature.SCHEMA_PARTITION: + args, kwargs = self._partition_inputs_attributes(schema, args, kwargs) + + # 3. Cast inputs (scalar→tensor promotion, type-variable matching) + if features & BuilderFeature.CAST_INPUTS: + args = self._cast_inputs(schema, args) + + # 4. Cast attributes using schema info + if features & BuilderFeature.CAST_ATTRIBUTES: + kwargs = self._cast_attributes(schema, kwargs) + + # 5. Convert remaining kwargs to ir.Attr list + attrs: Sequence[ir.Attr] = _convenience.convert_attributes(kwargs) if kwargs else () + + # 6. Determine outputs + output_values = self._adapt_outputs(outputs, op_type) + + # 7. Build the node + if name is None: + name = self._generate_node_name(op_type) + + if output_values is not None: + node = ir.Node( + domain, + op_type, + args, + attributes=attrs, + outputs=output_values, + version=version, + name=name, + ) + else: + num_outputs = len(outputs) if isinstance(outputs, Sequence) else outputs + node = ir.Node( + domain, + op_type, + args, + attributes=attrs, + num_outputs=num_outputs, + version=version, + name=name, + ) + + # 8. Annotate (metadata, scope) + self._annotate_node(node) + + # 9. Store self._add_node(node) self._record_opset(domain, version) - return node.outputs[0] + + # 10. Post-creation hooks (inference, const-prop) + if features & BuilderFeature.CONSTANT_PROPAGATION: + self._constant_propagation(node) + if features & BuilderFeature.INFER_SHAPES: + self._infer_shapes(node) + + # 11. Return + if len(node.outputs) == 1: + return node.outputs[0] + return node.outputs def initializer( self, @@ -166,7 +480,7 @@ def initializer( return value -class TapeBuilder(OpBuilderBase): +class TapeBuilder(BuilderBase): """Concrete builder backed by simple lists (tape-like storage). Engines (rewriter, optimizer, version converter) create an instance, @@ -175,7 +489,8 @@ class TapeBuilder(OpBuilderBase): ``used_opsets`` properties. """ - def __init__(self) -> None: + def __init__(self, *, features: BuilderFeature = BuilderFeature.NONE) -> None: + super().__init__(features=features) self._nodes: list[ir.Node] = [] self._initializers: list[ir.Value] = [] self._used_opsets: UsedOpsets = set() diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 6d34062fa3..da1a286cec 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -24,9 +24,9 @@ import onnx_ir as ir import onnxscript.utils.utils as utils -from onnxscript._internal.tape_builder import OpBuilderBase, TapeBuilder +from onnxscript._internal.tape_builder import BuilderBase, TapeBuilder -OptimizerContext = OpBuilderBase +OptimizerContext = BuilderBase DEFAULT_CONSTANT_FOLD_BLACKLIST = [ # ConstantOfShape is preserved to avoid increasing model size unnecessarily diff --git a/onnxscript/rewriter/_context.py b/onnxscript/rewriter/_context.py index e1cb1b9db9..ae53384ceb 100644 --- a/onnxscript/rewriter/_context.py +++ b/onnxscript/rewriter/_context.py @@ -2,24 +2,24 @@ # Licensed under the MIT License. """Rewriter-specific context aliases. -This module re-exports ``OpBuilderBase`` and ``TapeBuilder`` from +This module re-exports ``BuilderBase`` and ``TapeBuilder`` from :mod:`onnxscript._internal.tape_builder` and defines the ``RewriterContext`` alias used in rewrite-rule signatures. """ from __future__ import annotations -from onnxscript._internal.tape_builder import OpBuilderBase, TapeBuilder, UsedOpsets +from onnxscript._internal.tape_builder import BuilderBase, TapeBuilder, UsedOpsets # Alias used in rewrite rule signatures (the ``op`` parameter type) -RewriterContext = OpBuilderBase +RewriterContext = BuilderBase # Backward compatibility aliases TapeRewriterContext = TapeBuilder -OptimizerContext = OpBuilderBase +OptimizerContext = BuilderBase __all__ = [ - "OpBuilderBase", + "BuilderBase", "OptimizerContext", "RewriterContext", "TapeBuilder", diff --git a/onnxscript/rewriter/_context_test.py b/onnxscript/rewriter/_context_test.py index 08bce457d6..4b660dcc3e 100644 --- a/onnxscript/rewriter/_context_test.py +++ b/onnxscript/rewriter/_context_test.py @@ -85,7 +85,7 @@ def test_op_method_explicit(self): op = TapeBuilder() x = ir.Value(name="x") w = ir.Value(name="w") - result = op.op("Conv", inputs=[x, w], domain="", name="my_conv") + result = op.op("Conv", x, w, _domain="", _name="my_conv") self.assertIsInstance(result, ir.Value) self.assertEqual(op.nodes[0].op_type, "Conv") self.assertEqual(op.nodes[0].name, "my_conv") @@ -93,13 +93,13 @@ def test_op_method_explicit(self): def test_op_method_with_attributes(self): op = TapeBuilder() x = ir.Value(name="x") - result = op.op("Elu", inputs=[x], attributes={"alpha": 2.0}) + result = op.op("Elu", x, alpha=2.0) self.assertIsInstance(result, ir.Value) self.assertEqual(op.nodes[0].op_type, "Elu") self.assertIn("alpha", op.nodes[0].attributes) def test_op_method_with_attr_map(self): - """Verify that passing node.attributes (an Attributes mapping) works.""" + """Verify that passing **node.attributes (an Attributes mapping) works.""" op = TapeBuilder() source_node = ir.Node( "", @@ -108,7 +108,7 @@ def test_op_method_with_attr_map(self): attributes=[ir.AttrInt64s("pads", [1, 1, 1, 1])], num_outputs=1, ) - result = op.op("Conv", inputs=[ir.Value()], attributes=source_node.attributes) + result = op.op("Conv", ir.Value(), **source_node.attributes) self.assertIsInstance(result, ir.Value) self.assertIn("pads", op.nodes[0].attributes) diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py index fe5405641f..0754baf4f7 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -83,12 +83,10 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu return op.op( self.op_type, - inputs=[ - x, - op.initializer(fused_weights, name=inbound_node.inputs[1].name), - op.initializer(fused_bias, name=bias_name), - ], - attributes=inbound_node.attributes, + x, + op.initializer(fused_weights, name=inbound_node.inputs[1].name), + op.initializer(fused_bias, name=bias_name), + **inbound_node.attributes, ) def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> MatchResult: diff --git a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py index 6608dfc8b3..a50de6c4ff 100644 --- a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py @@ -85,10 +85,11 @@ def rewrite(self, op, x: ir.Value, pad: ir.Value, conv: ir.Value) -> ir.Value: return op.op( conv_node.op_type, - inputs=(x, *conv_node.inputs[1:]), - attributes=conv_attr, - domain=conv_node.domain, - name=conv_node.name, + x, + *conv_node.inputs[1:], + _domain=conv_node.domain, + _name=conv_node.name, + **conv_attr, ) def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult: @@ -214,10 +215,10 @@ def rewrite(self, op, conv: ir.Value, **__) -> ir.Value: return op.op( conv_node.op_type, - inputs=conv_node.inputs, - attributes=conv_attr, - domain=conv_node.domain, - name=conv_node.name, + *conv_node.inputs, + _domain=conv_node.domain, + _name=conv_node.name, + **conv_attr, ) def check(self, context, conv: ir.Value, **__) -> orp.MatchResult: diff --git a/onnxscript/rewriter/rules/common/_min_max_to_clip.py b/onnxscript/rewriter/rules/common/_min_max_to_clip.py index 88ae495dbc..c0aff9e8b9 100644 --- a/onnxscript/rewriter/rules/common/_min_max_to_clip.py +++ b/onnxscript/rewriter/rules/common/_min_max_to_clip.py @@ -65,7 +65,8 @@ def rewrite(self, op, x, out1, out2): return op.op( self.op_type, - inputs=[x, *initializers], + x, + *initializers, ) def _is_scalar(self, v: np.ndarray) -> bool: diff --git a/onnxscript/rewriter/rules/common/_remove_optional_bias.py b/onnxscript/rewriter/rules/common/_remove_optional_bias.py index dbcbe23459..db161f0756 100644 --- a/onnxscript/rewriter/rules/common/_remove_optional_bias.py +++ b/onnxscript/rewriter/rules/common/_remove_optional_bias.py @@ -19,8 +19,8 @@ def rewrite(self, op, out: ir.Value, **_) -> ir.Value: return op.op( self.op_type, - inputs=node.inputs[:-1], - attributes=node.attributes, + *node.inputs[:-1], + **node.attributes, ) def check(self, context, b: ir.Value, **_) -> MatchResult: diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index b5f66d4a0e..05830d47b4 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -13,7 +13,7 @@ import onnxscript.utils.metadata_merger as metadata_merger from onnxscript import ir -from onnxscript._internal.tape_builder import OpBuilderBase, TapeBuilder +from onnxscript._internal.tape_builder import BuilderBase, TapeBuilder logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class Replacement: # A version-adapter function takes a node, a VCContext and returns # a Replacement for the node or None (if no replacement is needed). -VCContext = OpBuilderBase +VCContext = BuilderBase ReturnValue = Union[Sequence[ir.Value], ir.Value, None] AdapterFunction = Callable[[ir.Node, VCContext], ReturnValue]