@@ -38,6 +38,12 @@ def get_device_type():
3838 if torch .cuda .is_available ():
3939 return "cuda"
4040
41+ try :
42+ import ixformer
43+ return "ilu"
44+ except ImportError :
45+ pass
46+
4147 try :
4248 import torch_mlu
4349 if torch .mlu .is_available ():
@@ -143,6 +149,14 @@ def get_torch_mlu_root_path():
143149 except ImportError :
144150 return None
145151
152+ def get_ixformer_root_path ():
153+ try :
154+ import ixformer
155+ import os
156+ return os .path .dirname (os .path .abspath (ixformer .__file__ ))
157+ except ImportError :
158+ return None
159+
146160def get_nccl_root_path ():
147161 try :
148162 from nvidia import nccl
@@ -253,7 +267,14 @@ def set_cuda_envs():
253267 os .environ ["LIBTORCH_ROOT" ] = get_torch_root_path ()
254268 os .environ ["PYTORCH_INSTALL_PATH" ] = get_torch_root_path ()
255269 os .environ ["CUDA_TOOLKIT_ROOT_DIR" ] = "/usr/local/cuda"
256-
270+
271+ def set_ilu_envs ():
272+ os .environ ["PYTHON_INCLUDE_PATH" ] = get_python_include_path ()
273+ os .environ ["PYTHON_LIB_PATH" ] = get_torch_root_path ()
274+ os .environ ["LIBTORCH_ROOT" ] = get_torch_root_path ()
275+ os .environ ["PYTORCH_INSTALL_PATH" ] = get_torch_root_path ()
276+ os .environ ["IXFORMER_INSTALL_PATH" ] = get_ixformer_root_path ()
277+
257278class CMakeExtension (Extension ):
258279 def __init__ (self , name : str , path : str , sourcedir : str = "" ) -> None :
259280 super ().__init__ (name , sources = [])
@@ -337,7 +358,7 @@ def build_extension(self, ext: CMakeExtension):
337358 f"-DDEVICE_ARCH={ self .arch .upper ()} " ,
338359 f"-DINSTALL_XLLM_KERNELS={ 'ON' if self .install_xllm_kernels else 'OFF' } " ,
339360 ]
340-
361+
341362 if self .device == "a2" or self .device == "a3" :
342363 cmake_args += ["-DUSE_NPU=ON" ]
343364 # set npu environment variables
@@ -352,6 +373,9 @@ def build_extension(self, ext: CMakeExtension):
352373 f"-DCMAKE_CUDA_ARCHITECTURES={ cuda_architectures } " ]
353374 # set cuda environment variables
354375 set_cuda_envs ()
376+ elif self .device == "ilu" :
377+ cmake_args += ["-DUSE_ILU=ON" ]
378+ set_ilu_envs ()
355379 else :
356380 raise ValueError ("Please set --device to a2 or a3 or mlu or cuda." )
357381
@@ -375,6 +399,7 @@ def build_extension(self, ext: CMakeExtension):
375399
376400 build_args = ["--config" , build_type ]
377401 max_jobs = os .getenv ("MAX_JOBS" , str (os .cpu_count ()))
402+ # max_jobs="2"
378403 build_args += ["-j" + max_jobs ]
379404
380405 env = os .environ .copy ()
@@ -604,9 +629,9 @@ def parse_arguments():
604629 parser .add_argument (
605630 '--device' ,
606631 type = str .lower ,
607- choices = ['auto' , 'a2' , 'a3' , 'mlu' , 'cuda' ],
632+ choices = ['auto' , 'a2' , 'a3' , 'mlu' , 'cuda' , 'ilu' ],
608633 default = 'auto' ,
609- help = 'Device type: a2, a3, mlu, or cuda (case-insensitive)'
634+ help = 'Device type: a2, a3, mlu, ilu or cuda (case-insensitive)'
610635 )
611636
612637 parser .add_argument (
0 commit comments