|
76 | 76 | from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module |
77 | 77 | from .comm import gen_vllm_comm_module as gen_vllm_comm_module |
78 | 78 | from .comm import gen_nvshmem_module as gen_nvshmem_module |
| 79 | +from typing import Optional |
| 80 | + |
| 81 | + |
| 82 | +def find_loaded_library(lib_name) -> Optional[str]: |
| 83 | + """ |
| 84 | + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, |
| 85 | + the file `/proc/self/maps` contains the memory maps of the process, which includes the |
| 86 | + shared libraries loaded by the process. We can use this file to find the path of the |
| 87 | + a loaded library. |
| 88 | + """ |
| 89 | + found = False |
| 90 | + with open("/proc/self/maps") as f: |
| 91 | + for line in f: |
| 92 | + if lib_name in line: |
| 93 | + found = True |
| 94 | + break |
| 95 | + if not found: |
| 96 | + # the library is not loaded in the current process |
| 97 | + return None |
| 98 | + # if lib_name is libcudart, we need to match a line with: |
| 99 | + # address /path/to/libcudart-hash.so.11.0 |
| 100 | + start = line.index("/") |
| 101 | + path = line[start:].strip() |
| 102 | + filename = path.split("/")[-1] |
| 103 | + assert filename.rpartition(".so")[0].startswith(lib_name), ( |
| 104 | + f"Unexpected filename: {filename} for library {lib_name}" |
| 105 | + ) |
| 106 | + return path |
79 | 107 |
|
80 | 108 |
|
81 | 109 | cuda_lib_path = os.environ.get( |
82 | 110 | "CUDA_LIB_PATH", "/usr/local/cuda/targets/x86_64-linux/lib/" |
83 | 111 | ) |
84 | | -if os.path.exists(f"{cuda_lib_path}/libcudart.so.12"): |
| 112 | +process_cudart_path = find_loaded_library("libcudart") |
| 113 | +if process_cudart_path is not None: |
| 114 | + ctypes.CDLL(process_cudart_path, mode=ctypes.RTLD_GLOBAL) |
| 115 | +elif os.path.exists(f"{cuda_lib_path}/libcudart.so.12"): |
85 | 116 | ctypes.CDLL(f"{cuda_lib_path}/libcudart.so.12", mode=ctypes.RTLD_GLOBAL) |
0 commit comments