@@ -38,6 +38,12 @@ def get_device_type():
3838 if torch .cuda .is_available ():
3939 return "cuda"
4040
41+ try :
42+ import ixformer
43+ return "ilu"
44+ except ImportError :
45+ pass
46+
4147 try :
4248 import torch_mlu
4349 if torch .mlu .is_available ():
@@ -143,6 +149,14 @@ def get_torch_mlu_root_path():
143149 except ImportError :
144150 return None
145151
152+ def get_ixformer_root_path ():
153+ try :
154+ import ixformer
155+ import os
156+ return os .path .dirname (os .path .abspath (ixformer .__file__ ))
157+ except ImportError :
158+ return None
159+
146160def get_nccl_root_path ():
147161 try :
148162 from nvidia import nccl
@@ -253,7 +267,14 @@ def set_cuda_envs():
253267 os .environ ["LIBTORCH_ROOT" ] = get_torch_root_path ()
254268 os .environ ["PYTORCH_INSTALL_PATH" ] = get_torch_root_path ()
255269 os .environ ["CUDA_TOOLKIT_ROOT_DIR" ] = "/usr/local/cuda"
256-
270+
271+ def set_ilu_envs ():
272+ os .environ ["PYTHON_INCLUDE_PATH" ] = get_python_include_path ()
273+ os .environ ["PYTHON_LIB_PATH" ] = get_torch_root_path ()
274+ os .environ ["LIBTORCH_ROOT" ] = get_torch_root_path ()
275+ os .environ ["PYTORCH_INSTALL_PATH" ] = get_torch_root_path ()
276+ os .environ ["IXFORMER_INSTALL_PATH" ] = get_ixformer_root_path ()
277+
257278class CMakeExtension (Extension ):
258279 def __init__ (self , name : str , path : str , sourcedir : str = "" ) -> None :
259280 super ().__init__ (name , sources = [])
@@ -304,8 +325,7 @@ def run(self):
304325 for ext in self .extensions :
305326 self .build_extension (ext )
306327 except Exception as e :
307- print ("ERROR: Build failed." )
308- print (f"Details: { e } " )
328+ print ("Build failed." )
309329 exit (1 )
310330
311331 def build_extension (self , ext : CMakeExtension ):
@@ -337,7 +357,7 @@ def build_extension(self, ext: CMakeExtension):
337357 f"-DDEVICE_ARCH={ self .arch .upper ()} " ,
338358 f"-DINSTALL_XLLM_KERNELS={ 'ON' if self .install_xllm_kernels else 'OFF' } " ,
339359 ]
340-
360+
341361 if self .device == "a2" or self .device == "a3" :
342362 cmake_args += ["-DUSE_NPU=ON" ]
343363 # set npu environment variables
@@ -352,6 +372,9 @@ def build_extension(self, ext: CMakeExtension):
352372 f"-DCMAKE_CUDA_ARCHITECTURES={ cuda_architectures } " ]
353373 # set cuda environment variables
354374 set_cuda_envs ()
375+ elif self .device == "ilu" :
376+ cmake_args += ["-DUSE_ILU=ON" ]
377+ set_ilu_envs ()
355378 else :
356379 raise ValueError ("Please set --device to a2 or a3 or mlu or cuda." )
357380
@@ -375,6 +398,7 @@ def build_extension(self, ext: CMakeExtension):
375398
376399 build_args = ["--config" , build_type ]
377400 max_jobs = os .getenv ("MAX_JOBS" , str (os .cpu_count ()))
401+ # max_jobs="2"
378402 build_args += ["-j" + max_jobs ]
379403
380404 env = os .environ .copy ()
@@ -604,9 +628,9 @@ def parse_arguments():
604628 parser .add_argument (
605629 '--device' ,
606630 type = str .lower ,
607- choices = ['auto' , 'a2' , 'a3' , 'mlu' , 'cuda' ],
631+ choices = ['auto' , 'a2' , 'a3' , 'mlu' , 'cuda' , 'ilu' ],
608632 default = 'auto' ,
609- help = 'Device type: a2, a3, mlu, or cuda (case-insensitive)'
633+ help = 'Device type: a2, a3, mlu, ilu or cuda (case-insensitive)'
610634 )
611635
612636 parser .add_argument (
0 commit comments