diff --git a/docker/Dockerfile.metax b/docker/Dockerfile.metax new file mode 100644 index 000000000..7e30dd806 --- /dev/null +++ b/docker/Dockerfile.metax @@ -0,0 +1,7 @@ +# docker pull from https://sw-download.metax-tech.com/docker +FROM cr.metax-tech.com/public-ai-release/maca/vllm:maca.ai3.1.0.7-torch2.6-py310-ubuntu22.04-amd64 + +ENV MACA_PATH=/opt/maca +ENV PATH=/opt/conda/bin:/opt/conda/condabin:${PATH} +COPY . /lightllm +RUN pip install -r /lightllm/requirements_metax.txt && pip install -e /lightllm --no-cache-dir diff --git a/lightllm/common/basemodel/attention_vit/xformers/fp.py b/lightllm/common/basemodel/attention_vit/xformers/fp.py index 361b5db05..643fc7d4b 100644 --- a/lightllm/common/basemodel/attention_vit/xformers/fp.py +++ b/lightllm/common/basemodel/attention_vit/xformers/fp.py @@ -1,3 +1,4 @@ +from lightllm.utils.device_utils import is_metax import torch import torch.nn.functional as F @@ -34,7 +35,11 @@ def _vit_att_fwd( if max_seqlen: assert max(seqlens) <= max_seqlen - attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens, device=q.device) + # The version of xformers on metex is 0.0.22 (nv is 0.0.32.post1), no device param + if is_metax(): + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + else: + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens, device=q.device) q_ = q.unsqueeze(0) # [1, T, H, D] k_ = k.unsqueeze(0) # [1, T, H, D] diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index 5ba6d0beb..6c3c297ad 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -7,7 +7,7 @@ import math import torch.nn.functional as F -from lightllm.utils.device_utils import is_tesla +from lightllm.utils.device_utils import is_metax, is_tesla @triton.jit @@ -123,7 +123,9 @@ def _fwd_kernel( def context_attention_fwd( q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs ): - BLOCK_M = 128 if not is_tesla() else 64 + BLOCK_M = 128 + if is_tesla() or is_metax(): + BLOCK_M = 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index a1ed6ed95..719c8017a 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -31,6 +31,11 @@ def is_4090(): return "4090" in torch.cuda.get_device_name(0) or "RTX 4090" in torch.cuda.get_device_name(0) +@lru_cache(maxsize=None) +def is_metax(): + return torch.cuda.is_available() and "MetaX" in torch.cuda.get_device_name(0) + + @lru_cache(maxsize=None) def get_device_sm_count(): import triton diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4..ec78e9f18 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -1,3 +1,4 @@ +from lightllm.utils.device_utils import is_metax import torch.distributed as dist import os import torch @@ -65,6 +66,11 @@ def init_vision_distributed_env(kvargs): device_id = visual_gpu_ids[kvargs["vit_rank_id"]] set_current_device_id(device_id) torch.cuda.set_device(device_id) + + # Can't init process group in device twice, we don't init it vision env + if is_metax(): + return + dist.init_process_group( "nccl", init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', @@ -99,8 +105,12 @@ def init_distributed_env(kvargs): device_id = kvargs["rank_id"] % get_node_world_size() set_current_device_id(device_id) torch.cuda.set_device(device_id) + backend = "nccl" + # NCCL internal error when using 8 or 16 gpus. + if is_metax(): + backend = "cpu:gloo,cuda:nccl" dist.init_process_group( - "nccl", + backend=backend, init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', rank=kvargs["rank_id"], world_size=kvargs["world_size"], diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 3256fdd1f..6377e29e3 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -1,3 +1,4 @@ +from lightllm.utils.device_utils import is_metax import torch import ctypes import dataclasses @@ -265,7 +266,33 @@ def _worker(): assert host_ptr.value == device_ptr.value handle.tasks_finished.set() - th = threading.Thread(target=_worker, name=f"cpu_cache_register_{shm_ptr}", daemon=True) + def _metax_worker(): + mc = ctypes.CDLL(os.path.join(os.getenv("MACA_PATH", "/opt/maca"), "lib/libmcruntime.so")) + mc.mcHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] + mc.mcHostRegister.restype = ctypes.c_int + mc.mcHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int] + mc.mcHostGetDevicePointer.restype = ctypes.c_int + + cudaHostRegisterFlag = 3 + + torch.cuda.set_device(get_current_device_id()) + # TODO 这个地方的分块注册是否具备合法性和合理性。 + for offset, seg_len in tasks: + ptr = ctypes.c_void_p(shm_ptr + offset) + r = mc.mcHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag) + if r != 0: + raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb") + handle.task_count += 1 + + device_ptr = ctypes.c_void_p() + host_ptr = ctypes.c_void_p(shm_ptr) + res = mc.mcHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0) + if res != 0: + raise Exception(f"cudaHostGetDevicePointer failed with error code {res}") + handle.tasks_finished.set() + + _worker_func = _metax_worker if is_metax() else _worker + th = threading.Thread(target=_worker_func, name=f"cpu_cache_register_{shm_ptr}", daemon=True) handle.thread = th th.start() return handle diff --git a/requirements_metax.txt b/requirements_metax.txt new file mode 100644 index 000000000..55c5e93dc --- /dev/null +++ b/requirements_metax.txt @@ -0,0 +1,8 @@ +rpyc==6.0.2 +setproctitle==1.3.7 +easydict==1.13 +atomics==1.0.3 +sortedcontainers==2.4.0 +librosa==0.11.0 +gunicorn==24.0.0 +ujson==5.11.0 diff --git a/test/benchmark/service/benchmark_client.py b/test/benchmark/service/benchmark_client.py index 09009fc9e..57b16abcf 100644 --- a/test/benchmark/service/benchmark_client.py +++ b/test/benchmark/service/benchmark_client.py @@ -96,7 +96,7 @@ def post_stream_lightllm(url: str, text_input: str, max_new_tokens: int) -> List def post_stream_openai(url: str, text_input: str, max_new_tokens: int) -> List[float]: data = { "model": model_name[0], - "prompt": text_input, + "messages": [{"role": "user", "content": text_input}], "n": 1, "ignore_eos": True, "max_tokens": max_new_tokens, @@ -115,7 +115,7 @@ def post_stream_openai(url: str, text_input: str, max_new_tokens: int) -> List[f if line == "[DONE]": continue data = json.loads(line) - if not data["choices"][0]["text"]: + if not data["choices"][0]["delta"]["content"]: continue current_time = time.time() elapsed_time = current_time - last_time