|
| 1 | +import builtins |
| 2 | +import ctypes |
| 3 | +import sys |
| 4 | +import types |
| 5 | +from dis import Bytecode |
| 6 | +from inspect import Parameter |
| 7 | +from types import FunctionType |
| 8 | +from typing import Callable, Any, Optional, Iterable |
| 9 | + |
| 10 | +from pyfastutil.objects import ObjectArrayList |
| 11 | + |
| 12 | +INT_MAX = (2 ** (ctypes.sizeof(ctypes.c_int) * 8 - 1)) - 1 |
| 13 | +INT_MIN = -INT_MAX - 1 |
| 14 | + |
| 15 | + |
| 16 | +def intToDigits(value: int) -> tuple[int, int, list[int]]: |
| 17 | + """ |
| 18 | + 将一个 Python 的 int 分解为 _PyLong_FromDigits 所需的 digit 信息。 |
| 19 | + :returns: |
| 20 | + - sign: 0 表示正数,1 表示负数 |
| 21 | + - num_digits: digit 数组的长度 |
| 22 | + - digits: 包含每个 digit 的数组 |
| 23 | + """ |
| 24 | + DIGIT_BITS = 15 if sys.maxsize > 2 ** 32 else 30 |
| 25 | + DIGIT_BASE = 1 << DIGIT_BITS |
| 26 | + |
| 27 | + sign = 0 if value >= 0 else 1 |
| 28 | + value = abs(value) |
| 29 | + |
| 30 | + digits = [] |
| 31 | + while value: |
| 32 | + digits.append(value % DIGIT_BASE) |
| 33 | + value //= DIGIT_BASE |
| 34 | + |
| 35 | + if not digits: |
| 36 | + digits = [0] |
| 37 | + |
| 38 | + return sign, len(digits), digits |
| 39 | + |
| 40 | + |
| 41 | +class _BytecodeTranslator: |
| 42 | + def __init__(self, func: FunctionType, arguments: Iterable[Parameter], bytecode: Bytecode): |
| 43 | + self.bytecode = bytecode |
| 44 | + self.vars: tuple[str, ...] = func.__code__.co_varnames |
| 45 | + self.code: list[str] = ObjectArrayList(100) |
| 46 | + self.delayActions: list[Callable[[], None]] = [] |
| 47 | + |
| 48 | + self.constsCount: int = 0 |
| 49 | + self.constants: dict[object, str] = {} |
| 50 | + |
| 51 | + self.regCount: int = func.__code__.co_stacksize |
| 52 | + self.regUsed: int = 0 |
| 53 | + self.regAutoclose: list[bool] = [False] * self.regCount |
| 54 | + |
| 55 | + # helper context |
| 56 | + self.kwNames: list[str] = [] # to support kwargs call |
| 57 | + self.forDepth: int = 0 # to support for |
| 58 | + self.bytecodeOffsets: dict[int, int] = {} # to support JUMP_BACKWARD |
| 59 | + """key: bytecode bytes offset (f->lastI), value: c code offset (line)""" |
| 60 | + self.varsUnused: set[str] = set(self.vars) |
| 61 | + |
| 62 | + # define local variables (include args) |
| 63 | + for name in self.varsUnused: |
| 64 | + self.append(f"PyObject *var_{name} = nullptr;") |
| 65 | + |
| 66 | + # define registers |
| 67 | + for i in range(self.regCount): |
| 68 | + self.append(f"[[maybe_unused]] PyObject *tmp{i + 1};") |
| 69 | + self.append("") |
| 70 | + |
| 71 | + # parse args |
| 72 | + for i, param in enumerate(arguments): |
| 73 | + if param.kind == Parameter.KEYWORD_ONLY: |
| 74 | + raise NotImplementedError(f"Param '{param.name}'") |
| 75 | + |
| 76 | + if param.default == Parameter.empty: |
| 77 | + self.assign(param.name, f"*(args + {i})") |
| 78 | + else: |
| 79 | + expr, shouldDecref = self.fromConstant(param.default) |
| 80 | + self.assign(param.name, f"nargs > {i + 1} ? *(args + {i}) : {expr}") |
| 81 | + |
| 82 | + def append(self, code: str) -> None: |
| 83 | + self.code.append(code) |
| 84 | + |
| 85 | + def handleDelayActions(self) -> None: |
| 86 | + for action in self.delayActions: |
| 87 | + action() |
| 88 | + self.delayActions.clear() |
| 89 | + |
| 90 | + def pushCall(self, expr: str, autoclose: bool) -> None: |
| 91 | + self.append(f"res = {expr};") |
| 92 | + self.delayActions.append(lambda: self.push("res", autoclose)) |
| 93 | + |
| 94 | + def push(self, expr: str, autoclose: bool) -> None: |
| 95 | + assert 0 <= self.regUsed <= self.regCount - 1 |
| 96 | + if autoclose: |
| 97 | + self.regAutoclose[self.regUsed] = True |
| 98 | + |
| 99 | + self.regUsed += 1 |
| 100 | + self.append(f"tmp{self.regUsed} = {expr};") |
| 101 | + |
| 102 | + def pop(self, noAutoClose: bool = False) -> str: |
| 103 | + regName = self.back() |
| 104 | + |
| 105 | + self.regUsed -= 1 |
| 106 | + if not noAutoClose and self.regAutoclose[self.regUsed]: |
| 107 | + self.delayActions.append(lambda: self.append(f"PyFast_DECREF({regName});")) |
| 108 | + return regName |
| 109 | + |
| 110 | + def back(self) -> str: |
| 111 | + assert 1 <= self.regUsed <= self.regCount |
| 112 | + return f"tmp{self.regUsed}" |
| 113 | + |
| 114 | + def assign(self, name: str, expr: str) -> None: |
| 115 | + if name not in self.varsUnused: |
| 116 | + self.append(f"PyFast_DECREF(var_{name});") |
| 117 | + else: |
| 118 | + self.varsUnused.remove(name) |
| 119 | + name = f"var_{name}" |
| 120 | + self.append(f"{name} = {expr};") |
| 121 | + self.append(f"PyFast_INCREF({name});") |
| 122 | + |
| 123 | + def name(self, name: str) -> Optional[str]: |
| 124 | + if name in self.varsUnused: |
| 125 | + return None |
| 126 | + return f"var_{name}" |
| 127 | + |
| 128 | + def assignConstant(self, expr: str, value: object = None) -> str: |
| 129 | + if value is not None and value in self.constants: |
| 130 | + return self.constants[value] |
| 131 | + |
| 132 | + self.constsCount += 1 |
| 133 | + name = f"constant{self.constsCount}" |
| 134 | + self.append(f"static auto {name} = {expr};") |
| 135 | + |
| 136 | + if value is not None: |
| 137 | + self.constants[value] = name |
| 138 | + return name |
| 139 | + |
| 140 | + @staticmethod |
| 141 | + def call(name: str, *args: Any) -> str: |
| 142 | + formattedArgs = ", ".join(map(str, args)) |
| 143 | + return f"{name}({formattedArgs})" |
| 144 | + |
| 145 | + def returnVal(self, expr: str) -> None: |
| 146 | + self.handleDelayActions() |
| 147 | + for name in self.vars: |
| 148 | + name = self.name(name) |
| 149 | + if name is None: |
| 150 | + continue |
| 151 | + self.append(f"PyFast_XDECREF({name});") |
| 152 | + self.append(f"return {expr};") |
| 153 | + |
| 154 | + def fromConstant(self, value: Any, forceConstant: bool = False) -> tuple[str, bool]: |
| 155 | + """expr code, should decref""" |
| 156 | + match type(value): |
| 157 | + case builtins.int: |
| 158 | + if INT_MIN <= value <= INT_MAX: |
| 159 | + return self.assignConstant(self.call(f"PyFast_FromInt", value), value), False |
| 160 | + elif sys.getsizeof(value) <= 16 or forceConstant: |
| 161 | + sign, num_digits, digits = intToDigits(value) |
| 162 | + formattedDigits = "{" + ", ".join(map(str, digits)) + "}" |
| 163 | + return self.assignConstant( |
| 164 | + self.call(f"PyFast_FromDigits", value, sign, num_digits, formattedDigits), value |
| 165 | + ), False |
| 166 | + else: |
| 167 | + sign, num_digits, digits = intToDigits(value) |
| 168 | + formattedDigits = "{" + ", ".join(map(str, digits)) + "}" |
| 169 | + return self.call(f"PyFast_FromDigits", value, sign, num_digits, formattedDigits), True |
| 170 | + case builtins.bool: |
| 171 | + return "Py_True" if value else "Py_False", False |
| 172 | + case types.NoneType: |
| 173 | + return "Py_None", False |
| 174 | + case builtins.str: |
| 175 | + if len(value) <= 16 or forceConstant: |
| 176 | + return self.assignConstant(self.call("PyUnicode_FromString", f"\"{value}\""), value), False |
| 177 | + return self.call("PyUnicode_FromString", f"\"{value}\""), True |
| 178 | + case _: |
| 179 | + raise NotImplementedError(value) |
| 180 | + |
| 181 | + def run(self) -> list[str]: |
| 182 | + for instr in self.bytecode: |
| 183 | + op = instr.opname |
| 184 | + arg = instr.arg |
| 185 | + argVal = instr.argval |
| 186 | + |
| 187 | + self.append("") |
| 188 | + self.append(f"// {op}({arg}, {argVal})") |
| 189 | + |
| 190 | + self.bytecodeOffsets[instr.offset] = len(self.code) |
| 191 | + |
| 192 | + self.visit(op, arg, argVal) |
| 193 | + |
| 194 | + self.handleDelayActions() |
| 195 | + |
| 196 | + return self.code |
| 197 | + |
| 198 | + def visit(self, op: str, arg: Optional[int], argVal: Any) -> None: |
| 199 | + match op: |
| 200 | + case "RESUME": |
| 201 | + # We needn't PyGIL_Ensure |
| 202 | + ... |
| 203 | + |
| 204 | + case "NOP": |
| 205 | + ... |
| 206 | + |
| 207 | + case "LOAD_FAST": |
| 208 | + self.push(f"var_{argVal}", True) |
| 209 | + self.append(f"PyFast_INCREF({self.back()});") |
| 210 | + |
| 211 | + case "STORE_FAST": |
| 212 | + self.assign(argVal, self.pop()) |
| 213 | + |
| 214 | + case "LOAD_CONST": |
| 215 | + self.push(*self.fromConstant(argVal)) |
| 216 | + |
| 217 | + case "LOAD_GLOBAL": |
| 218 | + name, _ = self.fromConstant(argVal, forceConstant=True) |
| 219 | + hashVal = self.assignConstant(self.call("PyObject_Hash", name)) |
| 220 | + self.pushCall(self.call("PyFast_LoadGlobal", name, hashVal), True) |
| 221 | + |
| 222 | + case "BINARY_OP": |
| 223 | + right = self.pop() |
| 224 | + left = self.pop() |
| 225 | + |
| 226 | + match arg: |
| 227 | + case 0: # + |
| 228 | + self.pushCall(self.call("PyNumber_Add", left, right), True) |
| 229 | + case 5: # * |
| 230 | + self.pushCall(self.call("PyNumber_Multiply", left, right), True) |
| 231 | + case 13: # += |
| 232 | + self.pushCall(self.call("PyNumber_Add", left, right), True) |
| 233 | + case _: |
| 234 | + raise NotImplementedError(op, arg, argVal) |
| 235 | + |
| 236 | + case "RETURN_VALUE": |
| 237 | + while self.regUsed > 1: |
| 238 | + self.pop() |
| 239 | + self.returnVal(self.pop(noAutoClose=True)) |
| 240 | + |
| 241 | + case "RETURN_CONST": |
| 242 | + self.returnVal(self.fromConstant(argVal, forceConstant=True)[0]) |
| 243 | + |
| 244 | + case "KW_NAMES": # cpython 3.11+ only |
| 245 | + self.kwNames.append(argVal[0]) |
| 246 | + |
| 247 | + case "CALL": |
| 248 | + args: list[str] = [self.pop() for _ in range(argVal)] |
| 249 | + args.reverse() |
| 250 | + |
| 251 | + if len(args) == 0: |
| 252 | + self.pushCall(self.call("PyFast_CallNoArgs", self.pop()), True) |
| 253 | + elif len(self.kwNames) == 0: |
| 254 | + self.pushCall(self.call("PyFast_CallNoKwargs", self.pop(), "{" + ", ".join(args) + "}"), True) |
| 255 | + else: |
| 256 | + posArgsCount: int = len(args) - len(self.kwNames) |
| 257 | + |
| 258 | + kwNamesLst = kwNamesHashLst = [] |
| 259 | + for name in self.kwNames: |
| 260 | + nameObj, _ = self.fromConstant(name, forceConstant=True) |
| 261 | + kwNamesLst.append(nameObj) |
| 262 | + hashObj = self.assignConstant(self.call("PyObject_Hash", nameObj)) |
| 263 | + kwNamesHashLst.append(hashObj) |
| 264 | + |
| 265 | + kwNames = self.assignConstant(self.call("PyFast_TuplePack", "{" + ", ".join(kwNamesLst) + "}")) |
| 266 | + kwNamesHash = "{" + ", ".join(kwNamesHashLst) + "}" |
| 267 | + |
| 268 | + self.pushCall(self.call( |
| 269 | + "PyFast_Call", |
| 270 | + self.pop(), |
| 271 | + "{" + ", ".join(args) + "}", |
| 272 | + posArgsCount, kwNames, kwNamesHash |
| 273 | + ), True) |
| 274 | + |
| 275 | + case "POP_TOP": |
| 276 | + self.pop() |
| 277 | + |
| 278 | + case "LOAD_ATTR": |
| 279 | + self.pushCall(self.call( |
| 280 | + "PyObject_GetAttr", |
| 281 | + self.pop(), |
| 282 | + self.fromConstant(argVal, forceConstant=True)[0] |
| 283 | + ), True) |
| 284 | + |
| 285 | + case "GET_ITER": |
| 286 | + self.pushCall(self.call("PyFast_GetIter", self.pop()), True) |
| 287 | + |
| 288 | + case "FOR_ITER": |
| 289 | + self.forDepth += 1 |
| 290 | + self.pushCall(self.call("PyFast_Next", self.back()), True) |
| 291 | + self.append(f"if (UNLIKELY(res == nullptr)) {{") |
| 292 | + # We needn't to decref the result, because it's nullptr |
| 293 | + self.append(f" PyErr_Clear();") |
| 294 | + self.append(f" goto endFor{self.forDepth};") |
| 295 | + self.append(f"}}") |
| 296 | + |
| 297 | + case "END_FOR": |
| 298 | + self.append(f"endFor{self.forDepth}:") |
| 299 | + self.pop() # pop iter |
| 300 | + self.forDepth -= 1 |
| 301 | + |
| 302 | + case "JUMP_BACKWARD": |
| 303 | + codePos = self.bytecodeOffsets[argVal] |
| 304 | + label = f"jumpLastI{argVal}" |
| 305 | + if self.code[codePos] != label: |
| 306 | + self.code.insert(codePos, f"{label}:") |
| 307 | + self.append(f"goto {label};") |
| 308 | + |
| 309 | + case _: |
| 310 | + raise NotImplementedError(op, arg, argVal) |
| 311 | + |
| 312 | + |
| 313 | +# noinspection PyUnusedLocal |
| 314 | +def toC(func: FunctionType, arguments: Iterable[Parameter], bytecode: Bytecode, c_int: bool) -> list[str]: |
| 315 | + return _BytecodeTranslator(func, arguments, bytecode).run() |
0 commit comments