140140 }
141141
142142{arg_parsing}
143-
144143{operation}
145-
146- Py_RETURN_NONE;
147144#else
148145 PyErr_SetString(PyExc_NotImplementedError, "AVX-512 is not supported on this architecture.");
149146 return nullptr;
@@ -631,15 +628,13 @@ def getCallCode(func: SIMDFunc, constantRequire: list[tuple[int, int, int]]) ->
631628 else :
632629 code = f" { func .name } ({ args } );"
633630
634- isOuter = True
635631 if len (constantRequire ) > 0 :
636632 if len (constantRequire ) > 4 :
637633 raise NotImplementedError (len (constantRequire ))
638634
639635 for i , (argId , immediateMin , immediateMax ) in enumerate (constantRequire ):
640636 if immediateMin == immediateMax :
641637 code = code .replace (f"arg{ argId } " , str (immediateMin ))
642- isOuter = False
643638 continue
644639
645640 result = f"switch (arg{ argId } ) {{\n "
@@ -664,16 +659,6 @@ def getCallCode(func: SIMDFunc, constantRequire: list[tuple[int, int, int]]) ->
664659
665660 result += "}"
666661
667- if isOuter :
668- result = f"""
669- #if defined(__clang__) || defined(__GNUC__)
670- { result }
671- #else
672- PyErr_SetString(PyExc_NotImplementedError, "AVX-512 is not supported on this architecture.");
673- return nullptr;
674- #endif
675- """
676-
677662 return result
678663 return code
679664
@@ -729,12 +714,25 @@ def main():
729714 if curImmediateGenerated != 1 :
730715 immediateGenerated += curImmediateGenerated
731716
717+ callCode = getCallCode (function , constantRequire )
718+ if len (constantRequire ) > 0 :
719+ argParseCode = f"#if defined(__clang__) || defined(__GNUC__)\n { argParseCode } "
720+ callCode = (f"{ formatCode (callCode )} \n \n "
721+ f" Py_RETURN_NONE;\n "
722+ f"#else\n "
723+ f" PyErr_SetString(PyExc_NotImplementedError, \" Target C Method require immediate numbers, "
724+ f"and this method is not supported in GCC/Clang now.\" );\n "
725+ f" return nullptr;\n "
726+ f"#endif" )
727+ else :
728+ callCode = f"{ callCode } \n \n Py_RETURN_NONE;"
729+
732730 funcCode = funcCode .replace ("{num_args}" , str (num_args ))
733731 funcCode = funcCode .replace ("{function_name}" , function .name )
734- funcCode = funcCode .replace ("{arg_parsing}" , argParseCode )
735732
736733 # operation
737- funcCode = funcCode .replace ("{operation}" , getCallCode (function , constantRequire ))
734+ funcCode = funcCode .replace ("{arg_parsing}" , argParseCode )
735+ funcCode = funcCode .replace ("{operation}" , callCode )
738736 functionsGenerated += 1
739737 function_def += funcCode
740738
@@ -799,5 +797,12 @@ def main():
799797 print (f"Generated '{ RESULT_PYI_FILE } ' with { size } bytes and { lines } lines." )
800798
801799
800+ def formatCode (code : str ) -> str :
801+ result = ""
802+ for line in code .split ("\n " ):
803+ result += f" { line } \n "
804+ return result
805+
806+
802807if __name__ == '__main__' :
803808 main ()
0 commit comments