Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 35b81ae

Browse files
committed
modify kernel c interface to a more trace friendly one
1 parent 59532ca commit 35b81ae

14 files changed

Lines changed: 186 additions & 216 deletions

include/matxscript/runtime/type_helper_macros.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ struct TypeAsHelper {
7777
return v.template As<TO_TYPE>();
7878
}
7979
bool state = v.template Is<TO_TYPE>();
80-
if (std::is_same<double, TO_TYPE>::value) {
80+
if (std::is_same<double, TO_TYPE>::value || std::is_floating_point<TO_TYPE>::value) {
8181
state |= v.template Is<int64_t>();
8282
}
8383
if (!state) {

python/matx/kernel/codegen/cpp_template/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,4 @@
1818
# under the License.
1919

2020

21-
def render(parser: 'KernelParser', lib_path: str) -> str:
22-
return ''
21+
from .matx_api import render_matx_api_code

python/matx/kernel/codegen/cpp_template/base.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
from dataclasses import dataclass
2020
from typing import List, TYPE_CHECKING
2121
import os
22+
import time
2223
import jinja2
24+
import numpy as np
2325

26+
import matx.kernel.typing.utils as typing_utils
27+
from matx.kernel.typing import STR_TO_PYTYPE
28+
from matx.kernel.parser.utils import FuncReturnKind
2429

2530
if TYPE_CHECKING:
26-
from matx.kernel.parser.utils import FuncReturnKind
27-
31+
from matx.kernel.kernel_parser import KernelParser
2832

2933
TEMPLATE_DIR = os.path.dirname(os.path.abspath(__file__))
3034
JINJA2_ENV = jinja2.Environment(loader=jinja2.FileSystemLoader(TEMPLATE_DIR))
@@ -59,9 +63,19 @@ class MatxInterfaceCodegenMetaData:
5963
return_ndim: int
6064
return_dtype: str
6165

62-
def __init__(self, file_name: str, line_no: int, lib_path: str,
63-
unique_id: int, py_func_name: str, func_return_kind: FuncReturnKind,
64-
arg_names: List[str], arg_types: List[str], rt_type: str, return_ndim: int, return_dtype: str) -> None:
66+
def __init__(
67+
self,
68+
file_name: str,
69+
line_no: int,
70+
lib_path: str,
71+
unique_id: int,
72+
py_func_name: str,
73+
func_return_kind: FuncReturnKind,
74+
arg_names: List[str],
75+
arg_types: List[str],
76+
rt_type: str,
77+
return_ndim: int,
78+
return_dtype: str) -> None:
6579
self.file_name = file_name
6680
self.line_no = line_no
6781
self.lib_path = lib_path
@@ -88,4 +102,66 @@ def __init__(self, file_name: str, line_no: int, lib_path: str,
88102
self.return_dtype = return_dtype
89103

90104

91-
105+
PYTYPE_TO_CPP_TYPE_STR = {
106+
bool: "bool",
107+
int: "int32_t",
108+
float: "float",
109+
np.bool_: "bool",
110+
np.int8: "int8_t",
111+
np.int16: "int16_t",
112+
np.int32: "int32_t",
113+
np.int64: "int64_t",
114+
np.intc: "int32_t",
115+
np.uint8: "uint8_t",
116+
np.uint16: "uint16_t",
117+
np.uint32: "uint32_t",
118+
np.uint64: "uint64_t",
119+
np.uintc: "uint32_t",
120+
# todo support float16
121+
# np.float16 has no corresponding python builtin ctypes
122+
np.float16: "__fp16",
123+
np.float32: "float",
124+
np.float64: "double",
125+
np.longlong: "int64_t",
126+
np.ulonglong: "uint64_t"
127+
}
128+
129+
130+
def cvt_to_cpp_type_str(t):
131+
if typing_utils.is_scalar_type(t):
132+
return PYTYPE_TO_CPP_TYPE_STR[t.dtype]
133+
elif typing_utils.is_ndarray_type(t):
134+
return "NDArray"
135+
elif typing_utils.is_symbol(t):
136+
return "int64_t"
137+
else:
138+
raise SyntaxError(f"Unsupported type {t}")
139+
140+
141+
def make_meta_data(parser: 'KernelParser', lib_path: str) -> MatxInterfaceCodegenMetaData:
142+
file_name: str = parser.file_name
143+
line_no: int = parser.line_no
144+
145+
_nanoseconds = int(time.time() * 1e9)
146+
unique_id: int = int(_nanoseconds / 100) + 0x01b21dd213814000
147+
python_func_name: str = parser.func_name
148+
func_return_kind: FuncReturnKind = parser.graph.func_return_kind
149+
150+
arg_names: List[str] = [k for k in parser.args.keys()]
151+
arg_types: List[str] = [cvt_to_cpp_type_str(t) for t in parser.arg_types]
152+
153+
if func_return_kind.is_void():
154+
return_type: str = "void"
155+
elif func_return_kind.is_scalar():
156+
return_type: str = PYTYPE_TO_CPP_TYPE_STR[STR_TO_PYTYPE[parser.graph.return_dtype_str]]
157+
elif func_return_kind.is_dynamic_tensor():
158+
return_type: str = "void"
159+
else:
160+
raise SyntaxError(f"Unsupported return type {func_return_kind}")
161+
162+
return_ndim: int = len(parser.graph.return_shape)
163+
return_dtype: str = parser.graph.return_dtype_str
164+
return MatxInterfaceCodegenMetaData(
165+
file_name, line_no, lib_path, unique_id, python_func_name, func_return_kind,
166+
arg_names, arg_types, return_type, return_ndim, return_dtype
167+
)

python/matx/kernel/codegen/cpp_template/function_meta_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def code(self):
7474
return_ndim=self.return_ndim,
7575
return_dtype=self.return_dtype,
7676
input_types=self.input_types,
77-
input_args = self.input_args,
77+
input_args=self.input_args,
7878
lib_path=self.lib_path,
7979
func_return_kind=self.func_return_kind,
8080
debug=self.debug)

python/matx/kernel/codegen/cpp_template/matx_api.py

Lines changed: 33 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -18,119 +18,45 @@
1818
# under the License.
1919

2020

21-
from dataclasses import dataclass
22-
from typing import List, TYPE_CHECKING
23-
import jinja2
24-
import os
25-
import time
26-
import numpy as np
27-
28-
import matx.kernel.typing.utils as typing_utils
29-
from matx.kernel.typing import STR_TO_PYTYPE
30-
31-
TEMPLATE_DIR = os.path.dirname(os.path.abspath(__file__))
21+
from .base import make_meta_data, MatxInterfaceCodegenMetaData, JINJA2_ENV
22+
from .matx_api_func import MatxAPIFuncCodegen
23+
from .matx_c_api_func import MatxCAPIFuncCodegen
24+
from typing import Tuple, TYPE_CHECKING
3225

3326
if TYPE_CHECKING:
34-
from matx.kernel.parser.utils import FuncReturnKind
3527
from matx.kernel.kernel_parser import KernelParser
3628

3729

38-
@dataclass
39-
class CInterfaceCodegenData:
30+
class MatxApiCodegen:
4031
"""Class for keeping track of data necessary for c++ interface codegen."""
41-
unique_id: int
42-
func_name: str
43-
return_type: str
44-
return_ndim: int
45-
return_dtype: str
46-
input_types: List[str]
47-
input_args: List[str]
48-
lib_path: str
49-
func_return_kind: 'FuncReturnKind'
50-
free_return: bool
51-
debug: bool
52-
53-
def __init__(self, unique_id: int, func_name: str, return_type: str, return_ndim: int,
54-
return_dtype: str, input_types: List[str], input_args: List[str], lib_path: str,
55-
func_return_kind: 'FuncReturnKind', debug: bool = False):
56-
self.env = jinja2.Environment(loader=jinja2.FileSystemLoader(TEMPLATE_DIR))
57-
self.template = self.env.get_template('cpp_header.txt')
58-
self.unique_id = unique_id
59-
self.func_name = func_name
60-
self.return_type = return_type
61-
self.return_ndim = return_ndim
62-
self.return_dtype = return_dtype
63-
self.input_types = input_types
64-
self.input_args = input_args
65-
self.lib_path = lib_path
66-
self.func_return_kind = func_return_kind
67-
68-
def code(self):
69-
output = self.template.render(unique_id=self.unique_id,
70-
func_name=self.func_name,
71-
return_type=self.return_type,
72-
return_ndim=self.return_ndim,
73-
return_dtype=self.return_dtype,
74-
input_types=self.input_types,
75-
input_args = self.input_args,
76-
lib_path=self.lib_path,
77-
func_return_kind=self.func_return_kind,
78-
debug=self.debug)
79-
return output
8032

33+
def __init__(self, meta_data: MatxInterfaceCodegenMetaData) -> None:
34+
self.meta_data = meta_data
35+
self.matx_api_func_code_gen = MatxAPIFuncCodegen(meta_data)
36+
self.matx_c_api_func_code_gen = MatxCAPIFuncCodegen(meta_data)
8137

82-
PYTYPE_TO_CPP_TYPE_STR = {
83-
bool: "bool",
84-
int: "int32_t",
85-
float: "float",
86-
np.bool_: "bool",
87-
np.int8: "int8_t",
88-
np.int16: "int16_t",
89-
np.int32: "int32_t",
90-
np.int64: "int64_t",
91-
np.intc: "int32_t",
92-
np.uint8: "uint8_t",
93-
np.uint16: "uint16_t",
94-
np.uint32: "uint32_t",
95-
np.uint64: "uint64_t",
96-
np.uintc: "uint32_t",
97-
# todo support float16
98-
# np.float16 has no corresponding python builtin ctypes
99-
np.float16: "__fp16",
100-
np.float32: "float",
101-
np.float64: "double",
102-
np.longlong: "int64_t",
103-
np.ulonglong: "uint64_t"
104-
}
38+
def func_name(self):
39+
return self.meta_data.python_func_name
10540

106-
107-
def cvt_to_cpp_type_str(t):
108-
if typing_utils.is_scalar_type(t):
109-
return PYTYPE_TO_CPP_TYPE_STR[t.dtype]
110-
elif typing_utils.is_ndarray_type(t):
111-
return "void *"
112-
elif typing_utils.is_symbol(t):
113-
return "int64_t"
114-
else:
115-
raise SyntaxError(f"Unsupported type {t}")
116-
117-
118-
def get_codegen_data(parser: 'KernelParser', lib_path: str) -> CInterfaceCodegenData:
119-
nanoseconds = int(time.time() * 1e9)
120-
unique_id: int = int(nanoseconds / 100) + 0x01b21dd213814000
121-
func_name: str = parser.func_name
122-
return_ndim: int = len(parser.graph.return_shape)
123-
return_dtype: str = parser.graph.return_dtype_str
124-
input_types: List[str] = [cvt_to_cpp_type_str(t) for t in parser.arg_types]
125-
input_args: List[str] = [k for k in parser.args.keys()]
126-
func_return_kind: 'FuncReturnKind' = parser.graph.func_return_kind
127-
if func_return_kind.is_void():
128-
return_type: str = "void"
129-
elif func_return_kind.is_scalar():
130-
return_type: str = PYTYPE_TO_CPP_TYPE_STR[STR_TO_PYTYPE[parser.graph.return_dtype_str]]
131-
elif func_return_kind.is_dynamic_tensor():
132-
return_type: str = "void"
133-
else:
134-
raise SyntaxError(f"Unsupported return type {func_return_kind}")
135-
return CInterfaceCodegenData(unique_id, func_name, return_type, return_ndim,
136-
return_dtype, input_types, input_args, lib_path, func_return_kind)
41+
def code(self):
42+
code_template = JINJA2_ENV.get_template('matx_api.txt')
43+
matx_api_declaration = self.matx_api_func_code_gen.func_declaration()
44+
matx_c_api_declaration = self.matx_c_api_func_code_gen.func_declaration()
45+
matx_api_definition = self.matx_api_func_code_gen.func_definition()
46+
matx_c_api_definition = self.matx_c_api_func_code_gen.func_definition()
47+
return code_template.render(
48+
matx_api_declaration=matx_api_declaration,
49+
matx_c_api_declaration=matx_c_api_declaration,
50+
matx_api_definition=matx_api_definition,
51+
matx_c_api_definition=matx_c_api_definition,
52+
c_interface_func_name=self.meta_data.c_api_func_name,
53+
py_func_name=self.meta_data.python_func_name
54+
)
55+
56+
57+
def render_matx_api_code(parser: 'KernelParser',
58+
lib_path: str) -> Tuple[str,
59+
MatxInterfaceCodegenMetaData]:
60+
meta_data = make_meta_data(parser, lib_path)
61+
codegen_class = MatxApiCodegen(meta_data)
62+
return codegen_class.code(), meta_data

python/matx/kernel/codegen/cpp_template/matx_api_func.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ def __init__(self, meta_data: MatxInterfaceCodegenMetaData) -> None:
4040

4141
self.gen_mlir_func_type()
4242
self.gen_ndarray_cvt_code()
43-
self.arg_type_and_name = ", ".join(
44-
[f"{t} {n}" for t, n in zip(self.meta_data.matx_arg_types, self.meta_data.matx_arg_names)]
45-
)
43+
self.gen_mlir_func_call_code()
44+
self.arg_type_and_name = ", ".join([f"{t} {n}" for t, n in zip(
45+
self.meta_data.matx_arg_types, self.meta_data.matx_arg_names)])
4646

4747
def func_declaration(self):
48-
return f"MATX_DLL {self.meta_data.matx_rt_type} {self.meta_data.matx_func_name}({', '.join(self.meta_data.arg_types)});"
48+
return f"MATX_DLL auto {self.meta_data.matx_func_name}({', '.join(self.meta_data.matx_arg_types)} handle_2_71828182846 =((void*)(int64_t)0));"
4949

5050
def gen_ndarray_cvt_code(self):
5151
i = 0
@@ -68,32 +68,33 @@ def gen_mlir_func_type(self):
6868
elif self.meta_data.func_return_kind.is_dynamic_tensor():
6969
self.mlir_func_type = f"void(*)({', '.join(['void *', *self.meta_data.mlir_arg_types])});"
7070
else:
71-
raise SyntaxError(f"function_return_kind({self.meta_data.func_return_kind}) is not supported")
71+
raise SyntaxError(
72+
f"function_return_kind({self.meta_data.func_return_kind}) is not supported")
7273

7374
def gen_mlir_func_call_code(self):
7475
if self.meta_data.func_return_kind.is_void():
7576
self.call_func_code_list = [
76-
f"casted_func_ptr({self.mlir_args});"
77+
f"casted_func_ptr({', '.join(self.mlir_args)});",
7778
"return None;"
7879
]
7980
elif self.meta_data.func_return_kind.is_scalar():
8081
self.call_func_code_list = [
81-
f"return casted_func_ptr({self.mlir_args});"
82+
f"return casted_func_ptr({', '.join(self.mlir_args)});"
8283
]
8384
elif self.meta_data.func_return_kind.is_dynamic_tensor():
8485
rt_shared_ptr = "_mlir_return_31905_shared_ptr_571"
8586
rt_ptr = "_mlir_return_31905_ptr_571"
8687
self.call_func_code_list = [
87-
f"auto && {rt_shared_ptr} = alloc_memref_descriptor_ptr({self.meta_data.return_ndim});"
88-
f"void * {rt_ptr} = {rt_shared_ptr}.get();"
89-
f"casted_func_ptr({', '.join([rt_ptr, *self.mlir_args])});"
90-
f"""return convert_to_ndarray({rt_shared_ptr}, {self.meta_data.return_ndim}, cvt_str_to_dl_dtype("{self.meta_data.return_dtype}"));"""
91-
]
88+
f"auto && {rt_shared_ptr} = alloc_memref_descriptor_ptr({self.meta_data.return_ndim});",
89+
f"void * {rt_ptr} = {rt_shared_ptr}.get();",
90+
f"casted_func_ptr({', '.join([rt_ptr, *self.mlir_args])});",
91+
f"""return convert_to_ndarray({rt_shared_ptr}, {self.meta_data.return_ndim}, cvt_str_to_dl_dtype("{self.meta_data.return_dtype}"));"""]
9292
else:
93-
raise SyntaxError(f"function_return_kind({self.meta_data.func_return_kind}) is not supported")
93+
raise SyntaxError(
94+
f"function_return_kind({self.meta_data.func_return_kind}) is not supported")
9495

9596
def func_definition(self) -> str:
96-
func_definition_template = JINJA2_ENV.get_template('matx_c_api_func.txt')
97+
func_definition_template = JINJA2_ENV.get_template('matx_api_func.txt')
9798
return func_definition_template.render(
9899
mlir_func_type=self.mlir_func_type,
99100
matx_func_name=self.meta_data.matx_func_name,
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
using _mlir_func_type = {{mlir_func_type}};
22

3-
MATX_DLL auto {{matx_func_name}}({arg_type_and_name}}) {
3+
MATX_DLL auto {{matx_func_name}}({{arg_type_and_name}}) {
44
static void * func_ptr = load_func("{{mlir_func_name}}", "{{lib_path}}");
55
static auto casted_func_ptr = reinterpret_cast<_mlir_func_type>(func_ptr);
66
//convert input types
7-
{{type_cvt_code_list | join("\n ")}}
7+
{{type_cvt_code_list | join("\n ")}}
88

99
// call funtion
10-
{{call_func_code_list | join(\n ")}}}
10+
{{call_func_code_list | join("\n ")}}
1111
}

0 commit comments

Comments
 (0)