Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docker/Dockerfile.metax
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# docker pull from https://sw-download.metax-tech.com/docker
FROM cr.metax-tech.com/public-ai-release/maca/vllm:maca.ai3.1.0.7-torch2.6-py310-ubuntu22.04-amd64

ENV MACA_PATH=/opt/maca
ENV PATH=/opt/conda/bin:/opt/conda/condabin:${PATH}
COPY . /lightllm
RUN pip install -r /lightllm/requirements_metax.txt && pip install -e /lightllm --no-cache-dir
7 changes: 6 additions & 1 deletion lightllm/common/basemodel/attention_vit/xformers/fp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from lightllm.utils.device_utils import is_metax
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -34,7 +35,11 @@ def _vit_att_fwd(
if max_seqlen:
assert max(seqlens) <= max_seqlen

attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens, device=q.device)
# The version of xformers on metex is 0.0.22 (nv is 0.0.32.post1), no device param
if is_metax():
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
else:
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens, device=q.device)

q_ = q.unsqueeze(0) # [1, T, H, D]
k_ = k.unsqueeze(0) # [1, T, H, D]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
import torch.nn.functional as F

from lightllm.utils.device_utils import is_tesla
from lightllm.utils.device_utils import is_metax, is_tesla


@triton.jit
Expand Down Expand Up @@ -123,7 +123,9 @@ def _fwd_kernel(
def context_attention_fwd(
q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs
):
BLOCK_M = 128 if not is_tesla() else 64
BLOCK_M = 128
if is_tesla() or is_metax():
BLOCK_M = 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
Expand Down
5 changes: 5 additions & 0 deletions lightllm/utils/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def is_4090():
return "4090" in torch.cuda.get_device_name(0) or "RTX 4090" in torch.cuda.get_device_name(0)


@lru_cache(maxsize=None)
def is_metax():
return torch.cuda.is_available() and "MetaX" in torch.cuda.get_device_name(0)


@lru_cache(maxsize=None)
def get_device_sm_count():
import triton
Expand Down
12 changes: 11 additions & 1 deletion lightllm/utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from lightllm.utils.device_utils import is_metax
import torch.distributed as dist
import os
import torch
Expand Down Expand Up @@ -65,6 +66,11 @@ def init_vision_distributed_env(kvargs):
device_id = visual_gpu_ids[kvargs["vit_rank_id"]]
set_current_device_id(device_id)
torch.cuda.set_device(device_id)

# Can't init process group in device twice, we don't init it vision env
if is_metax():
return

dist.init_process_group(
"nccl",
init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}',
Expand Down Expand Up @@ -99,8 +105,12 @@ def init_distributed_env(kvargs):
device_id = kvargs["rank_id"] % get_node_world_size()
set_current_device_id(device_id)
torch.cuda.set_device(device_id)
backend = "nccl"
# NCCL internal error when using 8 or 16 gpus.
if is_metax():
backend = "cpu:gloo,cuda:nccl"
dist.init_process_group(
"nccl",
backend=backend,
init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}',
rank=kvargs["rank_id"],
world_size=kvargs["world_size"],
Expand Down
29 changes: 28 additions & 1 deletion lightllm/utils/kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from lightllm.utils.device_utils import is_metax
import torch
import ctypes
import dataclasses
Expand Down Expand Up @@ -265,7 +266,33 @@ def _worker():
assert host_ptr.value == device_ptr.value
handle.tasks_finished.set()

th = threading.Thread(target=_worker, name=f"cpu_cache_register_{shm_ptr}", daemon=True)
def _metax_worker():
mc = ctypes.CDLL(os.path.join(os.getenv("MACA_PATH", "/opt/maca"), "lib/libmcruntime.so"))
mc.mcHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint]
mc.mcHostRegister.restype = ctypes.c_int
mc.mcHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int]
mc.mcHostGetDevicePointer.restype = ctypes.c_int

cudaHostRegisterFlag = 3

torch.cuda.set_device(get_current_device_id())
# TODO 这个地方的分块注册是否具备合法性和合理性。
for offset, seg_len in tasks:
ptr = ctypes.c_void_p(shm_ptr + offset)
r = mc.mcHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag)
if r != 0:
raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb")
handle.task_count += 1

device_ptr = ctypes.c_void_p()
host_ptr = ctypes.c_void_p(shm_ptr)
res = mc.mcHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0)
if res != 0:
raise Exception(f"cudaHostGetDevicePointer failed with error code {res}")
handle.tasks_finished.set()

_worker_func = _metax_worker if is_metax() else _worker
th = threading.Thread(target=_worker_func, name=f"cpu_cache_register_{shm_ptr}", daemon=True)
handle.thread = th
th.start()
return handle
Expand Down
8 changes: 8 additions & 0 deletions requirements_metax.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
rpyc==6.0.2
setproctitle==1.3.7
easydict==1.13
atomics==1.0.3
sortedcontainers==2.4.0
librosa==0.11.0
gunicorn==24.0.0
ujson==5.11.0
4 changes: 2 additions & 2 deletions test/benchmark/service/benchmark_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def post_stream_lightllm(url: str, text_input: str, max_new_tokens: int) -> List
def post_stream_openai(url: str, text_input: str, max_new_tokens: int) -> List[float]:
data = {
"model": model_name[0],
"prompt": text_input,
"messages": [{"role": "user", "content": text_input}],
"n": 1,
"ignore_eos": True,
"max_tokens": max_new_tokens,
Expand All @@ -115,7 +115,7 @@ def post_stream_openai(url: str, text_input: str, max_new_tokens: int) -> List[f
if line == "[DONE]":
continue
data = json.loads(line)
if not data["choices"][0]["text"]:
if not data["choices"][0]["delta"]["content"]:
continue
current_time = time.time()
elapsed_time = current_time - last_time
Expand Down