Skip to content

Commit b65a6aa

Browse files
committed
Fix compat issue
1 parent 8995634 commit b65a6aa

2 files changed

Lines changed: 5946 additions & 6883 deletions

File tree

scripts/GenerateAVX512Bindings.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,7 @@
140140
}
141141
142142
{arg_parsing}
143-
144143
{operation}
145-
146-
Py_RETURN_NONE;
147144
#else
148145
PyErr_SetString(PyExc_NotImplementedError, "AVX-512 is not supported on this architecture.");
149146
return nullptr;
@@ -631,15 +628,13 @@ def getCallCode(func: SIMDFunc, constantRequire: list[tuple[int, int, int]]) ->
631628
else:
632629
code = f" {func.name}({args});"
633630

634-
isOuter = True
635631
if len(constantRequire) > 0:
636632
if len(constantRequire) > 4:
637633
raise NotImplementedError(len(constantRequire))
638634

639635
for i, (argId, immediateMin, immediateMax) in enumerate(constantRequire):
640636
if immediateMin == immediateMax:
641637
code = code.replace(f"arg{argId}", str(immediateMin))
642-
isOuter = False
643638
continue
644639

645640
result = f"switch (arg{argId}) {{\n"
@@ -664,16 +659,6 @@ def getCallCode(func: SIMDFunc, constantRequire: list[tuple[int, int, int]]) ->
664659

665660
result += "}"
666661

667-
if isOuter:
668-
result = f"""
669-
#if defined(__clang__) || defined(__GNUC__)
670-
{result}
671-
#else
672-
PyErr_SetString(PyExc_NotImplementedError, "AVX-512 is not supported on this architecture.");
673-
return nullptr;
674-
#endif
675-
"""
676-
677662
return result
678663
return code
679664

@@ -729,12 +714,25 @@ def main():
729714
if curImmediateGenerated != 1:
730715
immediateGenerated += curImmediateGenerated
731716

717+
callCode = getCallCode(function, constantRequire)
718+
if len(constantRequire) > 0:
719+
argParseCode = f"#if defined(__clang__) || defined(__GNUC__)\n{argParseCode}"
720+
callCode = (f"{formatCode(callCode)}\n\n"
721+
f" Py_RETURN_NONE;\n"
722+
f"#else\n"
723+
f" PyErr_SetString(PyExc_NotImplementedError, \"Target C Method require immediate numbers, "
724+
f"and this method is not supported in GCC/Clang now.\");\n"
725+
f" return nullptr;\n"
726+
f"#endif")
727+
else:
728+
callCode = f"{callCode}\n\n Py_RETURN_NONE;"
729+
732730
funcCode = funcCode.replace("{num_args}", str(num_args))
733731
funcCode = funcCode.replace("{function_name}", function.name)
734-
funcCode = funcCode.replace("{arg_parsing}", argParseCode)
735732

736733
# operation
737-
funcCode = funcCode.replace("{operation}", getCallCode(function, constantRequire))
734+
funcCode = funcCode.replace("{arg_parsing}", argParseCode)
735+
funcCode = funcCode.replace("{operation}", callCode)
738736
functionsGenerated += 1
739737
function_def += funcCode
740738

@@ -799,5 +797,12 @@ def main():
799797
print(f"Generated '{RESULT_PYI_FILE}' with {size} bytes and {lines} lines.")
800798

801799

800+
def formatCode(code: str) -> str:
801+
result = ""
802+
for line in code.split("\n"):
803+
result += f" {line}\n"
804+
return result
805+
806+
802807
if __name__ == '__main__':
803808
main()

0 commit comments

Comments
 (0)