diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index d6abd0948a..e789039efd 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -86,8 +86,8 @@ jobs: with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - # Build and push default image (cuda12.8.0) - - name: Build and push Docker image (default cuda12.8.0) + # Build and push default image (cuda13.0.0) + - name: Build and push Docker image (default cuda13.0.0) id: build-and-push uses: docker/build-push-action@ac9327eae2b366085ac7f6a2d02df8aa8ead720a with: @@ -97,10 +97,11 @@ jobs: tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} build-args: | - CUDA_VERSION=12.8.0 + CUDA_VERSION=13.0.0 ENABLE_DEEPEP=1 ENABLE_NIXL=1 ENABLE_CACHE=1 + ENABLE_SM100=0 cache-from: type=gha cache-to: type=gha,mode=max @@ -117,4 +118,4 @@ jobs: DIGEST: ${{ steps.build-and-push.outputs.digest }} # This step uses the identity token to provision an ephemeral certificate # against the sigstore community Fulcio instance. - run: echo "${TAGS}" | xargs -I {} cosign sign --yes {}@${DIGEST} \ No newline at end of file + run: echo "${TAGS}" | xargs -I {} cosign sign --yes {}@${DIGEST} diff --git a/.gitignore b/.gitignore index 63408699f4..9b69e2eb4c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +logs/ \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 439ecddb34..313c4c72a5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,14 +1,17 @@ -ARG CUDA_VERSION=12.8.0 +ARG CUDA_VERSION=13.0.0 FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 -ARG VLLM_VERSION=0.16.0 +ARG VLLM_VERSION=0.21.0 +ARG NIXL_REF=v1.2.0 ARG FLASH_MLA_REF=47c35a7 +ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 ARG ENABLE_CACHE=1 +ARG ENABLE_SM100=0 ENV PATH=/opt/conda/bin:$PATH \ CONDA_PREFIX=/opt/conda @@ -44,13 +47,20 @@ WORKDIR /root COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip -RUN pip install -r /lightllm/requirements.txt --no-cache-dir -RUN pip install --no-cache-dir vllm==${VLLM_VERSION} -RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ +RUN pip install --no-cache-dir \ + -i https://pypi.org/simple \ + --extra-index-url https://download.pytorch.org/whl/cu130 \ + vllm==${VLLM_VERSION} +RUN pip install -r /lightllm/requirements.txt --no-cache-dir \ + -i https://pypi.org/simple \ + --extra-index-url https://download.pytorch.org/whl/cu130 +RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cuda/targets/x86_64-linux/include${CPATH:+:${CPATH}} && \ + git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ cd /root/FlashMLA && \ git checkout ${FLASH_MLA_REF} && \ git submodule update --init --recursive && \ - FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . + FLASH_MLA_DISABLE_SM100="$(if [ "${ENABLE_SM100}" = "1" ]; then echo 0; else echo 1; fi)" \ + pip install --no-cache-dir . RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* @@ -78,27 +88,20 @@ RUN if [ "${ENABLE_NIXL}" = "1" ] || [ "${ENABLE_DEEPEP}" = "1" ]; then \ RUN if [ "${ENABLE_DEEPEP}" = "1" ]; then \ set -e; \ ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so; \ - NVSHMEM_VERSION=3.3.9; \ - CUDA_ARCHS=90; \ - wget https://developer.download.nvidia.com/compute/redist/nvshmem/${NVSHMEM_VERSION}/source/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && tar -xf nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz && mv nvshmem_src nvshmem \ - && cd nvshmem \ - && rm -f /root/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS} \ - && cmake --build build --target install -j64; \ - DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58; \ - cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..; \ - cd /root/DeepEP && NVSHMEM_DIR=/root/nvshmem/install python setup.py install; \ + python -m pip install --upgrade --no-deps \ + "nvidia-nccl-cu13==2.30.4" \ + "nvidia-nvshmem-cu13==3.6.5"; \ + cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout b306af06afd412c88e51e71802951606e40b7358; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so.2 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so; \ + pip install --no-build-isolation .; \ fi +RUN cd /root && git clone https://github.com/deepseek-ai/DeepGEMM.git && \ + cd DeepGEMM && git checkout ${DEEPGEMM_REF} && \ + git submodule update --init --recursive && \ + pip install --no-build-isolation . + RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ @@ -126,7 +129,7 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y pkg-config tmux net-tools && \ cd /usr/local/src; \ pip install --upgrade meson pybind11 patchelf; \ - git clone https://github.com/ai-dynamo/nixl.git -b main && \ + git clone https://github.com/ai-dynamo/nixl.git -b ${NIXL_REF} && \ cd nixl && \ rm -rf build && \ mkdir build && \ diff --git a/docker/scripts/build.sh b/docker/scripts/build.sh index 355d6c65b3..bc1fd73da3 100644 --- a/docker/scripts/build.sh +++ b/docker/scripts/build.sh @@ -18,21 +18,23 @@ set -euo pipefail # --no-nixl Disable NIXL (default: enabled) # --no-cache Disable cache (default: enabled) # --lite Disable DEEPEP, NIXL and cache in one shot -# --cuda-version CUDA version (default: 12.8.0) +# --cuda-version CUDA version (default: 13.0.0) # --image-prefix Image prefix (default: lightllm) # --image-tag Image tag (default: generated from enabled features) +# --enable-sm100 Enable SM100 support (default: disabled) # -h / --help Show help ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" cd "${ROOT_DIR}" IMAGE_PREFIX="${IMAGE_PREFIX:-lightllm}" -CUDA_VERSION="${CUDA_VERSION:-12.8.0}" +CUDA_VERSION="${CUDA_VERSION:-13.0.0}" IMAGE_TAG="${IMAGE_TAG:-}" ENABLE_DEEPEP="${ENABLE_DEEPEP:-1}" ENABLE_NIXL="${ENABLE_NIXL:-1}" ENABLE_CACHE="${ENABLE_CACHE:-1}" +ENABLE_SM100="${ENABLE_SM100:-0}" print_help() { sed -n '1,80p' "$0" | sed 's/^# \{0,1\}//' @@ -43,6 +45,7 @@ while [[ $# -gt 0 ]]; do --no-deepep) ENABLE_DEEPEP=0 ;; --no-nixl) ENABLE_NIXL=0 ;; --no-cache) ENABLE_CACHE=0 ;; + --enable-sm100) ENABLE_SM100=1 ;; --lite) ENABLE_DEEPEP=0 ENABLE_NIXL=0 @@ -78,13 +81,16 @@ done # - Other combos: composed from enabled feature names if [[ -z "${IMAGE_TAG}" ]]; then tag_parts=() + if [[ "${ENABLE_SM100}" -eq 1 ]]; then + tag_parts+=("sm100") + fi if [[ "${ENABLE_NIXL}" -eq 1 ]]; then tag_parts+=("nixl") fi if [[ "${ENABLE_DEEPEP}" -eq 1 ]]; then tag_parts+=("deepep") fi - if [[ "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then + if [[ "${ENABLE_SM100}" -eq 0 && "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then IMAGE_TAG="cuda${CUDA_VERSION}" else prefix="" @@ -100,6 +106,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \ --build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \ --build-arg ENABLE_NIXL="${ENABLE_NIXL}" \ --build-arg ENABLE_CACHE="${ENABLE_CACHE}" \ + --build-arg ENABLE_SM100="${ENABLE_SM100}" \ --progress=plain \ -t "${IMAGE_PREFIX}:${IMAGE_TAG}" . - diff --git a/docs/CN/source/cookbook/qwen35_deployment.rst b/docs/CN/source/cookbook/qwen35_deployment.rst index 4cb6bf93e4..5e25ef9d9c 100644 --- a/docs/CN/source/cookbook/qwen35_deployment.rst +++ b/docs/CN/source/cookbook/qwen35_deployment.rst @@ -74,6 +74,17 @@ Qwen3.5-397B-A17B(8×H200) - ``--graph_max_batch_size 128``: CUDA graph 最大批处理大小(显存不足时可减小) - ``--reasoning_parser qwen3``: 启用 Qwen3 推理解析器,支持思考模式 +线性注意力缓存调参说明 +~~~~~~~~~~~~~~~~~~~~~~ + +Qwen3.5 使用混合注意力架构,在涉及线性注意力缓存复用时,建议关注以下参数: + +- ``--linear_att_hash_page_size``: 小块粒度(每个 hash bucket 的 token 数) +- ``--linear_att_page_block_num``: 块级匹配相关配置。可将块大小近似理解为 ``linear_att_page_block_num * linear_att_hash_page_size``。 +- 当 ``linear_att_page_block_num * linear_att_hash_page_size > max_req_total_len`` 时,radix cache 的块级匹配能力会近似关闭,更多依赖请求级小块匹配(小块大小为 ``linear_att_hash_page_size``)。 +- 在高负载下,小块数量不足叠加内部 LRU 淘汰,可能导致命中率下降。此时可调大 ``--linear_att_cache_size`` 提升命中率,但会增加内存占用。 +- 开启 ``--enable_cpu_cache`` 时,CPU cache 的 page 大小会被强制设置为 ``linear_att_page_block_num * linear_att_hash_page_size``,以满足内部复用约束。 + 纯文本模式(节省显存) ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index ebe7b3ff89..8e7f9d78e8 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -18,6 +18,16 @@ APIServer 参数详解 * ``pd_master``: pd 主节点模式(用于 pd 分离运行模式) * ``config_server``: 配置服务器模式(用于 pd 分离模式,用于注册 pd_master 节点并获取 pd_master 节点列表),专门为大规模、高并发场景设计,当 `pd_master` 遇到显著的 CPU 瓶颈时使用。 +.. option:: --performance_mode, --p_mode + + 不同场景的性能模式,可选值: + + * ``None``: 不应用性能模式(默认) + * ``personal``: 私有化个人运行模式,自动设置: + - ``running_max_req_size`` 为 3 + - ``batch_max_tokens`` 为 2048 (2k) + - ``chunked_prefill_size`` 为 1024 (1k) + .. option:: --host 服务器监听地址,默认为 ``127.0.0.1`` @@ -122,7 +132,10 @@ PD 分离模式参数 .. option:: --max_req_total_len - 请求输入长度 + 请求输出长度的最大值,默认为 ``16384`` + 请求输入长度 + 请求输出长度的最大值。若未显式设置,将从模型配置自动推导, + 若推导失败则回退到 ``16384``。 + 对于部分 RoPE 类型(如 ``yarn/dynamic/su/llama3``),推导不会直接用 ``rope_scaling.factor`` + 去乘以 ``max_position_embeddings``,以避免过度估算最大长度。 .. option:: --eos_id @@ -201,6 +214,16 @@ PD 分离模式参数 激进调度可能导致解码期间频繁的预填充中断。禁用它可以让 router_max_wait_tokens 参数更有效地工作。 +.. option:: --enable_prefill_decode_mixed + + 在同一次推理调度步骤中混合执行 prefill 与 decode。 + + 仅支持 ``--run_mode`` 为 ``normal`` 时开启。当同时存在 prefill 与 decode 请求时,调度器会在同一步内 + 先执行 prefill、再执行 decode,而不是在激进调度下只执行 prefill、阻塞 decode,从而在有新 prefill + 请求时也能推进 decode,提升整体吞吐。 + + 不能与 ``--enable_prefill_microbatch_overlap`` 或 ``--enable_decode_microbatch_overlap`` 同时使用。 + .. option:: --disable_dynamic_prompt_cache 禁用kv cache 缓存 @@ -259,6 +282,18 @@ PD 分离模式参数 多模态资源的缓存服务器容量,默认为 ``200`` +.. option:: --max_image_token_count + + 单张图片在转换为 token 后允许的最大 token 数量,默认为 ``6128`` + + 当任意图片超过该阈值时,请求会被拒绝。 + +.. option:: --max_image_pixels + + 单张图片在预处理缩放前允许的最大像素数量,默认为 ``8294400``(约等于 4K 图片像素总量)。 + + 当输入图片超过该阈值时,LightLLM 会先自动将其缩放到该像素预算内,再继续后续流程。 + .. option:: --visual_infer_batch_size 每次推理批次中处理的图像数量,默认为 ``1`` @@ -293,13 +328,13 @@ PD 分离模式参数 性能优化参数 ------------ -.. option:: --disable_custom_allreduce +.. option:: --disable_symm_mem_allreduce - 是否禁用自定义 allreduce + 禁用默认开启的 SymmMem all-reduce 快路径,并回退到 NCCL -.. option:: --enable_custom_allgather +.. option:: --disable_flashinfer_allreduce - 是否启用自定义 allgather + 禁用默认开启的 FlashInfer all-reduce 快路径,并回退到 SymmMem / NCCL .. option:: --enable_tpsp_mix_mode @@ -342,6 +377,41 @@ PD 分离模式参数 - ``fp8kv_sph``: FP8 静态按 head 量化,对应 fa3 后端 - ``fp8kv_spt``: FP8 静态按 tensor 量化,对应 flashinfer 后端 +.. option:: --linear_att_hash_page_size + + 线性注意力的哈希页大小,默认为 ``512``。 + + 该参数控制每个哈希桶中的 token 数量,会影响 radix cache 的复用效果。 + +.. option:: --linear_att_page_block_num + + 线性注意力状态存储使用的块数量,默认为 ``10000000``。 + + 该参数控制用于保存注意力状态的可用页数,会影响内存占用和多轮对话性能。 + 在当前实现中,可将块大小近似理解为 + ``linear_att_page_block_num * linear_att_hash_page_size``。 + 当 ``linear_att_page_block_num * linear_att_hash_page_size > max_req_total_len`` 时, + radix cache 的块级匹配能力会近似被关闭,此时更依赖请求级别的小块匹配(小块大小为 ``linear_att_hash_page_size``)。 + 如果负载较高,小块数量不足叠加内部 LRU 淘汰机制,可能导致 cache 命中率下降。 + + 当开启 ``--enable_cpu_cache`` 时,cpu cache 的 page 大小会被强制设置为 + ``linear_att_page_block_num * linear_att_hash_page_size``,以满足内部复用约束。 + +.. option:: --linear_att_cache_size + + 线性注意力缓存大小。 + + 不指定时会根据缓存相关配置自动计算。 + 当高负载下出现小块缓存命中不足(例如受小块数量和 LRU 淘汰影响)时, + 可以调大该参数以提升命中率,但会增加内存占用。 + +.. option:: --linear_att_ssm_data_type + + 线性注意力 SSM 状态的数据类型,可选值: + + * ``bfloat16`` + * ``float32``(默认) + .. option:: --disable_cudagraph 禁用解码阶段的 cudagraph @@ -394,6 +464,14 @@ PD 分离模式参数 示例可以在 test/advanced_config/mixed_quantization/llamacls-mix-down.yaml 中找到。 +.. option:: --expert_dtype + + EP MoE 专家量化类型,可选值: + + * ``fp8`` + * ``fp4``,仅支持 SM100 GPU + * ``None`` (默认) + .. option:: --vit_quant_type ViT 量化方法,可选值: @@ -426,14 +504,6 @@ PD 分离模式参数 使用奖励模型 -.. option:: --long_truncation_mode - - 当 input_token_len + max_new_tokens > max_req_total_len 时的处理方式,可选值: - - * ``None``: 抛出异常(默认) - * ``head``: 移除一些头部 token 使 input_token_len + max_new_tokens <= max_req_total_len - * ``center``: 移除中心位置的一些 token 使 input_token_len + max_new_tokens <= max_req_total_len - .. option:: --use_tgi_api 使用 tgi 输入和输出格式 @@ -509,4 +579,4 @@ DeepSeek 冗余专家参数 .. option:: --enable_monitor_auth - 是否为 push_gateway 开启身份验证 \ No newline at end of file + 是否为 push_gateway 开启身份验证 diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index de7ecc84c3..00cf12da0c 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -175,6 +175,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 # PD prefill 模式 for DeepSeek-R1 (DP+EP) on H200 # 使用方法: sh pd_prefill.sh + # 默认使用 NIXL 传输;如需使用 NCCL 数据面,可设置 LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl # nvidia-cuda-mps-control -d,运行MPS(可选, 有mps支持性能会好特别多,但是部分显卡和驱动环境开启mps会容易出现错误,建议升级驱动到较高版本,特别是H系列卡) export host=$1 @@ -201,6 +202,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 # PD decode 模式 for DeepSeek-R1 (DP+EP) on H200 # 使用方法: sh pd_decode.sh + # 默认使用 NIXL 传输;如需使用 NCCL 数据面,可设置 LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d @@ -336,4 +338,4 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --tokenizer_path /path/DeepSeek-R1/ \ --url http://127.0.0.1:8088/generate_stream -以上所有脚本可以参考 `test/start_scripts/multi_pd_master/` 目录下的脚本。 \ No newline at end of file +以上所有脚本可以参考 `test/start_scripts/multi_pd_master/` 目录下的脚本。 diff --git a/docs/EN/source/cookbook/qwen35_deployment.rst b/docs/EN/source/cookbook/qwen35_deployment.rst index 6b3b56252d..36e1288eed 100644 --- a/docs/EN/source/cookbook/qwen35_deployment.rst +++ b/docs/EN/source/cookbook/qwen35_deployment.rst @@ -74,6 +74,17 @@ Deploy the full multimodal MoE model on 8 GPUs: - ``--graph_max_batch_size 128``: Maximum batch size for CUDA graph optimization (reduce if OOM) - ``--reasoning_parser qwen3``: Enable Qwen3 reasoning parser for thinking mode +Linear-Attention Cache Tuning Notes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Qwen3.5 uses a hybrid attention architecture. For linear-attention cache reuse, pay attention to: + +- ``--linear_att_hash_page_size``: small-page granularity (tokens per hash bucket) +- ``--linear_att_page_block_num``: block-level matching related setting. Block size can be approximated as ``linear_att_page_block_num * linear_att_hash_page_size``. +- When ``linear_att_page_block_num * linear_att_hash_page_size > max_req_total_len``, block-level matching in radix cache is effectively disabled, and request-level small-page matching (small page size is ``linear_att_hash_page_size``) becomes dominant. +- Under high load, limited small-page capacity plus internal LRU eviction can reduce hit rate. In this case, increasing ``--linear_att_cache_size`` can improve hit rate, at the cost of more memory usage. +- When ``--enable_cpu_cache`` is enabled, CPU cache page size is forced to ``linear_att_page_block_num * linear_att_hash_page_size`` to satisfy internal reuse constraints. + Text-only Mode (Save Memory) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/EN/source/index.rst b/docs/EN/source/index.rst index 808f432892..d4a36385cf 100755 --- a/docs/EN/source/index.rst +++ b/docs/EN/source/index.rst @@ -53,6 +53,7 @@ Documentation List Multimodal Deployment Reward Model Deployment OpenAI api Usage + Anthropic Messages API Function Calling Reasoning Parser APIServer Parameters diff --git a/docs/EN/source/tutorial/anthropic.rst b/docs/EN/source/tutorial/anthropic.rst new file mode 100644 index 0000000000..f9f8146ab9 --- /dev/null +++ b/docs/EN/source/tutorial/anthropic.rst @@ -0,0 +1,73 @@ +.. _anthropic_api: + +Anthropic Messages API (Experimental) +===================================== + +LightLLM can expose a ``/v1/messages`` endpoint that speaks the Anthropic +Messages API wire protocol. This is useful if you have client code written +against the Anthropic Python/TypeScript SDK and want to point it at a locally +hosted open-source model without rewriting the client. + +Enabling +-------- + +The ``/v1/messages`` endpoint is always exposed; no extra flag is needed: + +.. code-block:: bash + + python -m lightllm.server.api_server \ + --model_dir /path/to/model \ + --port 8088 + +Using it from the Anthropic SDK +------------------------------- + +.. code-block:: python + + import anthropic + + client = anthropic.Anthropic( + base_url="http://localhost:8088", + api_key="dummy", + ) + resp = client.messages.create( + model="any-name", # echoed back; LightLLM serves the loaded model + max_tokens=1024, + messages=[{"role": "user", "content": "hello"}], + ) + print(resp.content[0].text) + +Streaming works the same way the Anthropic SDK expects: + +.. code-block:: python + + with client.messages.stream( + model="any-name", + max_tokens=256, + messages=[{"role": "user", "content": "Count from 1 to 5."}], + ) as stream: + for text in stream.text_stream: + print(text, end="", flush=True) + +Supported features +------------------ + +- Text generation (streaming and non-streaming) +- System prompts +- Tool use / function calling +- Multi-turn conversations +- Vision (image inputs) via Anthropic content blocks + +Known limitations +----------------- + +- Prompt caching (``cache_control``) is accepted but ignored; ``cache_*`` + fields in ``usage`` are always zero. +- Extended thinking (``thinking`` parameter) is not supported. +- The Batch API (``/v1/messages/batches``) and Files API are not implemented. +- Model name is accepted but ignored; LightLLM always serves the model + loaded via ``--model_dir`` and echoes the requested name back in the response. +- On the streaming path, ``message_start.message.usage.input_tokens`` is + always ``0`` because the upstream usage chunk arrives after all content + chunks. Clients that need an accurate prompt-token count should read + ``message_delta.usage`` at the end of the stream. diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index c31bba4903..84785de3b7 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -18,6 +18,16 @@ Basic Configuration Parameters * ``pd_master``: pd master node mode (for pd disaggregation running mode) * ``config_server``: Configuration server mode (for pd disaggregation mode, used to register pd_master nodes and get pd_master node list), specifically designed for large-scale, high-concurrency scenarios, used when `pd_master` encounters significant CPU bottlenecks. +.. option:: --performance_mode, --p_mode + + Performance mode for different scenarios, optional values: + + * ``None``: No performance mode applied (default) + * ``personal``: Private personal running mode, automatically sets: + - ``running_max_req_size`` to 3 + - ``batch_max_tokens`` to 2048 (2k) + - ``chunked_prefill_size`` to 1024 (1k) + .. option:: --host Server listening address, default is ``127.0.0.1`` @@ -122,7 +132,10 @@ Memory and Batch Processing Parameters .. option:: --max_req_total_len - Maximum value of request input length + request output length, default is ``16384`` + Maximum value of request input length + request output length. If not set, it will be + automatically derived from model config.json and fall back to ``16384`` if derivation fails. + For some RoPE types (like ``yarn/dynamic/su/llama3``), the derivation does not multiply + ``rope_scaling.factor`` by ``max_position_embeddings`` to avoid over-estimating the max length. .. option:: --eos_id @@ -200,6 +213,17 @@ Scheduling Parameters Aggressive scheduling may cause frequent prefill interruptions during decoding. Disabling it can make the router_max_wait_tokens parameter work more effectively. +.. option:: --enable_prefill_decode_mixed + + Enable mixed prefill and decode scheduling in the same inference step. + + Only supported when ``--run_mode`` is ``normal``. When both prefill and decode requests are pending, + the scheduler runs prefill first and then decode in one scheduling step, instead of running only + prefill under aggressive scheduling. This improves decode throughput when new prefill requests arrive. + + Cannot be used together with ``--enable_prefill_microbatch_overlap`` or + ``--enable_decode_microbatch_overlap``. + .. option:: --disable_dynamic_prompt_cache Disable kv cache caching @@ -257,6 +281,18 @@ Multimodal Parameters Cache server capacity for multimodal resources, default is ``200`` +.. option:: --max_image_token_count + + Maximum allowed token count for a single image after tokenization, default is ``6128`` + + Requests are rejected when any image exceeds this limit. + +.. option:: --max_image_pixels + + Maximum allowed pixel count for a single image before preprocessing resize, default is ``8294400`` (about 4K image pixels). + + If an input image exceeds this threshold, LightLLM automatically resizes it down to this pixel budget before continuing. + .. option:: --visual_infer_batch_size Number of images processed in each inference batch, default is ``1`` @@ -291,13 +327,13 @@ Multimodal Parameters Performance Optimization Parameters ----------------------------------- -.. option:: --disable_custom_allreduce +.. option:: --disable_symm_mem_allreduce - Whether to disable custom allreduce + Disable the default SymmMem all-reduce fast path and fall back to NCCL -.. option:: --enable_custom_allgather +.. option:: --disable_flashinfer_allreduce - Whether to enable custom allgather + Disable the default FlashInfer all-reduce fast path and fall back to SymmMem / NCCL .. option:: --enable_tpsp_mix_mode @@ -343,6 +379,42 @@ Performance Optimization Parameters * ``fp8kv_sph``: FP8 static per-head quantization, uses fa3 backend * ``fp8kv_spt``: FP8 static per-tensor quantization, uses flashinfer backend +.. option:: --linear_att_hash_page_size + + Hash page size for linear attention, default is ``512``. + + This controls the number of tokens per hash bucket, which can affect radix cache reuse. + +.. option:: --linear_att_page_block_num + + Number of blocks used for linear-attention state storage, default is ``10000000``. + + This controls the available pages for attention state data, which can affect memory usage and multi-turn chat performance. + In current behavior, block size can be approximated as + ``linear_att_page_block_num * linear_att_hash_page_size``. + When ``linear_att_page_block_num * linear_att_hash_page_size > max_req_total_len``, + block-level matching in radix cache is effectively disabled, and request-level small-page matching + (small page size is ``linear_att_hash_page_size``) becomes dominant. + Under high load, limited small-page capacity plus internal LRU eviction can reduce cache hit rate. + + When ``--enable_cpu_cache`` is enabled, CPU cache page size is forced to + ``linear_att_page_block_num * linear_att_hash_page_size`` to satisfy internal reuse constraints. + +.. option:: --linear_att_cache_size + + Size of linear-attention cache. + + If not specified, it will be automatically derived from cache-related settings. + If small-page cache hits are poor under high load (for example, due to limited small-page count and LRU eviction), + increasing this value can improve cache hit rate, at the cost of more memory usage. + +.. option:: --linear_att_ssm_data_type + + Data type of linear-attention SSM state, optional values: + + * ``bfloat16`` + * ``float32`` (default) + .. option:: --disable_cudagraph Disable cudagraph in the decoding phase @@ -393,6 +465,14 @@ Quantization Parameters Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml. +.. option:: --expert_dtype + + Expert quantization dtype for EP MoE, optional values: + + * ``fp8`` + * ``fp4``: SM100 GPUs only + * ``None`` (default) + .. option:: --vit_quant_type ViT quantization method, optional values: @@ -425,14 +505,6 @@ Sampling and Generation Parameters Use reward model -.. option:: --long_truncation_mode - - How to handle when input_token_len + max_new_tokens > max_req_total_len, optional values: - - * ``None``: Throw exception (default) - * ``head``: Remove some head tokens to make input_token_len + max_new_tokens <= max_req_total_len - * ``center``: Remove some tokens at the center position to make input_token_len + max_new_tokens <= max_req_total_len - .. option:: --use_tgi_api Use tgi input and output format @@ -508,4 +580,4 @@ Monitoring and Logging Parameters .. option:: --enable_monitor_auth - Whether to enable authentication for push_gateway \ No newline at end of file + Whether to enable authentication for push_gateway diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index 4c5a121dd6..3a968b8948 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -175,6 +175,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for # PD prefill mode for DeepSeek-R1 (DP+EP) on H200 # Usage: sh pd_prefill.sh + # NIXL is used by default. To use NCCL as the data-plane backend, set LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl. # nvidia-cuda-mps-control -d, run MPS (optional, performance will be much better with mps support, but some GPUs may encounter errors when enabling mps, it's recommended to upgrade to a higher driver version, especially for H-series cards) export host=$1 @@ -198,6 +199,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for # PD decode mode for DeepSeek-R1 (DP+EP) on H200 # Usage: sh pd_decode.sh + # NIXL is used by default. To use NCCL as the data-plane backend, set LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl. export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d @@ -333,4 +335,4 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --tokenizer_path /path/DeepSeek-R1/ \ --url http://127.0.0.1:8088/generate_stream -All the above scripts can be referenced from the scripts in the `test/start_scripts/multi_pd_master/` directory. \ No newline at end of file +All the above scripts can be referenced from the scripts in the `test/start_scripts/multi_pd_master/` directory. diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 2c4a34d325..594e81a9b4 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -70,11 +70,16 @@ def _auto_select_backend( ) -> type: """Auto-select the best available backend with validation. - Priority: FA3 > FlashInfer > Triton + Priority follows the provided priority_list. Each backend is validated in a subprocess with ground truth checks. """ backend_map = kv_type_to_backend + args = get_env_start_args() + if args.enable_ep_moe: + logger.info("Expert parallelism with MoE enabled, excluding flashinfer attention backend") + priority_list = [name for name in priority_list if name != "flashinfer"] + for backend_name in priority_list: if backend_name in backend_map[llm_dtype] and validate(backend_name): logger.info(f"Auto-selected {backend_name} backend (validated)") @@ -95,7 +100,7 @@ def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashi return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list) -def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: +def get_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type backend_str = args.llm_decode_att_backend[index] @@ -115,7 +120,7 @@ def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "fl return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list) -def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: +def get_mla_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type backend_str = args.llm_decode_att_backend[index] diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2e..e0c44e8ed5 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -4,6 +4,7 @@ from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id from ...triton_kernel.repack_kv_index import repack_kv_index from .env_utils import set_flashinfer_envs +from .utils import should_init_decode_wrapper class FlashInferAttBackend(BaseAttBackend): @@ -126,6 +127,9 @@ class FlashInferDecodeAttState(BaseDecodeAttState): kv_starts: torch.Tensor = None decode_wrapper: object = None + def _should_init_decode_wrapper(self) -> bool: + return should_init_decode_wrapper(self.backend.model, self.infer_state) + def init_state(self): import flashinfer @@ -156,6 +160,10 @@ def init_state(self): self.kv_indices, ) self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if not self._should_init_decode_wrapper(): + # 处于 graph replay 回放阶段,不需要特殊初始化 decode wrapper。 + return + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.backend.workspace_buffer, @@ -182,18 +190,36 @@ def init_state(self): def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): super().copy_for_decode_cuda_graph(new_state) - self.decode_wrapper.plan( - new_state.kv_starts, - new_state.kv_indices, - new_state.kv_last_page_len_buffer, - new_state.backend.tp_q_head_num, - new_state.backend.tp_kv_head_num, - new_state.backend.head_dim, - 1, - q_data_type=new_state.backend.q_data_type, - kv_data_type=new_state.backend.kv_data_type, + self._refresh_cuda_graph_decode_plan(new_state.infer_state.max_kv_seq_len) + return + + def _refresh_cuda_graph_decode_plan(self, max_kv_len: int): + from flashinfer.decode import fast_decode_plan + + uniform_kv_indptr_cpu = ( + torch.arange( + self.infer_state.batch_size + 1, + dtype=torch.int32, + device="cpu", + ) + * max_kv_len + ) + + fast_decode_plan( + self.decode_wrapper, + indptr=self.kv_starts, + indices=self.kv_indices, + last_page_len=self.kv_last_page_len_buffer, + num_qo_heads=self.backend.tp_q_head_num, + num_kv_heads=self.backend.tp_kv_head_num, + head_dim=self.backend.head_dim, + page_size=1, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, non_blocking=True, + global_override_indptr_cpu=uniform_kv_indptr_cpu, ) + return def decode_att( self, diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 84b44dc45a..a71dd4d464 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -3,8 +3,10 @@ from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id from ...triton_kernel.repack_kv_index import repack_kv_index +from ...triton_kernel.flashinfer_mla_plan import fill_mla_decode_plan_for_cuda_graph from typing import Tuple from .env_utils import set_flashinfer_envs +from .utils import should_init_decode_wrapper class MlaFlashInferAttBackend(BaseAttBackend): @@ -113,8 +115,12 @@ def _mla_prefill_att( class MlaFlashInferDecodeAttState(BaseDecodeAttState): kv_indices: torch.Tensor = None kv_starts: torch.Tensor = None + q_indptr_host: torch.Tensor = None decode_wrapper: object = None + def _should_init_decode_wrapper(self) -> bool: + return should_init_decode_wrapper(self.backend.model, self.infer_state) + def init_state(self): import flashinfer @@ -126,6 +132,7 @@ def init_state(self): self.kv_starts = self.infer_state.b1_cu_kv_seq_len self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + self.q_indptr_host = torch.arange(batch_size + 1, dtype=torch.int32, device="cpu") if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ : batch_size * self.backend.max_seq_length @@ -145,6 +152,10 @@ def init_state(self): self.infer_state.max_kv_seq_len, self.kv_indices, ) + + if not self._should_init_decode_wrapper(): + return + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( @@ -173,19 +184,18 @@ def init_state(self): def copy_for_decode_cuda_graph(self, new_state: "MlaFlashInferDecodeAttState"): super().copy_for_decode_cuda_graph(new_state) - self.decode_wrapper.plan( - new_state.q_indptr, - new_state.kv_starts, - new_state.kv_indices, - new_state.infer_state.b_seq_len, - new_state.backend.tp_q_head_num, - new_state.backend.kv_lora_rank, - new_state.backend.qk_rope_head_dim, - 1, - False, # causal - new_state.backend.softmax_scale, - new_state.backend.q_data_type, - new_state.backend.kv_data_type, + self._refresh_cuda_graph_decode_plan(new_state.infer_state.max_kv_seq_len) + return + + def _refresh_cuda_graph_decode_plan(self, max_kv_len: int): + # Prefer the GPU-generated split plan for long decode; use exact non-split for + # short or unsupported graph shapes. + fill_mla_decode_plan_for_cuda_graph( + self.decode_wrapper, + self.kv_starts, + self.infer_state.batch_size, + self.backend.tp_q_head_num, + max_kv_len, ) def decode_att( diff --git a/lightllm/common/basemodel/attention/flashinfer/utils.py b/lightllm/common/basemodel/attention/flashinfer/utils.py new file mode 100644 index 0000000000..b68542e25f --- /dev/null +++ b/lightllm/common/basemodel/attention/flashinfer/utils.py @@ -0,0 +1,16 @@ +def should_init_decode_wrapper(model, infer_state) -> bool: + graph = getattr(model, "graph", None) + if graph is None: + # Cuda graph is disabled, so this state owns a normal decode wrapper. + return True + + if infer_state.is_cuda_graph: + # This is the captured graph state; it must create the wrapper captured by replay. + return True + + if not graph.can_run(infer_state.batch_size, infer_state.max_kv_seq_len): + # Cuda graph is enabled, but this input falls outside graph limits and runs normally. + return True + + # This is a temporary replay state. Its tensors are copied into the captured graph state. + return False diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 673b5896d8..c3456f4b7a 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -79,17 +79,17 @@ def _nsa_prefill_att( from sgl_kernel.flash_mla import flash_mla_sparse_fwd nsa_dict = att_control.nsa_prefill_dict - topk_indices = nsa_dict["topk_indices"] + topk_mem_indices = nsa_dict["topk_mem_indices"] softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] - if topk_indices.ndim == 2: - topk_indices = topk_indices.unsqueeze(1) + if topk_mem_indices.ndim == 2: + topk_mem_indices = topk_mem_indices.unsqueeze(1) mla_out, _, _ = flash_mla_sparse_fwd( q=q, kv=kv, - indices=topk_indices, + indices=topk_mem_indices, sm_scale=softmax_scale, d_v=kv_lora_rank, ) diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index d29f15ec3b..a1370a7045 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -25,12 +25,12 @@ def prefill_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_sliding_window is False and att_control.use_att_sink is False if att_control.use_alibi: + assert att_control.use_sliding_window is False, "alibi + sliding_window not supported" assert att_control.tp_alibi is not None return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: - return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) + return self._nomarl_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) def _alibi_prefill_att( self, @@ -59,9 +59,21 @@ def _alibi_prefill_att( ) return out - def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): + def _nomarl_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd + if att_control.use_sliding_window: + sliding_window = att_control.sliding_window + else: + sliding_window = (-1, -1) + out = alloc_func(q.shape, q.dtype) context_attention_fwd( q, @@ -74,6 +86,7 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, self.infer_state.b_ready_cache_len, self.infer_state.max_q_seq_len, self.infer_state.req_manager.req_to_token_indexs, + sliding_window=sliding_window, ) return out @@ -94,17 +107,20 @@ def decode_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ): - assert att_control.use_sliding_window is False and att_control.use_att_sink is False if att_control.use_alibi: + assert att_control.use_sliding_window is False, "alibi + sliding_window not supported" assert att_control.tp_alibi is not None return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: q_head_num = q.shape[1] k_head_num = k.shape[1] if q_head_num == k_head_num: + assert att_control.use_sliding_window is False, "sliding_window not supported in non-gqa attention yet" return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) elif q_head_num > k_head_num: - return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) + return self._normal_decode_gqa_flash_decoding_att( + q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func + ) else: raise NotImplementedError("error") @@ -163,12 +179,18 @@ def _normal_decode_gqa_flash_decoding_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + att_control: AttControl = AttControl(), alloc_func=torch.empty, ): from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import ( gqa_token_decode_attention_flash_decoding, ) + if att_control.use_sliding_window: + sliding_window = att_control.sliding_window + else: + sliding_window = (-1, -1) + out = alloc_func(q.shape, q.dtype) gqa_token_decode_attention_flash_decoding( @@ -178,6 +200,7 @@ def _normal_decode_gqa_flash_decoding_att( cache_v=v, out=out, alloc_tensor_func=alloc_func, + sliding_window=sliding_window, ) return out diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 1d36c72d0b..e83de684a7 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -6,12 +6,12 @@ import json import torch import torch.nn.functional as F +import triton from typing import final, List from tqdm import tqdm from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo -from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager @@ -22,7 +22,7 @@ from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg -from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token +from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num @@ -54,9 +54,6 @@ class TpPartBaseModel: # infer state class infer_state_class = InferStateInfo - # radix cache class - radix_cache_class = RadixCache - def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] @@ -68,8 +65,6 @@ def __init__(self, kvargs): self.finetune_config = kvargs.get("finetune_config", None) self.max_req_num = kvargs.get("max_req_num", 1000) self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5) - # 用于等待外围的一些模块的初始化完成(如 CPU KV Cache 注册完成) - self.wait_events = kvargs.get("wait_events", []) # is_token_healing 和 return_all_prompt_logics 是有排斥关系的两个模式,只能单独有一个生效 # 主要是在prefill阶段返回多少个token的用于后续处理相关。 self.is_token_healing = kvargs.get("is_token_healing", False) @@ -90,6 +85,7 @@ def __init__(self, kvargs): self.disable_cudagraph = kvargs.get("disable_cudagraph", False) self.quant_type = kvargs.get("quant_type", "none") self.quant_cfg_path = kvargs.get("quant_cfg", None) + self.expert_dtype = kvargs.get("expert_dtype", None) self.mem_fraction = kvargs.get("mem_fraction", 0.9) self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode @@ -108,16 +104,17 @@ def __init__(self, kvargs): self._init_quant() self._init_weights() + self._init_req_manager() self._init_mem_manager() - self._init_kv_move_buffer() + # 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state + # 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值 + self.req_manager.mem_manager = self.mem_manager + self._check_mem_size() - self._init_req_manager() self._init_infer_layer() self._init_some_value() self._init_custom() self._load_hf_weights() - # wait必须在init cudagraph 之前,避免错误捕获 - self._wait_other_modules_ready() self._init_att_backend() self._init_att_backend1() @@ -137,11 +134,6 @@ def __init__(self, kvargs): set_model_init_status(True) return - def _wait_other_modules_ready(self): - for event in self.wait_events: - event.wait() - return - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file) @@ -164,7 +156,7 @@ def _verify_params(self): return def _init_quant(self): - self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path) + self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path, self.expert_dtype) logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): @@ -204,14 +196,25 @@ def _init_mem_manager(self): ) return - def _init_kv_move_buffer(self): - # p d 分离的推理模式下才需要做这一步初始化 - if self.run_mode in ["prefill", "decode"]: - self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size) - def _check_mem_size(self): self.max_total_token_num = self.mem_manager.size - assert self.max_seq_length <= self.max_total_token_num + + assert ( + self.max_total_token_num > self.batch_max_tokens + ), "max_total_token_num must be greater than batch_max_tokens" + + # 非个人性能模式下,需要保证 max_seq_length 小于等于 max_total_token_num, + # 这样才能得到完整的上下文长度的支持。个人模式主要是私有化场景,显卡显存不是 + # 特别大,可能能分配的 kv 容量有限,无法支持 max_seq_length 的推理。所以个人模式下 + # 可以适当放宽这个限制,不做这个校验。 + if self.args.performance_mode != "personal": + assert self.max_seq_length <= self.max_total_token_num, ( + f"max_total_token_num must be >= max_seq_length, " + f"got max_total_token_num={self.max_total_token_num}, " + f"max_seq_length={self.max_seq_length}. " + f"Try set --max_req_total_len a smaller value < {self.max_total_token_num}." + ) + return def _init_req_manager(self): @@ -222,7 +225,7 @@ def _init_req_manager(self): if self.max_seq_length is not None: create_max_seq_len = max(create_max_seq_len, self.max_seq_length) - self.req_manager = ReqManager(self.max_req_num, create_max_seq_len, self.mem_manager) + self.req_manager = ReqManager(self.max_req_num, create_max_seq_len, None) return def _init_infer_layer(self, start_layer_index=0): @@ -257,7 +260,13 @@ def _init_att_backend1(self): def _init_cudagraph(self): self.graph = ( - None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch) + None + if self.disable_cudagraph + else CudaGraph( + max_batch_size=self.graph_max_batch_size, + max_len_in_batch=self.graph_max_len_in_batch, + tp_world_size=self.tp_world_size_, + ) ) if self.graph is not None: if get_env_start_args().enable_decode_microbatch_overlap: @@ -269,7 +278,7 @@ def _init_prefill_cuda_graph(self): self.prefill_graph = ( None if not get_env_start_args().enable_prefill_cudagraph - else PrefillCudaGraph(decode_cuda_graph=self.graph) + else PrefillCudaGraph(decode_cuda_graph=self.graph, tp_world_size=self.tp_world_size_) ) if self.prefill_graph is not None: if get_env_start_args().enable_prefill_microbatch_overlap: @@ -305,6 +314,8 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0] infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len + infer_state.b_mtp_index = model_input.b_mtp_index + infer_state.b_position_delta = model_input.b_position_delta if model_input.is_prefill: if model_input.b_ready_cache_len is not None: infer_state.b_ready_cache_len = model_input.b_ready_cache_len @@ -357,7 +368,14 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input.b_req_idx = F.pad( new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID ) + new_model_input.b_mtp_index = F.pad( + new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 + ) new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) + if new_model_input.b_position_delta is not None: + new_model_input.b_position_delta = F.pad( + new_model_input.b_position_delta, (0, padded_batch_size), mode="constant", value=0 + ) new_model_input.mem_indexes = F.pad( new_model_input.mem_indexes, (0, padded_batch_size), @@ -389,6 +407,9 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s return new_model_input def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle_token_num: int): + if model_input.total_token_num - model_input.prefix_total_token_num == new_handle_token_num: + return model_input + assert model_input.total_token_num - model_input.prefix_total_token_num < new_handle_token_num padded_token_num = new_handle_token_num - (model_input.total_token_num - model_input.prefix_total_token_num) @@ -446,14 +467,16 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba return new_model_output - def _create_unpad_prefill_model_output(self, padded_model_output: ModelOutput, origin_handle_token_num: int): + def _create_unpad_prefill_model_output( + self, padded_model_output: ModelOutput, origin_handle_token_num: int, origin_batch_size: int + ): if self.return_all_prompt_logics: new_model_output = copy.copy(padded_model_output) new_model_output.logits = new_model_output.logits[0:origin_handle_token_num] else: new_model_output = copy.copy(padded_model_output) # 移除多余的pad 的那个 req 对应的 logics - new_model_output.logits = new_model_output.logits[0:-1] + new_model_output.logits = new_model_output.logits[0:origin_batch_size] # 特殊模型,特殊模式的特殊变量的特殊 unpad if new_model_output.mtp_main_output_hiddens is not None: @@ -466,18 +489,32 @@ def _prefill( self, model_input: ModelInput, ): + if self.args.enable_prefill_decode_mixed and model_input.b_is_decode_req is not None: + gather_token_prefill_decode_mixed( + input_ids=model_input.input_ids, + req_to_next_token_ids=self.req_manager.req_sampling_params_manager.req_to_next_token_ids, + b_req_idx=model_input.b_req_idx, + b_mtp_index=model_input.b_mtp_index, + b_is_decode_req=model_input.b_is_decode_req, + b_prefill_start_loc=model_input.b_prefill_start_loc, + ) + origin_handle_token_num = model_input.total_token_num - model_input.prefix_total_token_num + origin_batch_size = model_input.batch_size - is_padded_model_input = False - if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=origin_handle_token_num): - finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num( - handle_token_num=origin_handle_token_num + if self.args.enable_tpsp_mix_mode: + infer_handle_token_num = triton.cdiv(origin_handle_token_num, self.tp_world_size_) * self.tp_world_size_ + else: + infer_handle_token_num = origin_handle_token_num + + if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=infer_handle_token_num): + infer_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num( + handle_token_num=infer_handle_token_num ) - if finded_handle_token_num != origin_handle_token_num: - is_padded_model_input = True - model_input = self._create_padded_prefill_model_input( - model_input=model_input, new_handle_token_num=finded_handle_token_num - ) + + model_input = self._create_padded_prefill_model_input( + model_input=model_input, new_handle_token_num=infer_handle_token_num + ) infer_state = self._create_inferstate(model_input) init_req_to_token_indexes( @@ -495,10 +532,12 @@ def _prefill( infer_state.init_some_extra_state(self) infer_state.init_att_state() model_output = self._context_forward(infer_state) - if is_padded_model_input: - model_output = self._create_unpad_prefill_model_output( - model_output, origin_handle_token_num=origin_handle_token_num - ) + + model_output = self._create_unpad_prefill_model_output( + padded_model_output=model_output, + origin_handle_token_num=origin_handle_token_num, + origin_batch_size=origin_batch_size, + ) model_output.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event return model_output @@ -514,10 +553,22 @@ def _decode( model_input.b_mtp_index, ) - if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_kv_seq_len): - find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) - padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) - infer_state = self._create_inferstate(padded_model_input) + origin_batch_size = model_input.batch_size + if self.args.enable_tpsp_mix_mode: + infer_batch_size = triton.cdiv(model_input.batch_size, self.tp_world_size_) * self.tp_world_size_ + else: + infer_batch_size = model_input.batch_size + + if self.graph is not None and self.graph.can_run( + batch_size=infer_batch_size, max_len_in_batch=model_input.max_kv_seq_len + ): + infer_batch_size = self.graph.find_closest_graph_batch_size(batch_size=infer_batch_size) + model_input = self._create_padded_decode_model_input( + model_input=model_input, new_batch_size=infer_batch_size + ) + infer_state = self._create_inferstate(model_input) + need_capture = self.graph.need_capture(infer_batch_size) + infer_state.is_cuda_graph = need_capture copy_kv_index_to_req( self.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -527,16 +578,16 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() - if self.graph.need_capture(find_graph_batch_size): - infer_state.is_cuda_graph = True + if need_capture: model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: model_output: ModelOutput = self.graph.replay(infer_state) - model_output = self._create_unpad_decode_model_output( - model_output, origin_batch_size=model_input.batch_size - ) + model_output = self._create_unpad_decode_model_output(model_output, origin_batch_size=origin_batch_size) else: + model_input = self._create_padded_decode_model_input( + model_input=model_input, new_batch_size=infer_batch_size + ) infer_state = self._create_inferstate(model_input) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -547,33 +598,36 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() model_output = self._token_forward(infer_state) + model_output = self._create_unpad_decode_model_output(model_output, origin_batch_size=origin_batch_size) return model_output @final def _context_forward(self, infer_state: InferStateInfo): - run_mode_index = 1 if self.enable_tpsp_mix_mode else 0 - input_ids = infer_state.input_ids - cuda_input_ids = input_ids - pre_method = (self.pre_infer.context_forward, self.pre_infer.tpsp_context_forward)[run_mode_index] - input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight) + input_embs = self.pre_infer.context_forward(infer_state.input_ids, infer_state, self.pre_post_weight) + if self.args.enable_dp_prefill_balance: + assert not self.args.enable_prefill_cudagraph, "not support now" + infer_state.prepare_prefill_dp_balance() + input_embs = infer_state._all_to_all_balance_get(data=input_embs) + + input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) input_tensors = [input_embs] def prefill_func(input_tensors, infer_state): _input_embs = input_tensors[0] for i in range(self.layers_num): layer = self.layers_infer[i] - layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index] - _input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i]) + _input_embs = layer.context_forward(_input_embs, infer_state, self.trans_layers_weight[i]) return [_input_embs] - handle_token_num = input_ids.shape[0] + handle_token_num = infer_state.input_ids.shape[0] if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num): finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num( handle_token_num=handle_token_num ) + assert finded_handle_token_num == handle_token_num if self.prefill_graph.need_capture(handle_token_num=finded_handle_token_num): output_tensors: List[torch.Tensor] = self.prefill_graph.capture_prefill( prefill_func=prefill_func, @@ -591,13 +645,20 @@ def prefill_func(input_tensors, infer_state): g_cache_manager.cache_env_out() input_embs = output_tensors[0] - post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index] - predict_logits = post_method(input_embs, infer_state, self.pre_post_weight) + + last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + if infer_state.need_dp_prefill_balance: + last_input_embs = infer_state._all_to_all_unbalance_get(data=last_input_embs) + + predict_logits = self.post_infer.token_forward(last_input_embs, infer_state, self.pre_post_weight) model_output = ModelOutput(logits=predict_logits) # 特殊模型特殊模式的额外输出 if self.is_mtp_mode: - model_output.mtp_main_output_hiddens = input_embs + input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + if infer_state.need_dp_prefill_balance: + input_embs = infer_state._all_to_all_unbalance_get(data=input_embs) + model_output.mtp_main_output_hiddens = input_embs.contiguous() # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 # 该调用没有实际意义 @@ -606,27 +667,26 @@ def prefill_func(input_tensors, infer_state): @final def _token_forward(self, infer_state: InferStateInfo): - run_mode_index = 1 if self.enable_tpsp_mix_mode else 0 input_ids = infer_state.input_ids cuda_input_ids = input_ids - pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index] - input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight) + input_embs = self.pre_infer.token_forward(cuda_input_ids, infer_state, self.pre_post_weight) + input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) + for i in range(self.layers_num): layer = self.layers_infer[i] - layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index] - input_embs: torch.Tensor = layer_method(input_embs, infer_state, self.trans_layers_weight[i]) - - post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index] - predict_logits: torch.Tensor = post_method(input_embs, infer_state, self.pre_post_weight) + input_embs: torch.Tensor = layer.token_forward(input_embs, infer_state, self.trans_layers_weight[i]) - if self.is_mtp_mode: - graph_out_hiddens = input_embs.contiguous() + last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + predict_logits: torch.Tensor = self.post_infer.token_forward( + last_input_embs, infer_state=infer_state, layer_weight=self.pre_post_weight + ) model_output = ModelOutput(logits=predict_logits.contiguous()) # 特殊模型特殊模式的额外输出 if self.is_mtp_mode: - model_output.mtp_main_output_hiddens = graph_out_hiddens + input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + model_output.mtp_main_output_hiddens = input_embs.contiguous() # 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。 if infer_state.is_cuda_graph: @@ -639,8 +699,44 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod model_input0.to_cuda() model_input1.to_cuda() + if self.args.enable_prefill_decode_mixed and model_input0.b_is_decode_req is not None: + gather_token_prefill_decode_mixed( + input_ids=model_input0.input_ids, + req_to_next_token_ids=self.req_manager.req_sampling_params_manager.req_to_next_token_ids, + b_req_idx=model_input0.b_req_idx, + b_mtp_index=model_input0.b_mtp_index, + b_is_decode_req=model_input0.b_is_decode_req, + b_prefill_start_loc=model_input0.b_prefill_start_loc, + ) + + if self.args.enable_prefill_decode_mixed and model_input1.b_is_decode_req is not None: + gather_token_prefill_decode_mixed( + input_ids=model_input1.input_ids, + req_to_next_token_ids=self.req_manager.req_sampling_params_manager.req_to_next_token_ids, + b_req_idx=model_input1.b_req_idx, + b_mtp_index=model_input1.b_mtp_index, + b_is_decode_req=model_input1.b_is_decode_req, + b_prefill_start_loc=model_input1.b_prefill_start_loc, + ) + assert model_input0.mem_indexes.is_cuda assert model_input1.mem_indexes.is_cuda + + assert self.args.enable_tpsp_mix_mode + origin_handle_token_num0 = model_input0.total_token_num - model_input0.prefix_total_token_num + origin_handle_token_num1 = model_input1.total_token_num - model_input1.prefix_total_token_num + infer_handle_token_num0 = triton.cdiv(origin_handle_token_num0, self.tp_world_size_) * self.tp_world_size_ + infer_handle_token_num1 = triton.cdiv(origin_handle_token_num1, self.tp_world_size_) * self.tp_world_size_ + origin_batch_size0 = model_input0.batch_size + origin_batch_size1 = model_input1.batch_size + + model_input0 = self._create_padded_prefill_model_input( + model_input=model_input0, new_handle_token_num=infer_handle_token_num0 + ) + model_input1 = self._create_padded_prefill_model_input( + model_input=model_input1, new_handle_token_num=infer_handle_token_num1 + ) + infer_state0 = self._create_inferstate(model_input0, 0) init_req_to_token_indexes( req_to_token_indexs=self.req_manager.req_to_token_indexs, @@ -672,6 +768,16 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod model_output0, model_output1 = self._overlap_tpsp_context_forward(infer_state0, infer_state1=infer_state1) + model_output0 = self._create_unpad_prefill_model_output( + padded_model_output=model_output0, + origin_handle_token_num=origin_handle_token_num0, + origin_batch_size=origin_batch_size0, + ) + model_output1 = self._create_unpad_prefill_model_output( + padded_model_output=model_output1, + origin_handle_token_num=origin_handle_token_num1, + origin_batch_size=origin_batch_size1, + ) # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 # 该调用没有实际意义 dist_group_manager.clear_deepep_buffer() @@ -683,6 +789,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput): model_input0.to_cuda() model_input1.to_cuda() + assert self.args.enable_tpsp_mix_mode if model_input0.input_ids is None: model_input0.input_ids = gather_token( @@ -703,14 +810,17 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode origin_batch_size = model_input0.batch_size max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) + infer_batch_size = triton.cdiv(origin_batch_size, self.tp_world_size_) * self.tp_world_size_ - if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): - find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) + if self.graph is not None and self.graph.can_run(infer_batch_size, max_len_in_batch): + infer_batch_size = self.graph.find_closest_graph_batch_size(infer_batch_size) + need_capture = self.graph.need_capture(infer_batch_size) # TODO 如果支持动态步数的 mtp,在不同的mtp步上,model_input0 和 model_input1 的内部batch size可能不 # 一致,需要按照较高 batch size 进行graph的寻找,同时,进行有效的恢复。 - padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) - padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) + padded_model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) + padded_model_input1 = self._create_padded_decode_model_input(model_input1, infer_batch_size) infer_state0 = self._create_inferstate(padded_model_input0, 0) + infer_state0.is_cuda_graph = need_capture copy_kv_index_to_req( self.req_manager.req_to_token_indexs, infer_state0.b_req_idx, @@ -721,6 +831,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.init_att_state() infer_state1 = self._create_inferstate(padded_model_input1, 1) + infer_state1.is_cuda_graph = need_capture copy_kv_index_to_req( self.req_manager.req_to_token_indexs, infer_state1.b_req_idx, @@ -730,10 +841,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.init_some_extra_state(self) infer_state1.init_att_state() - if self.graph.need_capture(find_graph_batch_size): - infer_state0.is_cuda_graph = True - infer_state1.is_cuda_graph = True - + if need_capture: model_output0, model_output1 = self.graph.capture_decode( self._overlap_tpsp_token_forward, infer_state0, @@ -749,6 +857,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode model_output0 = self._create_unpad_decode_model_output(model_output0, origin_batch_size=origin_batch_size) model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size) else: + model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) + model_input1 = self._create_padded_decode_model_input(model_input1, infer_batch_size) infer_state0 = self._create_inferstate(model_input0, 0) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -770,20 +880,47 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.init_att_state() model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1) + model_output0 = self._create_unpad_decode_model_output(model_output0, origin_batch_size=origin_batch_size) + model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size) + return model_output0, model_output1 @final def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state1: InferStateInfo): g_cache_manager.cache_env_in() + input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight ) + + # 决定是否进行 dp balance 优化,可以提升dp > 1 时的 prefill 效率。 + if get_env_start_args().enable_dp_prefill_balance: + assert not self.args.enable_prefill_cudagraph, "not support now" + infer_state.prepare_prefill_dp_balance() + infer_state1.prepare_prefill_dp_balance() + input_embs = infer_state._all_to_all_balance_get(data=input_embs) + input_embs1 = infer_state1._all_to_all_balance_get(data=input_embs1) + + input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) + input_embs1 = self.pre_infer._tpsp_sp_split(input=input_embs1, infer_state=infer_state1) + for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] ) + + # 折叠模式调用完infer_state 和 infer_state1 上的hook函数后,input_embs 和 input_embs1 才具备正确的运算数据。 + infer_state.call_overlap_hook() + infer_state1.call_overlap_hook() + + last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + last_input_embs1 = self.post_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) + if infer_state.need_dp_prefill_balance: + last_input_embs = infer_state._all_to_all_unbalance_get(data=last_input_embs) + last_input_embs1 = infer_state1._all_to_all_unbalance_get(data=last_input_embs1) + predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( - input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight + last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight ) g_cache_manager.cache_env_out() @@ -791,6 +928,11 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state model_output1 = ModelOutput(logits=predict_logits1.contiguous()) if self.is_mtp_mode: + input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + input_embs1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) + if infer_state.need_dp_prefill_balance: + input_embs = infer_state._all_to_all_unbalance_get(data=input_embs) + input_embs1 = infer_state1._all_to_all_unbalance_get(data=input_embs1) model_output.mtp_main_output_hiddens = input_embs.contiguous() model_output1.mtp_main_output_hiddens = input_embs1.contiguous() @@ -801,26 +943,33 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward( infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight ) + input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) + input_embs1 = self.pre_infer._tpsp_sp_split(input=input_embs1, infer_state=infer_state1) for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_token_forward( input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] ) + # 折叠模式调用完infer_state 上的hook函数后,input_embs 和 input_embs 才具备正确的运算数据。 + infer_state.call_overlap_hook() + infer_state1.call_overlap_hook() + + last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + last_input_embs1 = self.post_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) + predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( - input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight + last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight ) - if self.is_mtp_mode: - graph_out_hiddens = input_embs.contiguous() - graph_out_hiddens1 = input_embs1.contiguous() - model_output = ModelOutput(logits=predict_logits.contiguous()) model_output1 = ModelOutput(logits=predict_logits1.contiguous()) if self.is_mtp_mode: - model_output.mtp_main_output_hiddens = graph_out_hiddens - model_output1.mtp_main_output_hiddens = graph_out_hiddens1 + input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + input_embs1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) + model_output.mtp_main_output_hiddens = input_embs.contiguous() + model_output1.mtp_main_output_hiddens = input_embs1.contiguous() if infer_state.is_cuda_graph: model_output.to_no_ref_tensor() @@ -942,6 +1091,9 @@ def _autotune_warmup(self): is_prefill=True, b_ready_cache_len=b_ready_cache_len, b_prefill_start_loc=b_prefill_start_loc, + b_prefill_has_output_cpu=[ + False, + ], multimodal_params=[{"images": [], "audios": []}], **self._gen_special_model_input(total_token_num), ) @@ -1003,6 +1155,9 @@ def _init_padded_req(self): b_seq_len=b_seq_len, b_ready_cache_len=b_ready_cache_len, b_prefill_start_loc=b_prefill_start_loc, + b_prefill_has_output_cpu=[ + False, + ], is_prefill=True, multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], **self._gen_special_model_input(total_token_num), diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 758c0b5194..303e6d06ce 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -20,6 +20,12 @@ class ModelInput: b_req_idx: torch.Tensor = None b_mtp_index: torch.Tensor = None b_seq_len: torch.Tensor = None + # 在 prefill 阶段,用于在 enable_prefill_decode_mixed 开启下, + # 用于标识请求是否为 decode 请求混合在 prefill 请求中。 + # 其对应的 input_ids 需要特殊处理, 从 req_to_next_token_ids 中获取。 + + b_is_decode_req: torch.Tensor = None + # 只会在 diverse_mode 下的 decode 阶段真正被使用的参数, 用于记录共享的radix cache中的长度 b_shared_seq_len: torch.Tensor = None # 只会在 diverse_mode 下的 decode 阶段真正被使用的参数, 用于记录请求间的共享关系。 @@ -32,6 +38,9 @@ class ModelInput: mem_indexes: torch.Tensor = None is_prefill: bool = False b_ready_cache_len: torch.Tensor = None + # 只会在继承 Qwen2VLInferStateInfo 的 MRoPE 模型 decode 阶段使用,如 + # Qwen2/2.5-VL、Qwen3-VL/MOE/Omni、Qwen3.5;普通模型不会使用。 + b_position_delta: torch.Tensor = None b_prefill_start_loc: torch.Tensor = None multimodal_params: list = None # cpu 变量 @@ -52,11 +61,22 @@ def to_cuda(self): self.input_ids = self.input_ids.cuda(non_blocking=True) if self.mem_indexes is None: self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True) + + if self.b_is_decode_req is not None: + self.b_is_decode_req = self.b_is_decode_req.cuda(non_blocking=True) + assert self.is_prefill + self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) if self.b_ready_cache_len is not None: self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True) + if self.b_position_delta is not None: + self.b_position_delta = self.b_position_delta.cuda(non_blocking=True) + assert self.is_prefill is False, "b_position_delta should only be used in decode phase." + else: + assert self.is_prefill is True, "decode ModelInput should provide b_position_delta." + if self.b_prefill_start_loc is not None: self.b_prefill_start_loc = self.b_prefill_start_loc.cuda(non_blocking=True) if not self.is_prefill and enable_diverse_mode_gqa_decode_fast_kernel(): diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 5e8301015c..d0ac8ead10 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -2,10 +2,11 @@ import torch import copy import bisect +import triton from typing import Optional from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args -from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup +from lightllm.distributed import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from .infer_struct import InferStateInfo @@ -16,8 +17,9 @@ class CudaGraph: # CudaGraph forward pass for the decoding stage. - def __init__(self, max_batch_size=8, max_len_in_batch=8192): + def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = 1): self.graph = {} + self.tp_world_size = tp_world_size self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None self.args = get_env_start_args() self.mtp_step = self.args.mtp_step @@ -40,6 +42,11 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size])) batch_sizes.append(max_batch_size) batch_sizes.sort() + if self.args.enable_tpsp_mix_mode: + batch_sizes = [triton.cdiv(e, self.tp_world_size) * self.tp_world_size for e in batch_sizes] + batch_sizes = list(set(batch_sizes)) + batch_sizes.sort() + self.cuda_graph_batch_sizes = batch_sizes assert batch_sizes[-1] == self.max_batch_size logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") @@ -63,7 +70,6 @@ def find_closest_graph_batch_size(self, batch_size): return None def _capture_decode(self, decode_func, infer_state: InferStateInfo): - dist_group: CustomProcessGroup = infer_state.dist_group graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids batch_size = input_ids.shape[0] @@ -88,9 +94,8 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): if param_name not in pure_para_set: delattr(infer_state, param_name) - with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): - model_output = decode_func(infer_state) + with torch.cuda.graph(graph_obj, pool=self.mempool): + model_output = decode_func(infer_state) self.graph[batch_size] = (graph_obj, infer_state, model_output) graph_obj.replay() return model_output @@ -101,8 +106,6 @@ def _capture_decode_overlap( infer_state: InferStateInfo, infer_state1: InferStateInfo, ): - dist_group: CustomProcessGroup = infer_state.dist_group - dist_group1 = infer_state1.dist_group graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids batch_size = input_ids.shape[0] @@ -125,10 +128,8 @@ def _capture_decode_overlap( if para_name not in pure_para_set1: delattr(infer_state1, para_name) - with lightllm_capture_graph(dist_group1): - with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): - model_output, model_output1 = decode_func(infer_state, infer_state1) + with torch.cuda.graph(graph_obj, pool=self.mempool): + model_output, model_output1 = decode_func(infer_state, infer_state1) self.graph[batch_size] = ( graph_obj, infer_state, @@ -219,6 +220,7 @@ def warmup(self, model): b_req_idx=b_req_idx, b_seq_len=b_seq_len, b_mtp_index=b_mtp_index, + b_position_delta=torch.zeros(batch_size, dtype=torch.int32, device="cuda"), is_prefill=False, multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], **model._gen_special_model_input(batch_size), @@ -278,6 +280,7 @@ def warmup_overlap(self, model): mem_indexes=mem_indexes, b_req_idx=b_req_idx, b_seq_len=b_seq_len, + b_position_delta=torch.zeros(batch_size, dtype=torch.int32, device="cuda"), multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], **model._gen_special_model_input(batch_size), ) diff --git a/lightllm/common/basemodel/infer_lock.py b/lightllm/common/basemodel/infer_lock.py deleted file mode 100644 index 9da027e662..0000000000 --- a/lightllm/common/basemodel/infer_lock.py +++ /dev/null @@ -1,138 +0,0 @@ -# 这不是一个很好的设计但是不是很好找到更好更简单对架构入侵更小的实现方法。 -# 这个地方声明的锁和计数,主要是用来解决在 PD 分离模式下,kv_move_manager 进程中会出现 -# 通过rpyc调用操作 radix cache 和 mem_manager 中的数据的问题,这可能导致严重的数据同步 -# 问题,主要原因是各个tp的推理进程运行到的位置节点并没有严格的保证,导致radix cache 和 -# mem manager 中的数据出现各个进程间不一致的问题。 -# 下面的实现中,通过一个锁和计数对象, 配合使用的方式,来解决这个问题。 -from dataclasses import dataclass -import numpy as np -import threading -from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray -import torch.distributed as dist -import time -import torch.multiprocessing as mp -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class InferStateLock: - def __init__(self, name, rank_in_dp: int, dp_rank_in_node: int, dp_world_size: int): - self.infer_lock = threading.Lock() - self.dp_rank_in_node = dp_rank_in_node - # sync_world_size 应该是 min(dp_world_size, node_world_size) - self.dp_world_size = dp_world_size - self.rank_in_dp = rank_in_dp - # 默认开 128 tp 的空间, 现在应该没什么卡能开这么大的tp 吧 - self.lock_tp_infos = SharedArray( - f"{name}_dp_rank_{str(self.dp_rank_in_node)}_lock_tp_infos", shape=(self.dp_world_size + 1,), dtype=np.int64 - ) - self.lock_tp_infos.arr[:] = 0 - - def add_cur_mark(self): - self.lock_tp_infos.arr[self.rank_in_dp] += 1 - - def get_cur_mark(self): - return self.lock_tp_infos.arr[self.rank_in_dp] - - def get_max_mark_in_group(self): - return np.max(self.lock_tp_infos.arr[0 : self.dp_world_size]) - - def judge_cur_mark_equal_max_mark_in_group(self): - return self.get_cur_mark() == self.get_max_mark_in_group() - - def judge_mark_in_group_all_same(self): - marks = self.lock_tp_infos.arr[0 : self.dp_world_size] - return bool(np.all(marks == marks[0])) - - def acquire_lock_and_update_cur_mark(self): - self.infer_lock.acquire() - self.add_cur_mark() - - def release_lock(self): - self.infer_lock.release() - - def set_group_wait_mark(self): - if self.rank_in_dp == 0: - self.lock_tp_infos.arr[-1] = 1 - - def unset_group_wait_mark(self): - if self.rank_in_dp == 0: - self.lock_tp_infos.arr[-1] = 0 - - def get_group_wait_mark(self): - return self.lock_tp_infos.arr[-1] - - -@dataclass -class G_Infer_Lock: - obj: InferStateLock = None - dp_world_size: int = None - - def acquire(self): - if self.obj is not None: - # 当遇到有同步请求的时候,同时自己的mark已经是最大的mark的时候,就在这里休眠, - # 不去竞争锁, 因为 wait_mark == 1 的时候, 说明acquire_lock_until_ready被调用, - # 有推理进程在申请同步点操作 - while self.obj.get_group_wait_mark() == 1 and self.obj.judge_cur_mark_equal_max_mark_in_group(): - time.sleep(0) - - self.obj.acquire_lock_and_update_cur_mark() - - def release(self): - if self.obj is not None: - self.obj.release_lock() - - -# 后续由 backend 对象来对obj进行初始化赋值,方便进行全局调用 -g_infer_state_lock = G_Infer_Lock() - - -# 下面两个函数需要配对使用 -def acquire_lock_until_ready(nccl_group): - # 单卡一tp不用过度加锁 - if g_infer_state_lock.dp_world_size == 1: - g_infer_state_lock.obj.infer_lock.acquire() - return - - g_infer_state_lock.obj.set_group_wait_mark() - while True: - g_infer_state_lock.obj.infer_lock.acquire() - dist.barrier(nccl_group) - judge_ans = g_infer_state_lock.obj.judge_mark_in_group_all_same() - dist.barrier(nccl_group) - - if judge_ans is not True: - # 释放锁进行重试 - g_infer_state_lock.obj.infer_lock.release() - time.sleep(0.001) - logger.info("wait get locks sleep 1ms") - else: - break - - g_infer_state_lock.obj.unset_group_wait_mark() - return - - -def release_acquired_lock(): - g_infer_state_lock.obj.infer_lock.release() - - -@dataclass -class G_Router_Lock: - """ - 保护pd分离模式下, 一些调度相关信息数据的操作。 - """ - - obj = None # 进程锁对象 - - def acquire(self): - if self.obj is not None: - self.obj.acquire() - - def release(self): - if self.obj is not None: - self.obj.release() - - -g_router_lock = G_Router_Lock() diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 75856b1086..89e0508608 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -10,7 +10,7 @@ from .triton_kernel.multimodal_emb import mark_multimodal_obj from .batch_objs import ModelInput from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.dist_utils import get_global_dp_rank +from lightllm.utils.dist_utils import get_global_dp_rank, get_dp_world_size from .attention import BasePrefillAttState, BaseDecodeAttState @@ -37,6 +37,10 @@ def __init__(self): self.b_shared_seq_len: torch.Tensor = None # only for diverse mode used in decode phase. self.b_mark_shared_group: torch.Tensor = None # only for diverse mode used in decode phase. + self.b_mtp_index: torch.Tensor = None + # only for mrope model in decode phase used. + self.b_position_delta: torch.Tensor = None + self.b_seq_len: torch.Tensor = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None @@ -144,7 +148,19 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): self.decode_att_state1.copy_for_decode_cuda_graph(new_infer_state.decode_att_state1) return - def prefill_dp_balance(self, input_ids: torch.Tensor): + def call_overlap_hook(self): + """ + overlap_hook 是在 microbatch overlap 的运行模式下,用于调用绑定在inferstate上的hook函数,用于 + 实现折叠通信和计算的效果,普通模式并不调用这个函数。这是个快速调用函数,用于减少重复代码。 + """ + if getattr(self, "hook", None) is not None: + self.hook() + self.hook = None + return + + ##### prefill dp balance 相关的函数 ##### + + def prepare_prefill_dp_balance(self): """ 在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致 的prefill 推理性能下降 @@ -152,6 +168,8 @@ def prefill_dp_balance(self, input_ids: torch.Tensor): assert self.is_prefill import torch.distributed as dist + input_ids = self.input_ids # 原始输入的input_ids + self.need_dp_prefill_balance = True args = get_env_start_args() @@ -165,19 +183,31 @@ def prefill_dp_balance(self, input_ids: torch.Tensor): group=self.dist_group.dp_prefill_balance_group, async_op=False, ) - dp_input_lens = dp_input_lens.detach().cpu() - self.dp_origin_lens = dp_input_lens.tolist() - sum_input_len = dp_input_lens.sum().item() - dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)] - for i in range(sum_input_len % args.dp): - dp_handle_lens[i] += 1 + dp_input_lens = dp_input_lens.detach().cpu().tolist() + self.dp_origin_lens = dp_input_lens.copy() + sum_input_len = sum(dp_input_lens) + if not args.enable_tpsp_mix_mode: + dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)] + for i in range(sum_input_len % args.dp): + dp_handle_lens[i] += 1 + else: + # tpsp mix mode 需要让每个dp 的处理长度是 tp_world_size 的整数倍 + tp_world_size = get_dp_world_size() + assert all(e % tp_world_size == 0 for e in dp_input_lens) + _dp_input_lens = [e // tp_world_size for e in dp_input_lens] + _sum_input_len = sum(_dp_input_lens) + _dp_input_lens = [_sum_input_len // args.dp for _ in range(args.dp)] + for i in range(_sum_input_len % args.dp): + _dp_input_lens[i] += 1 + dp_handle_lens = [_dp_input_len * tp_world_size for _dp_input_len in _dp_input_lens] + assert sum(dp_handle_lens) == sum_input_len self.dp_handle_lens = dp_handle_lens.copy() dest_dp_inputs = [[] for _ in range(args.dp)] # 分配每个dp 的原始输入和分配后的原始输入 origin_datas = collections.deque() - for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()): + for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens): handle_len = dp_handle_lens[origin_dp_index] if origin_dp_input_len > handle_len: origin_datas.append((origin_dp_index, handle_len, origin_dp_input_len)) @@ -216,29 +246,49 @@ def prefill_dp_balance(self, input_ids: torch.Tensor): self.dp_input_split_sizes = dp_input_split_sizes self.dp_output_split_sizes = dp_output_split_sizes - new_input_ids = self._all_to_all_balance_get(input_ids) if hasattr(self, "position_ids") and self.position_ids is not None: # deepseekv2 mla 特殊模型需要保留原始的 position_ids, 用于减少通信量 self._unbalance_position_ids = self.position_ids + self._balance_position_ids = self._all_to_all_balance_get(self.position_ids, change_state=False) - self.position_ids = self._all_to_all_balance_get(self.position_ids) if hasattr(self, "position_cos") and self.position_cos is not None: # deepseekv2 mla 特殊模型需要保留原始的 position_cos, 用于减少通信量 self._unbalance_position_cos = self.position_cos + self._balance_position_cos = self._all_to_all_balance_get(self.position_cos, change_state=False) - self.position_cos = self._all_to_all_balance_get(self.position_cos) if hasattr(self, "position_sin") and self.position_sin is not None: # deepseekv2 mla 特殊模型需要保留原始的 position_sin, 用于减少通信量 self._unbalance_position_sin = self.position_sin - - self.position_sin = self._all_to_all_balance_get(self.position_sin) + self._balance_position_sin = self._all_to_all_balance_get(self.position_sin, change_state=False) self._unbalance_input_ids = self.input_ids - self.input_ids = new_input_ids + self._balance_input_ids = self._all_to_all_balance_get(input_ids, change_state=False) + + return - return new_input_ids + def __change_to_unbalance(self): + self.input_ids = self._unbalance_input_ids + if hasattr(self, "position_ids"): + self.position_ids = self._unbalance_position_ids + if hasattr(self, "position_cos"): + self.position_cos = self._unbalance_position_cos + if hasattr(self, "position_sin"): + self.position_sin = self._unbalance_position_sin + return - def _all_to_all_balance_get(self, data: torch.Tensor): + def __change_to_balance(self): + self.input_ids = self._balance_input_ids + if hasattr(self, "position_ids"): + self.position_ids = self._balance_position_ids + if hasattr(self, "position_cos"): + self.position_cos = self._balance_position_cos + if hasattr(self, "position_sin"): + self.position_sin = self._balance_position_sin + return + + def _all_to_all_balance_get(self, data: torch.Tensor, change_state: bool = True): + if change_state: + self.__change_to_balance() dp_rank = get_global_dp_rank() import torch.distributed as dist from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -266,7 +316,10 @@ def _all_to_all_balance_get(self, data: torch.Tensor): ) return dest_data.view(-1, *old_shape[1:]) - def _all_to_all_unbalance_get(self, data: torch.Tensor): + def _all_to_all_unbalance_get(self, data: torch.Tensor, change_state: bool = True): + if change_state: + self.__change_to_unbalance() + dp_rank = get_global_dp_rank() import torch.distributed as dist from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -293,7 +346,7 @@ def _all_to_all_unbalance_get(self, data: torch.Tensor): ) return origin_data.view(-1, *old_shape[1:]) - # 用于 prefll cuda graph 的专用功能接口 + # 用于 prefill cuda graph 的专用功能接口 def prefill_cuda_graph_create_graph_obj(self): if not hasattr(self, "prefill_cuda_graph_exe_list"): self.prefill_cuda_graph_exe_list = [] diff --git a/lightllm/common/basemodel/layer_infer/base_layer_infer.py b/lightllm/common/basemodel/layer_infer/base_layer_infer.py index c61e4d9e44..354920f142 100644 --- a/lightllm/common/basemodel/layer_infer/base_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/base_layer_infer.py @@ -1,9 +1,13 @@ import torch +import torch.distributed as dist from typing import Dict, Iterable, Literal, Tuple, Union, List from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from .cache_tensor_manager import g_cache_manager +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor, all_reduce +from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy class BaseLayerInfer: @@ -26,12 +30,6 @@ def alloc_tensor( """ """ return g_cache_manager.alloc_tensor(shape, dtype, device=device) - def tpsp_context_forward(self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BaseLayerWeight): - raise Exception("need to impl") - - def tpsp_token_forward(self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BaseLayerWeight): - raise Exception("need to impl") - def overlap_tpsp_token_forward( self, input0: torch.Tensor, @@ -51,3 +49,50 @@ def overlap_tpsp_context_forward( layer_weight: BaseLayerWeight, ): raise Exception("need to impl") + + def _tpsp_allgather(self, input: torch.Tensor, infer_state: InferStateInfo): + if self.tp_world_size_ > 1 and get_env_start_args().enable_tpsp_mix_mode: + sp_token_num, hidden_dim = input.shape + gather_input = self.alloc_tensor( + (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device + ) + all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) + return gather_input + return input + + def _tpsp_reduce(self, input: torch.Tensor, infer_state: InferStateInfo): + """ + 函数内部会根据当前的启动参数决定是进行reduce scatter还是all reduce + """ + if self.tp_world_size_ > 1 and get_env_start_args().enable_tpsp_mix_mode: + sp_token_num = input.shape[0] // self.tp_world_size_ + assert input.shape[0] % self.tp_world_size_ == 0 + hidden_dim = input.view(input.shape[0], -1).shape[1] + reduce_o_tensor = self.alloc_tensor((sp_token_num, hidden_dim), dtype=input.dtype, device=input.device) + reduce_scatter_tensor( + output=reduce_o_tensor, + input=input, + op=dist.ReduceOp.SUM, + group=infer_state.dist_group, + async_op=False, + ) + return reduce_o_tensor + elif self.tp_world_size_ > 1: + all_reduce(input, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return input + return input + + def _tpsp_sp_split(self, input: torch.Tensor, infer_state: InferStateInfo): + """ + 根据当前的启动参数决定是否将请求进行sp分割,如果需要分割,则进行分割,并返回分割后的结果 + 如果不需要分割,则返回原始请求, 举列说明,如果input shape 为【16, 1024】,tp_world_size为4,则分割后返回的shape为【4, 1024】 + """ + if self.tp_world_size_ > 1 and get_env_start_args().enable_tpsp_mix_mode: + input = sp_pad_copy( + in_tensor=input, + sp_rank_id=self.tp_rank_, + sp_world_size=self.tp_world_size_, + alloc_func=self.alloc_tensor, + ) + return input + return input diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 7889e8090e..8bcf99b992 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -33,6 +33,7 @@ class BufNode: inner_tensor: torch.Tensor shape_key: Tuple[int, torch.dtype] storage_weak_ptr: int + free_use_count_bias: int = 0 shape_to_tensor: Dict[Union[torch.Size, Iterable[int]], torch.Tensor] = field(default_factory=dict) def __del__(self): @@ -99,7 +100,8 @@ def alloc_tensor( # 回收可能消亡的 tensor for ptr in self.changed_ptr: t_buf_node = self.ptr_to_bufnode[ptr] - if self.use_count(ptr) == 1 + len(t_buf_node.shape_to_tensor): + free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor) + if self.use_count(ptr) <= free_use_count: self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node) self.changed_ptr.clear() @@ -131,6 +133,7 @@ def alloc_tensor( self.ptr_to_bufnode[storage_weak_ptr] = buf_node if shape not in buf_node.shape_to_tensor: buf_node.shape_to_tensor[shape] = buf_node.inner_tensor.view(shape) + buf_node.free_use_count_bias = self.use_count(storage_weak_ptr) - (1 + len(buf_node.shape_to_tensor)) mark_tensor = buf_node.shape_to_tensor[shape] ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断 ans.storage_weak_ptr = buf_node.storage_weak_ptr diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 646f998642..f0cc129c09 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -32,12 +32,9 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: raise Exception("need to impl") - def _tpsp_get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: - raise Exception("need to impl") - def _post_cache_kv(self, cache_kv, infer_state: InferStateInfo, layer_weight): mem_manager = infer_state.mem_manager - mem_manager.copy_kv_to_mem_manager( + mem_manager.operator.copy_kv_to_mem_manager( layer_index=self.layer_num_, mem_index=infer_state.mem_index, kv=cache_kv, @@ -53,15 +50,9 @@ def _token_attention_kernel(self, q, infer_state: InferStateInfo, layer_weight, def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") - def _tpsp_get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") - def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -70,8 +61,7 @@ def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, ) q = None o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): @@ -83,8 +73,7 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings @@ -94,8 +83,7 @@ def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, l o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): @@ -106,50 +94,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return input_embdings - - def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run( - q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight - ) - q = None - o = self._tpsp_get_o(o, infer_state, layer_weight) - return o - - def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight) - input_embdings.add_(o.view(-1, self.embed_dim_)) - o = None - - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight) - input1 = None - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return input_embdings - def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._token_attention_kernel(q, infer_state, layer_weight) - q = None - o = self._tpsp_get_o(o, infer_state, layer_weight) - return o - - def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight) - input_embdings.add_(o.view(-1, self.embed_dim_)) - o = None - - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight) - input1 = None input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings @@ -158,11 +103,13 @@ def _context_attention_wrapper_run( ) -> torch.Tensor: if torch.cuda.is_current_stream_capturing(): q = q.contiguous() - cache_kv = cache_kv.contiguous() - _q, _cache_kv = ( - tensor_to_no_ref_tensor(q), - tensor_to_no_ref_tensor(cache_kv), - ) + # cache_kv is None for layers that own no K/V slot (e.g. gemma4 + # KV-shared layers, which read K/V from a prior layer's cache and + # ignore this arg in _context_attention_kernel). Skip the + # graph-input plumbing for it instead of crashing on None. + cache_kv = cache_kv.contiguous() if cache_kv is not None else None + _q = tensor_to_no_ref_tensor(q) + _cache_kv = tensor_to_no_ref_tensor(cache_kv) if cache_kv is not None else None pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() pre_capture_graph.__exit__(None, None, None) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8f54e14a72..26f2b338b7 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -33,12 +33,16 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + per_expert_scale_name: str = "", ) -> None: super().__init__(data_type=data_type) self.w1_weight_name = gate_proj_name self.w2_weight_name = down_proj_name self.w3_weight_name = up_proj_name self.e_score_correction_bias_name = e_score_correction_bias_name + # gemma4 的专家计算出的值都需要一个 scale 值,每个专家有自己独立的scale参数 + # per_expert_scale_name 是专家的scale参数权重的名称, 为 "" 表示没有专家独立的scale参数 + self.per_expert_scale_name = per_expert_scale_name self.weight_prefix = weight_prefix self.layer_num_ = layer_num self.global_rank_ = get_global_rank() @@ -130,6 +134,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +150,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + per_expert_scale=self.per_expert_scale, + shared_expert_gate=shared_expert_gate, ) def low_latency_dispatch( @@ -261,18 +268,26 @@ def combine( def load_hf_weights(self, weights): # Load bias - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + self._load_e_score_correction_bias(weights) + self._load_per_expert_scale(weights) self._load_weight(self.expert_idx_to_local_idx, weights) if self.redundancy_expert_num > 0: self._load_weight(self.redundancy_expert_idx_to_local_idx, weights) def verify_load(self): - return all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list) + weight_load_ok = all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list) + per_expert_scale_load_ok = ( + True if self.per_expert_scale is None else getattr(self.per_expert_scale, "load_ok", False) + ) + e_score_correction_bias_load_ok = ( + True if self.e_score_correction_bias is None else getattr(self.e_score_correction_bias, "load_ok", False) + ) + return weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok def _create_weight(self): intermediate_size = self.split_inter_size self.e_score_correction_bias = None + self.per_expert_scale = None # Create e_score_correction_bias if self.e_score_correction_bias_name: self.e_score_correction_bias = torch.empty( @@ -280,6 +295,15 @@ def _create_weight(self): dtype=self.data_type_, device=f"cuda:{self.device_id_}", ) + self.e_score_correction_bias.load_ok = False + + if self.per_expert_scale_name: + self.per_expert_scale = torch.empty( + (self.n_routed_experts,), + dtype=torch.float32, + device=f"cuda:{self.device_id_}", + ) + self.per_expert_scale.load_ok = False self.w13, w13_param_list = self.quant_method.create_moe_weight( out_dims=[intermediate_size, intermediate_size], @@ -299,6 +323,16 @@ def _create_weight(self): self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) + def _load_e_score_correction_bias(self, weights: Dict[str, torch.Tensor]): + if self.e_score_correction_bias_name and self.e_score_correction_bias_name in weights: + self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + self.e_score_correction_bias.load_ok = True + + def _load_per_expert_scale(self, weights: Dict[str, torch.Tensor]): + if self.per_expert_scale_name and self.per_expert_scale_name in weights: + self.per_expert_scale.copy_(weights[self.per_expert_scale_name].to(self.per_expert_scale.dtype)) + self.per_expert_scale.load_ok = True + def _get_expert_weight_list(self, weight_pack: WeightPack): weight_list = [] for idx in range(self.local_n_routed_experts): @@ -307,7 +341,6 @@ def _get_expert_weight_list(self, weight_pack: WeightPack): return weight_list def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]): - # Load each expert with TP slicing for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py new file mode 100644 index 0000000000..dd0d270ba9 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py @@ -0,0 +1,38 @@ +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight + + +class Gemma4PackedFusedMoeWeight(FusedMoeWeight): + def load_hf_weights(self, weights): + # 将权重名称的格式对齐基类的统一加载格式。 + gate_up_name = f"{self.weight_prefix}.gate_up_proj" + down_name = f"{self.weight_prefix}.down_proj" + assert not self.enable_ep_moe, "Gemma-4 packed MoE currently supports TP mode only." + moe_intermediate_size = self.moe_intermediate_size + + if gate_up_name in weights: + gate_up_weight = weights[gate_up_name] + + for expert_idx in range(self.n_routed_experts): + expert_gate_weight = gate_up_weight[expert_idx, :moe_intermediate_size, :] + expert_up_weight = gate_up_weight[expert_idx, moe_intermediate_size:, :] + + weights[ + f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" + ] = expert_gate_weight + weights[ + f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" + ] = expert_up_weight + + del weights[gate_up_name] + + if down_name in weights: + down_weight = weights[down_name] + for expert_idx in range(self.n_routed_experts): + expert_down_weight = down_weight[expert_idx, :, :] + weights[ + f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" + ] = expert_down_weight + del weights[down_name] + + super().load_hf_weights(weights) + return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b4..90ce5761c3 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -144,7 +144,9 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): + assert shared_expert_gate is None, "shared_expert_gate is not supported by GPT-OSS fused MoE" topk_weights, topk_ids = self._router(router_logits, top_k) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac185..8467c328da 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + # Qwen3.5 uses this gate to control fused shared expert aggregation weights. + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index bdd86eb51e..5da17c57e1 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -4,11 +4,16 @@ from lightllm.distributed import dist_group_manager from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.common.quantization.quantize_method import WeightPack -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( - fused_experts_impl, + fused_experts, + get_ep_num_sms, masked_group_gemm, - _deepgemm_grouped_fp8_nt_contiguous, + deepgemm_grouped_fp8_nt_contiguous, + quantize_fused_experts_input, ) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, @@ -31,8 +36,11 @@ def _select_experts( topk_group: int, num_expert_group: int, scoring_func: str, + per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): """Select experts and return topk weights and ids.""" + assert shared_expert_gate is None, "fused shared expert as MoE is not supported by DeepGEMM fused MoE" from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts topk_weights, topk_ids = select_experts( @@ -48,6 +56,8 @@ def _select_experts( ) if self.routed_scaling_factor != 1.0: topk_weights.mul_(self.routed_scaling_factor) + if per_expert_scale is not None: + topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) if self.redundancy_expert_num > 0: redundancy_topk_ids_repair( topk_ids=topk_ids, @@ -69,24 +79,15 @@ def _fused_experts( router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, ): - - w13_weight, w13_scale = w13.weight, w13.weight_scale - w2_weight, w2_scale = w2.weight, w2.weight_scale - use_fp8_w8a8 = self.quant_method.method_name != "none" - output = fused_experts_impl( + output = fused_experts( hidden_states=input_tensor, - w1=w13_weight, - w2=w2_weight, + w13=w13, + w2=w2, topk_weights=topk_weights, topk_idx=topk_ids.to(torch.long), num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy - buffer=dist_group_manager.ep_buffer, + quant_method=self.quant_method, is_prefill=is_prefill, - use_fp8_w8a8=use_fp8_w8a8, - use_fp8_all2all=use_fp8_w8a8, - use_int8_w8a16=False, # default to False - w1_scale=w13_scale, - w2_scale=w2_scale, previous_event=None, # for overlap ) return output @@ -116,13 +117,13 @@ def low_latency_dispatch( ) topk_idx = topk_idx.to(torch.long) - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() use_fp8_w8a8 = self.quant_method.method_name != "none" - recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - self.total_expert_num_contain_redundancy, + recv_x, masked_m, handle, event, hook = dist_group_manager.ep_low_latency_buffer.low_latency_dispatch( + topk_idx=topk_idx, + x=hidden_states, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + num_experts=self.total_expert_num_contain_redundancy, use_fp8=use_fp8_w8a8, async_finish=False, return_recv_hook=True, @@ -153,13 +154,8 @@ def select_experts_and_quant_input( num_expert_group=n_group, scoring_func=scoring_func, ) - w13_weight, w13_scale = w13.weight, w13.weight_scale - block_size_k = 0 - if w13_weight.ndim == 3: - block_size_k = w13_weight.shape[2] // w13_scale.shape[2] - assert block_size_k == 128, "block_size_k must be 128" - qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13_weight.dtype) - return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale) + qinput_tensor = quantize_fused_experts_input(hidden_states, w13, self.quant_method) + return topk_weights, topk_idx.to(torch.long), qinput_tensor def dispatch( self, @@ -169,38 +165,26 @@ def dispatch( overlap_event: Optional[Any] = None, ): buffer = dist_group_manager.ep_buffer - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, - self.total_expert_num_contain_redundancy, - previous_event=overlap_event, - async_finish=True, - allocate_on_comm_stream=True, - ) - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + recv_x, recv_topk_idx, recv_topk_weights, handle, event = buffer.dispatch( qinput_tensor, topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=True, - allocate_on_comm_stream=True, + num_experts=self.total_expert_num_contain_redundancy, + num_max_tokens_per_rank=num_max_tokens_per_rank, expert_alignment=128, + num_sms=get_ep_num_sms(), + previous_event=overlap_event, + async_with_compute_stream=True, + allocate_on_comm_stream=True, + do_cpu_sync=True, + do_handle_copy=False, ) def hook(): event.current_stream_wait() - return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook + return recv_x, recv_topk_idx, recv_topk_weights, handle.num_recv_tokens_per_expert_list, handle, hook def masked_group_gemm( self, @@ -214,7 +198,14 @@ def masked_group_gemm( w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale return masked_group_gemm( - recv_x, masked_m, dtype, w13_weight, w13_scale, w2_weight, w2_scale, expected_m=expected_m + recv_x, + masked_m, + dtype, + w13_weight, + w13_scale, + w2_weight, + w2_scale, + expected_m=expected_m, ) def prefilled_group_gemm( @@ -272,7 +263,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -286,7 +277,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous( + deepgemm_grouped_fp8_nt_contiguous( (qsilu_out, qsilu_out_scale), (w2_weight, w2_scale), gemm_out_b, m_indices ) # gather and local reduce @@ -310,7 +301,7 @@ def low_latency_combine( topk_weights: torch.Tensor, handle: Any, ): - combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine( + combined_x, event_overlap, hook = dist_group_manager.ep_low_latency_buffer.low_latency_combine( gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True ) return combined_x, hook @@ -326,8 +317,9 @@ def combine( gemm_out_b, handle, topk_weights=None, - async_finish=True, + num_sms=get_ep_num_sms(), previous_event=overlap_event, + async_with_compute_stream=True, allocate_on_comm_stream=True, ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 6391a10800..0094b09b1c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -7,6 +7,7 @@ AWQMARLINW4A16QuantizationMethod, ) from typing import Optional +from lightllm.utils.config_utils import ffn_use_tanh_approximate_gelu class FuseMoeMarlin(FuseMoeTriton): @@ -38,6 +39,8 @@ def _fused_experts( self.quant_method: AWQMARLINW4A16QuantizationMethod = self.quant_method + activation = "silu" if not ffn_use_tanh_approximate_gelu() else "gelu" + fused_marlin_moe( input_tensor, w1_weight, @@ -52,6 +55,7 @@ def _fused_experts( quant_type_id=self.quant_method.vllm_quant_type.id, apply_router_weight_on_input=False, global_num_experts=-1, + activation=activation, expert_map=None, w1_zeros=w1_zero_point, w2_zeros=w2_zero_point, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index d6e923a115..110a83094b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -42,6 +42,8 @@ def _select_experts( topk_group: int, num_expert_group: int, scoring_func: str, + per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): """Select experts and return topk weights and ids.""" from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts @@ -59,27 +61,20 @@ def _select_experts( ) if self.routed_scaling_factor != 1.0: topk_weights.mul_(self.routed_scaling_factor) + if per_expert_scale is not None: + topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts, - end=self.n_routed_experts + self.num_fused_shared_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, + from lightllm.common.basemodel.triton_kernel.fused_moe.append_shared_expert_topk import ( + append_fused_shared_experts, ) - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + topk_weights, topk_ids = append_fused_shared_experts( + topk_weights=topk_weights, + topk_ids=topk_ids, + shared_expert_start_id=self.n_routed_experts, + num_fused_shared_experts=self.num_fused_shared_experts, + shared_expert_gate=shared_expert_gate, + ) return topk_weights, topk_ids def _fused_experts( @@ -125,6 +120,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -136,6 +133,8 @@ def __call__( topk_group=topk_group, num_expert_group=num_expert_group, scoring_func=scoring_func, + per_expert_scale=per_expert_scale, + shared_expert_gate=shared_expert_gate, ) output = self._fused_experts( input_tensor=input_tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index fb50398368..6f3f9fe62c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -53,13 +53,8 @@ def __init__( ) -> None: self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() - self.repeat_times = 1 - assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( - f"kv_head_num must be divisible by tp_world_size_ or " - f"tp_world_size_ must be divisible by kv_head_num, " - f"but found: {kv_head_num} % {self.tp_world_size_}" - ) - kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + self.repeat_times = self._get_repeat_times(kv_head_num) + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num, self.repeat_times) * head_dim out_dims = [kv_hidden_size, kv_hidden_size] super().__init__( in_dim=in_dim, @@ -78,18 +73,19 @@ def __init__( repeat_times=self.repeat_times, ) - def _get_tp_padded_head_num(self, head_num: int): - if head_num % self.tp_world_size_ == 0: - return head_num // self.tp_world_size_ - elif self.tp_world_size_ % head_num == 0: - self.repeat_times = self.tp_world_size_ // head_num - return self.repeat_times * head_num // self.tp_world_size_ + def _get_repeat_times(self, kv_head_num: int) -> int: + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by kv_head_num, " + f"but found: {kv_head_num} % {self.tp_world_size_}" + ) + if kv_head_num % self.tp_world_size_ == 0: + return 1 else: - raise ValueError( - f"head_num must be divisible by tp_world_size_ or " - f"tp_world_size_ must be divisible by head_num, " - f"but found: {head_num} % {self.tp_world_size_}" - ) + return self.tp_world_size_ // kv_head_num + + def _get_tp_padded_head_num(self, head_num: int, repeat_times: int) -> int: + return repeat_times * head_num // self.tp_world_size_ class QKVROWNMMWeight(MMWeightTpl): @@ -109,17 +105,12 @@ def __init__( self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() self.q_repeat_times = 1 - self.kv_repeat_times = 1 + self.kv_repeat_times = self._get_kv_repeat_times(kv_head_num) assert q_head_num % self.tp_world_size_ == 0, ( f"q_head_num must be divisible by tp_world_size_, " f"but found: {q_head_num} % {self.tp_world_size_}" ) - assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( - f"kv_head_num must be divisible by tp_world_size_ or " - f"tp_world_size_ must be divisible by kv_head_num, " - f"but found: {kv_head_num} % {self.tp_world_size_}" - ) q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim - kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num, self.kv_repeat_times) * head_dim out_dims = [q_hidden_size, kv_hidden_size, kv_hidden_size] super().__init__( in_dim=in_dim, @@ -157,18 +148,19 @@ def _get_param_slicer(self, sub_child_index: int): else: return self.kv_param_slicer - def _get_tp_padded_head_num(self, head_num: int): - if head_num % self.tp_world_size_ == 0: - return head_num // self.tp_world_size_ - elif self.tp_world_size_ % head_num == 0: - self.kv_repeat_times = self.tp_world_size_ // head_num - return self.kv_repeat_times * head_num // self.tp_world_size_ + def _get_kv_repeat_times(self, kv_head_num: int) -> int: + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by kv_head_num, " + f"but found: {kv_head_num} % {self.tp_world_size_}" + ) + if kv_head_num % self.tp_world_size_ == 0: + return 1 else: - raise ValueError( - f"head_num must be divisible by tp_world_size_ or " - f"tp_world_size_ must be divisible by head_num, " - f"but found: {head_num} % {self.tp_world_size_}" - ) + return self.tp_world_size_ // kv_head_num + + def _get_tp_padded_head_num(self, head_num: int, repeat_times: int) -> int: + return repeat_times * head_num // self.tp_world_size_ class ROWBMMWeight(BMMWeightTpl): diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index a8b2616418..a70c5ce63b 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -2,12 +2,13 @@ import torch import copy import bisect +import triton from typing import List, Tuple from typing import Optional from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor -from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup +from lightllm.distributed import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from .infer_struct import InferStateInfo from .cuda_graph import CudaGraph @@ -18,9 +19,9 @@ class PrefillCudaGraph: # CudaGraph forward pass for the decoding stage. - def __init__(self, decode_cuda_graph: CudaGraph): + def __init__(self, decode_cuda_graph: CudaGraph, tp_world_size: int): self.graph = {} - + self.tp_world_size = tp_world_size if decode_cuda_graph is not None: self.mempool = decode_cuda_graph.mempool # prefill 和 decode 共享一个 mempool else: @@ -28,14 +29,30 @@ def __init__(self, decode_cuda_graph: CudaGraph): self.args = get_env_start_args() self.enable_prefill_microbatch_overlap = self.args.enable_prefill_microbatch_overlap - self.max_handle_token_num = self.args.prefll_cudagraph_max_handle_token - - graph_handle_token_nums = [] - for i in range(2048): - token_num = int(2 ** (2 * i)) - if 1 < token_num < self.max_handle_token_num: - graph_handle_token_nums.append(token_num) + self.max_handle_token_num = self.args.prefill_cudagraph_max_handle_token + if self.args.batch_max_tokens is not None: + self.max_handle_token_num = min(self.max_handle_token_num, self.args.batch_max_tokens) + + graph_handle_token_nums = ( + list(range(4, 33, 4)) + + list(range(48, 257, 16)) + + list(range(288, 513, 32)) + + list(range(576, 1024 + 1, 64)) + + list(range(1280, 4096 + 1, 256)) + + list(range(4608, self.max_handle_token_num + 1, 512)) + ) + graph_handle_token_nums = [e for e in graph_handle_token_nums if e <= self.max_handle_token_num] graph_handle_token_nums.append(self.max_handle_token_num) + + graph_handle_token_nums = list(set[int](graph_handle_token_nums)) + graph_handle_token_nums.sort() + if self.args.enable_tpsp_mix_mode: + graph_handle_token_nums = [ + triton.cdiv(e, self.tp_world_size) * self.tp_world_size for e in graph_handle_token_nums + ] + graph_handle_token_nums = list(set(graph_handle_token_nums)) + graph_handle_token_nums.sort() + self.graph_handle_token_nums = graph_handle_token_nums logger.info(f"prefill cuda graph graph_handle_token_nums: {self.graph_handle_token_nums}") @@ -61,15 +78,13 @@ def _capture_prefill( self, prefill_func, input_tensors: List[torch.Tensor], infer_state: InferStateInfo ) -> List[torch.Tensor]: handle_token_num = infer_state.total_token_num - infer_state.prefix_total_token_num - dist_group: CustomProcessGroup = infer_state.dist_group - with lightllm_capture_graph(dist_group): - infer_state.mem_pool = self.mempool - infer_state.prefill_cuda_graph_create_graph_obj() - infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() - graph_input_tensors: List[torch.Tensor] = [torch.empty_like(e) for e in input_tensors] - graph_out_tensors: List[torch.Tensor] = prefill_func(graph_input_tensors, infer_state) - graph_out_tensors = [e.contiguous() for e in graph_out_tensors] - infer_state.prefill_cuda_graph_get_current_capture_graph().__exit__(None, None, None) + infer_state.mem_pool = self.mempool + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + graph_input_tensors: List[torch.Tensor] = [torch.empty_like(e) for e in input_tensors] + graph_out_tensors: List[torch.Tensor] = prefill_func(graph_input_tensors, infer_state) + graph_out_tensors = [e.contiguous() for e in graph_out_tensors] + infer_state.prefill_cuda_graph_get_current_capture_graph().__exit__(None, None, None) graph_input_tensors = [tensor_to_no_ref_tensor(e) for e in graph_input_tensors] graph_out_tensors = [tensor_to_no_ref_tensor(e) for e in graph_out_tensors] diff --git a/lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py index 03ed864475..69d7a71966 100644 --- a/lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py @@ -132,7 +132,7 @@ def context_attention_fwd( # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} sm_scale = 1.0 / (Lq ** 0.5) batch, head = b_seq_len.shape[0], q.shape[1] diff --git a/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py index 174aea6c1f..88bebd5cf8 100644 --- a/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py +++ b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py @@ -73,11 +73,11 @@ def _fwd_kernel_token_att1( @torch.no_grad() def token_att_fwd(q, k, att_out, alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch): - BLOCK = 32 # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} + BLOCK = 16 if Lk == 256 else 32 sm_scale = 1.0 / (Lk ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py index e549298e3b..59a7d4f742 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py @@ -2,7 +2,13 @@ def gqa_token_decode_attention_flash_decoding( - q: torch.Tensor, infer_state, cache_k: torch.Tensor, cache_v: torch.Tensor, out=None, alloc_tensor_func=torch.empty + q: torch.Tensor, + infer_state, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + out=None, + alloc_tensor_func=torch.empty, + sliding_window=(-1, -1), ): batch_size = infer_state.batch_size q_head_num, head_dim = q.shape[1], q.shape[2] @@ -39,6 +45,7 @@ def gqa_token_decode_attention_flash_decoding( mid_out=mid_o, mid_out_logsumexp=mid_o_logexpsum, block_seq=BLOCK_SEQ, + sliding_window=sliding_window, ) flash_decode_stage2( mid_out=mid_o, @@ -46,5 +53,6 @@ def gqa_token_decode_attention_flash_decoding( B_Seqlen=infer_state.b_seq_len, out=o_tensor.view(calcu_shape1), block_seq=BLOCK_SEQ, + sliding_window=sliding_window, ) return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index 339088e753..cae913b4bc 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -39,6 +39,8 @@ def _fwd_kernel_flash_decode_stage1( BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + LEFT_SLIDING_WINDOW_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -46,6 +48,12 @@ def _fwd_kernel_flash_decode_stage1( grid_block_num = tl.num_programs(2) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + if USE_SLIDING_WINDOW: + kv_start_index = tl.maximum(cur_batch_seq_len - 1 - LEFT_SLIDING_WINDOW_SIZE, 0) + cur_batch_seq_len = cur_batch_seq_len - kv_start_index + else: + kv_start_index = 0 + req_total_block_num = tl.cdiv(cur_batch_seq_len, BLOCK_SEQ) if block_index >= req_total_block_num: return @@ -77,7 +85,7 @@ def _fwd_kernel_flash_decode_stage1( offs_n_new = start_n * BLOCK_N + offs_n n_mask = offs_n_new < cur_batch_end_index k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + kv_start_index + offs_n_new, mask=n_mask, other=0, ).to(tl.int64) @@ -110,14 +118,8 @@ def _fwd_kernel_flash_decode_stage1( + offs_d[None, :] ) off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + block_index - tl.store( - Mid_O + off_mid_o, - acc / sum_exp[:, None], - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - max_logic + tl.log(sum_exp), - ) + tl.store(Mid_O + off_mid_o, acc / sum_exp[:, None]) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) return @@ -170,6 +172,7 @@ def flash_decode_stage1( mid_out, mid_out_logsumexp, block_seq, + sliding_window=(-1, -1), run_config: Optional[dict] = None, ): """ """ @@ -185,12 +188,21 @@ def flash_decode_stage1( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256, 512} + if Lk >= 256: + BLOCK_N = min(BLOCK_N, 16) + assert BLOCK_SEQ % BLOCK_N == 0 sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] block_num = mid_out.shape[2] grid = (batch, kv_head_num, block_num) gqa_group_size = q.shape[1] // k.shape[1] + sliding_window_left = int(sliding_window[0]) + use_sliding_window = sliding_window_left >= 0 + + # 当前 不支持 right sliding window + if use_sliding_window: + assert sliding_window[1] == 0 _fwd_kernel_flash_decode_stage1[grid]( q, @@ -225,6 +237,8 @@ def flash_decode_stage1( BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, + USE_SLIDING_WINDOW=use_sliding_window, + LEFT_SLIDING_WINDOW_SIZE=sliding_window_left, num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py index 4eff53c3ac..d97a1f2522 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py @@ -22,12 +22,17 @@ def _fwd_kernel_flash_decode_stage2( block_num, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + LEFT_SLIDING_WINDOW_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + if USE_SLIDING_WINDOW: + kv_start_index = tl.maximum(cur_batch_seq_len - 1 - LEFT_SLIDING_WINDOW_SIZE, 0) + cur_batch_seq_len = cur_batch_seq_len - kv_start_index block_num = tl.minimum(tl.cdiv(cur_batch_seq_len, BLOCK_SEQ), block_num) @@ -54,12 +59,17 @@ def _fwd_kernel_flash_decode_stage2( @torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq, sliding_window=(-1, -1)): Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256, 512} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) block_num = mid_out.shape[2] + sliding_window_left = int(sliding_window[0]) + use_sliding_window = sliding_window_left >= 0 + # 当前 不支持 right sliding window + if use_sliding_window: + assert sliding_window[1] == 0 _fwd_kernel_flash_decode_stage2[grid]( B_Seqlen, @@ -79,6 +89,8 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): block_num, BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, + USE_SLIDING_WINDOW=use_sliding_window, + LEFT_SLIDING_WINDOW_SIZE=sliding_window_left, num_warps=4, num_stages=2, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py index afc67a84e8..5ce5d193f5 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py @@ -116,11 +116,11 @@ def _fwd_kernel( @torch.no_grad() def gqa_decode_attention_fwd(q, k, v, o, req_to_tokens, b_req_idx, b_seq_len): - BLOCK = 32 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} + BLOCK = 16 if Lk == 256 else 32 sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 batch = b_req_idx.shape[0] diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py index d6e3628e55..00e884352b 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py @@ -205,7 +205,10 @@ def int4kv_flash_decode_stage1( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] * 2 assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} + if Lk == 256: + BLOCK_N = min(BLOCK_N, 16) + assert BLOCK_SEQ % BLOCK_N == 0 sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] grid_block_num = mid_out.shape[2] diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py index ad6a8b5b3a..b40251b8f8 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py @@ -1,6 +1,5 @@ # 为 diverse mode 定制设计的 int8kv flash decoding attention 实现,可以实现更高效的多样性采样 import torch -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from lightllm.common.basemodel.infer_struct import InferStateInfo from .int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 from .int8kv_flash_decoding_diverse_stage2 import flash_decode_stage2 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py index 295ae66ab3..fb122a79da 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py @@ -276,7 +276,10 @@ def flash_decode_stage1( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} + if Lk == 256: + BLOCK_N = min(BLOCK_N, 16) + assert BLOCK_SEQ % BLOCK_N == 0 sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py index f5c0b9c395..963139b36b 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py @@ -250,7 +250,10 @@ def flash_decode_stage2( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} + if Lk == 256: + BLOCK_N = min(BLOCK_N, 16) + assert BLOCK_SEQ % BLOCK_N == 0 sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py index a82af03349..f96f0cd5d6 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py @@ -67,7 +67,7 @@ def flash_diverse_decode_stage3( block_seq: int, ): Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py index 9a8d0f5cbd..d241d31c07 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py @@ -213,7 +213,10 @@ def flash_decode_stage1( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} + if Lk == 256: + BLOCK_N = min(BLOCK_N, 16) + assert BLOCK_SEQ % BLOCK_N == 0 sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] gqa_group_size = q.shape[1] // k.shape[1] diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage2.py index 43dc6051e2..e211ae584f 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage2.py @@ -62,7 +62,7 @@ def flash_decode_stage2( block_seq: int, ): Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) block_num = mid_out.shape[2] diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py deleted file mode 100644 index f51d611661..0000000000 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops - - -def token_decode_attention_flash_decoding( - q, - infer_state, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=None, - alloc_tensor_func=torch.empty, -): - BLOCK_SEQ = 256 - q_head_num, head_dim = q.shape[1], q.shape[2] - batch_size = infer_state.batch_size - max_kv_seq_len = infer_state.max_kv_seq_len - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" - ) - - light_ops.group8_int8kv_flashdecoding_stage1( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_kv_seq_len, - ) - - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py index f41a5c8fde..196ded61a5 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py @@ -112,7 +112,7 @@ def flash_decode_stage1( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} sm_scale = 1.0 / (Lk ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py index 101e99dde5..a25fcdb431 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py @@ -55,7 +55,7 @@ def _fwd_kernel_flash_decode_stage2( @torch.no_grad() def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py deleted file mode 100644 index b0a9b6245c..0000000000 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops - - -def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - q_head_num = q.shape[1] - head_dim = q.shape[2] - max_kv_seq_len = infer_state.max_kv_seq_len - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" - ) - - light_ops.fp16_flashdecoding_stage1( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_kv_seq_len, - ) - - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index 5ba6d0beb6..ef38e39d6c 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -41,6 +41,8 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_LEFT: tl.constexpr, ): start_m = tl.program_id(0) cur_bh = tl.program_id(1) @@ -60,6 +62,7 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = block_start_loc + tl.arange(0, BLOCK_M) + q_pos = offs_m + prompt_cache_len off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh @@ -76,20 +79,33 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) - # causal mask - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + if USE_SLIDING_WINDOW: + kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_LEFT + kv_start_index = tl.maximum(kv_start_index, 0) + block_kv_len = block_end_loc - kv_start_index + else: + kv_start_index = 0 + block_kv_len = block_end_loc + + # causal (+ sliding-window) mask + for start_n in range(0, block_mask * block_kv_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) + k_pos = kv_start_index + start_n + offs_n # -- compute qk ---- kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), - mask=(start_n + offs_n) < block_end_loc, + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_pos < block_end_loc, other=0, ).to(tl.int64) off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) + k = tl.load(K + off_k, mask=k_pos[None, :] < block_end_loc, other=0.0) qk = tl.dot(q, k) - mask = offs_m[:, None] + prompt_cache_len >= (start_n + offs_n[None, :]) + if USE_SLIDING_WINDOW: + # FA-style left inclusive offset + causal (right=0). + mask = ((q_pos[:, None] - k_pos[None, :]) <= SLIDING_WINDOW_LEFT) & (q_pos[:, None] >= k_pos[None, :]) + else: + mask = q_pos[:, None] >= k_pos[None, :] qk = tl.where(mask, qk * sm_scale, -1.0e8) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -103,7 +119,7 @@ def _fwd_kernel( acc = acc * alpha[:, None] # update acc off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0) + v = tl.load(V + off_v, mask=k_pos[:, None] < block_end_loc, other=0.0) p = p.to(v.dtype) acc = tl.dot(p, v, acc) # update m_i and l_i @@ -121,13 +137,30 @@ def _fwd_kernel( @torch.no_grad() def context_attention_fwd( - q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + sliding_window=(-1, -1), ): BLOCK_M = 128 if not is_tesla() else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} + # Larger head_dim needs smaller tiles to fit in SM shared memory. + # H100/H200 has ~228KB shared memory per SM; a 128x512 bf16 tile already + # consumes 128KB, leaving no room for K/V/scores buffers. + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634, # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。 @@ -141,6 +174,14 @@ def context_attention_fwd( num_warps = 4 if Lk <= 64 else 8 num_stages = 1 + if sliding_window == (-1, -1): + use_sliding_window = False + sliding_window_left = -1 + else: + use_sliding_window = True + assert int(sliding_window[1]) == 0, "sliding_window right must be 0" + sliding_window_left = int(sliding_window[0]) + _fwd_kernel[grid]( q, k, @@ -171,6 +212,8 @@ def context_attention_fwd( BLOCK_DMODEL=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_LEFT=sliding_window_left, num_warps=num_warps, num_stages=num_stages, ) @@ -291,7 +334,14 @@ def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, ma # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} + # Larger head_dim needs smaller tiles to fit in SM shared memory. + # H100/H200 has ~228KB shared memory per SM; a 128x512 bf16 tile already + # consumes 128KB, leaving no room for K/V/scores buffers. + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634, # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。 @@ -463,7 +513,14 @@ def context_attention_fwd_contiguous_kv( # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} + # Larger head_dim needs smaller tiles to fit in SM shared memory. + # H100/H200 has ~228KB shared memory per SM; a 128x512 bf16 tile already + # consumes 128KB, leaving no room for K/V/scores buffers. + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634, # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。 diff --git a/lightllm/common/basemodel/triton_kernel/embedding.py b/lightllm/common/basemodel/triton_kernel/embedding.py index 8c88f9fd23..0c43de6809 100644 --- a/lightllm/common/basemodel/triton_kernel/embedding.py +++ b/lightllm/common/basemodel/triton_kernel/embedding.py @@ -29,7 +29,7 @@ def embedding_kernel( n_ctx_mask = offs_seq < n_ctx token_ids = tl.load(input_ids + offs_seq, mask=n_ctx_mask, other=vob_end_id) id_mask = (token_ids >= vob_start_id) & (token_ids < vob_end_id) - token_ids = token_ids - vob_start_id + token_ids = (token_ids - vob_start_id).to(tl.int64) dim_mask = offs_d < hiden_size load_mask = id_mask[:, None] & dim_mask[None, :] store_mask = n_ctx_mask[:, None] & dim_mask[None, :] @@ -130,4 +130,4 @@ def embedding_old(input_ids, wte_weight, vob_start_id, vob_end_id): t2 = 0 MFLOPS = int(DIM * N_CTX * TEST_COUNT / t1 / 1000 / 1000) - print(f"TP={TP}, Diff={max_diff}, old_t:{t2:.5f}, new_t:{t1:.5f}, MFLOPS={MFLOPS}, SP={t2/t1:.5f}") + print(f"TP={TP}, Diff={max_diff}, old_t:{t2:.5f}, new_t:{t1:.5f}, MFLOPS={MFLOPS}, SP={t2 / t1:.5f}") diff --git a/lightllm/common/basemodel/triton_kernel/flashinfer_mla_plan.py b/lightllm/common/basemodel/triton_kernel/flashinfer_mla_plan.py new file mode 100644 index 0000000000..ed836800d1 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/flashinfer_mla_plan.py @@ -0,0 +1,421 @@ +import torch + +import triton +import triton.language as tl + + +# 这个文件不实现 MLA attention 本身,只在 decode CUDA graph replay 时重新生成 +# FlashInfer MLA 的 plan 表。FlashInfer run kernel 会从 +# decode_wrapper._int_workspace_buffer 中读取这些表: +# +# q_indptr / kv_indptr / q_len / kv_len / q_start / kv_start / kv_end: +# attention 阶段消费的 work item 描述。 +# partial_indptr: +# -1 表示直接写最终输出;否则写到 partial 输出。 +# merge_*: +# kernel 内 split-K merge 阶段消费的描述。 +# work_indptr: +# 每个 cluster 在 work-item 表中的范围。 +# +# 这些数组的 offset 由 FlashInfer 首次 CPU plan 生成,并保存在 +# decode_wrapper._plan_info 中。这里保持相同的 buffer layout,只用 Triton 覆写 +# 数组内容。 + + +@triton.jit +def _fill_exact_mla_decode_plan_kernel( + int_buf_i32, + kv_indptr, + q_indptr_off: tl.constexpr, + kv_indptr_off: tl.constexpr, + partial_indptr_off: tl.constexpr, + merge_start_off: tl.constexpr, + merge_end_off: tl.constexpr, + merge_partial_start_off: tl.constexpr, + merge_partial_end_off: tl.constexpr, + merge_stride_off: tl.constexpr, + q_len_off: tl.constexpr, + kv_len_off: tl.constexpr, + q_start_off: tl.constexpr, + kv_start_off: tl.constexpr, + kv_end_off: tl.constexpr, + work_indptr_off: tl.constexpr, + batch_size: tl.constexpr, + num_clusters: tl.constexpr, + total_ctas: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_C: tl.constexpr, +): + # exact non-split 路径:每个 request 只有一个 work item。它直接使用真实 + # kv_indptr,也不会写 partial 输出,是最简单、最稳的 CUDA graph replay plan。 + cluster_offsets = tl.arange(0, BLOCK_C) + base_count = batch_size // num_clusters + extra_count = batch_size - base_count * num_clusters + work_indptr = cluster_offsets * base_count + tl.minimum(cluster_offsets, extra_count) + tl.store( + int_buf_i32 + work_indptr_off + cluster_offsets, + work_indptr, + mask=cluster_offsets <= num_clusters, + ) + + # 这条路径没有 work 会写 partial 输出,所以需要把所有 merge CTA 的 range + # 都置零,相当于禁用 merge 阶段。 + zeros = tl.full((BLOCK_C,), 0, tl.int32) + cta_offsets = tl.arange(0, BLOCK_C) + tl.store(int_buf_i32 + merge_start_off + cta_offsets, zeros, mask=cta_offsets < total_ctas) + tl.store(int_buf_i32 + merge_end_off + cta_offsets, zeros, mask=cta_offsets < total_ctas) + tl.store( + int_buf_i32 + merge_partial_start_off + cta_offsets, + zeros, + mask=cta_offsets < total_ctas, + ) + tl.store( + int_buf_i32 + merge_partial_end_off + cta_offsets, + zeros, + mask=cta_offsets < total_ctas, + ) + tl.store(int_buf_i32 + merge_stride_off + cta_offsets, zeros, mask=cta_offsets < total_ctas) + + batch_offsets = tl.arange(0, BLOCK_B) + valid_batch = batch_offsets < batch_size + cluster = batch_offsets % num_clusters + rank_in_cluster = batch_offsets // num_clusters + # 按 cluster 维度拉平 work item,这样 work_indptr[cluster] 可以指向一段 + # 连续 work range,和 FlashInfer scheduler 的约定保持一致。 + record_index = cluster * base_count + tl.minimum(cluster, extra_count) + rank_in_cluster + + kv_start = tl.load(kv_indptr + batch_offsets, mask=valid_batch, other=0) + kv_next = tl.load(kv_indptr + batch_offsets + 1, mask=valid_batch, other=0) + kv_len = kv_next - kv_start + + tl.store(int_buf_i32 + q_indptr_off + record_index, batch_offsets, mask=valid_batch) + tl.store(int_buf_i32 + kv_indptr_off + record_index, kv_start, mask=valid_batch) + tl.store(int_buf_i32 + partial_indptr_off + record_index, -1, mask=valid_batch) + tl.store(int_buf_i32 + q_len_off + record_index, 1, mask=valid_batch) + tl.store(int_buf_i32 + kv_len_off + record_index, kv_len, mask=valid_batch) + tl.store(int_buf_i32 + q_start_off + record_index, 0, mask=valid_batch) + tl.store(int_buf_i32 + kv_start_off + record_index, 0, mask=valid_batch) + tl.store(int_buf_i32 + kv_end_off + record_index, kv_len, mask=valid_batch) + + +@triton.jit +def _fill_fixed_chunk_mla_decode_plan_kernel( + int_buf_i32, + kv_indptr, + q_indptr_off: tl.constexpr, + kv_indptr_off: tl.constexpr, + partial_indptr_off: tl.constexpr, + merge_start_off: tl.constexpr, + merge_end_off: tl.constexpr, + merge_partial_start_off: tl.constexpr, + merge_partial_end_off: tl.constexpr, + merge_stride_off: tl.constexpr, + q_len_off: tl.constexpr, + kv_len_off: tl.constexpr, + q_start_off: tl.constexpr, + kv_start_off: tl.constexpr, + kv_end_off: tl.constexpr, + work_indptr_off: tl.constexpr, + batch_size: tl.constexpr, + num_heads: tl.constexpr, + cluster_size: tl.constexpr, + num_clusters: tl.constexpr, + total_ctas: tl.constexpr, + min_chunk_size: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_M: tl.constexpr, +): + # fixed-chunk split 路径,仅面向 decode q_len=1。只要 + # num_heads <= cluster_size * 64,每个 request 就只有一个 Q tile。长 KV + # 会被拆成多个 work item,再由 FlashInfer 已有的 kernel 内 persistent merge + # 合并。 + batch_offsets = tl.arange(0, BLOCK_B) + valid_batch = batch_offsets < batch_size + + kv_start = tl.load(kv_indptr + batch_offsets, mask=valid_batch, other=0) + kv_next = tl.load(kv_indptr + batch_offsets + 1, mask=valid_batch, other=0) + kv_len = kv_next - kv_start + + # 使用 GPU 上真实 kv_indptr 来模拟 FlashInfer CPU scheduler: + # kv_len_limit = f(total_kv_len / num_clusters) + # 这是它和 conservative max-kv plan 的关键区别:不依赖 host meta,也能适配 + # 长短混合 batch。 + total_kv_len = tl.load(kv_indptr + batch_size) - tl.load(kv_indptr) + avg_kv_len = (total_kv_len + num_clusters - 1) // num_clusters + chunk_hint = tl.where( + avg_kv_len <= 8, + 32, + tl.where( + avg_kv_len <= 16, + 64, + tl.where(avg_kv_len <= 32, 128, tl.where(avg_kv_len <= 64, 192, ((avg_kv_len + 255) // 256) * 256)), + ), + ) + chunk_size = tl.maximum(chunk_hint, min_chunk_size) + chunks = (kv_len + chunk_size - 1) // chunk_size + chunks = tl.where(valid_batch, tl.maximum(chunks, 1), 0) + + # work item 按 request-major 排列:request 0 的所有 chunk 在前,然后是 + # request 1,以此类推。下面的 work_indptr 会把这个扁平 work range 均匀 + # 切给 FlashInfer 的各个 cluster。 + chunk_prefix = tl.cumsum(chunks, 0) - chunks + total_chunks = tl.sum(chunks, 0) + + # partial_indptr 的单位是 packed Q row。decode q_len=1 时 row tile size + # 就是 num_heads;split request 的每个 KV chunk 都需要在 partial output + # workspace 中占用一个 row_tile_size 切片。 + row_tile_size: tl.constexpr = num_heads + partial_rows = tl.where(chunks > 1, chunks * row_tile_size, 0) + partial_base = tl.cumsum(partial_rows, 0) - partial_rows + + cluster_offsets = tl.arange(0, BLOCK_C) + work_start = (total_chunks * cluster_offsets) // num_clusters + tl.store( + int_buf_i32 + work_indptr_off + cluster_offsets, + work_start, + mask=cluster_offsets <= num_clusters, + ) + + # 先清空所有 merge CTA。只有 split request 会在下面写入有效 merge 描述。 + # FlashInfer 会把 zero-length merge range 当成 no-op。 + zeros = tl.full((BLOCK_C,), 0, tl.int32) + cta_offsets = tl.arange(0, BLOCK_C) + tl.store(int_buf_i32 + merge_start_off + cta_offsets, zeros, mask=cta_offsets < total_ctas) + tl.store(int_buf_i32 + merge_end_off + cta_offsets, zeros, mask=cta_offsets < total_ctas) + tl.store( + int_buf_i32 + merge_partial_start_off + cta_offsets, + zeros, + mask=cta_offsets < total_ctas, + ) + tl.store( + int_buf_i32 + merge_partial_end_off + cta_offsets, + zeros, + mask=cta_offsets < total_ctas, + ) + tl.store(int_buf_i32 + merge_stride_off + cta_offsets, zeros, mask=cta_offsets < total_ctas) + + is_split = valid_batch & (chunks > 1) + split_count = tl.sum(is_split.to(tl.int32), 0) + # FlashInfer 启动 total_ctas = num_blks_x * num_blks_y 个 CTA。merge 表 + # 每个 CTA 只有一个 entry,所以 split request 需要共享这份容量。如果只有 + # 一个 request 很长,它可以使用较多 merge CTA;如果很多 request 都很长, + # 每个 request 分到的 row chunk 就会减少。 + merge_capacity = total_ctas // tl.maximum(split_count, 1) + # merge work 沿 packed row/head 方向切分。merge_chunks 越大,表示越多 + # CTA 并行 merge 不同 head range。 + merge_chunks = tl.minimum(num_heads, tl.minimum(chunks * cluster_size, merge_capacity)) + merge_chunks_for_div = tl.maximum(merge_chunks, 1) + row_chunk_size = (num_heads + merge_chunks_for_div - 1) // merge_chunks_for_div + merge_chunks = (num_heads + row_chunk_size - 1) // row_chunk_size + merge_chunks = tl.where(is_split, merge_chunks, 0) + merge_base = tl.cumsum(merge_chunks, 0) - merge_chunks + merge_offsets = tl.arange(0, BLOCK_M) + valid_merge = merge_offsets[None, :] < merge_chunks[:, None] + local_merge_start = merge_offsets[None, :] * row_chunk_size[:, None] + local_merge_end = tl.minimum(local_merge_start + row_chunk_size[:, None], row_tile_size) + merge_index = merge_base[:, None] + merge_offsets[None, :] + tl.store( + int_buf_i32 + merge_start_off + merge_index, + batch_offsets[:, None] * num_heads + local_merge_start, + mask=valid_merge, + ) + tl.store( + int_buf_i32 + merge_end_off + merge_index, + batch_offsets[:, None] * num_heads + local_merge_end, + mask=valid_merge, + ) + tl.store( + int_buf_i32 + merge_partial_start_off + merge_index, + partial_base[:, None] + local_merge_start, + mask=valid_merge, + ) + tl.store( + int_buf_i32 + merge_partial_end_off + merge_index, + partial_base[:, None] + chunks[:, None] * row_tile_size, + mask=valid_merge, + ) + tl.store(int_buf_i32 + merge_stride_off + merge_index, row_tile_size, mask=valid_merge) + + # 写入真正的 attention work-item 表。非 split request 保持 + # partial_indptr=-1,直接写最终输出。split request 把每个 chunk 的 partial + # 写到: + # partial_base + chunk_idx * row_tile_size + # 上面的 merge 表随后会归并这些 partial row。 + chunk_offsets = tl.arange(0, BLOCK_K) + work_id = chunk_prefix[:, None] + chunk_offsets[None, :] + valid_chunk = valid_batch[:, None] & (chunk_offsets[None, :] < chunks[:, None]) + chunk_start = chunk_offsets[None, :] * chunk_size + chunk_end = tl.minimum(chunk_start + chunk_size, kv_len[:, None]) + partial_indptr = tl.where( + chunks[:, None] > 1, + partial_base[:, None] + chunk_offsets[None, :] * row_tile_size, + -1, + ) + + tl.store(int_buf_i32 + q_indptr_off + work_id, batch_offsets[:, None], mask=valid_chunk) + tl.store(int_buf_i32 + kv_indptr_off + work_id, kv_start[:, None], mask=valid_chunk) + tl.store(int_buf_i32 + partial_indptr_off + work_id, partial_indptr, mask=valid_chunk) + tl.store(int_buf_i32 + q_len_off + work_id, 1, mask=valid_chunk) + tl.store(int_buf_i32 + kv_len_off + work_id, kv_len[:, None], mask=valid_chunk) + tl.store(int_buf_i32 + q_start_off + work_id, 0, mask=valid_chunk) + tl.store(int_buf_i32 + kv_start_off + work_id, chunk_start, mask=valid_chunk) + tl.store(int_buf_i32 + kv_end_off + work_id, chunk_end, mask=valid_chunk) + + +@torch.no_grad() +def fill_exact_mla_decode_plan(decode_wrapper, kv_indptr: torch.Tensor, batch_size: int) -> None: + plan_info = [int(v) for v in decode_wrapper._plan_info] + int_buf_i32 = decode_wrapper._int_workspace_buffer.view(torch.int32) + # FlashInfer plan offset 是 byte offset;Triton 这里按 int32 元素写入。 + offsets = [v // 4 for v in plan_info] + # plan_info[0] 是 grid.x:一个 cluster 内的 CTA 数。 + # plan_info[1] 是 grid.y:cluster 数。 + num_blks_x = plan_info[0] + num_blks_y = plan_info[1] + block_b = triton.next_power_of_2(max(batch_size, 1)) + block_c = triton.next_power_of_2(max(num_blks_x * num_blks_y, num_blks_y + 1, 1)) + + _fill_exact_mla_decode_plan_kernel[(1,)]( + int_buf_i32=int_buf_i32, + kv_indptr=kv_indptr, + q_indptr_off=offsets[2], + kv_indptr_off=offsets[3], + partial_indptr_off=offsets[4], + merge_start_off=offsets[5], + merge_end_off=offsets[6], + merge_partial_start_off=offsets[7], + merge_partial_end_off=offsets[8], + merge_stride_off=offsets[9], + q_len_off=offsets[10], + kv_len_off=offsets[11], + q_start_off=offsets[12], + kv_start_off=offsets[13], + kv_end_off=offsets[14], + work_indptr_off=offsets[15], + batch_size=batch_size, + num_clusters=num_blks_y, + total_ctas=num_blks_x * num_blks_y, + BLOCK_B=block_b, + BLOCK_C=block_c, + num_warps=8, + ) + return + + +def _mla_kv_len_limit_hint(avg_kv_len: int) -> int: + # 保持和 FlashInfer MLAPlan scheduler 相同的分段函数。这个值对齐后, + # 生成的 split plan 在长上下文性能上才会接近原生 CPU plan。 + if avg_kv_len <= 8: + return 32 + if avg_kv_len <= 16: + return 64 + if avg_kv_len <= 32: + return 128 + if avg_kv_len <= 64: + return 192 + return triton.cdiv(avg_kv_len, 256) * 256 + + +@torch.no_grad() +def fill_fixed_chunk_mla_decode_plan( + decode_wrapper, + kv_indptr: torch.Tensor, + batch_size: int, + num_heads: int, + max_kv_len: int, +) -> bool: + plan_info = [int(v) for v in decode_wrapper._plan_info] + # FlashInfer 启动 MLA run kernel 时使用 grid=(num_blks_x, num_blks_y)。 + # num_blks_x 同时也是 scheduler 中的 cluster_size。每个 CTA 覆盖 64 个 + # packed Q row,所以一个 cluster 覆盖 cluster_tile_q 个 row。 + num_blks_x = plan_info[0] + num_blks_y = plan_info[1] + num_ctas = num_blks_x * num_blks_y + cluster_tile_q = num_blks_x * 64 + row_tile_size = num_heads + + # 计入 plan 生成开销后,极短 decode 使用 exact non-split 更快。当前 split + # kernel 也有意限制在 q_len=1 且每个 request 只有一个 Q tile 的场景。 + if batch_size <= 0 or max_kv_len <= 512 or row_tile_size <= 0: + return False + if num_heads > cluster_tile_q or batch_size > num_ctas: + return False + + # 这个上界用于确定 BLOCK_K,同时保护 FlashInfer 固定的 + # max_total_num_works=16384 表容量。实际 chunk size 仍会在 Triton kernel + # 内根据真实 GPU kv_indptr 重新计算。 + min_chunk_size = _mla_kv_len_limit_hint(triton.cdiv(max_kv_len, num_blks_y)) + if min_chunk_size >= max_kv_len: + return False + + # FlashInfer MLAPlan 里 work-item 相关数组固定按 max_total_num_works=16384 + # 分配。这里用 graph shape 的最坏情况估算每个 request 最多会被拆成多少 + # chunk,如果 batch_size * max_chunks_per_request 超过 16384,就不能安全 + # 覆写这个 plan layout,必须回退 exact non-split。 + max_chunks_per_request = triton.cdiv(max_kv_len, min_chunk_size) + if batch_size * max_chunks_per_request > 16384: + return False + + int_buf_i32 = decode_wrapper._int_workspace_buffer.view(torch.int32) + # FlashInfer plan offset 是 byte offset;Triton 这里按 int32 元素写入。 + offsets = [v // 4 for v in plan_info] + block_b = triton.next_power_of_2(max(batch_size, 1)) + block_k = triton.next_power_of_2(max(max_chunks_per_request, 1)) + block_c = triton.next_power_of_2(max(num_ctas, num_blks_y + 1, 1)) + + _fill_fixed_chunk_mla_decode_plan_kernel[(1,)]( + int_buf_i32=int_buf_i32, + kv_indptr=kv_indptr, + q_indptr_off=offsets[2], + kv_indptr_off=offsets[3], + partial_indptr_off=offsets[4], + merge_start_off=offsets[5], + merge_end_off=offsets[6], + merge_partial_start_off=offsets[7], + merge_partial_end_off=offsets[8], + merge_stride_off=offsets[9], + q_len_off=offsets[10], + kv_len_off=offsets[11], + q_start_off=offsets[12], + kv_start_off=offsets[13], + kv_end_off=offsets[14], + work_indptr_off=offsets[15], + batch_size=batch_size, + num_heads=num_heads, + cluster_size=num_blks_x, + num_clusters=num_blks_y, + total_ctas=num_ctas, + min_chunk_size=min_chunk_size, + BLOCK_B=block_b, + BLOCK_K=block_k, + BLOCK_C=block_c, + BLOCK_M=triton.next_power_of_2(max(num_heads, 1)), + num_warps=8, + ) + return True + + +@torch.no_grad() +def fill_mla_decode_plan_for_cuda_graph( + decode_wrapper, + kv_indptr: torch.Tensor, + batch_size: int, + num_heads: int, + max_kv_len: int, +) -> str: + # 长 decode 优先使用 split plan,因为它能匹配 FlashInfer 的 split-K 并行度。 + # 短序列或当前不支持的 graph shape 回退到 exact plan 来保证正确性。 + use_fixed_chunk_split = fill_fixed_chunk_mla_decode_plan( + decode_wrapper, + kv_indptr, + batch_size, + num_heads, + max_kv_len, + ) + if use_fixed_chunk_split: + return "fixed_chunk_split" + + fill_exact_mla_decode_plan(decode_wrapper, kv_indptr, batch_size) + return "exact_non_split" diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/append_shared_expert_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/append_shared_expert_topk.py new file mode 100644 index 0000000000..03099612bb --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/append_shared_expert_topk.py @@ -0,0 +1,151 @@ +import torch +import triton +import triton.language as tl +from typing import Optional, Tuple +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _append_fused_shared_experts_kernel( + topk_weights_ptr, + topk_ids_ptr, + shared_expert_gate_ptr, + out_topk_weights_ptr, + out_topk_ids_ptr, + token_num, + topk_num: tl.constexpr, + out_topk_num: tl.constexpr, + shared_expert_start_id: tl.constexpr, + num_fused_shared_experts: tl.constexpr, + shared_expert_gate_stride_0: tl.constexpr, + HAS_SHARED_EXPERT_GATE: tl.constexpr, + BLOCK_TOKEN: tl.constexpr, + TOPK_BLOCK: tl.constexpr, + SHARED_BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + token_offsets = pid * BLOCK_TOKEN + tl.arange(0, BLOCK_TOKEN) + token_mask = token_offsets < token_num + topk_offsets = tl.arange(0, TOPK_BLOCK) + topk_mask = topk_offsets < topk_num + shared_expert_offsets = tl.arange(0, SHARED_BLOCK) + shared_expert_mask = shared_expert_offsets < num_fused_shared_experts + + topk_in_offsets = token_offsets[:, None] * topk_num + topk_offsets[None, :] + topk_out_offsets = token_offsets[:, None] * out_topk_num + topk_offsets[None, :] + topk_valid_mask = token_mask[:, None] & topk_mask[None, :] + topk_ids = tl.load(topk_ids_ptr + topk_in_offsets, mask=topk_valid_mask, other=0) + topk_weights = tl.load(topk_weights_ptr + topk_in_offsets, mask=topk_valid_mask, other=0.0) + tl.store(out_topk_ids_ptr + topk_out_offsets, topk_ids, mask=topk_valid_mask) + tl.store(out_topk_weights_ptr + topk_out_offsets, topk_weights, mask=topk_valid_mask) + + shared_out_offsets = token_offsets[:, None] * out_topk_num + topk_num + shared_expert_offsets[None, :] + shared_valid_mask = token_mask[:, None] & shared_expert_mask[None, :] + shared_ids = shared_expert_start_id + shared_expert_offsets + tl.store(out_topk_ids_ptr + shared_out_offsets, shared_ids[None, :], mask=shared_valid_mask) + + shared_weights = tl.full((BLOCK_TOKEN, SHARED_BLOCK), 1.0, tl.float32) + if HAS_SHARED_EXPERT_GATE: + gate_offsets = token_offsets[:, None] * shared_expert_gate_stride_0 + shared_expert_offsets[None, :] + gate_vals = tl.load(shared_expert_gate_ptr + gate_offsets, mask=shared_valid_mask, other=0.0).to(tl.float32) + shared_weights = tl.sigmoid(gate_vals) + tl.store(out_topk_weights_ptr + shared_out_offsets, shared_weights, mask=shared_valid_mask) + + +def _get_append_fused_shared_experts_static_key( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_fused_shared_experts: int, + shared_expert_gate: Optional[torch.Tensor] = None, +) -> dict: + return { + "topk_num": topk_ids.shape[1], + "num_fused_shared_experts": num_fused_shared_experts, + "has_shared_expert_gate": shared_expert_gate is not None, + "topk_weights_dtype": str(topk_weights.dtype), + "topk_ids_dtype": str(topk_ids.dtype), + } + + +def _get_append_fused_shared_experts_configs(): + block_token_choices = (4, 8, 16, 32, 64, 128, 256) + num_warps_choices = (1, 2, 4, 8) + return [ + {"BLOCK_TOKEN": block_token, "num_warps": num_warps} + for block_token in block_token_choices + for num_warps in num_warps_choices + ] + + +@torch.no_grad() +@autotune( + kernel_name="append_fused_shared_experts:v1", + configs_gen_func=_get_append_fused_shared_experts_configs, + static_key_func=_get_append_fused_shared_experts_static_key, + run_key_func=lambda topk_ids: topk_ids.shape[0], +) +def append_fused_shared_experts( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_expert_start_id: int, + num_fused_shared_experts: int, + shared_expert_gate: Optional[torch.Tensor] = None, + run_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert topk_weights.dim() == 2 and topk_ids.dim() == 2 + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert num_fused_shared_experts > 0 + + topk_weights = topk_weights.contiguous() + topk_ids = topk_ids.contiguous() + token_num, topk_num = topk_ids.shape + out_topk_num = topk_num + num_fused_shared_experts + out_topk_weights = torch.empty((token_num, out_topk_num), dtype=topk_weights.dtype, device=topk_weights.device) + out_topk_ids = torch.empty((token_num, out_topk_num), dtype=topk_ids.dtype, device=topk_ids.device) + + has_shared_expert_gate = shared_expert_gate is not None + if has_shared_expert_gate: + shared_expert_gate = shared_expert_gate.contiguous().view(token_num, -1) + assert shared_expert_gate.shape[1] == num_fused_shared_experts, "shared_expert_gate shape mismatch" + shared_expert_gate_stride_0 = shared_expert_gate.stride(0) + assert shared_expert_gate.stride(1) == 1, "shared_expert_gate last dim must be contiguous" + else: + shared_expert_gate_stride_0 = 0 + + if run_config is None: + if token_num <= 1: + run_config = {"BLOCK_TOKEN": 4, "num_warps": 2} + elif token_num <= 4: + run_config = {"BLOCK_TOKEN": 8, "num_warps": 4} + elif token_num <= 65536: + run_config = {"BLOCK_TOKEN": 32, "num_warps": 8} + elif token_num <= 131072: + run_config = {"BLOCK_TOKEN": 64, "num_warps": 8} + elif token_num <= 262144: + run_config = {"BLOCK_TOKEN": 128, "num_warps": 8} + else: + run_config = {"BLOCK_TOKEN": 256, "num_warps": 8} + + block_token = run_config["BLOCK_TOKEN"] + num_warps = run_config["num_warps"] + grid_num = triton.cdiv(token_num, block_token) + grid = (grid_num,) + _append_fused_shared_experts_kernel[grid]( + topk_weights, + topk_ids, + shared_expert_gate, + out_topk_weights, + out_topk_ids, + token_num, + topk_num=topk_num, + out_topk_num=out_topk_num, + shared_expert_start_id=shared_expert_start_id, + num_fused_shared_experts=num_fused_shared_experts, + shared_expert_gate_stride_0=shared_expert_gate_stride_0, + HAS_SHARED_EXPERT_GATE=has_shared_expert_gate, + BLOCK_TOKEN=block_token, + TOPK_BLOCK=triton.next_power_of_2(topk_num), + SHARED_BLOCK=triton.next_power_of_2(num_fused_shared_experts), + num_warps=num_warps, + ) + return out_topk_weights, out_topk_ids diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 638abbd6ca..c6eeb781dc 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -214,90 +214,188 @@ def moe_align1( @triton.jit -def moe_align_fused_kernel( +def moe_align_fused_small_token_kernel( topk_ids_ptr, # [token_num, topk] topk_weights_ptr, # [token_num, topk] expert_to_token_index_ptr, # [expert_num, token_num * topk] expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] - token_num, - topk_num: tl.constexpr, + token_num_mul_topk, BLOCK_SIZE: tl.constexpr, + NUM_STAGE: tl.constexpr, ): - token_block = tl.program_id(0) - offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offs < token_num * topk_num + expert_id = tl.program_id(0) + block_offs = tl.arange(0, BLOCK_SIZE) + token_count = tl.full((), 0, tl.int32) + + for start in tl.range(0, token_num_mul_topk, BLOCK_SIZE, num_stages=NUM_STAGE): + raw_offs = start + block_offs + valid = raw_offs < token_num_mul_topk + load_offs = tl.where(valid, raw_offs, 0) + expert_ids = tl.load(topk_ids_ptr + load_offs, mask=valid, other=-1) + weights = tl.load(topk_weights_ptr + load_offs, mask=valid, other=0.0) + + expert_mask = (expert_ids == expert_id) & valid + expert_hits = tl.where(expert_mask, 1, 0) + write_pos = token_count + tl.cumsum(expert_hits, axis=0) - 1 + tl.store( + expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos, + raw_offs, + mask=expert_mask, + ) + tl.store( + expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, + weights, + mask=expert_mask, + ) + token_count += tl.sum(expert_hits, axis=0) - expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0) - weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0) - - # 用 atomic_add 给 expert 分配写位置 - write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask) - - # 按 token 顺序写 index 和 weight - tl.store( - expert_to_token_index_ptr + expert_ids * (token_num * topk_num) + write_pos, - offs, - mask=mask, - ) - tl.store( - expert_to_weight_ptr + expert_ids * (token_num * topk_num) + write_pos, - weights, - mask=mask, - ) + tl.store(expert_token_num_ptr + expert_id, token_count) def _get_moe_align_fused_static_key( + expert_token_num: torch.Tensor, topk_weights: torch.Tensor, ) -> dict: topk_num = topk_weights.shape[1] + expert_num = expert_token_num.shape[0] return { "topk_num": topk_num, + "expert_num": expert_num, } -def _get_moe_align_fused_configs(): - return [ +@autotune( + kernel_name="moe_align_fused_small:v2", + configs_gen_func=lambda: [ { "BLOCK_SIZE": bt, "num_warps": nw, + "NUM_STAGE": ns, } + for ns in [1, 2, 4, 6] for nw in [1, 2, 4, 8] - for bt in [128, 256, 512, 1024, 2048] - ] + for bt in [8, 16, 32, 64, 128, 256, 512, 1024, 2048] + ], + static_key_func=_get_moe_align_fused_static_key, + run_key_func=lambda topk_ids: topk_ids.shape[0], + mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], +) +def _moe_align_fused_small_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + run_config: Optional[dict] = None, +): + if run_config is None: + token_num = topk_ids.shape[0] + if token_num <= 2: + run_config = {"BLOCK_SIZE": 16, "num_warps": 1, "NUM_STAGE": 1} + elif token_num <= 8: + run_config = {"BLOCK_SIZE": 64, "num_warps": 1, "NUM_STAGE": 1} + elif token_num <= 16: + run_config = {"BLOCK_SIZE": 128, "num_warps": 1, "NUM_STAGE": 1} + elif token_num < 32: + run_config = {"BLOCK_SIZE": 256, "num_warps": 2, "NUM_STAGE": 1} + elif token_num <= 64: + run_config = {"BLOCK_SIZE": 512, "num_warps": 4, "NUM_STAGE": 1} + elif token_num <= 128: + run_config = {"BLOCK_SIZE": 1024, "num_warps": 8, "NUM_STAGE": 1} + elif token_num <= 192: + run_config = {"BLOCK_SIZE": 512, "num_warps": 4, "NUM_STAGE": 1} + else: + run_config = {"BLOCK_SIZE": 2048, "num_warps": 8, "NUM_STAGE": 1} + token_num_mul_topk = topk_ids.numel() + expert_num = expert_token_num.shape[0] + block_size = run_config["BLOCK_SIZE"] + + moe_align_fused_small_token_kernel[(expert_num,)]( + topk_ids, + topk_weights, + expert_to_token_index, + expert_to_weight, + expert_token_num, + token_num_mul_topk, + BLOCK_SIZE=block_size, + NUM_STAGE=run_config["NUM_STAGE"], + num_warps=run_config["num_warps"], + ) + return expert_to_token_index, expert_to_weight, expert_token_num + + +@triton.jit +def moe_align_fused_atomic_kernel( + topk_ids_ptr, # [token_num, topk] + topk_weights_ptr, # [token_num, topk] + expert_to_token_index_ptr, # [expert_num, token_num * topk] + expert_to_weight_ptr, # [expert_num, token_num * topk] + expert_token_num_ptr, # [expert_num] + token_num_mul_topk, + BLOCK_SIZE: tl.constexpr, +): + block_id = tl.program_id(0) + offs = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + valid = offs < token_num_mul_topk + expert_id = tl.load(topk_ids_ptr + offs, mask=valid, other=0) + weight = tl.load(topk_weights_ptr + offs, mask=valid, other=0.0) + write_pos = tl.atomic_add(expert_token_num_ptr + expert_id, 1, mask=valid) + tl.store(expert_to_token_index_ptr + expert_id * token_num_mul_topk + write_pos, offs, mask=valid) + tl.store(expert_to_weight_ptr + expert_id * token_num_mul_topk + write_pos, weight, mask=valid) @autotune( - kernel_name="moe_align_fused:v1", - configs_gen_func=_get_moe_align_fused_configs, + kernel_name="moe_align_fused_atomic:v1", + configs_gen_func=lambda: [ + { + "BLOCK_SIZE": block_size, + "num_warps": num_warps, + } + for num_warps in [1, 2, 4, 8] + for block_size in [128, 256, 512, 1024, 2048] + ], static_key_func=_get_moe_align_fused_static_key, run_key_func=lambda topk_ids: topk_ids.shape[0], mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], ) -def moe_align_fused( - expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None +def _moe_align_fused_atomic_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + run_config: Optional[dict] = None, ): - token_num, topk_num = topk_ids.shape if run_config is None: - run_config = {} - BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) - num_warps = run_config.get("num_warps", 4) + run_config = {"BLOCK_SIZE": 128, "num_warps": 4} - grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) - moe_align_fused_kernel[grid]( + token_num_mul_topk = topk_ids.numel() + expert_token_num.zero_() + moe_align_fused_atomic_kernel[(triton.cdiv(token_num_mul_topk, run_config["BLOCK_SIZE"]),)]( topk_ids, topk_weights, expert_to_token_index, expert_to_weight, expert_token_num, - token_num, - topk_num, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, + token_num_mul_topk, + BLOCK_SIZE=run_config["BLOCK_SIZE"], + num_warps=run_config["num_warps"], ) return expert_to_token_index, expert_to_weight, expert_token_num +def moe_align_fused(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights): + token_num = topk_ids.shape[0] + if token_num <= 128: + _moe_align_fused_small_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights) + else: + # Expert rows may be unordered, but grouped matmul reuses this same + # mapping for up/down projections and writes back to original topk slots. + _moe_align_fused_atomic_token(expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights) + return expert_to_token_index, expert_to_weight, expert_token_num + + @triton.jit def moe_align2_kernel( experts_token_num_ptr, # [expert_num,] @@ -503,7 +601,7 @@ def grouped_matmul_kernel( else: a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num)[:, None] - a_scale = tl.load(a_scale_ptrs, eviction_policy="evict_last") + a_scale = tl.load(a_scale_ptrs, mask=token_mask[:, None], other=0.0, eviction_policy="evict_last") b_scale = tl.load( weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bn[None, :] * weight_scale_stride1, eviction_policy="evict_last", @@ -957,7 +1055,7 @@ def fused_experts_impl( expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda") expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda") - expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda") + expert_to_token_num = torch.empty((E,), dtype=torch.int32, device="cuda") moe_align_fused( expert_to_token_index=expert_to_tokens, expert_to_weight=expert_to_weights, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2c6d013bd5..cb2e370cb9 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -1,10 +1,8 @@ """Fused MoE kernel.""" -import os import torch import triton -import triton.language as tl from typing import Any, Callable, Dict, Optional, Tuple -import torch.distributed as dist +from lightllm.distributed import dist_group_manager from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( @@ -15,11 +13,16 @@ tma_align_input_scale, ) from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.triton_utils.autotuner import Autotuner -import numpy as np +from lightllm.utils.device_utils import is_sm100_gpu logger = init_logger(__name__) +_MEGA_MOE_STATES: Dict[Tuple[int, int, int, int], Dict[str, Any]] = {} +SUPPORTED_EP_EXPERT_DTYPES = ("deepgemm-fp8w8a8-b128", "deepgemm-fp4fp8-b32") try: from deep_ep import Buffer, EventOverlap @@ -31,6 +34,29 @@ HAS_DEEPGEMM = False +def get_ep_num_sms() -> int: + return getattr(dist_group_manager, "ep_num_sms", None) or 0 + + +def use_sm100_mega_moe(quant_method: Any) -> bool: + return is_sm100_gpu() and quant_method.method_name == "deepgemm-fp4fp8-b32" + + +def check_ep_expert_dtype(quant_method: Any): + expert_dtype = getattr(quant_method, "method_name", None) + if expert_dtype not in SUPPORTED_EP_EXPERT_DTYPES: + raise ValueError( + "EP MoE requires --expert_dtype to be one of ['fp8', 'fp4'], " + f"but the resolved fused_moe quant method is `{expert_dtype}`. " + "Please start with --expert_dtype fp8 or --expert_dtype fp4. " + "Note that --expert_dtype fp4 is only supported on SM100 GPUs." + ) + if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu(): + raise RuntimeError( + "--expert_dtype fp4 requires an SM100 GPU for EP MoE; " "please use --expert_dtype fp8 on non-SM100 GPUs." + ) + + def masked_group_gemm( recv_x: Tuple[torch.Tensor, torch.Tensor], masked_m: torch.Tensor, @@ -59,6 +85,138 @@ def masked_group_gemm( return gemm_out_b +def _get_mega_moe_cache_state(w13: Any, w2: Any): + state_key = ( + w13.weight.data_ptr(), + w13.weight_scale.data_ptr(), + w2.weight.data_ptr(), + w2.weight_scale.data_ptr(), + ) + return _MEGA_MOE_STATES.setdefault(state_key, {}) + + +def _get_mega_moe_weights(w13: Any, w2: Any, state: Dict[str, Any]): + if "weight_cache" not in state: + state["weight_cache"] = deep_gemm.transform_weights_for_mega_moe( + (w13.weight, w13.weight_scale), + (w2.weight, w2.weight_scale), + ) + return state["weight_cache"] + + +def _get_mega_moe_cumulative_stats(num_local_experts: int, device: torch.device, state: Dict[str, Any]): + stats = state.get("stats") + if stats is None or stats.numel() != num_local_experts or stats.device != device: + stats = torch.zeros((num_local_experts,), device=device, dtype=torch.int32) + state["stats"] = stats + return stats + + +def mega_moe_impl( + hidden_states: torch.Tensor, + w13: Any, + w2: Any, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_method: Any, +): + if not (HAS_DEEPGEMM and hasattr(deep_gemm, "fp8_fp4_mega_moe")): + raise RuntimeError("deep_gemm does not provide fp8-fp4 Mega MoE kernel") + + from deep_gemm.utils import per_token_cast_to_fp8 + + buffer = getattr(dist_group_manager, "ep_mega_moe_buffer", None) + if buffer is None: + raise RuntimeError("SM100 Mega MoE requires dist_group_manager.ep_mega_moe_buffer to be initialized") + + num_tokens = hidden_states.shape[0] + if num_tokens > buffer.num_max_tokens_per_rank: + raise RuntimeError( + f"Mega MoE got {num_tokens} tokens, exceeding num_max_tokens_per_rank={buffer.num_max_tokens_per_rank}" + ) + + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=quant_method.block_size, + use_packed_ue8m0=True, + ) + state = _get_mega_moe_cache_state(w13, w2) + l1_weights, l2_weights = _get_mega_moe_weights(w13, w2, state) + stats = _get_mega_moe_cumulative_stats(w13.weight.shape[0], hidden_states.device, state) + buffer.x[:num_tokens].copy_(qinput_tensor[0]) + buffer.x_sf[:num_tokens].copy_(qinput_tensor[1]) + buffer.topk_idx[:num_tokens].copy_(topk_ids) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + output = torch.empty_like(hidden_states) + deep_gemm.fp8_fp4_mega_moe( + output, + l1_weights, + l2_weights, + buffer, + cumulative_local_expert_recv_stats=stats, + ) + return output + + +def quantize_fused_experts_input( + hidden_states: torch.Tensor, + w13: Any, + quant_method: Any, +): + check_ep_expert_dtype(quant_method) + if use_sm100_mega_moe(quant_method): + from deep_gemm.utils import per_token_cast_to_fp8 + + return per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=quant_method.block_size, + use_packed_ue8m0=True, + ) + + block_size_k = 0 + if w13.weight.ndim == 3: + block_size_k = w13.weight.shape[2] // w13.weight_scale.shape[2] + assert block_size_k == 128, "block_size_k must be 128" + return per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13.weight.dtype) + + +def fused_experts( + hidden_states: torch.Tensor, + w13: Any, + w2: Any, + topk_weights: torch.Tensor, + topk_idx: torch.Tensor, + num_experts: int, + quant_method: Any, + is_prefill: Optional[bool], + previous_event: Optional[Any] = None, +): + check_ep_expert_dtype(quant_method) + if use_sm100_mega_moe(quant_method): + return mega_moe_impl(hidden_states, w13, w2, topk_weights, topk_idx, quant_method) + + buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer + return fused_experts_impl( + hidden_states=hidden_states, + w1=w13.weight, + w2=w2.weight, + topk_weights=topk_weights, + topk_idx=topk_idx, + num_experts=num_experts, + buffer=buffer, + is_prefill=is_prefill, + use_fp8_w8a8=True, + use_fp8_all2all=True, + use_int8_w8a16=False, + w1_scale=w13.weight_scale, + w2_scale=w2.weight_scale, + previous_event=previous_event, + ) + + def fused_experts_impl( hidden_states: torch.Tensor, # [M, K] w1: torch.Tensor, # [group, N, K] @@ -66,14 +224,14 @@ def fused_experts_impl( topk_weights: torch.Tensor, # [M, topk] topk_idx: torch.Tensor, # [M, topk] num_experts: int, - buffer: "Buffer", + buffer: Any, is_prefill: bool, use_fp8_w8a8: bool = False, use_fp8_all2all: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, - previous_event: Optional["EventOverlap"] = None, + previous_event: Optional[Any] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -99,39 +257,27 @@ def fused_experts_impl( combined_x = None if is_prefill: qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype) - - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, num_experts, previous_event=previous_event, async_finish=False, allocate_on_comm_stream=False - ) - + allocate_on_comm_stream = previous_event is not None # normal dispatch # recv_x [recive_num_tokens, hidden] recv_x_scale [recive_num_tokens, hidden // block_size] # recv_topk_idx [recive_num_tokens, topk_num] # recv_topk_weights [recive_num_tokens, topk_num] # num_recv_tokens_per_expert_list list [cur_node_expert_num] padding with expert_alignment=128 - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + recv_x, recv_topk_idx, recv_topk_weights, handle, _ = buffer.dispatch( (qinput_tensor, input_scale), topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=False, - allocate_on_comm_stream=False, + num_experts=num_experts, + num_max_tokens_per_rank=get_deepep_num_max_dispatch_tokens_per_rank_prefill(), expert_alignment=128, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + do_cpu_sync=True, + do_handle_copy=False, ) # scatter - all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. + all_tokens = sum(handle.num_recv_tokens_per_expert_list) # calcu padding all nums. # gather_out shape [recive_num_tokens, hidden] gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) if all_tokens > 0: @@ -149,7 +295,7 @@ def fused_experts_impl( output_index = torch.empty_like(recv_topk_idx) num_recv_tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" + handle.num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" ).cuda(non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) @@ -169,7 +315,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype) input_tensor[1] = tma_align_input_scale(input_tensor[1]) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -183,7 +329,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=hidden_states.device, dtype=hidden_states.dtype) - _deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) + deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) # gather and local reduce ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) @@ -202,13 +348,12 @@ def fused_experts_impl( gather_out, handle, topk_weights=None, - async_finish=False, previous_event=previous_event, - allocate_on_comm_stream=False, + allocate_on_comm_stream=allocate_on_comm_stream, ) else: # low latency dispatch - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() expected_m = triton.cdiv(hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1], num_experts) recv_x, masked_m, handle, event, hook = buffer.low_latency_dispatch( hidden_states, @@ -228,7 +373,7 @@ def fused_experts_impl( return combined_x -def _deepgemm_grouped_fp8_nt_contiguous( +def deepgemm_grouped_fp8_nt_contiguous( input_tuple: Tuple[torch.Tensor, torch.Tensor], w_tuple: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, @@ -255,3 +400,22 @@ def _deepgemm_grouped_fp8_nt_masked( if hasattr(deep_gemm, "m_grouped_gemm_fp8_fp8_bf16_nt_masked"): return deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(input_tuple, w_tuple, out, masked_m, expected_m) raise RuntimeError("deep_gemm does not provide grouped_gemm_fp8 NT contiguous GEMM kernel in this version") + + +def deepgemm_grouped_fp8_fp4_nt_contiguous( + input_tuple: Tuple[torch.Tensor, torch.Tensor], + w_tuple: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + grouped_layout: torch.Tensor, + use_psum_layout: bool = False, +): + if HAS_DEEPGEMM and hasattr(deep_gemm, "m_grouped_fp8_fp4_gemm_nt_contiguous"): + return deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + input_tuple, + w_tuple, + out, + grouped_layout, + use_psum_layout=use_psum_layout, + recipe=(1, 1, 32), + ) + raise RuntimeError("deep_gemm does not provide grouped fp8-fp4 NT contiguous GEMM kernel") diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index d7bcc17743..a63d92692e 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -3,6 +3,7 @@ import triton import triton.language as tl from lightllm.common.triton_utils.autotuner import autotune +from lightllm.utils.config_utils import ffn_use_tanh_approximate_gelu @triton.jit @@ -23,6 +24,7 @@ def _silu_and_mul_kernel_fast( NEED_MASK: tl.constexpr, layout: tl.constexpr = "blocked", # "blocked" or "interleaved" USE_LIMIT_AND_ALPHA: tl.constexpr = False, + USE_TANH_APPROXIMATE_GELU: tl.constexpr = False, ): stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) stride_output_m = tl.cast(stride_output_m, dtype=tl.int64) @@ -74,7 +76,14 @@ def _silu_and_mul_kernel_fast( mask=mask, ) else: - gate = gate / (1 + tl.exp(-gate)) + if USE_TANH_APPROXIMATE_GELU: + # tanh-approx GELU, matching Gemma's gelu_pytorch_tanh MLP. + gate_cubed = gate * gate * gate + tanh_arg = 0.7978845608028654 * (gate + 0.044715 * gate_cubed) + tanh_val = 2.0 / (1.0 + tl.exp(-2.0 * tanh_arg)) - 1.0 + gate = 0.5 * gate * (1.0 + tanh_val) + else: + gate = gate / (1 + tl.exp(-gate)) gate = gate.to(input_ptr.dtype.element_ty) tl.store( @@ -106,9 +115,14 @@ def _get_silu_and_mul_static_key(input: torch.Tensor, output: torch.Tensor): mutates_args=["output"], ) def silu_and_mul_fwd( - input: torch.Tensor, output: torch.Tensor, layout="blocked", limit=None, alpha=None, run_config=None + input: torch.Tensor, + output: torch.Tensor, + layout="blocked", + limit=None, + alpha=None, + run_config=None, ): - assert input.is_contiguous() + assert input.stride(-1) == 1 assert output.is_contiguous() assert (limit is None and alpha is None) or (limit is not None and alpha is not None) @@ -157,5 +171,6 @@ def silu_and_mul_fwd( num_warps=num_warps, layout=layout, USE_LIMIT_AND_ALPHA=USE_LIMIT_AND_ALPHA, + USE_TANH_APPROXIMATE_GELU=ffn_use_tanh_approximate_gelu(), ) return diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py index d2c44b2953..aa91f15ed9 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py @@ -3,6 +3,8 @@ import triton import triton.language as tl +from lightllm.utils.config_utils import ffn_use_tanh_approximate_gelu + @triton.jit def _silu_and_mul_post_quant_kernel( @@ -24,6 +26,7 @@ def _silu_and_mul_post_quant_kernel( fp8_min, BLOCK_N: tl.constexpr, NUM_STAGE: tl.constexpr, + USE_TANH_APPROXIMATE_GELU: tl.constexpr = False, ): expert_id = tl.program_id(2) token_id = tl.program_id(1) @@ -48,7 +51,13 @@ def _silu_and_mul_post_quant_kernel( for token_index in tl.range(token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE): gate = tl.load(input_ptr_offs + token_index * stride_input_1, mask=offs_in_d < size_n, other=0.0).to(tl.float32) up = tl.load(input_ptr_offs + token_index * stride_input_1 + size_n, mask=offs_in_d < size_n, other=0.0) - gate = gate / (1 + tl.exp(-gate)) + if USE_TANH_APPROXIMATE_GELU: + gate_cubed = gate * gate * gate + tanh_arg = 0.7978845608028654 * (gate + 0.044715 * gate_cubed) + tanh_val = 2.0 / (1.0 + tl.exp(-2.0 * tanh_arg)) - 1.0 + gate = 0.5 * gate * (1.0 + tanh_val) + else: + gate = gate / (1 + tl.exp(-gate)) gate = gate.to(input_ptr.dtype.element_ty) gate_up = up * gate _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) @@ -66,7 +75,11 @@ def _silu_and_mul_post_quant_kernel( def silu_and_mul_masked_post_quant_fwd( - input: torch.Tensor, output: torch.Tensor, output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, ): """ input shape [expert_num, token_num_padded, hidden_dim] @@ -122,6 +135,7 @@ def silu_and_mul_masked_post_quant_fwd( fp8_min, BLOCK_N=BLOCK_N, NUM_STAGE=NUM_STAGES, + USE_TANH_APPROXIMATE_GELU=ffn_use_tanh_approximate_gelu(), num_warps=num_warps, ) return diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 59d1f825a3..1c01cbd638 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -17,16 +17,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import torch from lightllm.utils.sgl_utils import sgl_ops -from lightllm.utils.light_utils import light_ops from typing import Callable, List, Optional, Tuple from lightllm.common.basemodel.triton_kernel.fused_moe.softmax_topk import softmax_topk from lightllm.common.triton_utils.autotuner import Autotuner -use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"] - def fused_topk( hidden_states: torch.Tensor, @@ -127,44 +123,6 @@ def biased_grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) -# This is used by the Deepseek-V2 model -def cuda_grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - assert light_ops is not None, "lightllm_kernel is not installed." - - num_tokens = gating_output.shape[0] - topk_weights = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.float32) - topk_indices = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.int32) - token_expert_indices = torch.empty(num_tokens, topk_group, device=hidden_states.device, dtype=torch.int32) - group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32) - if correction_bias is None: - correction_bias = torch.zeros_like(gating_output, dtype=torch.float32) - light_ops.grouped_topk( - topk_weights, - correction_bias, - topk_indices, - token_expert_indices, - gating_output.float(), - num_expert_group, - topk_group, - topk, - renormalize, - scoring_func, - group_scores, - ) - - return topk_weights, topk_indices - - def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -184,34 +142,22 @@ def select_experts( if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - if use_cuda_grouped_topk: - topk_weights, topk_ids = cuda_grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - correction_bias=correction_bias, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - ) - else: - group_score_topk_num = 1 - # for deepseek v3 - if topk_group == 4 and num_expert_group == 8 and top_k == 8: - group_score_topk_num = 2 - - topk_weights, topk_ids = triton_grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - correction_bias=correction_bias, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - group_score_used_topk_num=group_score_topk_num, - ) + group_score_topk_num = 1 + # for deepseek v3 + if topk_group == 4 and num_expert_group == 8 and top_k == 8: + group_score_topk_num = 2 + + topk_weights, topk_ids = triton_grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + correction_bias=correction_bias, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + group_score_used_topk_num=group_score_topk_num, + ) elif custom_routing_function is None: topk_weights, topk_ids = fused_topk( diff --git a/lightllm/common/basemodel/triton_kernel/gather_token_id.py b/lightllm/common/basemodel/triton_kernel/gather_token_id.py index f8181d73c0..16c7528b33 100644 --- a/lightllm/common/basemodel/triton_kernel/gather_token_id.py +++ b/lightllm/common/basemodel/triton_kernel/gather_token_id.py @@ -141,6 +141,75 @@ def gather_token(req_to_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, b return output +@triton.jit +def _fwd_kernel_gather_prefill_decode_mixed( + input_ids, + req_to_next_token_ids, + req_to_next_token_ids_stride, + req_to_next_token_ids_stride_1, + b_req_idx, + b_mtp_index, + b_is_decode_req, + b_prefill_start_loc, + num_size, + BLOCK: tl.constexpr, +): + block_index = tl.program_id(0) + block_range = block_index * BLOCK + tl.arange(0, BLOCK) + block_mask = block_range < num_size + cur_req_idx = tl.load(b_req_idx + block_range, mask=block_mask) + cur_mtp_index = tl.load(b_mtp_index + block_range, mask=block_mask) + cur_next_token_id = tl.load( + req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index, mask=block_mask + ) + cur_is_decode_req = tl.load(b_is_decode_req + block_range, mask=block_mask, other=False) + cur_prefill_start_loc = tl.load(b_prefill_start_loc + block_range, mask=block_mask, other=-1) + + tl.store(input_ids + cur_prefill_start_loc, cur_next_token_id, mask=block_mask & cur_is_decode_req) + return + + +def gather_token_prefill_decode_mixed( + input_ids: torch.Tensor, + req_to_next_token_ids: torch.Tensor, + b_req_idx: torch.Tensor, + b_mtp_index: torch.Tensor, + b_is_decode_req: torch.Tensor, + b_prefill_start_loc: torch.Tensor, +): + """ + This function is used to gather the token_info(CPU tensor) to the token_info(GPU tensor). + Args: + input_ids: (batch_size,) + req_to_next_token_ids: (max_req_num, max_mtp_step) + b_req_idx: (batch_size,) + b_mtp_index: (batch_size,) + b_is_decode_req: (batch_size,) + b_prefill_start_loc: (batch_size,) + Returns: + input_ids: + """ + batch_size = b_req_idx.shape[0] + BLOCK = 256 + grid = (triton.cdiv(batch_size, BLOCK),) + num_warps = 1 + _fwd_kernel_gather_prefill_decode_mixed[grid]( + input_ids=input_ids, + req_to_next_token_ids=req_to_next_token_ids, + req_to_next_token_ids_stride=req_to_next_token_ids.stride(0), + req_to_next_token_ids_stride_1=req_to_next_token_ids.stride(1), + b_req_idx=b_req_idx, + b_mtp_index=b_mtp_index, + b_is_decode_req=b_is_decode_req, + b_prefill_start_loc=b_prefill_start_loc, + num_size=batch_size, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return input_ids + + def test_scatter_token_to_cpu(): batch_size = 30 req_to_token_info = torch.zeros((1000, 1), dtype=torch.float32, pin_memory=True) @@ -166,6 +235,172 @@ def test_gather_token(): print("test_gather_token passed") +def _ref_gather_token_prefill_decode_mixed( + input_ids: torch.Tensor, + req_to_next_token_ids: torch.Tensor, + b_req_idx: torch.Tensor, + b_mtp_index: torch.Tensor, + b_is_decode_req: torch.Tensor, + b_prefill_start_loc: torch.Tensor, +) -> torch.Tensor: + out = input_ids.clone() + table = req_to_next_token_ids.detach().cpu() + req_idx_cpu = b_req_idx.detach().cpu() + mtp_cpu = b_mtp_index.detach().cpu() + is_decode_cpu = b_is_decode_req.detach().cpu() + start_loc_cpu = b_prefill_start_loc.detach().cpu() + for i in range(req_idx_cpu.shape[0]): + if is_decode_cpu[i].item(): + rid = int(req_idx_cpu[i].item()) + mid = int(mtp_cpu[i].item()) + loc = int(start_loc_cpu[i].item()) + out[loc] = table[rid, mid] + return out + + +def _run_gather_token_prefill_decode_mixed_case( + input_ids: torch.Tensor, + req_to_next_token_ids: torch.Tensor, + b_req_idx: torch.Tensor, + b_mtp_index: torch.Tensor, + b_is_decode_req: torch.Tensor, + b_prefill_start_loc: torch.Tensor, +): + input_cuda = input_ids.clone().cuda() + req_table = req_to_next_token_ids.cuda() + b_req_idx_cuda = b_req_idx.cuda() + b_mtp_index_cuda = b_mtp_index.cuda() + b_is_decode_cuda = b_is_decode_req.cuda() + b_start_loc_cuda = b_prefill_start_loc.cuda() + + expected = _ref_gather_token_prefill_decode_mixed( + input_cuda, + req_table, + b_req_idx_cuda, + b_mtp_index_cuda, + b_is_decode_cuda, + b_start_loc_cuda, + ) + gather_token_prefill_decode_mixed( + input_cuda, + req_table, + b_req_idx_cuda, + b_mtp_index_cuda, + b_is_decode_cuda, + b_start_loc_cuda, + ) + diff = (input_cuda - expected).abs().max() + assert diff < 1e-6, f"max diff {diff.item()}" + + +def test_gather_token_prefill_decode_mixed_decode_only(): + """仅 decode 行:按 b_prefill_start_loc 写入 req_to_next_token_ids 中的 next token。""" + req_to_next_token_ids = torch.zeros((32, 4), dtype=torch.int64, device="cuda") + req_to_next_token_ids[3, 0] = 42 + req_to_next_token_ids[7, 0] = 99 + req_to_next_token_ids[11, 2] = 17 + + input_ids = torch.tensor([0, 0, 0], dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor([3, 7, 11], dtype=torch.int32, device="cuda") + b_mtp_index = torch.tensor([0, 0, 2], dtype=torch.int32, device="cuda") + b_is_decode_req = torch.tensor([True, True, True], dtype=torch.bool, device="cuda") + b_prefill_start_loc = torch.tensor([0, 1, 2], dtype=torch.int32, device="cuda") + + _run_gather_token_prefill_decode_mixed_case( + input_ids, req_to_next_token_ids, b_req_idx, b_mtp_index, b_is_decode_req, b_prefill_start_loc + ) + print("test_gather_token_prefill_decode_mixed_decode_only passed") + + +def test_gather_token_prefill_decode_mixed_mixed_batch(): + """prefill + decode 混合:仅 decode 位置被覆盖,prefill token 保持不变。""" + req_to_next_token_ids = torch.zeros((16, 2), dtype=torch.int64, device="cuda") + req_to_next_token_ids[5, 0] = 9001 + + # prefill [10,11,12] | decode placeholder | prefill [20,21] + input_ids = torch.tensor([10, 11, 12, -1, 20, 21], dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor([0, 5, 1], dtype=torch.int32, device="cuda") + b_mtp_index = torch.tensor([0, 0, 0], dtype=torch.int32, device="cuda") + b_is_decode_req = torch.tensor([False, True, False], dtype=torch.bool, device="cuda") + b_q_seq_len = torch.tensor([3, 1, 2], dtype=torch.int32, device="cuda") + b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len + + _run_gather_token_prefill_decode_mixed_case( + input_ids, req_to_next_token_ids, b_req_idx, b_mtp_index, b_is_decode_req, b_prefill_start_loc + ) + print("test_gather_token_prefill_decode_mixed_mixed_batch passed") + + +def test_gather_token_prefill_decode_mixed_prefill_only_unchanged(): + """无 decode 行时 input_ids 不应被修改。""" + req_to_next_token_ids = torch.full((8, 1), 777, dtype=torch.int64, device="cuda") + input_ids = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor([0, 1, 2], dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(3, dtype=torch.int32, device="cuda") + b_is_decode_req = torch.zeros(3, dtype=torch.bool, device="cuda") + b_q_seq_len = torch.tensor([2, 1, 1], dtype=torch.int32, device="cuda") + b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len + + before = input_ids.clone() + gather_token_prefill_decode_mixed( + input_ids, + req_to_next_token_ids, + b_req_idx, + b_mtp_index, + b_is_decode_req, + b_prefill_start_loc, + ) + assert torch.equal(input_ids, before) + print("test_gather_token_prefill_decode_mixed_prefill_only_unchanged passed") + + +def test_gather_token_prefill_decode_mixed_large_batch(): + """batch_size > 256,覆盖多 block 的 triton grid。""" + batch_size = 300 + max_req = 400 + req_to_next_token_ids = torch.arange(max_req * 2, dtype=torch.int64, device="cuda").view(max_req, 2) + input_ids = torch.zeros(batch_size, dtype=torch.int64, device="cuda") + b_req_idx = torch.arange(10, 10 + batch_size, dtype=torch.int32, device="cuda") + b_mtp_index = (b_req_idx % 2).to(torch.int32) + b_is_decode_req = torch.ones(batch_size, dtype=torch.bool, device="cuda") + b_prefill_start_loc = torch.arange(batch_size, dtype=torch.int32, device="cuda") + + _run_gather_token_prefill_decode_mixed_case( + input_ids, req_to_next_token_ids, b_req_idx, b_mtp_index, b_is_decode_req, b_prefill_start_loc + ) + print("test_gather_token_prefill_decode_mixed_large_batch passed") + + +def test_gather_token_prefill_decode_mixed_roundtrip_with_scatter(): + """scatter_token 写入后,mixed gather 能读回同一 next token。""" + batch_size = 16 + req_to_next_token_ids = torch.zeros((64, 3), dtype=torch.float32, pin_memory=True) + token_info = torch.arange(100, 100 + batch_size, dtype=torch.float32, device="cuda") + b_req_idx = torch.arange(4, 4 + batch_size, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + scatter_token(token_info, req_to_next_token_ids, b_req_idx, b_mtp_index) + + input_ids = torch.zeros(batch_size, dtype=torch.int64, device="cuda") + b_is_decode_req = torch.ones(batch_size, dtype=torch.bool, device="cuda") + b_prefill_start_loc = torch.arange(batch_size, dtype=torch.int32, device="cuda") + + gather_token_prefill_decode_mixed( + input_ids, + req_to_next_token_ids, + b_req_idx, + b_mtp_index, + b_is_decode_req, + b_prefill_start_loc, + ) + assert torch.equal(input_ids, token_info.to(torch.int64)) + print("test_gather_token_prefill_decode_mixed_roundtrip_with_scatter passed") + + if __name__ == "__main__": test_scatter_token_to_cpu() test_gather_token() + test_gather_token_prefill_decode_mixed_decode_only() + test_gather_token_prefill_decode_mixed_mixed_batch() + test_gather_token_prefill_decode_mixed_prefill_only_unchanged() + test_gather_token_prefill_decode_mixed_large_batch() + test_gather_token_prefill_decode_mixed_roundtrip_with_scatter() diff --git a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py index 8f9172b552..f066766b15 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py @@ -19,7 +19,7 @@ def _gen_cumsum_pad0_kernel( for start_index in range(0, size, BLOCK): current_offs = start_index + offs - in_data = tl.load(b_q_seq_len + offs, mask=current_offs < size, other=0) + in_data = tl.load(b_q_seq_len + current_offs, mask=current_offs < size, other=0) in_data = tl.cumsum(in_data) + start_value start_value = tl.max(in_data, 0) tl.store(b1_cu_q_seq_len + current_offs + 1, in_data, mask=current_offs < size) @@ -30,7 +30,7 @@ def _gen_cumsum_pad0_kernel( start_value = tl.cast(0, tl.int64) for start_index in range(0, size, BLOCK): current_offs = start_index + offs - in_data = tl.load(b_kv_seq_len + offs * b_kv_seq_len_stride_0, mask=current_offs < size, other=0) + in_data = tl.load(b_kv_seq_len + current_offs * b_kv_seq_len_stride_0, mask=current_offs < size, other=0) in_data = tl.cumsum(in_data) + start_value start_value = tl.max(in_data, 0) tl.store(b1_cu_kv_seq_len + current_offs + 1, in_data, mask=current_offs < size) diff --git a/lightllm/common/basemodel/triton_kernel/kv_move.py b/lightllm/common/basemodel/triton_kernel/kv_move.py new file mode 100644 index 0000000000..d2ce0764b7 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_move.py @@ -0,0 +1,68 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _copy_kv_buffer_to_kv_buffer( + mem_num, + src_mem_index, + dst_mem_index, + kv_buffer, + kv_buffer_stride_l, + kv_buffer_stride_s, + kv_buffer_stride_d, + kv_buffer_tail_dim, + BLOCK: tl.constexpr, +): + layer_index = tl.program_id(0).to(tl.int64) + start_index = tl.program_id(1).to(tl.int64) + grid_num = tl.num_programs(1).to(tl.int64) + + kv_buffer_stride_l = tl.cast(kv_buffer_stride_l, dtype=tl.int64) + kv_buffer_stride_s = tl.cast(kv_buffer_stride_s, dtype=tl.int64) + kv_buffer_stride_d = tl.cast(kv_buffer_stride_d, dtype=tl.int64) + + for i in range(start_index, mem_num, grid_num): + src_mem = tl.load(src_mem_index + i) + dst_mem = tl.load(dst_mem_index + i) + for j in range(tl.cdiv(kv_buffer_tail_dim, BLOCK)): + offs = j * BLOCK + tl.arange(0, BLOCK) + mask = offs < kv_buffer_tail_dim + kv_buffer_data = tl.load( + kv_buffer + layer_index * kv_buffer_stride_l + src_mem * kv_buffer_stride_s + offs, mask=mask + ) + tl.store( + kv_buffer + layer_index * kv_buffer_stride_l + dst_mem * kv_buffer_stride_s + offs, + kv_buffer_data, + mask=mask, + ) + return + + +def copy_kv_buffer_to_kv_buffer( + src_mem_index: torch.Tensor, + dst_mem_index: torch.Tensor, + kv_buffer: torch.Tensor, +): + assert len(src_mem_index) == len(dst_mem_index) + assert src_mem_index.is_cuda and dst_mem_index.is_cuda and kv_buffer.is_cuda + kv_buffer = kv_buffer.view(kv_buffer.shape[0], kv_buffer.shape[1], -1).view(dtype=torch.uint8) + BLOCK = 4096 + layer_num = kv_buffer.shape[0] + grid = ( + layer_num, + 1024, + ) + _copy_kv_buffer_to_kv_buffer[grid]( + mem_num=len(src_mem_index), + src_mem_index=src_mem_index, + dst_mem_index=dst_mem_index, + kv_buffer=kv_buffer, + kv_buffer_stride_l=kv_buffer.stride(0), + kv_buffer_stride_s=kv_buffer.stride(1), + kv_buffer_stride_d=kv_buffer.stride(2), + kv_buffer_tail_dim=kv_buffer.shape[-1], + BLOCK=BLOCK, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py new file mode 100644 index 0000000000..d9f631cbd0 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py @@ -0,0 +1,133 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _copy_linear_att_state_to_kv_buffer( + gpu_conv_ptr, # [linear_layer_num, size_num, xdim] + gpu_ssm_ptr, # [linear_layer_num, size_num, xxdim] + cpu_kv_conv_ptr, # [size, linear_layer_num, xdim] + cpu_kv_ssm_ptr, # [size, linear_layer_num, xxdim] + b_req_idx, # [batch_size,] + big_page_buffer_ids, # [batch_size,] + gpu_conv_stride_l, + gpu_conv_stride_s, + gpu_conv_stride_d, + gpu_ssm_stride_l, + gpu_ssm_stride_s, + gpu_ssm_stride_d, + cpu_kv_conv_stride_s, + cpu_kv_conv_stride_l, + cpu_kv_conv_stride_d, + cpu_kv_ssm_stride_s, + cpu_kv_ssm_stride_l, + cpu_kv_ssm_stride_d, + mtp_step, + gpu_conv_tail_dim, + gpu_ssm_tail_dim, + BLOCK: tl.constexpr, +): + cur_layer = tl.program_id(0).to(tl.int64) + cur_batch = tl.program_id(1).to(tl.int64) + cpu_kv_conv_stride_s = tl.cast(cpu_kv_conv_stride_s, dtype=tl.int64) + cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, dtype=tl.int64) + gpu_conv_stride_s = tl.cast(gpu_conv_stride_s, dtype=tl.int64) + gpu_ssm_stride_s = tl.cast(gpu_ssm_stride_s, dtype=tl.int64) + + big_page_buffer_idx = tl.load(big_page_buffer_ids + cur_batch) + if big_page_buffer_idx == -1: + return + + cur_req_idx = tl.load(b_req_idx + cur_batch).to(tl.int64) + cur_state_req_idx = (cur_req_idx * (mtp_step + 1)).to(tl.int64) + + for i in range(tl.cdiv(gpu_conv_tail_dim, BLOCK)): + gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) + mask = gpu_start_off < gpu_conv_tail_dim + conv_data = tl.load( + gpu_conv_ptr + cur_layer * gpu_conv_stride_l + cur_state_req_idx * gpu_conv_stride_s + gpu_start_off, + mask=mask, + ) + dest_conv_ptr = ( + cpu_kv_conv_ptr + + big_page_buffer_idx * cpu_kv_conv_stride_s + + cur_layer * cpu_kv_conv_stride_l + + gpu_start_off + ) + tl.store(dest_conv_ptr, conv_data, mask=mask) + + for i in range(tl.cdiv(gpu_ssm_tail_dim, BLOCK)): + gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) + mask = gpu_start_off < gpu_ssm_tail_dim + ssm_data = tl.load( + gpu_ssm_ptr + cur_layer * gpu_ssm_stride_l + cur_state_req_idx * gpu_ssm_stride_s + gpu_start_off, + mask=mask, + ) + dest_ssm_ptr = ( + cpu_kv_ssm_ptr + big_page_buffer_idx * cpu_kv_ssm_stride_s + cur_layer * cpu_kv_ssm_stride_l + gpu_start_off + ) + tl.store(dest_ssm_ptr, ssm_data, mask=mask) + + return + + +def copy_linear_att_state_to_kv_buffer( + b_req_idx: torch.Tensor, + big_page_buffer_ids: torch.Tensor, + gpu_conv_state: torch.Tensor, # [linear_layer_num, s, ...] + gpu_ssm_state: torch.Tensor, # [linear_layer_num, s, ...] + cpu_kv_conv_state: torch.Tensor, # [s, linear_layer_num, ...] + cpu_kv_ssm_state: torch.Tensor, # [s, linear_layer_num, ...] + mtp_step: int, +): + assert len(b_req_idx) == big_page_buffer_ids.shape[0] + BLOCK = 4096 + gpu_conv_state = gpu_conv_state.view(gpu_conv_state.shape[0], gpu_conv_state.shape[1], -1).view(dtype=torch.uint8) + gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8) + cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], -1).view( + dtype=torch.uint8 + ) + cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], cpu_kv_ssm_state.shape[1], -1).view( + dtype=torch.uint8 + ) + assert gpu_conv_state.shape[-1] == cpu_kv_conv_state.shape[-1] + assert gpu_ssm_state.shape[-1] == cpu_kv_ssm_state.shape[-1] + assert ( + gpu_conv_state.stride(-1) + == gpu_ssm_state.stride(-1) + == cpu_kv_conv_state.stride(-1) + == cpu_kv_ssm_state.stride(-1) + ) + + gpu_conv_tail_dim = gpu_conv_state.shape[-1] + gpu_ssm_tail_dim = gpu_ssm_state.shape[-1] + + layer_num = gpu_conv_state.shape[0] + + grid = (layer_num, b_req_idx.shape[0]) + + _copy_linear_att_state_to_kv_buffer[grid]( + gpu_conv_ptr=gpu_conv_state, + gpu_ssm_ptr=gpu_ssm_state, + cpu_kv_conv_ptr=cpu_kv_conv_state, + cpu_kv_ssm_ptr=cpu_kv_ssm_state, + b_req_idx=b_req_idx, + big_page_buffer_ids=big_page_buffer_ids, + gpu_conv_stride_l=gpu_conv_state.stride(0), + gpu_conv_stride_s=gpu_conv_state.stride(1), + gpu_conv_stride_d=gpu_conv_state.stride(2), + gpu_ssm_stride_l=gpu_ssm_state.stride(0), + gpu_ssm_stride_s=gpu_ssm_state.stride(1), + gpu_ssm_stride_d=gpu_ssm_state.stride(2), + cpu_kv_conv_stride_s=cpu_kv_conv_state.stride(0), + cpu_kv_conv_stride_l=cpu_kv_conv_state.stride(1), + cpu_kv_conv_stride_d=cpu_kv_conv_state.stride(2), + cpu_kv_ssm_stride_s=cpu_kv_ssm_state.stride(0), + cpu_kv_ssm_stride_l=cpu_kv_ssm_state.stride(1), + cpu_kv_ssm_stride_d=cpu_kv_ssm_state.stride(2), + mtp_step=mtp_step, + gpu_conv_tail_dim=gpu_conv_tail_dim, + gpu_ssm_tail_dim=gpu_ssm_tail_dim, + BLOCK=BLOCK, + ) diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py new file mode 100644 index 0000000000..37b27cadb2 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py @@ -0,0 +1,536 @@ +import torch +import triton +import triton.language as tl +from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + + +@triton.jit +def _copy_kv_buffer_to_cpu_cache( + page_num, + mem_indexes_ptr, # [move_token_num] + page_indexes_ptr, # [page_num], + page_readies_ptr, # [page_num], + big_page_buffer_ids, # [page_num] + cpu_cache_full_att, # [all_page_num, head, xdim] + cpu_cache_full_att_stride_p, + cpu_cache_full_att_stride_h, + cpu_cache_full_att_stride_d, + cpu_cache_conv, # [all_page_num, tp_world_size, xxdim] + cpu_cache_conv_stride_p, + cpu_cache_conv_stride_t, + cpu_cache_conv_stride_d, + cpu_cache_ssm, # [all_page_num, tp_world_size, xxxdim] + cpu_cache_ssm_stride_p, + cpu_cache_ssm_stride_t, + cpu_cache_ssm_stride_d, + gpu_kv_full_att_state, # [token_size, full_att_layer_num, xdim] + gpu_kv_full_att_stride_s, + gpu_kv_full_att_stride_l, + gpu_kv_full_att_stride_d, + cpu_kv_conv_state, # [buffer_count, xxxxxdim] + cpu_kv_conv_stride_s, + cpu_kv_conv_stride_d, + cpu_kv_ssm_state, # [buffer_count, xxxxxxxdim] + cpu_kv_ssm_stride_s, + cpu_kv_ssm_stride_d, + gpu_full_att_tail_dim, + cpu_kv_conv_tail_dim, + cpu_kv_ssm_tail_dim, + tp_rank, + full_att_layer_num, + big_page_token_num, + head_scale_size, + BLOCK: tl.constexpr, +): + split_index_start = tl.program_id(0) + grid_num = tl.num_programs(0) + # 将 所有stride 切成 tl.int64 + cpu_cache_full_att_stride_p = tl.cast(cpu_cache_full_att_stride_p, tl.int64) + cpu_cache_full_att_stride_h = tl.cast(cpu_cache_full_att_stride_h, tl.int64) + cpu_cache_full_att_stride_d = tl.cast(cpu_cache_full_att_stride_d, tl.int64) + cpu_cache_conv_stride_p = tl.cast(cpu_cache_conv_stride_p, tl.int64) + cpu_cache_conv_stride_t = tl.cast(cpu_cache_conv_stride_t, tl.int64) + cpu_cache_conv_stride_d = tl.cast(cpu_cache_conv_stride_d, tl.int64) + cpu_cache_ssm_stride_p = tl.cast(cpu_cache_ssm_stride_p, tl.int64) + cpu_cache_ssm_stride_t = tl.cast(cpu_cache_ssm_stride_t, tl.int64) + cpu_cache_ssm_stride_d = tl.cast(cpu_cache_ssm_stride_d, tl.int64) + gpu_kv_full_att_stride_s = tl.cast(gpu_kv_full_att_stride_s, tl.int64) + gpu_kv_full_att_stride_l = tl.cast(gpu_kv_full_att_stride_l, tl.int64) + gpu_kv_full_att_stride_d = tl.cast(gpu_kv_full_att_stride_d, tl.int64) + cpu_kv_conv_stride_s = tl.cast(cpu_kv_conv_stride_s, tl.int64) + cpu_kv_conv_stride_d = tl.cast(cpu_kv_conv_stride_d, tl.int64) + cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, tl.int64) + cpu_kv_ssm_stride_d = tl.cast(cpu_kv_ssm_stride_d, tl.int64) + + for block_index in range(page_num): + cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64) + run_flag = 1 + if cpu_page_index == -1: + run_flag = 0 + ready_state = tl.load(page_readies_ptr + block_index) + if ready_state: + run_flag = 0 + if tp_rank % head_scale_size == 0: + head_flag = 1 + else: + head_flag = 0 + + mem_start_ptr = mem_indexes_ptr + big_page_token_num * block_index + for i in range(split_index_start, tl.cdiv(gpu_full_att_tail_dim, BLOCK) * run_flag * head_flag, grid_num): + gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) + mask = gpu_start_i < gpu_full_att_tail_dim + per_token_size = gpu_full_att_tail_dim // big_page_token_num + per_layer_size = per_token_size // full_att_layer_num + mem_offs = gpu_start_i // (per_token_size) + mem_index = tl.load(mem_start_ptr + mem_offs, mask=mask, other=-1) + layer_index = (gpu_start_i // (per_layer_size)) % full_att_layer_num + dim_index = gpu_start_i % per_layer_size + gpu_full_att_data = tl.load( + gpu_kv_full_att_state + + mem_index * gpu_kv_full_att_stride_s + + layer_index * gpu_kv_full_att_stride_l + + dim_index * gpu_kv_full_att_stride_d, + mask=mask & (mem_index != -1), + other=0, + ) + dest_cpu_cache_full_att_ptr = ( + cpu_cache_full_att + + cpu_page_index * cpu_cache_full_att_stride_p + + (tp_rank // head_scale_size) * cpu_cache_full_att_stride_h + + gpu_start_i + ) + tl.store(dest_cpu_cache_full_att_ptr, gpu_full_att_data, mask=mask & (mem_index != -1)) + + big_page_idx = tl.load(big_page_buffer_ids + block_index) + + for i in range(split_index_start, tl.cdiv(cpu_kv_conv_tail_dim, BLOCK) * run_flag, grid_num): + gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) + mask = gpu_start_i < cpu_kv_conv_tail_dim + cpu_kv_conv_data = tl.load( + cpu_kv_conv_state + big_page_idx * cpu_kv_conv_stride_s + gpu_start_i, + mask=mask, + other=0, + ) + dest_cpu_cache_conv_ptr = ( + cpu_cache_conv + + cpu_page_index * cpu_cache_conv_stride_p + + tp_rank * cpu_cache_conv_stride_t + + gpu_start_i + ) + tl.store(dest_cpu_cache_conv_ptr, cpu_kv_conv_data, mask=mask) + + for i in range(split_index_start, tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK) * run_flag, grid_num): + gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) + mask = gpu_start_i < cpu_kv_ssm_tail_dim + + cpu_kv_ssm_data = tl.load( + cpu_kv_ssm_state + big_page_idx * cpu_kv_ssm_stride_s + gpu_start_i, + mask=mask, + other=0, + ) + dest_cpu_cache_ssm_ptr = ( + cpu_cache_ssm + cpu_page_index * cpu_cache_ssm_stride_p + tp_rank * cpu_cache_ssm_stride_t + gpu_start_i + ) + tl.store(dest_cpu_cache_ssm_ptr, cpu_kv_ssm_data, mask=mask) + + return + + +def copy_kv_buffer_to_cpu_cache( + mem_indexes: torch.Tensor, + page_indexes: torch.Tensor, + page_readies: torch.Tensor, + big_page_buffer_ids: torch.Tensor, + gpu_kv_full_att_state: torch.Tensor, # [full_att_layer_num, s, head_num, head_dim] + cpu_kv_conv_state: torch.Tensor, # [s, linear_layer_num, dim] + cpu_kv_ssm_state: torch.Tensor, # [s, linear_layer_num, xdim] + cpu_cache_tensor: torch.Tensor, # [page_num, 1, 1, 1, xxdim] + tp_rank: int, + tp_world_size: int, + big_page_token_num: int, + linear_config: LinearAttCacheConfig, + grid_num: int = 12, +): + assert len(page_indexes) == len(page_readies) == len(big_page_buffer_ids) + assert len(mem_indexes) % len(page_indexes) == 0 + + BLOCK = 4096 + if linear_config.full_att_all_num_kv_heads % tp_world_size == 0: + # tp world size 不比 kv 的 head 多时 + head_scale_size = 1 + else: + head_scale_size = tp_world_size // linear_config.full_att_all_num_kv_heads + + cpu_page_num = cpu_cache_tensor.shape[0] + cpu_cache_tensor = cpu_cache_tensor.view(cpu_page_num, -1).view(dtype=torch.uint8) + a = linear_config.get_cpu_cache_full_att_bytes() + b = linear_config.get_cpu_cache_conv_bytes() + c = linear_config.get_cpu_cache_ssm_bytes() + + if head_scale_size == 1: + cpu_cache_full_att = cpu_cache_tensor[:, 0:a].view(cpu_page_num, tp_world_size, -1) + else: + cpu_cache_full_att = cpu_cache_tensor[:, 0:a].view(cpu_page_num, linear_config.full_att_all_num_kv_heads, -1) + + cpu_cache_full_att = cpu_cache_full_att.view(dtype=torch.uint64) + + cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64) + cpu_cache_ssm = ( + cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64) + ) + + gpu_kv_full_att_state = gpu_kv_full_att_state.view( + gpu_kv_full_att_state.shape[0], gpu_kv_full_att_state.shape[1], -1 + ).view(dtype=torch.uint64) + + gpu_kv_full_att_state = gpu_kv_full_att_state.permute(1, 0, 2) # [s, layer_num, xxdim] + + cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint64) + cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint64) + + gpu_full_att_tail_dim = gpu_kv_full_att_state.shape[-1] * gpu_kv_full_att_state.shape[-2] * big_page_token_num + cpu_kv_conv_tail_dim = cpu_kv_conv_state.shape[-1] + cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] + full_att_layer_num = gpu_kv_full_att_state.shape[-2] + + assert ( + full_att_layer_num + == (linear_config.all_layer_num // linear_config.full_attention_interval) + == (linear_config.all_layer_num - linear_config.linear_layer_num) + ) + assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] + assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] + assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] + assert gpu_kv_full_att_state.stride(2) == 1 + assert ( + gpu_full_att_tail_dim % big_page_token_num == 0 + and (gpu_full_att_tail_dim // big_page_token_num) % full_att_layer_num == 0 + ) + assert (tp_rank // head_scale_size) < linear_config.full_att_all_num_kv_heads + + grid = (grid_num,) + _copy_kv_buffer_to_cpu_cache[grid]( + page_num=len(page_indexes), + mem_indexes_ptr=mem_indexes, + page_indexes_ptr=page_indexes, + page_readies_ptr=page_readies, + big_page_buffer_ids=big_page_buffer_ids, + cpu_cache_full_att=cpu_cache_full_att, + cpu_cache_full_att_stride_p=cpu_cache_full_att.stride(0), + cpu_cache_full_att_stride_h=cpu_cache_full_att.stride(1), + cpu_cache_full_att_stride_d=cpu_cache_full_att.stride(2), + cpu_cache_conv=cpu_cache_conv, + cpu_cache_conv_stride_p=cpu_cache_conv.stride(0), + cpu_cache_conv_stride_t=cpu_cache_conv.stride(1), + cpu_cache_conv_stride_d=cpu_cache_conv.stride(2), + cpu_cache_ssm=cpu_cache_ssm, + cpu_cache_ssm_stride_p=cpu_cache_ssm.stride(0), + cpu_cache_ssm_stride_t=cpu_cache_ssm.stride(1), + cpu_cache_ssm_stride_d=cpu_cache_ssm.stride(2), + gpu_kv_full_att_state=gpu_kv_full_att_state, + gpu_kv_full_att_stride_s=gpu_kv_full_att_state.stride(0), + gpu_kv_full_att_stride_l=gpu_kv_full_att_state.stride(1), + gpu_kv_full_att_stride_d=gpu_kv_full_att_state.stride(2), + cpu_kv_conv_state=cpu_kv_conv_state, + cpu_kv_conv_stride_s=cpu_kv_conv_state.stride(0), + cpu_kv_conv_stride_d=cpu_kv_conv_state.stride(1), + cpu_kv_ssm_state=cpu_kv_ssm_state, + cpu_kv_ssm_stride_s=cpu_kv_ssm_state.stride(0), + cpu_kv_ssm_stride_d=cpu_kv_ssm_state.stride(1), + gpu_full_att_tail_dim=gpu_full_att_tail_dim, + cpu_kv_conv_tail_dim=cpu_kv_conv_tail_dim, + cpu_kv_ssm_tail_dim=cpu_kv_ssm_tail_dim, + tp_rank=tp_rank, + full_att_layer_num=full_att_layer_num, + big_page_token_num=big_page_token_num, + head_scale_size=head_scale_size, + BLOCK=BLOCK, + ) + + +@triton.jit +def _copy_cpu_cache_to_kv_buffer( + page_num, + mem_indexes_ptr, # [move_token_num] + page_indexes_ptr, # [page_num], + big_page_buffer_ids, # [page_num] + cpu_cache_full_att, # [all_page_num, head, xdim] + cpu_cache_full_att_stride_p, + cpu_cache_full_att_stride_h, + cpu_cache_full_att_stride_d, + cpu_cache_conv, # [all_page_num, tp_world_size, xxdim] + cpu_cache_conv_stride_p, + cpu_cache_conv_stride_t, + cpu_cache_conv_stride_d, + cpu_cache_ssm, # [all_page_num, tp_world_size, xxxdim] + cpu_cache_ssm_stride_p, + cpu_cache_ssm_stride_t, + cpu_cache_ssm_stride_d, + gpu_kv_full_att_state, # [token_size, full_att_layer_num, xdim] + gpu_kv_full_att_stride_s, + gpu_kv_full_att_stride_l, + gpu_kv_full_att_stride_d, + cpu_kv_conv_state, # [buffer_count, xxxxxdim] + cpu_kv_conv_stride_s, + cpu_kv_conv_stride_d, + cpu_kv_ssm_state, # [buffer_count, xxxxxxxdim] + cpu_kv_ssm_stride_s, + cpu_kv_ssm_stride_d, + gpu_full_att_tail_dim, + cpu_kv_conv_tail_dim, + cpu_kv_ssm_tail_dim, + tp_rank, + full_att_layer_num, + big_page_token_num, + head_scale_size, + BLOCK: tl.constexpr, +): + split_index_start = tl.program_id(0) + grid_num = tl.num_programs(0) + # 将 所有stride 切成 tl.int64 + cpu_cache_full_att_stride_p = tl.cast(cpu_cache_full_att_stride_p, tl.int64) + cpu_cache_full_att_stride_h = tl.cast(cpu_cache_full_att_stride_h, tl.int64) + cpu_cache_full_att_stride_d = tl.cast(cpu_cache_full_att_stride_d, tl.int64) + cpu_cache_conv_stride_p = tl.cast(cpu_cache_conv_stride_p, tl.int64) + cpu_cache_conv_stride_t = tl.cast(cpu_cache_conv_stride_t, tl.int64) + cpu_cache_conv_stride_d = tl.cast(cpu_cache_conv_stride_d, tl.int64) + cpu_cache_ssm_stride_p = tl.cast(cpu_cache_ssm_stride_p, tl.int64) + cpu_cache_ssm_stride_t = tl.cast(cpu_cache_ssm_stride_t, tl.int64) + cpu_cache_ssm_stride_d = tl.cast(cpu_cache_ssm_stride_d, tl.int64) + gpu_kv_full_att_stride_s = tl.cast(gpu_kv_full_att_stride_s, tl.int64) + gpu_kv_full_att_stride_l = tl.cast(gpu_kv_full_att_stride_l, tl.int64) + gpu_kv_full_att_stride_d = tl.cast(gpu_kv_full_att_stride_d, tl.int64) + cpu_kv_conv_stride_s = tl.cast(cpu_kv_conv_stride_s, tl.int64) + cpu_kv_conv_stride_d = tl.cast(cpu_kv_conv_stride_d, tl.int64) + cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, tl.int64) + cpu_kv_ssm_stride_d = tl.cast(cpu_kv_ssm_stride_d, tl.int64) + + for block_index in range(page_num): + cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64) + + mem_start_ptr = mem_indexes_ptr + big_page_token_num * block_index + for i in range(split_index_start, tl.cdiv(gpu_full_att_tail_dim, BLOCK), grid_num): + gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) + mask = gpu_start_i < gpu_full_att_tail_dim + per_token_size = gpu_full_att_tail_dim // big_page_token_num + per_layer_size = per_token_size // full_att_layer_num + mem_offs = gpu_start_i // (per_token_size) + mem_index = tl.load(mem_start_ptr + mem_offs, mask=mask, other=-1) + layer_index = (gpu_start_i // (per_layer_size)) % full_att_layer_num + dim_index = gpu_start_i % per_layer_size + + src_cpu_cache_full_att_ptr = ( + cpu_cache_full_att + + cpu_page_index * cpu_cache_full_att_stride_p + + (tp_rank // head_scale_size) * cpu_cache_full_att_stride_h + + gpu_start_i + ) + cpu_full_att_data = tl.load(src_cpu_cache_full_att_ptr, mask=mask & (mem_index != -1), other=0) + + tl.store( + gpu_kv_full_att_state + + mem_index * gpu_kv_full_att_stride_s + + layer_index * gpu_kv_full_att_stride_l + + dim_index * gpu_kv_full_att_stride_d, + cpu_full_att_data, + mask=mask & (mem_index != -1), + ) + + big_page_idx = tl.load(big_page_buffer_ids + block_index) + + for i in range(split_index_start, tl.cdiv(cpu_kv_conv_tail_dim, BLOCK), grid_num): + gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) + mask = gpu_start_i < cpu_kv_conv_tail_dim + + src_cpu_cache_conv_ptr = ( + cpu_cache_conv + + cpu_page_index * cpu_cache_conv_stride_p + + tp_rank * cpu_cache_conv_stride_t + + gpu_start_i + ) + cpu_kv_conv_data = tl.load(src_cpu_cache_conv_ptr, mask=mask, other=0) + + tl.store( + cpu_kv_conv_state + big_page_idx * cpu_kv_conv_stride_s + gpu_start_i, + cpu_kv_conv_data, + mask=mask, + ) + + for i in range(split_index_start, tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK), grid_num): + gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) + mask = gpu_start_i < cpu_kv_ssm_tail_dim + + src_cpu_cache_ssm_ptr = ( + cpu_cache_ssm + cpu_page_index * cpu_cache_ssm_stride_p + tp_rank * cpu_cache_ssm_stride_t + gpu_start_i + ) + cpu_kv_ssm_data = tl.load(src_cpu_cache_ssm_ptr, mask=mask, other=0) + + tl.store( + cpu_kv_ssm_state + big_page_idx * cpu_kv_ssm_stride_s + gpu_start_i, + cpu_kv_ssm_data, + mask=mask, + ) + + return + + +def copy_cpu_cache_to_kv_buffer( + mem_indexes: torch.Tensor, + big_page_buffer_ids: torch.Tensor, + page_indexes: torch.Tensor, + gpu_full_att_kv_state: torch.Tensor, # [layer_num, s, head_num, head_dim] + cpu_kv_conv_state: torch.Tensor, # [layer_num, s, dim] + cpu_kv_ssm_state: torch.Tensor, # [layer_num, s, xdim] + cpu_cache_tensor: torch.Tensor, # [page_num, 1, 1, tp_world_size, xxdim] + tp_rank: int, + tp_world_size: int, + big_page_token_num: int, + linear_config: LinearAttCacheConfig, + grid_num: int = 12, +): + + assert len(mem_indexes) % len(page_indexes) == 0 + + BLOCK = 4096 + if linear_config.full_att_all_num_kv_heads % tp_world_size == 0: + head_scale_size = 1 + else: + head_scale_size = tp_world_size // linear_config.full_att_all_num_kv_heads + + cpu_page_num = cpu_cache_tensor.shape[0] + cpu_cache_tensor = cpu_cache_tensor.view(cpu_page_num, -1).view(dtype=torch.uint8) + a = linear_config.get_cpu_cache_full_att_bytes() + b = linear_config.get_cpu_cache_conv_bytes() + c = linear_config.get_cpu_cache_ssm_bytes() + + if head_scale_size == 1: + cpu_cache_full_att = cpu_cache_tensor[:, 0:a].view(cpu_page_num, tp_world_size, -1) + else: + cpu_cache_full_att = cpu_cache_tensor[:, 0:a].view(cpu_page_num, linear_config.full_att_all_num_kv_heads, -1) + + cpu_cache_full_att = cpu_cache_full_att.view(dtype=torch.uint64) + + cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64) + cpu_cache_ssm = ( + cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64) + ) + + gpu_full_att_kv_state = gpu_full_att_kv_state.view( + gpu_full_att_kv_state.shape[0], gpu_full_att_kv_state.shape[1], -1 + ).view(dtype=torch.uint64) + gpu_full_att_kv_state = gpu_full_att_kv_state.permute(1, 0, 2) # [s, layer_num, xxdim] + + cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint64) + cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint64) + + gpu_full_att_tail_dim = gpu_full_att_kv_state.shape[-1] * gpu_full_att_kv_state.shape[-2] * big_page_token_num + cpu_kv_conv_tail_dim = cpu_kv_conv_state.shape[-1] + cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] + full_att_layer_num = gpu_full_att_kv_state.shape[-2] + + assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] + assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] + assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] + assert gpu_full_att_kv_state.stride(2) == 1 + + assert (tp_rank // head_scale_size) < linear_config.full_att_all_num_kv_heads + + grid = (grid_num,) + _copy_cpu_cache_to_kv_buffer[grid]( + page_num=len(page_indexes), + mem_indexes_ptr=mem_indexes, + page_indexes_ptr=page_indexes, + big_page_buffer_ids=big_page_buffer_ids, + cpu_cache_full_att=cpu_cache_full_att, + cpu_cache_full_att_stride_p=cpu_cache_full_att.stride(0), + cpu_cache_full_att_stride_h=cpu_cache_full_att.stride(1), + cpu_cache_full_att_stride_d=cpu_cache_full_att.stride(2), + cpu_cache_conv=cpu_cache_conv, + cpu_cache_conv_stride_p=cpu_cache_conv.stride(0), + cpu_cache_conv_stride_t=cpu_cache_conv.stride(1), + cpu_cache_conv_stride_d=cpu_cache_conv.stride(2), + cpu_cache_ssm=cpu_cache_ssm, + cpu_cache_ssm_stride_p=cpu_cache_ssm.stride(0), + cpu_cache_ssm_stride_t=cpu_cache_ssm.stride(1), + cpu_cache_ssm_stride_d=cpu_cache_ssm.stride(2), + gpu_kv_full_att_state=gpu_full_att_kv_state, + gpu_kv_full_att_stride_s=gpu_full_att_kv_state.stride(0), + gpu_kv_full_att_stride_l=gpu_full_att_kv_state.stride(1), + gpu_kv_full_att_stride_d=gpu_full_att_kv_state.stride(2), + cpu_kv_conv_state=cpu_kv_conv_state, + cpu_kv_conv_stride_s=cpu_kv_conv_state.stride(0), + cpu_kv_conv_stride_d=cpu_kv_conv_state.stride(1), + cpu_kv_ssm_state=cpu_kv_ssm_state, + cpu_kv_ssm_stride_s=cpu_kv_ssm_state.stride(0), + cpu_kv_ssm_stride_d=cpu_kv_ssm_state.stride(1), + gpu_full_att_tail_dim=gpu_full_att_tail_dim, + cpu_kv_conv_tail_dim=cpu_kv_conv_tail_dim, + cpu_kv_ssm_tail_dim=cpu_kv_ssm_tail_dim, + tp_rank=tp_rank, + full_att_layer_num=full_att_layer_num, + big_page_token_num=big_page_token_num, + head_scale_size=head_scale_size, + BLOCK=BLOCK, + ) + + +@triton.jit +def _copy_linear_att_state_to_linear_att_state( + src_conv_state, + dst_conv_state, + src_ssm_state, + dst_ssm_state, + conv_size, + ssm_size, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + grid_num = tl.num_programs(0) + + # copy conv state + num_conv_blocks = tl.cdiv(conv_size, BLOCK) + for i in range(pid, num_conv_blocks, grid_num): + start = i * BLOCK + tl.arange(0, BLOCK) + mask = start < conv_size + data = tl.load(src_conv_state + start, mask=mask, other=0) + tl.store(dst_conv_state + start, data, mask=mask) + + # copy ssm state + num_ssm_blocks = tl.cdiv(ssm_size, BLOCK) + for i in range(pid, num_ssm_blocks, grid_num): + start = i * BLOCK + tl.arange(0, BLOCK) + mask = start < ssm_size + data = tl.load(src_ssm_state + start, mask=mask, other=0) + tl.store(dst_ssm_state + start, data, mask=mask) + + +def copy_linear_att_state_to_linear_att_state( + src_conv_state: torch.Tensor, + src_ssm_state: torch.Tensor, + dst_conv_state: torch.Tensor, + dst_ssm_state: torch.Tensor, + grid_num: int = 16, +): + assert src_conv_state.shape == dst_conv_state.shape + assert src_ssm_state.shape == dst_ssm_state.shape + + BLOCK = 4096 + + src_conv_flat = src_conv_state.view(-1).view(dtype=torch.uint8) + dst_conv_flat = dst_conv_state.view(-1).view(dtype=torch.uint8) + src_ssm_flat = src_ssm_state.view(-1).view(dtype=torch.uint8) + dst_ssm_flat = dst_ssm_state.view(-1).view(dtype=torch.uint8) + + conv_size = src_conv_flat.shape[0] + ssm_size = src_ssm_flat.shape[0] + + grid = (grid_num,) + _copy_linear_att_state_to_linear_att_state[grid]( + src_conv_state=src_conv_flat, + dst_conv_state=dst_conv_flat, + src_ssm_state=src_ssm_flat, + dst_ssm_state=dst_ssm_flat, + conv_size=conv_size, + ssm_size=ssm_size, + BLOCK=BLOCK, + ) diff --git a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py index e2d4aea587..05d678e41b 100644 --- a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py +++ b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py @@ -23,6 +23,8 @@ def _fwd_kernel( tp_text_end_token_id, hidden_size, tp_world_size, + APPLY_TEXT_EMBED_SCALE: tl.constexpr, + TEXT_EMBED_SCALE: tl.constexpr, BLOCK_HIDDEN_DIM: tl.constexpr, ): @@ -43,6 +45,8 @@ def _fwd_kernel( mask=off_d < hidden_size, other=0, ) + if APPLY_TEXT_EMBED_SCALE: + load_emb *= TEXT_EMBED_SCALE tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size) img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0) @@ -84,9 +88,12 @@ def multimodal_emb( tp_text_start_token_id: int, tp_text_end_token_id: int, tp_world_size: int, + text_embed_scale: float = 1.0, ): total_len = prompt_ids.shape[0] BLOCK = triton.next_power_of_2(out.shape[1]) + text_embed_scale = float(text_embed_scale) + apply_text_embed_scale = text_embed_scale != 1.0 # print(len(img_token_lens)) grid = (total_len, len(img_token_lens) + 1) num_warps = 1 @@ -109,6 +116,8 @@ def multimodal_emb( tp_text_end_token_id=tp_text_end_token_id, hidden_size=out.shape[1], tp_world_size=float(tp_world_size), + APPLY_TEXT_EMBED_SCALE=apply_text_embed_scale, + TEXT_EMBED_SCALE=text_embed_scale, BLOCK_HIDDEN_DIM=BLOCK, num_warps=num_warps, num_stages=1, diff --git a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py index 89db5e00cb..c62c5eb5d2 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py @@ -16,7 +16,6 @@ def gated_rmsnorm_forward_kernel( W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch (required, not optional) - Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_z_row, @@ -33,7 +32,6 @@ def gated_rmsnorm_forward_kernel( X += row * stride_x_row + group * N Y += row * stride_y_row + group * N Z += row * stride_z_row + group * N - Rstd += group * M W += group * N if HAS_BIAS: B += group * N @@ -47,7 +45,6 @@ def gated_rmsnorm_forward_kernel( xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) @@ -128,9 +125,6 @@ def gated_rmsnorm_forward( else: out = torch.empty_like(x) assert out.stride(-1) == 1 - # For RMS norm, we still need rstd for the kernel - rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) - # Default heuristic when autotune is disabled or no config provided if not run_config: # Less than 64KB per feature: enqueue fused kernel @@ -160,7 +154,6 @@ def gated_rmsnorm_forward( weight, bias, z, - rstd, x.stride(0), out.stride(0), z.stride(0), diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index ca8f9a1c81..8dc8558922 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -18,6 +18,7 @@ def _rms_norm_fwd_fused( y_stride1, N, # number of columns in X eps, # epsilon to avoid division by zero + HAS_WEIGHT: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. @@ -32,14 +33,17 @@ def _rms_norm_fwd_fused( _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - # Normalize and apply linear transformation + # Normalize and optionally apply linear transformation for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = x * rstd - y = x_hat * w + y = x_hat + if HAS_WEIGHT: + y = x_hat * w # Write output tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) @@ -50,7 +54,9 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) # reshape input data into 2D tensor x_arg = x.view(-1, x.shape[-1]) y_arg = y.view(-1, x.shape[-1]) - assert x_arg.shape[-1] == weight.shape[0] and x_arg.shape == y_arg.shape + assert x_arg.shape == y_arg.shape + if weight is not None: + assert x_arg.shape[-1] == weight.shape[0] assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel @@ -73,6 +79,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) y_arg.stride(1), N, eps, + HAS_WEIGHT=weight is not None, BLOCK_SIZE=BLOCK_SIZE, num_warps=rmsnorm_num_warps, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py b/lightllm/common/basemodel/triton_kernel/post_process/__init__.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py rename to lightllm/common/basemodel/triton_kernel/post_process/__init__.py diff --git a/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py new file mode 100644 index 0000000000..353affd8ed --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py @@ -0,0 +1,36 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_apply_invalid_token( + Logits, + invalid_token_ids, + cu_invalid_token_num, + stride_logit_b, +): + cur_batch = tl.program_id(0) + start_index = tl.load(cu_invalid_token_num + cur_batch) + end_index = tl.load(cu_invalid_token_num + cur_batch + 1) + for i in range(start_index, end_index): + cur_invalid_token_id = tl.load(invalid_token_ids + i) + cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id + tl.store(cur_logit_ptr, float("-inf")) + return + + +def apply_invalid_token_ids( + Logits: torch.Tensor, + invalid_token_ids: torch.Tensor, + cu_invalid_token_num: torch.Tensor, +): + batch_size = Logits.shape[0] + grid = (batch_size,) + _fwd_kernel_apply_invalid_token[grid]( + Logits=Logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + stride_logit_b=Logits.stride(0), + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py index e86d2e819e..c218d15e08 100644 --- a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -36,8 +36,6 @@ def _fwd_kernel_repack_kv_index( @torch.no_grad() def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): batch_size = req_index.shape[0] - # flashinfer requires out_kv_index to be zeroed before use - out_kv_index.zero_() BLOCK = 64 grid = ( batch_size, diff --git a/lightllm/common/basemodel/triton_kernel/sp_pad_copy.py b/lightllm/common/basemodel/triton_kernel/sp_pad_copy.py index 7c463c2ca7..60fa0ebac3 100644 --- a/lightllm/common/basemodel/triton_kernel/sp_pad_copy.py +++ b/lightllm/common/basemodel/triton_kernel/sp_pad_copy.py @@ -48,7 +48,9 @@ def sp_pad_copy(in_tensor: torch.Tensor, sp_rank_id: int, sp_world_size: int, al start = sp_rank_id * split_size end = start + split_size return in_tensor[start:end, :] - + assert ( + in_token_num % sp_world_size == 0 + ), f"in_token_num % sp_world_size != 0, in_token_num: {in_token_num}, sp_world_size: {sp_world_size}" out_token_num = triton.cdiv(in_token_num, sp_world_size) * sp_world_size // sp_world_size out_tensor = alloc_func((out_token_num, hidden_dim), dtype=in_tensor.dtype, device=in_tensor.device) out_token_start_index = out_token_num * sp_rank_id diff --git a/lightllm/common/cpu_cache/creator.py b/lightllm/common/cpu_cache/creator.py index 7d03f0c89c..e5d86a555e 100644 --- a/lightllm/common/cpu_cache/creator.py +++ b/lightllm/common/cpu_cache/creator.py @@ -2,7 +2,7 @@ import torch import numpy as np from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Tuple from lightllm.utils.kv_cache_utils import attach_shm_kv_cache_ptr, create_shm_kv_cache_ptr, register_shm_ptr_to_pin @@ -18,25 +18,21 @@ class CpuCacheCreator: def __init__(self, tensor_spec: CpuCacheTensorSpec): self.tensor_spec = tensor_spec - def create_or_attach( - self, init_shm_data: bool, pin: bool, pin_no_blocking: bool - ) -> Tuple[Optional[torch.Tensor], Optional[object]]: + def create_or_attach(self, init_shm_data: bool, pin: bool) -> torch.Tensor: if init_shm_data: shm_ptr = create_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes) else: shm_ptr = attach_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes) if pin: - attach_handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.tensor_spec.size_bytes) - # 是否阻塞等待pin 完成 - if not pin_no_blocking: - attach_handle.wait() - cpu_cache_tensor = self._build_tensor_view(shm_ptr=shm_ptr) - assert shm_ptr == cpu_cache_tensor.data_ptr() - return cpu_cache_tensor, attach_handle + device_ptr = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.tensor_spec.size_bytes) + cpu_cache_tensor = self._build_tensor_view(shm_ptr=device_ptr) + assert device_ptr == cpu_cache_tensor.data_ptr() else: cpu_cache_tensor = self._build_tensor_view(shm_ptr=shm_ptr) - return cpu_cache_tensor, None + assert shm_ptr == cpu_cache_tensor.data_ptr() + + return cpu_cache_tensor def _build_tensor_view(self, shm_ptr: int) -> torch.Tensor: numpy_array = np.frombuffer( diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 79e75b3485..05544e149a 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -1,3 +1,4 @@ +from .allocator import KvCacheAllocator from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager @@ -6,8 +7,10 @@ from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager +from .qwen3next_mem_manager import Qwen3NextMemManager __all__ = [ + "KvCacheAllocator", "MemoryManager", "ReadOnlyStaticsMemoryManager", "PPLINT4KVMemoryManager", @@ -17,4 +20,5 @@ "FP8PerTokenGroupQuantDeepseek3_2MemoryManager", "FP8StaticPerHeadQuantMemManager", "FP8StaticPerTensorQuantMemManager", + "Qwen3NextMemManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/allocator.py b/lightllm/common/kv_cache_mem_manager/allocator.py new file mode 100644 index 0000000000..850c158778 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/allocator.py @@ -0,0 +1,106 @@ +import torch +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.log_utils import init_logger +from typing import Union, List + +logger = init_logger(__name__) + + +class KvCacheAllocator: + def __init__(self, size: int) -> None: + self.size = size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + + self.can_use_mem_size = self.size + + rank_in_node = get_current_rank_in_node() + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return + + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end - self.mark_start: + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") + assert False, "error alloc state" + + start = self.mark_start + end = self.mark_start + need_size + self.mark_start += need_size + + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + # 利用缓冲区返回,避免异步情况下的内存竞争 + if self._return_start + need_size > self._mem_state_return.shape[0]: + self._return_start = 0 + ans = self._mem_state_return[self._return_start : self._return_start + need_size] + ans.copy_(self.mem_state[start:end]) + self._return_start += need_size + return ans + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """_summary_ + + Args: + free_index (torch.Tensor): _description_ + """ + + end = self.mark_start + start = self.mark_start - len(free_index) + assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + + if isinstance(free_index, list): + self.mem_state.numpy()[start:end] = free_index + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start -= len(free_index) + + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.can_use_mem_size == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") + return + + def free_all(self): + self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) + self.mark_start = 0 + self.mark_end = len(self.mem_state) + self.can_use_mem_size = len(self.mem_state) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return + + def resize(self, new_size: int) -> None: + """ + just for test code + """ + self.size = new_size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + + self.can_use_mem_size = self.size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 3d93e1b070..9eb02b963c 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -1,41 +1,22 @@ import torch import os import torch.distributed as dist -from lightllm.server.pd_io_struct import KVMoveTask from .mem_manager import MemoryManager from typing import List, Union, Any from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_trans_kernel.kv_trans import kv_trans -from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node -from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io +from .operator import Deepseek2MemOperator logger = init_logger(__name__) class Deepseek2MemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - """ - 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 - """ - from ..basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv + operator_class = Deepseek2MemOperator - rope_dim = 64 - kv_lora_rank = kv.shape[2] - rope_dim - assert kv_lora_rank + rope_dim == self.kv_buffer.shape[-1] - - destindex_copy_kv( - kv[:, :, :kv_lora_rank], - kv[:, :, kv_lora_rank:], - mem_index, - self.kv_buffer[layer_index][:, :, :kv_lora_rank], - self.kv_buffer[layer_index][:, :, kv_lora_rank:], - ) - return + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) def get_att_input_params(self, layer_index: int) -> Any: kv = self.kv_buffer[layer_index] @@ -47,14 +28,6 @@ def get_cell_size(self): def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") - def alloc_kv_move_buffer(self, max_req_total_len): - self.kv_move_buffer = torch.empty( - (1, max_req_total_len + 8, self.head_num, self.head_dim), dtype=self.dtype, device="cuda" - ) - self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda") - self.token_dim_size = self.kv_move_buffer.shape[-1] * self.kv_move_buffer.shape[-2] - return - def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: self.kv_move_buffer = torch.empty( (page_num, page_size, self.layer_num, self.head_num, self.head_dim), dtype=self.dtype, device="cuda" @@ -71,7 +44,10 @@ def write_mem_to_page_kv_move_buffer( dp_index: int, mem_managers: List["MemoryManager"], dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, ): + assert page_kind == "kv", f"{type(self).__name__} does not support page_kind={page_kind}" cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes @@ -92,7 +68,10 @@ def read_page_kv_move_buffer_to_mem( dp_index: int, mem_managers: List["MemoryManager"], dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, ): + assert page_kind == "kv", f"{type(self).__name__} does not support page_kind={page_kind}" cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes @@ -105,180 +84,3 @@ def read_page_kv_move_buffer_to_mem( kv_buffer=mem.kv_buffer, mode="read", ) - - def send_to_decode_node( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["Deepseek2MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据发送到指定的一张卡上的buffer,再发送。 - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) - - cur_device_index = self.kv_buffer.get_device() - cur_mem = mem_managers[cur_device_index] - for layer_index in range(cur_mem.layer_num): - move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index) - nccl_comm.send(move_buffer, dst=1) - return - - def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): - move_size = self.token_dim_size * len(token_indexes) - move_buffer = self.kv_move_buffer.view(-1)[0:move_size].view( - 1, len(token_indexes), self.head_num, self.head_dim - ) - move_buffer[:, :, :, :] = self.kv_buffer[layer_index, token_indexes, :, :] - return move_buffer - - def receive_from_prefill_node( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。 - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) - - cur_device_index = self.kv_buffer.get_device() - token_num = len(move_token_indexes) - move_size = self.token_dim_size * token_num - recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim) - for layer_index in range(self.layer_num): - nccl_comm.recv(recive_buffer, src=0) - for i, mem in enumerate(mem_managers): - if i == cur_device_index: - mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) - else: - new_recive_buffer = mem.kv_move_buffer.view(-1)[0:move_size].view(recive_buffer.shape) - from torch.cuda import comm - - comm.broadcast(recive_buffer, out=[new_recive_buffer]) - mem._write_kv_move_data(move_token_indexes, new_recive_buffer, layer_index) - return - - def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index): - self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor - return - - def send_to_decode_node_p2p( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - """ - 使用 p2p triton kernel 进行数据复制和传输的实现方式。 - """ - if not hasattr(self, "mem_ptrs_dict"): - self.mem_ptrs_dict = {} - for layer_index in range(self.layer_num): - mems_ptr = [] - for i in range(0, len(mem_managers), len(mem_managers) // dp_size_in_node): - mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) - mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") - self.mem_ptrs_dict[layer_index] = mems_ptr - - move_token_indexes = [] - token_dp_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) - token_dp_indexes.extend([task.prefill_dp_index for _ in range(task.move_kv_len)]) - - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") - for layer_index in range(self.layer_num): - move_buffer = self._get_kv_move_data_p2p( - move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node - ) - nccl_comm.send(move_buffer, dst=1) - return - - def _get_kv_move_data_p2p( - self, - token_indexes: torch.Tensor, - token_dp_indexes: torch.Tensor, - layer_index: int, - kv_move_buffer: torch.Tensor, - dp_size_in_node: int, - ): - move_token_num = len(token_indexes) - move_size = self.token_dim_size * move_token_num - move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, self.head_num, self.head_dim) - kv_trans_v2_for_p_node( - input_mems=self.mem_ptrs_dict[layer_index], - input_idx=token_indexes, - input_dp_idx=token_dp_indexes, - output=move_buffer, - output_idx=self.kv_move_buf_indexes[0:move_token_num], - dp_size_in_node=dp_size_in_node, - ) - return move_buffer - - def receive_from_prefill_node_p2p( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - if not hasattr(self, "mem_ptrs_dict"): - self.mem_ptrs_dict = {} - for layer_index in range(self.layer_num): - mems_ptr = [] - for i in range(0, len(mem_managers)): - mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) - mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") - self.mem_ptrs_dict[layer_index] = mems_ptr - - move_token_indexes = [] - token_dp_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) - token_dp_indexes.extend([task.decode_dp_index for _ in range(task.move_kv_len)]) - - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") - - token_num = len(move_token_indexes) - move_size = self.token_dim_size * token_num - recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim) - for layer_index in range(self.layer_num): - nccl_comm.recv(recive_buffer, src=0) - self._write_kv_move_data_p2p( - move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node - ) - return - - def _write_kv_move_data_p2p( - self, - token_indexes: torch.Tensor, - token_dp_indexes: torch.Tensor, - buffer_tensor: torch.Tensor, - layer_index, - dp_size_in_node: int, - ): - move_token_num = len(token_indexes) - kv_trans_v2_for_d_node( - output_mems=self.mem_ptrs_dict[layer_index], - output_idx=token_indexes, - output_dp_idx=token_dp_indexes, - input=buffer_tensor, - input_idx=self.kv_move_buf_indexes[0:move_token_num], - dp_size_in_node=dp_size_in_node, - ) - return diff --git a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py index 66f37a16f1..4034bcc8fa 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py @@ -1,9 +1,13 @@ import torch from typing import Any from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager +from .operator import Deepseek3_2MemOperator class Deepseek3_2MemoryManager(Deepseek2MemoryManager): + + operator_class = Deepseek3_2MemOperator + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): assert dtype in [torch.bfloat16, torch.float16] # 因为V3.2 使用了NSA 稀疏的缘故,所以其head_dim 会比原始的kv 多 128 + 4 = 132 个字节 (128 fp8 + 4byte float32 scale), @@ -12,25 +16,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False # 所以在子类中定制为其pad上,对外使用的接口,需要进行重载区别。 super().__init__(size, dtype, head_num, head_dim + (144 // 2), layer_num, always_copy, mem_fraction) - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - """ - 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 - """ - from ..basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv - - rope_dim = 64 - kv_lora_rank = kv.shape[2] - rope_dim - assert kv_lora_rank + rope_dim == self.kv_buffer.shape[-1] - (144 // 2) - - destindex_copy_kv( - kv[:, :, :kv_lora_rank], - kv[:, :, kv_lora_rank:], - mem_index, - self.kv_buffer[layer_index][:, :, :kv_lora_rank], - self.kv_buffer[layer_index][:, :, kv_lora_rank : (kv_lora_rank + rope_dim)], - ) - return - def get_att_input_params(self, layer_index: int) -> Any: kv = self.kv_buffer[layer_index][:, :, : (self.head_dim - (144 // 2))] return kv diff --git a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py index b4464cd12d..b72587545f 100644 --- a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py @@ -2,9 +2,13 @@ from typing import Any from .deepseek2_mem_manager import Deepseek2MemoryManager +from .operator import FP8PerTokenGroupQuantDeepseek3_2MemOperator class FP8PerTokenGroupQuantDeepseek3_2MemoryManager(Deepseek2MemoryManager): + + operator_class = FP8PerTokenGroupQuantDeepseek3_2MemOperator + kv_nope_dim = 512 kv_rope_dim = 64 # 576 = 512 + 64 @@ -34,27 +38,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.prefill_dtype = dtype super().__init__(size, torch.uint8, head_num, self.total_bytes_per_token, layer_num, always_copy, mem_fraction) - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_kv_flashmla_fp8 import ( - destindex_copy_kv_flashmla_fp8, - ) - - rope_dim = 64 - kv_lora_rank = kv.shape[2] - rope_dim - assert kv_lora_rank == 512, f"Expected kv_lora_rank=512, got {kv_lora_rank}" - - o_nope = self.kv_buffer[layer_index][:, :, :512].view(torch.float8_e4m3fn) - o_scale = self.kv_buffer[layer_index][:, :, 512:528].view(torch.float32) - o_rope = self.kv_buffer[layer_index][:, :, 528 : self.flashmla_bytes_per_token].view(torch.bfloat16) - destindex_copy_kv_flashmla_fp8( - kv[:, :, :kv_lora_rank], - kv[:, :, kv_lora_rank:], - mem_index, - o_nope, - o_scale, - o_rope, - ) - def get_att_input_params(self, layer_index: int) -> Any: return self.kv_buffer[layer_index][:, :, : self.flashmla_bytes_per_token] diff --git a/lightllm/common/kv_cache_mem_manager/fp8_static_per_head_quant_mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_static_per_head_quant_mem_manager.py index 52fd31fcd8..7980ca2dd7 100755 --- a/lightllm/common/kv_cache_mem_manager/fp8_static_per_head_quant_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/fp8_static_per_head_quant_mem_manager.py @@ -8,11 +8,15 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp from .mem_manager import MemoryManager +from .operator import FP8StaticPerHeadQuantMemOperator logger = init_logger(__name__) class FP8StaticPerHeadQuantMemManager(MemoryManager): + + operator_class = FP8StaticPerHeadQuantMemOperator + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): # 这里用uint8存储量化后的kv,方便兼容各种torch算子。fp8量化目前采用离线方案,kv_buffer不存储scale super().__init__(size, torch.uint8, head_num, head_dim, layer_num, always_copy, mem_fraction) @@ -72,21 +76,6 @@ def _load_and_check_config(self): f"kv_quant_calibration_config {get_env_start_args().kv_quant_calibration_config_path} not found" ) - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - """ - 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 - """ - from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 - - scales = self.scales - destindex_copy_kv_fp8( - kv, - mem_index, - scales[layer_index], - self.kv_buffer[layer_index].view(torch.float8_e4m3fn), - ) - return - def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: k = self.kv_buffer[layer_index][:, : self.head_num, :] v = self.kv_buffer[layer_index][:, self.head_num :, :] diff --git a/lightllm/common/kv_cache_mem_manager/fp8_static_per_tensor_quant_mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_static_per_tensor_quant_mem_manager.py index 4fa71415b8..06338808bb 100755 --- a/lightllm/common/kv_cache_mem_manager/fp8_static_per_tensor_quant_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/fp8_static_per_tensor_quant_mem_manager.py @@ -7,11 +7,15 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from .mem_manager import MemoryManager +from .operator import FP8StaticPerTensorQuantMemOperator logger = init_logger(__name__) class FP8StaticPerTensorQuantMemManager(MemoryManager): + + operator_class = FP8StaticPerTensorQuantMemOperator + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): # 这里用uint8存储量化后的kv,方便兼容各种torch算子。fp8量化目前采用离线方案,kv_buffer不存储scale super().__init__(size, torch.uint8, head_num, head_dim, layer_num, always_copy, mem_fraction) @@ -60,21 +64,6 @@ def _load_and_check_config(self): f"kv_quant_calibration_config {get_env_start_args().kv_quant_calibration_config_path} not found" ) - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - """ - 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 - """ - from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 - - destindex_copy_kv_fp8( - kv, - mem_index, - self.scales[layer_index], - self.kv_buffer[layer_index].view(torch.float8_e4m3fn), - is_per_tensor_quant=True, - ) - return - def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: k = self.kv_buffer[layer_index][:, : self.head_num, :] v = self.kv_buffer[layer_index][:, self.head_num :, :] diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 1203cbdec7..658d3e899c 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -3,16 +3,13 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from typing import List, Union, Tuple, Any -from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp -from lightllm.server.pd_io_struct import KVMoveTask +from typing import List, Tuple, Any, Union from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from .allocator import KvCacheAllocator from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory -from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args -from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.config_utils import get_num_key_value_heads from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io @@ -20,12 +17,15 @@ from lightllm.utils.shm_utils import create_or_link_shm from multiprocessing.reduction import ForkingPickler from filelock import FileLock - +from .operator import BaseMemManagerOperator, NormalMemOperator logger = init_logger(__name__) class MemoryManager: + + operator_class = NormalMemOperator + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num @@ -36,27 +36,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False # profile the max total token num if the size is None self.profile_size(mem_fraction) - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._mem_state_return = torch.arange( - 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._return_start = 0 - self.mark_start = 0 - self.mark_end = self.size - - self.can_use_mem_size = self.size - - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name - - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + self.allocator = KvCacheAllocator(self.size) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, @@ -66,14 +47,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False ) self.HOLD_TOKEN_MEMINDEX = self.size - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - """ - 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 - """ - from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv - - destindex_copy_kv(kv, mem_index, self.kv_buffer[layer_index]) - return + # 构建对外的操作类接口 + self.operator: BaseMemManagerOperator = self.operator_class(self) def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: k = self.kv_buffer[layer_index][:, : self.head_num, :] @@ -87,9 +62,9 @@ def profile_size(self, mem_fraction): if self.size is not None: return + torch.cuda.empty_cache() world_size = dist.get_world_size() - total_memory = get_total_gpu_memory() - available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) + available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction) cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) if world_size > 1: @@ -110,23 +85,7 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") - def alloc_kv_move_buffer(self, max_req_total_len): - """ - pd 分离模式使用的特殊接口 - """ - if isinstance(self, MemoryManager) and type(self) is not MemoryManager: - raise NotImplementedError("subclass need reimpl this method") - self.kv_move_buffer = torch.empty( - (1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda" - ) - self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda") - self.token_dim_size = self.kv_move_buffer.shape[-2] * self.kv_move_buffer.shape[-1] - return - def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: - if isinstance(self, MemoryManager) and type(self) is not MemoryManager: - raise NotImplementedError("subclass need reimpl this method") - num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir) self.kv_move_buffer = torch.empty( (page_num, page_size, self.layer_num, 2 * num_kv_head, self.head_dim), dtype=self.dtype, device="cuda" @@ -143,7 +102,10 @@ def write_mem_to_page_kv_move_buffer( dp_index: int, mem_managers: List["MemoryManager"], dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, ): + assert page_kind == "kv", f"{type(self).__name__} does not support page_kind={page_kind}" cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes @@ -172,7 +134,10 @@ def read_page_kv_move_buffer_to_mem( dp_index: int, mem_managers: List["MemoryManager"], dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, ): + assert page_kind == "kv", f"{type(self).__name__} does not support page_kind={page_kind}" cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes @@ -194,205 +159,17 @@ def read_page_kv_move_buffer_to_mem( # logger.info(f"dst token tensor {self.kv_buffer[:, mem_indexes[0], 0, 0]}") # logger.info(f"dst page token tensor {cur_page[0, :, 0, 0]}") - def send_to_decode_node( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据发送到指定的一张卡上的buffer,再发送。 - - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) - - cur_device_index = self.kv_buffer.get_device() - cur_mem = mem_managers[cur_device_index] - for i, mem in enumerate(mem_managers): - for layer_index in range(mem.layer_num): - move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index) - if i == cur_device_index: - nccl_comm.send(move_buffer, dst=1) - else: - move_size = move_buffer.numel() - new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape) - from torch.cuda import comm - - comm.broadcast(move_buffer, out=[new_move_buffer]) - nccl_comm.send(new_move_buffer, dst=1) - return - - def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): - move_size = self.token_dim_size * len(token_indexes) - move_buffer = self.kv_move_buffer.view(-1)[0:move_size].view( - 1, len(token_indexes), 2 * self.head_num, self.head_dim - ) - move_buffer[:, :, :, :] = self.kv_buffer[layer_index, token_indexes, :, :] - return move_buffer - - def receive_from_prefill_node( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。 - - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) - - cur_device_index = self.kv_buffer.get_device() - token_num = len(move_token_indexes) - move_size = self.token_dim_size * token_num - recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim) - for i, mem in enumerate(mem_managers): - for layer_index in range(mem.layer_num): - nccl_comm.recv(recive_buffer, src=0) - if i == cur_device_index: - mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) - else: - new_recive_buffer = mem.kv_move_buffer.view(-1)[0:move_size].view(recive_buffer.shape) - from torch.cuda import comm - - comm.broadcast(recive_buffer, out=[new_recive_buffer]) - mem._write_kv_move_data(move_token_indexes, new_recive_buffer, layer_index) - return - - def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index): - self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor - return - - def send_to_decode_node_p2p( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - """ - 使用 p2p triton kernel 进行数据复制和传输的实现方式。 - """ - assert dp_size_in_node == 1 - - # 先将数据发送到指定的一张卡上的buffer,再发送。 - - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) - - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - for i, mem in enumerate(mem_managers): - for layer_index in range(mem.layer_num): - move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer) - nccl_comm.send(move_buffer, dst=1) - return - - def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): - move_token_num = len(token_indexes) - move_size = self.token_dim_size * move_token_num - move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, 2 * self.head_num, self.head_dim) - kv_trans( - self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num] - ) - return move_buffer - - def receive_from_prefill_node_p2p( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。 - - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) - - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - - token_num = len(move_token_indexes) - move_size = self.token_dim_size * token_num - recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim) - for i, mem in enumerate(mem_managers): - for layer_index in range(mem.layer_num): - nccl_comm.recv(recive_buffer, src=0) - mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index) - return - - def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index): - move_token_num = len(token_indexes) - kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes) - return - def _free_buffers(self): self.kv_buffer = None def alloc(self, need_size) -> torch.Tensor: - if need_size > self.mark_end - self.mark_start: - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") - assert False, "error alloc state" - - start = self.mark_start - end = self.mark_start + need_size - self.mark_start += need_size - - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # 利用缓冲区返回,避免异步情况下的内存竞争 - if self._return_start + need_size > self._mem_state_return.shape[0]: - self._return_start = 0 - ans = self._mem_state_return[self._return_start : self._return_start + need_size] - ans.copy_(self.mem_state[start:end]) - self._return_start += need_size - return ans - - def free(self, free_index: Union[torch.Tensor, List[int]]): - """_summary_ - - Args: - free_index (torch.Tensor): _description_ - """ + return self.allocator.alloc(need_size) - end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" - - if isinstance(free_index, list): - self.mem_state.numpy()[start:end] = free_index - else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 - self.mem_state[start:end] = free_index - - self.mark_start -= len(free_index) - - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.can_use_mem_size == len(self.mem_state): - logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") - return + def free(self, free_index: Union[torch.Tensor, List[int]]) -> None: + self.allocator.free(free_index) def free_all(self): - self.can_use_mem_size = len(self.mem_state) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) - self.mark_start = 0 - self.mark_end = len(self.mem_state) + self.allocator.free_all() def resize_mem(self, new_size): """ @@ -405,13 +182,8 @@ def resize_mem(self, new_size): layer_num = self.layer_num self.size = new_size - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.allocator.resize(new_size) + self.HOLD_TOKEN_MEMINDEX = self.size self._free_buffers() self._init_buffers(size, dtype, head_num, head_dim, layer_num) return @@ -422,39 +194,12 @@ def get_index_kv_buffer(self, index): def load_index_kv_buffer(self, index, load_tensor_dict): self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) - def copy_kv_from_other_dp_ranks( - self, - mem_managers: List["MemoryManager"], - move_token_indexes: torch.Tensor, - token_dp_indexes: torch.Tensor, - mem_indexes: torch.Tensor, - dp_size_in_node: int, - rank_in_dp: int, - ): - if not hasattr(self, "mem_ptrs_tensor"): - # 构建一个2D tensor,shape为(layer_num, mem_num) - mems_ptr_list = [] - for i in range(0, len(mem_managers)): - mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr()) - self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cpu", pin_memory=True) - - # 一次性传输所有层 - kv_trans_for_dp( - input_mems=self.mem_ptrs_tensor.cuda(non_blocking=True), - input_idx=move_token_indexes, - input_dp_idx=token_dp_indexes, - output=self.kv_buffer, - output_idx=mem_indexes, - dp_size_in_node=dp_size_in_node, - rank_in_dp=rank_in_dp, - ) - def write_to_shm(self, req_manager): """ 将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。 """ if kv_trans_use_p2p(): - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor + from lightllm.server.router.model_infer.mode_backend.pd.p2p_fix import reduce_tensor mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ diff --git a/lightllm/common/kv_cache_mem_manager/operator/__init__.py b/lightllm/common/kv_cache_mem_manager/operator/__init__.py new file mode 100644 index 0000000000..85c37ad39b --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/operator/__init__.py @@ -0,0 +1,13 @@ +from .base import BaseMemManagerOperator +from .normal import NormalMemOperator +from .quant import QuantScaleMemOperator, PPLInt4KVMemOperator, PPLInt8KVMemOperator +from .linear_att import LinearAttMemOperator +from .deepseek import ( + Deepseek2MemOperator, + Deepseek3_2MemOperator, + FP8PerTokenGroupQuantDeepseek3_2MemOperator, +) +from .fp8_quant import ( + FP8StaticPerHeadQuantMemOperator, + FP8StaticPerTensorQuantMemOperator, +) diff --git a/lightllm/common/kv_cache_mem_manager/operator/base.py b/lightllm/common/kv_cache_mem_manager/operator/base.py new file mode 100644 index 0000000000..682b8a5d60 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/operator/base.py @@ -0,0 +1,48 @@ +import torch +from abc import ABC, abstractmethod +from typing import List, TYPE_CHECKING + +if TYPE_CHECKING: + from ..mem_manager import MemoryManager + from lightllm.server.router.model_infer.infer_batch import InferReq + +# 定义一个抽象基类 +class BaseMemManagerOperator(ABC): + def __init__(self, mem_manager: "MemoryManager") -> None: + super().__init__() + self.mem_manager = mem_manager + + @abstractmethod + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + pass + + def copy_mem_to_mem(self, src_mem_index: torch.Tensor, dst_mem_index: torch.Tensor): + raise NotImplementedError() + + # cpu cache 的相关操作接口 + def load_cpu_cache_to_gpu( + self, mem_indexes: torch.Tensor, page_indexes: torch.Tensor, cpu_cache_client, req: "InferReq" + ): + raise NotImplementedError() + + def offload_gpu_kv_to_cpu_cache( + self, + mem_indexes: torch.Tensor, + page_indexes: torch.Tensor, + page_readies: torch.Tensor, + cpu_cache_client, + req: "InferReq", + ): + raise NotImplementedError() + + # dp 间共享 kv 的操作接口, 提升dp 模式下的kv 共享效率,降低调度的难度 + def copy_kv_from_other_dp_ranks( + self, + mem_managers: List["MemoryManager"], + move_token_indexes: torch.Tensor, + token_dp_indexes: torch.Tensor, + mem_indexes: torch.Tensor, + dp_size_in_node: int, + rank_in_dp: int, + ): + raise NotImplementedError() diff --git a/lightllm/common/kv_cache_mem_manager/operator/deepseek.py b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py new file mode 100644 index 0000000000..6e05b96e10 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py @@ -0,0 +1,80 @@ +import torch +from .normal import NormalMemOperator +from .base import BaseMemManagerOperator +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Deepseek2MemOperator(NormalMemOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager + + mem_manager: Deepseek2MemoryManager = self.mem_manager + + from ...basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank + rope_dim == mem_manager.kv_buffer.shape[-1] + + destindex_copy_kv( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + mem_manager.kv_buffer[layer_index][:, :, :kv_lora_rank], + mem_manager.kv_buffer[layer_index][:, :, kv_lora_rank:], + ) + return + + +class Deepseek3_2MemOperator(Deepseek2MemOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.deepseek3_2mem_manager import Deepseek3_2MemoryManager + + mem_manager: Deepseek3_2MemoryManager = self.mem_manager + from ...basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank + rope_dim == mem_manager.kv_buffer.shape[-1] - (144 // 2) + + destindex_copy_kv( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + mem_manager.kv_buffer[layer_index][:, :, :kv_lora_rank], + mem_manager.kv_buffer[layer_index][:, :, kv_lora_rank : (kv_lora_rank + rope_dim)], + ) + return + + +class FP8PerTokenGroupQuantDeepseek3_2MemOperator(BaseMemManagerOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.fp8_per_token_group_quant_deepseek3_2mem_manager import ( + FP8PerTokenGroupQuantDeepseek3_2MemoryManager, + ) + + mem_manager: FP8PerTokenGroupQuantDeepseek3_2MemoryManager = self.mem_manager + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_kv_flashmla_fp8 import ( + destindex_copy_kv_flashmla_fp8, + ) + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank == 512, f"Expected kv_lora_rank=512, got {kv_lora_rank}" + + flashmla_bytes_per_token = mem_manager.flashmla_bytes_per_token + + o_nope = mem_manager.kv_buffer[layer_index][:, :, :512].view(torch.float8_e4m3fn) + o_scale = mem_manager.kv_buffer[layer_index][:, :, 512:528].view(torch.float32) + o_rope = mem_manager.kv_buffer[layer_index][:, :, 528:flashmla_bytes_per_token].view(torch.bfloat16) + destindex_copy_kv_flashmla_fp8( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + o_nope, + o_scale, + o_rope, + ) + return diff --git a/lightllm/common/kv_cache_mem_manager/operator/fp8_quant.py b/lightllm/common/kv_cache_mem_manager/operator/fp8_quant.py new file mode 100644 index 0000000000..4b1f0ac207 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/operator/fp8_quant.py @@ -0,0 +1,48 @@ +import torch +from typing import TYPE_CHECKING +from .base import BaseMemManagerOperator +from lightllm.utils.log_utils import init_logger + +if TYPE_CHECKING: + from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient +logger = init_logger(__name__) + + +class FP8StaticPerHeadQuantMemOperator(BaseMemManagerOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.fp8_static_per_head_quant_mem_manager import ( + FP8StaticPerHeadQuantMemManager, + ) + + mem_manager: FP8StaticPerHeadQuantMemManager = self.mem_manager + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import ( + destindex_copy_kv_fp8, + ) + + scales = mem_manager.scales + destindex_copy_kv_fp8( + kv, + mem_index, + scales[layer_index], + mem_manager.kv_buffer[layer_index].view(torch.float8_e4m3fn), + ) + return + + +class FP8StaticPerTensorQuantMemOperator(BaseMemManagerOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + mem_manager: MemoryManager = self.mem_manager + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import ( + destindex_copy_kv_fp8, + ) + + destindex_copy_kv_fp8( + kv, + mem_index, + mem_manager.scales[layer_index], + mem_manager.kv_buffer[layer_index].view(torch.float8_e4m3fn), + is_per_tensor_quant=True, + ) + return diff --git a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py new file mode 100644 index 0000000000..147b43f697 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py @@ -0,0 +1,211 @@ +import torch +import triton +from typing import List +from typing import TYPE_CHECKING +from .base import BaseMemManagerOperator +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size +from lightllm.utils.log_utils import init_logger +from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + +if TYPE_CHECKING: + from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient + from lightllm.server.router.model_infer.infer_batch import InferReq + +logger = init_logger(__name__) + + +class LinearAttMemOperator(BaseMemManagerOperator): + """ + 只用于非量化的linear att 混合 full att的模型,列入 qwen3.5 + """ + + def __init__(self, mem_manager): + super().__init__(mem_manager) + self.linear_config = LinearAttCacheConfig.load_from_args() + + def load_cpu_cache_to_gpu( + self, + mem_indexes: torch.Tensor, + page_indexes: torch.Tensor, + cpu_cache_client: "CpuKvCacheClient", + req: "InferReq", + ): + assert mem_indexes.is_cuda and page_indexes.is_cuda + args = get_env_start_args() + assert triton.cdiv(len(mem_indexes), args.cpu_cache_token_page_size) == len(page_indexes) + assert len(mem_indexes) % args.linear_att_hash_page_size == 0 + assert args.cpu_cache_token_page_size == args.linear_att_hash_page_size * args.linear_att_page_block_num + from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager + + mem_manager: Qwen3NextMemManager = self.mem_manager + + big_page_num = len(mem_indexes) // args.cpu_cache_token_page_size + max_kv_len = (req.cur_kv_len // args.cpu_cache_token_page_size) * args.cpu_cache_token_page_size + assert max_kv_len % args.cpu_cache_token_page_size == 0 + + big_page_buffer_ids_cpu = [] + for i in range(big_page_num): + page_id = mem_manager.linear_att_big_page_buffers.alloc_one_state_cache() + assert page_id is not None + req.linear_att_len_to_big_page_id[max_kv_len] = page_id + big_page_buffer_ids_cpu.append(page_id) + max_kv_len -= args.cpu_cache_token_page_size + assert max_kv_len % args.cpu_cache_token_page_size == 0 + + big_page_buffer_ids_cpu.reverse() + + # 碎页情况的处理 + has_tail_page = len(mem_indexes) % args.cpu_cache_token_page_size != 0 + if has_tail_page: + padded_token_num = triton.cdiv( + len(mem_indexes), args.cpu_cache_token_page_size + ) * args.cpu_cache_token_page_size - len(mem_indexes) + mem_indexes = torch.nn.functional.pad(mem_indexes, (0, padded_token_num), mode="constant", value=-1) + + # 将对应的小叶数据拷贝到临时的大页上,再从大页上拷贝到对应的运行态页面上 + big_page_buffer_ids_cpu.append(mem_manager.CPU_CACHE_BIG_PAGE_LOAD_TEMP_BUFFER_ID) + + big_page_buffer_ids_gpu = torch.tensor(big_page_buffer_ids_cpu, dtype=torch.int64, device="cpu").cuda( + non_blocking=True + ) + + assert len(big_page_buffer_ids_gpu) == len(page_indexes) + + from lightllm.common.basemodel.triton_kernel.linear_att_cpu_cache_copy import ( + copy_cpu_cache_to_kv_buffer, + ) + + copy_cpu_cache_to_kv_buffer( + mem_indexes=mem_indexes, + big_page_buffer_ids=big_page_buffer_ids_gpu, + page_indexes=page_indexes, + gpu_full_att_kv_state=mem_manager.kv_buffer, + cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, + cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, + cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, + tp_rank=get_current_rank_in_dp(), + tp_world_size=get_dp_world_size(), + big_page_token_num=args.cpu_cache_token_page_size, + linear_config=self.linear_config, + ) + + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + g_infer_context.req_manager.copy_big_page_buffer_to_linear_att_state( + big_page_buffer_idx=big_page_buffer_ids_cpu[-1], + req=req, + ) + + return + + def offload_gpu_kv_to_cpu_cache( + self, + mem_indexes: torch.Tensor, + page_indexes: torch.Tensor, + page_readies: torch.Tensor, + cpu_cache_client: "CpuKvCacheClient", + req: "InferReq", + ): + args = get_env_start_args() + if not hasattr(self, "big_page_ids_buffer_store"): + self.big_page_ids_buffer_store = torch.empty((1024 * 1024 * 4,), dtype=torch.int64, device="cuda") + # 多申请3个cpu cache token page size,用于处理碎页情况,碎页情况需要将对应的大页数据拷贝到临时的大页上, + # 再从大页上拷贝到对应的运行态页面上 + self.mem_indexes_buffer = torch.empty( + (args.max_req_total_len + 3 * args.cpu_cache_token_page_size,), dtype=torch.int32, device="cuda" + ) + + assert mem_indexes.is_cuda and page_indexes.is_cuda and page_readies.is_cuda + + assert len(mem_indexes) % args.linear_att_hash_page_size == 0 + assert triton.cdiv(len(mem_indexes), args.cpu_cache_token_page_size) == len(page_indexes) + + from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager + + mem_manager: Qwen3NextMemManager = self.mem_manager + + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + big_page_buffer_ids_cpu = g_infer_context.radix_cache.get_big_page_ids_by_node(req.shared_kv_node) + max_kv_len = (len(mem_indexes) // args.cpu_cache_token_page_size) * args.cpu_cache_token_page_size + start_kv_len = (len(big_page_buffer_ids_cpu) + 1) * args.cpu_cache_token_page_size + for seq_len in range(start_kv_len, max_kv_len + 1, args.cpu_cache_token_page_size): + page_id = req.linear_att_len_to_big_page_id[seq_len] + big_page_buffer_ids_cpu.append(page_id) + + if len(mem_indexes) % args.cpu_cache_token_page_size != 0: + # 存在不满大页的碎页的页面存在需要复制的情况 + dst_len = triton.cdiv(len(mem_indexes), args.cpu_cache_token_page_size) * args.cpu_cache_token_page_size + assert dst_len <= self.mem_indexes_buffer.shape[0] + dst_mem_indexes = self.mem_indexes_buffer[0:dst_len].fill_(-1) + dst_mem_indexes[0 : len(mem_indexes)].copy_(mem_indexes, non_blocking=True) + mem_indexes = dst_mem_indexes + assert req.tail_linear_att_small_page_buffer_id is not None + from lightllm.common.basemodel.triton_kernel.linear_att_cpu_cache_copy import ( + copy_linear_att_state_to_linear_att_state, + ) + + src_conv_state, src_ssm_state = g_infer_context.radix_cache.linear_att_small_page_buffers.get_state_cache( + buffer_idx=req.tail_linear_att_small_page_buffer_id + ) + dst_conv_state, dst_ssm_state = mem_manager.linear_att_big_page_buffers.get_state_cache( + buffer_idx=mem_manager.CPU_CACHE_BIG_PAGE_OFFLOAD_TEMP_BUFFER_ID, + ) + copy_linear_att_state_to_linear_att_state( + src_conv_state=src_conv_state, + src_ssm_state=src_ssm_state, + dst_conv_state=dst_conv_state, + dst_ssm_state=dst_ssm_state, + ) + big_page_buffer_ids_cpu.append(mem_manager.CPU_CACHE_BIG_PAGE_OFFLOAD_TEMP_BUFFER_ID) + + assert len(big_page_buffer_ids_cpu) == len(page_indexes) == len(page_readies) + + big_page_buffer_ids_cpu = torch.tensor( + big_page_buffer_ids_cpu, dtype=torch.int64, device="cpu", pin_memory=True + ) + assert len(big_page_buffer_ids_cpu) <= self.big_page_ids_buffer_store.shape[0] + big_page_buffer_ids_gpu = self.big_page_ids_buffer_store[0 : len(big_page_buffer_ids_cpu)] + big_page_buffer_ids_gpu.copy_(big_page_buffer_ids_cpu, non_blocking=True) + + from lightllm.common.basemodel.triton_kernel.linear_att_cpu_cache_copy import ( + copy_kv_buffer_to_cpu_cache, + ) + + copy_kv_buffer_to_cpu_cache( + mem_indexes=mem_indexes, + page_indexes=page_indexes, + page_readies=page_readies, + big_page_buffer_ids=big_page_buffer_ids_gpu, + gpu_kv_full_att_state=mem_manager.kv_buffer, + cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, + cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, + cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, + tp_rank=get_current_rank_in_dp(), + tp_world_size=get_dp_world_size(), + big_page_token_num=args.cpu_cache_token_page_size, + linear_config=self.linear_config, + ) + return + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + # Qwen3Next 需要调整 layer_index + layer_index = layer_index // self.linear_config.full_attention_interval + from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + mem_manager: MemoryManager = self.mem_manager + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import ( + destindex_copy_kv, + ) + + destindex_copy_kv(kv, mem_index, mem_manager.kv_buffer[layer_index]) + return + + def copy_mem_to_mem(self, src_mem_index: torch.Tensor, dst_mem_index: torch.Tensor): + from lightllm.common.basemodel.triton_kernel.kv_move import copy_kv_buffer_to_kv_buffer + + copy_kv_buffer_to_kv_buffer( + src_mem_index.cuda(non_blocking=True), dst_mem_index.cuda(non_blocking=True), self.mem_manager.kv_buffer + ) + return diff --git a/lightllm/common/kv_cache_mem_manager/operator/normal.py b/lightllm/common/kv_cache_mem_manager/operator/normal.py new file mode 100644 index 0000000000..3c53ace079 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/operator/normal.py @@ -0,0 +1,115 @@ +import torch +from typing import TYPE_CHECKING, List +from .base import BaseMemManagerOperator +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size +from lightllm.utils.log_utils import init_logger + +if TYPE_CHECKING: + from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient + from lightllm.server.router.model_infer.infer_batch import InferReq + +logger = init_logger(__name__) + + +class NormalMemOperator(BaseMemManagerOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + mem_manager: MemoryManager = self.mem_manager + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import ( + destindex_copy_kv, + ) + + destindex_copy_kv(kv, mem_index, mem_manager.kv_buffer[layer_index]) + return + + def load_cpu_cache_to_gpu( + self, + mem_indexes: torch.Tensor, + page_indexes: torch.Tensor, + cpu_cache_client: "CpuKvCacheClient", + req: "InferReq", + ): + assert mem_indexes.is_cuda and page_indexes.is_cuda + args = get_env_start_args() + assert len(mem_indexes) % args.cpu_cache_token_page_size == 0 + from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + mem_manager: MemoryManager = self.mem_manager + from lightllm.common.basemodel.triton_kernel.kv_cache_offload import load_cpu_kv_to_gpu + + load_cpu_kv_to_gpu( + gpu_mem_indexes=mem_indexes, + gpu_kv_cache=mem_manager.kv_buffer, + gpu_kv_cache_scale=None, + cpu_kv_cache=cpu_cache_client.cpu_kv_cache_tensor, + cpu_kv_cache_scale=None, + page_indexes=page_indexes, + tp_index=get_current_rank_in_dp(), + tp_world_size=get_dp_world_size(), + grid_num=16, + ) + return + + def offload_gpu_kv_to_cpu_cache( + self, + mem_indexes: torch.Tensor, + page_indexes: torch.Tensor, + page_readies: torch.Tensor, + cpu_cache_client: "CpuKvCacheClient", + req: "InferReq", + ): + assert mem_indexes.is_cuda and page_indexes.is_cuda + args = get_env_start_args() + assert len(mem_indexes) % args.cpu_cache_token_page_size == 0 + assert len(mem_indexes) // args.cpu_cache_token_page_size == len(page_indexes) + from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + mem_manager: MemoryManager = self.mem_manager + + from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu + + offload_gpu_kv_to_cpu( + token_indexes=mem_indexes, + gpu_kv_cache=mem_manager.kv_buffer, + gpu_kv_cache_scale=None, + cpu_kv_cache=cpu_cache_client.cpu_kv_cache_tensor, + cpu_kv_cache_scale=None, + page_indexes=page_indexes, + page_readies=page_readies, + tp_index=get_current_rank_in_dp(), + tp_world_size=get_dp_world_size(), + grid_num=16, + ) + return + + def copy_kv_from_other_dp_ranks( + self, + mem_managers: List, + move_token_indexes: torch.Tensor, + token_dp_indexes: torch.Tensor, + mem_indexes: torch.Tensor, + dp_size_in_node: int, + rank_in_dp: int, + ): + if not hasattr(self, "mem_ptrs_tensor"): + # 构建一个2D tensor,shape为(layer_num, mem_num) + mems_ptr_list = [] + for i in range(0, len(mem_managers)): + mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr()) + self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cpu", pin_memory=True) + + from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp + + # 一次性传输所有层 + kv_trans_for_dp( + input_mems=self.mem_ptrs_tensor.cuda(non_blocking=True), + input_idx=move_token_indexes, + input_dp_idx=token_dp_indexes, + output=self.mem_manager.kv_buffer, + output_idx=mem_indexes, + dp_size_in_node=dp_size_in_node, + rank_in_dp=rank_in_dp, + ) + return diff --git a/lightllm/common/kv_cache_mem_manager/operator/quant.py b/lightllm/common/kv_cache_mem_manager/operator/quant.py new file mode 100644 index 0000000000..a3a1c1d01c --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/operator/quant.py @@ -0,0 +1,129 @@ +import torch +from typing import TYPE_CHECKING, List +from .base import BaseMemManagerOperator +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size +from lightllm.utils.log_utils import init_logger + +if TYPE_CHECKING: + from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient + from lightllm.server.router.model_infer.infer_batch import InferReq + +logger = init_logger(__name__) + + +class QuantScaleMemOperator(BaseMemManagerOperator): + """ + 对于kv cache中包含独立的对应每个token的scale变量的memManager使用。 + """ + + def load_cpu_cache_to_gpu( + self, + mem_indexes: torch.Tensor, + page_indexes: torch.Tensor, + cpu_cache_client: "CpuKvCacheClient", + req: "InferReq", + ): + assert mem_indexes.is_cuda and page_indexes.is_cuda + args = get_env_start_args() + assert len(mem_indexes) % args.cpu_cache_token_page_size == 0 + from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + mem_manager: MemoryManager = self.mem_manager + + cpu_cache_meta = cpu_cache_client.kv_cache_tensor_meta + cpu_kv_cache = cpu_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0 : cpu_cache_meta.head_dim] + cpu_kv_cache_scale = cpu_cache_client.cpu_kv_cache_tensor[:, :, :, :, cpu_cache_meta.head_dim :].view( + mem_manager.scale_buffer.dtype + ) + + from lightllm.common.basemodel.triton_kernel.kv_cache_offload import load_cpu_kv_to_gpu + + load_cpu_kv_to_gpu( + gpu_mem_indexes=mem_indexes, + gpu_kv_cache=mem_manager.kv_buffer, + gpu_kv_cache_scale=mem_manager.scale_buffer, + cpu_kv_cache=cpu_kv_cache, + cpu_kv_cache_scale=cpu_kv_cache_scale, + page_indexes=page_indexes, + tp_index=get_current_rank_in_dp(), + tp_world_size=get_dp_world_size(), + grid_num=16, + ) + return + + def offload_gpu_kv_to_cpu_cache( + self, + mem_indexes: torch.Tensor, + page_indexes: torch.Tensor, + page_readies: torch.Tensor, + cpu_cache_client: "CpuKvCacheClient", + req: "InferReq", + ): + assert mem_indexes.is_cuda and page_indexes.is_cuda and page_readies.is_cuda + args = get_env_start_args() + assert len(mem_indexes) % args.cpu_cache_token_page_size == 0 + assert len(mem_indexes) // args.cpu_cache_token_page_size == len(page_indexes) + from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + mem_manager: MemoryManager = self.mem_manager + + cpu_cache_meta = cpu_cache_client.kv_cache_tensor_meta + cpu_kv_cache = cpu_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0 : cpu_cache_meta.head_dim] + cpu_kv_cache_scale = cpu_cache_client.cpu_kv_cache_tensor[:, :, :, :, cpu_cache_meta.head_dim :].view( + mem_manager.scale_buffer.dtype + ) + + from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu + + offload_gpu_kv_to_cpu( + token_indexes=mem_indexes, + gpu_kv_cache=mem_manager.kv_buffer, + gpu_kv_cache_scale=mem_manager.scale_buffer, + cpu_kv_cache=cpu_kv_cache, + cpu_kv_cache_scale=cpu_kv_cache_scale, + page_indexes=page_indexes, + page_readies=page_readies, + tp_index=get_current_rank_in_dp(), + tp_world_size=get_dp_world_size(), + grid_num=16, + ) + return + + +class PPLInt4KVMemOperator(QuantScaleMemOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager + + mem_manager: PPLINT4KVMemoryManager = self.mem_manager + from ...basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import ( + destindex_copy_int4kv, + ) + + destindex_copy_int4kv( + kv, + mem_index, + mem_manager.kv_buffer[layer_index], + mem_manager.scale_buffer[layer_index], + quant_group_size=mem_manager.group_quant_size, + ) + return + + +class PPLInt8KVMemOperator(QuantScaleMemOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager + + mem_manager: PPLINT8KVMemoryManager = self.mem_manager + from ...basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import ( + destindex_copy_quantize_kv, + ) + + destindex_copy_quantize_kv( + kv, + mem_index, + mem_manager.kv_buffer[layer_index], + mem_manager.scale_buffer[layer_index], + quant_group_dim=mem_manager.group_quant_size, + ) + return diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py index 559980dc12..584877a1b3 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py @@ -1,29 +1,18 @@ import torch from typing import Tuple, Any from .mem_manager import MemoryManager +from .operator import PPLInt4KVMemOperator class PPLINT4KVMemoryManager(MemoryManager): + + operator_class = PPLInt4KVMemOperator + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=0.9): self.kv_dtype = torch.int8 self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - """ - 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 - """ - from ..basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv - - destindex_copy_int4kv( - kv, - mem_index, - self.kv_buffer[layer_index], - self.scale_buffer[layer_index], - quant_group_size=self.group_quant_size, - ) - return - def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: k = self.kv_buffer[layer_index][:, : self.head_num, :] k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py index 951d72e2c8..994c676e9f 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py @@ -1,29 +1,17 @@ import torch from typing import Tuple, Any from .mem_manager import MemoryManager +from .operator import PPLInt8KVMemOperator class PPLINT8KVMemoryManager(MemoryManager): + operator_class = PPLInt8KVMemOperator + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=0.9): self.kv_dtype = torch.int8 self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - """ - 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 - """ - from ..basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import destindex_copy_quantize_kv - - destindex_copy_quantize_kv( - kv, - mem_index, - self.kv_buffer[layer_index], - self.scale_buffer[layer_index], - quant_group_dim=self.group_quant_size, - ) - return - def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: k = self.kv_buffer[layer_index][:, : self.head_num, :] k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] diff --git a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py new file mode 100644 index 0000000000..c7ce9d96ba --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py @@ -0,0 +1,419 @@ +import torch +import triton +from lightllm.utils.log_utils import init_logger +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.linear_att_cache_manager import LinearAttCacheConfig, LinearAttCacheManager +from .operator import LinearAttMemOperator +from typing import Tuple, Any, List + +logger = init_logger(__name__) + + +class Qwen3NextMemManager(MemoryManager): + operator_class = LinearAttMemOperator + + def __init__( + self, + size, + dtype, + num_kv_heads, + head_dim, + full_att_layer_num, + linear_config: LinearAttCacheConfig, + always_copy=False, + mem_fraction=0.9, + ): + self.linear_config = linear_config + + super().__init__(size, dtype, num_kv_heads, head_dim, full_att_layer_num, always_copy, mem_fraction) + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + layer_index = layer_index // self.linear_config.full_attention_interval + return super().get_att_input_params(layer_index) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + super()._init_buffers(size, dtype, head_num, head_dim, layer_num) + # TODO 初始化线性 att 对应的部分 buffer. + self._init_linear_att_buffers() + return + + def _init_linear_att_buffers(self): + big_page_token_num = ( + get_env_start_args().linear_att_page_block_num * get_env_start_args().linear_att_hash_page_size + ) + # 申请大页可能需要对应的资源, 多申请了两个linear att的状态,理论上这个状态 + # 永远不会被 alloc 申请到,只会在 cpu cache中,用于过渡和存储碎页情况下的 + # cpu cache 的页面拷贝。 + self.linear_att_big_page_buffers = LinearAttCacheManager( + size=triton.cdiv(self.size, big_page_token_num) + 2, + linear_config=self.linear_config, + keep_num=2, + ) + + self.CPU_CACHE_BIG_PAGE_LOAD_TEMP_BUFFER_ID = self.linear_att_big_page_buffers.size - 2 + self.CPU_CACHE_BIG_PAGE_OFFLOAD_TEMP_BUFFER_ID = self.linear_att_big_page_buffers.size - 1 + return + + def _free_buffers(self): + super()._free_buffers() + self._free_linear_att_buffers() + return + + def _free_linear_att_buffers(self): + self.linear_att_big_page_buffers = None + return + + def write_to_shm(self, req_manager): + self.req_to_conv_state = req_manager.req_to_conv_state + self.req_to_ssm_state = req_manager.req_to_ssm_state + return super().write_to_shm(req_manager) + + def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: + kv_move_buffer = super().alloc_paged_kv_move_buffer(page_num, page_size) + Qwen3NextLinearAttPageHelper(self).assert_page_size() + return kv_move_buffer + + def write_mem_to_page_kv_move_buffer( + self, + mem_indexes, + page_index: int, + dp_index: int, + mem_managers, + dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, + ): + if page_kind == "kv": + return super().write_mem_to_page_kv_move_buffer( + mem_indexes=mem_indexes, + page_index=page_index, + dp_index=dp_index, + mem_managers=mem_managers, + dp_world_size=dp_world_size, + page_kind=page_kind, + req_idx=req_idx, + ) + assert page_kind == "linear_att_state", f"unknown page_kind={page_kind}" + assert req_idx is not None + helper = Qwen3NextLinearAttPageHelper(self) + dp_mems = helper.get_dp_mems(mem_managers, dp_index, dp_world_size) + helper.write_req_to_page(page_index=page_index, req_idx=req_idx, dp_mems=dp_mems) + return + + def read_page_kv_move_buffer_to_mem( + self, + mem_indexes, + page_index: int, + dp_index: int, + mem_managers, + dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, + ): + if page_kind == "kv": + return super().read_page_kv_move_buffer_to_mem( + mem_indexes=mem_indexes, + page_index=page_index, + dp_index=dp_index, + mem_managers=mem_managers, + dp_world_size=dp_world_size, + page_kind=page_kind, + req_idx=req_idx, + ) + assert page_kind == "linear_att_state", f"unknown page_kind={page_kind}" + assert req_idx is not None + helper = Qwen3NextLinearAttPageHelper(self) + dp_mems = helper.get_dp_mems(mem_managers, dp_index, dp_world_size) + helper.read_page_to_req(page_index=page_index, req_idx=req_idx, dp_mems=dp_mems) + return + + +class Qwen3NextLinearAttPageHelper: + def __init__(self, mem_manager: "Qwen3NextMemManager"): + self.mem_manager = mem_manager + self.linear_config = mem_manager.linear_config + self.req_to_conv_state = mem_manager.req_to_conv_state + self.req_to_ssm_state = mem_manager.req_to_ssm_state + self.global_linear_k_heads = self.linear_config.global_linear_k_heads + self.global_linear_v_heads = self.linear_config.global_linear_v_heads + + self.global_q_dim = self.global_linear_k_heads * self.linear_config.head_linear_k_dim + self.global_k_dim = self.global_q_dim + self.global_v_heads = self.global_linear_v_heads + self.global_v_dim = self.global_v_heads * self.linear_config.head_linear_v_dim + # conv state follows mixed_qkv layout: [q, k, v], each as a flat channel block. + self.conv_shape = ( + self.linear_config.linear_layer_num, + self.global_q_dim + self.global_k_dim + self.global_v_dim, + self.linear_config.conv_kernel_size - 1, + ) + self.ssm_shape = ( + self.linear_config.linear_layer_num, + self.global_v_heads, + self.linear_config.head_linear_k_dim, + self.linear_config.head_linear_v_dim, + ) + + self.conv_nbytes = ( + self.conv_shape[0] * self.conv_shape[1] * self.conv_shape[2] * self.req_to_conv_state.buffer.element_size() + ) + ssm_alignment = self.req_to_ssm_state.buffer.element_size() + # 做一下字节对齐,防止切出来的不对齐,导致一些操作的性能下降。 + self.ssm_offset = ((self.conv_nbytes + ssm_alignment - 1) // ssm_alignment) * ssm_alignment + self.ssm_nbytes = ( + self.ssm_shape[0] + * self.ssm_shape[1] + * self.ssm_shape[2] + * self.ssm_shape[3] + * self.req_to_ssm_state.buffer.element_size() + ) + self.state_nbytes = self.ssm_offset + self.ssm_nbytes + + def assert_page_size(self): + kv_move_buffer = self.mem_manager.kv_move_buffer + page_nbytes = kv_move_buffer[0].numel() * kv_move_buffer.element_size() + assert ( + page_nbytes >= self.state_nbytes + ), f"nixl kv move page bytes {page_nbytes} is smaller than global linear att state bytes {self.state_nbytes}" + return + + def get_dp_mems(self, mem_managers: List["Qwen3NextMemManager"], dp_index: int, dp_world_size: int): + dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] + assert len(dp_mems) == dp_world_size + for mem in dp_mems: + assert hasattr(mem, "req_to_conv_state") and hasattr(mem, "req_to_ssm_state") + assert mem.linear_config.linear_layer_num == self.linear_config.linear_layer_num + assert mem.linear_config.conv_kernel_size == self.linear_config.conv_kernel_size + assert mem.linear_config.head_linear_k_dim == self.linear_config.head_linear_k_dim + assert mem.linear_config.head_linear_v_dim == self.linear_config.head_linear_v_dim + assert mem.linear_config.num_linear_k_heads == self.linear_config.num_linear_k_heads + assert mem.linear_config.num_linear_v_heads == self.linear_config.num_linear_v_heads + return dp_mems + + def view_page_to_linear_att_state(self, page_index: int): + page_bytes = self.mem_manager.kv_move_buffer[page_index].view(torch.uint8).reshape(-1) + conv_page = page_bytes[0 : self.conv_nbytes].view(self.req_to_conv_state.buffer.dtype).view(self.conv_shape) + ssm_page = ( + page_bytes[self.ssm_offset : self.ssm_offset + self.ssm_nbytes] + .view(self.req_to_ssm_state.buffer.dtype) + .view(self.ssm_shape) + ) + return conv_page, ssm_page + + def write_req_to_page( + self, + page_index: int, + req_idx: int, + dp_mems: List["Qwen3NextMemManager"], + ): + conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) + req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + for tp_index, mem in enumerate(dp_mems): + self._write_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + return + + def read_page_to_req( + self, + page_index: int, + req_idx: int, + dp_mems: List["Qwen3NextMemManager"], + ): + conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) + req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + for tp_index, mem in enumerate(dp_mems): + self._read_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + return + + def _write_one_rank( + self, + mem: "Qwen3NextMemManager", + tp_index: int, + req_buffer_idx: int, + conv_page: torch.Tensor, + ssm_page: torch.Tensor, + ): + conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] + ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + self._copy_conv_state_to_page(conv_state, conv_page, mem, tp_index) + self._copy_ssm_state_to_page(ssm_state, ssm_page, mem, tp_index) + return + + def _copy_conv_state_to_page( + self, + conv_state: torch.Tensor, + conv_page: torch.Tensor, + mem: "Qwen3NextMemManager", + tp_index: int, + ): + local_q_heads = mem.linear_config.num_linear_k_heads + local_v_heads = mem.linear_config.num_linear_v_heads + head_k_dim = mem.linear_config.head_linear_k_dim + head_v_dim = mem.linear_config.head_linear_v_dim + + local_q_state = conv_state[:, 0 : local_q_heads * head_k_dim, :] + local_k_state = conv_state[:, local_q_heads * head_k_dim : 2 * local_q_heads * head_k_dim, :] + local_v_state = conv_state[:, 2 * local_q_heads * head_k_dim :, :] + global_q_page = conv_page[:, 0 : self.global_q_dim, :] + global_k_page = conv_page[:, self.global_q_dim : self.global_q_dim + self.global_k_dim, :] + global_v_page = conv_page[:, self.global_q_dim + self.global_k_dim :, :] + + qk_head_slice = self._get_head_slice( + tp_index, local_q_heads, self.global_linear_k_heads, mem.linear_config.tp_world_size, is_write=True + ) + if qk_head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = qk_head_slice + local_dim_start = local_head_start * head_k_dim + local_dim_end = local_head_end * head_k_dim + global_dim_start = global_head_start * head_k_dim + global_dim_end = global_head_end * head_k_dim + global_q_page[:, global_dim_start:global_dim_end, :].copy_( + local_q_state[:, local_dim_start:local_dim_end, :], non_blocking=True + ) + global_k_page[:, global_dim_start:global_dim_end, :].copy_( + local_k_state[:, local_dim_start:local_dim_end, :], non_blocking=True + ) + + v_head_slice = self._get_head_slice( + tp_index, local_v_heads, self.global_linear_v_heads, mem.linear_config.tp_world_size, is_write=True + ) + if v_head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = v_head_slice + local_dim_start = local_head_start * head_v_dim + local_dim_end = local_head_end * head_v_dim + global_dim_start = global_head_start * head_v_dim + global_dim_end = global_head_end * head_v_dim + global_v_page[:, global_dim_start:global_dim_end, :].copy_( + local_v_state[:, local_dim_start:local_dim_end, :], non_blocking=True + ) + return + + def _copy_ssm_state_to_page( + self, + ssm_state: torch.Tensor, + ssm_page: torch.Tensor, + mem: "Qwen3NextMemManager", + tp_index: int, + ): + head_slice = self._get_head_slice( + tp_index, + mem.linear_config.num_linear_v_heads, + self.global_linear_v_heads, + mem.linear_config.tp_world_size, + is_write=True, + ) + if head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = head_slice + ssm_page[:, global_head_start:global_head_end, :, :].copy_( + ssm_state[:, local_head_start:local_head_end, :, :], + non_blocking=True, + ) + return + + def _get_head_slice( + self, + tp_index: int, + local_heads: int, + global_heads: int, + tp_world_size: int, + is_write: bool, + ): + if local_heads == 0 or global_heads == 0: + return None + total_local_heads = local_heads * tp_world_size + repeat_count = max(1, total_local_heads // global_heads) + if is_write and repeat_count > 1 and tp_index % repeat_count != 0: + return None + unique_tp_index = tp_index // repeat_count + global_head_start = unique_tp_index * local_heads + global_head_end = min(global_head_start + local_heads, global_heads) + local_head_start = 0 + local_head_end = global_head_end - global_head_start + if local_head_end <= local_head_start: + return None + return local_head_start, local_head_end, global_head_start, global_head_end + + def _copy_page_to_conv_state( + self, + conv_page: torch.Tensor, + conv_state: torch.Tensor, + mem: "Qwen3NextMemManager", + tp_index: int, + ): + local_q_heads = mem.linear_config.num_linear_k_heads + local_v_heads = mem.linear_config.num_linear_v_heads + head_k_dim = mem.linear_config.head_linear_k_dim + head_v_dim = mem.linear_config.head_linear_v_dim + + local_q_state = conv_state[:, 0 : local_q_heads * head_k_dim, :] + local_k_state = conv_state[:, local_q_heads * head_k_dim : 2 * local_q_heads * head_k_dim, :] + local_v_state = conv_state[:, 2 * local_q_heads * head_k_dim :, :] + global_q_page = conv_page[:, 0 : self.global_q_dim, :] + global_k_page = conv_page[:, self.global_q_dim : self.global_q_dim + self.global_k_dim, :] + global_v_page = conv_page[:, self.global_q_dim + self.global_k_dim :, :] + + qk_head_slice = self._get_head_slice( + tp_index, local_q_heads, self.global_linear_k_heads, mem.linear_config.tp_world_size, is_write=False + ) + if qk_head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = qk_head_slice + local_dim_start = local_head_start * head_k_dim + local_dim_end = local_head_end * head_k_dim + global_dim_start = global_head_start * head_k_dim + global_dim_end = global_head_end * head_k_dim + local_q_state[:, local_dim_start:local_dim_end, :].copy_( + global_q_page[:, global_dim_start:global_dim_end, :], non_blocking=True + ) + local_k_state[:, local_dim_start:local_dim_end, :].copy_( + global_k_page[:, global_dim_start:global_dim_end, :], non_blocking=True + ) + + v_head_slice = self._get_head_slice( + tp_index, local_v_heads, self.global_linear_v_heads, mem.linear_config.tp_world_size, is_write=False + ) + if v_head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = v_head_slice + local_dim_start = local_head_start * head_v_dim + local_dim_end = local_head_end * head_v_dim + global_dim_start = global_head_start * head_v_dim + global_dim_end = global_head_end * head_v_dim + local_v_state[:, local_dim_start:local_dim_end, :].copy_( + global_v_page[:, global_dim_start:global_dim_end, :], non_blocking=True + ) + return + + def _copy_page_to_ssm_state( + self, + ssm_page: torch.Tensor, + ssm_state: torch.Tensor, + mem: "Qwen3NextMemManager", + tp_index: int, + ): + head_slice = self._get_head_slice( + tp_index, + mem.linear_config.num_linear_v_heads, + self.global_linear_v_heads, + mem.linear_config.tp_world_size, + is_write=False, + ) + if head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = head_slice + ssm_state[:, local_head_start:local_head_end, :, :].copy_( + ssm_page[:, global_head_start:global_head_end, :, :], + non_blocking=True, + ) + return + + def _read_one_rank( + self, + mem: "Qwen3NextMemManager", + tp_index: int, + req_buffer_idx: int, + conv_page: torch.Tensor, + ssm_page: torch.Tensor, + ): + conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] + ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + self._copy_page_to_conv_state(conv_page, conv_state, mem, tp_index) + self._copy_page_to_ssm_state(ssm_page, ssm_state, mem, tp_index) + return diff --git a/lightllm/common/kv_trans_kernel/nixl_kv_trans.py b/lightllm/common/kv_trans_kernel/nixl_kv_trans.py index c753a85c8e..f95cebbc18 100644 --- a/lightllm/common/kv_trans_kernel/nixl_kv_trans.py +++ b/lightllm/common/kv_trans_kernel/nixl_kv_trans.py @@ -10,6 +10,8 @@ @triton.jit def _page_io( mem_index_ptr, + token_num, + page_write_head_num, k_page_ptr, k_page_stride_size, k_page_stride_layer_num, @@ -45,88 +47,91 @@ def _page_io( k_stride_size = tl.cast(k_stride_size, dtype=tl.int64) v_stride_size = tl.cast(v_stride_size, dtype=tl.int64) - tid = tl.program_id(0) - kv_head_id = tl.program_id(1) - page_head_id = page_head_start + kv_head_id + start_index = tl.program_id(0) + grid_num = tl.num_programs(0) - mem_index = tl.load(mem_index_ptr + tid) - off_dim = tl.arange(0, HEAD_DIM_BLOCK) - if NEED_MASK: - mask = off_dim < head_dim - else: - mask = None + for tid in tl.range(start_index, token_num, step=grid_num): + for kv_head_id in tl.range(page_write_head_num): - for layer_index in tl.range(layer_num, num_stages=3): - if IS_WRITE: - k_tensor = tl.load( - k_ptr - + layer_index * k_stride_layer_num - + mem_index * k_stride_size - + kv_head_id * k_stride_head - + off_dim * k_stride_dim, - mask=mask, - ) - v_tensor = tl.load( - v_ptr - + layer_index * v_stride_layer_num - + mem_index * v_stride_size - + kv_head_id * v_stride_head - + off_dim * v_stride_dim, - mask=mask, - ) - tl.store( - k_page_ptr - + tid * k_page_stride_size - + layer_index * k_page_stride_layer_num - + page_head_id * k_page_stride_head - + off_dim * k_page_stride_dim, - k_tensor, - mask=mask, - ) - tl.store( - v_page_ptr - + tid * v_page_stride_size - + layer_index * v_page_stride_layer_num - + page_head_id * v_page_stride_head - + off_dim * v_page_stride_dim, - v_tensor, - mask=mask, - ) - else: - k_page_tensor = tl.load( - k_page_ptr - + tid * k_page_stride_size - + layer_index * k_page_stride_layer_num - + page_head_id * k_page_stride_head - + off_dim * k_page_stride_dim, - mask=mask, - ) - v_page_tensor = tl.load( - v_page_ptr - + tid * v_page_stride_size - + layer_index * v_page_stride_layer_num - + page_head_id * v_page_stride_head - + off_dim * v_page_stride_dim, - mask=mask, - ) - tl.store( - k_ptr - + layer_index * k_stride_layer_num - + mem_index * k_stride_size - + kv_head_id * k_stride_head - + off_dim * k_stride_dim, - k_page_tensor, - mask=mask, - ) - tl.store( - v_ptr - + layer_index * v_stride_layer_num - + mem_index * v_stride_size - + kv_head_id * v_stride_head - + off_dim * v_stride_dim, - v_page_tensor, - mask=mask, - ) + page_head_id = page_head_start + kv_head_id + mem_index = tl.load(mem_index_ptr + tid) + off_dim = tl.arange(0, HEAD_DIM_BLOCK) + if NEED_MASK: + mask = off_dim < head_dim + else: + mask = None + + for layer_index in tl.range(layer_num, num_stages=3): + if IS_WRITE: + k_tensor = tl.load( + k_ptr + + layer_index * k_stride_layer_num + + mem_index * k_stride_size + + kv_head_id * k_stride_head + + off_dim, + mask=mask, + ) + v_tensor = tl.load( + v_ptr + + layer_index * v_stride_layer_num + + mem_index * v_stride_size + + kv_head_id * v_stride_head + + off_dim, + mask=mask, + ) + tl.store( + k_page_ptr + + tid * k_page_stride_size + + layer_index * k_page_stride_layer_num + + page_head_id * k_page_stride_head + + off_dim, + k_tensor, + mask=mask, + ) + tl.store( + v_page_ptr + + tid * v_page_stride_size + + layer_index * v_page_stride_layer_num + + page_head_id * v_page_stride_head + + off_dim, + v_tensor, + mask=mask, + ) + else: + k_page_tensor = tl.load( + k_page_ptr + + tid * k_page_stride_size + + layer_index * k_page_stride_layer_num + + page_head_id * k_page_stride_head + + off_dim, + mask=mask, + ) + v_page_tensor = tl.load( + v_page_ptr + + tid * v_page_stride_size + + layer_index * v_page_stride_layer_num + + page_head_id * v_page_stride_head + + off_dim, + mask=mask, + ) + tl.store( + k_ptr + + layer_index * k_stride_layer_num + + mem_index * k_stride_size + + kv_head_id * k_stride_head + + off_dim, + k_page_tensor, + mask=mask, + ) + tl.store( + v_ptr + + layer_index * v_stride_layer_num + + mem_index * v_stride_size + + kv_head_id * v_stride_head + + off_dim, + v_page_tensor, + mask=mask, + ) return @@ -169,10 +174,17 @@ def page_io( page_head_start = tp_index * (page_write_head_num) token_num = len(mem_indexes) - grid = (token_num, page_write_head_num) + grid = (128,) + + assert k_page_tensor.stride(3) == 1 + assert v_page_tensor.stride(3) == 1 + assert k_buffer.stride(3) == 1 + assert v_buffer.stride(3) == 1 _page_io[grid]( mem_index_ptr=mem_indexes, + token_num=token_num, + page_write_head_num=page_write_head_num, k_page_ptr=k_page_tensor, k_page_stride_size=k_page_tensor.stride(0), k_page_stride_layer_num=k_page_tensor.stride(1), @@ -207,6 +219,7 @@ def page_io( @triton.jit def _mla_page_io( mem_index_ptr, + token_num, page_ptr, page_stride_size, page_stride_layer_num, @@ -227,52 +240,54 @@ def _mla_page_io( kv_stride_layer_num = tl.cast(kv_stride_layer_num, dtype=tl.int64) kv_stride_size = tl.cast(kv_stride_size, dtype=tl.int64) - tid = tl.program_id(0) + start_index = tl.program_id(0) + grid_num = tl.num_programs(0) - mem_index = tl.load(mem_index_ptr + tid) - off_dim = tl.arange(0, HEAD_DIM_BLOCK) - if NEED_MASK: - mask = off_dim < head_dim - else: - mask = None - - for layer_index in tl.range(layer_num, num_stages=3): - if IS_WRITE: - kv_tensor = tl.load( - kv_ptr - + layer_index * kv_stride_layer_num - + mem_index * kv_stride_size - + 0 * kv_stride_head - + off_dim * kv_stride_dim, - mask=mask, - ) - tl.store( - page_ptr - + tid * page_stride_size - + layer_index * page_stride_layer_num - + 0 * page_stride_head - + off_dim * page_stride_dim, - kv_tensor, - mask=mask, - ) + for tid in tl.range(start_index, token_num, step=grid_num): + mem_index = tl.load(mem_index_ptr + tid) + off_dim = tl.arange(0, HEAD_DIM_BLOCK) + if NEED_MASK: + mask = off_dim < head_dim else: - page_tensor = tl.load( - page_ptr - + tid * page_stride_size - + layer_index * page_stride_layer_num - + 0 * page_stride_head - + off_dim * page_stride_dim, - mask=mask, - ) - tl.store( - kv_ptr - + layer_index * kv_stride_layer_num - + mem_index * kv_stride_size - + 0 * kv_stride_head - + off_dim * kv_stride_dim, - page_tensor, - mask=mask, - ) + mask = None + + for layer_index in tl.range(layer_num, num_stages=3): + if IS_WRITE: + kv_tensor = tl.load( + kv_ptr + + layer_index * kv_stride_layer_num + + mem_index * kv_stride_size + + 0 * kv_stride_head + + off_dim * kv_stride_dim, + mask=mask, + ) + tl.store( + page_ptr + + tid * page_stride_size + + layer_index * page_stride_layer_num + + 0 * page_stride_head + + off_dim * page_stride_dim, + kv_tensor, + mask=mask, + ) + else: + page_tensor = tl.load( + page_ptr + + tid * page_stride_size + + layer_index * page_stride_layer_num + + 0 * page_stride_head + + off_dim * page_stride_dim, + mask=mask, + ) + tl.store( + kv_ptr + + layer_index * kv_stride_layer_num + + mem_index * kv_stride_size + + 0 * kv_stride_head + + off_dim * kv_stride_dim, + page_tensor, + mask=mask, + ) return @@ -290,10 +305,11 @@ def mla_page_io(mem_indexes: torch.Tensor, page_tensor: torch.Tensor, kv_buffer: assert page_head_num == kv_head_num == 1 token_num = len(mem_indexes) - grid = (token_num,) + grid = (64,) _mla_page_io[grid]( mem_index_ptr=mem_indexes, + token_num=token_num, page_ptr=page_tensor, page_stride_size=page_tensor.stride(0), page_stride_layer_num=page_tensor.stride(1), diff --git a/lightllm/common/linear_att_cache_manager/__init__.py b/lightllm/common/linear_att_cache_manager/__init__.py new file mode 100644 index 0000000000..ab3c8e2cd9 --- /dev/null +++ b/lightllm/common/linear_att_cache_manager/__init__.py @@ -0,0 +1,3 @@ +from .linear_att_buffer_manager import LinearAttCacheManager +from .config_objs import LinearAttCacheConfig +from .layer_cache import LayerCache diff --git a/lightllm/common/linear_att_cache_manager/config_objs.py b/lightllm/common/linear_att_cache_manager/config_objs.py new file mode 100644 index 0000000000..bc39067069 --- /dev/null +++ b/lightllm/common/linear_att_cache_manager/config_objs.py @@ -0,0 +1,116 @@ +import torch +import dataclasses +import triton +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger +from lightllm.utils.torch_dtype_utils import get_torch_dtype + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class LinearAttCacheConfig: + tp_world_size: int + # full att 的参数 + full_att_all_num_kv_heads: int + full_att_dtype: torch.dtype + full_att_num_kv_heads: int # 这个是 tp 后的head头数量 + full_att_head_dim: int + + # linear att 的参数 + global_linear_k_heads: int + global_linear_v_heads: int + num_linear_k_heads: int + num_linear_v_heads: int + head_linear_k_dim: int + head_linear_v_dim: int + conv_kernel_size: int + linear_layer_num: int + conv_state_dtype: torch.dtype + ssm_state_dtype: torch.dtype + full_attention_interval: int + all_layer_num: int # 包括 linear att 和 full att 的层加起来的层数 + + def get_conv_dim(self): + # 第一项对应q的参数,第二项对应k的参数,第三项对应v的参数 + # 由于 k_dim = q_dim, k_heads = q_heads, 所以第一项和第二项的计算 + # 形式相同,但是实际内在含义是不同的。 + return ( + self.head_linear_k_dim * self.num_linear_k_heads + + self.head_linear_k_dim * self.num_linear_k_heads + + self.head_linear_v_dim * self.num_linear_v_heads + ) + + def get_conv_state_shape(self): + return (self.get_conv_dim(), self.conv_kernel_size - 1) + + def get_ssm_state_shape(self): + return (self.num_linear_v_heads, self.head_linear_k_dim, self.head_linear_v_dim) + + def get_conv_state_bytes_per_layer(self): + return self.get_conv_dim() * (self.conv_kernel_size - 1) * self.conv_state_dtype.itemsize + + def get_ssm_state_bytes_per_layer(self): + return self.num_linear_v_heads * self.head_linear_k_dim * self.head_linear_v_dim * self.ssm_state_dtype.itemsize + + def get_cpu_cache_big_page_bytes(self): + a = self.get_cpu_cache_full_att_bytes() + b = self.get_cpu_cache_conv_bytes() + c = self.get_cpu_cache_ssm_bytes() + + return triton.cdiv(a + b + c, 16) * 16 + + def get_cpu_cache_full_att_bytes(self): + big_page_token_num = ( + get_env_start_args().linear_att_page_block_num * get_env_start_args().linear_att_hash_page_size + ) + assert big_page_token_num == get_env_start_args().cpu_cache_token_page_size + full_att_bytes = 2 * self.full_att_all_num_kv_heads * self.full_att_head_dim * self.full_att_dtype.itemsize + a = full_att_bytes * (self.all_layer_num - self.linear_layer_num) * big_page_token_num + return a + + def get_cpu_cache_conv_bytes(self): + b = self.get_conv_state_bytes_per_layer() * self.linear_layer_num * self.tp_world_size + return b + + def get_cpu_cache_ssm_bytes(self): + c = self.get_ssm_state_bytes_per_layer() * self.linear_layer_num * self.tp_world_size + return c + + @staticmethod + def load_from_args() -> "LinearAttCacheConfig": + args = get_env_start_args() + model_path = args.model_dir + from transformers.configuration_utils import PretrainedConfig + + model_cfg, _ = PretrainedConfig.get_config_dict(model_path) + model_type = model_cfg["model_type"] + assert model_type in ["qwen3_5", "qwen3_5_moe", "qwen3_5_text", "qwen3_5_moe_text"] + llm_config = model_cfg + try: + llm_config = llm_config["text_config"] + except: + pass + + n_layer = llm_config["num_hidden_layers"] + + tp_world_size = get_env_start_args().tp // get_env_start_args().dp + return LinearAttCacheConfig( + tp_world_size=tp_world_size, + full_att_all_num_kv_heads=llm_config["num_key_value_heads"], + full_att_dtype=get_torch_dtype(args.data_type), + full_att_num_kv_heads=max(1, llm_config["num_key_value_heads"] // tp_world_size), + full_att_head_dim=llm_config["head_dim"], + global_linear_k_heads=llm_config["linear_num_key_heads"], + global_linear_v_heads=llm_config["linear_num_value_heads"], + num_linear_k_heads=max(1, llm_config["linear_num_key_heads"] // tp_world_size), + num_linear_v_heads=max(1, llm_config["linear_num_value_heads"] // tp_world_size), + head_linear_k_dim=llm_config["linear_key_head_dim"], + head_linear_v_dim=llm_config["linear_value_head_dim"], + conv_kernel_size=llm_config["linear_conv_kernel_dim"], + linear_layer_num=n_layer - (n_layer // llm_config["full_attention_interval"]), + conv_state_dtype=get_torch_dtype(args.data_type), + ssm_state_dtype=get_torch_dtype(args.linear_att_ssm_data_type), + full_attention_interval=llm_config["full_attention_interval"], + all_layer_num=n_layer, + ) diff --git a/lightllm/common/linear_att_cache_manager/layer_cache.py b/lightllm/common/linear_att_cache_manager/layer_cache.py new file mode 100644 index 0000000000..9ac72acca3 --- /dev/null +++ b/lightllm/common/linear_att_cache_manager/layer_cache.py @@ -0,0 +1,37 @@ +from typing import Tuple +import torch +import numpy as np +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class LayerCache: + def __init__( + self, + size: int, + dtype: torch.dtype, + shape: Tuple[int, ...], + layer_num: int, + device: torch.device, + size_first: bool = False, + ): + self.size_first = size_first + self.size = size + self.dtype = dtype + self.shape = shape + self.layer_num = layer_num + self.device = device + if not self.size_first: + if device == "cpu": + self.buffer = torch.zeros((self.layer_num, size, *shape), dtype=dtype, device="cpu", pin_memory=True) + else: + self.buffer = torch.zeros((self.layer_num, size, *shape), dtype=dtype, device=device) + else: + if device == "cpu": + self.buffer = torch.zeros((size, self.layer_num, *shape), dtype=dtype, device="cpu", pin_memory=True) + else: + self.buffer = torch.zeros((size, self.layer_num, *shape), dtype=dtype, device=device) + + def get_cell_size(self): + return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py new file mode 100644 index 0000000000..30dc4d937c --- /dev/null +++ b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py @@ -0,0 +1,83 @@ +import torch +import collections +from lightllm.utils.log_utils import init_logger +from .layer_cache import LayerCache +from typing import List, Optional, Tuple, Union +from .config_objs import LinearAttCacheConfig + +logger = init_logger(__name__) + + +class LinearAttCacheManager: + def __init__( + self, + size: int, + linear_config: LinearAttCacheConfig, + keep_num: int = 0, # 用于记录需要保留的缓存数量,用于支持含有 linear_att 的如qwen3.5 模型的cpu cache的碎页处理。 + ): + # init the mem state + self.size = size + self.linear_config = linear_config + self.keep_num = keep_num + assert 0 <= self.keep_num <= self.size, f"invalid keep_num {self.keep_num} for size {self.size}" + # init the layer cache + self.conv_state_cache = LayerCache( + size=self.size, + dtype=self.linear_config.conv_state_dtype, + shape=self.linear_config.get_conv_state_shape(), + layer_num=self.linear_config.linear_layer_num, + device="cpu", + size_first=True, + ) + self.ssm_state_cache = LayerCache( + size=self.size, + dtype=self.linear_config.ssm_state_dtype, + shape=self.linear_config.get_ssm_state_shape(), + layer_num=self.linear_config.linear_layer_num, + device="cpu", + size_first=True, + ) + self.clear_to_init_state() + return + + def get_state_cache(self, buffer_idx: int): + return self.conv_state_cache.buffer[buffer_idx, ...], self.ssm_state_cache.buffer[buffer_idx, ...] + + def alloc_one_state_cache(self) -> Optional[int]: + if len(self.free_list) == 0: + return None + + alloc_index = self.free_list.popleft() + return alloc_index + + def alloc_state_cache(self, need_size: int) -> Optional[List[int]]: + if need_size > len(self.free_list): + logger.error(f"warn no enough cache need_size {need_size} free_size {len(self.free_list)}") + return None + + alloc_indexes = [self.free_list.popleft() for _ in range(need_size)] + return alloc_indexes + + def free_state_cache(self, free_indexes: List[int]): + alloc_upper_bound = self.size - self.keep_num + for idx in free_indexes: + assert 0 <= idx < alloc_upper_bound, ( + f"free index {idx} out of alloc range [0, {alloc_upper_bound}), " f"reserved tail num {self.keep_num}" + ) + self.free_list.extend(free_indexes) + assert ( + len(self.free_list) <= alloc_upper_bound + ), f"free cache num {len(self.free_list)} should not be larger than alloc size {alloc_upper_bound}" + return + + def get_free_cache_num(self): + return len(self.free_list) + + def get_used_cache_num(self): + return self.size - len(self.free_list) + + def clear_to_init_state(self): + self.conv_state_cache.buffer.zero_() + self.ssm_state_cache.buffer.zero_() + self.free_list = collections.deque(range(self.size - self.keep_num)) + return diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py deleted file mode 100644 index 8602f2e67e..0000000000 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ /dev/null @@ -1,283 +0,0 @@ -from typing import List, Tuple, Union - -import torch -import numpy as np - -from lightllm.utils.dist_utils import get_current_rank_in_node -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args -from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_mamba_buffer, fork_mamba_buffer -from lightllm.utils.log_utils import init_logger -from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt - -logger = init_logger(__name__) - -MAMBA_CACHE_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_mamba_cache_can_use_num" - - -class LayerCache: - def __init__(self, size: int, dtype: torch.dtype, shape: Tuple[int, ...], layer_num: int): - self.size = size - self.dtype = dtype - self.shape = shape - self.layer_num = layer_num - self.buffer = torch.zeros((self.layer_num, size + 1, *shape), dtype=dtype, device="cuda") - - def get_cell_size(self): - return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) - - -class MambaCacheManager: - def __init__( - self, - size: int, - layer_num: int, - conv_state_dtype: torch.dtype, - ssm_state_dtype: torch.dtype, - conv_kernel_size: int, - num_linear_k_heads: int, - num_linear_v_heads: int, - head_linear_k_dim: int, - head_linear_v_dim: int, - ): - # init the mem state - self.size = size - self.num_linear_k_heads = num_linear_k_heads - self.num_linear_v_heads = num_linear_v_heads - self.head_linear_k_dim = head_linear_k_dim - self.head_linear_v_dim = head_linear_v_dim - self.conv_dim = ( - self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads - ) - self.layer_num = layer_num - self.conv_kernel_size = conv_kernel_size - conv_state_shape = (self.conv_dim, conv_kernel_size - 1) - ssm_state_shape = ( - self.num_linear_v_heads, - self.head_linear_k_dim, - self.head_linear_v_dim, - ) - self.ssm_state_dtype = ssm_state_dtype - self.conv_state_dtype = conv_state_dtype - self.profile_size() - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._mem_state_return = torch.arange( - 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._return_start = 0 - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num = SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # init the layer cache - self.conv_state_cache = LayerCache(self.size, conv_state_dtype, conv_state_shape, layer_num) - self.ssm_state_cache = LayerCache(self.size, ssm_state_dtype, ssm_state_shape, layer_num) - self.HOLD_BUFFER_INDEX = self.size - - def get_mamba_cache(self, layer_idx: int): - conv_state = self.conv_state_cache.buffer[layer_idx] - ssm_state = self.ssm_state_cache.buffer[layer_idx] - return conv_state, ssm_state - - def copy_state_buffers(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): - copy_mamba_buffer( - self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes - ) - copy_mamba_buffer( - self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes - ) - - def fork_state_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): - fork_mamba_buffer( - self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_index, dst_buffer_indexes - ) - fork_mamba_buffer( - self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes - ) - - def fork_ssm_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): - """ - Fork ONLY SSM states (not conv states) from source indices to destination indices. - - This is used for MTP mode where each buffer maintains its own independent conv state, - but SSM states need to be synchronized. - """ - fork_mamba_buffer( - self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes - ) - - def alloc(self, need_size) -> torch.Tensor: - if need_size > self.mark_end - self.mark_start: - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") - assert False, "error alloc state" - - start = self.mark_start - end = self.mark_start + need_size - self.mark_start += need_size - - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # 利用缓冲区返回,避免异步情况下的内存竞争 - if self._return_start + need_size > self._mem_state_return.shape[0]: - self._return_start = 0 - ans = self._mem_state_return[self._return_start : self._return_start + need_size] - ans.copy_(self.mem_state[start:end]) - self._return_start += need_size - return ans - - def free(self, free_index: Union[torch.Tensor, List[int]]): - """ - Free the allocated cache buffers and clear them. - - Args: - free_index: Buffer indices to free (tensor or list of ints) - """ - # Convert to tensor if needed for indexing - if isinstance(free_index, list): - free_index_tensor = torch.tensor(free_index, dtype=torch.long, device="cuda") - else: - free_index_tensor = free_index.to(device="cuda", dtype=torch.long) - - # Clear the buffers for the freed indices - # Shape: [layer_num, buffer_index, *shape] - self.conv_state_cache.buffer[:, free_index_tensor, ...] = 0 - self.ssm_state_cache.buffer[:, free_index_tensor, ...] = 0 - - # update the mem state - end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" - - if isinstance(free_index, list): - free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device) - self.mem_state[start:end] = free_index_tensor - else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 - self.mem_state[start:end] = free_index - - self.mark_start -= len(free_index) - - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.can_use_mem_size == len(self.mem_state): - logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") - - return - - def free_all(self): - self.conv_state_cache.buffer.fill_(0) - self.ssm_state_cache.buffer.fill_(0) - self.can_use_mem_size = len(self.mem_state) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) - self.mark_start = 0 - self.mark_end = len(self.mem_state) - - return - - def resize_mem(self, new_size): - """ - just for test code - """ - self.size = new_size - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - return - - def profile_size( - self, - ): - start_args = get_env_start_args() - if self.size is not None and not start_args.disable_dynamic_prompt_cache: - assert self.size < start_args.running_max_req_size * 2, ( - f"error mamba_cache_size({self.size}), ", - f"mamba_cache_size should be at least running_max_req_size * 2", - f"({start_args.running_max_req_size * 2}), ", - f"you can add `--disable_dynamic_prompt_cache` to avoid this error.", - ) - return - from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory - import torch.distributed as dist - - mem_fraction = start_args.mem_fraction - world_size = dist.get_world_size() - total_memory = get_total_gpu_memory() - available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) - conv_cell_size = ( - self.layer_num - * self.conv_dim - * (self.conv_kernel_size - 1) - * torch._utils._element_size(self.conv_state_dtype) - ) - ssm_cell_size = ( - self.layer_num - * (self.num_linear_v_heads) - * self.head_linear_k_dim - * self.head_linear_v_dim - * torch._utils._element_size(self.ssm_state_dtype) - ) - total_cell_size = conv_cell_size + ssm_cell_size - mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 - mamba_memory_gb = available_memory * mamba_cache_ratio - mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) - - if mamba_cache_size < start_args.running_max_req_size * 2: - ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 - raise ValueError( - f"Insufficient memory for mamba cache allocation!\n\n" - f"mamba_cache_size should be at least running_max_req_size * 2\n" - f"Calculated mamba_cache_size ({mamba_cache_size}) < " - f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n\n" - f"Memory budget:\n" - f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" - f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" - f" Calculated buffers: {mamba_cache_size}\n" - f" Required buffers: {start_args.running_max_req_size}\n\n" - f"Solutions:\n" - f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" - f" 2. Increase --mamba_cache_ratio from {ratio} to " - f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n" - f" 3. Increase --mem_fraction to leave more memory for caches\n" - ) - - logger.info( - f"Mamba cache allocation:\n" - f" Available memory: {mamba_memory_gb:.2f} GB\n" - f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" - f" Calculated mamba_cache_size: {mamba_cache_size}" - ) - self.size = mamba_cache_size - return - - -class ReadOnlyStaticsMambaCacheManager: - """ - 读取一些统计信息 - """ - - def __init__(self) -> None: - args = get_env_start_args() - self.global_world_size = args.tp - self.node_world_size = args.tp // args.nnodes - self.dp_world_size = self.global_world_size // args.dp - # 兼容多机 dp size=1 纯 tp 模式的情况 - self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 - self.shared_tp_can_use_token_nums = [ - SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") - for rank_in_node in range(0, self.node_world_size, self.dp_world_size) - ] - - def get_unrefed_token_num(self, dp_rank_in_node: int): - if self.is_multinode_tp: - return self.shared_tp_can_use_token_nums[0].get_value() - return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 1f08432c6a..cd534d53ec 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -7,18 +7,36 @@ from .awq import * from .no_quant import * from lightllm.utils.log_utils import init_logger +from lightllm.utils.device_utils import is_sm100_gpu logger = init_logger(__name__) +EXPERT_DTYPE_TO_QUANT_TYPE = { + "fp8": "deepgemm-fp8w8a8-b128", + "fp4": "deepgemm-fp4fp8-b32", +} +SUPPORTED_EXPERT_DTYPES = tuple(EXPERT_DTYPE_TO_QUANT_TYPE) + class Quantcfg: - def __init__(self, network_config, quant_type="none", custom_cfg_path=None): + def __init__(self, network_config, quant_type="none", custom_cfg_path=None, expert_dtype=None): self.layer_num = network_config["n_layer"] self.quant_type = quant_type + self.expert_dtype = expert_dtype self.network_config_ = network_config self._parse_custom_cfg(custom_cfg_path) self._parse_network_config(network_config) + def _get_expert_quant_type(self, expert_dtype): + quant_type = EXPERT_DTYPE_TO_QUANT_TYPE.get(expert_dtype) + if quant_type is None: + raise ValueError( + f"unsupported expert_dtype `{expert_dtype}`; expected one of {list(SUPPORTED_EXPERT_DTYPES)}" + ) + if expert_dtype == "fp4" and not is_sm100_gpu(): + raise RuntimeError("expert_dtype `fp4` requires an SM100 GPU; please use `fp8` on non-SM100 GPUs.") + return quant_type + def _parse_network_config(self, network_config): hf_quantization_config = network_config.get("quantization_config", None) if hf_quantization_config is None: @@ -44,6 +62,19 @@ def _mapping_quant_method(self): else: self.quant_type = "vllm-fp8w8a8-b128" logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") + + # fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度, + # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。 + expert_dtype = self.expert_dtype or self.network_config_.get("expert_dtype", None) + if expert_dtype is None: + return + target = self._get_expert_quant_type(expert_dtype) + for layer_num in range(self.layer_num): + if self.expert_dtype is not None: + self.quant_cfg[layer_num]["fused_moe"] = target + else: + self.quant_cfg[layer_num].setdefault("fused_moe", target) + logger.info(f"select fused_moe quant way from expert_dtype=`{expert_dtype}`: {target}") elif self.hf_quantization_method == "awq": self.quant_type = "awq" if is_awq_marlin_compatible(self.hf_quantization_config): diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..ec1ee90fd4 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -126,6 +126,78 @@ def _create_weight( return mm_param, mm_param_list +@QUANTMETHODS.register(["deepgemm-fp4fp8-b32"], platform="cuda") +class DeepGEMMFP8FP4B32QuantizationMethod(DeepGEMMBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 32 + self.weight_suffix = "weight" + self.weight_zero_point_suffix = None + self.weight_scale_suffix = None + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "deepgemm-fp4fp8-b32" + + def quantize(self, weight: torch.Tensor, output: WeightPack): + from deep_gemm.utils import per_token_cast_to_fp4 + import deep_gemm + + weight = weight.cuda(output.weight.device) + if weight.dim() == 2: + n, k = weight.shape + packed_weight, weight_scale = per_token_cast_to_fp4(weight, use_ue8m0=True, gran_k=self.block_size) + weight_scale = deep_gemm.transform_sf_into_required_layout(weight_scale, n, k, (1, self.block_size), None) + else: + num_groups, n, k = weight.shape + packed_weight = torch.empty((num_groups, n, k // 2), device=weight.device, dtype=torch.int8) + weight_scale = torch.empty((num_groups, n, k // self.block_size), device=weight.device, dtype=torch.float32) + for i in range(num_groups): + packed_weight[i], weight_scale[i] = per_token_cast_to_fp4( + weight[i], use_ue8m0=True, gran_k=self.block_size + ) + weight_scale = deep_gemm.transform_sf_into_required_layout( + weight_scale, n, k, (1, self.block_size), num_groups + ) + output.weight.copy_(packed_weight) + output.weight_scale.copy_(weight_scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "WeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("deepgemm-fp4fp8-b32 is only implemented for fused MoE expert weights") + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + assert in_dim % 2 == 0, "FP4 packed weight requires even input dimension" + assert in_dim % self.block_size == 0, "FP4 scale dimension must be divisible by block_size" + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.int32).cuda( + device_id + ) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=out_dims, + weight_scale_split_dim=-2, + ) + return mm_param, mm_param_list + + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index b3d29b0527..65ec6cd145 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -8,19 +8,12 @@ from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from .quantize_method import WeightPack -if HAS_LIGHTLLM_KERNEL: - - def scaled_fp8_quant(tensor, *args, **kwargs): - return light_ops.per_token_quant_bf16_fp8(tensor) - -else: - if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant +if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant LIGHTLLM_USE_TRITON_FP8_SCALED_MM = os.getenv("LIGHTLLM_USE_TRITON_FP8_SCALED_MM", "False").upper() in [ "ON", diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 169e6ac2d8..01e9c4ad35 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,14 +1,20 @@ import torch import collections +from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager -from typing import List, Optional - +from typing import List, Optional, TYPE_CHECKING from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager +from lightllm.common.linear_att_cache_manager.layer_cache import LayerCache +from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import LinearAttCacheManager + +if TYPE_CHECKING: + from lightllm.server.router.model_infer.infer_batch import InferReq logger = init_logger(__name__) @@ -128,11 +134,7 @@ def __init__(self, max_request_num): (max_request_num + 1, self.vocab_size), dtype=torch.int32, device="cpu", pin_memory=True ) - def init_req_sampling_params(self, req): - # fix cycle loop import - from lightllm.server.router.model_infer.infer_batch import InferReq - - req: InferReq = req + def init_req_sampling_params(self, req: "InferReq"): shm_param = req.sampling_param.shm_param self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token()) @@ -186,12 +188,8 @@ def update_reqs_out_token_counter_gpu( return def update_reqs_token_counter( - self, req_objs: List, next_token_ids: List[int], accept_mark: Optional[List[List[bool]]] = None + self, req_objs: List["InferReq"], next_token_ids: List[int], accept_mark: Optional[List[List[bool]]] = None ): - from lightllm.server.router.model_infer.infer_batch import InferReq - - req_objs: List[InferReq] = req_objs - if self.penalty_counter_mode != "cpu_counter": return @@ -200,13 +198,9 @@ def update_reqs_token_counter( req_obj.out_token_id_count[next_token_id] += 1 return - def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): + def gen_cpu_out_token_counter_sampling_params(self, req_objs: List["InferReq"]): assert self.penalty_counter_mode == "cpu_counter" - from lightllm.server.router.model_infer.infer_batch import InferReq - - req_objs: List[InferReq] = req_objs - p_token_ids: List[int] = [] p_token_counts: List[int] = [] p_cumsum_seq_len: List[int] = [ @@ -236,25 +230,72 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): class ReqManagerForMamba(ReqManager): - def __init__(self, max_request_num, max_sequence_length, mem_manager): - from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager - + def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_config: LinearAttCacheConfig): super().__init__(max_request_num, max_sequence_length, mem_manager) self.mtp_step = get_env_start_args().mtp_step - self.buffer_mem_manager: MambaCacheManager = self.mem_manager.mamba_cache_mem_manager - self.req_to_buffer_index = torch.zeros( - (self.max_request_num + 1, self.mtp_step + 1), dtype=torch.int32, device="cuda" + self.big_page_token_num = ( + get_env_start_args().linear_att_page_block_num * get_env_start_args().linear_att_hash_page_size + ) + assert ( + self.mtp_step == 0 + ), "currently only support mtp_step 0 for simplicity, more mtp_step support will be added in the future" + self.linear_config = linear_config + + self.req_to_conv_state = LayerCache( + size=(max_request_num + 1) * (self.mtp_step + 1), + dtype=self.linear_config.conv_state_dtype, + shape=self.linear_config.get_conv_state_shape(), + layer_num=self.linear_config.linear_layer_num, + device="cuda", + ) + self.req_to_ssm_state = LayerCache( + size=(max_request_num + 1) * (self.mtp_step + 1), + dtype=self.linear_config.ssm_state_dtype, + shape=self.linear_config.get_ssm_state_shape(), + layer_num=self.linear_config.linear_layer_num, + device="cuda", ) - self.req_to_buffer_index[self.HOLD_REQUEST_ID, :] = self.buffer_mem_manager.HOLD_BUFFER_INDEX + return + + def init_linear_att_state(self, req: "InferReq"): + index = req.req_idx * (self.mtp_step + 1) + conv_state = self.req_to_conv_state.buffer[:, index, ...] + ssm_state = self.req_to_ssm_state.buffer[:, index, ...] + conv_state.fill_(0) + ssm_state.fill_(0) + return + + def get_mamba_cache(self, layer_idx_in_all: int): + assert ( + 0 <= layer_idx_in_all < self.linear_config.all_layer_num + ), f"invalid transformer layer index {layer_idx_in_all}" + layer_idx_in_linear = layer_idx_in_all - (layer_idx_in_all // self.linear_config.full_attention_interval) + conv_states = self.req_to_conv_state.buffer[layer_idx_in_linear] + ssm_states = self.req_to_ssm_state.buffer[layer_idx_in_linear] + return conv_states, ssm_states + + def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req: "InferReq"): - def free_buffer(self, free_buffer_indexes: List[int]): - self.buffer_mem_manager.free(free_buffer_indexes) + from .linear_att_cache_manager import LinearAttCacheManager + + big_page_buffers: LinearAttCacheManager = self.mem_manager.linear_att_big_page_buffers + + conv_state, ssm_state = big_page_buffers.get_state_cache(buffer_idx=big_page_buffer_idx) + dest_req_idx = req.req_idx * (self.mtp_step + 1) + + self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state + self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state return - def alloc_buffer_for_req(self, req_index: torch.Tensor): - num_reqs = req_index.shape[0] - num_buffers_per_req = self.mtp_step + 1 - buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req) - if not buffer_indexes.is_cuda: - buffer_indexes = buffer_indexes.cuda() - self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req) + def copy_small_page_buffer_to_linear_att_state( + self, req: "InferReq", linear_att_small_page_buffers: LinearAttCacheManager + ): + conv_state, ssm_state = linear_att_small_page_buffers.get_state_cache( + buffer_idx=req.shared_kv_node.small_page_buffer_idx + ) + dest_req_idx = req.req_idx * (self.mtp_step + 1) + # TODO 下面这个从 cpu cache 拷贝数据的 gpu的操作,是否是阻塞的操作。 + # 同时,非连续对象的拷贝,可能存在效率问题。 + self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state + self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state + return diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..3835d4703f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..5a0729cfcf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 5, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 5, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 5, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 5, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..1301ae36b0 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 1, + "num_warps": 8 + }, + "100": { + "num_stages": 4, + "num_warps": 2 + }, + "1024": { + "num_stages": 4, + "num_warps": 1 + }, + "128": { + "num_stages": 4, + "num_warps": 2 + }, + "16": { + "num_stages": 4, + "num_warps": 1 + }, + "16384": { + "num_stages": 4, + "num_warps": 1 + }, + "2048": { + "num_stages": 1, + "num_warps": 2 + }, + "256": { + "num_stages": 4, + "num_warps": 2 + }, + "32": { + "num_stages": 1, + "num_warps": 2 + }, + "4096": { + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "num_stages": 4, + "num_warps": 1 + }, + "8": { + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json index 9f3a8dcb25..c377e8b898 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -1,4 +1,16 @@ { + "1": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 4 + }, "1024": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -17,6 +29,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, "16384": { "BLOCK_M": 8, "BLOCK_N": 128, @@ -35,12 +53,24 @@ "NUM_STAGES": 1, "num_warps": 1 }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, "32768": { "BLOCK_M": 8, "BLOCK_N": 128, "NUM_STAGES": 1, "num_warps": 1 }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "512": { "BLOCK_M": 1, "BLOCK_N": 128, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2816,N=1408,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2816,N=1408,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..40e7b7917f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2816,N=1408,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=704,N=2816,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=704,N=2816,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..0de06a11d7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=704,N=2816,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "65536": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json index ea17f7f5ae..0cee014398 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json @@ -31,6 +31,10 @@ "BLOCK_SIZE": 256, "num_warps": 4 }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, "24": { "BLOCK_SIZE": 128, "num_warps": 8 @@ -43,6 +47,10 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, "48": { "BLOCK_SIZE": 128, "num_warps": 8 @@ -59,6 +67,10 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "8192": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, "96": { "BLOCK_SIZE": 128, "num_warps": 8 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2816,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2816,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json new file mode 100644 index 0000000000..4213010b0f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2816,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "1024": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "8192": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2112,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2112,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..356559b99e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2112,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_M": 256, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "8192": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=21504,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=21504,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..53b3bdb311 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=21504,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "100": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=704,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=704,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..abc156d395 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=704,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "65536": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json new file mode 100644 index 0000000000..f2a0176db5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/append_fused_shared_experts:v1/{has_shared_expert_gate=true,num_fused_shared_experts=1,topk_ids_dtype=torch.int64,topk_num=8,topk_weights_dtype=torch.float32}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "100": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "1024": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "128": { + "BLOCK_TOKEN": 4, + "num_warps": 8 + }, + "16": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "16384": { + "BLOCK_TOKEN": 32, + "num_warps": 2 + }, + "2048": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "256": { + "BLOCK_TOKEN": 4, + "num_warps": 8 + }, + "32": { + "BLOCK_TOKEN": 4, + "num_warps": 8 + }, + "4096": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "64": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + }, + "8": { + "BLOCK_TOKEN": 16, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5497f5e6c3 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5497f5e6c3 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..e9918f6ad0 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "2": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..d037521c38 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2": { + "BV": 32, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..e922d888fc --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 4 + }, + "100": { + "num_warps": 4 + }, + "1024": { + "num_warps": 4 + }, + "128": { + "num_warps": 4 + }, + "16": { + "num_warps": 4 + }, + "16384": { + "num_warps": 4 + }, + "2048": { + "num_warps": 4 + }, + "256": { + "num_warps": 4 + }, + "32": { + "num_warps": 4 + }, + "4096": { + "num_warps": 4 + }, + "64": { + "num_warps": 4 + }, + "8": { + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..9459f41fa7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2": { + "BK": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..dff8ac4d00 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 4, + "num_warps": 2 + }, + "100": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "256": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "8": { + "BLK_HEADS": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..1fcfa30e97 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1024": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "131072": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "2048": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "256": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "32768": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "64": { + "BLOCK_N": 512, + "num_warps": 4 + }, + "8": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "800": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "8192": { + "BLOCK_N": 128, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..978754ec99 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=3072,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1152": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "144": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "147456": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "18432": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2304": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "288": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "36864": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "576": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "72": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "9": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "900": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "9216": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..2af76b5496 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/grouped_matmul:v1/{K=3072,N=256,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json new file mode 100644 index 0000000000..271548dd65 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_atomic:v1/{expert_num=257,topk_num=9}_NVIDIA_H200.json @@ -0,0 +1,22 @@ +{ + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json new file mode 100644 index 0000000000..42cc5c38a5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_align_fused_small:v2/{expert_num=257,topk_num=9}_NVIDIA_H200.json @@ -0,0 +1,37 @@ +{ + "1": { + "BLOCK_SIZE": 32, + "NUM_STAGE": 2, + "num_warps": 2 + }, + "100": { + "BLOCK_SIZE": 1024, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE": 256, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "16": { + "BLOCK_SIZE": 256, + "NUM_STAGE": 2, + "num_warps": 2 + }, + "32": { + "BLOCK_SIZE": 512, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE": 256, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "8": { + "BLOCK_SIZE": 128, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json new file mode 100644 index 0000000000..bd78d49c49 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 8, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 16 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 4, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..f1882ef5dc --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 4, + "num_warps": 1 + }, + "100": { + "num_stages": 4, + "num_warps": 1 + }, + "1024": { + "num_stages": 4, + "num_warps": 1 + }, + "128": { + "num_stages": 4, + "num_warps": 1 + }, + "16": { + "num_stages": 4, + "num_warps": 1 + }, + "16384": { + "num_stages": 4, + "num_warps": 2 + }, + "2048": { + "num_stages": 3, + "num_warps": 2 + }, + "256": { + "num_stages": 4, + "num_warps": 1 + }, + "32": { + "num_stages": 4, + "num_warps": 1 + }, + "4096": { + "num_stages": 5, + "num_warps": 1 + }, + "64": { + "num_stages": 4, + "num_warps": 1 + }, + "8": { + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..c3cabb161a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.6.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1152": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "144": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "147456": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "18432": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2304": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "288": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "36864": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "576": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "72": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "9": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "900": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "9216": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index 52d4e61da8..f15badde25 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -27,26 +27,21 @@ from lightllm.utils.device_utils import has_nvlink from lightllm.utils.envs_utils import ( get_env_start_args, - get_deepep_num_max_dispatch_tokens_per_rank, + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, get_redundancy_expert_num, ) from lightllm.utils.dist_utils import ( get_global_world_size, get_dp_world_size, - get_global_rank, - get_current_rank_in_dp, create_new_group_for_current_dp, create_dp_special_inter_group, ) -from lightllm.utils.device_utils import get_device_sm_count -from lightllm.utils.sgl_utils import HAS_SGL_KERNEL -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL -from contextlib import nullcontext, contextmanager +from lightllm.utils.device_utils import get_device_sm_count, is_sm100_gpu +from lightllm.utils.torch_dtype_utils import get_torch_dtype logger = init_logger(__name__) -from .custom_all_reduce import CustomAllreduce -from .custom_all_gather import CustomAllgather try: import deep_ep @@ -59,8 +54,8 @@ class CustomProcessGroup: def __init__(self): - self.custom_reduce = None - self.custom_gather = None + self.symm_mem_reduce = None + self.flashinfer_reduce = None self.dp_world_size = get_dp_world_size() self.device_group = create_new_group_for_current_dp("nccl") if get_env_start_args().enable_dp_prefill_balance: @@ -70,62 +65,64 @@ def __init__(self): self.autotune_group = dist.new_group([i for i in range(get_global_world_size())], backend="gloo") - def init_custom_reduce(self) -> None: - if not HAS_SGL_KERNEL or not has_nvlink() or self.dp_world_size not in [2, 4, 6, 8]: - return - args = get_env_start_args() - if args.disable_custom_allreduce: - return - cpu_group = create_new_group_for_current_dp("gloo") - self.custom_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device()) - logger.info("Enable Custom ALLReduce. You can disable it by settting --disable_custom_allreduce.") + def _support_custom_allreduce(self) -> bool: + return has_nvlink() and self.dp_world_size in [2, 4, 6, 8] - def init_custom_gather(self) -> None: - if not HAS_LIGHTLLM_KERNEL or not has_nvlink() or self.dp_world_size not in [2, 4, 6, 8]: + def init_symm_mem_reduce(self) -> None: + if not self._support_custom_allreduce(): return + from .symm_mem_all_reduce import SymmMemAllreduce - args = get_env_start_args() - if not args.enable_custom_allgather: + data_type = get_torch_dtype(get_env_start_args().data_type) + symm = SymmMemAllreduce(self.device_group, torch.cuda.current_device(), dtype=data_type) + if not symm.disabled: + self.symm_mem_reduce = symm + logger.info("Enable SymmMem ALLReduce.") + + def init_flashinfer_reduce(self) -> None: + if not self._support_custom_allreduce(): return + from .flashinfer_all_reduce import FlashInferAllReduce - cpu_group = create_new_group_for_current_dp("gloo") - self.custom_gather = CustomAllgather(cpu_group, torch.cuda.current_device()) - logger.info("Enable Custom ALLGather. You can disable it by settting --disable_custom_allgather") + fi_cpu_group = create_new_group_for_current_dp("gloo") + fi = FlashInferAllReduce(fi_cpu_group, torch.cuda.current_device()) + if not fi.disabled: + self.flashinfer_reduce = fi + logger.info("Enable FlashInfer ALLReduce.") def all_reduce(self, input_: torch.Tensor) -> None: - if self.custom_reduce is not None and self.custom_reduce.should_custom_ar(input_): - input_.data = self.custom_reduce.custom_all_reduce(input_) + # Dispatch chain: FlashInfer -> SymmMem -> NCCL. + if self.flashinfer_reduce is not None and self.flashinfer_reduce.should_use(input_): + input_.data = self.flashinfer_reduce.all_reduce(input_) return - else: - return dist.all_reduce(input_, group=self.device_group) - - def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, async_op: bool = False) -> None: - if self.custom_gather is not None and self.custom_gather.should_custom_ar(input_): - self.custom_gather.custom_all_gather(output_, input_) + if self.symm_mem_reduce is not None and self.symm_mem_reduce.should_use(input_): + self.symm_mem_reduce.all_reduce(input_) return - else: - return dist.all_gather_into_tensor(output_, input_, group=self.device_group, async_op=async_op) + return dist.all_reduce(input_, group=self.device_group) - -@contextmanager -def lightllm_capture_graph(group: CustomProcessGroup = None): - with group.custom_reduce.capture() if group and group.custom_reduce else nullcontext(): - with group.custom_gather.capture() if group and group.custom_gather else nullcontext(): - yield + def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, async_op: bool = False) -> None: + return dist.all_gather_into_tensor(output_, input_, group=self.device_group, async_op=async_op) class DistributeGroupManager: def __init__(self): self.groups = [] + self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None def __len__(self): return len(self.groups) def create_groups(self, group_size: int): + args = get_env_start_args() for i in range(group_size): group = CustomProcessGroup() - group.init_custom_gather() - group.init_custom_reduce() + if not args.disable_symm_mem_allreduce: + group.init_symm_mem_reduce() + if not args.disable_flashinfer_allreduce: + group.init_flashinfer_reduce() self.groups.append(group) return @@ -135,52 +132,92 @@ def get_default_group(self) -> CustomProcessGroup: def get_group(self, group_index: int) -> CustomProcessGroup: return self.groups[group_index] - def new_deepep_group(self, n_routed_experts, hidden_size): + def new_deepep_group( + self, + n_routed_experts, + hidden_size, + num_experts_per_tok: int = 1, + moe_intermediate_size: Optional[int] = None, + ): enable_ep_moe = get_env_start_args().enable_ep_moe - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + prefill_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + decode_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() if not enable_ep_moe: self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None return assert HAS_DEEPEP, "deep_ep is required for expert parallelism" - self._set_num_sms_for_deep_gemm() global_world_size = get_global_world_size() deepep_group = dist.new_group(list(range(global_world_size))) - low_latency_mode, num_rdma_bytes = True, 0 - if low_latency_mode: - self.ll_num_tokens, self.ll_hidden = num_max_dispatch_tokens_per_rank, hidden_size - self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size - num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - self.ll_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts - ) - self.ep_buffer = deep_ep.Buffer( + self.ll_num_tokens = prefill_num_max_dispatch_tokens_per_rank + self.ll_decode_num_tokens = decode_num_max_dispatch_tokens_per_rank + self.ll_hidden = hidden_size + self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size + self.ep_buffer = deep_ep.ElasticBuffer( deepep_group, - int(1e9), - num_rdma_bytes, - low_latency_mode=low_latency_mode, - num_qps_per_rank=(self.ll_num_experts // global_world_size if low_latency_mode else 1), + num_max_tokens_per_rank=self.ll_num_tokens, + hidden=self.ll_hidden, + num_topk=num_experts_per_tok, + use_fp8_dispatch=True, + allow_multiple_reduction=False, ) + self.ep_mega_moe_buffer = None + self.ep_low_latency_buffer = None + if not is_sm100_gpu(): + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + self.ll_decode_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts + ) + self.ep_low_latency_buffer = deep_ep.Buffer( + deepep_group, + int(1e9), + num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=(self.ll_num_experts // global_world_size), + ) + else: + if moe_intermediate_size is None: + raise ValueError("SM100 Mega MoE requires moe_intermediate_size or intermediate_size in model config") + + import deep_gemm + + self.ep_mega_moe_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + deepep_group, + self.ll_num_experts, + self.ll_num_tokens, + num_experts_per_tok, + self.ll_hidden, + moe_intermediate_size, + ) + theoretical_sms = self.ep_buffer.get_theoretical_num_sms(self.ll_num_experts, num_experts_per_tok) + self._set_num_sms_for_deep_gemm(theoretical_sms) - def _set_num_sms_for_deep_gemm(self): + def _set_num_sms_for_deep_gemm(self, deepep_sms: int): try: try: from deep_gemm.jit_kernels.utils import set_num_sms except: from deep_gemm import set_num_sms - deepep_sms = int(os.getenv("DEEPEP_SMS", deep_ep.Buffer.num_sms)) device_sms = get_device_sm_count() - deep_ep.Buffer.set_num_sms(deepep_sms) - set_num_sms(device_sms - deepep_sms) + deepep_sms = max(0, min(deepep_sms, max(device_sms - 2, 0))) + self.ep_num_sms = deepep_sms + if self.ep_low_latency_buffer is not None: + deep_ep.Buffer.set_num_sms(deepep_sms - deepep_sms % 2) + set_num_sms(max(device_sms - deepep_sms, 2)) except BaseException as e: logger.warning(f"set num sms for deep_gemm failed: {e}") def clear_deepep_buffer(self): """ - prefill 之后需要clean 一下,ep buffer 才能正常执行 decode。 + Prefill after using ElasticBuffer may leave the legacy low-latency buffer dirty for decode. """ - if hasattr(self, "ep_buffer") and self.ep_buffer is not None: - self.ep_buffer.clean_low_latency_buffer(self.ll_num_tokens, self.ll_hidden, self.ll_num_experts) + if self.ep_low_latency_buffer is not None: + self.ep_low_latency_buffer.clean_low_latency_buffer( + self.ll_decode_num_tokens, self.ll_hidden, self.ll_num_experts + ) def all_reduce( @@ -192,9 +229,10 @@ def all_reduce( if _is_single_group(group=group): return if isinstance(group, CustomProcessGroup): - return group.all_reduce(input_) - else: - return dist.all_reduce(input_, op, group, async_op) + if op == ReduceOp.SUM: + return group.all_reduce(input_) + return dist.all_reduce(input_, op, group.device_group, async_op) + return dist.all_reduce(input_, op, group, async_op) def all_gather_into_tensor( diff --git a/lightllm/distributed/custom_all_gather.py b/lightllm/distributed/custom_all_gather.py deleted file mode 100644 index 44c72fcdaa..0000000000 --- a/lightllm/distributed/custom_all_gather.py +++ /dev/null @@ -1,252 +0,0 @@ -# Adapted from -# https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/distributed/device_communicators/custom_all_gather.py -# of the vllm-project/vllm GitHub repository. -# -# Copyright 2023 ModelTC Team -# Copyright 2023 vLLM Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import ctypes -from contextlib import contextmanager -from typing import List, Optional, Union - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -from lightllm.common.cuda_wrapper import CudaRTLibrary -from lightllm.utils.log_utils import init_logger -from lightllm.utils.device_utils import has_nvlink -from lightllm.utils.light_utils import light_ops - - -try: - if light_ops is not None: - light_ops.meta_size() -except: - pass - -logger = init_logger(__name__) - - -def is_weak_contiguous(inp: torch.Tensor): - return inp.is_contiguous() or ( - inp.storage().nbytes() - inp.storage_offset() * inp.element_size() == inp.numel() * inp.element_size() - ) - - -class CustomAllgather: - - _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] - - # max_size: max supported allgather size - def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024 * 10) -> None: - """ - Args: - group: the process group to work on. If None, it will use the - default process group. - device: the device to bind the CustomAllgather to. If None, - it will be bind to f"cuda:{local_rank}". - It is the caller's responsibility to make sure each communicator - is bind to a unique device, and all communicators in this group - are in the same node. - """ - self._IS_CAPTURING = False - self.disabled = True - - if light_ops is None: - # disable because of missing custom allgather library - # e.g. in a non-cuda environment - return - self.group = group - assert dist.get_backend(group) != dist.Backend.NCCL, "CustomAllgather should be attached to a non-NCCL group." - - rank = dist.get_rank(group=self.group) - world_size = dist.get_world_size(group=self.group) - if world_size == 1: - # No need to initialize custom allgather for single GPU case. - return - - if world_size not in CustomAllgather._SUPPORTED_WORLD_SIZES: - logger.warning( - "Custom allgather is disabled due to an unsupported world" - " size: %d. Supported world sizes: %s. To silence this " - "warning, specify disable_custom_all_gather=True explicitly.", - world_size, - str(CustomAllgather._SUPPORTED_WORLD_SIZES), - ) - return - - if isinstance(device, int): - device = torch.device(f"cuda:{device}") - elif isinstance(device, str): - device = torch.device(device) - # now `device` is a `torch.device` object - assert isinstance(device, torch.device) - self.device = device - - cuda_visible_devices = None - if cuda_visible_devices: - device_ids = list(map(int, cuda_visible_devices.split(","))) - else: - device_ids = list(range(torch._C._cuda_getDeviceCount())) - - physical_device_id = device_ids[device.index] - tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") - gather_list = [torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)] - dist.all_gather(gather_list, tensor, group=self.group) - # physical_device_ids = [t.item() for t in gather_list] - - full_nvlink = has_nvlink() - if world_size > 2 and not full_nvlink: - logger.warning( - "Custom allgather is disabled because it's not supported on" - " more than two PCIe-only GPUs. To silence this warning, " - "specify disable_custom_all_gather=True explicitly." - ) - return - - self.disabled = False - # Buffers memory are owned by this Python class and passed to C++. - # Meta data is for synchronization - self.meta_ptrs = self.create_shared_buffer(light_ops.meta_size(), group=group) - # This is a pre-registered IPC buffer. In eager mode, input tensors - # are first copied into this buffer before allgather is performed - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) - # This is a buffer for storing the tuples of pointers pointing to - # IPC buffers from all ranks. Each registered tuple has size of - # 8*world_size bytes where world_size is at most 8. Allocating 8MB - # is enough for 131072 such tuples. The largest model I've seen only - # needs less than 10000 of registered tuples. - self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=self.device) - self.max_size = max_size - self.rank = rank - self.world_size = world_size - self.full_nvlink = full_nvlink - self._ptr = light_ops.init_custom_gather_ar(self.meta_ptrs, self.rank_data, rank, self.full_nvlink) - light_ops.allgather_register_buffer(self._ptr, self.buffer_ptrs) - - @staticmethod - def create_shared_buffer(size_in_bytes: int, group: Optional[ProcessGroup] = None) -> List[int]: - """ - Creates a shared buffer and returns a list of pointers - representing the buffer on all processes in the group. - """ - lib = CudaRTLibrary() - pointer = lib.cudaMalloc(size_in_bytes) - handle = lib.cudaIpcGetMemHandle(pointer) - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) - - pointers: List[int] = [] - for i, h in enumerate(handles): - if i == rank: - pointers.append(pointer.value) # type: ignore - else: - pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore - - return pointers - - @staticmethod - def free_shared_buffer(pointers: List[int], group: Optional[ProcessGroup] = None) -> None: - rank = dist.get_rank(group=group) - lib = CudaRTLibrary() - lib.cudaFree(ctypes.c_void_p(pointers[rank])) - - @contextmanager - def capture(self): - """ - The main responsibility of this context manager is the - `register_graph_buffers` call at the end of the context. - It records all the buffer addresses used in the CUDA graph. - """ - try: - self._IS_CAPTURING = True - yield - finally: - self._IS_CAPTURING = False - if not self.disabled: - self.register_graph_buffers() - - def register_graph_buffers(self): - handle, offset = light_ops.allgather_get_graph_buffer_ipc_meta(self._ptr) - # We cannot directly use `dist.all_gather_object` here - # because it is incompatible with `gloo` backend under inference mode. - # see https://github.com/pytorch/pytorch/issues/126032 for details. - all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] - all_data[self.rank] = [handle, offset] - ranks = sorted(dist.get_process_group_ranks(group=self.group)) - for i, rank in enumerate(ranks): - dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu") - # Unpack list of tuples to tuple of lists. - handles = [d[0] for d in all_data] # type: ignore - offsets = [d[1] for d in all_data] # type: ignore - light_ops.allgather_register_graph_buffers(self._ptr, handles, offsets) - - def should_custom_ar(self, inp: torch.Tensor): - if self.disabled: - return False - inp_size = inp.numel() * inp.element_size() - # custom allgather requires input byte size to be multiples of 16 - if inp_size % 16 != 0: - return False - if not is_weak_contiguous(inp): - return False - if self.world_size == 2 or self.full_nvlink: - return inp_size < self.max_size - return False - - def all_gather(self, out: torch.Tensor, inp: torch.Tensor, registered: bool = False): - """Performs an out-of-place all gather. - - If registered is True, this assumes inp's pointer is already - IPC-registered. Otherwise, inp is first copied into a pre-registered - buffer. - """ - if registered: - light_ops.all_gather(self._ptr, inp, out, 0, 0) - else: - light_ops.all_gather(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) - return out - - def custom_all_gather(self, output: torch.Tensor, input: torch.Tensor) -> Optional[torch.Tensor]: - """The main allgather API that provides support for cuda graph.""" - # When custom allgather is disabled, this will be None. - if self.disabled or not self.should_custom_ar(input): - return - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - self.all_gather(output, input, registered=True) - return - else: - # If warm up, mimic the allocation pattern since custom - # allgather is out-of-place. - return - else: - # Note: outside of cuda graph context, custom allgather incurs a - # cost of cudaMemcpy, which should be small (<=1% of overall - # latency) compared to the performance gain of using custom kernels - self.all_gather(output, input, registered=False) - return - - def close(self): - if not self.disabled and self._ptr: - light_ops.allgather_dispose(self._ptr) - self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs) - self.free_shared_buffer(self.buffer_ptrs) - - def __del__(self): - self.close() diff --git a/lightllm/distributed/custom_all_reduce.py b/lightllm/distributed/custom_all_reduce.py deleted file mode 100644 index 690cc4061e..0000000000 --- a/lightllm/distributed/custom_all_reduce.py +++ /dev/null @@ -1,267 +0,0 @@ -# Adapted from -# https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/distributed/device_communicators/custom_all_reduce.py -# of the vllm-project/vllm GitHub repository. -# -# Copyright 2023 ModelTC Team -# Copyright 2023 vLLM Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import ctypes -from contextlib import contextmanager -from typing import List, Optional, Union -import os -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -from lightllm.common.cuda_wrapper import CudaRTLibrary -from lightllm.utils.log_utils import init_logger -from lightllm.utils.device_utils import has_nvlink -from lightllm.utils.sgl_utils import sgl_allreduce_ops -from lightllm.utils.vllm_utils import vllm_ops - -logger = init_logger(__name__) - -use_vllm_custom_allreduce = os.getenv("LIGHTLLM_USE_VLLM_CUSTOM_ALLREDUCE", "0").upper() in ["ON", "TRUE", "1"] -if use_vllm_custom_allreduce: - # Use vllm custom allreduce - ops = vllm_ops -else: - # Use sgl custom allreduce - ops = sgl_allreduce_ops - -if ops is not None: - ops.meta_size() - - -def is_weak_contiguous(inp: torch.Tensor): - return inp.is_contiguous() or ( - inp.storage().nbytes() - inp.storage_offset() * inp.element_size() == inp.numel() * inp.element_size() - ) - - -class CustomAllreduce: - - _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] - - # max_size: max supported allreduce size - def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024) -> None: - """ - Args: - group: the process group to work on. If None, it will use the - default process group. - device: the device to bind the CustomAllreduce to. If None, - it will be bind to f"cuda:{local_rank}". - It is the caller's responsibility to make sure each communicator - is bind to a unique device, and all communicators in this group - are in the same node. - """ - self._IS_CAPTURING = False - self.disabled = True - - if ops is None: - # disable because of missing custom allreduce library - # e.g. in a non-cuda environment - return - self.group = group - assert dist.get_backend(group) != dist.Backend.NCCL, "CustomAllreduce should be attached to a non-NCCL group." - - rank = dist.get_rank(group=self.group) - world_size = dist.get_world_size(group=self.group) - if world_size == 1: - # No need to initialize custom allreduce for single GPU case. - return - - if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: - logger.warning( - "Custom allreduce is disabled due to an unsupported world" - " size: %d. Supported world sizes: %s. To silence this " - "warning, specify disable_custom_all_reduce=True explicitly.", - world_size, - str(CustomAllreduce._SUPPORTED_WORLD_SIZES), - ) - return - - if isinstance(device, int): - device = torch.device(f"cuda:{device}") - elif isinstance(device, str): - device = torch.device(device) - # now `device` is a `torch.device` object - assert isinstance(device, torch.device) - self.device = device - - cuda_visible_devices = None - if cuda_visible_devices: - device_ids = list(map(int, cuda_visible_devices.split(","))) - else: - device_ids = list(range(torch._C._cuda_getDeviceCount())) - - physical_device_id = device_ids[device.index] - tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") - gather_list = [torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)] - dist.all_gather(gather_list, tensor, group=self.group) - # physical_device_ids = [t.item() for t in gather_list] - - full_nvlink = has_nvlink() - if world_size > 2 and not full_nvlink: - logger.warning( - "Custom allreduce is disabled because it's not supported on" - " more than two PCIe-only GPUs. To silence this warning, " - "specify disable_custom_all_reduce=True explicitly." - ) - return - - self.disabled = False - # Buffers memory are owned by this Python class and passed to C++. - # Meta data composes of two parts: meta data for synchronization and a - # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, group=group) - # This is a pre-registered IPC buffer. In eager mode, input tensors - # are first copied into this buffer before allreduce is performed - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) - # This is a buffer for storing the tuples of pointers pointing to - # IPC buffers from all ranks. Each registered tuple has size of - # 8*world_size bytes where world_size is at most 8. Allocating 8MB - # is enough for 131072 such tuples. The largest model I've seen only - # needs less than 10000 of registered tuples. - self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=self.device) - self.max_size = max_size - self.rank = rank - self.world_size = world_size - self.full_nvlink = full_nvlink - self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, self.full_nvlink) - ops.register_buffer(self._ptr, self.buffer_ptrs) - - @staticmethod - def create_shared_buffer(size_in_bytes: int, group: Optional[ProcessGroup] = None) -> List[int]: - """ - Creates a shared buffer and returns a list of pointers - representing the buffer on all processes in the group. - """ - lib = CudaRTLibrary() - pointer = lib.cudaMalloc(size_in_bytes) - handle = lib.cudaIpcGetMemHandle(pointer) - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) - - pointers: List[int] = [] - for i, h in enumerate(handles): - if i == rank: - pointers.append(pointer.value) # type: ignore - else: - pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore - - return pointers - - @staticmethod - def free_shared_buffer(pointers: List[int], group: Optional[ProcessGroup] = None) -> None: - rank = dist.get_rank(group=group) - lib = CudaRTLibrary() - lib.cudaFree(ctypes.c_void_p(pointers[rank])) - - @contextmanager - def capture(self): - """ - The main responsibility of this context manager is the - `register_graph_buffers` call at the end of the context. - It records all the buffer addresses used in the CUDA graph. - """ - try: - self._IS_CAPTURING = True - yield - finally: - self._IS_CAPTURING = False - if not self.disabled: - self.register_graph_buffers() - - def register_graph_buffers(self): - handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) - # We cannot directly use `dist.all_gather_object` here - # because it is incompatible with `gloo` backend under inference mode. - # see https://github.com/pytorch/pytorch/issues/126032 for details. - all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] - all_data[self.rank] = [handle, offset] - ranks = sorted(dist.get_process_group_ranks(group=self.group)) - for i, rank in enumerate(ranks): - dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu") - # Unpack list of tuples to tuple of lists. - handles = [d[0] for d in all_data] # type: ignore - offsets = [d[1] for d in all_data] # type: ignore - ops.register_graph_buffers(self._ptr, handles, offsets) - - def should_custom_ar(self, inp: torch.Tensor): - if self.disabled: - return False - inp_size = inp.numel() * inp.element_size() - # custom allreduce requires input byte size to be multiples of 16 - if inp_size % 16 != 0: - return False - if not is_weak_contiguous(inp): - return False - # for 4 or more non NVLink-capable GPUs, custom allreduce provides - # little performance improvement over NCCL. - if self.world_size == 2 or self.full_nvlink: - return inp_size < self.max_size - return False - - def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False): - """Performs an out-of-place all reduce. - - If registered is True, this assumes inp's pointer is already - IPC-registered. Otherwise, inp is first copied into a pre-registered - buffer. - """ - if out is None: - # fix circle import - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - out = g_cache_manager.alloc_tensor(inp.shape, inp.dtype, device=inp.device) - if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) - else: - ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) - return out - - def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: - """The main allreduce API that provides support for cuda graph.""" - # When custom allreduce is disabled, this will be None. - if self.disabled or not self.should_custom_ar(input): - return None - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input, registered=True) - else: - # If warm up, mimic the allocation pattern since custom - # allreduce is out-of-place. - # fix circle import - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - out = g_cache_manager.alloc_tensor(input.shape, input.dtype, device=input.device) - return out - else: - # Note: outside of cuda graph context, custom allreduce incurs a - # cost of cudaMemcpy, which should be small (<=1% of overall - # latency) compared to the performance gain of using custom kernels - return self.all_reduce(input, registered=False) - - def close(self): - if not self.disabled and self._ptr: - ops.dispose(self._ptr) - self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs) - self.free_shared_buffer(self.buffer_ptrs) - - def __del__(self): - self.close() diff --git a/lightllm/distributed/flashinfer_all_reduce.py b/lightllm/distributed/flashinfer_all_reduce.py new file mode 100644 index 0000000000..35cb0aaedf --- /dev/null +++ b/lightllm/distributed/flashinfer_all_reduce.py @@ -0,0 +1,136 @@ +import os +import random +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + import flashinfer.comm as flashinfer_comm + from flashinfer.comm.mnnvl import TorchDistBackend + + _FI_OK = hasattr(flashinfer_comm, "allreduce_fusion") and hasattr( + flashinfer_comm, "create_allreduce_fusion_workspace" + ) +except ImportError: + flashinfer_comm = None + TorchDistBackend = None + _FI_OK = False + +_MiB = 1024 * 1024 +# Default upper bound for the FlashInfer fast path (oneshot lamport regime). +# Used when (compute_cap, world_size) is not in the table below. Above the +# resolved bound, dispatch falls through to SymmMem multimem / NCCL. +FI_ALLREDUCE_DEFAULT_MAX_BYTES = 256 * 1024 + +_FI_ALLREDUCE_MAX_BYTES = { + "9.0": {2: 512 * 1024, 4: 256 * 1024, 6: 256 * 1024, 8: 128 * 1024}, + "10.0": {2: 1024 * 1024, 4: 512 * 1024, 6: 256 * 1024, 8: 256 * 1024}, + "10.3": {2: 1024 * 1024, 4: 512 * 1024, 6: 512 * 1024, 8: 256 * 1024}, +} + +_FI_WORKSPACE_MAX_SIZE_MB = { + "9.0": {2: 2.0, 4: 1.0, 6: 1.0, 8: 0.5}, + "10.0": {2: 2.0, 4: 2.0, 6: 1.0, 8: 1.0}, + "10.3": {2: 2.0, 4: 2.0, 6: 2.0, 8: 1.0}, +} + + +class FlashInferAllReduce: + """Small-message all-reduce via flashinfer trtllm oneshot lamport. + + Out-of-place: callers assign back via ``t.data = fi.all_reduce(t)``. + """ + + def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]) -> None: + self.disabled = True + self._workspace = None + self._ws_hidden_dim = None + self._ws_dtype = None + self._ws_max_token_num = 0 + + if not _FI_OK or not torch.cuda.is_available(): + return + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + self.device = device + self.group = group + self.world_size = dist.get_world_size(group=group) + self.rank = dist.get_rank(group=group) + if self.world_size == 1: + return + + cap = torch.cuda.get_device_capability(device) + cap_str = f"{cap[0]}.{cap[1]}" + ws_table = _FI_WORKSPACE_MAX_SIZE_MB.get(cap_str) + if ws_table is None or self.world_size not in ws_table: + return + self.max_workspace_size = int(ws_table[self.world_size] * _MiB) + self.max_bytes = _FI_ALLREDUCE_MAX_BYTES.get(cap_str, {}).get(self.world_size, FI_ALLREDUCE_DEFAULT_MAX_BYTES) + assert self.max_bytes <= self.max_workspace_size, ( + "FlashInferAllReduce config mismatch: " + f"max_bytes={self.max_bytes} exceeds max_workspace_size={self.max_workspace_size}" + ) + self.disabled = False + + def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool: + if self._workspace is not None and self._ws_hidden_dim == hidden_dim and self._ws_dtype == dtype: + return True + element_size = torch.tensor([], dtype=dtype).element_size() + max_token_num = max(1, self.max_workspace_size // (hidden_dim * element_size)) + if self._workspace is not None: + try: + self._workspace.destroy() + except Exception: + pass + self._workspace = None + rng_state = random.getstate() + try: + random.seed(int.from_bytes(os.urandom(16), byteorder="big")) + self._workspace = flashinfer_comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=self.world_size, + rank=self.rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + comm_backend=TorchDistBackend(group=self.group), + ) + except Exception as e: + logger.warning("FlashInferAllReduce workspace init failed: %s. Disabling.", e) + self.disabled = True + self._workspace = None + return False + finally: + random.setstate(rng_state) + self._ws_hidden_dim = hidden_dim + self._ws_dtype = dtype + self._ws_max_token_num = max_token_num + return True + + def should_use(self, inp: torch.Tensor) -> bool: + if self.disabled or not inp.is_cuda or not inp.is_contiguous(): + return False + if inp.dtype not in (torch.bfloat16, torch.float16) or inp.dim() != 2: + return False + if inp.numel() * inp.element_size() >= self.max_bytes: + return False + _, hidden_dim = inp.shape + if not self._ensure_workspace(hidden_dim, inp.dtype): + return False + return True + + def all_reduce(self, inp: torch.Tensor) -> torch.Tensor: + return flashinfer_comm.allreduce_fusion( + input=inp, + workspace=self._workspace, + pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce, + # launch_with_pdl=True, # TODO: learn pdl and ensure no other side effects. + ) diff --git a/lightllm/distributed/symm_mem_all_reduce.py b/lightllm/distributed/symm_mem_all_reduce.py new file mode 100644 index 0000000000..2256a5093f --- /dev/null +++ b/lightllm/distributed/symm_mem_all_reduce.py @@ -0,0 +1,92 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + + _SYMM_MEM_OK = True +except ImportError: + torch_symm_mem = None + _SYMM_MEM_OK = False + +_MiB = 1024 * 1024 + +# Adopted from vLLM's benchmark-tuned SymmMem max-size table: +# vllm/distributed/device_communicators/all_reduce_utils.py +SYMM_MEM_ALL_REDUCE_MAX_SIZES = { + "9.0": {2: 64 * _MiB, 4: 32 * _MiB, 6: 64 * _MiB, 8: 64 * _MiB}, + "10.0": {2: 8 * _MiB, 4: 32 * _MiB, 6: 128 * _MiB, 8: 128 * _MiB}, + "10.3": {2: 4 * _MiB, 4: 32 * _MiB, 6: 32 * _MiB, 8: 64 * _MiB}, +} + +# Adopted from vLLM's multimem-vs-two_shot world-size split. +# World sizes for which multimem (NVLS hardware reduce) beats two_shot. +_WORLD_SIZES_MULTIMEM = {"9.0": [4, 6, 8], "10.0": [6, 8], "10.3": [6, 8]} + + +class SymmMemAllreduce: + """In-place all-reduce via torch symmetric memory (NVLink SHARP / NVLS).""" + + def __init__(self, group: ProcessGroup, device, dtype: torch.dtype = torch.bfloat16) -> None: + self.disabled = True + if not _SYMM_MEM_OK: + return + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + self.device = device + self.group = group + self.dtype = dtype + self.world_size = dist.get_world_size(group=group) + + cap = torch.cuda.get_device_capability(device) + cap_str = f"{cap[0]}.{cap[1]}" + if cap_str not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: + return + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[cap_str]: + return + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[cap_str][self.world_size] + self.use_multimem = self.world_size in _WORLD_SIZES_MULTIMEM.get(cap_str, []) + + try: + self.buffer = torch_symm_mem.empty(self.max_size // dtype.itemsize, device=device, dtype=dtype) + handle = torch_symm_mem.rendezvous(self.buffer, group.group_name) + except RuntimeError as e: + logger.warning("SymmMemAllreduce: rendezvous failed (%s). Disabling.", e) + return + # multimem and two_shot both require a multicast pointer. + if getattr(handle, "multicast_ptr", 0) == 0: + logger.warning("SymmMemAllreduce: multicast pointer unavailable; disabling.") + return + self.disabled = False + logger.info( + "SymmMemAllreduce enabled: world_size=%d, max_size=%d, multimem=%s", + self.world_size, + self.max_size, + self.use_multimem, + ) + + def should_use(self, inp: torch.Tensor) -> bool: + if self.disabled or inp.dtype != self.dtype or not inp.is_contiguous(): + return False + nbytes = inp.numel() * inp.element_size() + if nbytes % 4 != 0: + return False + # Lower bound is implicitly handled by the dispatch order in + # CustomProcessGroup.all_reduce: FlashInfer claims small messages first. + return nbytes < self.max_size + + def all_reduce(self, inp: torch.Tensor) -> None: + n = inp.numel() + self.buffer[:n].copy_(inp.view(-1)) + if self.use_multimem: + torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:n], "sum", self.group.group_name) + else: + torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:n], "sum", self.group.group_name) + inp.view(-1).copy_(self.buffer[:n]) diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 2caee91709..f619b1d88f 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -33,6 +33,7 @@ from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel from lightllm.models.gemma3.model import Gemma3TpPartModel +from lightllm.models.gemma4.model import Gemma4TpPartModel from lightllm.models.tarsier2.model import ( Tarsier2Qwen2TpPartModel, Tarsier2Qwen2VLTpPartModel, diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 60d584eebd..c8a78e6b62 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -71,19 +71,26 @@ def _ffn_norm( def _get_qkv( self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight ) -> Tuple[torch.Tensor, torch.Tensor]: + input = self._tpsp_allgather(input, infer_state) q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + assert infer_state.need_dp_prefill_balance is False, "bloom does not support dp prefill balance" return q, cache_kv def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: + assert infer_state.need_dp_prefill_balance is False, "bloom does not support dp prefill balance" o_tensor = layer_weight.o_proj.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_)) + o_tensor = self._tpsp_reduce(o_tensor, infer_state) return o_tensor def _ffn(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: - ffn1_out = layer_weight.gate_up_proj.mm(input.view(-1, self.embed_dim_)) + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + ffn1_out = layer_weight.gate_up_proj.mm(input) input = None gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh") ffn1_out = None ffn2_out = layer_weight.down_proj.mm(gelu_out) gelu_out = None + ffn2_out = self._tpsp_reduce(ffn2_out, infer_state) return ffn2_out diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 98cc7c229e..be819c94a0 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -1,6 +1,5 @@ import os import torch -import torch.distributed as dist import triton from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.common.basemodel.attention.base_att import AttControl @@ -8,9 +7,9 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale -from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger @@ -65,14 +64,11 @@ def _bind_ffn(self): if self.is_moe: enable_ep_moe = get_env_start_args().enable_ep_moe if enable_ep_moe: - self._ffn = self._moe_ffn_edp - self._tpsp_ffn = self._tpsp_ffn_ep + self._ffn = self._ffn_ep_impl else: - self._ffn = self._moe_ffn - self._tpsp_ffn = self._tpsp_ffn_tp + self._ffn = self._ffn_tp_impl else: self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) - self._tpsp_ffn = self._tpsp_ffn_tp def _context_attention_kernel( self, @@ -151,51 +147,12 @@ def _decompress_kv( return sampled_k_nope, sampled_k_rope, sampled_v def _get_qkv( - self, - input: torch.Tensor, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - - if self.q_lora_rank is None: - q = layer_weight.q_weight_.mm(input) - cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) - else: - q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 - ) - q = layer_weight.q_a_layernorm_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) - q = layer_weight.q_b_proj_.mm(q) - cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) - q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) - q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - layer_weight.kv_a_layernorm_( - cache_kv[:, :, : self.kv_lora_rank], eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank] - ) - - rotary_emb_fwd( - q_rope, - cache_kv[:, :, self.kv_lora_rank :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _tpsp_get_qkv( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) if self.q_lora_rank is None: # q_lora_rank is None 的时候,当前不支持低rank通信优化。 - if self.tp_world_size_ > 1: - sp_token_num, hidden_dim = input.shape - gather_input = self.alloc_tensor( - (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device - ) - all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input[0 : len(infer_state.input_ids), :] + input = self._tpsp_allgather(input=input, infer_state=infer_state) input = input.view(-1, self.embed_dim_) q = layer_weight.q_weight_.mm(input) @@ -214,26 +171,16 @@ def _tpsp_get_qkv( if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) + return q, cache_kv else: input = input.view(-1, self.embed_dim_) qkv = layer_weight.qkv_a_proj_with_mqa_.mm(input) # 在 lora rank 之后,进行通信,可以减少通信量。 - if self.tp_world_size_ > 1: - sp_token_num, qkv_dim = qkv.shape - gather_qkv = self.alloc_tensor( - (sp_token_num * self.tp_world_size_, qkv_dim), dtype=qkv.dtype, device=qkv.device - ) - all_gather_into_tensor(gather_qkv, qkv, group=infer_state.dist_group, async_op=False) - qkv = gather_qkv[0 : len(infer_state.input_ids), :] + qkv = self._tpsp_allgather(input=qkv, infer_state=infer_state) if infer_state.need_dp_prefill_balance: qkv = infer_state._all_to_all_unbalance_get(data=qkv) - position_cos = infer_state._unbalance_position_cos - position_sin = infer_state._unbalance_position_sin - else: - position_cos = infer_state.position_cos - position_sin = infer_state.position_sin q, cache_kv = qkv.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1) q = layer_weight.q_a_layernorm_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) @@ -247,51 +194,24 @@ def _tpsp_get_qkv( rotary_emb_fwd( q_rope, cache_kv[:, :, self.kv_lora_rank :], - position_cos, - position_sin, + infer_state.position_cos, + infer_state.position_sin, ) return q, cache_kv def _get_o( self, input: torch.Tensor, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight - ) -> torch.Tensor: - if input.shape[2] == self.kv_lora_rank: - input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) - o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim)) - return o_tensor - - def _tpsp_get_o( - self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) if input.shape[2] == self.kv_lora_rank: input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) - - input = input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim) - dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ - o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) - layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) - e_o_tensor = o_tensor[len(infer_state.input_ids) :, :] - if e_o_tensor.shape[0] > 0: - e_o_tensor.fill_(0) - - if self.tp_world_size_ > 1: - sp_token_num = o_tensor.shape[0] // self.tp_world_size_ - reduce_o_tensor = self.alloc_tensor((sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device) - reduce_scatter_tensor( - output=reduce_o_tensor, - input=o_tensor, - op=dist.ReduceOp.SUM, - group=infer_state.dist_group, - async_op=False, - ) - o_tensor = reduce_o_tensor - + o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim)) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor - def _moe_ffn( + def _moe_ffn_tp( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: @@ -300,7 +220,7 @@ def _moe_ffn( # if fused_shared_experts is not enabled, compute shared_output if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: - shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) + shared_output = LlamaTransformerLayerInfer._ffn_tp(self, hidden_states, infer_state, layer_weight) moe_gate_dtype = layer_weight.moe_gate.data_type_ router_logits = layer_weight.moe_gate.mm(hidden_states.to(moe_gate_dtype)) @@ -326,7 +246,7 @@ def _moe_ffn_edp( hidden_states = input token_num, hidden_dim = hidden_states.shape if self.n_shared_experts is not None: - shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) + shared_output = LlamaTransformerLayerInfer._ffn_tp(self, hidden_states, infer_state, layer_weight) moe_gate_dtype = layer_weight.moe_gate.data_type_ router_logits = layer_weight.moe_gate.mm(hidden_states.to(moe_gate_dtype)) @@ -347,40 +267,24 @@ def _moe_ffn_edp( ep_output = ep_output.view(token_num, hidden_dim) return ep_output - def _tpsp_ffn(self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight): - raise Exception("need bind to real impl") - - def _tpsp_ffn_tp( + def _ffn_tp_impl( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - if self.tp_world_size_ > 1: - sp_token_num, hidden_dim = input.shape - gather_input = self.alloc_tensor( - (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device - ) - all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input + input = self._tpsp_allgather(input=input, infer_state=infer_state) + ffn2_out = self._moe_ffn_tp(input=input, infer_state=infer_state, layer_weight=layer_weight) - ffn2_out = self._ffn(input=input, infer_state=infer_state, layer_weight=layer_weight) + ffn2_out = self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) - if self.tp_world_size_ > 1: - sp_token_num = ffn2_out.shape[0] // self.tp_world_size_ - reduce_o_tensor = self.alloc_tensor( - (sp_token_num, self.embed_dim_), dtype=ffn2_out.dtype, device=ffn2_out.device - ) - reduce_scatter_tensor( - reduce_o_tensor, ffn2_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False - ) - ffn2_out = reduce_o_tensor return ffn2_out - def _tpsp_ffn_ep( + def _ffn_ep_impl( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: + # ep 本身就是一种 sp 兼容,所以不需要再进行 allgather 和 reduce input = input.view(-1, self.embed_dim_) - ffn2_out = self._ffn(input=input, infer_state=infer_state, layer_weight=layer_weight) + ffn2_out = self._moe_ffn_edp(input=input, infer_state=infer_state, layer_weight=layer_weight) return ffn2_out @@ -392,18 +296,18 @@ def overlap_tpsp_token_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) # 0 attention _0_input1 = self._att_norm(input_embdings, infer_state, layer_weight) - _0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, infer_state, layer_weight) + _0_q, _0_cache_kv = self._get_qkv(_0_input1, infer_state, layer_weight) _0_input1 = None self._post_cache_kv(_0_cache_kv, infer_state, layer_weight) _0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight) _0_q = None - _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) + _0_o = self._get_o(_0_o, infer_state, layer_weight) input_embdings.add_(_0_o.view(-1, self.embed_dim_)) _0_o = None _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) @@ -416,7 +320,7 @@ def overlap_tpsp_token_forward( # 0 shared expert if self.n_shared_experts is not None: - _0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) + _0_shared_output = LlamaTransformerLayerInfer._ffn_tp(self, _0_input1, infer_state, layer_weight) # 0 dispatch ( @@ -431,12 +335,12 @@ def overlap_tpsp_token_forward( # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) - _1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, infer_state1, layer_weight) + _1_q, _1_cache_kv = self._get_qkv(_1_input1, infer_state1, layer_weight) _1_input1 = None self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight) _1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight) _1_q = None - _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) + _1_o = self._get_o(_1_o, infer_state1, layer_weight) input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) _1_o = None _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) @@ -451,7 +355,7 @@ def overlap_tpsp_token_forward( # 1 shared expert if self.n_shared_experts is not None: - _1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight) + _1_shared_output = LlamaTransformerLayerInfer._ffn_tp(self, _1_input1, infer_state1, layer_weight) # 1 dispatch ( @@ -518,18 +422,18 @@ def overlap_tpsp_context_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) # 0 attention _0_input1 = self._att_norm(input_embdings, infer_state, layer_weight) - _0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, infer_state, layer_weight) + _0_q, _0_cache_kv = self._get_qkv(_0_input1, infer_state, layer_weight) _0_input1 = None self._post_cache_kv(_0_cache_kv, infer_state, layer_weight) _0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight) _0_q = None - _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) + _0_o = self._get_o(_0_o, infer_state, layer_weight) input_embdings.add_(_0_o.view(-1, self.embed_dim_)) _0_o = None _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) @@ -544,18 +448,18 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) - _1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, infer_state1, layer_weight) + _1_q, _1_cache_kv = self._get_qkv(_1_input1, infer_state1, layer_weight) _1_input1 = None self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight) _1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight) _1_q = None - _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) + _1_o = self._get_o(_1_o, infer_state1, layer_weight) input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) _1_o = None _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) @@ -583,16 +487,15 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 shared expert if self.n_shared_experts is not None: - _0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) + _0_shared_output = LlamaTransformerLayerInfer._ffn_tp(self, _0_input1, infer_state, layer_weight) # 1 shared expert if self.n_shared_experts is not None: - _1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight) + _1_shared_output = LlamaTransformerLayerInfer._ffn_tp(self, _1_input1, infer_state1, layer_weight) # 0 moe calu _0_moe_out = layer_weight.experts.prefilled_group_gemm( @@ -615,7 +518,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -630,7 +533,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() if self.n_shared_experts is not None: _0_ffn_out.add_(_0_shared_output) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f9176..cff020ea40 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -37,9 +37,9 @@ def _parse_config(self): self.num_attention_heads = self.network_config_["num_attention_heads"] self.kv_lora_rank = self.network_config_["kv_lora_rank"] self.num_fused_shared_experts = 0 - if get_env_start_args().enable_fused_shared_experts and self.is_moe: - # enable_fused_shared_experts can only work with tensor parallelism - assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode." + start_args = get_env_start_args() + if start_args.enable_fused_shared_experts and not start_args.enable_ep_moe and self.is_moe: + # fused shared experts can only work with tensor parallelism self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) self.n_embed = self.network_config_["hidden_size"] self.n_inter = self.network_config_["intermediate_size"] diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..ea6620b4e4 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -48,7 +48,12 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/deepseek3_2/__init__.py b/lightllm/models/deepseek3_2/__init__.py index e69de29bb2..7c89d8f54c 100644 --- a/lightllm/models/deepseek3_2/__init__.py +++ b/lightllm/models/deepseek3_2/__init__.py @@ -0,0 +1,25 @@ +"""Make HuggingFace transformers recognize the ``deepseek_v32`` model_type. + +DeepSeek-V3.2 ships ``config.json`` with ``model_type="deepseek_v32"``, which +transformers (>=5.x) does not know. ``AutoTokenizer``/``AutoConfig`` then fall +back to the base ``PreTrainedConfig`` and crash during RoPE standardization +(``'PreTrainedConfig' object has no attribute 'max_position_embeddings'``). + +V3.2 is architecturally a V3 variant, so we alias its config to +``DeepseekV3Config``. lightllm uses its own model implementation and reads +``config.json`` directly; this registration only fixes loading the HF tokenizer +through ``AutoTokenizer`` (see ``lightllm/server/tokenizer.py``). +""" +from transformers import AutoConfig + +try: + from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + + class DeepseekV32Config(DeepseekV3Config): + model_type = "deepseek_v32" + + AutoConfig.register("deepseek_v32", DeepseekV32Config, exist_ok=True) +except Exception: + # Older transformers without deepseek_v3, or a build that already + # supports deepseek_v32 natively. Nothing to do in either case. + pass diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 58c544e820..d6eaebe2fd 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -30,6 +30,9 @@ def _get_qkv( layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + if infer_state.need_dp_prefill_balance: + input = infer_state._all_to_all_unbalance_get(data=input) q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 @@ -224,7 +227,15 @@ def _get_indices( import deep_gemm - logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + logits = deep_gemm.fp8_mqa_logits( + q_fp8, + (k_fp8_, k_scale_), + weights.squeeze(-1), + ks, + ke, + clean_logits=False, + max_seqlen_k=infer_state.max_kv_seq_len, + ) from sgl_kernel import fast_topk_v2 @@ -232,7 +243,6 @@ def _get_indices( score=logits, lengths=lengths, topk=self.index_topk, - row_starts=ks, ) b_topk_index = torch.where(b_topk_index != -1, b_topk_index + ks.view(-1, 1), -1) # 将 topk index 转化为 mem index @@ -248,7 +258,7 @@ def _get_indices( @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 - from sgl_kernel import hadamard_transform + from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform hidden_size = x.size(-1) assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index d0f8b45f81..f02fc30942 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -112,4 +112,4 @@ def extract_indexer_ks( num_stages=1, ) - return O_fp8, O_scale + return O_fp8, O_scale.squeeze(-1) diff --git a/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py b/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py new file mode 100644 index 0000000000..eabf703f56 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py @@ -0,0 +1,80 @@ +import functools + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _butterfly_stage(x, GROUPS: tl.constexpr, STEP: tl.constexpr, BLOCK_R: tl.constexpr, BLOCK_N: tl.constexpr): + x_grouped = tl.reshape(x, (BLOCK_R, GROUPS, 2, STEP)) + x_grouped = tl.permute(x_grouped, (0, 1, 3, 2)) + left, right = tl.split(x_grouped) + x_pair = tl.join(left + right, left - right) + x_pair = tl.permute(x_pair, (0, 1, 3, 2)) + return tl.reshape(x_pair, (BLOCK_R, BLOCK_N)) + + +@triton.jit +def _hadamard_transform_kernel( + X, + Y, + n_rows, + scale: tl.constexpr, + BLOCK_R: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows = pid * BLOCK_R + tl.arange(0, BLOCK_R) + mask = rows[:, None] < n_rows + offsets = rows[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32) + + x = _butterfly_stage(x, 64, 1, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 32, 2, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 16, 4, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 8, 8, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 4, 16, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 2, 32, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 1, 64, BLOCK_R, BLOCK_N) + + tl.store(Y + offsets, x * scale, mask=mask) + + +@functools.lru_cache(maxsize=None) +def _target_programs(device_index: int) -> int: + return torch.cuda.get_device_properties(device_index).multi_processor_count * 2 + + +def _pick_block_r(rows: int, device_index: int) -> int: + block_r = triton.next_power_of_2(max(1, rows // _target_programs(device_index))) + return max(1, min(128, block_r)) + + +def _hadamard_transform_triton(x: torch.Tensor, scale: float) -> torch.Tensor: + original_shape = x.shape + hidden_size = x.size(-1) + if not x.is_contiguous(): + x = x.contiguous() + rows = x.numel() // hidden_size + out = torch.empty_like(x) + BLOCK_R = _pick_block_r(rows, x.device.index) + grid = (triton.cdiv(rows, BLOCK_R),) + _hadamard_transform_kernel[grid]( + x, + out, + rows, + scale, + BLOCK_R=BLOCK_R, + BLOCK_N=hidden_size, + num_warps=4, + ) + return out.view(original_shape) + + +def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + assert x.is_cuda, "hadamard_transform only supports CUDA tensors" + assert x.dtype == torch.bfloat16, "hadamard_transform expects bfloat16 input" + assert x.size(-1) == 128, "DeepSeek-V3.2 Hadamard transform expects hidden size 128" + + return _hadamard_transform_triton(x, scale) diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py index d9ffdb0e31..e2b2a56137 100644 --- a/lightllm/models/deepseek_mtp/model.py +++ b/lightllm/models/deepseek_mtp/model.py @@ -7,6 +7,9 @@ class Deepseek3MTPModel(Deepseek2TpPartModel): + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True + pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer diff --git a/lightllm/models/gemma3/gemma3_visual.py b/lightllm/models/gemma3/gemma3_visual.py index b2f7a6b779..b2cdf1ec54 100644 --- a/lightllm/models/gemma3/gemma3_visual.py +++ b/lightllm/models/gemma3/gemma3_visual.py @@ -127,7 +127,7 @@ def encode(self, images: List[ImageItem]): t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] img_tensors.append(t) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) cur_num = img_tensors[-1].shape[0] valid_ids.append([valid_id, valid_id + cur_num]) diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index 183c4d8d45..86f00cfbb4 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -1,8 +1,6 @@ import torch -import torch.distributed as dist import torch.nn as nn from lightllm.common.basemodel.infer_struct import InferStateInfo -from lightllm.distributed import all_reduce from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer @@ -24,6 +22,7 @@ def __init__(self, layer_num, network_config): def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma3TransformerLayerWeight ) -> torch.Tensor: + input = self._tpsp_allgather(input=input, infer_state=infer_state) q = layer_weight.q_proj.mm(input) # kv = layer_weight.kv_proj.mm(input) # kv = kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) @@ -58,17 +57,22 @@ def _get_qkv( infer_state.position_cos_global.to(q.dtype), infer_state.position_sin_global.to(q.dtype), ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma3TransformerLayerWeight) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - gate = layer_weight.gate_proj.mm(input.view(-1, self.embed_dim_)) - up = layer_weight.up_proj.mm(input.view(-1, self.embed_dim_)) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + gate = layer_weight.gate_proj.mm(input) + up = layer_weight.up_proj.mm(input) # gelu_and_mul_fwd(up_gate_out, ffn1_out) ffn1_out = nn.functional.gelu(gate, approximate="tanh") * up input = None ffn2_out = layer_weight.down_proj.mm(ffn1_out) ffn1_out = None + ffn2_out = self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) return ffn2_out def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma3TransformerLayerWeight): @@ -82,8 +86,6 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -94,8 +96,6 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) ffn_out = layer_weight.post_feedforward_layernorm_weight_( input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor @@ -115,8 +115,6 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -127,8 +125,6 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) ffn_out = layer_weight.post_feedforward_layernorm_weight_( input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor diff --git a/lightllm/common/mamba_cache_mem_manager/__init__.py b/lightllm/models/gemma4/__init__.py similarity index 100% rename from lightllm/common/mamba_cache_mem_manager/__init__.py rename to lightllm/models/gemma4/__init__.py diff --git a/lightllm/models/gemma4/gemma4_visual.py b/lightllm/models/gemma4/gemma4_visual.py new file mode 100644 index 0000000000..7ed64108b3 --- /dev/null +++ b/lightllm/models/gemma4/gemma4_visual.py @@ -0,0 +1,146 @@ +import json +import os +from io import BytesIO +from typing import List + +import torch +from PIL import Image +from safetensors import safe_open +from transformers import AutoConfig, AutoProcessor + +from lightllm.server.embed_cache.utils import get_shm_name_data, read_shm +from lightllm.server.multimodal_params import ImageItem +from lightllm.utils.log_utils import init_logger +from lightllm.utils.torch_dtype_utils import get_torch_dtype + + +logger = init_logger(__name__) + + +class Gemma4VisionModel: + def __init__(self, data_type="bfloat16"): + self.vision_tower = None + self.embed_vision = None + self.image_processor = None + self.data_type = data_type if isinstance(data_type, torch.dtype) else get_torch_dtype(data_type) + self.device = torch.device("cpu") + + def _weight_files(self, weight_dir): + index_path = os.path.join(weight_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + weight_map = json.load(f)["weight_map"] + return sorted(set(weight_map.values())) + return sorted(f for f in os.listdir(weight_dir) if f.endswith(".safetensors")) + + def _load_prefix_state_dict(self, weight_dir, prefix): + state_dict = {} + for file_name in self._weight_files(weight_dir): + file_path = os.path.join(weight_dir, file_name) + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith(prefix): + state_dict[key[len(prefix) :]] = f.get_tensor(key) + return state_dict + + def load_model(self, weight_dir): + try: + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4MultimodalEmbedder, + Gemma4VisionModel as HFGemma4VisionModel, + ) + except ImportError as e: + raise ImportError("Gemma-4 vision requires a transformers build with Gemma4 support.") from e + + config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True) + if config.vision_config is None: + raise ValueError("Gemma-4 checkpoint does not contain vision_config") + + processor = AutoProcessor.from_pretrained(weight_dir) + self.image_processor = processor.image_processor + self.vision_tower = HFGemma4VisionModel(config.vision_config).eval() + self.embed_vision = Gemma4MultimodalEmbedder(config.vision_config, config.text_config).eval() + + vision_state = self._load_prefix_state_dict(weight_dir, "model.vision_tower.") + embed_state = self._load_prefix_state_dict(weight_dir, "model.embed_vision.") + missing, unexpected = self.vision_tower.load_state_dict(vision_state, strict=False) + if missing or unexpected: + raise RuntimeError(f"Gemma-4 vision_tower weight mismatch: missing={missing}, unexpected={unexpected}") + missing, unexpected = self.embed_vision.load_state_dict(embed_state, strict=False) + if missing or unexpected: + raise RuntimeError(f"Gemma-4 embed_vision weight mismatch: missing={missing}, unexpected={unexpected}") + + return self + + def cuda(self): + self.device = torch.device("cuda") + self.vision_tower = self.vision_tower.cuda() + self.embed_vision = self.embed_vision.cuda() + return self + + def forward(self, pixel_values, image_position_ids): + pixel_values = pixel_values.to(self.device, non_blocking=True) + image_position_ids = image_position_ids.to(self.device, non_blocking=True) + pooling_k = self.vision_tower.config.pooling_kernel_size + pooling_k2 = pooling_k * pooling_k + + # Per-image vision-tower call. `output_length` MUST match the per-image + # num_soft_tokens the image processor declared; otherwise HF's pooler + # falls back to config.image_seq_length and silently emits a different + # token count than what `valid_ids` expects. + per_image_hidden = [] + for i in range(pixel_values.shape[0]): + pv = pixel_values[i : i + 1] + pp = image_position_ids[i : i + 1] + output_length = pv.shape[1] // pooling_k2 + per_image_hidden.append( + self.vision_tower( + pixel_values=pv, + pixel_position_ids=pp, + output_length=output_length, + ).last_hidden_state + ) + + # embed_vision is token-independent (RMSNorm + Linear); cat once and + # project once instead of looping like vllm — same numerics, fewer + # Python launches, lines up naturally with our flat embed-cache output. + flat_hidden = torch.cat(per_image_hidden, dim=0) + target_dtype = self.embed_vision.embedding_projection.weight.dtype + image_features = self.embed_vision(inputs_embeds=flat_hidden.unsqueeze(0).to(target_dtype)).squeeze(0) + return image_features.to(self.data_type) + + @torch.inference_mode() + def encode(self, images: List[ImageItem]): + pil_images = [] + uuids = [] + for img in images: + if not isinstance(img, ImageItem): + raise TypeError(f"Unsupported Gemma-4 image input type: {type(img)}") + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + with Image.open(BytesIO(image_data)) as image: + pil_images.append(image.convert("RGB")) + + if not pil_images: + return None + + image_inputs = self.image_processor(pil_images, return_tensors="pt") + token_nums = image_inputs.pop("num_soft_tokens_per_image") + pixel_values = image_inputs["pixel_values"] + image_position_ids = image_inputs["image_position_ids"] + + valid_ids = [] + valid_start = 0 + for img, token_num in zip(images, token_nums): + token_num = int(token_num) + if img.token_num != token_num: + raise ValueError(f"Gemma-4 image token mismatch: allocated={img.token_num}, encoded={token_num}") + valid_ids.append([valid_start, valid_start + token_num]) + valid_start += token_num + + all_img_embeds = self.forward(pixel_values, image_position_ids) + if all_img_embeds.shape[0] != valid_start: + raise ValueError( + f"Gemma-4 image embed length mismatch: embeds={all_img_embeds.shape[0]}, tokens={valid_start}" + ) + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/gemma4/infer_struct.py b/lightllm/models/gemma4/infer_struct.py new file mode 100644 index 0000000000..89ad34acbd --- /dev/null +++ b/lightllm/models/gemma4/infer_struct.py @@ -0,0 +1,78 @@ +import torch +from lightllm.common.basemodel import InferStateInfo +from lightllm.models.gemma4.triton_kernel.build_b_image_token_end import build_b_image_token_end + + +class Gemma4InferStateInfo(InferStateInfo): + def __init__(self): + super().__init__() + # Gemma-4 uses two RoPE frequency tables (one per layer type): + # * sliding_attention layers: theta=10000, full rotation over head_dim=256 + # * full_attention layers: theta=1_000_000, partial rotation (first 25% of head_dim=512) + self.position_cos_sliding = None + self.position_sin_sliding = None + self.position_cos_full = None + self.position_sin_full = None + # b_image_token_end 用于标记每个 token 在 att 计算时,可以看到的对应的最大长度位置,用于 + # 对于文本token 和 image token 是区别对待的, + # 文本token 对应的位置一定是 0, image token 对应的位置,是该token能看到的最远image token位置。 + # 相当于 image token 部分是双向 att,text token 还是 causal att。 + # 对应一个请求 token list 为 [t, t, i, i, t] 的一个token序列, + # 则对应的 b_image_token_end 为 [0, 0, 4, 4, 0], + # image token 可以看到自己当前这个token以及后面的 image token。 + self.b_image_token_end = None + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) + position_ids = self.position_ids + self.position_cos_sliding = torch.index_select(model._cos_cached_sliding, 0, position_ids).view( + position_ids.shape[0], -1 + ) + self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, position_ids).view( + position_ids.shape[0], -1 + ) + self.position_cos_full = torch.index_select(model._cos_cached_full, 0, position_ids).view( + position_ids.shape[0], -1 + ) + self.position_sin_full = torch.index_select(model._sin_cached_full, 0, position_ids).view( + position_ids.shape[0], -1 + ) + if self.is_prefill: + self.max_seq_len = self.max_kv_seq_len + self._build_b_image_token_end() + return + + def _build_b_image_token_end(self): + device = self.position_ids.device + self.b_image_token_end = torch.zeros(self.position_ids.shape[0], dtype=torch.int32, device=device) + + if not self.multimodal_params: + return + + b_image_start_idx = [] + b_image_len = [] + b_image_nums = [] + b_image_start_num = [] + image_start_num = 0 + for params in self.multimodal_params: + b_image_start_num.append(image_start_num) + images = params.get("images", []) + b_image_nums.append(len(images)) + for img in images: + b_image_start_idx.append(img["start_idx"]) + b_image_len.append(img["token_num"]) + image_start_num += 1 + + if image_start_num == 0: + return + + build_b_image_token_end( + b_image_start_idx=torch.tensor(b_image_start_idx, dtype=torch.int32).cuda(non_blocking=True), + b_image_len=torch.tensor(b_image_len, dtype=torch.int32).cuda(non_blocking=True), + b_image_nums=torch.tensor(b_image_nums, dtype=torch.int32).cuda(non_blocking=True), + b_image_start_num=torch.tensor(b_image_start_num, dtype=torch.int32).cuda(non_blocking=True), + b_q_start_loc=self.b_q_start_loc, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_image_token_end=self.b_image_token_end, + ) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/__init__.py b/lightllm/models/gemma4/layer_infer/__init__.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/continues_batch/__init__.py rename to lightllm/models/gemma4/layer_infer/__init__.py diff --git a/lightllm/models/gemma4/layer_infer/post_layer_infer.py b/lightllm/models/gemma4/layer_infer/post_layer_infer.py new file mode 100644 index 0000000000..22bcf0508d --- /dev/null +++ b/lightllm/models/gemma4/layer_infer/post_layer_infer.py @@ -0,0 +1,20 @@ +import torch +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer + + +class Gemma4PostLayerInfer(LlamaPostLayerInfer): + """ + Same final RMSNorm + tied lm_head path as Llama, with an extra tanh-based + logit softcap at the end: logits = softcap * tanh(logits / softcap). + """ + + def __init__(self, network_config): + super().__init__(network_config) + self.final_logit_softcapping = float(network_config.get("final_logit_softcapping")) + + def token_forward(self, input_embdings, infer_state, layer_weight): + logits = super().token_forward(input_embdings, infer_state, layer_weight) + if self.final_logit_softcapping is not None and self.final_logit_softcapping > 0: + cap = self.final_logit_softcapping + logits = torch.tanh(logits / cap) * cap + return logits diff --git a/lightllm/models/gemma4/layer_infer/pre_layer_infer.py b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..5de3036358 --- /dev/null +++ b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py @@ -0,0 +1,99 @@ +import math +import torch +import torch.distributed as dist +from lightllm.distributed.communication_op import all_reduce +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb + + +class Gemma4PreLayerInfer(LlamaMultimodalPreLayerInfer): + def __init__(self, network_config): + super().__init__(network_config) + self.embed_scale = float(network_config["hidden_size"]) ** 0.5 + self.multimodal_text_embed_scale_ = self.embed_scale + self.pad_token_id_ = network_config.get("pad_token_id", 0) + + self.has_ple = bool(network_config.get("hidden_size_per_layer_input")) + if self.has_ple: + self.num_layers_ = network_config["num_hidden_layers"] + self.ple_dim_ = network_config["hidden_size_per_layer_input"] + self.ple_embed_scale_ = math.sqrt(self.ple_dim_) + self.ple_proj_scale_ = float(network_config["hidden_size"]) ** -0.5 + self.ple_combine_scale_ = 2.0 ** -0.5 + self.rms_norm_eps_ = network_config.get("rms_norm_eps", 1e-6) + self.ple_static_buffer = None + + def _compute_per_layer_embeds(self, input_ids_for_ple, input_embdings, infer_state, layer_weight): + # 查表 PLE。 + ple_embeds = layer_weight.embed_tokens_per_layer_weight_(input_ids_for_ple) + if self.tp_world_size_ > 1: + all_reduce(ple_embeds, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + ple_embeds = ple_embeds * self.ple_embed_scale_ + + # 这个分支本质上只对多模态token存在建模上的意义。 + ple_proj = layer_weight.per_layer_model_projection_weight_.mm(input_embdings) + ple_proj = ple_proj * self.ple_proj_scale_ + ple_proj = ple_proj.reshape(*ple_proj.shape[:-1], self.num_layers_, self.ple_dim_) + ple_proj = layer_weight.per_layer_projection_norm_weight_( + input=ple_proj, eps=self.rms_norm_eps_, alloc_func=self.alloc_tensor + ) + ple_embeds = ple_embeds.reshape(*ple_embeds.shape[:-1], self.num_layers_, self.ple_dim_) + + handle_len = input_embdings.shape[0] + torch.add(ple_proj, ple_embeds, out=self.ple_static_buffer[:handle_len]) + self.ple_static_buffer[:handle_len].mul_(self.ple_combine_scale_) + return + + def context_forward(self, input_ids, infer_state, layer_weight): + input_embdings = LlamaMultimodalPreLayerInfer.context_forward(self, input_ids, infer_state, layer_weight) + if self.has_ple: + input_ids_for_ple = input_ids.masked_fill(infer_state.b_image_token_end != 0, self.pad_token_id_) + self._compute_per_layer_embeds(input_ids_for_ple, input_embdings, infer_state, layer_weight) + return input_embdings + + def token_forward(self, input_ids, infer_state, layer_weight): + input_embdings = LlamaPreLayerInfer.token_forward(self, input_ids, infer_state, layer_weight) + input_embdings = input_embdings * self.embed_scale + if self.has_ple: + self._compute_per_layer_embeds(input_ids, input_embdings, infer_state, layer_weight) + return input_embdings + + def _tpsp_sp_split(self, input: torch.Tensor, infer_state): + if self.tp_world_size_ > 1 and get_env_start_args().enable_tpsp_mix_mode: + # SP would need a per-rank slice (N/world_size tokens), but the + # PLE static buffer is sized/written for the full N tokens. If you + # ever need SP + PLE, refactor _compute_per_layer_embeds to do an + # sp_pad_copy into a per-rank buffer. + assert not self.has_ple, "gemma4 PLE + enable_tpsp_mix_mode not implemented" + return super()._tpsp_sp_split(input=input, infer_state=infer_state) + return input + + def _multimodal_emb( + self, + out: torch.Tensor, + input_ids: torch.Tensor, + layer_weight, + embed_cache: torch.Tensor, + img_token_lens: torch.Tensor, + img_start_token_ids: torch.Tensor, + img_start_locs_in_cache: torch.Tensor, + ) -> torch.Tensor: + """ + 修改多模态的 embed 计算的细节实现方式,调用本地的 multimodal_text_embed_scale_ 参数。 + """ + multimodal_emb( + out=out, + prompt_ids=input_ids, + text_weight_embs=layer_weight.wte_weight_.weight, + embed_cache=embed_cache, + img_token_lens=img_token_lens, + img_start_token_ids=img_start_token_ids, + img_start_locs_in_cache=img_start_locs_in_cache, + tp_text_start_token_id=layer_weight.wte_weight_.tp_vocab_start_id, + tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, + tp_world_size=self.tp_world_size_, + text_embed_scale=self.multimodal_text_embed_scale_, + ) + return diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..015b526fbc --- /dev/null +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -0,0 +1,381 @@ +import math +import torch +import torch.nn as nn + +from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight +from lightllm.models.gemma4.triton_kernel.context_attention_fwd_gemma4_mm import ( + context_attention_fwd_gemma4_mm, +) +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd + + +class Gemma4TransformerLayerInfer(LlamaTransformerLayerInfer): + """ + Gemma-4 decoder block. Per-layer heterogeneity (sliding vs full attention) + is handled by switching shape / RoPE table / sliding-window flag at init + time. The KV cache layout is uniform (sliding shape: num_kv_heads=16, + head_dim=256); full-attention layers pack their (4, 512) tensor into the + first 8 heads of the 16-head slot at cache-write time, then reshape on + read. See Gemma4TpPartModel._init_mem_manager for context. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.eps_ = network_config.get("rms_norm_eps", 1e-6) + self.embed_dim_ = network_config["hidden_size"] + self.is_moe = bool(network_config.get("enable_moe_block", False)) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", network_config.get("top_k_experts", 0)) + self.norm_topk_prob = network_config.get("norm_topk_prob", True) + self.router_root_scale = self.embed_dim_ ** -0.5 + + layer_type = network_config["layer_types"][layer_num] + self.is_sliding = layer_type == "sliding_attention" + + # Some E-series checkpoints leave num_global_key_value_heads = null; + # HF treats that as "fall back to num_key_value_heads". + num_global_kv = network_config.get("num_global_key_value_heads") or network_config["num_key_value_heads"] + + # Override parent's head_dim_ (hidden_size/num_heads = 224 on 31B, wrong + # for Gemma-4 — actual is 256 sliding / 512 full). + if self.is_sliding: + self.head_dim_ = network_config["head_dim"] + total_kv_heads = network_config["num_key_value_heads"] + self.k_eq_v = False + else: + self.head_dim_ = network_config["global_head_dim"] + total_kv_heads = num_global_kv + self.k_eq_v = network_config.get("attention_k_eq_v", True) + + # TP shard counts for this layer + self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = max(total_kv_heads // self.tp_world_size_, 1) + self.tp_v_head_num_ = self.tp_k_head_num_ + self.tp_o_head_num_ = self.tp_q_head_num_ + + self.kv_cache_slot_dim_ = network_config["head_dim"] + sliding_total = network_config["num_key_value_heads"] * network_config["head_dim"] + full_total = num_global_kv * network_config["global_head_dim"] + per_token_k_width = max(sliding_total, full_total) + assert ( + per_token_k_width % self.kv_cache_slot_dim_ == 0 + ), f"per-token K width {per_token_k_width} not aligned to kv_cache_slot_dim {self.kv_cache_slot_dim_}" + self.kv_cache_slot_num_ = (per_token_k_width // self.kv_cache_slot_dim_) // self.tp_world_size_ + + # Sliding window (None on full-attn layers) + if self.is_sliding: + sw = network_config.get("sliding_window", 0) + self.sliding_window_ = int(sw) if sw else 0 + else: + self.sliding_window_ = 0 + + # E-series Per-Layer Embeddings gate (HF: config.hidden_size_per_layer_input, + # absent or 0 on 31B). + self.has_ple_ = bool(network_config.get("hidden_size_per_layer_input")) + if self.has_ple_: + self.ple_dim_ = network_config["hidden_size_per_layer_input"] + + # HF: config.num_kv_shared_layers (may be missing or null on non-E + # checkpoints — treat as 0). + kv_shared_count = network_config.get("num_kv_shared_layers") or 0 + total_layers = network_config["num_hidden_layers"] + self.is_kv_shared_ = kv_shared_count > 0 and layer_num >= total_layers - kv_shared_count + self.kv_share_target_layer_ = None + if self.is_kv_shared_: + cutoff = total_layers - kv_shared_count + for j in range(cutoff - 1, -1, -1): + if network_config["layer_types"][j] == layer_type: + self.kv_share_target_layer_ = j + break + assert self.kv_share_target_layer_ is not None, ( + f"layer {layer_num} ({layer_type}) is KV-shared but no earlier non-shared " + f"layer of the same type found below cutoff={cutoff}" + ) + + # Always 1.0: NoPE dims for full-attn layers are zero-padded into + # cos/sin (cos=1, sin=0 → identity), so the kernel walks the whole + # head_dim. Don't change to 0.25 — that double-counts with the table. + self.partial_rotary_factor_ = 1.0 + + self.ple_static_buffer = None + + def _rope_cos_sin(self, infer_state): + # Tables are built in the model dtype (Gemma4TpPartModel._init_to_get_rotary_gemma4), + # so they already match q/k dtype — no cast needed. + if self.is_sliding: + return infer_state.position_cos_sliding, infer_state.position_sin_sliding + return infer_state.position_cos_full, infer_state.position_sin_full + + def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: + input = self._tpsp_allgather(input=input, infer_state=infer_state) + + head_dim = self.head_dim_ + q_heads = self.tp_q_head_num_ + kv_heads = self.tp_k_head_num_ + + q = layer_weight.q_proj.mm(input).view(-1, q_heads, head_dim) + q = layer_weight.q_norm_weight_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) + + cos, sin = self._rope_cos_sin(infer_state) + + if self.is_kv_shared_: + # K/V come from target layer's already-rotated, already-normed cache. + rotary_emb_fwd(q, None, cos, sin, partial_rotary_factor=self.partial_rotary_factor_) + q = q * math.sqrt(head_dim) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + return q, None + + # ---- non-shared: full K/V path ---- + k = layer_weight.k_proj.mm(input).view(-1, kv_heads, head_dim) + if self.k_eq_v: + # Full-attn k_eq_v variant (e.g. 31B): K weights serve as V. + v = k + else: + v = layer_weight.v_proj.mm(input).view(-1, kv_heads, head_dim) + + k = layer_weight.k_norm_weight_(input=k, eps=self.eps_, alloc_func=self.alloc_tensor) + + # V-norm: unweighted RMSNorm over head_dim (matches vllm's Gemma4 has_weight=False). + v = rmsnorm_forward( + x=v, + weight=None, + eps=self.eps_, + out=self.alloc_tensor(v.shape, dtype=v.dtype, device=v.device), + ) + + rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=self.partial_rotary_factor_) + + # Gemma-4 uses scaling=1.0 in attention. The attention kernel hardcodes + # sm_scale = 1/sqrt(head_dim); pre-scale Q by sqrt(head_dim) so the + # kernel's division cancels out, yielding scores = Q @ K^T. + q = q * math.sqrt(head_dim) + + # Pack into the uniform KV-cache layout (N, 2*slot_num, slot_dim). + # K occupies slots [0, used_slots); V occupies + # [slot_num, slot_num + used_slots). If this layer's K/V width is + # smaller than the allocated cache slot width, pad with zeros. + cache_slot_num = self.kv_cache_slot_num_ + cache_slot_dim = self.kv_cache_slot_dim_ + N = k.shape[0] + k_packed = k.reshape(N, -1, cache_slot_dim) + v_packed = v.reshape(N, -1, cache_slot_dim) + used_cache_slots = k_packed.shape[1] + if used_cache_slots == cache_slot_num: + cache_kv = torch.cat([k_packed, v_packed], dim=1) + else: + cache_kv = self.alloc_tensor((N, 2 * cache_slot_num, cache_slot_dim), dtype=k.dtype) + cache_kv.zero_() + cache_kv[:, :used_cache_slots, :] = k_packed + cache_kv[:, cache_slot_num : cache_slot_num + used_cache_slots, :] = v_packed + + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) + + return q, cache_kv + + def _post_cache_kv(self, cache_kv, infer_state, layer_weight): + if self.is_kv_shared_ or cache_kv is None: + return + return super()._post_cache_kv(cache_kv, infer_state, layer_weight) + + # ----- Attention kernels (sliding window + per-layer KV reshape) --- + + def _att_control(self): + if self.is_sliding and self.sliding_window_ > 0: + w = self.sliding_window_ - 1 + return AttControl(use_sliding_window=True, sliding_window=(w, 0)) + return AttControl(use_sliding_window=False, sliding_window=(-1, -1)) + + def _get_layer_kv(self, infer_state: InferStateInfo): + # KV-shared layers read from the target layer's cache slot. + layer_idx = self.kv_share_target_layer_ if self.is_kv_shared_ else self.layer_num_ + _k_raw, _v_raw = infer_state.mem_manager.get_att_input_params(layer_index=layer_idx) + # _k_raw / _v_raw shape (S, cache_slot_num, cache_slot_dim). Use .view + # (not .reshape) so any non-contiguous layout from a future mem_manager + # backend fails loudly instead of silently copying — slice + view is + # O(1) on the standard MemoryManager layout (inner (kv_heads, head_dim) + # span is contiguous). + kv_heads = self.tp_k_head_num_ + head_dim = self.head_dim_ + cache_slot_dim = self.kv_cache_slot_dim_ + used_cache_slots = kv_heads * head_dim // cache_slot_dim + if used_cache_slots == _k_raw.shape[1]: + # Layout already matches this layer's natural shape. + return _k_raw.view(-1, kv_heads, head_dim), _v_raw.view(-1, kv_heads, head_dim) + # Otherwise the K/V live in the first used_cache_slots; the rest is zero pad. + _k = _k_raw[:, :used_cache_slots, :].view(-1, kv_heads, head_dim) + _v = _v_raw[:, :used_cache_slots, :].view(-1, kv_heads, head_dim) + return _k, _v + + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: InferStateInfo, + layer_weight: Gemma4TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + _k, _v = self._get_layer_kv(infer_state) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + if self.is_sliding: + # Sliding layers always go through the gemma4_mm Triton kernel: it + # handles SWA + image bidirectional masking in one pass. + o_tensor = self.alloc_tensor(_q.shape, q.dtype) + sw = (self.sliding_window_ - 1, 0) if self.sliding_window_ > 0 else (-1, -1) + context_attention_fwd_gemma4_mm( + _q, + _k, + _v, + o_tensor, + infer_state.b_req_idx, + infer_state.b_q_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_q_seq_len, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_end, + sliding_window=sw, + ) + return o_tensor.view(q.shape) + + # Full-attn layers: head_dim=512, no SWA, no image bidi — standard + # triton via backend1. + o_tensor = infer_state.prefill_att_state1.prefill_att( + q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor + ) + return o_tensor.view(q.shape) + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: InferStateInfo, + layer_weight: Gemma4TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + _k, _v = self._get_layer_kv(infer_state) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + att_state = infer_state.decode_att_state if self.is_sliding else infer_state.decode_att_state1 + o_tensor = att_state.decode_att(q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor) + return o_tensor.view(q.shape) + + # ----- FFN (Gemma gelu-tanh, fused gate_up + down) ----------------- + + def _ffn_dense( + self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + gate_up = layer_weight.gate_up_proj.mm(input) + ffn1 = self.alloc_tensor((input.size(0), gate_up.size(1) // 2), input.dtype) + silu_and_mul_fwd(gate_up, ffn1) + gate_up = None + ffn2 = layer_weight.down_proj.mm(ffn1) + ffn1 = None + ffn2 = self._tpsp_reduce(input=ffn2, infer_state=infer_state) + return ffn2 + + def _router_logits(self, residual, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: + # Mirrors vllm Gemma4Router: unweighted RMSNorm -> 1/sqrt(hidden) -> + # per-channel scale -> bf16xbf16 -> fp32 gate matmul for stable top-k. + x = residual.view(-1, self.embed_dim_) + x = rmsnorm_forward(x=x, weight=None, eps=self.eps_, out=self.alloc_tensor(x.shape, dtype=x.dtype)) + x = x * self.router_root_scale * layer_weight.router_input_scale_.weight + return layer_weight.moe_gate.mm(x.to(torch.float32)) + + def _ffn_moe(self, input, router_logits, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + moe_out = layer_weight.experts.experts( + input, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + ) + moe_out = self._tpsp_reduce(input=moe_out, infer_state=infer_state) + return moe_out + + def _ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + residual = input_embdings + dense_input = layer_weight.pre_feedforward_layernorm_weight_( + input=residual, eps=self.eps_, alloc_func=self.alloc_tensor + ) + dense_out = self._ffn_dense(dense_input, infer_state, layer_weight) + dense_input = None + + if self.is_moe: + dense_out = layer_weight.post_feedforward_layernorm_1_weight_( + input=dense_out, eps=self.eps_, alloc_func=self.alloc_tensor + ) + + router_logits = self._router_logits(residual, layer_weight) + moe_input = layer_weight.pre_feedforward_layernorm_2_weight_( + input=residual, eps=self.eps_, alloc_func=self.alloc_tensor + ) + moe_out = self._ffn_moe(moe_input, router_logits, infer_state, layer_weight) + moe_input = None + router_logits = None + moe_out = layer_weight.post_feedforward_layernorm_2_weight_( + input=moe_out, eps=self.eps_, alloc_func=self.alloc_tensor + ) + dense_out.add_(moe_out) + moe_out = None + + ffn_out = layer_weight.post_feedforward_layernorm_weight_( + input=dense_out, eps=self.eps_, alloc_func=self.alloc_tensor + ) + dense_out = None + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + # ----- block-level forwards (PLE fusion + layer_scalar at the end) ---- + + def _block_epilogue(self, hidden_states, infer_state, layer_weight): + if self.has_ple_: + flat = hidden_states.view(-1, self.embed_dim_) + N = flat.shape[0] + ple_slice = self.ple_static_buffer[:N, self.layer_num_, :] + gate = layer_weight.per_layer_input_gate_.mm(flat) + gated = nn.functional.gelu(gate, approximate="tanh") * ple_slice + contrib = layer_weight.per_layer_projection_.mm(gated) + contrib = layer_weight.post_per_layer_input_norm_weight_( + input=contrib, eps=self.eps_, alloc_func=self.alloc_tensor + ) + flat.add_(contrib) + hidden_states.mul_(layer_weight.layer_scalar_.weight) + return hidden_states + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_), infer_state, layer_weight) + o = self.context_attention_forward(input1, infer_state, layer_weight) + input1 = None + # Gemma sandwich norm: post_attention_layernorm on the attn branch + # before the residual add, not on the post-add residual stream. + o = self._ffn_norm(o, infer_state, layer_weight) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input_embdings = self._ffn(input_embdings, infer_state, layer_weight) + + return self._block_epilogue(input_embdings, infer_state, layer_weight) + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_), infer_state, layer_weight) + o = self.token_attention_forward(input1, infer_state, layer_weight) + input1 = None + o = self._ffn_norm(o, infer_state, layer_weight) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input_embdings = self._ffn(input_embdings, infer_state, layer_weight) + + return self._block_epilogue(input_embdings, infer_state, layer_weight) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/__init__.py b/lightllm/models/gemma4/layer_weights/__init__.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/__init__.py rename to lightllm/models/gemma4/layer_weights/__init__.py diff --git a/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..22a2fc4dc7 --- /dev/null +++ b/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,64 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + ROWMMWeight, + RMSNormWeight, +) + + +class Gemma4PreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + + self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.language_model.embed_tokens.weight", + data_type=self.data_type_, + ) + # lm_head is tied to input embedding for Gemma-4 (no separate lm_head.weight). + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_, + ) + + # Gemma-4 uses standard RMSNorm (not the gemma2/3 (1+w) variant). + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name="model.language_model.norm.weight", + data_type=self.data_type_, + ) + + if network_config.get("hidden_size_per_layer_input"): + num_layers = network_config["num_hidden_layers"] + ple_dim = network_config["hidden_size_per_layer_input"] + ple_vocab = network_config.get("vocab_size_per_layer_input", vocab_size) + self.embed_tokens_per_layer_weight_ = EmbeddingWeight( + dim=num_layers * ple_dim, + vocab_size=ple_vocab, + weight_name="model.language_model.embed_tokens_per_layer.weight", + data_type=self.data_type_, + ) + # nn.Linear(in=hidden_size, out=num_layers*ple_dim); HF storage is + # (out, in). Replicated across TP ranks. + self.per_layer_model_projection_weight_ = ROWMMWeight( + in_dim=hidden_size, + out_dims=[num_layers * ple_dim], + weight_names="model.language_model.per_layer_model_projection.weight", + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + # RMSNorm over the ple_dim of the projection output. + self.per_layer_projection_norm_weight_ = RMSNormWeight( + dim=ple_dim, + weight_name="model.language_model.per_layer_projection_norm.weight", + data_type=self.data_type_, + ) + return diff --git a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..977b889f2e --- /dev/null +++ b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py @@ -0,0 +1,275 @@ +import torch +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight, COLMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights import RMSNormWeight, ParameterWeight +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.gemma4_packed_fused_moe_weight import ( + Gemma4PackedFusedMoeWeight, +) +from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight +from lightllm.utils.envs_utils import get_env_start_args + + +class Gemma4TransformerLayerWeight(LlamaTransformerLayerWeight): + def __init__( + self, + layer_num, + data_type, + network_config, + quant_cfg=None, + ): + self._pre_parse_layer_shape(layer_num, network_config) + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _pre_parse_layer_shape(self, layer_num, network_config): + self._is_moe = bool(network_config.get("enable_moe_block", False)) + layer_type = network_config["layer_types"][layer_num] + self._is_sliding = layer_type == "sliding_attention" + # Some E-series checkpoints leave num_global_key_value_heads = null; + # HF treats that as "fall back to num_key_value_heads". + num_global_kv = network_config.get("num_global_key_value_heads") or network_config["num_key_value_heads"] + if self._is_sliding: + self._layer_head_dim = network_config["head_dim"] + self._layer_kv_head_num = network_config["num_key_value_heads"] + self._layer_k_eq_v = False + else: + self._layer_head_dim = network_config["global_head_dim"] + self._layer_kv_head_num = num_global_kv + self._layer_k_eq_v = network_config.get("attention_k_eq_v", True) + + def _parse_config(self): + self.n_head = self.network_config_["num_attention_heads"] + self.q_head_num_ = self.network_config_["num_attention_heads"] + self.k_head_num_ = self._layer_kv_head_num + self.v_head_num_ = self._layer_kv_head_num + self.o_head_num_ = self.q_head_num_ + self.head_dim = self._layer_head_dim + self.n_embed = self.network_config_["hidden_size"] + self.n_inter = self.network_config_["intermediate_size"] + + def _init_weight_names(self): + prefix = f"model.language_model.layers.{self.layer_num_}" + self._q_weight_name = f"{prefix}.self_attn.q_proj.weight" + self._q_bias_name = None + self._k_weight_name = f"{prefix}.self_attn.k_proj.weight" + self._k_bias_name = None + self._v_weight_name = f"{prefix}.self_attn.v_proj.weight" + self._v_bias_name = None + self._o_weight_name = f"{prefix}.self_attn.o_proj.weight" + self._o_bias_name = None + + self._q_norm_weight_name = f"{prefix}.self_attn.q_norm.weight" + self._k_norm_weight_name = f"{prefix}.self_attn.k_norm.weight" + + self._gate_weight_name = f"{prefix}.mlp.gate_proj.weight" + self._up_weight_name = f"{prefix}.mlp.up_proj.weight" + self._down_weight_name = f"{prefix}.mlp.down_proj.weight" + + self._att_norm_weight_name = f"{prefix}.input_layernorm.weight" + self._ffn_norm_weight_name = f"{prefix}.post_attention_layernorm.weight" + self._pre_feedforward_layernorm_name = f"{prefix}.pre_feedforward_layernorm.weight" + self._post_feedforward_layernorm_name = f"{prefix}.post_feedforward_layernorm.weight" + self._post_feedforward_layernorm_1_name = f"{prefix}.post_feedforward_layernorm_1.weight" + self._pre_feedforward_layernorm_2_name = f"{prefix}.pre_feedforward_layernorm_2.weight" + self._post_feedforward_layernorm_2_name = f"{prefix}.post_feedforward_layernorm_2.weight" + + self._router_input_scale_name = f"{prefix}.router.scale" + self._router_weight_name = f"{prefix}.router.proj.weight" + + self._layer_scalar_name = f"{prefix}.layer_scalar" + + # E-series Per-Layer Embeddings names (only loaded when PLE enabled). + self._per_layer_input_gate_name = f"{prefix}.per_layer_input_gate.weight" + self._per_layer_projection_name = f"{prefix}.per_layer_projection.weight" + self._post_per_layer_input_norm_name = f"{prefix}.post_per_layer_input_norm.weight" + + def _init_weight(self): + self._init_qkv() + self._init_o() + self._init_ffn() + if self._is_moe: + self._init_moe() + self._init_norm() + if self.network_config_.get("hidden_size_per_layer_input"): + self._init_ple() + + def _init_ple(self): + ple_dim = self.network_config_["hidden_size_per_layer_input"] + hidden_size = self.network_config_["hidden_size"] + self.per_layer_input_gate_ = ROWMMWeight( + in_dim=hidden_size, + out_dims=[ple_dim], + weight_names=self._per_layer_input_gate_name, + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + self.per_layer_projection_ = ROWMMWeight( + in_dim=ple_dim, + out_dims=[hidden_size], + weight_names=self._per_layer_projection_name, + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + self.post_per_layer_input_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_per_layer_input_norm_name, + data_type=self.data_type_, + ) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + kv_out_dim = self.k_head_num_ * self.head_dim + + self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=self._q_weight_name, + data_type=self.data_type_, + bias_names=self._q_bias_name, + quant_method=self.get_quant_method("q_proj"), + ) + self.k_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[kv_out_dim], + weight_names=self._k_weight_name, + data_type=self.data_type_, + bias_names=self._k_bias_name, + quant_method=self.get_quant_method("k_proj"), + ) + if not self._layer_k_eq_v: + self.v_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[kv_out_dim], + weight_names=self._v_weight_name, + data_type=self.data_type_, + bias_names=self._v_bias_name, + quant_method=self.get_quant_method("v_proj"), + ) + # For k_eq_v layers HF checkpoint has no v_proj weight; the inference + # code aliases v = k at compute time, so no weight object is created. + + def _init_o(self): + in_dim = self.o_head_num_ * self.head_dim + out_dim = self.n_embed + self.o_proj = COLMMWeight( + in_dim=in_dim, + out_dims=[out_dim], + weight_names=self._o_weight_name, + data_type=self.data_type_, + bias_names=self._o_bias_name, + quant_method=self.get_quant_method("o_proj"), + ) + + def _init_ffn(self): + # Packed gate+up: ROWMMWeight stitches `gate_proj` and `up_proj` weights + # along the output dim so the dense FFN runs one matmul + a fused + # gelu*mul kernel (mirrors llama's gate_up_proj path). + self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter, self.n_inter], + weight_names=[self._gate_weight_name, self._up_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("gate_up_proj"), + ) + self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], + weight_names=self._down_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("down_proj"), + ) + + def _init_moe(self): + enable_ep_moe = get_env_start_args().enable_ep_moe + assert not enable_ep_moe, "Gemma-4 MoE packed expert weights currently support TP mode only." + + self.router_input_scale_ = ParameterWeight( + weight_name=self._router_input_scale_name, + data_type=self.data_type_, + weight_shape=(self.n_embed,), + ) + self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.network_config_["num_experts"]], + weight_names=self._router_weight_name, + data_type=torch.float32, + bias_names=None, + quant_method=self.get_quant_method("moe_gate"), + tp_rank=0, + tp_world_size=1, + ) + self.experts = Gemma4PackedFusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"model.language_model.layers.{self.layer_num_}.experts", + n_routed_experts=self.network_config_["num_experts"], + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=self.network_config_["moe_intermediate_size"], + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + per_expert_scale_name=f"model.language_model.layers.{self.layer_num_}.router.per_expert_scale", + ) + + def _init_norm(self): + hidden_size = self.network_config_["hidden_size"] + # Gemma-4 uses standard RMSNorm (x * rsqrt(var+eps) * w), NOT the + # gemma2/3 (1+w) variant - do not swap in NoTpGEMMANormWeight. + self.q_norm_weight_ = RMSNormWeight( + dim=self._layer_head_dim, + weight_name=self._q_norm_weight_name, + data_type=self.data_type_, + ) + self.k_norm_weight_ = RMSNormWeight( + dim=self._layer_head_dim, + weight_name=self._k_norm_weight_name, + data_type=self.data_type_, + ) + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + self.pre_feedforward_layernorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._pre_feedforward_layernorm_name, + data_type=self.data_type_, + ) + self.post_feedforward_layernorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_name, + data_type=self.data_type_, + ) + if self._is_moe: + self.post_feedforward_layernorm_1_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_1_name, + data_type=self.data_type_, + ) + self.pre_feedforward_layernorm_2_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._pre_feedforward_layernorm_2_name, + data_type=self.data_type_, + ) + self.post_feedforward_layernorm_2_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_2_name, + data_type=self.data_type_, + ) + self.layer_scalar_ = ParameterWeight( + weight_name=self._layer_scalar_name, + data_type=self.data_type_, + weight_shape=(1,), + ) diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py new file mode 100644 index 0000000000..10b1958b0e --- /dev/null +++ b/lightllm/models/gemma4/model.py @@ -0,0 +1,216 @@ +import os +import json +import torch +from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.build_utils import repair_config +from lightllm.models.llama.model import LlamaTpPartModel +from lightllm.models.gemma4.infer_struct import Gemma4InferStateInfo +from lightllm.models.gemma4.layer_infer.pre_layer_infer import Gemma4PreLayerInfer +from lightllm.models.gemma4.layer_infer.post_layer_infer import Gemma4PostLayerInfer +from lightllm.models.gemma4.layer_infer.transformer_layer_infer import Gemma4TransformerLayerInfer +from lightllm.models.gemma4.layer_weights.pre_and_post_layer_weight import Gemma4PreAndPostLayerWeight +from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num, get_env_start_args +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager + +logger = init_logger(__name__) + + +@ModelRegistry("gemma4", is_multimodal=True) +class Gemma4TpPartModel(LlamaTpPartModel): + pre_and_post_weight_class = Gemma4PreAndPostLayerWeight + transformer_weight_class = Gemma4TransformerLayerWeight + + pre_layer_infer_class = Gemma4PreLayerInfer + transformer_layer_infer_class = Gemma4TransformerLayerInfer + post_layer_infer_class = Gemma4PostLayerInfer + + infer_state_class = Gemma4InferStateInfo + + def __init__(self, kvargs): + # head_dim_ is used by the default _init_to_get_rotary which we + # override; still set it to the sliding-layer head_dim for consistency + # with the mem manager and any generic helpers. + self.head_dim_ = 256 + super().__init__(kvargs) + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + self.config = json.load(json_file) + # The shipped checkpoint is a multimodal config wrapping a Gemma4TextConfig + # under text_config; flatten it so downstream code sees text-model fields + # at the top level (mirrors the gemma3 approach). + if "text_config" in self.config: + self.config = self.config["text_config"].copy() + + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + self._reset_num_key_value_heads() + + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + + if self.config.get("enable_moe_block", False): + # LightLLM's MoE helpers use Qwen/DeepSeek-style field names. + # Gemma-4 checkpoints expose equivalent values as top_k_experts + # and moe_intermediate_size. + self.config.setdefault("num_experts_per_tok", self.config["top_k_experts"]) + self.config.setdefault("norm_topk_prob", True) + self.config.setdefault("scoring_func", "softmax") + return + + def _verify_params(self): + assert self.load_way == "HF", "Gemma-4 only supports HF format." + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + assert self.config["num_key_value_heads"] % self.tp_world_size_ == 0 + # Use `or` rather than the dict.get default: E4B-style configs ship + # `num_global_key_value_heads: null`, which the default form would + # leave as None. + num_global_kv = self.config.get("num_global_key_value_heads") or self.config["num_key_value_heads"] + assert ( + num_global_kv % self.tp_world_size_ == 0 + ), f"num_global_key_value_heads={num_global_kv} must be divisible by tp={self.tp_world_size_}" + kv_shared = self.config.get("num_kv_shared_layers") or 0 + assert 0 <= kv_shared < self.config["num_hidden_layers"], ( + f"num_kv_shared_layers={kv_shared} out of range for " + f"num_hidden_layers={self.config['num_hidden_layers']}" + ) + return + + def _init_mem_manager(self): + # Uniform per-layer KV cache layout. The per-layer cache slot must fit + # whichever layer type has the largest per-token K/V width: sliding + # (num_key_value_heads * head_dim) or full + # (num_global_kv * global_head_dim). Keep cache_slot_dim = head_dim + # and pick cache_slot_num = max-width / head_dim. For 31B this + # collapses to num_key_value_heads; for E4B the full-attn shape wins + # (2*512 > 2*256), so it uses 4 storage slots of 256 dims. + # Gemma4TransformerLayerInfer.__init__ computes the same value and + # uses it to pack/unpack K/V at write/read time. + head_dim = self.config["head_dim"] + num_global_kv = self.config.get("num_global_key_value_heads") or self.config["num_key_value_heads"] + sliding_total = self.config["num_key_value_heads"] * self.config["head_dim"] + full_total = num_global_kv * self.config["global_head_dim"] + per_token_k_width = max(sliding_total, full_total) + head_num_per_rank = (per_token_k_width // head_dim) // self.tp_world_size_ + self.mem_manager = select_mem_manager_class()( + self.max_total_token_num, + dtype=self.data_type, + head_num=head_num_per_rank, + head_dim=head_dim, + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), + mem_fraction=self.mem_fraction, + ) + return + + def _init_att_backend(self): + # Gemma-4 has per-layer heterogeneous attention: sliding layers use + # (head_dim=256, kv_heads=16); full-attn layers use (head_dim=512, + # kv_heads=4, k_eq_v). FA3 caps head_dim at 256 and flashinfer plans + # once per infer_state on a single shape — both unworkable for the + # heterogeneous layout. Both layer kinds go through triton. + # + # Primary backend = sliding layers. Sliding prefill bypasses the + # backend and calls gemma4_mm directly (SWA + image bidi in one + # pass); the prefill_att_state created here is unused but the + # framework requires prefill_att_backend to be non-None. + self.prefill_att_backend = TritonAttBackend(model=self) + self.decode_att_backend = TritonAttBackend(model=self) + + def _init_att_backend1(self): + # Secondary backend = full-attn layers (head_dim=512, plain causal). + self.prefill_att_backend1 = TritonAttBackend(model=self) + self.decode_att_backend1 = TritonAttBackend(model=self) + + def _init_custom(self): + self._init_to_get_rotary_gemma4() + if self.config.get("enable_moe_block", False): + dist_group_manager.new_deepep_group( + self.config["num_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", self.config.get("top_k_experts", 1)), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) + self._init_ple_static_buffer() + + def _init_ple_static_buffer(self): + ple_dim = self.config.get("hidden_size_per_layer_input") or 0 + if ple_dim <= 0: + return + args = get_env_start_args() + max_tokens = max( + int(self.batch_max_tokens or 0), + int(self.graph_max_batch_size or 0), + int(getattr(args, "prefill_cudagraph_max_handle_token", 0) or 0), + ) + assert max_tokens > 0, "PLE static buffer needs a positive max-token bound" + num_layers = self.config["num_hidden_layers"] + buf = torch.zeros((max_tokens, num_layers, ple_dim), dtype=self.data_type, device="cuda") + self.pre_infer.ple_static_buffer = buf + for layer_infer in self.layers_infer: + layer_infer.ple_static_buffer = buf + logger.info( + f"Allocated PLE static buffer: tokens={max_tokens}, layers={num_layers}, " + f"ple_dim={ple_dim}, dtype={self.data_type}" + ) + + def _init_to_get_rotary_gemma4(self): + # gemma4 当前不支持 dp prefill balance + assert self.args.enable_dp_prefill_balance is False, "Gemma-4 does not support dp prefill balance" + + rope_params = self.config["rope_parameters"] + + # Cap the rotary table at something we can fit in memory — Gemma-4's + # advertised max_position_embeddings is 262144 which would require + # ~200MB per table in fp32. Rely on the server's max_seq_length instead. + max_seq_len = max(self.max_seq_length + 1024, 16384) + + t = torch.arange(max_seq_len, dtype=torch.float32, device="cpu") + + # Sliding layers: default RoPE, theta=10000, full rotation over head_dim=256. + sliding_params = rope_params["sliding_attention"] + sliding_head_dim = self.config["head_dim"] + sliding_theta = sliding_params["rope_theta"] + sliding_partial = sliding_params.get("partial_rotary_factor", 1.0) + sliding_rot_dim = int(sliding_head_dim * sliding_partial) + inv_freq_sliding = 1.0 / ( + sliding_theta ** (torch.arange(0, sliding_rot_dim, 2, dtype=torch.float32) / sliding_rot_dim) + ) + freqs_s = torch.outer(t, inv_freq_sliding) + self._cos_cached_sliding = torch.cos(freqs_s).to(self.data_type).cuda() + self._sin_cached_sliding = torch.sin(freqs_s).to(self.data_type).cuda() + + # Full-attention layers: proportional RoPE, theta=1_000_000, + # partial_rotary_factor=0.25 over global_head_dim=512. + # Proportional semantics (HF transformers): + # rope_angles = int(partial * head_dim // 2) -> 64 + # inv_freq[0:rope_angles] = 1 / base ** (arange(0, 2*rope_angles, 2) / head_dim) + # inv_freq[rope_angles:head_dim//2] = 0 (identity rotation for "no-pe" dims) + full_params = rope_params["full_attention"] + full_head_dim = self.config["global_head_dim"] + full_theta = full_params["rope_theta"] + full_partial = full_params.get("partial_rotary_factor", 1.0) + rope_type = full_params.get("rope_type", "default") + if rope_type == "proportional": + rope_angles = int(full_partial * full_head_dim // 2) + inv_freq_rot = 1.0 / ( + full_theta ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.float32) / full_head_dim) + ) + nope_angles = full_head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq_full = torch.cat([inv_freq_rot, torch.zeros(nope_angles, dtype=torch.float32)]) + else: + inv_freq_full = inv_freq_rot + else: + full_rot_dim = int(full_head_dim * full_partial) + inv_freq_full = 1.0 / (full_theta ** (torch.arange(0, full_rot_dim, 2, dtype=torch.float32) / full_rot_dim)) + + freqs_f = torch.outer(t, inv_freq_full) + self._cos_cached_full = torch.cos(freqs_f).to(self.data_type).cuda() + self._sin_cached_full = torch.sin(freqs_f).to(self.data_type).cuda() + return diff --git a/lightllm/models/gemma4/tokenizer.py b/lightllm/models/gemma4/tokenizer.py new file mode 100644 index 0000000000..5a675856f2 --- /dev/null +++ b/lightllm/models/gemma4/tokenizer.py @@ -0,0 +1,93 @@ +import math + +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.server.core.objs.sampling_params import SamplingParams +from lightllm.server.multimodal_params import AudioItem, ImageItem, MultimodalParams + + +class Gemma4Tokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer, model_cfg, image_processor=None): + super().__init__(tokenizer) + self.image_token_index = model_cfg.get("image_token_id", 258880) + self.boi_token_index = model_cfg.get("boi_token_id", 255999) + self.eoi_token_index = model_cfg.get("eoi_token_id", 258882) + self.image_processor = image_processor + self.image_length = model_cfg.get("vision_soft_tokens_per_image", 280) + self.patch_size = getattr(self.image_processor, "patch_size", 16) + self.pooling_kernel_size = getattr(self.image_processor, "pooling_kernel_size", 3) + self.max_soft_tokens = getattr(self.image_processor, "max_soft_tokens", self.image_length) + # HF Gemma-4 tokenizer does not prepend BOS even with add_special_tokens=True. + self.bos_token_id = tokenizer.bos_token_id + + def init_imageitem_extral_params( + self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + return + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + raise NotImplementedError + + def get_image_token_length(self, img: ImageItem): + if self.image_processor is None or img.image_w <= 0 or img.image_h <= 0: + return self.image_length + + patch, kernel = self.patch_size, self.pooling_kernel_size + unit = patch * kernel + num_patches_orig = (img.image_h / patch) * (img.image_w / patch) + scale = math.sqrt(self.max_soft_tokens * kernel ** 2 / num_patches_orig) + target_h = max(unit, int(math.floor(img.image_h * scale / unit)) * unit) + target_w = max(unit, int(math.floor(img.image_w * scale / unit)) * unit) + num_patches = (target_h // patch) * (target_w // patch) + return min(num_patches // kernel ** 2, self.max_soft_tokens) + + def get_audio_token_length(self, audio: AudioItem): + raise NotImplementedError + + def encode(self, prompt, multimodal_params: MultimodalParams = None, add_special_tokens=False): + origin_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids + if ( + add_special_tokens + and self.bos_token_id is not None + and (len(origin_ids) == 0 or origin_ids[0] != self.bos_token_id) + ): + origin_ids = [self.bos_token_id] + origin_ids + + images = [] if multimodal_params is None else getattr(multimodal_params, "images", []) + if not images: + return origin_ids + + input_ids = [] + image_id = 0 + start = 0 + while True: + try: + image_start = origin_ids.index(self.image_token_index, start) + except ValueError: + break + + input_ids.extend(origin_ids[start:image_start]) + image_end = image_start + 1 + while image_end < len(origin_ids) and origin_ids[image_end] == self.image_token_index: + image_end += 1 + if image_id >= len(images): + raise ValueError("image token error") + + img = images[image_id] + if not input_ids or input_ids[-1] != self.boi_token_index: + input_ids.append(self.boi_token_index) + img.start_idx = len(input_ids) + input_ids.extend(range(img.token_id, img.token_id + img.token_num)) + input_ids.append(self.eoi_token_index) + + if image_end < len(origin_ids) and origin_ids[image_end] == self.eoi_token_index: + image_end += 1 + start = image_end + image_id += 1 + + input_ids.extend(origin_ids[start:]) + image_cnt = len(images) + if image_cnt != image_id: + raise ValueError(f"invalid image tag num: {image_cnt} vs {image_id}!") + return input_ids diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/__init__.py b/lightllm/models/gemma4/triton_kernel/__init__.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/__init__.py rename to lightllm/models/gemma4/triton_kernel/__init__.py diff --git a/lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py b/lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py new file mode 100644 index 0000000000..bb5f383611 --- /dev/null +++ b/lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py @@ -0,0 +1,172 @@ +"""GPU-resident builder for ``b_image_token_end``. + +Replaces a 3× D2H sync + Python per-batch-image slice-fill in CPU memory +with a single small H2D copy (image metadata) + one Triton kernel that +scatters the image-end markers into the flat-Q-token tensor on GPU. + +Adapted from neo_chat_moe's `get_neo_position_triton`. Same per-batch +program structure; we only emit the `b_image_token_end` scatter (no 3D +position_ids — gemma-4 uses 1D position ids). +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _build_b_image_token_end_kernel( + B_Image_Start_Idx, # (num_imgs,) int32, image span start in absolute request position + B_Image_Len, # (num_imgs,) int32, image token count + B_Image_Nums, # (batch,) int32, per-batch image count + B_Image_Start_Num, # (batch,) int32, prefix-sum offset into flat per-image arrays + B_Q_Start_Loc, # (batch,) int32, per-batch start in flat layout + B_Ready_Cache_Len, # (batch,) int32, per-batch prompt-cache length + B_Q_Seq_Len, # (batch,) int32, per-batch new-token count + B_Image_Token_End, # (sum_q,) int32, output scatter target + BLOCK_SIZE: tl.constexpr, +): + cur_batch = tl.program_id(0) + cache_len = tl.load(B_Ready_Cache_Len + cur_batch) + q_seq_len = tl.load(B_Q_Seq_Len + cur_batch) + image_num = tl.load(B_Image_Nums + cur_batch) + image_start_num = tl.load(B_Image_Start_Num + cur_batch) + flat_start = tl.load(B_Q_Start_Loc + cur_batch) + + for i in range(image_num): + image_start_idx = tl.load(B_Image_Start_Idx + image_start_num + i) + image_len = tl.load(B_Image_Len + image_start_num + i) + image_end_idx = image_start_idx + image_len + # Flat layout offset of the image's first token within this batch. + flat_image_start = flat_start + image_start_idx - cache_len + + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + in_image = off < image_len + # Only fill positions that fall inside this batch's NEW-tokens range + # (i.e., the part of the image that hasn't already been processed + # in a previous chunked-prefill chunk and isn't past the chunk's end). + in_new_tokens = (image_start_idx - cache_len + off >= 0) & (image_start_idx - cache_len + off < q_seq_len) + tl.store( + B_Image_Token_End + flat_image_start + off, + image_end_idx, + mask=in_image & in_new_tokens, + ) + + +def build_b_image_token_end( + b_image_start_idx: torch.Tensor, + b_image_len: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_q_start_loc: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_image_token_end: torch.Tensor, +): + batch_size = b_q_start_loc.shape[0] + assert b_image_nums.shape[0] == batch_size + grid = (batch_size,) + BLOCK_SIZE = 64 + _build_b_image_token_end_kernel[grid]( + b_image_start_idx, + b_image_len, + b_image_nums, + b_image_start_num, + b_q_start_loc, + b_ready_cache_len, + b_q_seq_len, + b_image_token_end, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +# --------------------------------------------------------------------------- +# Standalone correctness check +# --------------------------------------------------------------------------- + + +def _reference( + multimodal_params, + b_q_start_loc_cpu, + b_ready_cache_len_cpu, + b_q_seq_len_cpu, + sum_q, +): + out = torch.zeros((sum_q,), dtype=torch.int32) + for batch_idx, params in enumerate(multimodal_params): + cache_len = b_ready_cache_len_cpu[batch_idx] + new_len = b_q_seq_len_cpu[batch_idx] + flat_start = b_q_start_loc_cpu[batch_idx] + for img in params.get("images", []): + image_start_idx = img["start_idx"] + image_end_idx = image_start_idx + img["token_num"] + for j in range(img["token_num"]): + req_off = image_start_idx - cache_len + j + if req_off < 0 or req_off >= new_len: + continue + out[flat_start + req_off] = image_end_idx + return out + + +def _check(): + device = "cuda" + # Two batches. b0 has 1 image overlapping new tokens; b1 has 2 images, one + # fully cached and one in the new-token range. + multimodal = [ + {"images": [{"start_idx": 5, "token_num": 4}]}, # b0: image at req[5..9) + { + "images": [ + {"start_idx": 0, "token_num": 3}, # fully cached + {"start_idx": 8, "token_num": 5}, # in new tokens + ] + }, + ] + b_q_start_loc = torch.tensor([0, 6], dtype=torch.int32) # b0 new=6, b1 new=10 + b_ready_cache_len = torch.tensor([2, 5], dtype=torch.int32) + b_q_seq_len = torch.tensor([6, 10], dtype=torch.int32) + sum_q = int(b_q_seq_len.sum().item()) + + ref = _reference( + multimodal, + b_q_start_loc.tolist(), + b_ready_cache_len.tolist(), + b_q_seq_len.tolist(), + sum_q, + ) + + b_image_start_idx = [] + b_image_len = [] + b_image_nums = [] + b_image_start_num = [] + image_start_num = 0 + for params in multimodal: + b_image_start_num.append(image_start_num) + b_image_nums.append(len(params["images"])) + for img in params["images"]: + b_image_start_idx.append(img["start_idx"]) + b_image_len.append(img["token_num"]) + image_start_num += 1 + + out_gpu = torch.zeros((sum_q,), dtype=torch.int32, device=device) + build_b_image_token_end( + b_image_start_idx=torch.tensor(b_image_start_idx, dtype=torch.int32, device=device), + b_image_len=torch.tensor(b_image_len, dtype=torch.int32, device=device), + b_image_nums=torch.tensor(b_image_nums, dtype=torch.int32, device=device), + b_image_start_num=torch.tensor(b_image_start_num, dtype=torch.int32, device=device), + b_q_start_loc=b_q_start_loc.to(device), + b_ready_cache_len=b_ready_cache_len.to(device), + b_q_seq_len=b_q_seq_len.to(device), + b_image_token_end=out_gpu, + ) + + out_cpu = out_gpu.cpu() + assert torch.equal(out_cpu, ref), f"\n got {out_cpu.tolist()}\n ref {ref.tolist()}" + print("ok", out_cpu.tolist()) + + +if __name__ == "__main__": + if torch.cuda.is_available(): + _check() + else: + print("No CUDA, skip.") diff --git a/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py new file mode 100644 index 0000000000..dee10e96d3 --- /dev/null +++ b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py @@ -0,0 +1,488 @@ +"""Gemma-4 prefill attention kernel with image bidirectional masking. + +Gemma-4 was trained with bidirectional attention inside each image span on its +sliding-window layers (matches HF/vllm `use_bidirectional_attention="vision"`). +Other lightllm multimodal models use causal attention on image tokens, so the +shared prefill kernel does not need this — keep the modification scoped to +this gemma4-private file rather than the common path. + +The kernel mirrors `context_flashattention_nopad._fwd_kernel` (paged KV via +req_to_token_indexs, prompt_cache_len for chunked prefill, sliding window +support, head_dim=256/512 with BLOCK_M reduction) and adds two ideas borrowed +from `lightllm-neo/.../context_attention_fwd_neo`: + +1. Per-Q `b_image_token_end` tensor of shape (sum_q,). For Q tokens inside an + image span it carries the span's end index; for text tokens it is 0. + The attention mask becomes `local_or_causal_mask | (k_pos < q_image_end)`. +2. K/V iteration upper bound is extended to `max(local_end, block_image_end)` + so a Q tile in the middle of an image span actually loads K/V tiles past + its causal end. Without this, the bidi mask in the original diff was a + no-op on every tile but the last one of the image span. + +The standalone `reference_attention` and `check_once` are runnable as a script +for unit testing image bidi correctness. +""" + +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + B_Image_Token_End, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + H: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_LEFT: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len # new tokens this step + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + if block_start_loc >= cur_batch_seq_len: + return + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + q_valid = offs_m < cur_batch_seq_len + + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # Per-Q image_end. 0 for non-image tokens, image-span end for image tokens. + q_image_end = tl.load( + B_Image_Token_End + cur_batch_in_all_start_index + offs_m, + mask=q_valid, + other=0, + ).to(tl.int32) + + # Absolute position in the request (prompt_cache_len + offset within new tokens). + q_pos = prompt_cache_len + offs_m # [M] + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + causal_end = tl.minimum(prompt_cache_len + block_start_loc + BLOCK_M, total_len) + block_image_end = tl.minimum(tl.max(q_image_end, axis=0), total_len) + block_end_loc = tl.maximum(causal_end, block_image_end) + + if USE_SLIDING_WINDOW: + kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_LEFT + kv_start_index = tl.maximum(kv_start_index, 0) + block_kv_len = block_end_loc - kv_start_index + else: + kv_start_index = 0 + block_kv_len = block_end_loc + + for start_n in range(0, block_kv_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_pos = kv_start_index + start_n + offs_n # [N] + k_valid = k_pos < block_end_loc + + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + qk = tl.dot(q, k) + + if USE_SLIDING_WINDOW: + # Sliding window: FA-style left inclusive offset + causal (right=0). + local_mask = ((q_pos[:, None] - k_pos[None, :]) <= SLIDING_WINDOW_LEFT) & (q_pos[:, None] >= k_pos[None, :]) + else: + local_mask = q_pos[:, None] >= k_pos[None, :] + # Image bidi: a Q in image span [_, e) attends to all K with k_pos < e. + # For text Q (q_image_end == 0) this is k_pos < 0 = always False, so + # the union with local_mask leaves text-attention unchanged. + image_mask = k_pos[None, :] < q_image_end[:, None] + mask = local_mask | image_mask + + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_gemma4_mm( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=(-1, -1), +): + """Prefill attention with image bidirectional masking on sliding layers. + + Args: + sliding_window: ``(-1, -1)`` disables SWA; otherwise ``(left, 0)`` with + FA-style inclusive left offset and causal right bound (right must be 0). + b_image_token_end: int32 tensor of shape (sum_q,). For each Q token + position (in the flattened new-token layout), value is the image + span's end index (in absolute request position) if the token is + inside an image span, else 0. + """ + BLOCK_M = 128 if not is_tesla() else 64 + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256, 512} + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) + + sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + if sliding_window == (-1, -1): + use_sliding_window = False + sliding_window_left = -1 + else: + use_sliding_window = True + assert int(sliding_window[1]) == 0, "sliding_window right must be 0" + sliding_window_left = int(sliding_window[0]) + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + b_image_token_end, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + H=head, + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_LEFT=sliding_window_left, + num_warps=num_warps, + num_stages=num_stages, + ) + + +# --------------------------------------------------------------------------- +# Reference implementation + standalone test harness +# --------------------------------------------------------------------------- + + +def reference_attention( + q, + k, + v, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=(-1, -1), +): + """Slow torch reference for the gemma4 mm prefill kernel. + + `sliding_window` is (left, 0) using FA-style inclusive left offset with causal + right bound. (-1, -1) disables SWA. + """ + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + out = torch.empty_like(q) + scale = 1.0 / math.sqrt(D) + + if sliding_window == (-1, -1): + use_sliding_window = False + sliding_window_left = 0 + else: + use_sliding_window = True + sliding_window_left = int(sliding_window[0]) + assert int(sliding_window[1]) == 0, "sliding_window right must be 0" + + batch = b_seq_len.shape[0] + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + new_len = total_len - prompt_len + q_start = int(b_start_loc[b].item()) + + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + q_image_end = b_image_token_end[q_start : q_start + new_len].to(torch.int64) # [M] + + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) + k_blk = k[token_locs] + v_blk = v[token_locs] + + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) + + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) + + if use_sliding_window: + causal = ((q_pos[:, None] - k_pos[None, :]) <= sliding_window_left) & (q_pos[:, None] >= k_pos[None, :]) + else: + causal = k_pos[None, :] <= q_pos[:, None] + image = k_pos[None, :] < q_image_end[:, None] + allow = causal | image + + q_t = q_blk.permute(1, 0, 2).to(torch.float32) + k_t = k_hq.permute(1, 2, 0).to(torch.float32) + scores = torch.matmul(q_t, k_t) * scale + + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + p = torch.softmax(scores, dim=-1) + v_t = v_hq.permute(1, 0, 2).to(torch.float32) + out_hq = torch.matmul(p, v_t) + out[q_start : q_start + new_len] = out_hq.permute(1, 0, 2).to(dtype) + + return out + + +def make_test_case( + device="cuda", + dtype=torch.bfloat16, + batch=3, + Hq=8, + Hk=4, + D=256, + seed=0, + base_index=50000, + sliding_window=(-1, -1), +): + torch.manual_seed(seed) + + prompt_lens = torch.randint(low=0, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=4, high=24, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) + max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) + + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(new_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 1024 + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + req_to_token_indexs = torch.zeros((batch, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(batch): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # Inject one image span per batch into the new-token region with prob 0.7. + b_image_token_end = torch.zeros((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + P = int(prompt_lens[b].item()) + start = int(b_start_loc[b].item()) + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + span_len = int(torch.randint(2, max(3, M - s + 1), (1,), device=device).item()) + e = min(M, s + span_len) + # image_end is absolute (request-position) = prompt_len + new-offset + b_image_token_end[start + s : start + e] = P + e + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + b_image_token_end, + sliding_window, + ) + + +def check_once(seed=0, dtype=torch.bfloat16, sliding_window=(-1, -1), D=256): + case = make_test_case(seed=seed, dtype=dtype, sliding_window=sliding_window, D=D) + ( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + b_image_token_end, + sliding_window, + ) = case + + context_attention_fwd_gemma4_mm( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=sliding_window, + ) + + ref = reference_attention( + q, + k, + v, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=sliding_window, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + has_image = (b_image_token_end > 0).any().item() + print( + f"seed={seed} dtype={dtype} D={D} sw={sliding_window} has_image={has_image} " + f"max_abs={max_abs:.4e} max_rel={max_rel:.4e}" + ) + assert max_abs < 5e-2, f"max_abs too large: {max_abs}" + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + # Vary D, sliding window, and image presence. + for seed in (0, 1, 2): + check_once(seed=seed, D=128, sliding_window=(-1, -1)) + check_once(seed=seed, D=128, sliding_window=(4096, 0)) + check_once(seed=seed, D=256, sliding_window=(4096, 0)) + print("ok") diff --git a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py index 2ed325659d..f535fa7d66 100644 --- a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py @@ -25,13 +25,14 @@ def __init__(self, layer_num, network_config): def _ffn( self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bTransformerLayerWeight ) -> torch.Tensor: - up_gate_out = layer_weight.gate_up_proj.mm(input.view(-1, self.embed_dim_)) + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + up_gate_out = layer_weight.gate_up_proj.mm(input) ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) gelu_and_mul_fwd(up_gate_out, ffn1_out) input = None up_gate_out = None - ffn2_out = layer_weight.down_proj.mm( - ffn1_out, - ) + ffn2_out = layer_weight.down_proj.mm(ffn1_out) ffn1_out = None + ffn2_out = self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) return ffn2_out diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py index a8fe49ac5e..1e31306aea 100644 --- a/lightllm/models/glm4_moe_lite/model.py +++ b/lightllm/models/glm4_moe_lite/model.py @@ -25,7 +25,12 @@ def _init_config(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _init_to_get_yarn_rotary(self): rope_scaling = self.config.get("rope_scaling") diff --git a/lightllm/models/glm4_moe_lite_mtp/model.py b/lightllm/models/glm4_moe_lite_mtp/model.py index 549bf7ce41..2e4ba5c86b 100644 --- a/lightllm/models/glm4_moe_lite_mtp/model.py +++ b/lightllm/models/glm4_moe_lite_mtp/model.py @@ -10,6 +10,9 @@ class Glm4MoeLiteMTPModel(Glm4MoeLiteTpPartModel): + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True + pre_and_post_weight_class = Glm4MoeLiteMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index d80eefd16e..b27ea8fd2d 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -41,6 +41,7 @@ def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6): def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: hidden_states = input.view(-1, self.embed_dim_) + hidden_states = self._tpsp_allgather(input=hidden_states, infer_state=infer_state) num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) hidden_states = layer_weight.experts.experts( @@ -52,7 +53,8 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - topk_group=None, num_expert_group=None, ) - return hidden_states.view(num_tokens, hidden_dim) + hidden_states = hidden_states.view(num_tokens, hidden_dim) + return self._tpsp_reduce(input=hidden_states, infer_state=infer_state) def _context_attention_kernel( self, @@ -63,7 +65,7 @@ def _context_attention_kernel( out=None, ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": - window_size = (self.sliding_window - 1, self.sliding_window - 1) + window_size = (self.sliding_window - 1, 0) use_sliding_window = True else: window_size = (-1, -1) @@ -90,7 +92,7 @@ def _token_attention_kernel( self, q: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": - window_size = (self.sliding_window - 1, self.sliding_window - 1) + window_size = (self.sliding_window - 1, 0) use_sliding_window = True else: window_size = (-1, -1) diff --git a/lightllm/models/internvl/internvl_visual.py b/lightllm/models/internvl/internvl_visual.py index 093ad2b5d1..8add1568f3 100644 --- a/lightllm/models/internvl/internvl_visual.py +++ b/lightllm/models/internvl/internvl_visual.py @@ -58,7 +58,7 @@ def encode(self, images: List[ImageItem]): t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) cur_num = img_tensors[-1].shape[0] valid_ids.append([valid_id, valid_id + cur_num]) diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 7714164151..50dc0109e2 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -34,8 +34,8 @@ def _slice_get_last_input(self, input_embdings: torch.Tensor, infer_state: Llama start_index += cur_len select_token_num += 1 - last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) - last_input = self.alloc_tensor((select_token_num, embed_dim_), dtype=input_embdings.dtype) + last_index = torch.tensor(select_index, dtype=torch.long, device="cpu").cuda(non_blocking=True) + last_input = self.alloc_tensor((select_token_num, embed_dim_), dtype=input_embdings.dtype, device="cuda") last_input[:, :] = input_embdings[last_index, :] return last_input, select_token_num @@ -58,7 +58,7 @@ def _slice_get_last_input(self, input_embdings: torch.Tensor, infer_state: Llama assert False, "Error State" - def token_forward( + def _token_forward( self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight ): last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) @@ -89,28 +89,11 @@ def token_forward( gather_data = None return ans_logics - def tpsp_token_forward( + def token_forward( self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight ): - if self.tp_world_size_ > 1: - assert len(input_embdings.shape) == 2 - token_num, hidden_dim = input_embdings.shape - gather_data = torch.empty( - (self.tp_world_size_ * token_num, hidden_dim), device=input_embdings.device, dtype=input_embdings.dtype - ) - all_gather( - [gather_data[i * token_num : (i + 1) * token_num, :] for i in range(self.tp_world_size_)], - input_embdings, - group=infer_state.dist_group, - async_op=False, - ) - # len(infer_state.input_ids) 获取真实输入长度 - input_embdings = gather_data[0 : len(infer_state.input_ids)] - if infer_state.need_dp_prefill_balance: - input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) - - return self.token_forward(input_embdings=input_embdings, infer_state=infer_state, layer_weight=layer_weight) + return self._token_forward(input_embdings=input_embdings, infer_state=infer_state, layer_weight=layer_weight) def overlap_tpsp_token_forward( self, @@ -120,16 +103,9 @@ def overlap_tpsp_token_forward( infer_state1: LlamaInferStateInfo, layer_weight: BaseLayerWeight, ): - if getattr(infer_state, "hook", None) is not None: - infer_state.hook() - infer_state.hook = None - - logics = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight) - if getattr(infer_state1, "hook", None) is not None: - infer_state1.hook() - infer_state1.hook = None + logics = self.token_forward(input_embdings, infer_state, layer_weight=layer_weight) - logics1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) + logics1 = self.token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) return logics, logics1 diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index 63a2fe4d14..edeb764ec9 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -15,6 +15,7 @@ def __init__(self, network_config): return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) @@ -26,25 +27,6 @@ def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weigh all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings - def tpsp_context_forward( - self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight - ): - if get_env_start_args().enable_dp_prefill_balance: - input_ids = infer_state.prefill_dp_balance(input_ids=input_ids) - - input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) - return padded_input_embdings - - def tpsp_token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = self.token_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) - return padded_input_embdings - def overlap_tpsp_token_forward( self, input_ids: torch.Tensor, @@ -55,18 +37,9 @@ def overlap_tpsp_token_forward( ): input_embdings = self.token_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) - input_embdings1 = self.token_forward(input_ids=input_ids1, infer_state=infer_state1, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - padded_input_embdings1 = sp_pad_copy( - input_embdings1, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_ - ) - - return padded_input_embdings, padded_input_embdings1 + return input_embdings, input_embdings1 def overlap_tpsp_context_forward( self, @@ -76,24 +49,10 @@ def overlap_tpsp_context_forward( infer_state1: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight, ): - if get_env_start_args().enable_dp_prefill_balance: - input_ids = infer_state.prefill_dp_balance(input_ids=input_ids) input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) - - if get_env_start_args().enable_dp_prefill_balance: - input_ids1 = infer_state1.prefill_dp_balance(input_ids=input_ids1) - input_embdings1 = self.context_forward( input_ids=input_ids1, infer_state=infer_state1, layer_weight=layer_weight ) - from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy - - padded_input_embdings1 = sp_pad_copy( - input_embdings1, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_ - ) - return padded_input_embdings, padded_input_embdings1 + return input_embdings, input_embdings1 diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index dc6f10be59..69acffaa4d 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -79,28 +79,7 @@ def _ffn_norm( def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _tpsp_get_qkv( - self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight - ) -> torch.Tensor: - if self.tp_world_size_ > 1: - sp_token_num, hidden_dim = input.shape - gather_input = self.alloc_tensor( - (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device - ) - all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input[0 : len(infer_state.input_ids), :] - + input = self._tpsp_allgather(input, infer_state) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) @@ -119,62 +98,27 @@ def _tpsp_get_qkv( def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - o_tensor = layer_weight.o_proj.mm(input) - return o_tensor - - def _tpsp_get_o( - self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ - o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) - layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) - e_o_tensor = o_tensor[len(infer_state.input_ids) :, :] - if e_o_tensor.shape[0] > 0: - e_o_tensor.fill_(0) - - if self.tp_world_size_ > 1: - sp_token_num = o_tensor.shape[0] // self.tp_world_size_ - reduce_o_tensor = self.alloc_tensor((sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device) - reduce_scatter_tensor( - output=reduce_o_tensor, - input=o_tensor, - op=dist.ReduceOp.SUM, - group=infer_state.dist_group, - async_op=False, - ) - o_tensor = reduce_o_tensor + o_tensor = layer_weight.o_proj.mm(input) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.gate_up_proj.mm(input) - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - input = None - up_gate_out = None - ffn2_out = layer_weight.down_proj.mm(ffn1_out) - ffn1_out = None + input = self._tpsp_allgather(input=input, infer_state=infer_state) + ffn2_out = self._ffn_tp(input=input, infer_state=infer_state, layer_weight=layer_weight) + ffn2_out = self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) return ffn2_out - def _tpsp_ffn( + def _ffn_tp( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - if self.tp_world_size_ > 1: - sp_token_num, hidden_dim = input.shape - gather_input = self.alloc_tensor( - (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device - ) - all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input - up_gate_out = layer_weight.gate_up_proj.mm(input) ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) silu_and_mul_fwd(up_gate_out, ffn1_out) @@ -182,15 +126,6 @@ def _tpsp_ffn( up_gate_out = None ffn2_out = layer_weight.down_proj.mm(ffn1_out) ffn1_out = None - if self.tp_world_size_ > 1: - sp_token_num = ffn2_out.shape[0] // self.tp_world_size_ - reduce_o_tensor = self.alloc_tensor( - (sp_token_num, self.embed_dim_), dtype=ffn2_out.dtype, device=ffn2_out.device - ) - reduce_scatter_tensor( - reduce_o_tensor, ffn2_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False - ) - ffn2_out = reduce_o_tensor return ffn2_out # # keep code @@ -213,8 +148,8 @@ def overlap_tpsp_token_forward( infer_state1: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight, ): - input_embdings = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight) - input_embdings1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) + input_embdings = self.token_forward(input_embdings, infer_state, layer_weight=layer_weight) + input_embdings1 = self.token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) return input_embdings, input_embdings1 def overlap_tpsp_context_forward( @@ -225,6 +160,6 @@ def overlap_tpsp_context_forward( infer_state1: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight, ): - input_embdings = self.tpsp_context_forward(input_embdings, infer_state, layer_weight=layer_weight) - input_embdings1 = self.tpsp_context_forward(input_embdings1, infer_state1, layer_weight=layer_weight) + input_embdings = self.context_forward(input_embdings, infer_state, layer_weight=layer_weight) + input_embdings1 = self.context_forward(input_embdings1, infer_state1, layer_weight=layer_weight) return input_embdings, input_embdings1 diff --git a/lightllm/models/llama/triton_kernel/rotary_emb.py b/lightllm/models/llama/triton_kernel/rotary_emb.py index c6d4f3010d..f87b9d9e02 100755 --- a/lightllm/models/llama/triton_kernel/rotary_emb.py +++ b/lightllm/models/llama/triton_kernel/rotary_emb.py @@ -23,6 +23,7 @@ def _rotary_kernel( max_total_len, HEAD_Q, HEAD_K, # N_CTX 代表要计算的上下文长度 + HAS_K: tl.constexpr, BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -73,55 +74,59 @@ def _rotary_kernel( Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) ) - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) - - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - - k0 = tl.load( - K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - k1 = tl.load( - K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos - - tl.store( - K + off_k0, - out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - tl.store( - K + off_k1, - out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) + if HAS_K: + off_k0 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range0[None, None, :] * stride_kd + ) + off_k1 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range1[None, None, :] * stride_kd + ) + + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd + + k0 = tl.load( + K + off_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + k1 = tl.load( + K + off_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos + + tl.store( + K + off_k0, + out_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) + tl.store( + K + off_k1, + out_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) return @torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): +def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] + has_k = k is not None + head_num_q = q.shape[1] + head_num_k = k.shape[1] if has_k else 0 head_dim = int(q.shape[2] * partial_rotary_factor) assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + if has_k: + assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" BLOCK_SEQ = 16 BLOCK_HEAD = 4 @@ -139,9 +144,9 @@ def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): q.stride(0), q.stride(1), q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), + k.stride(0) if has_k else 0, + k.stride(1) if has_k else 0, + k.stride(2) if has_k else 0, cos.stride(0), cos.stride(1), sin.stride(0), @@ -149,6 +154,7 @@ def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): total_len, head_num_q, head_num_k, + HAS_K=has_k, BLOCK_HEAD=BLOCK_HEAD, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 293bcd4450..d4310a66db 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -138,7 +138,7 @@ def encode(self, images: List[ImageItem]): t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] img_tensors.append(t) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) cur_num = img_tensors[-1].shape[0] valid_ids.append([valid_id, valid_id + cur_num]) diff --git a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py index 6d72ae2c38..aee5dc2446 100644 --- a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py @@ -1,12 +1,5 @@ -import torch.functional as F -import torch.distributed as dist -import numpy as np from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer -from lightllm.distributed.communication_op import all_reduce -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) class MistralMTPTransformerLayerInfer(MistralTransformerLayerInfer): @@ -18,8 +11,6 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings @@ -27,7 +18,5 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py index 7c64625ca8..f17bc0a383 100644 --- a/lightllm/models/mistral_mtp/model.py +++ b/lightllm/models/mistral_mtp/model.py @@ -10,6 +10,9 @@ class MistralMTPModel(MistralTpPartModel): + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True + pre_and_post_weight_class = MistralMTPPreAndPostLayerWeight pre_layer_infer_class = MistralMTPPreLayerInfer diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 44e66cff2d..0cf651598a 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -17,9 +17,10 @@ def __init__(self, layer_num, network_config): def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransformerLayerWeight) -> torch.Tensor: hidden_states = input.view(-1, self.embed_dim_) + hidden_states = self._tpsp_allgather(input=hidden_states, infer_state=infer_state) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(input.view(-1, self.embed_dim_)) + router_logits = layer_weight.moe_gate.mm(hidden_states) topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, gating_output=router_logits, @@ -29,7 +30,7 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor ) from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - return fused_experts_impl( + ffn2_out = fused_experts_impl( hidden_states=hidden_states, w1=layer_weight.experts.w1[0], w2=layer_weight.experts.w2[0], @@ -41,3 +42,4 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor w2_scale=None, alloc_tensor_func=self.alloc_tensor, ) + return self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index fd3d05e426..fddb14cfe5 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -12,14 +12,18 @@ def __init__(self, layer_num, network_config): return def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight): - q = layer_weight.q_proj.mm(input_emb.view(-1, self.embed_dim_)) - cache_kv = layer_weight.kv_proj.mm( - input_emb.view(-1, self.embed_dim_), - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + input_emb = self._tpsp_allgather(input=input_emb.view(-1, self.embed_dim_), infer_state=infer_state) + q = layer_weight.q_proj.mm(input_emb) + cache_kv = layer_weight.kv_proj.mm(input_emb).view( + -1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_ + ) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, 0 : self.tp_k_head_num_, :], infer_state.position_cos, infer_state.position_sin, ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv diff --git a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py index 333870eb9d..003188d088 100755 --- a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py @@ -1,5 +1,4 @@ import torch -from typing import Tuple from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.qwen.layer_weights.transformer_layer_weight import QwenTransformerLayerWeight @@ -14,6 +13,7 @@ def __init__(self, layer_num, network_config): return def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeight): + input_emb = self._tpsp_allgather(input_emb, infer_state) q = layer_weight.q_proj.mm(input_emb) cache_kv = layer_weight.kv_proj.mm(input_emb).view( -1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_ @@ -27,8 +27,7 @@ def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: Qwe ) if infer_state.logn_values is not None: q.mul_(infer_state.logn_values.view(-1, 1)) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv - - def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: - # TODO - raise Exception("not impl") diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 7156a5ce23..1b3a5f0db7 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -221,7 +221,7 @@ def _init_datatype(self): elif self.data_type in ["fp32", "float32"]: self.data_type = torch.float32 else: - raise ValueError(f"Unsupport datatype {self.data_type}!") + raise ValueError(f"Unsupported datatype {self.data_type}!") return def rot_pos_emb(self, grid_thw): @@ -346,7 +346,7 @@ def load_image(self, img: List[ImageItem]): image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image_data) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) return pixel_values.to(dtype=self.data_type), image_grid_thw def load_model(self, weight_dir): @@ -387,7 +387,7 @@ def encode(self, images: List[ImageItem]): img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) # must devide merge_length cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2) diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index 747be932d9..04f7bc3895 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -20,13 +20,8 @@ def init_some_extra_state(self, model): if self.is_prefill: self.position_ids = self.get_mrope_position(self.multimodal_params) else: - b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] - for batch_idx, p in enumerate(self.multimodal_params): - position_delta = 0 - for image in p["images"]: - position_delta += image["grid_thwd"][3] - b_position_delta[batch_idx] = position_delta - position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + b_position_delta = self.b_position_delta.to(dtype=self.position_ids.dtype) + position_ids = self.position_ids + b_position_delta self.position_ids = position_ids.unsqueeze(0).expand(3, -1) self.position_ids = self.position_ids.contiguous() diff --git a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py index 298a77044c..ae6861071e 100755 --- a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py @@ -1,5 +1,4 @@ import torch -from typing import Tuple from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer @@ -11,6 +10,7 @@ def __init__(self, layer_num, network_config): self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") def _get_qkv(self, input, infer_state, layer_weight): + input = self._tpsp_allgather(input, infer_state) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) mrope_triton_fused( @@ -21,8 +21,7 @@ def _get_qkv(self, input, infer_state, layer_weight): self.mrope_section, is_interleaved=False, ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv - - def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: - # TODO - raise Exception("not impl") diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 6076756043..e02c3d9aa3 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -235,7 +235,7 @@ def _init_datatype(self): elif self.data_type in ["fp32", "float32"]: self.data_type = torch.float32 else: - raise ValueError(f"Unsupport datatype {self.data_type}!") + raise ValueError(f"Unsupported datatype {self.data_type}!") return def load_model(self, weight_dir): @@ -319,7 +319,7 @@ def encode(self, images: List[ImageItem]): img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) # must devide merge_length cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2) diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 725b0cc02e..3a66d506ca 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -22,6 +22,7 @@ def _get_qkv( layer_weight: Qwen3TransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input, infer_state) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) layer_weight.qk_norm_weight_( @@ -36,8 +37,7 @@ def _get_qkv( infer_state.position_cos, infer_state.position_sin, ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv - - def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: - # TODO - raise Exception("not impl") diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py index d837c4d291..d23475c1cf 100644 --- a/lightllm/models/qwen3_5/infer_struct.py +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -13,5 +13,7 @@ def __init__(self): def init_some_extra_state(self, model): super().init_some_extra_state(model) self.b_att_seq_len = self.b_seq_len - self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() + mtp_step = get_env_start_args().mtp_step + + self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index return diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index d0657bcbe8..649db03b11 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -1,5 +1,4 @@ import torch -import torch.distributed as dist from typing import Tuple from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( @@ -10,9 +9,6 @@ ) from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) class Qwen35TransformerLayerInfer(Qwen3NextTransformerLayerInfer): @@ -30,15 +26,21 @@ def _get_qkv( layer_weight: Qwen35TransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) - - qkv_out = layer_weight.qkv_proj.mm(input) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + + qkv_gate_out = layer_weight.qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - o_gate = layer_weight._o_gate_proj.mm(input) - # In-place sigmoid for gate - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_logics_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -55,4 +57,7 @@ def _get_qkv( is_interleaved=True, # Qwen3 uses interleaved mrope partial_rotary_factor=self.partial_rotary_factor, ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 721893a4cd..8879aa2d27 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -1,24 +1,15 @@ -import os import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np import triton +from functools import partial from typing import Tuple -from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -from functools import partial -from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe from lightllm.utils.dist_utils import get_global_world_size -from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.envs_utils import get_env_start_args -logger = init_logger(__name__) - class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): def __init__(self, layer_num, network_config): @@ -45,14 +36,11 @@ def _bind_ffn(self): if self.is_moe: enable_ep_moe = get_env_start_args().enable_ep_moe if enable_ep_moe: - self._ffn = partial(Qwen3MOETransformerLayerInfer._moe_ffn_edp, self) - self._tpsp_ffn = self._tpsp_ffn_ep + self._ffn = self._ffn_ep_impl else: - self._ffn = partial(Qwen3MOETransformerLayerInfer._moe_ffn, self) - self._tpsp_ffn = self._tpsp_ffn_tp + self._ffn = self._ffn_tp_impl else: self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) - self._tpsp_ffn = self._tpsp_ffn_tp def _get_qkv( self, @@ -61,6 +49,8 @@ def _get_qkv( layer_weight: Qwen3MOETransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + qkv = layer_weight.qkv_proj.mm(input) q, cache_kv = qkv.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 @@ -77,49 +67,14 @@ def _get_qkv( infer_state.position_cos, infer_state.position_sin, ) - return q, cache_kv - - def _tpsp_get_qkv( - self, - input: torch.Tensor, - infer_state: LlamaInferStateInfo, - layer_weight: Qwen3MOETransformerLayerWeight, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.tp_world_size_ > 1: - sp_token_num, hidden_dim = input.shape - gather_input = self.alloc_tensor( - (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device - ) - all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input[0 : len(infer_state.input_ids), :] - - input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input) - layer_weight.qk_norm_weight_( - q, - cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - eps=self.eps_, - ) - cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) - return q, cache_kv - def _moe_ffn( + def _moe_ffn_tp( self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight ) -> torch.Tensor: - hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -137,7 +92,6 @@ def _moe_ffn( def _moe_ffn_edp( self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight ) -> torch.Tensor: - hidden_states = input token_num, hidden_dim = hidden_states.shape @@ -156,44 +110,21 @@ def _moe_ffn_edp( ep_output = ep_output.view(token_num, hidden_dim) return ep_output - def _tpsp_ffn( - self, input: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight - ): - raise Exception("need bind to real impl") - - def _tpsp_ffn_tp( - self, input: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight + def _ffn_tp_impl( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - if self.tp_world_size_ > 1: - sp_token_num, hidden_dim = input.shape - gather_input = self.alloc_tensor( - (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device - ) - all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) - input = gather_input - - ffn2_out = self._ffn(input=input, infer_state=infer_state, layer_weight=layer_weight) - - if self.tp_world_size_ > 1: - sp_token_num = ffn2_out.shape[0] // self.tp_world_size_ - reduce_o_tensor = self.alloc_tensor( - (sp_token_num, self.embed_dim_), dtype=ffn2_out.dtype, device=ffn2_out.device - ) - reduce_scatter_tensor( - reduce_o_tensor, ffn2_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False - ) - ffn2_out = reduce_o_tensor + input = self._tpsp_allgather(input=input, infer_state=infer_state) + ffn2_out = self._moe_ffn_tp(input=input, infer_state=infer_state, layer_weight=layer_weight) + ffn2_out = self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) return ffn2_out - def _tpsp_ffn_ep( + def _ffn_ep_impl( self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight ) -> torch.Tensor: + # ep 本身就是一种 sp 兼容,所以不需要再进行 allgather 和 reduce input = input.view(-1, self.embed_dim_) - - ffn2_out = self._ffn(input=input, infer_state=infer_state, layer_weight=layer_weight) - - return ffn2_out + return self._moe_ffn_edp(input=input, infer_state=infer_state, layer_weight=layer_weight) def overlap_tpsp_token_forward( self, @@ -203,18 +134,18 @@ def overlap_tpsp_token_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) # 0 attention _0_input1 = self._att_norm(input_embdings, infer_state, layer_weight) - _0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, infer_state, layer_weight) + _0_q, _0_cache_kv = self._get_qkv(_0_input1, infer_state, layer_weight) _0_input1 = None self._post_cache_kv(_0_cache_kv, infer_state, layer_weight) _0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight) _0_q = None - _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) + _0_o = self._get_o(_0_o, infer_state, layer_weight) input_embdings.add_(_0_o.view(-1, self.embed_dim_)) _0_o = None _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) @@ -237,12 +168,12 @@ def overlap_tpsp_token_forward( # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) - _1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, infer_state1, layer_weight) + _1_q, _1_cache_kv = self._get_qkv(_1_input1, infer_state1, layer_weight) _1_input1 = None self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight) _1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight) _1_q = None - _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) + _1_o = self._get_o(_1_o, infer_state1, layer_weight) input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) _1_o = None _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) @@ -315,18 +246,18 @@ def overlap_tpsp_context_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) # 0 attention _0_input1 = self._att_norm(input_embdings, infer_state, layer_weight) - _0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, infer_state, layer_weight) + _0_q, _0_cache_kv = self._get_qkv(_0_input1, infer_state, layer_weight) _0_input1 = None self._post_cache_kv(_0_cache_kv, infer_state, layer_weight) _0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight) _0_q = None - _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) + _0_o = self._get_o(_0_o, infer_state, layer_weight) input_embdings.add_(_0_o.view(-1, self.embed_dim_)) _0_o = None _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) @@ -340,18 +271,18 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) - _1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, infer_state1, layer_weight) + _1_q, _1_cache_kv = self._get_qkv(_1_input1, infer_state1, layer_weight) _1_input1 = None self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight) _1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight) _1_q = None - _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) + _1_o = self._get_o(_1_o, infer_state1, layer_weight) input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) _1_o = None _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) @@ -378,8 +309,7 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 moe calu _0_moe_out = layer_weight.experts.prefilled_group_gemm( @@ -402,7 +332,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -417,7 +347,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b71d7f4878..0d4b45bfe6 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -27,4 +27,9 @@ def _init_custom(self): super()._init_custom() # Only initialize DeepEP group for MoE models with num_experts if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["num_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py index 4e2b65d743..9840986dce 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py @@ -1,16 +1,5 @@ -import os -import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np -import triton -from typing import Tuple from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer -from lightllm.distributed.communication_op import all_reduce -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) class Qwen3MOEMTPTransformerLayerInfer(Qwen3MOETransformerLayerInfer): @@ -22,8 +11,6 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings @@ -31,7 +18,5 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 9f83832a7e..d9854250e2 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -9,6 +9,9 @@ class Qwen3MOEMTPModel(Qwen3MOEModel): + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True + pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer diff --git a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py index 194914d455..67fd49cd1f 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py +++ b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py @@ -1,3 +1,4 @@ +import os import torch import numpy as np from typing import TYPE_CHECKING, Any, Optional, Union, Tuple @@ -7,6 +8,8 @@ from transformers.utils import TensorType from functools import lru_cache +MAX_AUDIO_DURATION_SECONDS = int(os.getenv("LIGHTLLM_MAX_AUDIO_DURATION_SECONDS", "3600")) + class WhisperFeatureExtractor(SequenceFeatureExtractor): @@ -47,6 +50,7 @@ def __init__( norm="slaney", mel_scale="slaney", ) + self.max_audio_len = MAX_AUDIO_DURATION_SECONDS * sampling_rate @lru_cache(maxsize=12) def get_hann_window(self, device: Union[str, torch.device]): @@ -140,7 +144,7 @@ def _preprocess( padded_inputs = self.pad( batched_speech, padding=padding, - max_length=max_length if max_length else self.n_samples, + max_length=max_length if max_length else self.max_audio_len, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask or do_normalize, diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py index 1b8fa0110d..45fb16f701 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/model.py +++ b/lightllm/models/qwen3_omni_moe_thinker/model.py @@ -18,6 +18,7 @@ ) from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel +from lightllm.models.qwen3_omni_moe_thinker.audio_process import MAX_AUDIO_DURATION_SECONDS from lightllm.models.qwen3_omni_moe_thinker.infer_struct import Qwen3OmniMOEInferStateInfo from lightllm.models.qwen3_vl.model import QWen3VLTokenizer from lightllm.server.core.objs import SamplingParams @@ -44,6 +45,7 @@ def __init__(self, tokenizer=None, processor=None, **kwargs): self.sampling_rate = self.audio_processor.sampling_rate self.n_samples = self.audio_processor.n_samples self.hop_length = self.audio_processor.hop_length + self.max_audio_len = MAX_AUDIO_DURATION_SECONDS * self.sampling_rate self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"] self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"] @@ -59,13 +61,14 @@ def init_audioitem_extral_params( return def get_audio_token_length(self, audio: AudioItem): - # 这里得处理对应奖语音长度按照 30 进行限制,后续处理中,超过30的会被截断。 - if audio.audio_length > self.n_samples: - logger.warning(f"audio length {audio.audio_length} exceed max length {self.n_samples}, will be truncated.") + # 这里得处理对应奖语音长度按照 默认值1h 进行限制,后续处理中,超过 1h 的会被截断。 + if audio.audio_length > self.max_audio_len: + logger.warning( + f"audio length {audio.audio_length} exceed max length {self.max_audio_len}, will be truncated." + ) - length = min(audio.audio_length, int(self.n_samples)) + length = min(audio.audio_length, int(self.max_audio_len)) token_num = self._caclu_audio_token_num(length) - # print(f"token_num is {token_num} n_samples is {self.n_samples} hop_length is {self.hop_length}") return token_num @lru_cache(maxsize=128) diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py index 03c57126ff..b9ea66fb15 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -10,10 +10,15 @@ from transformers.activations import ACT2FN from lightllm.server.multimodal_params import AudioItem +from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.models.qwen3_omni_moe_thinker.audio_process import WhisperFeatureExtractor +QWEN3_OMNI_CONV_CHUNKSIZE = int(os.getenv("LIGHTLLM_QWEN3_OMNI_CONV_CHUNKSIZE", 200)) + +logger = init_logger(__name__) + def _get_feat_extract_output_lengths(input_lengths): """ @@ -156,7 +161,7 @@ def __init__( activation_function="gelu", output_dim=2048, n_window_infer=800, - conv_chunksize=500, + conv_chunksize=QWEN3_OMNI_CONV_CHUNKSIZE, encoder_attention_heads=20, attention_dropout=0, activation_dropout=0, @@ -212,7 +217,7 @@ def _init_datatype(self): elif self.data_type in ["fp32", "float32"]: self.data_type = torch.float32 else: - raise ValueError(f"Unsupport datatype {self.data_type}!") + raise ValueError(f"Unsupported datatype {self.data_type}!") return def _freeze_parameters(self): @@ -259,6 +264,7 @@ def load_model(self, weight_dir, config): self.load_state_dict(weight_dict) + @torch.inference_mode() def forward( self, input_features, @@ -268,6 +274,7 @@ def forward( aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + # TODO: Avoid constructing CUDA tensors directly from Python list data. chunk_lengths = torch.tensor( [self.n_window * 2] * chunk_num.sum(), dtype=torch.long, @@ -311,6 +318,7 @@ def forward( remainder = cnn_len % window_aftercnn if remainder != 0: cu_chunk_lens += [remainder] + # TODO: Avoid constructing CUDA tensors directly from Python list data. cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() for encoder_layer in self.layers: @@ -327,6 +335,7 @@ def forward( hidden_states = self.proj2(hidden_states) return hidden_states + @torch.inference_mode() def encode(self, audio_items: List[AudioItem]): uuids = [] items: List[AudioItem] = [] @@ -363,3 +372,23 @@ def encode(self, audio_items: List[AudioItem]): all_embeds.append(cur_embed) return all_embeds, audio_items + + @torch.inference_mode() + def check_long_audio_infer(self): + """Exercise forward with mel length chosen so the conv loop runs once with batch dim == conv_chunksize.""" + params = next(self.parameters()) + device = params.device + dtype = params.dtype + frame_len = self.conv_chunksize * (self.n_window * 2) + logger.info( + "check_long_audio_infer: start frame_len=%s conv_chunksize=%s n_window=%s device=%s dtype=%s", + frame_len, + self.conv_chunksize, + self.n_window, + device, + dtype, + ) + input_features = torch.zeros(self.num_mel_bins, frame_len, device=device, dtype=dtype) + feature_lens = torch.tensor([frame_len], device=device, dtype=torch.long) + out = self.forward(input_features, feature_lens=feature_lens) + logger.info("check_long_audio_infer: done output_shape=%s", tuple(out.shape)) diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index 0276724749..85b1352b58 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -204,7 +204,7 @@ def _init_datatype(self): elif self.data_type in ["fp32", "float32"]: self.data_type = torch.float32 else: - raise ValueError(f"Unsupport datatype {self.data_type}!") + raise ValueError(f"Unsupported datatype {self.data_type}!") return def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids): @@ -321,6 +321,7 @@ def fast_pos_embed_interpolate(self, grid_thw): idx_list[i].extend(indices[i].tolist()) weight_list[i].extend(weights[i].tolist()) + # TODO: Avoid constructing CUDA tensors directly from Python list data. idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) weight_tensor = torch.tensor( weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device @@ -388,7 +389,7 @@ def encode(self, images: List[ImageItem]): img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) # must devide merge_length cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2) diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index 8951a04381..6567eb57cc 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -1,12 +1,10 @@ import torch -import torch.distributed as dist from typing import Tuple from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.distributed import all_reduce from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor @@ -27,6 +25,7 @@ def _get_qkv( layer_weight: Qwen3TransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) layer_weight.qk_norm_weight_( @@ -43,6 +42,9 @@ def _get_qkv( self.mrope_section, is_interleaved=True, ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight): @@ -53,22 +55,22 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la o = self._context_attention_wrapper_run(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + + input_embdings = self._tpsp_allgather(input=input_embdings, infer_state=infer_state) self._apply_deepstack_features_wrapper_run( input_embeddings=input_embdings, infer_state=infer_state, layer_num=self.layer_num_, ) + input_embdings = self._tpsp_sp_split(input=input_embdings, infer_state=infer_state) + return input_embdings def _apply_deepstack_features_wrapper_run( diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index bed8898115..bab0800f26 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -199,7 +199,7 @@ def _init_datatype(self): elif self.data_type in ["fp32", "float32"]: self.data_type = torch.float32 else: - raise ValueError(f"Unsupport datatype {self.data_type}!") + raise ValueError(f"Unsupported datatype {self.data_type}!") return def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids): @@ -316,6 +316,7 @@ def fast_pos_embed_interpolate(self, grid_thw): idx_list[i].extend(indices[i].tolist()) weight_list[i].extend(weights[i].tolist()) + # TODO: Avoid constructing CUDA tensors directly from Python list data. idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) weight_tensor = torch.tensor( weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device @@ -386,7 +387,7 @@ def encode(self, images: List[ImageItem]): img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) # must devide merge_length cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2) diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index 40d4bbc0ad..edf1f8cecf 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -1,12 +1,10 @@ import torch -import torch.distributed as dist from typing import Tuple from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.distributed import all_reduce from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor @@ -25,6 +23,7 @@ def _get_qkv( layer_weight: Qwen3MOETransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) qkv = layer_weight.qkv_proj.mm(input) q, cache_kv = qkv.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 @@ -43,6 +42,9 @@ def _get_qkv( self.mrope_section, is_interleaved=True, ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight): @@ -53,22 +55,20 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la o = self._context_attention_wrapper_run(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.tp_world_size_ > 1: - all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + input_embdings = self._tpsp_allgather(input=input_embdings, infer_state=infer_state) self._apply_deepstack_features_wrapper_run( input_embeddings=input_embdings, infer_state=infer_state, layer_num=self.layer_num_, ) + input_embdings = self._tpsp_sp_split(input=input_embdings, infer_state=infer_state) return input_embdings def _apply_deepstack_features_wrapper_run( diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py index cd7c8d908d..0006a682f1 100644 --- a/lightllm/models/qwen3next/infer_struct.py +++ b/lightllm/models/qwen3next/infer_struct.py @@ -11,6 +11,6 @@ def __init__(self): def init_some_extra_state(self, model): super().init_some_extra_state(model) self.b_att_seq_len = self.b_seq_len - self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() - + mtp_step = get_env_start_args().mtp_step + self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index return diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index dce5e96b31..e6f40125f9 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -1,4 +1,3 @@ -import os import torch import torch.distributed as dist @@ -7,13 +6,14 @@ ) from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl from lightllm.utils.log_utils import init_logger -from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor +from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager from typing import Tuple -from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.gdn_decode_pack import conv_pack_gdn_decode_inputs +from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import sigmoid_mul_ from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce @@ -72,7 +72,7 @@ def _init_linear_layer_metadata(self, layer_num, network_config): # SSM state dtype optimization ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} start_args = get_env_start_args() - self.ssm_state_dtype = ssm_dtype_dict.get(start_args.mamba_ssm_data_type, torch.bfloat16) + self.ssm_state_dtype = ssm_dtype_dict.get(start_args.linear_att_ssm_data_type, torch.bfloat16) # Pre-compute whether dtype conversion is needed # GDN kernel output dtype is self.data_type @@ -87,33 +87,46 @@ def _bind_func(self): def _bind_ffn(self): if self.is_moe: - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": - self._ffn = partial(Qwen3NextTransformerLayerInfer._moe_ffn_edp, self) + enable_ep_moe = get_env_start_args().enable_ep_moe + if enable_ep_moe: + self._ffn = self._ffn_ep_impl else: - self._ffn = partial(Qwen3NextTransformerLayerInfer._moe_ffn, self) + self._ffn = self._ffn_tp_impl else: - self._ffn = partial(Qwen3NextTransformerLayerInfer._ffn, self) + self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) return + def _ffn_tp_impl( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + ffn2_out = self._moe_ffn_tp(input=input, infer_state=infer_state, layer_weight=layer_weight) + return self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) + + def _ffn_ep_impl( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + # ep 本身就是一种 sp 兼容,所以不需要再进行 allgather 和 reduce + input = input.view(-1, self.embed_dim_) + return self._moe_ffn_edp(input=input, infer_state=infer_state, layer_weight=layer_weight) + def _compute_shared_expert( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): input = input.view(-1, self.embed_dim_) - shared_expert_out = super()._ffn(input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input).sigmoid_() - shared_expert_out.mul_(gate) + shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) + gate = layer_weight.shared_expert_gate.mm(input) + sigmoid_mul_(shared_expert_out, gate) return shared_expert_out - def _moe_ffn( + def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) - hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) + shared_expert_gate = layer_weight.shared_expert_gate.mm(hidden_states) layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -122,9 +135,9 @@ def _moe_ffn( use_grouped_topk=False, topk_group=None, num_expert_group=None, + shared_expert_gate=shared_expert_gate, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) - hidden_states.add_(shared_expert_out) return hidden_states def _moe_ffn_edp( @@ -155,13 +168,20 @@ def _get_qkv( layer_weight: Qwen3NextTransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) - qkv_out = layer_weight.qkv_proj.mm(input) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + qkv_gate_out = layer_weight.qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ * 2 + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_ * 2, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1, ) - o_gate = layer_weight._o_gate_proj.mm(input) - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_logics_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -175,6 +195,9 @@ def _get_qkv( infer_state.position_sin, partial_rotary_factor=self.partial_rotary_factor, ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv def _get_o( @@ -184,10 +207,13 @@ def _get_o( layer_weight: Qwen3NextTransformerLayerWeight, ) -> torch.Tensor: """Output projection with gating (in-place multiply to save one allocation).""" + if infer_state.need_dp_prefill_balance: + input = infer_state._all_to_all_balance_get(data=input) input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - input.mul_(infer_state.gate_value) - infer_state.gate_value = None + sigmoid_mul_(input, infer_state.gate_logics_value) + infer_state.gate_logics_value = None o_tensor = layer_weight.o_proj.mm(input) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor # ==================== GDN Helper Methods ==================== @@ -227,21 +253,19 @@ def gdn_forward( layer_weight: Qwen3NextTransformerLayerWeight, is_prefill: bool, ): - assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) + assert isinstance(infer_state.mem_manager, Qwen3NextMemManager) input = input.view(-1, self.embed_dim_) - conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) - mixed_qkvzba = layer_weight.linear_in_proj.mm(input) - mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) if is_prefill: - core_attn_out = self._gdn_prefill_kernel( - mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight - ) + core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight) else: - core_attn_out = self._gdn_decode_kernel( + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) + core_attn_out, z = self._gdn_decode_kernel( mixed_qkv, + z, conv_states, ssm_states, a, @@ -258,7 +282,55 @@ def gdn_forward( output = layer_weight.linear_out_proj.mm(core_attn_out) return output - def _split_qkvzba(self, mixed_qkvzba, is_decode=False): + def _gdn_prefill_wrapper_run( + self, + mixed_qkvzba: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if torch.cuda.is_current_stream_capturing(): + mixed_qkvzba = mixed_qkvzba.contiguous() + _mixed_qkvzba = tensor_to_no_ref_tensor(mixed_qkvzba) + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() + pre_capture_graph.__exit__(None, None, None) + + # _gdn_prefill_kernel returns the pre-projection value stream. Its + # logical size is num_tokens * local value heads * value head dim. + # We avoid a dry-run because FlashQLA may do host-side syncs while + # preparing varlen chunk metadata, which is illegal during capture. + num_tokens = mixed_qkvzba.shape[0] + o_shape = (num_tokens, self.tp_num_v_heads, self.head_v_dim) + o_dtype = mixed_qkvzba.dtype + o_device = mixed_qkvzba.device + z_shape = o_shape + + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + o = torch.empty(o_shape, dtype=o_dtype, device=o_device) + _o = tensor_to_no_ref_tensor(o) + z = torch.empty(z_shape, dtype=o_dtype, device=o_device) + _z = tensor_to_no_ref_tensor(z) + + def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo): + conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_) + mixed_qkv, tmp_z, b, a = self._split_qkvzba(_mixed_qkvzba) + _z.copy_(tmp_z) + tmp_o = self._gdn_prefill_kernel( + mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight + ) + tmp_o = tmp_o.view(_o.shape) + _o.copy_(tmp_o) + return + + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=gdn_prefill_func, after_graph=pre_capture_graph) + return o, z + + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) + core_attn_out = self._gdn_prefill_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) + return core_attn_out, z + + def _split_qkvzba(self, mixed_qkvzba): qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim z_end = qkv_dim + self.tp_value_dim b_end = z_end + self.tp_num_v_heads @@ -341,6 +413,7 @@ def _gdn_prefill_kernel( def _gdn_decode_kernel( self, mixed_qkv: torch.Tensor, + z: torch.Tensor, conv_states: torch.Tensor, ssm_states: torch.Tensor, a: torch.Tensor, @@ -348,18 +421,25 @@ def _gdn_decode_kernel( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - mixed_qkv = causal_conv1d_update( + # Recurrent processing with fused gating. Decode uses a specialized + # conv+pack kernel to avoid materializing the post-conv qkv tensor + # before immediately splitting it into q/k/v. + query, key, value, z, a, b = conv_pack_gdn_decode_inputs( mixed_qkv, + z, + a, + b, conv_states, layer_weight.linear_conv1d.mm_param.weight, - bias=layer_weight.linear_conv1d.bias, - activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + layer_weight.linear_conv1d.bias, + infer_state.b_buffer_idx, + self.activation, + self.conv_kernel_dim, + self.tp_num_k_heads, + self.head_k_dim, + self.tp_num_v_heads, + self.head_v_dim, ) - - # Recurrent processing with fused gating - # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally - query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -373,4 +453,4 @@ def _gdn_decode_kernel( a_raw=a, b_raw=b, ) - return core_attn_out + return core_attn_out, z diff --git a/lightllm/models/qwen3next/layer_weights/qkv_gated_rowmm_weight.py b/lightllm/models/qwen3next/layer_weights/qkv_gated_rowmm_weight.py new file mode 100644 index 0000000000..c920b23fbd --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/qkv_gated_rowmm_weight.py @@ -0,0 +1,75 @@ +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import get_row_slice_mixin +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size + + +class QKVGatedROWNMMWeight(MMWeightTpl): + def __init__( + self, + in_dim, + q_head_num, + kv_head_num, + head_dim, + weight_names, + data_type, + bias_names=None, + quant_method=None, + tp_rank=None, + tp_world_size=None, + ): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.q_repeat_times = 1 + self.kv_repeat_times = self._get_kv_repeat_times(kv_head_num) + assert ( + q_head_num % self.tp_world_size_ == 0 + ), f"q_head_num must be divisible by tp_world_size_, found {q_head_num} % {self.tp_world_size_}" + q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num, self.kv_repeat_times) * head_dim + super().__init__( + in_dim=in_dim, + out_dims=[q_hidden_size, kv_hidden_size, kv_hidden_size, q_hidden_size], + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.q_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.q_repeat_times, + ) + self.kv_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.kv_repeat_times, + ) + + def _get_param_slicer(self, sub_child_index): + if sub_child_index == 0 or sub_child_index == 3: + return self.q_param_slicer + return self.kv_param_slicer + + def load_hf_weights(self, weights): + super().load_hf_weights(weights) + if self.bias_names is not None: + for sub_child_index, bias_name in enumerate(self.bias_names): + if bias_name is None: + self.bias_list[sub_child_index].zero_() + self.bias_list[sub_child_index].load_ok = True + + def _get_kv_repeat_times(self, kv_head_num): + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or vice versa, " + f"found {kv_head_num} % {self.tp_world_size_}" + ) + if kv_head_num % self.tp_world_size_ == 0: + return 1 + return self.tp_world_size_ // kv_head_num + + def _get_tp_padded_head_num(self, head_num, repeat_times): + return repeat_times * head_num // self.tp_world_size_ diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 60a75f306d..60901ad6b9 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -1,5 +1,6 @@ import torch from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, COLMMWeight, @@ -9,7 +10,9 @@ TpParameterWeight, QKVROWNMMWeight, QKGEMMANormWeight, + FusedMoeWeight, ) +from lightllm.models.qwen3next.layer_weights.qkv_gated_rowmm_weight import QKVGatedROWNMMWeight class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -21,25 +24,17 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): def _init_qkv(self): in_dim = self.n_embed - q_out_dim = self.q_head_num_ * self.head_dim - self.qkv_proj = QKVROWNMMWeight( + self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + qkv_quant = self.get_quant_method("qkv_proj") + self.qkvo_gate_proj = QKVGatedROWNMMWeight( in_dim=in_dim, q_head_num=self.q_head_num_, kv_head_num=self.k_head_num_, head_dim=self.head_dim, - weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name, self._o_gate_weight_name], data_type=self.data_type_, - bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_method=self.get_quant_method("qkv_proj"), - ) - self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" - self._o_gate_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[q_out_dim], - weight_names=[self._o_gate_weight_name], - data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("o_gate_proj"), + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name, None], + quant_method=qkv_quant, ) def _init_weight(self): @@ -56,8 +51,47 @@ def _init_weight(self): self._init_norm() def _init_moe(self): - super()._init_moe() - self._init_gated_ffn() + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + enable_ep_moe = get_env_start_args().enable_ep_moe + # Fused shared expert is only supported in TP mode. EP keeps the shared + # expert as a separate FFN and adds its output after routed MoE. + self.num_fused_shared_experts = 0 if enable_ep_moe else 1 + self.shared_expert_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + num_fused_shared_experts=self.num_fused_shared_experts, + layer_num=self.layer_num_, + network_config=self.network_config_, + ) + if enable_ep_moe: + self._init_moe_shared_expert_ffn() return def _init_norm(self): @@ -80,10 +114,8 @@ def _init_norm(self): data_type=self.data_type_, ) - def _init_gated_ffn(self): + def _init_moe_shared_expert_ffn(self): hidden_size = self.network_config_["hidden_size"] - if "shared_expert_intermediate_size" not in self.network_config_: - return prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" inter_size = self.network_config_["shared_expert_intermediate_size"] self.gate_up_proj = ROWMMWeight( @@ -92,6 +124,8 @@ def _init_gated_ffn(self): weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], data_type=self.data_type_, quant_method=self.get_quant_method("gate_up_proj"), + tp_rank=0, + tp_world_size=1, ) self.down_proj = COLMMWeight( in_dim=inter_size, @@ -99,14 +133,6 @@ def _init_gated_ffn(self): weight_names=f"{prefix}.down_proj.weight", data_type=self.data_type_, quant_method=self.get_quant_method("down_proj"), - ) - self.ffn_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, tp_rank=0, tp_world_size=1, ) @@ -121,6 +147,29 @@ def _split_q_with_gate(self, weights): weights[self._q_weight_name] = _q_proj weights[self._o_gate_weight_name] = _gate_proj + def _rename_shared_expert_to_moe_expert(self, weights): + if self.num_fused_shared_experts != 1: + return + assert not get_env_start_args().enable_ep_moe, "fused shared expert is only supported in TP mode" + assert self.num_fused_shared_experts == 1, "only one fused shared expert is supported" + + # When the shared expert is fused into MoE, load it as the last routed expert. + # The fused MoE kernel then treats expert id n_routed_experts as this shared expert. + old_prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + new_prefix = f"model.layers.{self.layer_num_}.mlp.experts.{self.n_routed_experts}" + suffixes = [ + self.experts.quant_method.weight_suffix, + self.experts.quant_method.weight_scale_suffix, + self.experts.quant_method.weight_zero_point_suffix, + ] + for proj_name in ("gate_proj", "up_proj", "down_proj"): + for suffix in suffixes: + if suffix is None: + continue + old_name = f"{old_prefix}.{proj_name}.{suffix}" + if old_name in weights: + weights[f"{new_prefix}.{proj_name}.{suffix}"] = weights[old_name] + def _parse_config(self): super()._parse_config() self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] @@ -266,6 +315,8 @@ def _parse_linear_conv1d(self, weight): def load_hf_weights(self, weights): self._split_q_with_gate(weights) + if self.is_moe: + self._rename_shared_expert_to_moe_expert(weights) if self.is_linear_attention_layer: self._preprocess_weight(weights) super().load_hf_weights(weights) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py deleted file mode 100644 index 12a6d56b8c..0000000000 --- a/lightllm/models/qwen3next/mem_manager.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -from typing import Tuple -from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager -from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager -from lightllm.server.core.objs.start_args_type import StartArgs - -logger = init_logger(__name__) - - -class Qwen3NextHybridMemManager(MemoryManager): - def __init__( - self, - full_attn_cache_size, - linear_attn_cache_size, - dtype, - num_kv_heads, - head_dim, - layer_num, - full_attention_interval: int, - conv_state_dtype: torch.dtype, - ssm_state_dtype: torch.dtype, - conv_kernel_size: int, - num_linear_k_heads: int, - num_linear_v_heads: int, - head_linear_k_dim: int, - head_linear_v_dim: int, - max_req_num: int, - always_copy=False, - mem_fraction=0.9, - network_config: dict = None, - ): - - self.full_attention_interval = full_attention_interval - assert layer_num % full_attention_interval == 0 - self.layer_num = layer_num - self.full_attn_layer_num = layer_num // full_attention_interval - self.linear_attn_layer_num = layer_num - self.full_attn_layer_num - - self.mamba_cache_mem_manager = MambaCacheManager( - size=linear_attn_cache_size, - layer_num=self.linear_attn_layer_num, - conv_state_dtype=conv_state_dtype, - ssm_state_dtype=ssm_state_dtype, - conv_kernel_size=conv_kernel_size, - num_linear_k_heads=num_linear_k_heads, - num_linear_v_heads=num_linear_v_heads, - head_linear_k_dim=head_linear_k_dim, - head_linear_v_dim=head_linear_v_dim, - ) - - super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) - - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ..., - # None, kv_cache, mtp_kv_cache, mtp_kv_cache] - # Only full attention layers have KV cache. - self.kv_buffer = [None for _ in range(self.layer_num)] - for layer_id in range(self.full_attn_layer_num): - self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( - (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" - ) - - def free_all(self): - super().free_all() - self.mamba_cache_mem_manager.free_all() - return - - def get_cell_size(self): - # Only full attention layers and MTP layers have KV cache - kv_cache_layer_num = self.full_attn_layer_num - return 2 * self.head_num * self.head_dim * kv_cache_layer_num * torch._utils._element_size(self.dtype) - - def get_mamba_cache(self, layer_idx: int): - layer_idx_in_linear = layer_idx - (layer_idx // self.full_attention_interval) - return self.mamba_cache_mem_manager.get_mamba_cache(layer_idx_in_linear) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 4af692df7c..9b5e9b7a50 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -12,12 +12,11 @@ ) from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger -from lightllm.distributed.communication_op import dist_group_manager from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager +from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba -from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache +from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig logger = init_logger(__name__) @@ -35,12 +34,11 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): # infer state class infer_state_class = Qwen3NextInferStateInfo - # radix cache class - radix_cache_class = HybridRadixCache - def __init__(self, kvargs) -> None: - self.mem_manager: Qwen3NextHybridMemManager = None + self._init_triton() + super().__init__(kvargs) + def _init_triton(self): def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch.Tensor: return torch.empty(size, device="cuda", dtype=torch.int8) @@ -48,7 +46,7 @@ def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch # This is required for kernels in qwen3next/triton_kernel/fla/ops/solve_tril.py triton.set_allocator(_triton_allocator) logger.info("Triton allocator set for Qwen3Next model") - super().__init__(kvargs) + return def autotune_layers(self): return self.config["full_attention_interval"] @@ -57,37 +55,38 @@ def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - def _init_custom(self): - super()._init_custom() - # Only initialize DeepEP group for MoE models with num_experts - if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) - def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 start_args: StartArgs = get_env_start_args() - self.num_linear_k_heads = self.config["linear_num_key_heads"] // self.tp_world_size_ - self.num_linear_v_heads = self.config["linear_num_value_heads"] // self.tp_world_size_ - self.head_linear_k_dim = self.config["linear_key_head_dim"] - self.head_linear_v_dim = self.config["linear_value_head_dim"] - conv_kernel_size = self.config["linear_conv_kernel_dim"] ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} - self.mem_manager = Qwen3NextHybridMemManager( - full_attn_cache_size=self.max_total_token_num, - linear_attn_cache_size=start_args.mamba_cache_size, + self.linear_config = LinearAttCacheConfig( + tp_world_size=self.tp_world_size_, + full_att_all_num_kv_heads=self.config["num_key_value_heads"], + full_att_dtype=self.data_type, + full_att_num_kv_heads=self.num_kv_heads, + full_att_head_dim=self.config["head_dim"], + global_linear_k_heads=self.config["linear_num_key_heads"], + global_linear_v_heads=self.config["linear_num_value_heads"], + num_linear_k_heads=max(1, self.config["linear_num_key_heads"] // self.tp_world_size_), + num_linear_v_heads=max(1, self.config["linear_num_value_heads"] // self.tp_world_size_), + head_linear_k_dim=self.config["linear_key_head_dim"], + head_linear_v_dim=self.config["linear_value_head_dim"], + conv_kernel_size=self.config["linear_conv_kernel_dim"], + linear_layer_num=self.config["n_layer"] + - (self.config["n_layer"] // self.config["full_attention_interval"]), + conv_state_dtype=self.data_type, + ssm_state_dtype=ssm_dtype_dict[start_args.linear_att_ssm_data_type], + full_attention_interval=self.config["full_attention_interval"], + all_layer_num=self.config["n_layer"], + ) + + self.mem_manager = Qwen3NextMemManager( + size=self.max_total_token_num, dtype=self.data_type, num_kv_heads=self.num_kv_heads, head_dim=self.config["head_dim"], - layer_num=self.config["n_layer"], - full_attention_interval=self.config["full_attention_interval"], - conv_state_dtype=self.data_type, - ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type], - conv_kernel_size=conv_kernel_size, - num_linear_k_heads=self.num_linear_k_heads, - num_linear_v_heads=self.num_linear_v_heads, - head_linear_k_dim=self.head_linear_k_dim, - head_linear_v_dim=self.head_linear_v_dim, - max_req_num=self.max_req_num, + full_att_layer_num=self.linear_config.all_layer_num - self.linear_config.linear_layer_num, + linear_config=self.linear_config, mem_fraction=self.mem_fraction, ) @@ -99,4 +98,7 @@ def _init_req_manager(self): if self.max_seq_length is not None: create_max_seq_len = max(create_max_seq_len, self.max_seq_length) - self.req_manager = ReqManagerForMamba(self.max_req_num, create_max_seq_len, self.mem_manager) + self.req_manager = ReqManagerForMamba( + self.max_req_num, create_max_seq_len, None, linear_config=LinearAttCacheConfig.load_from_args() + ) + return diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py index 22a93a2c99..b0dc41a3c1 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -54,6 +54,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + stride_q_tok: tl.constexpr, + stride_k_tok: tl.constexpr, + stride_v_tok: tl.constexpr, + stride_a_tok: tl.constexpr, + stride_b_tok: tl.constexpr, stride_init_state_token: tl.constexpr, stride_final_state_token: tl.constexpr, stride_indices_seq: tl.constexpr, @@ -94,15 +99,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( o_k = i_k * BK + tl.arange(0, BK) o_v = i_v * BV + tl.arange(0, BV) - p_q = q + (bos * H + i_h) * K + o_k - p_k = k + (bos * H + i_h) * K + o_k - p_v = v + (bos * HV + i_hv) * V + o_v + p_q = q + bos * stride_q_tok + i_h * K + o_k + p_k = k + bos * stride_k_tok + i_h * K + o_k + p_v = v + bos * stride_v_tok + i_hv * V + o_v if FUSE_GATING: # Fused gating: load per-head constants once, compute g/beta inline per token b_A_log = tl.load(A_log + i_hv).to(tl.float32) b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32) - p_a_raw = a_raw + bos * HV + i_hv - p_b_raw = b_raw + bos * HV + i_hv + p_a_raw = a_raw + bos * stride_a_tok + i_hv + p_b_raw = b_raw + bos * stride_b_tok + i_hv else: if IS_BETA_HEADWISE: p_beta = beta + (bos * HV + i_hv) * V + o_v @@ -193,13 +198,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) - p_q += H * K - p_k += H * K + p_q += stride_q_tok + p_k += stride_k_tok p_o += HV * V - p_v += HV * V + p_v += stride_v_tok if FUSE_GATING: - p_a_raw += HV - p_b_raw += HV + p_a_raw += stride_a_tok + p_b_raw += stride_b_tok else: if not IS_KDA: p_g += HV @@ -208,6 +213,34 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_beta += HV * (V if IS_BETA_HEADWISE else 1) +def _ensure_qkv_token_strided(x: torch.Tensor, inner_numel: int): + """Return q/k/v and token stride, copying only when needed.""" + if x is None: + return None, 0 + + # Decode layout must be [tokens, 1, head, dim]. + assert x.shape[1] == 1, "q/k/v must use decode layout [tokens, 1, head, dim]" + + # Packed tail [head, dim] means the last two strides are [dim, 1]. + tail_contiguous = x.stride()[-2:] == (x.shape[-1], 1) + if not tail_contiguous: + x = x.contiguous() + return x, inner_numel + else: + return x, x.stride(0) + + +def _ensure_gate_token_strided(x: torch.Tensor, inner_numel: int): + """Return a_raw/b_raw and token stride, copying only when needed.""" + if x is None: + return None, 0 + # a_raw/b_raw are 2D [tokens, HV]; the tail HV dimension must be packed. + if x.stride(1) != 1: + x = x.contiguous() + return x, inner_numel + return x, x.stride(0) + + def fused_recurrent_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, @@ -231,7 +264,16 @@ def fused_recurrent_gated_delta_rule_fwd( ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] - N = B if cu_seqlens is None else len(cu_seqlens) - 1 + # In LightLLM's Qwen3Next inference path this fused recurrent kernel is + # used only for decode. Prefill/varlen requests are handled by + # chunk_gated_delta_rule, so keep cu_seqlens out of this strided-view path. + assert cu_seqlens is None, "cu_seqlens is not supported by the decode-only fused recurrent kernel" + N = B + q, stride_q_tok = _ensure_qkv_token_strided(q, H * K) + k, stride_k_tok = _ensure_qkv_token_strided(k, H * K) + v, stride_v_tok = _ensure_qkv_token_strided(v, HV * V) + a_raw, stride_a_tok = _ensure_gate_token_strided(a_raw, HV) + b_raw, stride_b_tok = _ensure_gate_token_strided(b_raw, HV) BK = triton.next_power_of_2(K) if T == 1: # Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16) @@ -261,20 +303,23 @@ def fused_recurrent_gated_delta_rule_fwd( stride_init_state_token = initial_state.stride(0) stride_final_state_token = final_state.stride(0) - # Strides for read indices + # Strides for read indices. The kernel advances along a row with `+ i_t` + # (token stride 1), so 2D index tensors must have contiguous rows. if ssm_state_indices is None: stride_indices_seq, stride_indices_tok = 1, 1 elif ssm_state_indices.ndim == 1: stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 else: + assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows" stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() - # Strides for write indices (if provided) + # Strides for write indices (if provided); same contiguous-row requirement if ssm_state_write_indices is None: stride_write_indices_seq, stride_write_indices_tok = 1, 1 elif ssm_state_write_indices.ndim == 1: stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 else: + assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows" stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride() grid = (NK, NV, N * HV) @@ -305,6 +350,11 @@ def fused_recurrent_gated_delta_rule_fwd( V=V, BK=BK, BV=BV, + stride_q_tok=stride_q_tok, + stride_k_tok=stride_k_tok, + stride_v_tok=stride_v_tok, + stride_a_tok=stride_a_tok, + stride_b_tok=stride_b_tok, stride_init_state_token=stride_init_state_token, stride_final_state_token=stride_final_state_token, stride_indices_seq=stride_indices_seq, @@ -348,10 +398,12 @@ def forward( b_raw: torch.Tensor | None = None, out: torch.Tensor | None = None, ): + # q/k/v/a_raw/b_raw may be non-contiguous column views of one projection + # output; the kernel handles them via per-token strides (no copies). o, final_state = fused_recurrent_gated_delta_rule_fwd( - q=q.contiguous(), - k=k.contiguous(), - v=v.contiguous(), + q=q, + k=k, + v=v, g=g.contiguous() if g is not None else None, beta=beta.contiguous() if beta is not None else None, scale=scale, @@ -364,8 +416,8 @@ def forward( use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, A_log=A_log, dt_bias=dt_bias, - a_raw=a_raw.contiguous() if a_raw is not None else None, - b_raw=b_raw.contiguous() if b_raw is not None else None, + a_raw=a_raw, + b_raw=b_raw, out=out, ) @@ -417,8 +469,9 @@ def fused_recurrent_gated_delta_rule( Whether to store the final state in-place to save memory. Default: `True`. cu_seqlens (torch.LongTensor): - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, - consistent with the FlashAttention API. + Must be `None`. In LightLLM this fused recurrent kernel is used only + by the Qwen3Next decode path; prefill/varlen requests use + `chunk_gated_delta_rule`. ssm_state_indices (Optional[torch.Tensor]): Indices to map the input sequences to the initial/final states. num_accepted_tokens (Optional[torch.Tensor]): @@ -433,10 +486,9 @@ def fused_recurrent_gated_delta_rule( Examples:: >>> import torch >>> import torch.nn.functional as F - >>> from einops import rearrange >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule - # inputs with equal lengths - >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + # decode inputs + >>> B, T, H, HV, K, V = 4, 1, 4, 8, 512, 512 >>> q = torch.randn(B, T, H, K, device='cuda') >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) >>> v = torch.randn(B, T, HV, V, device='cuda') @@ -447,21 +499,10 @@ def fused_recurrent_gated_delta_rule( q, k, v, g, beta, initial_state=h0, ) - # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required - >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) - # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected - >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = fused_gated_recurrent_delta_rule( - q, k, v, g, beta, - initial_state=h0, - cu_seqlens=cu_seqlens - ) """ - if cu_seqlens is not None and q.shape[0] != 1: - raise ValueError( - f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing." - ) + # This wrapper is only used for Qwen3Next decode inference in LightLLM. + # Keep varlen/prefill inputs on chunk_gated_delta_rule instead. + assert cu_seqlens is None, "cu_seqlens is not supported by the decode-only fused recurrent kernel" if scale is None: scale = k.shape[-1] ** -0.5 else: diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py new file mode 100644 index 0000000000..c4efec47ea --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py @@ -0,0 +1,177 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _conv_pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + stride_s_b: tl.constexpr, + stride_s_d: tl.constexpr, + stride_s_w: tl.constexpr, + stride_w_d: tl.constexpr, + stride_w_w: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + conv_dim: tl.constexpr, + KERNEL_SIZE: tl.constexpr, + HAS_BIAS: tl.constexpr, + APPLY_SILU: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offs = block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < conv_dim + state_idx = tl.load(conv_state_indices + row) + + x = tl.load(mixed_qkv + row * stride_m_b + offs * stride_m_d, mask=mask, other=0.0).to(tl.float32) + # KERNEL_SIZE is a constexpr, so Triton fully unrolls these loops for each conv size. + y = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + for i in tl.static_range(0, KERNEL_SIZE - 1): + s = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + i * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + w = tl.load(conv_weight + offs * stride_w_d + i * stride_w_w, mask=mask, other=0.0).to(tl.float32) + y += s * w + + w = tl.load(conv_weight + offs * stride_w_d + (KERNEL_SIZE - 1) * stride_w_w, mask=mask, other=0.0).to(tl.float32) + y += x * w + if HAS_BIAS: + bias = tl.load(conv_bias + offs, mask=mask, other=0.0).to(tl.float32) + y += bias + if APPLY_SILU: + y = y * tl.sigmoid(y) + + for i in tl.static_range(0, KERNEL_SIZE - 2): + next_s = tl.load( + conv_state + state_idx * stride_s_b + offs * stride_s_d + (i + 1) * stride_s_w, mask=mask, other=0.0 + ) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + i * stride_s_w, next_s, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + (KERNEL_SIZE - 2) * stride_s_w, x, mask=mask) + + q_mask = offs < q_dim + k_mask = (offs >= q_dim) & (offs < q_dim + k_dim) + v_mask = (offs >= q_dim + k_dim) & (offs < conv_dim) + tl.store(q_out + row * q_dim + offs, y, mask=q_mask) + tl.store(k_out + row * k_dim + (offs - q_dim), y, mask=k_mask) + tl.store(v_out + row * v_dim + (offs - q_dim - k_dim), y, mask=v_mask) + + z_mask = offs < v_dim + z_vals = tl.load(z_raw + row * stride_z_b + offs, mask=z_mask, other=0.0) + tl.store(z_out + row * v_dim + offs, z_vals, mask=z_mask) + + gate_mask = offs < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + offs * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + offs * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + offs, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + offs, b_vals, mask=gate_mask) + + +@torch.no_grad() +def conv_pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + conv_state: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + conv_state_indices: torch.Tensor, + activation: str, + conv_size: int, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + conv_dim = q_dim + k_dim + v_dim + + assert conv_size >= 2, f"conv kernel size must be at least 2, got {conv_size}" + assert mixed_qkv.shape[1] == conv_dim, f"mixed_qkv shape mismatch: {mixed_qkv.shape[1]} != {conv_dim}" + assert conv_weight.shape[0] == conv_dim, f"conv_weight shape mismatch: {conv_weight.shape[0]} != {conv_dim}" + assert conv_weight.shape[1] == conv_size, f"conv_weight kernel mismatch: {conv_weight.shape[1]} != {conv_size}" + assert conv_state.shape[1] == conv_dim, f"conv_state shape mismatch: {conv_state.shape[1]} != {conv_dim}" + assert ( + conv_state.shape[2] >= conv_size - 1 + ), f"conv_state width must be at least conv_size - 1, got {conv_state.shape[2]} and {conv_size}" + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_size = 256 + grid = (batch, triton.cdiv(conv_dim, block_size)) + _conv_pack_gdn_decode_kernel[grid]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + conv_dim, + conv_size, + HAS_BIAS=conv_bias is not None, + APPLY_SILU=activation in ["silu", "swish"], + BLOCK_SIZE=block_size, + num_warps=8, + ) + return q, k, v, z, a, b diff --git a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py new file mode 100644 index 0000000000..8b73cfd74d --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py @@ -0,0 +1,50 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _sigmoid_mul_kernel( + x, + gate, + stride_x_m: tl.constexpr, + stride_x_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + x_ptrs = x + row * stride_x_m + offs * stride_x_n + x_vals = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + gate_vals = tl.sigmoid(gate_vals) + tl.store(x_ptrs, (x_vals * gate_vals).to(x.dtype.element_ty), mask=mask) + + +@torch.no_grad() +def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + x_arg = x.view(-1, x.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1]) + _, n = x_arg.shape + block_n = triton.next_power_of_2(n) + _sigmoid_mul_kernel[(x_arg.shape[0],)]( + x=x_arg, + gate=gate_arg, + stride_x_m=x_arg.stride(0), + stride_x_n=x_arg.stride(1), + stride_g_m=gate_arg.stride(0), + stride_g_n=gate_arg.stride(1), + N=n, + GATE_N=gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return x diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 9b9fe2569c..fa84ce6376 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -70,11 +70,37 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_locs_in_cache, dtype=torch.long, device="cpu", pin_memory=True ).cuda(non_blocking=True) + self._multimodal_emb( + out=out, + input_ids=input_ids, + layer_weight=layer_weight, + embed_cache=cpu_embed_cache_tensor, + img_token_lens=img_token_lens, + img_start_token_ids=img_start_token_ids, + img_start_locs_in_cache=img_start_locs_in_cache, + ) + if self.tp_world_size_ > 1: + all_reduce(out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) + return out + + def _multimodal_emb( + self, + out: torch.Tensor, + input_ids: torch.Tensor, + layer_weight: LlamaPreAndPostLayerWeight, + embed_cache: torch.Tensor, + img_token_lens: torch.Tensor, + img_start_token_ids: torch.Tensor, + img_start_locs_in_cache: torch.Tensor, + ) -> torch.Tensor: + """ + 方便子类继承修改多模态的embed计算的细节实现方式。 + """ multimodal_emb( out=out, prompt_ids=input_ids, text_weight_embs=layer_weight.wte_weight_.weight, - embed_cache=cpu_embed_cache_tensor, + embed_cache=embed_cache, img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids, img_start_locs_in_cache=img_start_locs_in_cache, @@ -82,6 +108,4 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, tp_world_size=self.tp_world_size_, ) - if self.tp_world_size_ > 1: - all_reduce(out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - return out + return diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index 07a7412020..e12cdb5745 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -427,7 +427,7 @@ def encode(self, image_uuids: List): t = self.image_transform(image_data) img_tensors.append(t) else: - raise Exception("Unsupport input types: {} for {}".format(type(item), item)) + raise Exception("Unsupported input types: {} for {}".format(type(item), item)) valid_ids.append([valid_id, valid_id + 1]) valid_id += 1 diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index 55848ce66a..f34619b1f8 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -1,6 +1,5 @@ import torch from functools import partial -from typing import Tuple from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.stablelm.layer_weights.transformer_layer_weight import StablelmTransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer @@ -21,6 +20,7 @@ def _bind_norm(self): def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: + input = self._tpsp_allgather(input, infer_state) q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)) cache_kv = layer_weight.kv_proj.mm( input.view(-1, self.embed_dim_), @@ -32,24 +32,22 @@ def _get_qkv( infer_state.position_sin, self.partial_rotary_factor, ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) return q, cache_kv - def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: - # TODO - raise Exception("not impl") - def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: + if infer_state.need_dp_prefill_balance: + input = infer_state._all_to_all_balance_get(data=input) o_tensor = layer_weight.o_proj.mm( input.view(-1, self.tp_o_head_num_ * self.head_dim_), ) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor - def _tpsp_get_o(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: - # TODO - raise Exception("not impl") - def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index 3e32682ecb..a4347e4a06 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -25,10 +25,13 @@ def _ffn_norm( def _ffn( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - ffn1_out = layer_weight.up_proj.mm(input.view(-1, self.embed_dim_)) + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + ffn1_out = layer_weight.up_proj.mm(input) input = None gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh") ffn1_out = None ffn2_out = layer_weight.down_proj.mm(gelu_out) gelu_out = None + ffn2_out = self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) return ffn2_out diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 9deaf08575..72ff301711 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -187,7 +187,7 @@ def __init__( projector_hidden_act, ) elif projection_head == "auto_map": - raise Exception("Unsupport projection_head auto_map") + raise Exception("Unsupported projection_head auto_map") elif projection_head is None: self.multi_modal_projector = lambda x, *args, **kwargs: x self.llm_model_type = text_config["model_type"] @@ -259,7 +259,7 @@ def encode(self, images: List[ImageItem]): img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) # must devide merge_length cur_num = img_tensors[-1].shape[0] // (self.merge_size ** 2) diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 13f8e2827f..0befb50166 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -155,7 +155,7 @@ def _init_datatype(self): elif self.data_type in ["fp32", "float32"]: self.data_type = torch.float32 else: - raise ValueError(f"Unsupport datatype {self.data_type}!") + raise ValueError(f"Unsupported datatype {self.data_type}!") @torch.no_grad() def forward(self, pixel_values): @@ -181,7 +181,7 @@ def encode(self, images: List[ImageItem]): t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + raise Exception("Unsupported input types: {} for {}".format(type(img), img)) cur_num = img.token_num valid_ids.append([valid_id, valid_id + cur_num]) diff --git a/lightllm/models/whisper/whisper_audio.py b/lightllm/models/whisper/whisper_audio.py index aaa29e1c71..b98d9f0e9e 100644 --- a/lightllm/models/whisper/whisper_audio.py +++ b/lightllm/models/whisper/whisper_audio.py @@ -135,6 +135,7 @@ def load_weight(self, weight_dir): def forward(self, audio_values, audio_lens_after_cnn): audio_values = audio_values.to(self.data_type).to(device=self.device) audio_values = audio_values.squeeze(1) + # TODO: Avoid constructing CUDA tensors directly from Python or NumPy data. audio_lens_after_cnn = torch.tensor(audio_lens_after_cnn).cuda() max_len_in_batch = torch.max(audio_lens_after_cnn).item() @@ -223,3 +224,6 @@ def encode(self, audio_items: List[AudioItem]): ans_embeds.append(cur_embed) return ans_embeds, audio_items + + def check_long_audio_infer(self): + pass diff --git a/lightllm/server/api_anthropic.py b/lightllm/server/api_anthropic.py new file mode 100644 index 0000000000..60e7f6746f --- /dev/null +++ b/lightllm/server/api_anthropic.py @@ -0,0 +1,597 @@ +"""Anthropic Messages API compatibility layer. + +Translates incoming /v1/messages requests into LightLLM's internal chat +completions pipeline by delegating the hard parts (content-block parsing, +tool schema normalisation, stop-reason mapping) to LiteLLM's adapter. + +The streaming path intercepts the OpenAI-format SSE stream from +chat_completions_impl and re-emits it as the Anthropic event sequence +(message_start, content_block_*, message_delta, message_stop). +""" +from __future__ import annotations + +import uuid +import ujson as json +from http import HTTPStatus +from typing import Any, Dict, Tuple + +from fastapi import Request +from fastapi.responses import JSONResponse, Response + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +_cached_adapter: Any = None + + +def get_anthropic_messages_adapter() -> Any: + """Return a cached instance of LiteLLM's Anthropic<->OpenAI adapter. + + The returned object exposes ``translate_anthropic_to_openai`` and + ``translate_openai_response_to_anthropic`` methods. + """ + global _cached_adapter + if _cached_adapter is not None: + return _cached_adapter + + try: + from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import ( + LiteLLMAnthropicMessagesAdapter, + ) + except ImportError as exc: + raise RuntimeError( + "The Anthropic Messages API (/v1/messages) requires the 'litellm' package. " + "Install it with: pip install 'lightllm[anthropic]' " + "(or directly: pip install 'litellm>=1.52.0,<1.85'). " + f"Original error: {exc}" + ) from exc + + _cached_adapter = LiteLLMAnthropicMessagesAdapter() + return _cached_adapter + + +# --------------------------------------------------------------------------- +# Request translation +# --------------------------------------------------------------------------- + + +def _anthropic_to_chat_request(anthropic_body: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, str]]: + """Translate an Anthropic Messages request body into a dict suitable + for constructing a LightLLM ``ChatCompletionRequest``. + + Returns ``(chat_request_dict, tool_name_mapping)``. The mapping must + be passed back to ``_chat_response_to_anthropic`` so that tool names + truncated by LiteLLM's 64-character limit can be restored. + """ + adapter = get_anthropic_messages_adapter() + + openai_request, tool_name_mapping = adapter.translate_anthropic_to_openai(anthropic_body) + + if hasattr(openai_request, "model_dump"): + openai_dict = openai_request.model_dump(exclude_none=True) + else: + openai_dict = dict(openai_request) + + if "max_tokens" not in openai_dict and "max_completion_tokens" not in openai_dict: + if "max_tokens" in anthropic_body: + openai_dict["max_tokens"] = anthropic_body["max_tokens"] + + # Forward LightLLM-specific fields nested under ``extra_body`` (OpenAI SDK + # convention) so clients hitting /v1/messages can reach ChatCompletionRequest + # options Anthropic's own schema does not expose — notably chat_template_kwargs + # for models with optional thinking modes (Qwen3, DeepSeek). Fields already + # produced by the Anthropic->OpenAI translation take precedence; unknown keys + # are silently dropped by Pydantic (extra='ignore'). + extra_body = anthropic_body.get("extra_body") + if isinstance(extra_body, dict): + for k, v in extra_body.items(): + openai_dict.setdefault(k, v) + + _UNKNOWN_FIELDS = {"extra_body", "metadata", "anthropic_version", "cache_control"} + dropped = [k for k in anthropic_body if k in _UNKNOWN_FIELDS] + if dropped: + logger.debug("Dropping Anthropic-only fields not forwarded to chat pipeline: %s", dropped) + for key in dropped: + openai_dict.pop(key, None) + + return openai_dict, tool_name_mapping + + +# --------------------------------------------------------------------------- +# Response translation +# --------------------------------------------------------------------------- + + +_FINISH_REASON_TO_STOP_REASON = { + "stop": "end_turn", + "length": "max_tokens", + "tool_calls": "tool_use", + None: "end_turn", +} + + +def _chat_response_to_anthropic( + chat_response: Any, + tool_name_mapping: Dict[str, str], + requested_model: str, +) -> Dict[str, Any]: + """Wrap a LightLLM ``ChatCompletionResponse`` into an Anthropic + Messages response dict. + + LiteLLM's ``translate_openai_response_to_anthropic`` requires a + ``litellm.ModelResponse`` object (discovered via Task 3's characterisation + test). We construct one from the LightLLM response's dict form. + """ + adapter = get_anthropic_messages_adapter() + if hasattr(chat_response, "model_dump"): + openai_dict = chat_response.model_dump(exclude_none=True) + else: + openai_dict = dict(chat_response) + + try: + # Lazy import so this module stays importable when litellm is absent. + from litellm import ModelResponse # type: ignore + + model_response = ModelResponse(**openai_dict) + anthropic_obj = adapter.translate_openai_response_to_anthropic(model_response, tool_name_mapping) + except Exception as exc: + logger.warning("LiteLLM response translation failed (%s); using fallback", exc) + return _fallback_openai_to_anthropic(openai_dict, requested_model) + + if hasattr(anthropic_obj, "model_dump"): + result = anthropic_obj.model_dump(exclude_none=True) + else: + result = dict(anthropic_obj) + + return _normalize_anthropic_response(result, requested_model) + + +def _normalize_anthropic_response(result: Dict[str, Any], requested_model: str) -> Dict[str, Any]: + """Cosmetic clean-ups applied to every non-streaming Anthropic response: + + - echo the client-supplied model name (LiteLLM sometimes emits the + upstream model id instead); + - force the Anthropic ``msg_`` id prefix (LiteLLM passes LightLLM's + raw numeric request id through, which confuses strict clients); + - set default ``type`` / ``role`` / ``stop_sequence`` when missing; + - drop empty text blocks (LiteLLM sometimes produces a leading + ``{"type":"text","text":""}`` before a tool_use block); + - strip the LiteLLM-specific ``provider_specific_fields`` leak from + every content block. + """ + result["model"] = requested_model + + if not str(result.get("id", "")).startswith("msg_"): + result["id"] = f"msg_{uuid.uuid4().hex[:24]}" + result.setdefault("type", "message") + result.setdefault("role", "assistant") + result.setdefault("stop_sequence", None) + + cleaned_content = [] + for block in result.get("content") or []: + if not isinstance(block, dict): + cleaned_content.append(block) + continue + if block.get("type") == "text" and not block.get("text"): + continue + block.pop("provider_specific_fields", None) + cleaned_content.append(block) + result["content"] = cleaned_content + + return result + + +def _fallback_openai_to_anthropic(openai_dict: Dict[str, Any], requested_model: str) -> Dict[str, Any]: + """Minimal hand-built OpenAI->Anthropic translation for text-only responses. + + Used only when LiteLLM's adapter raises on the response path. Does + not support tool_use; errors out loudly if tool calls are present + since silently dropping them would corrupt the response. + """ + choice = (openai_dict.get("choices") or [{}])[0] + message = choice.get("message") or {} + if message.get("tool_calls"): + raise RuntimeError("Fallback translator cannot handle tool_calls; LiteLLM adapter path is required.") + text = message.get("content") or "" + usage = openai_dict.get("usage") or {} + finish_reason = choice.get("finish_reason") + return { + "id": f"msg_{uuid.uuid4().hex[:24]}", + "type": "message", + "role": "assistant", + "model": requested_model, + "content": [{"type": "text", "text": text}], + "stop_reason": _FINISH_REASON_TO_STOP_REASON.get(finish_reason, "end_turn"), + "stop_sequence": None, + "usage": { + "input_tokens": int(usage.get("prompt_tokens", 0)), + "output_tokens": int(usage.get("completion_tokens", 0)), + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + }, + } + + +# --------------------------------------------------------------------------- +# Streaming bridge +# --------------------------------------------------------------------------- + + +def _sse_event(event_type: str, data_obj: Dict[str, Any]) -> bytes: + """Encode an Anthropic-style SSE event.""" + return f"event: {event_type}\ndata: {json.dumps(data_obj)}\n\n".encode("utf-8") + + +async def _openai_sse_to_anthropic_events( + openai_body_iterator, + requested_model: str, + message_id: str, +): + """Async generator: consume OpenAI-format SSE bytes and yield + Anthropic-format SSE event bytes. + + Handles both text deltas (emitted as text_delta content blocks) and + tool-call deltas (emitted as tool_use content blocks whose arguments + stream as input_json_delta events). Anthropic's protocol opens one + content block at a time — when switching between a text block and a + tool_use block (or between tool_use blocks) the current block is + closed before the next is opened. + """ + message_started = False + next_content_index = 0 + + # Currently open content block, if any. + # current_open is either None or a tuple ("text"|"tool_use", anthropic_index). + current_open = None + + text_block_index = None # Anthropic index of the active text block. + + # Per-tool-call state keyed by OpenAI streaming tool_calls[i].index. + # Each entry: {anthropic_index, id, name, started, buffered_args} + tool_state: Dict[int, Dict[str, Any]] = {} + + final_stop_reason = "end_turn" + final_output_tokens = 0 + final_input_tokens = 0 + + _OPENAI_TO_ANTHROPIC_STOP = { + "stop": "end_turn", + "length": "max_tokens", + "tool_calls": "tool_use", + } + + async for raw_chunk in openai_body_iterator: + if not raw_chunk: + continue + # chat_completions_impl yields str ("data: {...}\n\n"); some callers or + # middlewares may hand us bytes. Normalise to str so the splitter below + # does not have to branch on type. + if isinstance(raw_chunk, (bytes, bytearray)): + raw_chunk = raw_chunk.decode("utf-8", errors="replace") + # A single StreamingResponse chunk may contain multiple SSE lines. + for line in raw_chunk.split("\n"): + line = line.strip() + if not line or not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload == "[DONE]": + continue + try: + chunk = json.loads(payload) + except Exception: + logger.debug("Skipping non-JSON SSE payload: %r", payload) + continue + + # final_output_tokens is sourced exclusively from the trailing usage + # chunk emitted by chat_completions_impl; we intentionally do not + # estimate it per delta because that would diverge from the + # tokenizer-accurate count on any upstream change. + usage = chunk.get("usage") + if usage: + final_input_tokens = int(usage.get("prompt_tokens", 0)) + final_output_tokens = int(usage.get("completion_tokens", final_output_tokens)) + + choices = chunk.get("choices") or [] + if not choices: + continue + choice = choices[0] + delta = choice.get("delta") or {} + finish_reason = choice.get("finish_reason") + + # Emit message_start the first time we see any content. + # NOTE: The upstream usage chunk arrives AFTER all content chunks, so + # final_input_tokens is still 0 here. message_start.message.usage.input_tokens + # will always be 0 on this path — Anthropic clients that care about prompt + # token counts should read message_delta.usage instead. Fixing this would + # require buffering until the usage chunk arrives, trading streaming + # latency for accurate prompt-token reporting at message_start time. + if not message_started: + message_started = True + yield _sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": requested_model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": final_input_tokens, + "output_tokens": 0, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + }, + }, + }, + ) + + # ---- Text delta ---- + content_piece = delta.get("content") + if content_piece: + if current_open is None or current_open[0] != "text": + if current_open is not None: + yield _sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": current_open[1]}, + ) + text_block_index = next_content_index + next_content_index += 1 + current_open = ("text", text_block_index) + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": text_block_index, + "content_block": {"type": "text", "text": ""}, + }, + ) + yield _sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": text_block_index, + "delta": {"type": "text_delta", "text": content_piece}, + }, + ) + + # ---- Tool-call deltas ---- + for tc in delta.get("tool_calls") or []: + tc_idx = tc.get("index", 0) + fn = tc.get("function") or {} + state = tool_state.setdefault( + tc_idx, + { + "anthropic_index": None, + "id": None, + "name": None, + "started": False, + "buffered_args": "", + }, + ) + if tc.get("id"): + state["id"] = tc["id"] + if fn.get("name"): + state["name"] = fn["name"] + new_args = fn.get("arguments") or "" + + if not state["started"]: + # Buffer args until we know the tool name (required for + # content_block_start). + state["buffered_args"] += new_args + if not state["name"]: + continue + # Close whatever block is currently open (text or a + # previous tool_use) before opening this one. + if current_open is not None: + yield _sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": current_open[1]}, + ) + state["anthropic_index"] = next_content_index + next_content_index += 1 + current_open = ("tool_use", state["anthropic_index"]) + state["started"] = True + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": state["anthropic_index"], + "content_block": { + "type": "tool_use", + "id": state["id"] or f"toolu_{uuid.uuid4().hex[:24]}", + "name": state["name"], + "input": {}, + }, + }, + ) + if state["buffered_args"]: + yield _sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": state["anthropic_index"], + "delta": { + "type": "input_json_delta", + "partial_json": state["buffered_args"], + }, + }, + ) + state["buffered_args"] = "" + else: + # Already started. A delta for this tool-call index may + # arrive after a later tool-call has opened its own block. + # Anthropic's protocol forbids emitting deltas against a + # non-open index, so close whatever is currently open and + # reopen THIS block before emitting. + if new_args: + if current_open is None or current_open != ("tool_use", state["anthropic_index"]): + if current_open is not None: + yield _sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": current_open[1]}, + ) + current_open = ("tool_use", state["anthropic_index"]) + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": state["anthropic_index"], + "content_block": { + "type": "tool_use", + "id": state["id"] or f"toolu_{uuid.uuid4().hex[:24]}", + "name": state["name"], + "input": {}, + }, + }, + ) + yield _sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": state["anthropic_index"], + "delta": { + "type": "input_json_delta", + "partial_json": new_args, + }, + }, + ) + + if finish_reason: + final_stop_reason = _OPENAI_TO_ANTHROPIC_STOP.get(finish_reason, "end_turn") + + # Close any still-open content block. + if current_open is not None: + yield _sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": current_open[1]}, + ) + + # message_delta carries the final stop_reason and cumulative output_tokens. + if message_started: + yield _sse_event( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": final_stop_reason, "stop_sequence": None}, + "usage": {"input_tokens": final_input_tokens, "output_tokens": final_output_tokens}, + }, + ) + yield _sse_event("message_stop", {"type": "message_stop"}) + + +# --------------------------------------------------------------------------- +# Error response helper +# --------------------------------------------------------------------------- + + +# HTTP status → Anthropic error type. Derived from +# https://docs.anthropic.com/en/api/errors ; values outside this map fall +# back to "api_error". +_STATUS_TO_ERROR_TYPE = { + 400: "invalid_request_error", + 401: "authentication_error", + 403: "permission_error", + 404: "not_found_error", + 413: "request_too_large", + 429: "rate_limit_error", + 500: "api_error", + 529: "overloaded_error", +} + + +def _anthropic_error_response(status: HTTPStatus, message: str) -> JSONResponse: + """Return an Anthropic-shaped error envelope. + + Anthropic clients (including Claude Code) parse the {"type":"error", + "error":{"type":..., "message":...}} shape; the OpenAI-style envelope + from create_error_response hides the real message from them. + """ + err_type = _STATUS_TO_ERROR_TYPE.get(int(status), "api_error") + return JSONResponse( + {"type": "error", "error": {"type": err_type, "message": message}}, + status_code=int(status), + ) + + +def _rewrap_openai_error_as_anthropic(resp: JSONResponse) -> JSONResponse: + """Convert an OpenAI-format JSONResponse produced by create_error_response + into Anthropic's error envelope. Best-effort: if we can't decode the body + we leave the response alone so the caller still sees something.""" + try: + body = json.loads(bytes(resp.body).decode("utf-8")) + inner = (body or {}).get("error") or {} + message = inner.get("message") or "request failed" + except Exception: + return resp + return _anthropic_error_response(HTTPStatus(resp.status_code), message) + + +# --------------------------------------------------------------------------- +# HTTP entry point +# --------------------------------------------------------------------------- + + +async def anthropic_messages_impl(raw_request: Request) -> Response: + # Lazy imports to avoid pulling in heavy server deps at module import time. + from .api_models import ChatCompletionRequest, ChatCompletionResponse + from .api_openai import chat_completions_impl + + try: + raw_body = await raw_request.json() + except Exception as exc: + return _anthropic_error_response(HTTPStatus.BAD_REQUEST, f"Invalid JSON body: {exc}") + + if not isinstance(raw_body, dict): + return _anthropic_error_response(HTTPStatus.BAD_REQUEST, "Request body must be a JSON object") + + requested_model = raw_body.get("model", "default") + is_stream = bool(raw_body.get("stream")) + + try: + chat_dict, tool_name_mapping = _anthropic_to_chat_request(raw_body) + except Exception as exc: + logger.exception("Failed to translate Anthropic request") + return _anthropic_error_response(HTTPStatus.BAD_REQUEST, f"Request translation failed: {exc}") + + # Force the downstream path to stream if the client asked for stream. + chat_dict["stream"] = is_stream + + try: + chat_request = ChatCompletionRequest(**chat_dict) + except Exception as exc: + logger.exception("Failed to build ChatCompletionRequest") + return _anthropic_error_response(HTTPStatus.BAD_REQUEST, f"Invalid request after translation: {exc}") + + downstream = await chat_completions_impl(chat_request, raw_request) + + if is_stream: + from fastapi.responses import StreamingResponse + + if not isinstance(downstream, StreamingResponse): + # chat_completions_impl returned an OpenAI-format error — rewrap it. + if isinstance(downstream, JSONResponse): + return _rewrap_openai_error_as_anthropic(downstream) + return downstream + + message_id = f"msg_{uuid.uuid4().hex[:24]}" + anthropic_stream = _openai_sse_to_anthropic_events( + downstream.body_iterator, requested_model=requested_model, message_id=message_id + ) + return StreamingResponse(anthropic_stream, media_type="text/event-stream") + + if not isinstance(downstream, ChatCompletionResponse): + if isinstance(downstream, JSONResponse): + return _rewrap_openai_error_as_anthropic(downstream) + return downstream + + try: + anthropic_dict = _chat_response_to_anthropic(downstream, tool_name_mapping, requested_model) + except Exception as exc: + logger.error("Failed to translate response to Anthropic format: %s", exc) + return _anthropic_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, str(exc)) + return JSONResponse(anthropic_dict) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 7dcd7df1bb..04e0187452 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -11,18 +11,26 @@ def make_argument_parser() -> argparse.ArgumentParser: "normal", "prefill", "decode", - "nixl_prefill", - "nixl_decode", "pd_master", "config_server", "visual_only", ], default="normal", - help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, + help="""set run mode, normal is started for a single server, prefill/decode/pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, specifically designed for large-scale, high-concurrency scenarios where `pd_master` encounters significant CPU bottlenecks.""", ) + parser.add_argument( + "--performance_mode", + "--p_mode", + type=str, + choices=["personal"], + default=None, + help="""performance mode for different scenarios. + None: no performance mode applied (default). + personal: private personal running mode, automatically sets running_max_req_size to 3.""", + ) parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--httpserver_workers", type=int, default=1) @@ -45,12 +53,6 @@ def make_argument_parser() -> argparse.ArgumentParser: default=1212, help="when run_mode set to prefill or decode, you need set this pd_mater_port", ) - parser.add_argument( - "--pd_decode_rpyc_port", - type=int, - default=None, - help="p d mode, decode node used for kv move manager rpyc server port", - ) parser.add_argument( "--select_p_d_node_strategy", type=str, @@ -79,17 +81,17 @@ def make_argument_parser() -> argparse.ArgumentParser: proxy module use config server to find remote vit infer nodes to infer img""", ) parser.add_argument( - "--nixl_pd_kv_page_num", + "--pd_kv_page_num", type=int, default=16, - help="nixl pd mode, kv move page_num", + help="pd mode, kv move page_num", ) parser.add_argument( - "--nixl_pd_kv_page_size", + "--pd_kv_page_size", type=int, default=1024, - help="nixl pd mode, kv page size.", + help="pd mode, kv page size.", ) parser.add_argument( @@ -135,8 +137,8 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--mem_fraction", type=float, - default=0.9, - help="""Memory usage ratio, default is 0.9, you can specify a smaller value if OOM occurs at runtime. + default=0.8, + help="""Memory usage ratio, default is 0.8, you can specify a smaller value if OOM occurs at runtime. If max_total_token_num is not specified, it will be calculated automatically based on this value.""", ) parser.add_argument( @@ -183,6 +185,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "step3", "nano_v3", "interns1", + "gemma4", ], default=None, help="reasoning parser type", @@ -236,8 +239,10 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--max_req_total_len", type=int, - default=16384, + default=None, help="Maximum allowed length for a request (input tokens + output tokens). " + "If None, it will be automatically derived from the model config.json, " + "and fall back to 16384 if derivation fails. " "In PD (Prefill-Decode) mode, this value must be synchronized across the " "PD master, prefill, and decode nodes.", ) @@ -300,6 +305,12 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""aggressive schedule can lead to frequent prefill interruptions during decode. disabling it allows the router_max_wait_tokens parameter to work more effectively.""", ) + parser.add_argument( + "--enable_prefill_decode_mixed", + action="store_true", + help="""when run_mode is normal, allow prefill and decode requests to run in the same + scheduling step when both exist, improving throughput under aggressive schedule.""", + ) parser.add_argument( "--use_dynamic_prompt_cache", action="store_true", help="This argument is deprecated and no longer in use." @@ -339,8 +350,16 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service." ) - parser.add_argument("--disable_custom_allreduce", action="store_true", help="Whether to disable cutom allreduce.") - parser.add_argument("--enable_custom_allgather", action="store_true", help="Whether to enable cutom allgather.") + parser.add_argument( + "--disable_symm_mem_allreduce", + action="store_true", + help="Disable the default SymmMem all-reduce fast path and fall back to NCCL.", + ) + parser.add_argument( + "--disable_flashinfer_allreduce", + action="store_true", + help="Disable the default FlashInfer all-reduce fast path and fall back to SymmMem / NCCL.", + ) parser.add_argument( "--enable_tpsp_mix_mode", action="store_true", @@ -382,7 +401,7 @@ def make_argument_parser() -> argparse.ArgumentParser: default=["auto"], help="""decode attention kernel used in llm. auto: automatically select best backend based on GPU and available packages - (priority: fa3 > flashinfer > triton)""", + (priority: flashinfer > fa3 > triton)""", ) parser.add_argument( "--vit_att_backend", @@ -422,6 +441,18 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" ) + parser.add_argument( + "--max_image_token_count", + type=int, + default=8192, + help="maximum allowed token count for one image after tokenization", + ) + parser.add_argument( + "--max_image_pixels", + type=int, + default=8294400, + help="maximum allowed pixel count for one image before resize preprocessing", + ) parser.add_argument( "--embed_cache_storage_size", type=float, @@ -439,16 +470,6 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--use_reward_model", action="store_true", help="use reward model") - parser.add_argument( - "--long_truncation_mode", - type=str, - choices=[None, "head", "center"], - default=None, - help="""use to select the handle way when input_token_len + max_new_tokens > max_req_total_len. - None : raise Exception - head : remove some head tokens to make input_token_len + max_new_tokens <= max_req_total_len - center : remove some tokens in center loc to make input_token_len + max_new_tokens <= max_req_total_len""", - ) parser.add_argument("--use_tgi_api", action="store_true", help="use tgi input and ouput format") parser.add_argument( "--health_monitor", action="store_true", help="check the health of service and restart when error" @@ -538,7 +559,10 @@ def make_argument_parser() -> argparse.ArgumentParser: " currently only for llama and qwen model, not support ep moe model", ) parser.add_argument( - "--prefll_cudagraph_max_handle_token", type=int, default=512, help="max handle token num for prefill cudagraph" + "--prefill_cudagraph_max_handle_token", + type=int, + default=8192, + help="max handle token num for prefill cudagraph", ) parser.add_argument( @@ -591,6 +615,14 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Path of quantization config. It can be used for mixed quantization. Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml.""", ) + parser.add_argument( + "--expert_dtype", + type=str, + default=None, + choices=["fp8", "fp4"], + help="""Expert quantization dtype for EP MoE. Supported values are + fp8 and fp4. Note that fp4 is only supported on SM100 GPUs.""", + ) parser.add_argument( "--vit_quant_type", type=str, @@ -607,10 +639,10 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--sampling_backend", type=str, - choices=["triton", "sglang_kernel"], + choices=["triton", "flashinfer"], default="triton", help="""sampling used impl. 'triton' is use torch and triton kernel, - sglang_kernel use sglang_kernel impl""", + flashinfer use flashinfer sampling impl""", ) parser.add_argument( "--penalty_counter_mode", @@ -651,7 +683,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_fused_shared_experts", action="store_true", - help="""Whether to enable fused shared experts for deepseekv3 model. only work when tensor parallelism""", + help="""Whether to enable fused shared experts for supported MoE models. It is auto-enabled when supported.""", ) parser.add_argument( "--mtp_mode", @@ -717,7 +749,9 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_cpu_cache", action="store_true", - help="""enable cpu cache to store kv cache. prefer to use hugepages for better performance.""", + help="""enable cpu cache to store kv cache. prefer to use hugepages for better performance. + For linear attention cache reuse constraints, cpu cache token page size will be forced to + linear_att_page_block_num * linear_att_hash_page_size when cpu cache is enabled.""", ) parser.add_argument( "--cpu_cache_storage_size", @@ -748,31 +782,45 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) parser.add_argument( - "--mamba_cache_size", + "--linear_att_hash_page_size", type=int, - default=None, - help="""The size of linear attn cache. If not specified, will be calculated - automatically based on mamba_cache_ratio or max_total_token_num.""", + default=512, + help="""The hash page size for linear attention. + It controls the number of tokens in each hash bucket, which can affect radix cache reused""", ) parser.add_argument( - "--mamba_cache_ratio", - type=lambda v: float(v) - if 0.0 <= (_ := float(v)) <= 1.0 - else (_ for _ in ()).throw( - argparse.ArgumentTypeError(f"--mamba_cache_ratio must be between 0.0 and 1.0, got {v}") - ), - default=0.5, - help="""Ratio of mamba cache to total cache memory (mamba + KV). - Only effective when both mamba_cache_size and max_total_token_num are not set. - Default is 0.5 (50%% mamba cache, 50%% KV cache). - Example: 0.3 -> 30%% mamba, 70%% KV; 0.7 -> 70%% mamba, 30%% KV.""", + "--linear_att_page_block_num", + type=int, + default=10000000, + help="""The number of blocks for linear attention state storage. + It controls the number of pages used for storing the attention state, + which can affect memory usage and mutiturn chat performance. + Block size is linear_att_page_block_num * linear_att_hash_page_size. + When this value multiplied by linear_att_hash_page_size is greater than max_req_total_len, + block-level matching in radix cache is effectively disabled and request-level small-page + matching (linear_att_hash_page_size) may dominate.""", ) parser.add_argument( - "--mamba_ssm_data_type", + "--linear_att_cache_size", + type=int, + default=None, + help="""The size of linear attn cache. + If radix cache hit rate is low under high load due to limited small-page capacity and LRU + eviction, increasing linear_att_cache_size can improve hit rate at the cost of more memory.""", + ) + parser.add_argument( + "--linear_att_ssm_data_type", type=str, choices=["bfloat16", "float32"], default="float32", - help="the data type of the model weight", + help="the data type of linear att smm data type", + ) + parser.add_argument( + "--disable_linear_att_small_page_cpu_cache", + action="store_true", + default=False, + help="""Disable storing linear attention small page data in CPU cache. + This reduces CPU cache memory waste but also decreases the hit length.""", ) parser.add_argument( "--hardware_platform", @@ -794,4 +842,19 @@ def make_argument_parser() -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_profiling", + type=str, + choices=["torch_profiler", "nvtx"], + default=None, + help="""Enable profiler support. + This will expose '/profiler_start' and '/profiler_stop' API, + below profiling features will only be enabled in this range. + Options: + 'torch_profiler': will setup torch.profiler.profile(), trace files will be saved to './trace', + or set by 'LIGHTLLM_TRACE_DIR' env; + 'nvtx': will add NVTX marks for external profiler like NVIDIA Nsight System + (you should set it up by yourself). + A NVTX range named 'LIGHTLLM_PROFILE' will be added within the profiling range.""", + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 50d992bf9c..270e2a8cfd 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -47,7 +47,7 @@ from .api_lightllm import lightllm_get_score from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger -from lightllm.utils.error_utils import ServerBusyError +from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name from dataclasses import dataclass @@ -116,10 +116,58 @@ def set_args(self, args: StartArgs): app = FastAPI() g_objs.app = app +_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} +_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} +_ACCESS_LOG_RESET = "\033[0m" + + +class _AccessLogMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + status_holder = {"status": 0} + + async def send_wrapper(message): + if message["type"] == "http.response.start": + status_holder["status"] = message["status"] + await send(message) + + try: + await self.app(scope, receive, send_wrapper) + finally: + if scope["type"] == "http": + status = status_holder["status"] + msg = f"{scope['method']} {scope['path']} {status}" + color = _ACCESS_LOG_STATUS_COLORS.get(status // 100, "") + if color: + msg = color + msg + _ACCESS_LOG_RESET + logger.info(msg) + + +app.add_middleware(_AccessLogMiddleware) + + +def create_error_response( + status_code: HTTPStatus, message: str, err_type: str = None, param: str = None +) -> JSONResponse: + if err_type is None: + if status_code.value >= 500: + err_type = "InternalServerError" + elif status_code == HTTPStatus.NOT_FOUND: + err_type = "NotFoundError" + else: + err_type = "BadRequestError" -def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: g_objs.metric_client.counter_inc("lightllm_request_failure") - return JSONResponse({"message": message}, status_code=status_code.value) + return JSONResponse( + {"error": {"message": message, "type": err_type, "param": param, "code": status_code.value}}, + status_code=status_code.value, + ) @app.get("/liveness") @@ -149,13 +197,12 @@ async def healthcheck(request: Request): if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": return JSONResponse({"message": "Error"}, status_code=503) - from lightllm.utils.health_check import health_check, health_obj + from lightllm.utils.health_check import health_check - health_task = asyncio.create_task(health_check(g_objs.args, g_objs.httpserver_manager, None)) - if not health_obj.is_health(): - await health_task + is_healthy = health_check(g_objs.httpserver_manager.shm_req_manager) return JSONResponse( - {"message": "Ok" if health_obj.is_health() else "Error"}, status_code=200 if health_obj.is_health() else 503 + {"message": "Ok" if is_healthy else "Error"}, + status_code=200 if is_healthy else 503, ) @@ -184,7 +231,7 @@ async def token_load(request: Request): @app.post("/generate") async def generate(request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -194,6 +241,11 @@ async def generate(request: Request) -> Response: except ServerBusyError as e: logger.error("%s", str(e), exc_info=True) return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e)) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + except ClientDisconnected as e: + logger.warning(str(e)) + return Response(status_code=499) except Exception as e: logger.error("An error occurred: %s", str(e), exc_info=True) return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) @@ -201,7 +253,7 @@ async def generate(request: Request) -> Response: @app.post("/generate_stream") async def generate_stream(request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -211,6 +263,11 @@ async def generate_stream(request: Request) -> Response: except ServerBusyError as e: logger.error("%s", str(e), exc_info=True) return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e)) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + except ClientDisconnected as e: + logger.warning(str(e)) + return Response(status_code=499) except Exception as e: logger.error("An error occurred: %s", str(e), exc_info=True) return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) @@ -218,20 +275,23 @@ async def generate_stream(request: Request) -> Response: @app.post("/get_score") async def get_score(request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) try: return await lightllm_get_score(request, g_objs.httpserver_manager) + except ClientDisconnected as e: + logger.warning(str(e)) + return Response(status_code=499) except Exception as e: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) @app.post("/") async def compat_generate(request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -246,31 +306,58 @@ async def compat_generate(request: Request) -> Response: @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) - resp = await chat_completions_impl(request, raw_request) + try: + resp = await chat_completions_impl(request, raw_request) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + except ClientDisconnected as e: + logger.warning(str(e)) + return Response(status_code=499) return resp @app.post("/v1/completions", response_model=CompletionResponse) async def completions(request: CompletionRequest, raw_request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) - resp = await completions_impl(request, raw_request) + try: + resp = await completions_impl(request, raw_request) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + except ClientDisconnected as e: + logger.warning(str(e)) + return Response(status_code=499) return resp +@app.post("/v1/messages") +async def anthropic_messages(raw_request: Request) -> Response: + if get_env_start_args().run_mode in ["prefill", "decode"]: + return create_error_response( + HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" + ) + from .api_anthropic import anthropic_messages_impl + + try: + return await anthropic_messages_impl(raw_request) + except ClientDisconnected as e: + logger.warning(str(e)) + return Response(status_code=499) + + @app.get("/v1/models", response_model=ModelListResponse) -@app.post("/v1/models", response_model=ModelListResponse) async def get_models(raw_request: Request): model_name = g_objs.args.model_name - max_model_len = g_objs.args.max_req_total_len + max_model_len = g_objs.httpserver_manager.get_real_supported_max_req_total_len() + if model_name == "default_model_name" and g_objs.args.model_dir: model_name = os.path.basename(g_objs.args.model_dir.rstrip("/")) @@ -280,7 +367,7 @@ async def get_models(raw_request: Request): id=model_name, created=g_objs.model_created, max_model_len=max_model_len, - owned_by=g_objs.args.model_owner, + owned_by=g_objs.args.model_owner or "lightllm", ) ] ) @@ -309,6 +396,9 @@ async def tokens(request: Request): }, status_code=200, ) + except ClientDisconnected as e: + logger.warning(str(e)) + return Response(status_code=499) except Exception as e: return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") @@ -365,6 +455,24 @@ async def kv_move_status(websocket: WebSocket): return +@app.get("/profiler_start") +async def profiler_start() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("start") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + +@app.get("/profiler_stop") +async def profiler_stop() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("stop") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + @app.on_event("shutdown") async def shutdown(): logger.info("Received signal to shutdown. Performing graceful shutdown...") diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f54..39a5808aab 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -148,5 +148,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]: yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8") + from .api_openai import _safe_stream_wrapper + background_tasks = BackgroundTasks() - return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + return StreamingResponse( + _safe_stream_wrapper(stream_results()), media_type="text/event-stream", background=background_tasks + ) diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index ed4e053f79..1737d2774d 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -33,7 +33,6 @@ class Function(BaseModel): name: Optional[str] = None description: Optional[str] = Field(default=None, examples=[None]) parameters: Optional[dict] = None - response: Optional[dict] = None class Tool(BaseModel): @@ -96,6 +95,7 @@ class ChatCompletionMessageGenericParam(BaseModel): content: Union[str, List[MessageContent], None] = Field(default=None) tool_call_id: Optional[str] = None name: Optional[str] = None + reasoning: Optional[str] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) @@ -121,7 +121,7 @@ class CompletionRequest(BaseModel): prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None max_tokens: Optional[int] = Field( - default=16384, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" + default=65536, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" ) max_completion_tokens: Optional[int] = None temperature: Optional[float] = 1.0 @@ -197,7 +197,7 @@ class ChatCompletionRequest(BaseModel): stream_options: Optional[StreamOptions] = None stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = Field( - default=16384, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" + default=65536, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" ) max_completion_tokens: Optional[int] = None presence_penalty: Optional[float] = 0.0 @@ -221,6 +221,7 @@ class ChatCompletionRequest(BaseModel): parallel_tool_calls: Optional[bool] = True # OpenAI parameters for reasoning and others + reasoning_effort: Optional[Literal["low", "medium", "high"]] = None chat_template_kwargs: Optional[Dict] = None separate_reasoning: Optional[bool] = True stream_reasoning: Optional[bool] = False @@ -278,15 +279,22 @@ def sync_thinking_chat_template_kwargs(self): return self +class PromptTokensDetails(BaseModel): + cached_tokens: int = 0 + audio_tokens: int = 0 + + class UsageInfo(BaseModel): prompt_tokens: int = 0 completion_tokens: Optional[int] = 0 total_tokens: int = 0 + prompt_tokens_details: Optional[PromptTokensDetails] = None class ChatMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + reasoning: Optional[str] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) @@ -314,6 +322,7 @@ class DeltaMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + reasoning: Optional[str] = None reasoning_content: Optional[str] = None diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 1a17691a95..0d934c44c9 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -30,6 +30,7 @@ from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster from .api_lightllm import lightllm_get_score from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size +from lightllm.utils.error_utils import ClientDisconnected from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient @@ -47,6 +48,7 @@ FunctionResponse, ToolCall, UsageInfo, + PromptTokensDetails, ChatMessage, ChatCompletionResponseChoice, ChatCompletionResponse, @@ -58,11 +60,51 @@ logger = init_logger(__name__) -def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: +async def _safe_stream_wrapper(stream_generator): + """Wrap a streaming generator to catch ValueError (e.g. input too long) and yield an SSE error + event instead of letting the exception propagate to Starlette which prints a long traceback.""" + try: + async for item in stream_generator: + yield item + except ValueError as e: + error_data = json.dumps({"error": {"message": str(e), "type": "invalid_request_error"}}, ensure_ascii=False) + yield f"data: {error_data}\n\n" + except ClientDisconnected as e: + logger.warning(str(e)) + # Client is gone — there's no point yielding more SSE chunks. Stop quietly. + return + + +def _serialize_sse_chunk(chunk, choice_nulls=(), response_nulls=()): + """Serialize a streaming chunk, explicitly including specified null fields.""" + d = chunk.model_dump(exclude_none=True) + if choice_nulls and d.get("choices"): + for choice in d["choices"]: + for field in choice_nulls: + choice[field] = None + for field in response_nulls: + d[field] = None + return json.dumps(d, ensure_ascii=False) + + +def create_error_response( + status_code: HTTPStatus, message: str, err_type: str = None, param: str = None +) -> JSONResponse: from .api_http import g_objs + if err_type is None: + if status_code.value >= 500: + err_type = "InternalServerError" + elif status_code == HTTPStatus.NOT_FOUND: + err_type = "NotFoundError" + else: + err_type = "BadRequestError" + g_objs.metric_client.counter_inc("lightllm_request_failure") - return JSONResponse({"message": message}, status_code=status_code.value) + return JSONResponse( + {"error": {"message": message, "type": err_type, "param": param, "code": status_code.value}}, + status_code=status_code.value, + ) def _process_tool_call_id( @@ -111,15 +153,22 @@ def _get_history_tool_calls_cnt(request: ChatCompletionRequest) -> int: return idx -def _get_reasoning_from_request(request: ChatCompletionRequest) -> bool: - """Judge whether the request needs reasoning""" +def _is_force_thinking_mode(request: ChatCompletionRequest) -> bool: + """Whether this request uses forced thinking / reasoning (parser + template).""" + from .build_prompt import tokenizer_supports_force_thinking + + if not tokenizer_supports_force_thinking(): + return False + reasoning_parser = get_env_start_args().reasoning_parser if not reasoning_parser: return False + if reasoning_parser in ["qwen3-thinking", "gpt-oss", "minimax"]: + return True if reasoning_parser in ["deepseek-v3"]: return request.chat_template_kwargs is not None and request.chat_template_kwargs.get("thinking") is True - if reasoning_parser in ["qwen3", "glm45", "nano_v3", "interns1"]: - # qwen3, glm45, nano_v3, and interns1 are reasoning by default + if reasoning_parser in ["qwen3", "glm45", "nano_v3", "interns1", "gemma4"]: + # qwen3, glm45, nano_v3, interns1, and gemma4 are reasoning by default; return not request.chat_template_kwargs or request.chat_template_kwargs.get("enable_thinking", True) is True return True # default @@ -133,7 +182,7 @@ def _process_reasoning_stream( ) -> tuple[Optional[str], str]: """Process reasoning content in streaming response""" if index not in reasoning_parser_dict: - request_enable_reasoning = _get_reasoning_from_request(request) + request_enable_reasoning = _is_force_thinking_mode(request) reasoning_parser_dict[index] = ReasoningParser( get_env_start_args().reasoning_parser, request.stream_reasoning, @@ -237,14 +286,20 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req tools = None if request.tools and request.tool_choice != "none": # request.skip_special_tokens = False + # exclude_none=True so optional default-None fields (e.g. ``response``) + # don't surface in the chat-template render — Function.model_dump() + # otherwise emits {"response": None}, which chat.jinja's + # render_extra_keys turns into ``null`` and adds + # ~7 tokens per tool, drifting prompts away from other engines/clients + # that pass tools without that field. if not isinstance(request.tool_choice, str): tools = [ - item.function.model_dump() + item.function.model_dump(exclude_none=True) for item in request.tools if item.function.name == request.tool_choice.function.name ] else: - tools = [item.function.model_dump() for item in request.tools] + tools = [item.function.model_dump(exclude_none=True) for item in request.tools] prompt = await build_prompt(request, tools) sampling_params_dict = { @@ -262,6 +317,16 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "seed": request.seed, } + # Gemma-4's reasoning delimiters (<|channel>=100, =101) are + # special tokens. The default skip_special_tokens=True would drop them + # from the decoded stream and the Gemma4Detector would be unable to + # find the reasoning boundary. Mirrors vllm's + # Gemma4ReasoningParser.adjust_request behaviour. Only applied when no + # explicit value is supplied so callers can still opt back into the + # default if they want. + if get_env_start_args().reasoning_parser == "gemma4" and "skip_special_tokens" not in sampling_params_dict: + sampling_params_dict["skip_special_tokens"] = False + if request.max_completion_tokens is not None: sampling_params_dict["max_new_tokens"] = request.max_completion_tokens elif request.max_tokens is not None: @@ -295,6 +360,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req count_output_tokens_dict = collections.defaultdict(lambda: 0) finish_reason_dict = {} prompt_tokens_dict = {} + prompt_cache_len_dict = {} completion_tokens = 0 async for sub_req_id, request_output, metadata, finish_status in results_generator: from .req_id_generator import convert_sub_id_to_group_id @@ -305,27 +371,29 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status.get_finish_reason() prompt_tokens_dict[sub_req_id] = metadata["prompt_tokens"] + prompt_cache_len_dict[sub_req_id] = metadata.get("prompt_cache_len", 0) choices = [] sub_ids = list(final_output_dict.keys())[: request.n] for i in range(request.n): sub_req_id = sub_ids[i] prompt_tokens = prompt_tokens_dict[sub_req_id] completion_tokens = count_output_tokens_dict[sub_req_id] + cached_tokens = prompt_cache_len_dict.get(sub_req_id, 0) usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, + prompt_tokens_details=PromptTokensDetails(cached_tokens=cached_tokens), ) finish_reason = finish_reason_dict[sub_req_id] text = "".join(final_output_dict[sub_req_id]) - full_text = text # Handle reasoning content reasoning_text = None reasoning_parser = get_env_start_args().reasoning_parser - if reasoning_parser and request.separate_reasoning: - request_enable_reasoning = _get_reasoning_from_request(request) + if reasoning_parser: + request_enable_reasoning = _is_force_thinking_mode(request) try: parser = ReasoningParser( model_type=reasoning_parser, @@ -339,17 +407,20 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req HTTPStatus.BAD_REQUEST, "Failed to parse fc related info to json format!", ) + if not request.separate_reasoning: + text = (reasoning_text or "") + (text or "") + reasoning_text = None # Handle tool_calls parsing tool_calls = None tool_choice = request.tool_choice tools = request.tools - if tool_choice != "none" and any([i in full_text for i in TOOLS_TAG_LIST]): + if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]): try: # 为 tool_call_parser 提供默认值 tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3" parser = FunctionCallParser(tools, tool_parser) - full_normal_text, call_info_list = parser.parse_non_stream(full_text) + text, call_info_list = parser.parse_non_stream(text) tool_calls = [] history_tool_calls_cnt = _get_history_tool_calls_cnt(request) for call_info in call_info_list: @@ -370,12 +441,11 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req ) if tool_calls and finish_reason == "stop": finish_reason = "tool_calls" - text = "" chat_message = ChatMessage( role="assistant", content=text if text else "", tool_calls=tool_calls, - reasoning_content=reasoning_text if reasoning_text else "", + reasoning=reasoning_text if reasoning_text else "", ) choice = ChatCompletionResponseChoice( index=i, @@ -391,16 +461,28 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req parser_dict = {} reasoning_parser_dict = {} + # Pre-generate a UUID-style request ID (matching the 36888 service format) + chat_completion_id = f"chatcmpl-{uuid.uuid4().hex}" + + # Common null fields to include in every streamed choice chunk + _choice_nulls = ("logprobs", "token_ids", "finish_reason") + _first_choice_nulls = ("logprobs", "finish_reason") + _final_choice_nulls = ("logprobs", "token_ids", "stop_reason") + _first_resp_nulls = ("prompt_token_ids",) + # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: has_emitted_tool_calls: Dict[int, bool] = collections.defaultdict(bool) + has_emitted_first_chunk: Dict[int, bool] = collections.defaultdict(bool) stream_tool_call_ids: Dict[Tuple[int, int], str] = {} from .req_id_generator import convert_sub_id_to_group_id prompt_tokens = 0 completion_tokens = 0 + cached_tokens = 0 async for sub_req_id, request_output, metadata, finish_status in results_generator: prompt_tokens = metadata["prompt_tokens"] + cached_tokens = metadata.get("prompt_cache_len", 0) completion_tokens += 1 group_request_id = convert_sub_id_to_group_id(sub_req_id) choice_index = sub_req_id - group_request_id @@ -408,24 +490,44 @@ async def stream_results() -> AsyncGenerator[bytes, None]: delta = request_output current_finish_reason = finish_status.get_finish_reason() + # Emit the initial role-only chunk once per choice, as required by the + # OpenAI SSE spec: role appears only in the first delta with content="". + if not has_emitted_first_chunk[choice_index]: + has_emitted_first_chunk[choice_index] = True + first_choice = ChatCompletionStreamResponseChoice( + index=choice_index, + delta=DeltaMessage(role="assistant", content=""), + finish_reason=None, + ) + first_chunk = ChatCompletionStreamResponse( + id=chat_completion_id, + created=created_time, + model=request.model, + choices=[first_choice], + ) + yield f"data: {_serialize_sse_chunk(first_chunk, _first_choice_nulls, _first_resp_nulls)}\n\n" + # Handle reasoning content - if get_env_start_args().reasoning_parser and request.separate_reasoning: + if get_env_start_args().reasoning_parser: reasoning_text, delta = _process_reasoning_stream( choice_index, delta, reasoning_parser_dict, request_output, request ) if reasoning_text: - choice_data = ChatCompletionStreamResponseChoice( - index=choice_index, - delta=DeltaMessage(role="assistant", reasoning_content=reasoning_text), - finish_reason=None, - ) - chunk = ChatCompletionStreamResponse( - id=group_request_id, - created=created_time, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + if request.separate_reasoning: + choice_data = ChatCompletionStreamResponseChoice( + index=choice_index, + delta=DeltaMessage(reasoning=reasoning_text), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=chat_completion_id, + created=created_time, + choices=[choice_data], + model=request.model, + ) + yield f"data: {_serialize_sse_chunk(chunk, _choice_nulls)}\n\n" + else: + delta = reasoning_text + (delta or "") if request.tool_choice != "none" and request.tools: # parse_increment => returns (normal_text, calls) @@ -437,16 +539,16 @@ async def stream_results() -> AsyncGenerator[bytes, None]: if normal_text and (normal_text.strip() or not has_emitted_tool_calls[sub_req_id]): choice_data = ChatCompletionStreamResponseChoice( index=choice_index, - delta=DeltaMessage(role="assistant", content=normal_text), + delta=DeltaMessage(content=normal_text), finish_reason=None, ) chunk = ChatCompletionStreamResponse( - id=group_request_id, + id=chat_completion_id, created=created_time, choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + yield f"data: {_serialize_sse_chunk(chunk, _choice_nulls)}\n\n" # 2) if we found calls, we output them as separate chunk(s) history_tool_calls_cnt = _get_history_tool_calls_cnt(request) @@ -504,12 +606,12 @@ async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason=None, ) head_chunk = ChatCompletionStreamResponse( - id=group_request_id, + id=chat_completion_id, created=created_time, choices=[head_choice], model=request.model, ) - yield f"data: {head_chunk.model_dump_json(exclude_none=True)}\n\n" + yield f"data: {_serialize_sse_chunk(head_chunk, _choice_nulls)}\n\n" for arg_delta in _split_tool_argument_delta(call_item.parameters): arg_tool_call = ToolCall( @@ -522,12 +624,12 @@ async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason=None, ) arg_chunk = ChatCompletionStreamResponse( - id=group_request_id, + id=chat_completion_id, created=created_time, choices=[arg_choice], model=request.model, ) - yield f"data: {arg_chunk.model_dump_json(exclude_none=True)}\n\n" + yield f"data: {_serialize_sse_chunk(arg_chunk, _choice_nulls)}\n\n" else: tool_call = ToolCall( id=tool_call_id if is_tool_head else None, @@ -548,27 +650,87 @@ async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason=None, ) chunk = ChatCompletionStreamResponse( - id=group_request_id, + id=chat_completion_id, created=created_time, choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + yield f"data: {_serialize_sse_chunk(chunk, _choice_nulls)}\n\n" else: - delta_message = DeltaMessage(role="assistant", content=delta) - stream_choice = ChatCompletionStreamResponseChoice( - index=choice_index, delta=delta_message, finish_reason=None - ) - stream_resp = ChatCompletionStreamResponse( - id=group_request_id, - created=created_time, - model=request.model, - choices=[stream_choice], - ) - yield f"data: {stream_resp.model_dump_json(exclude_none=True)}\n\n" + if delta: + # If this is the final token, merge content with finish_reason + if current_finish_reason is not None: + if has_emitted_tool_calls[sub_req_id] and current_finish_reason == "stop": + current_finish_reason = "tool_calls" + delta_message = DeltaMessage(content=delta) + stream_choice = ChatCompletionStreamResponseChoice( + index=choice_index, delta=delta_message, finish_reason=current_finish_reason + ) + stream_resp = ChatCompletionStreamResponse( + id=chat_completion_id, + created=created_time, + model=request.model, + choices=[stream_choice], + ) + yield f"data: {_serialize_sse_chunk(stream_resp, _final_choice_nulls)}\n\n" + # Skip the separate final-chunk logic below + continue + else: + delta_message = DeltaMessage(content=delta) + stream_choice = ChatCompletionStreamResponseChoice( + index=choice_index, delta=delta_message, finish_reason=None + ) + stream_resp = ChatCompletionStreamResponse( + id=chat_completion_id, + created=created_time, + model=request.model, + choices=[stream_choice], + ) + yield f"data: {_serialize_sse_chunk(stream_resp, _choice_nulls)}\n\n" - # Emit a per-choice final empty chunk with finish_reason. + # Emit a per-choice final chunk with finish_reason (for tool_calls path + # or when no delta was emitted alongside finish_reason). if current_finish_reason is not None: + # Flush any buffered reasoning content that was never released + # (e.g., max_completion_tokens hit before was seen). + if get_env_start_args().reasoning_parser: + parser = reasoning_parser_dict.get(choice_index) + if parser is not None: + flush_reasoning, flush_text = parser.flush() + if flush_reasoning: + if request.separate_reasoning: + flush_choice = ChatCompletionStreamResponseChoice( + index=choice_index, + delta=DeltaMessage(reasoning=flush_reasoning), + finish_reason=None, + ) + else: + # vLLM compat: emit buffered thinking as content + flush_choice = ChatCompletionStreamResponseChoice( + index=choice_index, + delta=DeltaMessage(content=flush_reasoning), + finish_reason=None, + ) + flush_chunk = ChatCompletionStreamResponse( + id=chat_completion_id, + created=created_time, + model=request.model, + choices=[flush_choice], + ) + yield f"data: {_serialize_sse_chunk(flush_chunk, _choice_nulls)}\n\n" + if flush_text: + flush_choice = ChatCompletionStreamResponseChoice( + index=choice_index, + delta=DeltaMessage(content=flush_text), + finish_reason=None, + ) + flush_chunk = ChatCompletionStreamResponse( + id=chat_completion_id, + created=created_time, + model=request.model, + choices=[flush_choice], + ) + yield f"data: {_serialize_sse_chunk(flush_chunk, _choice_nulls)}\n\n" if has_emitted_tool_calls[sub_req_id] and current_finish_reason == "stop": current_finish_reason = "tool_calls" final_choice = ChatCompletionStreamResponseChoice( @@ -577,30 +739,34 @@ async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason=current_finish_reason, ) final_chunk = ChatCompletionStreamResponse( - id=group_request_id, + id=chat_completion_id, created=created_time, model=request.model, choices=[final_choice], ) - yield f"data: {final_chunk.model_dump_json(exclude_none=True)}\n\n" + yield f"data: {_serialize_sse_chunk(final_chunk, _final_choice_nulls)}\n\n" - if request.stream_options and request.stream_options.include_usage: - usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - usage_chunk = ChatCompletionStreamResponse( - id=group_request_id, - created=created_time, - choices=[], # Empty choices array as per OpenAI spec - model=request.model, - usage=usage, - ) - yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + prompt_tokens_details=PromptTokensDetails(cached_tokens=cached_tokens), + ) + usage_chunk = ChatCompletionStreamResponse( + id=chat_completion_id, + created=created_time, + choices=[], # Empty choices array as per OpenAI spec + model=request.model, + usage=usage, + ) + yield f"data: {json.dumps(usage_chunk.model_dump(exclude_none=True), ensure_ascii=False)}\n\n" + + yield "data: [DONE]\n\n".encode("utf-8") background_tasks = BackgroundTasks() - return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + return StreamingResponse( + _safe_stream_wrapper(stream_results()), media_type="text/event-stream", background=background_tasks + ) async def completions_impl(request: CompletionRequest, raw_request: Request) -> Response: @@ -755,11 +921,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]: prompt_tokens = 0 completion_tokens = 0 + cached_tokens = 0 async for sub_req_id, request_output, metadata, finish_status in results_generator: group_request_id = convert_sub_id_to_group_id(sub_req_id) choice_index = sub_req_id - group_request_id prompt_tokens = metadata["prompt_tokens"] + cached_tokens = metadata.get("prompt_cache_len", 0) completion_tokens += 1 current_finish_reason = None if finish_status.is_finished(): @@ -784,27 +952,29 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, choices=[stream_choice], ) - yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") + yield f"data: {json.dumps(stream_resp.model_dump(), ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n".encode("utf-8") + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + prompt_tokens_details=PromptTokensDetails(cached_tokens=cached_tokens), + ) + usage_chunk = CompletionStreamResponse( + id=group_request_id, + created=created_time, + choices=[], # Empty choices array as per OpenAI spec + model=request.model, + usage=usage, + ) + yield f"data: {json.dumps(usage_chunk.model_dump(), ensure_ascii=False)}\n\n" - if request.stream_options and request.stream_options.include_usage: - usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - usage_chunk = CompletionStreamResponse( - id=group_request_id, - created=created_time, - choices=[], # Empty choices array as per OpenAI spec - model=request.model, - usage=usage, - ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" background_tasks = BackgroundTasks() - return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + return StreamingResponse( + _safe_stream_wrapper(stream_results()), media_type="text/event-stream", background=background_tasks + ) async def _collect_generation_results( @@ -814,6 +984,7 @@ async def _collect_generation_results( count_output_tokens = 0 finish_reason = None prompt_tokens = 0 + prompt_cache_len = 0 token_infos = [] if request.logprobs is not None else None prompt_logprobs = None prompt_token_ids = None @@ -839,6 +1010,7 @@ async def _collect_generation_results( if finish_status.is_finished(): finish_reason = finish_status.get_finish_reason() prompt_tokens = metadata["prompt_tokens"] + prompt_cache_len = metadata.get("prompt_cache_len", 0) # 处理停止序列剔除 final_text = "".join(final_output) @@ -856,6 +1028,7 @@ async def _collect_generation_results( "text": final_text, "finish_reason": finish_reason, "prompt_tokens": prompt_tokens, + "prompt_cache_len": prompt_cache_len, "completion_tokens": count_output_tokens, "token_infos": token_infos, "prompt_logprobs": prompt_logprobs, @@ -870,6 +1043,7 @@ def _build_completion_response(results: List[Dict], request: CompletionRequest, choices = [] total_prompt_tokens = 0 total_completion_tokens = 0 + total_cached_tokens = 0 for result in results: text = result["text"] @@ -888,11 +1062,13 @@ def _build_completion_response(results: List[Dict], request: CompletionRequest, total_prompt_tokens += result["prompt_tokens"] total_completion_tokens += result["completion_tokens"] + total_cached_tokens += result.get("prompt_cache_len", 0) usage = UsageInfo( prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, total_tokens=total_prompt_tokens + total_completion_tokens, + prompt_tokens_details=PromptTokensDetails(cached_tokens=total_cached_tokens), ) if is_batch: @@ -915,8 +1091,6 @@ def _build_logprobs_data(result: Dict, request: CompletionRequest, tokenizer) -> offset = 0 def add_tokens_to_logprobs(token_ids=None, token_infos=None, logprob_map=None): - nonlocal offset - def add_single_token(token_text: str, logprob: float): nonlocal offset all_tokens.append(token_text) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index cfb59d18bb..c9b82e1e0c 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -4,6 +4,7 @@ import uuid import subprocess import signal +import math from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker from lightllm.utils.start_utils import process_manager, kill_recursive from .metrics.manager import start_metric_manager @@ -17,7 +18,14 @@ from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size -from lightllm.utils.config_utils import has_audio_module, has_vision_module +from lightllm.utils.config_utils import ( + has_audio_module, + has_vision_module, + is_linear_att_mixed_model, + auto_set_max_req_total_len, + auto_set_fused_shared_experts, +) +from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args logger = init_logger(__name__) @@ -68,6 +76,8 @@ def normal_or_p_d_start(args): args: StartArgs = args + auto_set_max_req_total_len(args) + auto_set_fused_shared_experts(args) set_unique_server_name(args) if args.enable_mps: @@ -75,7 +85,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual_only"]: + if args.run_mode not in ["normal", "prefill", "decode", "visual_only"]: return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 @@ -91,7 +101,7 @@ def normal_or_p_d_start(args): args.disable_audio = True # pd 分离模式下,不启动多模态的模块 - if args.run_mode in ["decode", "nixl_decode"]: + if args.run_mode == "decode": args.disable_audio = True args.disable_vision = True @@ -118,6 +128,19 @@ def normal_or_p_d_start(args): if args.diverse_mode: assert args.router_token_ratio == 0.0 + # performance_mode 参数处理 + if args.performance_mode == "personal": + args.running_max_req_size = 6 + args.batch_max_tokens = 2048 + args.chunked_prefill_size = 1024 + args.embed_cache_storage_size = 0.8 + args.graph_max_batch_size = 6 + logger.info( + f"performance_mode is personal, set running_max_req_size to 3," + f"batch_max_tokens to 2048, chunked_prefill_size to 1024," + f"graph_max_batch_size to 32" + ) + if not args.disable_shm_warning: check_recommended_shm_size(args) @@ -165,9 +188,28 @@ def normal_or_p_d_start(args): args.kv_quant_calibration_config_path is not None ), "fp8kv inference mode requires --kv_quant_calibration_config_path. " + if args.enable_prefill_microbatch_overlap or args.enable_decode_microbatch_overlap: + args.enable_tpsp_mix_mode = True + + if args.enable_prefill_decode_mixed: + assert args.run_mode == "normal", "--enable_prefill_decode_mixed only supports run_mode normal" + if args.enable_dp_prefill_balance: assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly and --dp > 1" + if args.enable_ep_moe: + allowed_ep_att_backends = {"auto", "fa3", "triton"} + for backend in args.llm_prefill_att_backend: + assert backend in allowed_ep_att_backends, ( + "When --enable_ep_moe is enabled, --llm_prefill_att_backend must be one of " + f"{sorted(allowed_ep_att_backends)}; flashinfer is not supported." + ) + for backend in args.llm_decode_att_backend: + assert backend in allowed_ep_att_backends, ( + "When --enable_ep_moe is enabled, --llm_decode_att_backend must be one of " + f"{sorted(allowed_ep_att_backends)}; flashinfer is not supported." + ) + # mtp params check if args.mtp_mode is not None: assert args.mtp_draft_model_dir is not None @@ -249,6 +291,18 @@ def normal_or_p_d_start(args): ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " f"but got {args.batch_max_tokens}, {args.chunked_prefill_size}" + # linear att cache 参数自动设置 + if args.linear_att_cache_size is None: + # linear_att_cache_size 只会在 qwen3.5 等混合线性层模型中生效。 + default_cache_size = args.running_max_req_size * 2 + dp_size_in_node = max(1, args.dp // args.nnodes) + per_dp_cache_size = max(1, math.ceil(args.running_max_req_size / dp_size_in_node) * 2) + args.linear_att_cache_size = min(default_cache_size, per_dp_cache_size) + + if args.enable_cpu_cache and is_linear_att_mixed_model(args.model_dir): + args.cpu_cache_token_page_size = args.linear_att_hash_page_size * args.linear_att_page_block_num + logger.info(f"set cpu_cache_token_page_size to {args.cpu_cache_token_page_size} for linear hybrid att model") + # help to manage data stored on Ceph if "s3://" in args.model_dir: from lightllm.utils.petrel_helper import s3_model_prepare @@ -261,6 +315,22 @@ def normal_or_p_d_start(args): args.eos_id = get_eos_token_ids(args.model_dir) + # 如果 tool_call_parser 是 None,尝试根据模型类型自动设置 + if args.tool_call_parser is None: + from lightllm.utils.config_utils import get_tool_call_parser_for_model + + args.tool_call_parser = get_tool_call_parser_for_model(args.model_dir) + if args.tool_call_parser: + logger.info(f"Auto set tool_call_parser to {args.tool_call_parser} based on model type") + + # 如果 reasoning_parser 是 None,尝试根据模型类型自动设置 + if args.reasoning_parser is None: + from lightllm.utils.config_utils import get_reasoning_parser_for_model + + args.reasoning_parser = get_reasoning_parser_for_model(args.model_dir) + if args.reasoning_parser: + logger.info(f"Auto set reasoning_parser to {args.reasoning_parser} based on model type") + if args.data_type is None: from lightllm.utils.config_utils import get_dtype @@ -270,8 +340,6 @@ def normal_or_p_d_start(args): already_uesd_ports = [args.port] if args.nccl_port is not None: already_uesd_ports.append(args.nccl_port) - if args.pd_decode_rpyc_port is not None: - already_uesd_ports.append(args.pd_decode_rpyc_port) if args.visual_nccl_ports is not None: already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) if not args.disable_audio and args.audio_nccl_ports is not None: @@ -291,6 +359,7 @@ def normal_or_p_d_start(args): ( nccl_port, router_port, + router_profiler_port, detokenization_port, http_server_port, visual_port, @@ -298,7 +367,6 @@ def normal_or_p_d_start(args): cache_port, metric_port, multi_level_kv_cache_port, - pd_decode_rpyc_port, ) = can_use_ports[0:10] can_use_ports = can_use_ports[10:] @@ -317,9 +385,8 @@ def normal_or_p_d_start(args): # 将申请好的端口放入args参数中 if args.nccl_port is None: args.nccl_port = nccl_port - if args.pd_decode_rpyc_port is None: - args.pd_decode_rpyc_port = pd_decode_rpyc_port args.router_port = router_port + args.router_profiler_port = router_profiler_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port args.visual_port = visual_port @@ -331,10 +398,6 @@ def normal_or_p_d_start(args): args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] # p d 分离模式下用于标识节点的id args.pd_node_id = uuid.uuid4().int - # p 节点用来建立torch kv 传输分布组的可用端口范围 - args.pd_p_allowed_port_min = 20000 - args.pd_p_allowed_port_max = 30000 - # p d 分离模式下,decode节点的调度间隙是0 if args.run_mode == "decode": args.router_max_wait_tokens = 0 @@ -348,6 +411,8 @@ def normal_or_p_d_start(args): overriding enable_dp_prompt_cache_fetch to False""" ) + auto_configure_allreduce_flags_from_args(args) + set_env_start_args(args) logger.info(f"all start args:{args}") @@ -463,6 +528,8 @@ def pd_master_start(args): if args.run_mode != "pd_master": return + auto_set_max_req_total_len(args) + # when use config_server to support multi pd_master node, we # need generate unique node id for each pd_master node. # otherwise, we use the 0 for single pd_master node. diff --git a/lightllm/server/api_tgi.py b/lightllm/server/api_tgi.py index bd72ea695d..f4a7cf6a5a 100755 --- a/lightllm/server/api_tgi.py +++ b/lightllm/server/api_tgi.py @@ -187,7 +187,11 @@ async def stream_results() -> AsyncGenerator[bytes, None]: "prompt_tokens": metadata.get("prompt_tokens", 0), } - yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8") + yield "data:" + json.dumps(ret, ensure_ascii=False) + "\n\n" + + from .api_openai import _safe_stream_wrapper background_tasks = BackgroundTasks() - return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + return StreamingResponse( + _safe_stream_wrapper(stream_results()), media_type="text/event-stream", background=background_tasks + ) diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index 39a7e06ac3..82919856d9 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -51,6 +51,7 @@ def exposed_init_model(self, kvargs): self.model.load_model(weight_dir, model_cfg) self.model = self.model.cuda() + self.model.check_long_audio_infer() self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 21c1cde678..54d22a0d0d 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -2,6 +2,8 @@ import json from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger +from functools import lru_cache +from lightllm.utils.config_utils import get_model_type_v1 logger = init_logger(__name__) @@ -45,6 +47,32 @@ def init_tokenizer(args): return +@lru_cache(maxsize=1) +def tokenizer_supports_force_thinking() -> bool: + """Whether this tokenizer supports thinking / reasoning.""" + + assert tokenizer is not None + + try: + ans = "thinking" in tokenizer.chat_template or "enable_thinking" in tokenizer.chat_template + logger.debug(f"chat_template: {tokenizer.chat_template}") + logger.info(f"tokenizer_supports_force_thinking : {ans}") + return ans + except: + pass + + try: + ans = "thinking" in tokenizer.tokenizer.chat_template or "enable_thinking" in tokenizer.tokenizer.chat_template + logger.debug(f"tokenizer.tokenizer.chat_template: {tokenizer.tokenizer.chat_template}") + logger.info(f"tokenizer_supports_force_thinking : {ans}") + return ans + except: + pass + + logger.info("tokenizer_supports_force_thinking : False") + return False + + def _normalize_tool_call_arguments(messages: list) -> None: # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility # Qwen35's chat template expects arguments to be a dict (uses |items filter) @@ -63,11 +91,49 @@ def _normalize_tool_call_arguments(messages: list) -> None: pass +def _alias_reasoning_to_reasoning_content(messages: list) -> None: + # Clients (OpenRouter-style, claw-eval, and others) replay prior thinking on + # assistant messages as `reasoning`, but Qwen3/Qwen3.5 chat templates read + # `message.reasoning_content`. Without this alias the template falls back to + # rendering every recent assistant turn as `\n` (empty think), + # which teaches the model in-context to skip thinking on the current turn. + for msg in messages: + if msg.get("role") != "assistant": + continue + if msg.get("reasoning_content"): + continue + reasoning = msg.get("reasoning") + if isinstance(reasoning, str) and reasoning: + msg["reasoning_content"] = reasoning + + +def _normalize_multimodal_content_types(messages: list) -> None: + # OpenAI requests use content part types like `image_url` and `audio_url`. + # Model chat templates generally render modality tokens from `image` and + # `audio` parts while the raw media payload is carried separately in + # MultimodalParams. Preserve the original fields and normalize only the + # template-facing type to keep prompt tags aligned with media counts. + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not isinstance(part, dict): + continue + if part.get("type") == "image_url": + part["type"] = "image" + elif part.get("type") == "audio_url": + part["type"] = "audio" + + async def build_prompt(request, tools) -> str: - global tokenizer # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] _normalize_tool_call_arguments(messages) + _alias_reasoning_to_reasoning_content(messages) + if get_model_type_v1() == "gemma4": + # gemma4 的 tokenizer 不支持 multimodal 内容类型,所以需要手动转换 + _normalize_multimodal_content_types(messages) kwargs = {"conversation": messages} if request.character_settings: @@ -78,9 +144,21 @@ async def build_prompt(request, tools) -> str: if request.chat_template_kwargs: kwargs.update(request.chat_template_kwargs) + # 修复一些parser类型是默认打开thinking,但是 tokenizer有时候不知道打开了thinking。导致 + # 构建的reasoning parser 和 tokenizer 的行为不对齐导致的问题。 + from .api_openai import _is_force_thinking_mode + + thinking = _is_force_thinking_mode(request) + + kwargs["thinking"] = thinking + kwargs["enable_thinking"] = thinking + + # TODO thinking 模式应该是3种,一种是强制思考,一种是强制不思考,一种是模型自己决定的自适应 + # 的思考模式。当前的代码只是实现了强制思考和强制不思考两种模式。后续要根据模型的情况,从tokenizer + # 上判断能支持的思考模式种类,再进行设置,才能具备更完备的处理。 + try: input_str = tokenizer.apply_chat_template(**kwargs, tokenize=False, add_generation_prompt=True, tools=tools) - except BaseException as e: - logger.error(f"Failed to build prompt: {e}") - raise e + except Exception as e: + raise ValueError(f"Failed to build prompt: {e}") from None return input_str diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index 5c015f234c..3ce39bb6e6 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -81,7 +81,7 @@ async def visual_websocket_endpoint(websocket: WebSocket): client_ip, client_port = websocket.client logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes()) - logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}") + logger.info(f"received registered_visual_server_obj {registered_visual_server_obj}") with registered_visual_server_obj_lock: registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj diff --git a/lightllm/server/core/objs/out_token_circlequeue.py b/lightllm/server/core/objs/out_token_circlequeue.py index ea99dae5f6..7664fd7c48 100644 --- a/lightllm/server/core/objs/out_token_circlequeue.py +++ b/lightllm/server/core/objs/out_token_circlequeue.py @@ -2,6 +2,11 @@ import ctypes from typing import Tuple +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) + LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1280)) LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 8)) @@ -24,9 +29,13 @@ def __init__(self): def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int): str_bytes = token_str.encode("utf-8") - assert ( - len(str_bytes) <= LIGHTLLM_TOKEN_MAX_BYTES - ), f"Token string {len(str_bytes)} exceeds maximum length of {LIGHTLLM_TOKEN_MAX_BYTES} bytes." + if len(str_bytes) > LIGHTLLM_TOKEN_MAX_BYTES: + old_len = len(str_bytes) + str_bytes = str_bytes[:LIGHTLLM_TOKEN_MAX_BYTES].decode("utf-8", errors="ignore").encode("utf-8") + logger.warning( + f"Token string {old_len} exceeds maximum length of {LIGHTLLM_TOKEN_MAX_BYTES} bytes, " + f"truncated to {len(str_bytes)} bytes." + ) ctypes.memmove(self.data, str_bytes, len(str_bytes)) self.data_len = len(str_bytes) self.src_index = src_index diff --git a/lightllm/server/core/objs/nixl_params.py b/lightllm/server/core/objs/pd_kv_trans_params.py similarity index 51% rename from lightllm/server/core/objs/nixl_params.py rename to lightllm/server/core/objs/pd_kv_trans_params.py index 8b64554f84..290a4ca287 100644 --- a/lightllm/server/core/objs/nixl_params.py +++ b/lightllm/server/core/objs/pd_kv_trans_params.py @@ -2,13 +2,13 @@ import ctypes from typing import Optional -LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES = int(os.getenv("LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES", 8 * 1024)) +LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES = int(os.getenv("LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES", 8 * 1024)) -class NIXLParamObj(ctypes.Structure): +class PDKVTransParamObj(ctypes.Structure): _pack_ = 4 _fields_ = [ - ("data", ctypes.c_ubyte * LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES), + ("data", ctypes.c_ubyte * LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES), ("data_len", ctypes.c_int), ] @@ -20,9 +20,10 @@ def set(self, obj_bytes: Optional[bytes]): self.data_len = 0 return - assert ( - len(obj_bytes) <= LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES - ), f"NIXL_PARAM_OBJ bytes len {len(obj_bytes)} exceeds length of {LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES} bytes." + assert len(obj_bytes) <= LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES, ( + f"PD_KV_TRANS_PARAM_OBJ bytes len {len(obj_bytes)} exceeds length of " + f"{LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES} bytes." + ) ctypes.memmove(self.data, obj_bytes, len(obj_bytes)) self.data_len = len(obj_bytes) return diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 0fa95d4174..cbc63c898d 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -38,7 +38,7 @@ def __init__( top_k: int = None, # -1 is for all ignore_eos: bool = False, image_max_patch_num: int = -1, - max_new_tokens: int = 16384, + max_new_tokens: int = 65535, min_new_tokens: int = 1, stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件 skip_special_tokens: bool = True, # whether to skip special tokens when decoding @@ -54,10 +54,10 @@ def __init__( # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. allowed_token_ids: Optional[List[int]] = None, + # if provided, the invalid token ids will be ignored during generation + invalid_token_ids: Optional[List[int]] = None, # p d mode used params group_request_id: Optional[int] = None, - # move kv to deocde node, only used in pd mode - move_kv_to_decode_node: Optional[dict] = None, # suggest dp index, deepseekv2 dp mode, use to suggest used dp_index suggested_dp_index: Optional[int] = None, seed: Optional[int] = -1, @@ -89,8 +89,8 @@ def __init__( self.guided_grammar = guided_grammar self.guided_json = guided_json self.allowed_token_ids = allowed_token_ids + self.invalid_token_ids = invalid_token_ids self.group_request_id = group_request_id - self.move_kv_to_decode_node = move_kv_to_decode_node self.suggested_dp_index = suggested_dp_index self.seed = seed if self.do_sample is False: @@ -189,9 +189,6 @@ def verify(self): if not (self.group_request_id is None or isinstance(self.group_request_id, int)): raise ValueError(f"group_request_id must be None or int ,but get {self.group_request_id}") - if not (self.move_kv_to_decode_node is None or isinstance(self.move_kv_to_decode_node, dict)): - raise ValueError(f"move_kv_to_decode_node must be None or dict, but get {self.move_kv_to_decode_node}") - if not (self.suggested_dp_index is None or isinstance(self.suggested_dp_index, int)): raise ValueError(f"suggested_dp_index must be None or int, but get {self.suggested_dp_index}") @@ -269,7 +266,7 @@ def to_dict(self): ret["guided_grammar"] = self.guided_grammar ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids - ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node + ret["invalid_token_ids"] = self.invalid_token_ids ret["seed"] = self.seed return ret diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 8905248bf8..7f2b697091 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -6,10 +6,11 @@ from .sampling_params import SamplingParams from .out_token_circlequeue import CircularQueue from .shm_array import ShmArray -from .token_chunck_hash_list import TokenHashList, CpuCachePageList +from .token_chunck_hash_list import TokenHashList, CpuCachePageList, TokenPageLenList from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.config_utils import is_linear_att_mixed_model from lightllm.utils.kv_cache_utils import compute_token_list_hash from typing import List, Any, Union from lightllm.utils.log_utils import init_logger @@ -116,12 +117,14 @@ class Req(ctypes.Structure): # 当 stop_str_matched 条件满足的时候,对应的最后一个生成 token 所在的index位置。 # 该变量为 detokenization 进程写入,http_server 读取 ("stop_str_matched_token_index", ctypes.c_int), + # 用于在 包含linear att 混合模型中,进行输入的提前hash,方便在对应的page radix tree中进行快速操作。 + ("linear_att_token_hash_list", TokenHashList), # 用于在开启cpu cache 或者 硬盘 cache时,预先计算,分块输入token的hash值。 ("token_hash_list", TokenHashList), + # 用于存储每个cpu cache 页面对应的真实token数量,用于linear att的qwen3.5等模型的碎片化处理最后一个页面的问题 + ("token_hash_page_len_list", TokenPageLenList), # 用于保存查找匹配到的可以被复用的cpu cache 页面信息。 ("cpu_cache_match_page_indexes", CpuCachePageList), - # 分块hash的块大小 - ("cpu_cache_token_page_size", ctypes.c_int), ] def get_str(self): @@ -182,21 +185,68 @@ def init( self.post_init() - self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size - if get_env_start_args().enable_cpu_cache: - self._fill_input_token_hash() + args = get_env_start_args() + if is_linear_att_mixed_model(args.model_dir): + self._fill_linear_att_token_hash() + if args.enable_cpu_cache: + cpu_cache_hash_list, cpu_cache_page_len_list = self._calcu_linear_att_cpu_cache_page_len_list() + self.token_hash_list = TokenHashList() + self.token_hash_list.clear() + self.token_hash_list.fill(cpu_cache_hash_list) + self.token_hash_page_len_list = TokenPageLenList() + self.token_hash_page_len_list.clear() + self.token_hash_page_len_list.fill(cpu_cache_page_len_list) + self.cpu_cache_match_page_indexes = CpuCachePageList() + else: + if args.enable_cpu_cache: + self._fill_input_token_hash() + page_num = self.token_hash_list.size + cpu_cache_page_len_list = [args.cpu_cache_token_page_size * (i + 1) for i in range(page_num)] + self.token_hash_page_len_list = TokenPageLenList() + self.token_hash_page_len_list.clear() + self.token_hash_page_len_list.fill(cpu_cache_page_len_list) + self.cpu_cache_match_page_indexes = CpuCachePageList() + return def post_init(self): # 子类继承进行一些额外的初始化操作 pass + def _calcu_linear_att_cpu_cache_page_len_list(self): + token_hash_list = self.linear_att_token_hash_list.get_all() + linear_att_hash_page_size = get_env_start_args().linear_att_hash_page_size + block_num = get_env_start_args().linear_att_page_block_num + cpu_cache_page_size = get_env_start_args().cpu_cache_token_page_size + assert cpu_cache_page_size == linear_att_hash_page_size * block_num + cpu_cache_hash_list = [] + cpu_cache_page_len_list = [] + cum_sum_len = 0 + for i in range(len(token_hash_list)): + if i % block_num == (block_num - 1): + cpu_cache_hash_list.append(token_hash_list[i]) + cum_sum_len += cpu_cache_page_size + cpu_cache_page_len_list.append(cum_sum_len) + elif i == len(token_hash_list) - 1: + cpu_cache_hash_list.append(token_hash_list[len(token_hash_list) - 1]) + page_num = (i % block_num) + 1 + cum_sum_len += page_num * linear_att_hash_page_size + cpu_cache_page_len_list.append(cum_sum_len) + + return cpu_cache_hash_list, cpu_cache_page_len_list + def _fill_input_token_hash(self): self.token_hash_list = TokenHashList() self.token_hash_list.clear() - hash_values = compute_token_list_hash(self.get_prompt_ids(), self.cpu_cache_token_page_size) + hash_values = compute_token_list_hash(self.get_prompt_ids(), get_env_start_args().cpu_cache_token_page_size) self.token_hash_list.fill(hash_values) - self.cpu_cache_match_page_indexes = CpuCachePageList() + return + + def _fill_linear_att_token_hash(self): + self.linear_att_token_hash_list = TokenHashList() + self.linear_att_token_hash_list.clear() + hash_values = compute_token_list_hash(self.get_prompt_ids(), get_env_start_args().linear_att_hash_page_size) + self.linear_att_token_hash_list.fill(hash_values) return def create_prompt_ids_shm_array(self): diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index ba30d3716b..c39559f5f6 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -3,7 +3,7 @@ from typing import Optional, List, Tuple, Union from transformers import GenerationConfig from lightllm.server.req_id_generator import MAX_BEST_OF -from .nixl_params import NIXLParamObj +from .pd_kv_trans_params import PDKVTransParamObj _SAMPLING_EPS = 1e-5 DEFAULT_INPUT_PENALTY = os.getenv("INPUT_PENALTY", "False").upper() in ["ON", "TRUE", "1"] @@ -17,6 +17,7 @@ REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) +INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10)) class StopSequence(ctypes.Structure): @@ -205,6 +206,25 @@ def to_list(self): return list(self.ids[: self.size]) +class InvalidTokenIds(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH), + ("size", ctypes.c_int), + ] + + def initialize(self, ids: List[int]): + self.size = len(ids) + assert ( + self.size <= INVALID_TOKEN_IDS_MAX_LENGTH + ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}." + self.ids[: self.size] = ids[:] + return + + def to_list(self): + return list(self.ids[: self.size]) + + class ExponentialDecayLengthPenalty(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -240,44 +260,6 @@ def get(self) -> int: return (self.node_id_high << 64) | self.node_id_low -class DecodeNode(ctypes.Structure): - _pack_ = 4 - _fields_ = [ - ("exists", ctypes.c_bool), - ("node_id", NodeUUId), - ("ip", ctypes.c_int32 * 4), - ("rpyc_port", ctypes.c_int), - ("max_new_tokens", ctypes.c_int), - ] - - def initialize(self, data_dict): - if data_dict is None: - self.exists = False - return - - self.exists = True - - pd_node_id = data_dict["node_id"] - self.node_id = NodeUUId() - self.node_id.initialize(pd_node_id) - - ip_parts = [int(part) for part in data_dict["ip"].split(".")] - self.ip = (ctypes.c_int32 * 4)(*ip_parts) - - self.rpyc_port = data_dict["rpyc_port"] - self.max_new_tokens = data_dict["max_new_tokens"] - - def to_dict(self): - if not self.exists: - return None - return { - "node_id": self.node_id.get(), - "ip": ".".join(str(self.ip[i]) for i in range(4)), - "rpyc_port": self.rpyc_port, - "max_new_tokens": self.max_new_tokens, - } - - class SamplingParams(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -304,15 +286,16 @@ class SamplingParams(ctypes.Structure): # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. ("allowed_token_ids", AllowedTokenIds), + # if provided, the invalid token ids will be ignored during generation + ("invalid_token_ids", InvalidTokenIds), ("stop_sequences", StopSequenceGroups), ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), ("group_request_id", ctypes.c_int64), # p d mode used params ("suggested_dp_index", ctypes.c_int), # suggest dp index, deepseekv2 dp mode, use to suggest used dp_index - ("move_kv_to_decode_node", DecodeNode), # move kv to deocde node, only used in pd mode # in pd split mode, use to keep the id of pd master ("pd_master_node_id", NodeUUId), - # nixl params object, only used in nixl pd mode, used to build nixl connection in p and d - ("nixl_params", NIXLParamObj), + # pd params object, only used in pd mode, used to build kv transport connection in prefill and decode + ("pd_kv_trans_params", PDKVTransParamObj), ("skip_special_tokens", ctypes.c_bool), # whether to skip special tokens when decoding ("add_special_tokens", ctypes.c_bool), # whether to add special tokens when encoding ( @@ -345,7 +328,7 @@ def init(self, tokenizer, **kwargs): self.top_k = kwargs.get("top_k", SamplingParams._top_k) self.ignore_eos = kwargs.get("ignore_eos", False) self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) - self.max_new_tokens = kwargs.get("max_new_tokens", 16384) + self.max_new_tokens = kwargs.get("max_new_tokens", 65535) self.min_new_tokens = kwargs.get("min_new_tokens", 1) self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) self.group_request_id = kwargs.get("group_request_id", -1) @@ -362,10 +345,8 @@ def init(self, tokenizer, **kwargs): self.exponential_decay_length_penalty = ExponentialDecayLengthPenalty() self.exponential_decay_length_penalty.initialize(kwargs.get("exponential_decay_length_penalty", (1, 1.0))) - self.move_kv_to_decode_node = DecodeNode() - self.move_kv_to_decode_node.initialize(kwargs.get("move_kv_to_decode_node", None)) - self.nixl_params = NIXLParamObj() - self.nixl_params.set(kwargs.get("nixl_params", None)) + self.pd_kv_trans_params = PDKVTransParamObj() + self.pd_kv_trans_params.set(kwargs.get("pd_kv_trans_params", None)) self.pd_master_node_id = NodeUUId() self.pd_master_node_id.initialize(kwargs.get("pd_master_node_id", 0)) @@ -394,6 +375,11 @@ def init(self, tokenizer, **kwargs): self.allowed_token_ids = AllowedTokenIds() self.allowed_token_ids.initialize(allowed_token_ids) + # Initialize invalid_token_ids + invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) + self.invalid_token_ids = InvalidTokenIds() + self.invalid_token_ids.initialize(list[int](invalid_token_ids)) + if self.do_sample is False: self.temperature = 1.0 self.top_p = 1.0 @@ -493,8 +479,8 @@ def to_dict(self): "guided_grammar": self.guided_grammar.to_str(), "guided_json": self.guided_json.to_str(), "allowed_token_ids": self.allowed_token_ids.to_list(), + "invalid_token_ids": self.invalid_token_ids.to_list(), "group_request_id": self.group_request_id, - "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), "skip_special_tokens": self.skip_special_tokens, "add_special_tokens": self.add_special_tokens, "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index aa3641afc2..fd9106d59c 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -114,6 +114,10 @@ def release_req_index(self, req_index_in_mem): async def async_release_req_index(self, req_index_in_mem): return self.release_req_index(req_index_in_mem) + def is_idle(self) -> bool: + """True when no request slot is currently allocated in shared memory.""" + return int(np.sum(self.alloc_state_shm.arr)) == 0 + # get_req_obj_by_index 和 put_back_req_obj 是 分配好后,进行对象获取和 # 管理的接口,主要是要进行引用计数的管理。 def get_req_obj_by_index(self, req_index_in_mem): diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 018c022860..bfc03cd542 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -8,9 +8,7 @@ class StartArgs: run_mode: str = field( default="normal", - metadata={ - "choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode", "visual_only"] - }, + metadata={"choices": ["normal", "pd_master", "prefill", "decode", "config_server", "visual_only"]}, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) @@ -25,14 +23,13 @@ class StartArgs: config_server_visual_redis_port: int = field(default=None) afs_image_embed_dir: str = field(default=None) afs_embed_capacity: int = field(default=250000) - pd_decode_rpyc_port: int = field(default=None) select_p_d_node_strategy: str = field(default=None) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) tokenizer_mode: str = field(default="slow") load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) - mem_fraction: float = field(default=0.9) + mem_fraction: float = field(default=0.8) batch_max_tokens: Optional[int] = field(default=None) eos_id: List[int] = field(default_factory=list) tool_call_parser: Optional[str] = field( @@ -65,7 +62,8 @@ class StartArgs: dp: int = field(default=1) nnodes: int = field(default=1) node_rank: int = field(default=0) - max_req_total_len: int = field(default=2048 + 1024) + # If None, will be automatically derived from model config in `lightllm.server.api_start`. + max_req_total_len: Optional[int] = field(default=None) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=None) use_config_server_to_init_nccl: bool = field(default=False) @@ -76,6 +74,7 @@ class StartArgs: router_token_ratio: float = field(default=0.0) router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) + enable_prefill_decode_mixed: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) chunked_prefill_size: int = field(default=None) disable_chunked_prefill: bool = field(default=False) @@ -87,20 +86,27 @@ class StartArgs: disable_vision: Optional[bool] = field(default=None) disable_audio: Optional[bool] = field(default=None) visual_use_proxy_mode: bool = field(default=False) + disable_symm_mem_allreduce: bool = field(default=False) + disable_flashinfer_allreduce: bool = field(default=False) enable_tpsp_mix_mode: bool = field(default=False) enable_dp_prefill_balance: bool = field(default=False) enable_decode_microbatch_overlap: bool = field(default=False) enable_prefill_microbatch_overlap: bool = field(default=False) cache_capacity: int = field(default=200) + max_image_token_count: int = field(default=8192) + max_image_pixels: int = field(default=8294400) embed_cache_storage_size: float = field(default=4) data_type: Optional[str] = field( default=None, metadata={"choices": ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]} ) return_all_prompt_logprobs: bool = field(default=False) use_reward_model: bool = field(default=False) - long_truncation_mode: Optional[str] = field(default=None, metadata={"choices": [None, "head", "center"]}) use_tgi_api: bool = field(default=False) health_monitor: bool = field(default=False) + enable_profiling: Optional[str] = field( + default=None, + metadata={"choices": ["torch_profiler", "nvtx"]}, + ) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") grouping_key: List[str] = field(default_factory=list) @@ -121,13 +127,14 @@ class StartArgs: enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) - prefll_cudagraph_max_handle_token: int = field(default=512) + prefill_cudagraph_max_handle_token: int = field(default=8192) graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) graph_max_len_in_batch: int = field(default=0) quant_type: Optional[str] = field(default=None) quant_cfg: Optional[str] = field(default=None) + expert_dtype: Optional[str] = field(default=None, metadata={"choices": ["fp8", "fp4"]}) vit_quant_type: Optional[str] = field(default=None) vit_quant_cfg: Optional[str] = field(default=None) llm_prefill_att_backend: List[str] = field( @@ -143,13 +150,14 @@ class StartArgs: default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} ) llm_kv_quant_group_size: int = field(default=8) - sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) + sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "flashinfer"]}) penalty_counter_mode: str = field( default="gpu_counter", metadata={"choices": ["cpu_counter", "pin_mem_counter", "gpu_counter"]} ) enable_ep_moe: bool = field(default=False) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) mtp_mode: Optional[str] = field( default=None, metadata={ @@ -167,8 +175,8 @@ class StartArgs: mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) kv_quant_calibration_config_path: Optional[str] = field(default=None) - nixl_pd_kv_page_num: int = field(default=16) - nixl_pd_kv_page_size: int = field(default=1024) + pd_kv_page_num: int = field(default=16) + pd_kv_page_size: int = field(default=1024) pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) @@ -179,6 +187,7 @@ class StartArgs: enable_dp_prompt_cache_fetch: bool = field(default=False) # zmp ports router_port: int = field(default=None) + router_profiler_port: int = field(default=None) detokenization_port: int = field(default=None) http_server_port: int = field(default=None) visual_port: int = field(default=None) @@ -189,6 +198,8 @@ class StartArgs: multi_level_kv_cache_port: int = field(default=None) # hybrid attention model (Qwen3Next) - mamba_cache_size: Optional[int] = field(default=None) - mamba_cache_ratio: Optional[float] = field(default=0.5) - mamba_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) + linear_att_hash_page_size: int = field(default=512) + linear_att_page_block_num: int = field(default=10000000) + disable_linear_att_small_page_cpu_cache: bool = field(default=False) + linear_att_cache_size: Optional[int] = field(default=None) + linear_att_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py index a79ff48a85..23de10353b 100644 --- a/lightllm/server/core/objs/token_chunck_hash_list.py +++ b/lightllm/server/core/objs/token_chunck_hash_list.py @@ -88,3 +88,12 @@ def clear(self): def get_all(self): return list(self.items[0 : self.size]) + + +class TokenPageLenList(CpuCachePageList): + """ + 用于记录cpu cache 每个 page 对应的真实prefix token数量, 用于支持含有 linear_att 的如qwen3.5 模型的cpu cache的 + 的最后一个页面的非满页面的碎片化处理。 + """ + + pass diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..8c213914c7 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -39,7 +39,7 @@ def __init__( self.req_id_to_out: Dict[int, DecodeReq] = {} self.eos_id = args.eos_id self._init_get_token_id_to_token_str() - self.is_pd_decode_mode = self.args.run_mode == "decode" + self.is_pd_decode_mode = False self.shm_req_manager = ShmReqManager() def _init_get_token_id_to_token_str(self): diff --git a/lightllm/server/embed_cache/embed_cache_client.py b/lightllm/server/embed_cache/embed_cache_client.py index 2d62cb73e5..92609775c0 100644 --- a/lightllm/server/embed_cache/embed_cache_client.py +++ b/lightllm/server/embed_cache/embed_cache_client.py @@ -35,10 +35,9 @@ def __init__(self, create_meta_data: bool, init_shm_data: bool, pin_shm: bool = size_bytes=self.embed_cache_tensor_meta.calcu_size(), ) cache_tensor_creator = CpuCacheCreator(tensor_spec=cache_tensor_spec) - self.cpu_embed_cache_tensor, _ = cache_tensor_creator.create_or_attach( + self.cpu_embed_cache_tensor = cache_tensor_creator.create_or_attach( init_shm_data=init_shm_data, pin=pin_shm, - pin_no_blocking=False, ) return diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 74208ac4b3..dfcb2f8d9e 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -13,6 +13,7 @@ import ast import json +import os import orjson import logging import re @@ -29,6 +30,7 @@ from .api_models import Tool logger = logging.getLogger(__name__) +ENABLE_TOOL_NAME_CHECK = os.getenv("LIGHTLLM_ENABLE_TOOL_NAME_CHECK", "False").upper() in ["ON", "TRUE", "1"] TOOLS_TAG_LIST = [ "<|plugin|>", @@ -156,7 +158,7 @@ def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: results = [] for act in action: name = act.get("name") - if name and name in tool_indices: + if name and (not ENABLE_TOOL_NAME_CHECK or name in tool_indices): results.append( ToolCallItem( tool_index=-1, # Caller should update this based on the actual tools array called diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index c9822ff618..0f1b873111 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -13,11 +13,11 @@ from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict, Optional, AsyncGenerator +from typing import Literal, Union, List, Tuple, Dict, Optional, AsyncGenerator from websockets import ClientConnection from fastapi import Request from ..tokenizer import get_tokenizer -from ..pd_io_struct import NodeRole, ObjType, NIXLDecodeNodeInfo +from ..pd_io_struct import NodeRole, ObjType, PDDecodeNodeInfo from ..embed_cache.utils import get_shm_name_data, create_shm from ..multimodal_params import AudioItem, MultimodalParams, ImageItem from ..req_id_generator import ReqIDGenerator @@ -34,7 +34,7 @@ from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.error_utils import ClientDisconnected, PDPrefillNodeStopGenToken from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -52,8 +52,9 @@ def __init__( self.multinode_req_manager = None self.nnodes = args.nnodes - self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1) + self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 2) self._resource_lock = AsyncLock(self._shm_lock_pool.get_lock_context(0)) + self._run_reqs_count_lock = AsyncLock(self._shm_lock_pool.get_lock_context(1)) self.node_rank = args.node_rank self.disable_abort = args.nnodes > 1 and args.dp == 1 # mulitnode dp=1 mode, disable abort self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 @@ -111,17 +112,25 @@ def __init__( self.metric_client = MetricClient(args.metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) - assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL, NodeRole.NP, NodeRole.ND] + assert self.pd_mode in [NodeRole.NORMAL, NodeRole.P, NodeRole.D] self.id_gen = ReqIDGenerator() self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() # 有的模型的vocab size 读取tokenizer和config.json中不一致 self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size) - # The timemark of the latest inference(prefill/decode) which is used to check the health status of the system. - # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. + # Timemark of the latest successful inference, used by passive /health checks. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + + self.run_reqs_count_mark = SharedInt(f"{get_unique_server_name()}_run_reqs_count_mark") + self.run_reqs_count_mark.set_value(0) + + # 用于记录真实的--max_total_token_num 参数,当这个参数在启动参数中没有设置的时候,其是在推理进程中被分析出来的, + # 这个时候如果 --max_req_total_len > --max_total_token_num 时,如果httpserver放过一些非法的输入进入后续的模块可能 + # 会触发整个系统崩溃,所以httpserver需要知道真实的 max_total_token_num的数据,用于提前拦截非法请求等参数。 + # router 进程会在启动后向这个共享内存写入正确的max_total_token_num 参数,用于后续的请求控制。 + self.shm_max_total_token_num = SharedInt(f"{get_unique_server_name()}_shm_max_total_token_num") return def _log_stage_timing(self, group_request_id: int, start_time: float, stage: str, **kwargs): @@ -175,8 +184,19 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): self.cache_client.root.set_items_data(update_data_ids) return + def _assert_image_token_count(self, token_num: int): + if token_num > self.args.max_image_token_count: + err_msg = ( + f"single image token count {token_num} exceeds max_image_token_count {self.args.max_image_token_count}." + f"You can increase this limit by setting --max_image_token_count to a larger value when starting " + f"LightLLM. Warning: increasing this limit raises runtime OOM risk." + ) + logger.warning(err_msg) + raise ValueError(err_msg) + return + async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams): - # 只有 P 和 NORMAL 节点需要真的管理多模态资源 + # 只有 prefill 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): items, md5sums, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: @@ -184,6 +204,7 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) + self._assert_image_token_count(token_num) md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) md5sums.append(md5sum) img.md5 = md5sum @@ -205,7 +226,7 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): - # 只有 P 和 NORMAL 节点需要真的管理多模态资源 + # 只有 prefill 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): if multimodal_params is not None: ids_to_release = [] @@ -239,7 +260,9 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar for img in multimodal_params.images: img_count += 1 self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params) - image_tokens += self.tokenizer.get_image_token_length(img) + token_num = self.tokenizer.get_image_token_length(img) + self._assert_image_token_count(token_num) + image_tokens += token_num for audio in multimodal_params.audios: audio_count += 1 self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params) @@ -263,12 +286,9 @@ async def generate_wrapper(results_generator): asyncio.create_task(generate_wrapper(results_generator)) return - def alloc_req_id(self, sampling_params, is_health_req: bool = False): + def alloc_req_id(self, sampling_params): # 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性 # 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置 - # health 请求 request_id 为负数,直接返回 - if is_health_req: - return sampling_params.group_request_id if self.pd_mode.is_normal(): if not self.is_multinode_tp: group_request_id = self.id_gen.generate_id() @@ -292,15 +312,15 @@ async def generate( sampling_params: SamplingParams, multimodal_params: MultimodalParams, request: Request, - is_health_req: bool = False, - # 该参数只会在 nixl pd mode 中使用,用于上报一些信息给 pd_master - nixl_pd_upload_websocket: ClientConnection = None, + # 该参数只会在 pd mode 中使用,用于上报一些信息给 pd_master + pd_upload_websocket: ClientConnection = None, # 用于等待 pd_master 下发的交换信息 - nixl_pd_event: asyncio.Event = None, + pd_event: asyncio.Event = None, ) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]: + start_time = time.time() request_headers = request.headers if request is not None else {} - group_request_id = self.alloc_req_id(sampling_params, is_health_req) + group_request_id = self.alloc_req_id(sampling_params) audio_count = len(multimodal_params.audios) if multimodal_params is not None else 0 image_count = len(multimodal_params.images) if multimodal_params is not None else 0 self._log_stage_timing( @@ -311,6 +331,9 @@ async def generate( image_count=image_count, ) + async with self._run_reqs_count_lock: + self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() + 1) + try: original_multimodal_params = None if self.is_multinode_tp_master: @@ -335,40 +358,40 @@ async def generate( ) prompt_tokens = len(prompt_ids) - # 监控 - if group_request_id > 0: - self.metric_client.counter_inc("lightllm_request_count") - self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens) - self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) prompt_ids = await self._check_and_repair_length(prompt_ids, sampling_params) + # 监控 + self.metric_client.counter_inc("lightllm_request_count") + self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens) + self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) + self._log_stage_timing( group_request_id, start_time, "check_and_repair_length_done", ) - if nixl_pd_upload_websocket is not None and not is_health_req and self.pd_mode.is_NP(): - # 在 nixl pd 模式下的 p 节点, 为了更好的兼容多模态的推理流程,np 节点需要先上报其 encode 好的 prompt ids 信息,然后 - # 再等待 pd_master 传输下来的对应的进行 decode 节点的decode信息,然后再执行后续的流程 + if pd_upload_websocket is not None and self.pd_mode.is_P(): + # 在 pd 模式下的 prefill 节点,为了兼容多模态推理流程,需要先上报 encode 好的 prompt ids, + # 再等待 pd_master 下发对应请求的 decode 节点信息,然后执行后续流程。 logger.info( - f"nixl prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}" + f"pd prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}" ) - await nixl_pd_upload_websocket.send( - pickle.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids)) + await pd_upload_websocket.send( + pickle.dumps((ObjType.PD_UPLOAD_PREFILL_PROMPT_IDS, group_request_id, prompt_ids)) ) try: - await asyncio.wait_for(nixl_pd_event.wait(), timeout=80) + await asyncio.wait_for(pd_event.wait(), timeout=180) except asyncio.TimeoutError: - logger.error(f"nixl np node wait nixl_pd_event 36s time out, group_req_id {group_request_id}") - raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out") + logger.error(f"pd prefill node wait pd_event 180s time out, group_req_id {group_request_id}") + raise Exception(f"group_req_id {group_request_id} wait pd_event time out") - decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info - sampling_params.nixl_params.set(pickle.dumps(decode_node_info)) + decode_node_info: PDDecodeNodeInfo = pd_event.decode_node_info + sampling_params.pd_kv_trans_params.set(pickle.dumps(decode_node_info)) if decode_node_info.ready_kv_len == len(prompt_ids) - 1: # 如果 decode 节点的 ready_kv_len 和 prefill encode 的 len(prompt ids) -1 相等,说明不需要进行 prefill - # 直接 raise NixlPrefillNodeStopGenToken - raise NixlPrefillNodeStopGenToken(group_request_id=group_request_id) + # 直接 raise PDPrefillNodeStopGenToken + raise PDPrefillNodeStopGenToken(group_request_id=group_request_id) # 申请资源并存储 alloced_req_indexes = [] @@ -444,8 +467,12 @@ async def generate( yield sub_req_id, request_output, metadata, finish_status - except Exception as e: - logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + except (ClientDisconnected, Exception) as e: + logger.warning(f"group_request_id: {group_request_id} has exception {str(e)}") + + if isinstance(e, ClientDisconnected): + logger.warning(f"group_request_id: {group_request_id} {e.reason}") + # error need to release multimodel resources. # 对于还没有形成正式请求对象管理的多模态资源,需要单独自己释放 # 已经放入到 req_id_to_out_inf 中的请求对象,由统一的回收循环 @@ -454,6 +481,9 @@ async def generate( await self._release_multimodal_resources(multimodal_params) await self.abort(group_request_id) raise e + finally: + async with self._run_reqs_count_lock: + self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() - 1) return def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: @@ -487,6 +517,15 @@ async def _encode( self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams, sampling_params: SamplingParams ): if isinstance(prompt, str): + # pre-verify prompt length + # The average character length per token is always less than 8 + # TODO: automatically calculate the average character length per token + max_prompt_chars = self.max_req_total_len * 8 + if len(prompt) > max_prompt_chars: + raise ValueError( + f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, " + f"the request is rejected before tokenization." + ) if self.enable_multimodal: assert ( len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity @@ -494,15 +533,22 @@ async def _encode( if multimodal_params.audios: assert not self.args.disable_audio, "audio multimodal not enabled" await self._alloc_multimodal_resources(multimodal_params, sampling_params) - prompt_ids = self.tokenizer.encode( - prompt, multimodal_params, add_special_tokens=sampling_params.add_special_tokens + prompt_ids = await asyncio.to_thread( + self.tokenizer.encode, + prompt, + multimodal_params, + add_special_tokens=sampling_params.add_special_tokens, ) else: - prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=sampling_params.add_special_tokens) + prompt_ids = await asyncio.to_thread( + self.tokenizer.encode, + prompt, + add_special_tokens=sampling_params.add_special_tokens, + ) if self.args.detail_log: logger.debug( - f"req_id: {sampling_params.group_request_id} prompt: {prompt},\n" + f"req_id: {sampling_params.group_request_id} prompt: {prompt}\n" f"samplingparmas: {sampling_params.to_dict()}\n" f"token_ids: {prompt_ids}" ) @@ -521,37 +567,34 @@ async def _encode( raise ValueError(f"prompt format error, get type{type(prompt)}") return + def get_real_supported_max_req_total_len(self): + # 得到系统真正能支持的最大长度,同时收到启动参数中模型支持长度的限制,也收到token容量的限制。 + return min(self.shm_max_total_token_num.get_value() - 36, self.max_req_total_len) + async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: SamplingParams): if not prompt_ids: raise ValueError("prompt_ids is empty") prompt_tokens = len(prompt_ids) - if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len: - # use long_truncation_mode to truncate long input len req. - if self.args.long_truncation_mode is None: - # 修改默认逻辑,如果 prompt_tokens + max_new_tokens 长度超过总的允许长度,则将 - # 修改 max_new_tokens 的值,使其满足合法约束。 - new_max_new_tokens = self.max_req_total_len - prompt_tokens - if new_max_new_tokens > 0: - logger.debug( - f"the input prompt token len {prompt_tokens} + max_new_tokens" - f"{sampling_params.max_new_tokens} > {self.max_req_total_len}," - f"so change max_new_tokens to {new_max_new_tokens}" - ) - sampling_params.max_new_tokens = new_max_new_tokens - else: - raise ValueError( - f"the input prompt token len {prompt_tokens} + max_new_tokens \ - {sampling_params.max_new_tokens} > {self.max_req_total_len}" - ) - elif self.args.long_truncation_mode == "head": - prompt_ids = prompt_ids[-(self.max_req_total_len - sampling_params.max_new_tokens) :] - elif self.args.long_truncation_mode == "center": - req_input_len = self.max_req_total_len - sampling_params.max_new_tokens - prompt_ids = prompt_ids[0 : req_input_len // 2] + prompt_ids[-(req_input_len - req_input_len // 2) :] - prompt_tokens = len(prompt_ids) - assert prompt_tokens == req_input_len + # 这里 -36 是保留一些不可预知的边界余量,防止系统出错 + real_supported_max_req_total_len = self.get_real_supported_max_req_total_len() + + if prompt_tokens + sampling_params.max_new_tokens > real_supported_max_req_total_len: + + # 修改默认逻辑,如果 prompt_tokens + max_new_tokens 长度超过总的允许长度,则将 + # 修改 max_new_tokens 的值,使其满足合法约束。 + new_max_new_tokens = real_supported_max_req_total_len - prompt_tokens + if new_max_new_tokens > 0: + logger.debug( + f"the input prompt token len {prompt_tokens} + max_new_tokens" + f"{sampling_params.max_new_tokens} > {real_supported_max_req_total_len}," + f"so change max_new_tokens to {new_max_new_tokens}" + ) + sampling_params.max_new_tokens = new_max_new_tokens else: - assert False, "error args" + raise ValueError( + f"the input prompt token len {prompt_tokens} + max_new_tokens \ + {sampling_params.max_new_tokens} > {real_supported_max_req_total_len}" + ) # last repaired req_total_len = len(prompt_ids) + sampling_params.max_new_tokens @@ -647,7 +690,9 @@ async def _wait_to_token_package( if not self.disable_abort and request is not None and await request.is_disconnected(): await self.abort(group_request_id) - raise Exception(f"req_id {group_request_id} disconnected") + raise ClientDisconnected( + group_request_id=group_request_id, reason="_wait_to_token_package check network disconnected" + ) async with req_status.lock: event.clear() @@ -661,9 +706,10 @@ async def _wait_to_token_package( if self.pd_mode.is_P() and is_first_token: metadata["prompt_ids"] = prompt_ids - prompt_cache_len = metadata.pop("prompt_cache_len", 0) + gpu_prompt_cache_len = metadata.pop("prompt_cache_len", 0) cpu_prompt_cache_len = metadata.pop("cpu_prompt_cache_len", 0) disk_prompt_cache_len = metadata.pop("disk_prompt_cache_len", 0) + metadata["prompt_cache_len"] = gpu_prompt_cache_len + cpu_prompt_cache_len + disk_prompt_cache_len sub_req_id_to_mtp_accepted_token_num[sub_req_id] = metadata.get("mtp_accepted_token_num", 0) if is_first_token: @@ -687,9 +733,12 @@ async def _wait_to_token_package( self.per_token_costs.add(mean_per_token_cost_time_ms) x_request_id = request.headers.get("X-Request-Id", "") if request is not None else "" x_session_id = request.headers.get("X-Session-Id", "") if request is not None else "" - prompt_cache_ratio = prompt_cache_len / prompt_tokens + gpu_prompt_cache_ratio = gpu_prompt_cache_len / prompt_tokens cpu_prompt_cache_ratio = cpu_prompt_cache_len / prompt_tokens disk_prompt_cache_ratio = disk_prompt_cache_len / prompt_tokens + prompt_cache_len = gpu_prompt_cache_len + cpu_prompt_cache_len + disk_prompt_cache_len + prompt_cache_ratio = prompt_cache_len / prompt_tokens + generation_throughput = out_token_counter / max(total_cost_time_ms / 1000.0, 1e-6) mtp_avg_token_per_step = out_token_counter / max( (out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1 @@ -702,9 +751,9 @@ async def _wait_to_token_package( f"total_cost_time:{total_cost_time_ms}ms,out_token_counter:{out_token_counter} " f"mean_per_token_cost_time: {mean_per_token_cost_time_ms}ms " f"prompt_token_num:{prompt_tokens} " - f"gpu cache hit: {prompt_cache_len > 0} " - f"gpu_prompt_cache_len:{prompt_cache_len} " - f"gpu_prompt_cache_ratio:{prompt_cache_ratio} " + f"gpu cache hit: {gpu_prompt_cache_ratio > 0} " + f"gpu_prompt_cache_len:{gpu_prompt_cache_len} " + f"gpu_prompt_cache_ratio:{gpu_prompt_cache_ratio} " f"cpu cache hit: {cpu_prompt_cache_len > 0} " f"cpu_prompt_cache_len:{cpu_prompt_cache_len} " f"cpu_prompt_cache_ratio:{cpu_prompt_cache_ratio} " @@ -713,11 +762,13 @@ async def _wait_to_token_package( f"disk_prompt_cache_ratio:{disk_prompt_cache_ratio} " f"mtp_avg_token_per_step:{mtp_avg_token_per_step} " ) - if group_request_id < 0: - # health 探测请求,不记录日志和监控 - return + self.metric_client.histogram_observe("lightllm_cache_length", prompt_cache_len) self.metric_client.histogram_observe("lightllm_cache_ratio", prompt_cache_ratio) + self.metric_client.counter_inc_by("lightllm_prompt_tokens_total", prompt_tokens) + self.metric_client.counter_inc_by("lightllm_generation_tokens_total", out_token_counter) + self.metric_client.gauge_set("lightllm_cache_hit_rate", prompt_cache_ratio) + self.metric_client.gauge_set("lightllm_gen_throughput", generation_throughput) self.metric_client.histogram_observe( "lightllm_request_inference_duration", total_cost_time_ms / 1000.0 ) @@ -749,6 +800,23 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + def _get_router_profiler_client(self): + router_profiler_client = getattr(self, "router_profiler_client", None) + if router_profiler_client is None or getattr(router_profiler_client, "closed", False): + from lightllm.utils.retry_utils import retry + + self.router_profiler_client = retry(max_attempts=20, wait_time=0.5)(rpyc.connect)( + "localhost", + self.args.router_profiler_port, + config={"allow_pickle": True}, + ) + self.router_profiler_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return self.router_profiler_client + + async def profiler_cmd(self, cmd: Literal["start", "stop"]): + client = self._get_router_profiler_client() + client.root.profiler_cmd(cmd) + async def recycle_resource_loop(self): pre_time_mark = time.time() diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index a646d4f4cc..dcf0c89fed 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -6,6 +6,9 @@ import httpx import base64 import weakref +import os +import signal +import sys from typing import Dict, Optional, Union, List from websockets import ClientConnection from lightllm.server.pd_io_struct import NodeRole, ObjType @@ -17,7 +20,7 @@ from ..pd_io_struct import PD_Master_Obj from lightllm.server.core.objs import StartArgs from lightllm.server.core.objs import SamplingParams -from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.error_utils import PDPrefillNodeStopGenToken logger = init_logger(__name__) @@ -31,7 +34,12 @@ async def timer_log(manager: HttpServerManager): async def pd_handle_loop(manager: HttpServerManager): - assert manager.args.host not in ["127.0.0.1", "localhost"], "pd mode must specify host ip" + if manager.args.host in ["127.0.0.1", "localhost"]: + logger.error("pd mode must specify host ip, not use 127.0.0.1 or localhost") + # kill father process to trigger graceful exit, avoid orphan process + os.kill(os.getppid(), signal.SIGINT) + sys.exit(-1) + if manager.args.host in ["0.0.0.0"]: manager.host_ip = get_hostname_ip() else: @@ -107,8 +115,8 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O if obj[0] == ObjType.REQ: prompt, sampling_params, multimodal_params = obj[1] group_req_id = sampling_params.group_request_id - nixl_pd_event = asyncio.Event() - group_req_id_to_event[group_req_id] = nixl_pd_event + pd_event = asyncio.Event() + group_req_id_to_event[group_req_id] = pd_event asyncio.create_task( _pd_process_generate( manager=manager, @@ -116,8 +124,8 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O sampling_params=sampling_params, multimodal_params=multimodal_params, forwarding_queue=forwarding_queue, - nixl_pd_upload_websocket=websocket, - nixl_pd_event=nixl_pd_event, + pd_upload_websocket=websocket, + pd_event=pd_event, ) ) elif obj[0] == ObjType.ABORT: @@ -133,14 +141,14 @@ async def delayed_abort_task(group_req_id, retry_count): asyncio.create_task(delayed_abort_task(group_req_id=group_req_id, retry_count=4)) - elif obj[0] == ObjType.NIXL_REQ_DECODE_NODE_INFO: + elif obj[0] == ObjType.PD_REQ_DECODE_NODE_INFO: _, group_req_id, decode_node_info = obj - nixl_pd_event = group_req_id_to_event.pop(group_req_id, None) - if nixl_pd_event is None: - logger.error(f"error in find nixl_pd_event, info: {obj}") + pd_event = group_req_id_to_event.pop(group_req_id, None) + if pd_event is None: + logger.error(f"error in find pd_event, info: {obj}") continue - nixl_pd_event.decode_node_info = decode_node_info - nixl_pd_event.set() + pd_event.decode_node_info = decode_node_info + pd_event.set() else: logger.error(f"recevie error obj {str(obj)}") @@ -201,8 +209,8 @@ async def _pd_process_generate( sampling_params: SamplingParams, multimodal_params: Dict, forwarding_queue: AsyncQueue, - nixl_pd_upload_websocket: ClientConnection, - nixl_pd_event: asyncio.Event, + pd_upload_websocket: ClientConnection, + pd_event: asyncio.Event, ): try: async for sub_req_id, request_output, metadata, finish_status in manager.generate( @@ -210,16 +218,13 @@ async def _pd_process_generate( sampling_params=sampling_params, multimodal_params=multimodal_params, request=None, - nixl_pd_upload_websocket=nixl_pd_upload_websocket, - nixl_pd_event=nixl_pd_event, + pd_upload_websocket=pd_upload_websocket, + pd_event=pd_event, ): - # p d 模式下,将 token 数据放入到转发队列中, 请求id 小于0的请求是health探测请求,不用转发。 - is_health_check_req = sub_req_id < 0 - if not is_health_check_req: - metadata["node_mode"] = manager.args.run_mode - await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status)) - except NixlPrefillNodeStopGenToken as e: - logger.info(f"nixl prefill node stop gen token for group_request_id {e.group_request_id}") + metadata["node_mode"] = manager.args.run_mode + await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status)) + except PDPrefillNodeStopGenToken as e: + logger.info(f"pd prefill node stop gen token for group_request_id {e.group_request_id}") except BaseException as e: logger.error(str(e)) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index d6a1a58b05..104da9f26e 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -9,7 +9,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict, Optional from lightllm.server.core.objs import FinishStatus -from ..pd_io_struct import PD_Client_Obj, UpKVStatus, NixlUpKVStatus, ObjType, NodeRole, NIXLDecodeNodeInfo +from ..pd_io_struct import PD_Client_Obj, PDUpKVStatus, ObjType, PDDecodeNodeInfo from lightllm.server.core.objs import SamplingParams, StartArgs from ..multimodal_params import MultimodalParams from ..tokenizer import get_tokenizer @@ -19,7 +19,7 @@ from lightllm.server.metrics.manager import MetricClient from lightllm.utils.statics_utils import MovingAverage from lightllm.server.httpserver.manager import AsyncQueue -from lightllm.utils.error_utils import ServerBusyError +from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError from lightllm.utils.envs_utils import get_pd_split_max_new_tokens from .pd_selector import create_selector @@ -48,6 +48,11 @@ def __init__( self.per_token_costs = MovingAverage() return + def get_real_supported_max_req_total_len(self): + # HttpServerManager.generate 会借用 _check_and_repair_length(self, ...),其中会调用本方法。 + # PD master 无本地 token 池 shm 计数;上限与启动参数及子节点对齐的 max_req_total_len 一致。 + return self.max_req_total_len + async def register_pd(self, pd_info_json, websocket): self.pd_manager.register_pd(pd_info_json, websocket) return @@ -56,7 +61,7 @@ async def remove_pd(self, pd_info_json): self.pd_manager.remove_pd(pd_info_json) return - async def update_req_status(self, upkv_status: Union[UpKVStatus, NixlUpKVStatus]): + async def update_req_status(self, upkv_status: PDUpKVStatus): try: group_request_id = convert_sub_id_to_group_id(upkv_status.group_request_id) up_status_event = self.req_id_to_out_inf[group_request_id].up_status_event @@ -76,7 +81,16 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar for img in multimodal_params.images: img_count += 1 self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params) - image_tokens += self.tokenizer.get_image_token_length(img) + token_num = self.tokenizer.get_image_token_length(img) + if token_num > self.args.max_image_token_count: + err_msg = ( + f"the image token count {token_num} > max_image_token_count {self.args.max_image_token_count}. " + f"You can increase this limit by setting --max_image_token_count to a larger value when starting " + f"LightLLM. Warning: increasing this limit raises runtime OOM risk." + ) + logger.warning(err_msg) + raise ValueError(err_msg) + image_tokens += token_num for audio in multimodal_params.audios: audio_count += 1 self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params) @@ -163,8 +177,12 @@ async def generate( await self.remove_req(group_request_id=block_group_request_id) - except BaseException as e: + except (ClientDisconnected, BaseException) as e: logger.error(f"has exception {str(e)}") + + if isinstance(e, ClientDisconnected): + logger.warning(f"group_request_id: {origin_group_request_id} {e.reason}") + try: await self.abort(block_group_request_id, p_node=p_node, d_node=d_node) except: @@ -186,7 +204,7 @@ async def _log_req_header(self, request: Request, group_request_id: int): ) return - async def fetch_stream( + async def fetch_pd_stream( self, p_node: PD_Client_Obj, d_node: PD_Client_Obj, @@ -202,104 +220,25 @@ async def fetch_stream( self.req_id_to_out_inf[group_request_id] = req_status up_status_event = req_status.up_status_event - - d_start_args = d_node.start_args - decode_node_dict = { - "node_id": d_start_args["pd_node_id"], - "ip": d_start_args["host"], - "rpyc_port": d_start_args["pd_decode_rpyc_port"], - "max_new_tokens": sampling_params.max_new_tokens - 1, - } + prefill_prompt_ids_event = req_status.prefill_prompt_ids_event old_max_new_tokens = sampling_params.max_new_tokens sampling_params.max_new_tokens = 1 - sampling_params.move_kv_to_decode_node.initialize(decode_node_dict if old_max_new_tokens != 1 else None) - sampling_params.suggested_dp_index = -1 - await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) - while True: - await req_status.wait_to_ready() - if await request.is_disconnected(): - raise Exception(f"req_id {group_request_id} disconnected") - - if await req_status.can_read(self.req_id_to_out_inf): - token_list = await req_status.pop_all_tokens() - for sub_req_id, request_output, metadata, finish_status in token_list: - if old_max_new_tokens != 1: - finish_status = FinishStatus(FinishStatus.NO_FINISH) - else: - finish_status = FinishStatus(FinishStatus.FINISHED_LENGTH) - # 得到 p 节点返回的 prompt_ids 信息 - if metadata.get("prompt_ids", None) is not None: - prompt_ids = metadata.get("prompt_ids") - prompt_ids.append(metadata.get("id")) - yield sub_req_id, request_output, metadata, finish_status - break - - # 如果只需要一个输出 token,prefill 完就直接结束掉吧 - if old_max_new_tokens == 1: - return - try: - await asyncio.wait_for(up_status_event.wait(), timeout=60) + await asyncio.wait_for(prefill_prompt_ids_event.wait(), timeout=60) except asyncio.TimeoutError: - logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.") - raise ServerBusyError() - - sampling_params.move_kv_to_decode_node.initialize(None) - sampling_params.max_new_tokens = old_max_new_tokens - 1 - upkv_status: UpKVStatus = up_status_event.upkv_status - sampling_params.suggested_dp_index = upkv_status.dp_index - - await d_node.websocket.send_bytes( - pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))) - ) - - while True: - await req_status.wait_to_ready() - if await request.is_disconnected(): - raise Exception(f"req_id {group_request_id} disconnected") - if await req_status.can_read(self.req_id_to_out_inf): - token_list = await req_status.pop_all_tokens() - for sub_req_id, request_output, metadata, finish_status in token_list: - yield sub_req_id, request_output, metadata, finish_status - - return - - async def fetch_nixl_stream( - self, - p_node: PD_Client_Obj, - d_node: PD_Client_Obj, - prompt: Union[str, List[int]], - sampling_params: SamplingParams, - multimodal_params: MultimodalParams, - request: Request, - ): - group_request_id = sampling_params.group_request_id - sampling_params.pd_master_node_id.initialize(self.args.pd_node_id) - - req_status = ReqStatus(group_request_id, p_node, d_node) - self.req_id_to_out_inf[group_request_id] = req_status - - up_status_event = req_status.up_status_event - nixl_np_up_prompt_ids_event = req_status.nixl_np_up_prompt_ids_event - - old_max_new_tokens = sampling_params.max_new_tokens - sampling_params.max_new_tokens = 1 - await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) - - try: - await asyncio.wait_for(nixl_np_up_prompt_ids_event.wait(), timeout=60) - except asyncio.TimeoutError: - logger.warning(f"group_request_id: {group_request_id} wait np up prompt ids time out") + logger.warning(f"group_request_id: {group_request_id} wait prefill prompt ids time out") raise ServerBusyError() if await request.is_disconnected(): - raise Exception(f"req_id {group_request_id} disconnected") + raise ClientDisconnected( + group_request_id=group_request_id, reason="fetch_pd_stream prefill period check network disconnected" + ) - prompt_ids = nixl_np_up_prompt_ids_event.prompt_ids - logger.info(f"group_request_id: {group_request_id} get np up prompt ids len {len(prompt_ids)}") + prompt_ids = prefill_prompt_ids_event.prompt_ids + logger.info(f"group_request_id: {group_request_id} get prefill prompt ids len {len(prompt_ids)}") sampling_params.max_new_tokens = old_max_new_tokens await d_node.websocket.send_bytes( @@ -307,34 +246,37 @@ async def fetch_nixl_stream( ) try: - await asyncio.wait_for(up_status_event.wait(), timeout=60) + await asyncio.wait_for(up_status_event.wait(), timeout=180) except asyncio.TimeoutError: logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.") raise ServerBusyError() # 将 decode 节点上报的当前请求使用的decode节点的信息下发给 p 节点,这样 p 节点才知道将 kv 传输给那个 d 节点。 - upkv_status: NixlUpKVStatus = up_status_event.upkv_status - nixl_params: bytes = upkv_status.nixl_params - decode_node_info: NIXLDecodeNodeInfo = pickle.loads(nixl_params) + upkv_status: PDUpKVStatus = up_status_event.upkv_status + pd_kv_trans_params: bytes = upkv_status.pd_kv_trans_params + decode_node_info: PDDecodeNodeInfo = pickle.loads(pd_kv_trans_params) await p_node.websocket.send_bytes( - pickle.dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info)) + pickle.dumps((ObjType.PD_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info)) ) first_token_gen = False while True: await req_status.wait_to_ready() if await request.is_disconnected(): - raise Exception(f"req_id {group_request_id} disconnected") + raise ClientDisconnected( + group_request_id=group_request_id, + reason="fetch_pd_stream decode period check network disconnected", + ) if await req_status.can_read(self.req_id_to_out_inf): token_list = await req_status.pop_all_tokens() for sub_req_id, request_output, metadata, finish_status in token_list: output_index = metadata.get("count_output_tokens") - # 因为 nixl 的 prefill 和 decode 节点都有可能上报首token,所以需要做一下过滤。 + # 因为 pd 的 prefill 和 decode 节点都有可能上报首token,所以需要做一下过滤。 if output_index == 1: if first_token_gen is False: first_token_gen = True node_run_mode = metadata.pop("node_mode", None) - if node_run_mode == "nixl_prefill": + if node_run_mode == "prefill": if old_max_new_tokens != 1 and finish_status.is_finished_length(): finish_status = FinishStatus(FinishStatus.NO_FINISH) yield sub_req_id, request_output, metadata, finish_status @@ -365,15 +307,13 @@ async def _wait_to_token_package( is_first_token = True sub_req_id_to_mtp_accepted_token_num: Dict[int, int] = {} - client_mode: NodeRole = NodeRole(d_node.mode) - - fetch_stream = self.fetch_nixl_stream if client_mode.is_NP_or_ND() else self.fetch_stream - - async for sub_req_id, out_str, metadata, finish_status in fetch_stream( + async for sub_req_id, out_str, metadata, finish_status in self.fetch_pd_stream( p_node, d_node, prompt, sampling_params, multimodal_params, request ): if await request.is_disconnected(): - raise Exception(f"req_id {group_request_id} disconnected") + raise ClientDisconnected( + group_request_id=group_request_id, reason="_wait_to_token_package check network disconnected" + ) prompt_tokens = metadata["prompt_tokens"] out_token_counter += 1 @@ -491,16 +431,16 @@ async def handle_loop(self): req_status.event.set() except: pass - elif obj[0] == ObjType.NIXL_UPLOAD_NP_PROMPT_IDS: + elif obj[0] == ObjType.PD_UPLOAD_PREFILL_PROMPT_IDS: _, group_req_id, prompt_ids = obj try: req_status: ReqStatus = self.req_id_to_out_inf[group_req_id] async with req_status.lock: - req_status.nixl_np_up_prompt_ids_event.prompt_ids = prompt_ids - req_status.nixl_np_up_prompt_ids_event.set() + req_status.prefill_prompt_ids_event.prompt_ids = prompt_ids + req_status.prefill_prompt_ids_event.set() except: logger.error( - f"NIXL_UPLOAD_NP_PROMPT_IDS fail find req status for group_req_id: {group_req_id}" + f"PD_UPLOAD_PREFILL_PROMPT_IDS fail find req status for group_req_id: {group_req_id}" ) else: logger.error(f"recevie error obj {obj}") @@ -523,7 +463,7 @@ def __init__(self, req_id, p_node, d_node) -> None: self.lock = asyncio.Lock() self.event = asyncio.Event() self.up_status_event = asyncio.Event() - self.nixl_np_up_prompt_ids_event = asyncio.Event() + self.prefill_prompt_ids_event = asyncio.Event() self.out_token_info_list: List[Tuple[int, str, dict, FinishStatus]] = [] self.p_node: PD_Client_Obj = p_node self.d_node: PD_Client_Obj = d_node @@ -573,10 +513,10 @@ def register_pd(self, pd_info_json, websocket): pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client - if pd_client.mode in ["prefill", "nixl_prefill"]: + if pd_client.mode == "prefill": self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.prefill_nodes.append(pd_client) - elif pd_client.mode in ["decode", "nixl_decode"]: + elif pd_client.mode == "decode": self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes.append(pd_client) else: diff --git a/lightllm/server/metrics/manager.py b/lightllm/server/metrics/manager.py index f3b1a5275b..a95ddc0236 100644 --- a/lightllm/server/metrics/manager.py +++ b/lightllm/server/metrics/manager.py @@ -48,6 +48,9 @@ def on_disconnect(self, conn): def exposed_counter_inc(self, name: str, label: str = None) -> None: return self.monitor.counter_inc(name, label) + def exposed_counter_inc_by(self, name: str, amount: float) -> None: + return self.monitor.counter_inc_by(name, amount) + def exposed_histogram_observe(self, name: str, value: float, label: str = None) -> None: return self.monitor.histogram_observe(name, value, label) @@ -106,6 +109,13 @@ def inner_func(): self._append_task(inner_func) return + def counter_inc_by(self, *args, **kwargs): + def inner_func(): + return self.conn.root.counter_inc_by(*args, **kwargs) + + self._append_task(inner_func) + return + def histogram_observe(self, *args, **kwargs): def inner_func(): return self.conn.root.histogram_observe(*args, **kwargs) diff --git a/lightllm/server/metrics/metrics.py b/lightllm/server/metrics/metrics.py index 130f32c7a7..0d42462c3f 100644 --- a/lightllm/server/metrics/metrics.py +++ b/lightllm/server/metrics/metrics.py @@ -27,6 +27,11 @@ "lightllm_cache_ratio": "cache length / input_length", "lightllm_batch_current_max_tokens": "dynamic max token used for current batch", "lightllm_request_mtp_avg_token_per_step": "Average number of tokens per step", + "lightllm_prompt_tokens_total": "Total number of prefill tokens processed", + "lightllm_generation_tokens_total": "Total number of generation tokens processed", + "lightllm_cache_hit_rate": "Prefix cache hit rate of latest completed request", + "lightllm_gen_throughput": "Generation throughput of latest completed request (tokens/s)", + "lightllm_num_running_reqs": "Number of running requests", } @@ -60,6 +65,7 @@ def __init__(self, args): self.init_metrics(args) def init_metrics(self, args): + self.model_name = args.model_name self.create_histogram("lightllm_request_duration", self.duration_buckets) self.create_histogram("lightllm_request_validation_duration", self.duration_buckets) @@ -100,40 +106,43 @@ def init_metrics(self, args): mtp_avg_token_per_step_buckets = [1.0, 2.0] self.create_histogram("lightllm_request_mtp_avg_token_per_step", mtp_avg_token_per_step_buckets) + self.create_counter("lightllm_prompt_tokens_total") + self.create_counter("lightllm_generation_tokens_total") + self.create_gauge("lightllm_cache_hit_rate") + self.create_gauge("lightllm_gen_throughput") + self.create_gauge("lightllm_num_running_reqs") + def create_histogram(self, name, buckets, labelnames=None): - if labelnames is None: - histogram = Histogram(name, MONITOR_INFO[name], buckets=buckets, registry=self.registry) - else: - histogram = Histogram( - name, MONITOR_INFO[name], labelnames=labelnames, buckets=buckets, registry=self.registry - ) + all_labels = ["model_name"] + (labelnames or []) + histogram = Histogram(name, MONITOR_INFO[name], labelnames=all_labels, buckets=buckets, registry=self.registry) self.monitor_registry[name] = histogram def create_counter(self, name, labelnames=None): - if labelnames is None: - histogram = Counter(name, MONITOR_INFO[name], registry=self.registry) - else: - histogram = Counter(name, MONITOR_INFO[name], labelnames=labelnames, registry=self.registry) - self.monitor_registry[name] = histogram + all_labels = ["model_name"] + (labelnames or []) + counter = Counter(name, MONITOR_INFO[name], labelnames=all_labels, registry=self.registry) + self.monitor_registry[name] = counter def create_gauge(self, name): - gauge = Gauge(name, MONITOR_INFO[name], registry=self.registry) + gauge = Gauge(name, MONITOR_INFO[name], labelnames=["model_name"], registry=self.registry) self.monitor_registry[name] = gauge def counter_inc(self, name, label=None): if label is None: - self.monitor_registry[name].inc() + self.monitor_registry[name].labels(model_name=self.model_name).inc() else: - self.monitor_registry[name].labels(method=label).inc() + self.monitor_registry[name].labels(model_name=self.model_name, method=label).inc() + + def counter_inc_by(self, name, amount): + self.monitor_registry[name].labels(model_name=self.model_name).inc(amount) def histogram_observe(self, name, value, label=None): if label is None: - self.monitor_registry[name].observe(value) + self.monitor_registry[name].labels(model_name=self.model_name).observe(value) else: - self.monitor_registry[name].labels(method=label).observe(value) + self.monitor_registry[name].labels(model_name=self.model_name, method=label).observe(value) def gauge_set(self, name, value): - self.monitor_registry[name].set(value) + self.monitor_registry[name].labels(model_name=self.model_name).set(value) def push_metrices(self): if self.gateway_url is not None: diff --git a/lightllm/server/multi_level_kv_cache/cpu_cache_client.py b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py index e4f37c0480..33da63ab56 100644 --- a/lightllm/server/multi_level_kv_cache/cpu_cache_client.py +++ b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py @@ -37,10 +37,9 @@ def __init__(self, only_create_meta_data: bool, init_shm_data: bool): size_bytes=self.kv_cache_tensor_meta.calcu_size(), ) tensor_creator = CpuCacheCreator(tensor_spec=tensor_spec) - self.cpu_kv_cache_tensor, self.attach_shm_handle = tensor_creator.create_or_attach( + self.cpu_kv_cache_tensor = tensor_creator.create_or_attach( init_shm_data=init_shm_data, pin=not init_shm_data, - pin_no_blocking=True, ) return @@ -116,7 +115,16 @@ def update_pages_status_to_ready( page_list: List[int], deref: bool = True, disk_offload_enable: bool = False, + token_num_in_page_list: Optional[int] = None, ): + """ + token_num_in_page_list, 只有在 disk_offload_enable 为True时, 需要传入,用于 + 判断当前请求的长度是否适合将其卸载到 disk 中, 避免往disk cache中卸载过短的数据 + 照成性能下降。 + """ + if disk_offload_enable and token_num_in_page_list is None: + raise ValueError("token_num_in_page_list must be provided when disk_offload_enable is True") + offload_candidates: List[int] = [] page_items = self.page_items.linked_items not_exist_none_page = True @@ -143,11 +151,7 @@ def update_pages_status_to_ready( # 控制prompt长度,较短的prompt不进行disk offload limit_length = get_disk_cache_prompt_limit_length() - if ( - disk_offload_enable - and offload_candidates - and len(page_list) * self.args.cpu_cache_token_page_size >= limit_length - ): + if disk_offload_enable and offload_candidates and token_num_in_page_list >= limit_length: # 加引用计数,落盘成功后再减掉 for offload_page_index in offload_candidates: offload_page_item: _CpuPageStatus = page_items[offload_page_index] diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py index 1de1b502c9..0a7dec0005 100644 --- a/lightllm/server/multi_level_kv_cache/manager.py +++ b/lightllm/server/multi_level_kv_cache/manager.py @@ -176,7 +176,27 @@ def _handle_group_req_multi_cache_match(self, group_req_indexes: GroupReqIndexes else: # 匹配 disk cache并load到cpu cache finded_page_indexes, disk_page_num = self._disk_cache_match(token_hash_list, all_pages) - req.disk_prompt_cache_len = disk_page_num * self.args.cpu_cache_token_page_size + + try: + token_hash_page_len_list = req.token_hash_page_len_list.get_all() + + if disk_page_num == 0 or len(finded_page_indexes) == 0: + req.disk_prompt_cache_len = 0 + else: + all_page_num = len(finded_page_indexes) + cpu_match_page_num = all_page_num - disk_page_num + + if cpu_match_page_num == 0: + cpu_match_page_len = 0 + else: + cpu_match_page_len = token_hash_page_len_list[cpu_match_page_num - 1] + + req.disk_prompt_cache_len = token_hash_page_len_list[all_page_num - 1] - cpu_match_page_len + except Exception as e: + # 因为不清楚上面的代码是否存在边界 bug,调用者是多线程的,自己打日志记录,避免 + # 日志无法记录, 无法排查问题。 + logger.exception(f"calculate disk prompt cache len has exception {str(e)}") + raise e while not self.cpu_cache_client.check_allpages_ready(finded_page_indexes): time.sleep(0.01) @@ -236,6 +256,7 @@ def start_multi_level_kv_cache_manager(args, pipe_writer): args=args, ) except Exception as e: + logger.exception(f"start multi_level_kv_cache_manager has exception {str(e)}") pipe_writer.send(str(e)) raise diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 6210628751..9541e434c8 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -4,13 +4,16 @@ import librosa import base64 import numpy as np -from typing import List +from typing import List, Tuple, Optional from io import BytesIO -from PIL import Image +from concurrent.futures import ThreadPoolExecutor +from PIL import Image, ImageFile from fastapi import Request from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from lightllm.utils.error_utils import ClientDisconnected from lightllm.utils.multimodal_utils import fetch_resource from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) @@ -63,6 +66,10 @@ async def preload(self, request: Request): self._preload_data = audio_values.tobytes() return + except ClientDisconnected as e: + # Preserve client-disconnect signal so the API layer can return 499 + # without the noisy 'Failed to read audio' error logs. + raise e except Exception as e: raise ValueError(f"Failed to read audio type={self._type}, data[:100]={self._data[:100]}: {e}!") @@ -125,6 +132,9 @@ def __init__(self, **kwargs): self.extra_params = {} async def preload(self, request: Request): + + max_image_pixels = get_env_start_args().max_image_pixels + try: if self._type == "url": timeout = int(os.getenv("REQUEST_TIMEOUT", "5")) @@ -135,19 +145,48 @@ async def preload(self, request: Request): elif self._type == "image_size": # image_size 代表直接传入图片的 width,height,主要是用于一些场景 # 的 token 计数判断, 所以只需要图片长宽信息,不需要具体图片的内容信息 - self.image_w = self._data[0] - self.image_h = self._data[1] + src_w = self._data[0] + src_h = self._data[1] + self.image_w, self.image_h = _resize_image_dimensions_if_needed(src_w, src_h, max_image_pixels) + if (self.image_w, self.image_h) != (src_w, src_h): + logger.warning( + f"image_size pixels {src_w * src_h} exceed max_image_pixels={max_image_pixels}, " + f"resized to {self.image_w}x{self.image_h}" + ) return else: raise ValueError(f"cannot read image which type is {self._type}!") - with Image.open(BytesIO(img_data)) as image: - self.image_w, self.image_h = image.size - image.verify() # verify后会失效 + # Do pixel-level decoding verification in a thread pool to avoid blocking the event loop; + # Decoding is mainly done in the C libraries (libjpeg/libpng/libwebp), which releases the GIL, + # and multiple threads can achieve true parallelism. + loop = asyncio.get_running_loop() + # 1) Verify original input bytes first. + src_w, src_h = await loop.run_in_executor(_IMAGE_VERIFY_POOL, _verify_image_bytes, img_data) + # 2) Resize (or no-op) after verification. + img_data, resized_w, resized_h = await loop.run_in_executor( + _IMAGE_VERIFY_POOL, + _resize_image_bytes_if_needed, + img_data, + src_w, + src_h, + max_image_pixels, + ) + self.image_w, self.image_h = resized_w, resized_h + + if (resized_w, resized_h) != (src_w, src_h): + logger.warning( + f"image pixels {src_w * src_h} exceed max_image_pixels={max_image_pixels}," + f" resized to {self.image_w}x{self.image_h}" + ) self._preload_data = img_data return + except ClientDisconnected as e: + # Preserve client-disconnect signal so the API layer can return 499 + # without the noisy 'Failed to read image' error logs. + raise e except Exception as e: raise ValueError(f"Failed to read image type={self._type}, data[:100]={self._data[:100]}: {e}!") @@ -211,3 +250,67 @@ def to_origin_dict(self): ret["images"] = [i.to_origin_dict() for i in self.images] ret["audios"] = [a.to_origin_dict() for a in self.audios] return ret + + +_IMAGE_VERIFY_POOL = ThreadPoolExecutor( + max_workers=int(os.getenv("LIGHTLLM_IMAGE_VERIFY_WORKERS", 4)), + thread_name_prefix="img-verify", +) + + +def _verify_image_bytes(img_data: bytes) -> Tuple[int, int]: + """ + Verify image bytes in a thread pool to find truncated/corrupted images. + image.verify() only does header-level verification and cannot find truncated images; + image.load() reads the entire pixel data and truncated images will raise OSError. + """ + # Disable PIL's truncated image loading tolerance to make truncated images raise OSError in load() + # so that the frontend can intercept it and avoid crashing in the subsequent encode/preprocess stage. + ImageFile.LOAD_TRUNCATED_IMAGES = False + + with Image.open(BytesIO(img_data)) as image: + w, h = image.size + image.load() + return w, h + + +def _resize_image_bytes_if_needed( + img_data: bytes, src_w: int, src_h: int, max_image_pixels: int +) -> Tuple[bytes, int, int]: + """ + Resize image bytes to satisfy max pixel constraint and return resized bytes with size. + """ + new_w, new_h = _resize_image_dimensions_if_needed(src_w, src_h, max_image_pixels) + if (new_w, new_h) == (src_w, src_h): + return img_data, src_w, src_h + + with Image.open(BytesIO(img_data)) as image: + resampling = Image.Resampling.LANCZOS if hasattr(Image, "Resampling") else Image.LANCZOS + resized_image = image.resize((new_w, new_h), resampling).convert("RGB") + + buffer = BytesIO() + resized_image.save(buffer, format="JPEG", quality=96, optimize=True) + return buffer.getvalue(), new_w, new_h + + +def _resize_image_dimensions_if_needed(src_w: int, src_h: int, max_image_pixels: int) -> Tuple[int, int]: + """ + Compute resized (w, h) under a max pixel budget while preserving aspect ratio. + """ + old_pixels = src_w * src_h + if old_pixels <= max_image_pixels: + return src_w, src_h + + scale = (max_image_pixels / old_pixels) ** 0.5 + new_w = max(1, int(src_w * scale)) + new_h = max(1, int(src_h * scale)) + + # Avoid overflow from integer rounding. + while new_w * new_h > max_image_pixels: + if new_w >= new_h: + new_w = max(1, new_w - 1) + else: + new_h = max(1, new_h - 1) + + assert new_w > 0 and new_h > 0, "resized image dimensions must be positive" + return new_w, new_h diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 5002e4f1cb..1d68f81a9e 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -2,7 +2,7 @@ import time import copy from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union, Set +from typing import Dict, List, Optional from lightllm.server.req_id_generator import convert_sub_id_to_group_id from fastapi import WebSocket @@ -14,24 +14,14 @@ class NodeRole(enum.Enum): P = "prefill" D = "decode" - - NP = "nixl_prefill" - ND = "nixl_decode" - NORMAL = "normal" PD_MASTER = "pd_master" def is_D(self): - return self == NodeRole.D or self == NodeRole.ND + return self == NodeRole.D def is_P(self): - return self == NodeRole.P or self == NodeRole.NP - - def is_NP(self): - return self == NodeRole.NP - - def is_ND(self): - return self == NodeRole.ND + return self == NodeRole.P def is_normal(self): return self == NodeRole.NORMAL @@ -42,16 +32,13 @@ def is_P_or_NORMAL(self): def is_P_or_D(self): return self.is_P() or self.is_D() - def is_NP_or_ND(self): - return self == NodeRole.NP or self == NodeRole.ND - class ObjType(enum.Enum): ABORT = 1 REQ = 2 TOKEN_PACKS = 3 - NIXL_UPLOAD_NP_PROMPT_IDS = 4 # nixl p 节点上报生成的 prompt ids 信息。 - NIXL_REQ_DECODE_NODE_INFO = 5 # nixl pd master 节点下发给 nixl p 节点的对应请求对应的 d 节点的信息。 + PD_UPLOAD_PREFILL_PROMPT_IDS = 4 # prefill 节点上报生成的 prompt ids 信息。 + PD_REQ_DECODE_NODE_INFO = 5 # pd master 节点下发给 prefill 节点的请求对应的 decode 节点信息。 @dataclass @@ -69,8 +56,8 @@ class PD_Client_Obj: run_status: _PD_Client_RunStatus = field(default_factory=_PD_Client_RunStatus) def __post_init__(self): - if self.mode not in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: - error_info = f"""mode must in ["prefill", "decode", "nixl_prefill", "nixl_decode"], but get {self.mode}""" + if self.mode not in ["prefill", "decode"]: + error_info = f"""mode must in ["prefill", "decode"], but get {self.mode}""" logger.error(error_info) raise ValueError(error_info) return @@ -88,120 +75,14 @@ def to_log_str(self): return f"PD_MASTER host_ip_port: {self.host_ip_port} node_id: {self.node_id}" -@dataclass -class UpKVStatus: - group_request_id: int - # The identifier of the pd_master node handling the request. - pd_master_node_id: int - # decode node dp_index to handle this request - dp_index: int - - def __post_init__(self): - if not isinstance(self.group_request_id, int): - error_info = "group_request_id only can be int" - logger.error(error_info) - raise ValueError(error_info) - - if not isinstance(self.pd_master_node_id, int): - error_info = "pd_master_node_id only can be int" - logger.error(error_info) - raise ValueError(error_info) - return - - -@dataclass -class DecodeNodeInfo: - node_id: int - ip: str - rpyc_port: str - max_new_tokens: int +####### 下边是 pd kv 传输使用的对象 ######## @dataclass -class PDTransJoinInfo: - decode_id: int - decode_device_id: int - prefill_id: int - prefill_device_id: int - pd_prefill_nccl_ip: str - pd_prefill_nccl_port: int - # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分 - # 一次连接,使用一个 uuid 为其标识 - connect_id: str - - -@dataclass -class PDTransLeaveInfo: - decode_id: int - prefill_id: int - # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分 - # 一次连接,使用一个 uuid 为其标识 - connect_id: str - - -@dataclass -class KVMoveTask: +class PDUpKVStatus: group_request_id: int - input_tokens: List[int] # 代表输入的token_id 序列 - prefill_token_indexes: List[int] # 在prefill节点上 mem manager kv buffer中的token index - # 在decode节点上 mem manager kv buffer中的token index, 其代表的是真实占用的额外token,并不与prefill_token_indexes 一样长 - decode_token_indexes: List[int] - move_kv_len: int # 因为 prompt cache 的原因,当prefill节点和decode节点沟通后,传输的kv的数量可能少于 prefill_value 的长度 - prefill_node_id: int - decode_node: DecodeNodeInfo - # 保存prefill 和 decode 节点对应处理的dp_index, 如果是普通tp模式,这个值一定是0, - # 如果是deepseekv2的tp dp 混合模式, 才有真正的意义。 - prefill_dp_index: int - decode_dp_index: int pd_master_node_id: int - mark_start_time: float = None - # 标记任务使用某个连接id进行传输 - connect_id: str = None - - def __post_init__(self): - if len(self.input_tokens) <= 0: - error_info = "key must len >= 1" - logger.error(error_info) - raise ValueError(error_info) - - def to_prefill_log_info(self): - v_len = None if self.prefill_token_indexes is None else len(self.prefill_token_indexes) - d_i = self.prefill_dp_index - id = self.group_request_id - log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}" - return log + f" connect_id: {self.connect_id}" - - def to_decode_log_info(self): - v_len = None if self.decode_token_indexes is None else len(self.decode_token_indexes) - d_i = self.decode_dp_index - id = self.group_request_id - log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}" - return log + f" connect_id: {self.connect_id}" - - def id(self): - return self.group_request_id - - def get_cost_time(self): - if self.mark_start_time is not None: - return time.time() - self.mark_start_time - else: - return 100000000000 - - -@dataclass -class KVMoveTaskGroup: - tasks: List[KVMoveTask] - connect_id: str - - -####### 下边是 NIXL模式下使用的特定对象 ######## - - -@dataclass -class NixlUpKVStatus: - group_request_id: int - pd_master_node_id: int - nixl_params: bytes # nixl 建立连接所使用的元数据对象 + pd_kv_trans_params: bytes # pd kv 传输建立连接所使用的元数据对象 def __post_init__(self): @@ -219,11 +100,14 @@ def __post_init__(self): def __str__(self): req_id = self.group_request_id pd_m_id = self.pd_master_node_id - return f"group_request_id: {req_id} pd_master_node_id: {pd_m_id} nixl_params_len: {len(self.nixl_params)}" + return ( + f"group_request_id: {req_id} pd_master_node_id: {pd_m_id} " + f"pd_kv_trans_params_len: {len(self.pd_kv_trans_params)}" + ) @dataclass -class NIXLDecodeNodeInfo: +class PDDecodeNodeInfo: decode_node_id: int pd_master_node_id: int @@ -237,7 +121,7 @@ class NIXLDecodeNodeInfo: @dataclass -class NixlAgentMetadata: +class PDAgentMetadata: agent_name: str agent_metadata: bytes num_pages: int @@ -246,7 +130,7 @@ class NixlAgentMetadata: @dataclass -class NIXLChunckedTransTask: +class PDChunckedTransTask: request_id: int start_kv_index: int end_kv_index: int @@ -273,9 +157,11 @@ class NIXLChunckedTransTask: first_gen_token_id: Optional[int] first_gen_token_logprob: Optional[float] + write_stage: Optional[str] = None + # transfer params - nixl_src_page_index: Optional[int] = None - nixl_dst_page_index: Optional[int] = None + src_page_index: Optional[int] = None + dst_page_index: Optional[int] = None # xfer_handle xfer_handle: Optional[int] = None @@ -284,13 +170,23 @@ class NIXLChunckedTransTask: start_trans_time: float = None # 用于标记传输开始的时间。同时标记是否正在传输中 error_info: Optional[str] = None + transfer_time_out_secs: int = 66 + page_kind: str = "kv" + # Only valid for the local task owner; remote notify copies may carry the sender-local req_idx. + req_idx: Optional[int] = None def __post_init__(self): if self.start_kv_index < 0 or self.end_kv_index < self.start_kv_index: error_info = "start_kv_index must >=0 and end_kv_index > start_kv_index" logger.error(error_info) raise ValueError(error_info) - assert len(self.mem_indexes) == (self.end_kv_index - self.start_kv_index) + if self.page_kind == "kv": + assert len(self.mem_indexes) == (self.end_kv_index - self.start_kv_index) + elif self.page_kind == "linear_att_state": + assert self.start_kv_index == self.end_kv_index + assert len(self.mem_indexes) == 0 + else: + raise ValueError(f"unknown PD trans page kind {self.page_kind}") self.create_time = time.time() return @@ -300,7 +196,7 @@ def time_out(self) -> bool: return True return False else: - if time.time() - self.start_trans_time > self.time_out_secs + 88: + if time.time() - self.start_trans_time > self.transfer_time_out_secs: return True else: return False @@ -312,10 +208,10 @@ def transfer_time(self): return time.time() - self.start_trans_time def get_key(self) -> str: - return f"{self.request_id}_{self.start_kv_index}_{self.end_kv_index}" + return f"{self.request_id}_{self.page_kind}_{self.start_kv_index}_{self.end_kv_index}" def to_str(self): - obj: NIXLChunckedTransTask = copy.copy(self) + obj: PDChunckedTransTask = copy.copy(self) obj.mem_indexes = None if obj.decode_agent_metadata is not None: obj.decode_agent_metadata = b"xxx" @@ -328,10 +224,15 @@ def to_str(self): return obj.__str__() def transfer_kv_num(self): + if self.page_kind != "kv": + return 0 return self.end_kv_index - self.start_kv_index - def createRetObj(self) -> "NIXLChunckedTransTaskRet": - ret = NIXLChunckedTransTaskRet( + def need_transfer_page(self): + return self.page_kind != "kv" or self.transfer_kv_num() != 0 + + def createRetObj(self) -> "PDChunckedTransTaskRet": + ret = PDChunckedTransTaskRet( request_id=self.request_id, start_kv_index=self.start_kv_index, end_kv_index=self.end_kv_index, @@ -342,16 +243,16 @@ def createRetObj(self) -> "NIXLChunckedTransTaskRet": ) return ret - def create_prefill_agent_obj(self) -> NixlAgentMetadata: - return NixlAgentMetadata( + def create_prefill_agent_obj(self) -> PDAgentMetadata: + return PDAgentMetadata( agent_name=self.prefill_agent_name, agent_metadata=self.prefill_agent_metadata, num_pages=self.prefill_num_pages, page_reg_desc=self.prefill_page_reg_desc, ) - def create_decode_agent_obj(self) -> NixlAgentMetadata: - return NixlAgentMetadata( + def create_decode_agent_obj(self) -> PDAgentMetadata: + return PDAgentMetadata( agent_name=self.decode_agent_name, agent_metadata=self.decode_agent_metadata, num_pages=self.decode_num_pages, @@ -360,7 +261,7 @@ def create_decode_agent_obj(self) -> NixlAgentMetadata: @dataclass -class NIXLChunckedTransTaskRet: +class PDChunckedTransTaskRet: request_id: int start_kv_index: int end_kv_index: int @@ -374,11 +275,11 @@ def get_key(self) -> str: @dataclass -class NIXLChunckedTransTaskGroup: - task_list: List[NIXLChunckedTransTask] = field(default_factory=list) +class PDChunckedTransTaskGroup: + task_list: List[PDChunckedTransTask] = field(default_factory=list) @dataclass -class NIXLAbortReq: +class PDAbortReq: request_id: int device_id: int diff --git a/lightllm/server/reasoning_parser.py b/lightllm/server/reasoning_parser.py index 21e1fc3e4f..8a8d07355b 100644 --- a/lightllm/server/reasoning_parser.py +++ b/lightllm/server/reasoning_parser.py @@ -622,20 +622,38 @@ def detect_and_parse(self, text: str) -> StreamingParseResult: return StreamingParseResult(normal_text=normal_text, reasoning_text=reasoning_text) + def flush(self) -> StreamingParseResult: + """ + Flush any remaining buffered content when generation ends prematurely + (e.g., max_completion_tokens reached before is seen). + Returns buffered content as reasoning_text (if still in reasoning block) + or normal_text (if in normal content block). + """ + if not self._buffer: + return StreamingParseResult() + remaining = self._buffer + self._buffer = "" + if self._in_reasoning: + return StreamingParseResult(reasoning_text=remaining) + else: + return StreamingParseResult(normal_text=remaining) + def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: """ Streaming incremental parsing for reasoning content. Handles partial reasoning tags and content. - If stream_reasoning is False: - Accumulates reasoning content until the end tag is found - If stream_reasoning is True: - Streams reasoning content as it arrives + Reasoning tokens are always streamed immediately as they arrive, + regardless of stream_reasoning setting (aligns with vLLM behavior). + The only exception is when the buffer holds a partial tag prefix + (e.g. ""), in which case we + keep buffering until the tag is confirmed or refuted. """ self._buffer += new_text current_text = self._buffer # If the current text is a prefix of the think token, keep buffering + # until we can confirm or refute the tag. if any( token.startswith(current_text) and token != current_text for token in [self.think_start_token, self.think_end_token] @@ -660,14 +678,11 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: return StreamingParseResult(normal_text=normal_text, reasoning_text=reasoning_text.rstrip()) - # Continue with reasoning content + # Always stream reasoning content immediately. + # stream_reasoning flag is ignored for streaming responses. if self._in_reasoning: - if self.stream_reasoning: - # Stream the content immediately - self._buffer = "" - return StreamingParseResult(reasoning_text=current_text) - else: - return StreamingParseResult() + self._buffer = "" + return StreamingParseResult(reasoning_text=current_text) # If we're not in a reasoning block return as normal text if not self._in_reasoning: @@ -847,6 +862,33 @@ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False) ) +class Gemma4Detector(BaseReasoningFormatDetector): + """ + Detector for Google Gemma-4 thinking models. + + Format: ``<|channel>thought\\n...reasoning...\\nanswer``. + Role label ``thought\\n`` is baked into the start token (cf. + GptOssDetector) so the base class strips it for free. + + Note: ``<|channel>`` and ```` are special tokens (ids 100/101). + The API layer forces ``skip_special_tokens=False`` when this parser is + active so the delimiters survive decoding (see ``api_openai.py``). + """ + + THINK_START_TOKEN = "<|channel>thought\n" + THINK_END_TOKEN = "" + + def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False): + # force_reasoning ignored: Gemma-4's template never starts generation + # inside an open channel (ReasoningParser pins it to False too). + super().__init__( + self.THINK_START_TOKEN, + self.THINK_END_TOKEN, + force_reasoning=False, + stream_reasoning=stream_reasoning, + ) + + class ReasoningParser: """ Parser that handles both streaming and non-streaming scenarios for extracting @@ -872,6 +914,7 @@ class ReasoningParser: "step3": DeepSeekR1Detector, "nano_v3": NanoV3Detector, "interns1": Qwen3Detector, + "gemma4": Gemma4Detector, } def __init__( @@ -887,9 +930,12 @@ def __init__( if not detector_class: raise ValueError(f"Unsupported model type: {model_type}") - # Special cases where we override force_reasoning - if model_type.lower() in {"qwen3-thinking", "gpt-oss", "minimax"}: - force_reasoning = True + elif model_type.lower() == "gemma4": + # Gemma-4's chat template never positions generation inside an open + # channel — see Gemma4Detector docstring. Pin to False so a + # request_enable_reasoning=True from the caller can't accidentally + # mark the parser as already inside reasoning. + force_reasoning = False # Only pass force_reasoning if explicitly set, let detectors use their defaults kwargs = {"stream_reasoning": stream_reasoning} @@ -907,3 +953,8 @@ def parse_stream_chunk(self, chunk_text: str) -> Tuple[Optional[str], Optional[s """Streaming call: incremental parsing""" ret = self.detector.parse_streaming_increment(chunk_text) return ret.reasoning_text, ret.normal_text + + def flush(self) -> Tuple[Optional[str], Optional[str]]: + """Flush remaining buffered content when generation ends prematurely.""" + ret = self.detector.flush() + return ret.reasoning_text, ret.normal_text diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py deleted file mode 100644 index 44bb269ed8..0000000000 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import Set, Protocol, List, Optional, Tuple - -import torch -from sortedcontainers import SortedSet - -from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode -from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class HybridRadixCache(RadixCache): - def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager): - super().__init__(unique_name, total_token_num, rank_in_node, kv_cache_mem_manager) - assert hasattr(kv_cache_mem_manager, "mamba_cache_mem_manager") - self.buffer_mem_manager: MambaCacheManager = kv_cache_mem_manager.mamba_cache_mem_manager - self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,)) - - def match_prefix(self, key, update_refs=False): - assert len(key) != 0 - ans_value_list = [] - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - evict_token_list = [] - kv_len = tree_node.node_prefix_total_len - while tree_node != self.root_node and tree_node.buffer_idx is None: - if tree_node.is_leaf(): - self.evict_tree_set.discard(tree_node) - - # Only update ref_counter when update_refs is True to maintain consistency - # with _match_prefix_helper which only increments ref_counter when update_refs=True - if update_refs: - if tree_node.ref_counter == 1: - self.refed_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) - tree_node.ref_counter -= 1 - kv_len -= len(ans_value_list.pop()) - if tree_node.is_leaf(): - self.evict_tree_set.add(tree_node) - tree_node = tree_node.parent - - if len(evict_token_list) > 0: - evict_token_value = torch.concat(evict_token_list) - self.mem_manager.free(evict_token_value) - - if tree_node == self.root_node: - return None, kv_len, None - - update_node = tree_node - while update_node != self.root_node: - if update_node.buffer_idx is not None: - self.evict_buffer_set.discard(update_node) - update_node.update_buffer_time() - self.evict_buffer_set.add(update_node) - update_node = update_node.parent - - value = torch.concat(ans_value_list) - return tree_node, kv_len, value - - def add_buffer_idx_to_node(self, node: TreeNode, buffer_idx: int): - """Set buffer_idx for a node and add it to evict_buffer_set.""" - self.evict_buffer_set.discard(node) - if node.is_leaf(): - self.evict_tree_set.discard(node) - if node.buffer_idx is not None: - self.buffer_mem_manager.free([node.buffer_idx]) - node.buffer_idx = buffer_idx - node.update_buffer_time() - self.evict_buffer_set.add(node) - if node.is_leaf(): - self.evict_tree_set.add(node) - return - - def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): - if need_buffer_num > self.buffer_mem_manager.can_use_mem_size: - need_evict_buffer_num = need_buffer_num - self.buffer_mem_manager.can_use_mem_size - release_buffers = [] - - def release_buffer(buffer_idx): - release_buffers.append(buffer_idx) - return - - self._evict_buffer(need_evict_buffer_num, release_buffer) - if len(release_buffers) > 0: - self.buffer_mem_manager.free(release_buffers) - return - - def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback): - while need_evict_buffer_num > 0: - node = self.evict_buffer_set.pop(0) - assert node.buffer_idx is not None - evict_buffer_callback(node.buffer_idx) - node.buffer_idx = None - need_evict_buffer_num -= 1 - return - - def free_radix_cache_to_get_enough_token(self, need_token_num): - assert self.mem_manager is not None - if need_token_num > self.mem_manager.can_use_mem_size: - need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size - release_mems = [] - - def release_mem(mem_index): - release_mems.append(mem_index) - return - - release_buffers = [] - - def release_buffer(buffer_idx): - release_buffers.append(buffer_idx) - return - - self.evict(need_evict_token_num, release_buffer, release_mem) - mem_index = torch.concat(release_mems) - self.mem_manager.free(mem_index) - if len(release_buffers) > 0: - self.buffer_mem_manager.free(release_buffers) - return - - def evict(self, need_remove_tokens, evict_buffer_callback, evict_callback): - if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: - assert False, f"""can not free tree tokens {need_remove_tokens}, - tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, - refed_tokens_num {self.refed_tokens_num.arr[0]}""" - num_evicted = 0 - while num_evicted < need_remove_tokens: - node: TreeNode = self.evict_tree_set.pop(0) - assert ( - node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node - ), f"error evict tree node state: {node.ref_counter}, {len(node.children)}" - num_evicted += len(node.token_mem_index_value) - evict_callback(node.token_mem_index_value) - if node.buffer_idx is not None: - self.evict_buffer_set.discard(node) - evict_buffer_callback(node.buffer_idx) - node.buffer_idx = None - # update total token num - self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) - parent_node: TreeNode = node.parent - parent_node.remove_child(node) - if parent_node.is_leaf(): - self.evict_tree_set.add(parent_node) - - return diff --git a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py new file mode 100644 index 0000000000..bf07e121e6 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py @@ -0,0 +1,636 @@ +import torch +import numpy as np +from typing import Tuple, Dict, Set, List, Optional +from sortedcontainers import SortedSet, SortedDict +from lightllm.common.linear_att_cache_manager import LinearAttCacheManager +from .shared_arr import SharedArray +from .radix_cache import time_gen + + +class LinearAttPagedTreeNode: + def __init__(self, hash_page_size: int, big_page_num: int): + self.hash_page_size = hash_page_size + self.big_page_num = big_page_num + + # children are keyed by the last ``block_hash`` of each child + self.children: Dict[int, "LinearAttPagedTreeNode"] = {} + self.parent: "LinearAttPagedTreeNode" = None + + # Hash of the last page in this node (None for the empty root). + self.page_num = None # 页面数量,只能是 1 或者 big_page_num + self.page_hash: Optional[int] = None + + # token-level data for this node; length == num_pages * hash_page_size + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None + + self.ref_counter = 0 + self.time_id = time_gen.generate_time_id() + + self.node_value_len = 0 + self.node_prefix_total_len = 0 + + # Kept for parity with ``TreeNode`` (used by hybrid attention models). + self.small_page_buffer_idx = None # 这个是对应小页的buffer_id, 可能存在可能不存在 + self.big_page_buffer_idx = None # 当初始化后,如果该页面是大页,则该buffer_id 必然存在不是None + + def is_big_page_node(self): + assert self.node_prefix_total_len % self.hash_page_size == 0 + return self.node_prefix_total_len % (self.hash_page_size * self.big_page_num) == 0 + + def get_compare_key(self): + assert len(self.children) == 0 + if self.is_big_page_node(): + keya = 1 + else: + if self.small_page_buffer_idx is None: + keya = 0 + else: + keya = 1 + # 对于叶节点,非大页节点,如果不存在buffer_idx 的时候,说明无法被复用了,所以应该提前被回收掉,放在evict_tree_set的前面。 + return (0 if self.ref_counter == 0 else 1, keya, self.time_id) + + def get_compare_key_for_buffer_idx(self): + assert self.is_big_page_node() is False + # 对于有 buffer_id 的节点的回收处理比较器 + assert self.small_page_buffer_idx is not None + return (self.time_id,) + + def add_and_return_new_child( + self, + token_id_key: torch.Tensor, + token_mem_index_value: torch.Tensor, + block_hash: int, + small_page_buffer_idx: Optional[int], + ) -> "LinearAttPagedTreeNode": + assert len(token_id_key) == self.hash_page_size == len(token_mem_index_value) + child = LinearAttPagedTreeNode(hash_page_size=self.hash_page_size, big_page_num=self.big_page_num) + child.page_hash = block_hash + child.small_page_buffer_idx = small_page_buffer_idx + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + child.page_num = 1 + assert child.page_hash not in self.children, "duplicate last block hash in children" + self.children[child.page_hash] = child + child.parent = self + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def add_and_return_new_big_page_child( + self, token_id_key: torch.Tensor, token_mem_index_value: torch.Tensor, block_hash: int, big_page_buffer_idx: int + ) -> "LinearAttPagedTreeNode": + assert len(token_id_key) == self.hash_page_size * self.big_page_num == len(token_mem_index_value) + child = LinearAttPagedTreeNode(hash_page_size=self.hash_page_size, big_page_num=self.big_page_num) + child.page_hash = block_hash + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + child.page_num = self.big_page_num + child.big_page_buffer_idx = big_page_buffer_idx + assert child.big_page_buffer_idx is not None + assert child.page_hash not in self.children, "duplicate last block hash in children" + child.parent = self + self.children[child.page_hash] = child + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def remove_child(self, child_node: "LinearAttPagedTreeNode"): + del self.children[child_node.page_hash] + child_node.parent = None + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return len(self.children) == 0 + + +class LinearAttPagedRadixCache: + def __init__( + self, + unique_name: str, + total_token_num: int, + rank_in_node: int, + hash_page_size: int, + big_page_num: int, + kv_cache_mem_manager=None, + linear_att_small_page_buffers=None, + ): + from lightllm.common.kv_cache_mem_manager import MemoryManager + + assert hash_page_size >= 1, "hash_page_size must be >= 1" + assert big_page_num >= 1, "big_page_num must be >= 1" + + self.hash_page_size = hash_page_size + self.big_page_num = big_page_num + self.big_page_tokens = hash_page_size * big_page_num + self.total_token_num = total_token_num + + self.mem_manager: MemoryManager = kv_cache_mem_manager + + self.linear_att_big_page_buffers: LinearAttCacheManager = self.mem_manager.linear_att_big_page_buffers + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + + self.root_node = LinearAttPagedTreeNode(hash_page_size=hash_page_size, big_page_num=big_page_num) + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 + self.root_node.page_num = self.big_page_num + + self._evict_tree_set: Set[LinearAttPagedTreeNode] = SortedSet(key=lambda x: x.get_compare_key()) + self._evict_tree_set_for_linear_att: Set[LinearAttPagedTreeNode] = SortedSet( + key=lambda x: x.get_compare_key_for_buffer_idx() + ) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + self.linear_att_small_page_buffers: LinearAttCacheManager = linear_att_small_page_buffers + + def _discard_node(self, node: LinearAttPagedTreeNode): + if node.is_leaf(): + self._evict_tree_set.discard(node) + if node.small_page_buffer_idx is not None: + self._evict_tree_set_for_linear_att.discard(node) + return + + def _add_node(self, node: LinearAttPagedTreeNode): + # root 永远不参与回收:当树为空时 root 自身也满足 is_leaf(),若加入 _evict_tree_set, + # 会与 _evict 中 "node is not self.root_node" 的断言相矛盾(当前仅靠 root 的 ref_counter>=1 + # 和回收水位 guard 掩盖)。这里显式排除,使数据结构与回收逻辑的意图一致。 + if node.is_leaf() and node is not self.root_node: + self._evict_tree_set.add(node) + if node.small_page_buffer_idx is not None: + self._evict_tree_set_for_linear_att.add(node) + return + + def insert( + self, + key: torch.Tensor, + value: Optional[torch.Tensor] = None, + block_hashs: Optional[List[int]] = None, + block_linear_idxs: Optional[List[int]] = None, + len_to_big_page_id: Optional[SortedDict] = None, + ) -> Tuple[int, Optional[LinearAttPagedTreeNode]]: + assert key is not None + if value is None: + value = key + assert len(key) == len(value) + if block_hashs is None: + block_hashs = [] + if block_linear_idxs is None: + block_linear_idxs = [] + if len_to_big_page_id is None: + len_to_big_page_id = SortedDict() + + assert (len(block_hashs) // self.big_page_num) >= len(len_to_big_page_id) + + assert ( + len(key) == len(block_hashs) * self.hash_page_size + ), f"key length {len(key)} does not match block_hashs length {len(block_hashs)} * {self.hash_page_size}" + assert len(block_hashs) == len( + block_linear_idxs + ), f"block_hashs length {len(block_hashs)} does not match block_linear_idxs length {len(block_linear_idxs)}" + + if len(block_hashs) == 0: + return 0, None + + if len(block_hashs) % self.big_page_num == 0: + assert all( + e is None for e in block_linear_idxs + ), "all block_linear_idxs must be None when block_hashs length is a multiple of big_page_num" + else: + # TODO, test stable then to delete this assertion + assert all( + e is None for e in block_linear_idxs[:-1] + ), "only the last block_linear_idx can be non-None, for compatibility with non-paged radix cache" + assert ( + block_linear_idxs[-1] is not None + ), "the last block_linear_idx must not be None, for compatibility with non-paged radix cache" + + ans = self._insert_helper(self.root_node, key, value, block_hashs, block_linear_idxs, len_to_big_page_id) + assert len(len_to_big_page_id) == 0 + return ans + + def _insert_helper( + self, + node: LinearAttPagedTreeNode, + key: torch.Tensor, + value: torch.Tensor, + block_hashs: List[int], + block_linear_idxs: List[int], + len_to_big_page_id: SortedDict, + ) -> Tuple[int, Optional[LinearAttPagedTreeNode]]: + self._discard_node(node) + node.update_time() + + try: + if len(block_hashs) == 0: + return 0, node + # 先看是不是能插入一个大页节点 + if len(block_hashs) >= self.big_page_num: + # 插入大叶节点 + big_page_block_hash = block_hashs[self.big_page_num - 1] + big_page_token_id_key = key[: self.big_page_tokens] + big_page_token_mem_index_value = value[: self.big_page_tokens] + if big_page_block_hash in node.children: + assert node.is_big_page_node() + child = node.children[big_page_block_hash] + assert child.is_big_page_node() + + # 提前释放 len_to_big_page_id 对应的buffer资源 + new_big_page_buffer_id = len_to_big_page_id.pop(child.node_prefix_total_len, None) + if new_big_page_buffer_id is not None: + # 因为节点已经存在,所以无法插入,但是要释放对应的buffer_id 节点 + self.linear_att_big_page_buffers.free_state_cache([new_big_page_buffer_id]) + + # 已经存在了 + sub_prefix_len, ans_node = self._insert_helper( + child, + key[self.big_page_tokens :], + value[self.big_page_tokens :], + block_hashs[self.big_page_num :], + block_linear_idxs[self.big_page_num :], + len_to_big_page_id, + ) + return self.big_page_tokens + sub_prefix_len, ans_node + else: + # 不存在,则新建一个大页节点 + assert node.is_big_page_node() + new_big_page_buffer_id = len_to_big_page_id.pop( + node.node_prefix_total_len + self.big_page_tokens, None + ) + assert new_big_page_buffer_id is not None + + new_child = node.add_and_return_new_big_page_child( + big_page_token_id_key, + big_page_token_mem_index_value, + big_page_block_hash, + new_big_page_buffer_id, + ) + self.tree_total_tokens_num.arr[0] += self.big_page_tokens + assert new_child.is_big_page_node() + assert new_child.page_num == self.big_page_num + _, ans_node = self._insert_helper( + new_child, + key[self.big_page_tokens :], + value[self.big_page_tokens :], + block_hashs[self.big_page_num :], + block_linear_idxs[self.big_page_num :], + len_to_big_page_id, + ) + return 0, ans_node + else: + # 插入小页节点的情况 + assert len(block_hashs) < self.big_page_num + + # 是否已经存在了。 + if block_hashs[0] in node.children: + child = node.children[block_hashs[0]] + + if block_linear_idxs[0] is not None: + assert len(block_hashs) == 1 == len(block_linear_idxs) + if child.small_page_buffer_idx is None: + # 将这个buffer id 移交给这个存在的节点。 + self._discard_node(child) + child.small_page_buffer_idx = block_linear_idxs[0] + self._add_node(child) + else: + # 说明节点已经存在了,直接提前移除掉这个节点占用的线性缓存,外部不用处理这个细节了 + self.linear_att_small_page_buffers.free_state_cache(free_indexes=[block_linear_idxs[0]]) + + sub_prefix_len, ans_node = self._insert_helper( + child, + key[self.hash_page_size :], + value[self.hash_page_size :], + block_hashs[1:], + block_linear_idxs[1:], + len_to_big_page_id, + ) + return self.hash_page_size + sub_prefix_len, ans_node + else: + new_node = node.add_and_return_new_child( + key[: self.hash_page_size], + value[: self.hash_page_size], + block_hashs[0], + block_linear_idxs[0], + ) + assert not new_node.is_big_page_node() + assert new_node.page_num == 1 + self.tree_total_tokens_num.arr[0] += self.hash_page_size + _, ans_node = self._insert_helper( + new_node, + key[self.hash_page_size :], + value[self.hash_page_size :], + block_hashs[1:], + block_linear_idxs[1:], + len_to_big_page_id, + ) + return 0, ans_node + + finally: + self._add_node(node) + + def match_prefix( + self, + key: torch.Tensor, + block_hashs: Optional[List[int]] = None, + update_refs: bool = False, + ): + assert update_refs is True, "update_refs must be True" + assert key is not None, "key must not be None" + if block_hashs is None: + block_hashs = [] + + assert ( + len(key) == len(block_hashs) * self.hash_page_size + ), f"key length {len(key)} does not match block_hashs length {len(block_hashs)} * {self.hash_page_size}" + + if len(block_hashs) == 0 or len(key) == 0: + return None, 0, None + + ans_node_list: List[LinearAttPagedTreeNode] = [] + self._match_prefix_helper( + self.root_node, + key=key, + block_hashs=block_hashs, + ans_node_list=ans_node_list, + update_refs=update_refs, + ) + if len(ans_node_list) == 0: + return None, 0, None + + # 判定真正可以用的匹配节点。 + ans_node_list = self._trim_unusable_match_tail(ans_node_list) + if len(ans_node_list) == 0: + return None, 0, None + + ans_node = ans_node_list[-1] + mem_value = torch.concat([e.token_mem_index_value for e in ans_node_list]) + assert len(mem_value) == ans_node.node_prefix_total_len + + return ans_node, len(mem_value), mem_value + + def _match_prefix_helper( + self, + node: LinearAttPagedTreeNode, + key: torch.Tensor, + block_hashs: Optional[List[int]], + ans_node_list: list, + update_refs: bool = False, + ): + self._discard_node(node) + node.update_time() + + try: + if update_refs: + node.ref_counter += 1 + # from 0 to 1 need update refs token num + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + if len(block_hashs) == 0: + return + + if len(block_hashs) >= self.big_page_num: + # 大页的匹配 + big_page_block_hash = block_hashs[self.big_page_num - 1] + if big_page_block_hash in node.children: + child = node.children[big_page_block_hash] + assert child.is_big_page_node() + ans_node_list.append(child) + self._match_prefix_helper( + child, + key[self.big_page_tokens :], + block_hashs[self.big_page_num :], + ans_node_list, + update_refs, + ) + return + + # 小页匹配的情况 + if block_hashs[0] in node.children: + child = node.children[block_hashs[0]] + ans_node_list.append(child) + self._match_prefix_helper( + child, + key[(self.hash_page_size) :], + block_hashs[1:], + ans_node_list, + update_refs, + ) + return + else: + return + + finally: + self._add_node(node) + + def _trim_unusable_match_tail(self, nodes: List[LinearAttPagedTreeNode]) -> List[LinearAttPagedTreeNode]: + removed_list = [] + for node in reversed(nodes): + if node.is_big_page_node(): + break + elif node.small_page_buffer_idx is not None: + assert not node.is_big_page_node() + break + else: + removed_list.append(node) + + for node in removed_list: + self._discard_node(node) + # dec ref + node.ref_counter -= 1 + if node.ref_counter == 0: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + + self._add_node(node) + + if len(removed_list) == 0: + return nodes + else: + return nodes[: -len(removed_list)] + + def _try_merge(self, child_node: LinearAttPagedTreeNode) -> Optional[LinearAttPagedTreeNode]: + raise NotImplementedError() + + def merge_unreferenced_nodes(self): + raise NotImplementedError() + + def clear_tree_nodes(self): + """Only used in tests.""" + self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) + return + + def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional[LinearAttPagedTreeNode]: + assert not node.is_big_page_node() + iter_node = node + while not iter_node.is_big_page_node(): + self._discard_node(iter_node) + + if iter_node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(iter_node.token_mem_index_value) + iter_node.ref_counter -= 1 + + self._add_node(iter_node) + + iter_node = iter_node.parent + + if iter_node is self.root_node: + return None + else: + return iter_node + + def dec_node_ref_counter(self, node: LinearAttPagedTreeNode): + if node is None: + return + old_node = node + self._discard_node(old_node) + + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + + self._add_node(old_node) + return + + def add_node_ref_counter(self, node: LinearAttPagedTreeNode): + if node is None: + return + old_node = node + self._discard_node(old_node) + + while node is not None: + if node.ref_counter == 0: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + node.ref_counter += 1 + node = node.parent + + self._add_node(old_node) + return + + def get_mem_index_value_by_node(self, node: LinearAttPagedTreeNode) -> Optional[torch.Tensor]: + if node is None: + return None + + ans_list = [] + while node is not None: + ans_list.append(node.token_mem_index_value) + node = node.parent + + ans_list.reverse() + return torch.concat(ans_list, dim=0) + + def get_big_page_ids_by_node(self, node: LinearAttPagedTreeNode) -> List[int]: + if node is None: + return [] + if node is self.root_node: + return [] + + ans_list = [] + while node is not self.root_node: + if node.is_big_page_node(): + ans_list.append(node.big_page_buffer_idx) + node = node.parent + + ans_list.reverse() + return ans_list + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: LinearAttPagedTreeNode, indent): + print( + " " * indent, + f"hash_info: {node.page_hash} " + f"k: {node.token_id_key[0:10] if node.token_id_key is not None else None} " + f"v: {node.token_mem_index_value[0:10] if node.token_mem_index_value is not None else None} " + f"refs: {node.ref_counter} time_id: {node.time_id} " + f"prefix_total_len: {node.node_prefix_total_len} " + f"node_value_len: {node.node_value_len} buffer_idx: {node.small_page_buffer_idx}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.allocator.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.allocator.can_use_mem_size + release_mems = [] + small_page_buffer_ids = [] + + def release_mem(mem_index, linear_att_small_page_id): + release_mems.append(mem_index) + small_page_buffer_ids.append(linear_att_small_page_id) + return + + self._evict(need_evict_token_num, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + small_page_buffer_ids = [idx for idx in small_page_buffer_ids if idx is not None] + if len(small_page_buffer_ids) > 0: + self.linear_att_small_page_buffers.free_state_cache(small_page_buffer_ids) + return + + def free_one_small_page_linear_att_buffer(self): + if self.linear_att_small_page_buffers is None: + return + if self.linear_att_small_page_buffers.get_free_cache_num() > 0: + return + if len(self._evict_tree_set_for_linear_att) == 0: + return + + node: LinearAttPagedTreeNode = self._evict_tree_set_for_linear_att.pop(0) + self._discard_node(node) + + assert node.small_page_buffer_idx is not None + self.linear_att_small_page_buffers.free_state_cache(free_indexes=[node.small_page_buffer_idx]) + node.small_page_buffer_idx = None + + self._add_node(node) + return + + def _evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: LinearAttPagedTreeNode = self._evict_tree_set.pop(0) + self._discard_node(node) + + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node is not self.root_node + ), "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + + if node.is_big_page_node(): + assert node.big_page_buffer_idx is not None + self.linear_att_big_page_buffers.free_state_cache([node.big_page_buffer_idx]) + + evict_callback(node.token_mem_index_value, node.small_page_buffer_idx) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: LinearAttPagedTreeNode = node.parent + parent_node.remove_child(node) + + self._add_node(parent_node) + + return diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index afb1d8e4b2..21e26c5854 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -31,12 +31,6 @@ def __init__(self): self.node_value_len = 0 self.node_prefix_total_len = 0 - # Used by hybrid attention models (e.g., Qwen3Next) to track - # a per-request buffer_idx alongside the token-level KV cache. - # Pure attention models keep buffer_idx as None. - self.buffer_idx = None - self.buffer_time = time_gen.generate_time_id() - def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) @@ -84,9 +78,6 @@ def remove_child(self, child_node: "TreeNode"): def update_time(self): self.time_id = time_gen.generate_time_id() - def update_buffer_time(self): - self.buffer_time = time_gen.generate_time_id() - def is_leaf(self): return len(self.children) == 0 @@ -112,10 +103,10 @@ class RadixCache: unique_name 主要用于解决单机,多实列部署时的shm冲突 """ - def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager=None): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): from lightllm.common.kv_cache_mem_manager import MemoryManager - self.mem_manager: MemoryManager = kv_cache_mem_manager + self.mem_manager: MemoryManager = mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 @@ -354,22 +345,20 @@ def evict(self, need_remove_tokens, evict_callback): def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: """ - merge condition: - 1. parent_node is not root node. - 2. parent_node's ref_counter is 0, for hybrid attention models (e.g., Qwen35), - parent_node's buffer_idx is None. - 3. child_node's ref_counter is 0. - 4. parent_node has only one child node (i.e. child_node). + 合并条件: + 1. 父节点不是根节点。 + 2. 父节点的引用计数为 0。 + 3. 子节点的引用计数为 0。 + 4. 父节点只有一个子节点 (即 child_node)。 """ parent_node = child_node.parent - # condition check + # 条件检查 if ( parent_node is None or parent_node == self.root_node or parent_node.ref_counter != 0 or len(parent_node.children) != 1 or child_node.ref_counter != 0 - or parent_node.buffer_idx is not None ): return None @@ -412,12 +401,6 @@ def merge_unreferenced_nodes(self): if merged_node: worklist.append(merged_node) - def assert_leafs_is_right(self): - for node in self.evict_tree_set: - if node.is_leaf() and node.ref_counter == 0: - a = node.token_mem_index_value.cuda() - assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a) - def clear_tree_nodes(self): """ 该函数只在测试时调用 @@ -500,7 +483,7 @@ def _print_helper(self, node: TreeNode, indent): " " * indent, f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ - node_value_len: {node.node_value_len} buffer_idx: {node.buffer_idx}", + node_value_len: {node.node_value_len}", ) for _, child in node.children.items(): self._print_helper(child, indent=indent + 2) @@ -508,8 +491,8 @@ def _print_helper(self, node: TreeNode, indent): def free_radix_cache_to_get_enough_token(self, need_token_num): assert self.mem_manager is not None - if need_token_num > self.mem_manager.can_use_mem_size: - need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size + if need_token_num > self.mem_manager.allocator.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.allocator.can_use_mem_size release_mems = [] def release_mem(mem_index): diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index f5e0b8df9a..dfb8866601 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -12,7 +12,7 @@ import torch.multiprocessing as mp import torch.distributed as dist import multiprocessing -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue @@ -26,15 +26,16 @@ from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.profiler import ProfilerCmd from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from .stats import RouterStatics - +from .profiler_service import RouterProfilerCmdQueue, start_router_profiler_server logger = init_logger(__name__) @@ -60,6 +61,8 @@ def __init__(self, args: StartArgs): self.is_safe_schedule = args.router_token_ratio == 0.0 self.load_way = args.load_way self.max_total_token_num = args.max_total_token_num + # 存储在共享内存中的真实token容量数据 + self.shm_max_total_token_num = SharedInt(f"{get_unique_server_name()}_shm_max_total_token_num") self.shm_req_manager = ShmReqManager() # 用共享内存进行共享,router 模块读取进行精确的调度估计 self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() @@ -70,7 +73,6 @@ def __init__(self, args: StartArgs): self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) for dp_index in range(self.dp_size_in_node): self.shared_token_load.set_estimated_peak_token_count(0, dp_index) - self.shared_token_load.set_frozened_token_count(0, dp_index) self.shared_token_load.set_current_load(0.0, dp_index) self.shared_token_load.set_logical_max_load(0.0, dp_index) self.shared_token_load.set_dynamic_max_load(0.0, dp_index) @@ -92,13 +94,8 @@ def __init__(self, args: StartArgs): ) self.metric_client = MetricClient(args.metric_port) - self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"] - self.is_pd_decode_mode = self.args.run_mode in ["decode", "nixl_decode"] - # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 - # 主要是为了防止调度失误,造成 OOM 等错误 - self.router_lock = mp.Lock() - g_router_lock.obj = self.router_lock - + self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] + self.is_pd_decode_mode = self.args.run_mode == "decode" self.shm_reqs_io_buffer = ShmObjsIOBuffer() self.cpu_cache_client = ( @@ -107,6 +104,8 @@ def __init__(self, args: StartArgs): else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False) ) self.router_statics = RouterStatics(self.args) + self.profiler_cmd_queue = RouterProfilerCmdQueue() + return async def wait_to_model_ready(self): @@ -132,7 +131,6 @@ async def wait_to_model_ready(self): rank_in_node=rank_in_node, node_world_size=node_world_size, info_queue=self.info_queue, - router_lock=self.router_lock, ) ) tasks.append(task) @@ -168,6 +166,7 @@ async def wait_to_model_ready(self): "batch_max_tokens": self.args.batch_max_tokens, "quant_type": self.args.quant_type, "quant_cfg": self.args.quant_cfg, + "expert_dtype": self.args.expert_dtype, "pd_rpyc_ports": self.args.pd_node_infer_rpyc_ports, # 非 pd 模式可以不设置 } @@ -185,6 +184,10 @@ async def wait_to_model_ready(self): assert max(_nums) == min(_nums), "all rank must have same token num" self.max_total_token_num = _nums[0] self.args.max_total_token_num = self.max_total_token_num + + self.shm_max_total_token_num.set_value(self.max_total_token_num) + logger.info(f"set shm_max_total_token_num value to {self.shm_max_total_token_num.get_value()}") + if not self.args.disable_dynamic_prompt_cache: self.radix_cache_client = RadixCacheReadOnlyClient( get_unique_server_name(), @@ -196,30 +199,14 @@ async def wait_to_model_ready(self): logger.info(f"use req queue {self.req_queue.__class__.__name__}") if self.args.run_mode == "prefill": - # 启动 prefill kv move 管理进程 - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.prefill_node_impl import ( - start_prefill_kv_move_manager_process, - ) - - start_prefill_kv_move_manager_process(self.args, self.info_queue) - - if self.args.run_mode == "nixl_prefill": - from lightllm.server.router.model_infer.mode_backend.pd_nixl.prefill_node_impl import ( + from lightllm.server.router.model_infer.mode_backend.pd.prefill_node_impl import ( start_prefill_kv_move_manager_process, ) start_prefill_kv_move_manager_process(self.args, self.info_queue) if self.args.run_mode == "decode": - # 启动 decode kv move 管理进程 - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import ( - start_decode_kv_move_manager_process, - ) - - start_decode_kv_move_manager_process(self.args, self.info_queue) - - if self.args.run_mode == "nixl_decode": - from lightllm.server.router.model_infer.mode_backend.pd_nixl.decode_node_impl import ( + from lightllm.server.router.model_infer.mode_backend.pd.decode_node_impl import ( start_decode_kv_move_manager_process, ) @@ -247,22 +234,21 @@ async def loop_for_fwd( - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) ) / self.max_total_token_num d_i = dp_index - frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i) estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i) logger.debug( f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" f"dp_i {d_i} paused req num: {paused_req_num} \n" - f"dp_i {d_i} frozen token num: {frozen_token_num} \n" f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" ) logger.debug(self.router_statics.log_str()) - self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num) + self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode) self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs)) + self.metric_client.gauge_set("lightllm_num_running_reqs", len(self.running_batch.reqs)) self.metric_client.gauge_set("lightllm_queue_size", self.req_queue.get_wait_req_num()) self.metric_client.gauge_set( "lightllm_batch_current_max_tokens", @@ -275,15 +261,14 @@ async def loop_for_fwd( self.req_queue.update_token_load(self.running_batch, force_update=True) if counter_count % 300 == 0: self.metric_client.gauge_set("lightllm_batch_current_size", 0.0) + self.metric_client.gauge_set("lightllm_num_running_reqs", 0.0) self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0) self.metric_client.gauge_set("lightllm_queue_size", 0.0) self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0) # 60s print once - if log_time_ready("frozen_info", 60): + if log_time_ready("token_load_info", 60): for dp_i in range(self.dp_size_in_node): - frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i) estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i) - logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n") logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n") await asyncio.sleep(self._get_schedule_time_interval()) @@ -294,6 +279,7 @@ async def _step(self): """ # 接受新请求,并尝试调度 await self._recv_new_reqs_and_schedule() + await self._write_profiler_cmds() # 判断是否有新请求加入推理 # 激进调度满足,有新的推理batch就需要进行加入。 # 或者延迟step的步数满足了当前条件,也需要进行新的推理batch的加入。 @@ -322,6 +308,17 @@ async def _add_batch(self, batch: Batch): logger.debug(f"Prefill Batch: {batch.simple_log()} \n") return + async def _write_profiler_cmds(self): + cmd = self.profiler_cmd_queue.pop() + if cmd is None: + return + + while not self.shm_reqs_io_buffer.is_empty(): + await asyncio.sleep(0.001) + self.shm_reqs_io_buffer.write_obj([ProfilerCmd(cmd)]) + self.shm_reqs_io_buffer.set_ready() + return + async def _aborted_reqs(self, aborted_reqs: List[Req]): cmds = [AbortedReqCmd(req_id=r.request_id) for r in aborted_reqs] while not self.shm_reqs_io_buffer.is_empty(): @@ -429,9 +426,11 @@ def _generate_new_batch(self): new_batch = self.req_queue.generate_new_batch( Batch.merge_two_batch(self.running_batch, self.schedule_new_batch) ) + + if new_batch is not None and len(new_batch.reqs) > 0: + logger.info(f"generate new batch, {new_batch.simple_log()}") + self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch) - if self.schedule_new_batch is not None: - logger.info(f"gen new batch, {self.schedule_new_batch.simple_log()}") return def _multinode_tp_generate_new_batch(self): @@ -554,6 +553,10 @@ def handle_exception(loop, context): ) loop.run_until_complete(router.wait_to_model_ready()) + router.profiler_rpyc_server, router.profiler_rpyc_thread = start_router_profiler_server( + args, + router.profiler_cmd_queue, + ) except: import traceback import sys diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index ceccebd8f4..5c2d0d45fb 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -5,20 +5,23 @@ import collections import pickle +from sortedcontainers import SortedDict from dataclasses import dataclass, field -from typing import List, Dict, Tuple, Optional, Callable, Any +from typing import List, Dict, Tuple, Optional, Callable, Any, Union from lightllm.common.req_manager import ReqManager, ReqManagerForMamba from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode -from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache +from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import ( + LinearAttPagedRadixCache, + LinearAttPagedTreeNode, +) from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.server.multimodal_params import MultimodalParams from lightllm.utils.custom_kernel_utis import custom_cat from lightllm.utils.envs_utils import get_env_start_args -from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo +from lightllm.server.pd_io_struct import PDDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient logger = init_logger(__name__) @@ -26,8 +29,8 @@ @dataclass class InferenceContext: - req_manager: ReqManager = None # gpu 请求管理 - radix_cache: RadixCache = None + req_manager: Union[ReqManager, ReqManagerForMamba] = None # gpu 请求管理 + radix_cache: Union[LinearAttPagedRadixCache, RadixCache] = None shm_req_manager: ShmReqManager = None # 共享内存请求对象管理 requests_mapping: Dict[int, "InferReq"] = None infer_req_ids = None @@ -36,13 +39,13 @@ class InferenceContext: overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream - has_recurrent_state: bool = False # for + is_linear_att_mixed_model: bool = False # 标记模型是否是full att 混合 linear att 的混合模型。 def register( self, backend, - req_manager: ReqManager, - radix_cache: RadixCache, + req_manager: Union[ReqManager, ReqManagerForMamba], + radix_cache: Union[LinearAttPagedRadixCache, RadixCache], shm_req_manager: ShmReqManager, vocab_size: int, ): @@ -60,7 +63,7 @@ def register( self.vocab_size = vocab_size - self.has_recurrent_state = isinstance(self.req_manager, ReqManagerForMamba) + self.is_linear_att_mixed_model = isinstance(self.req_manager, ReqManagerForMamba) return @@ -78,27 +81,6 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: self.cpu_kv_cache_stream = torch.cuda.Stream() return self.cpu_kv_cache_stream - def _alloc_and_copy_req_buffers( - self, req_manager: ReqManagerForMamba, radix_cache: HybridRadixCache, req_objs: List["InferReq"] - ) -> None: - if not req_objs: - return - - if radix_cache is not None: - radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs)) - - req_idx_gpu = torch.tensor([r.req_idx for r in req_objs], device="cuda", dtype=torch.int64) - req_manager.alloc_buffer_for_req(req_idx_gpu) - - if radix_cache is not None: - fork_req_ids = [r.req_idx for r in req_objs if r.shared_kv_node is not None] - if fork_req_ids: - src_buf_ids = [r.shared_kv_node.buffer_idx for r in req_objs if r.shared_kv_node is not None] - req_tensor = torch.tensor(fork_req_ids, device="cuda", dtype=torch.int32) - src_tensor = torch.tensor(src_buf_ids, device="cuda", dtype=torch.int32) - dst_buffers = req_manager.req_to_buffer_index[req_tensor[:], 0].view(-1, 1) - req_manager.buffer_mem_manager.fork_state_buffers(src_tensor, dst_buffers) - def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: req_objs = [] request_ids = [] @@ -137,63 +119,135 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: slave_req: InferReq = slave_req slave_req.related_master_req = master_req - if isinstance(self.req_manager, ReqManagerForMamba): - self._alloc_and_copy_req_buffers(self.req_manager, self.radix_cache, req_objs) - return req_objs def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) else: + if not self.is_linear_att_mixed_model: + self._full_att_free_req(free_token_index=free_token_index, req=req) + else: + self._linear_att_free_req(free_token_index=free_token_index, req=req) + assert len(req.linear_att_len_to_big_page_id) == 0 + req.cur_kv_len = 0 + req.shm_req.shm_cur_kv_len = req.cur_kv_len + return + + def _full_att_free_req(self, free_token_index: List, req: "InferReq"): + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + # .cpu() 是 流内阻塞操作 + value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + + prefix_len, _ = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) + if req.shared_kv_node is not None: + assert req.shared_kv_node.node_prefix_total_len <= prefix_len + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + return + + def _linear_att_free_req(self, free_token_index: List, req: "InferReq"): + assert g_infer_context.is_linear_att_mixed_model is True + args = get_env_start_args() + hash_page_size = args.linear_att_hash_page_size + big_page_num = args.linear_att_page_block_num + shared_kv_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + tail_big_page_token_num = ( + req.linear_att_cache_len // (hash_page_size * big_page_num) * (hash_page_size * big_page_num) + ) + page_num = req.linear_att_cache_len // hash_page_size + assert req.linear_att_cache_len >= shared_kv_len + if req.tail_linear_att_small_page_buffer_id is not None: + assert req.linear_att_cache_len <= req.cur_kv_len + + if req.cur_kv_len == 0: + return + + if req.linear_att_cache_len <= req.cur_kv_len and req.tail_linear_att_small_page_buffer_id is not None: + # 只有小页可以有 tail_linear_att_small_page_buffer_id,然后进行小页插入。 + assert page_num % big_page_num != 0 + free_token_index.append( + self.req_manager.req_to_token_indexs[req.req_idx][req.linear_att_cache_len : req.cur_kv_len] + ) + req.cur_kv_len = req.linear_att_cache_len input_token_ids = req.get_input_token_ids() key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") - # .cpu() 是 流内阻塞操作 value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - - prefix_len, node = self.radix_cache.insert(key, value) - + block_hashs = req.shm_req.linear_att_token_hash_list.get_all()[:page_num] + linear_idxs = [None for _ in range(page_num)] + linear_idxs[-1] = req.tail_linear_att_small_page_buffer_id + req.tail_linear_att_small_page_buffer_id = None + prefix_len, _ = self.radix_cache.insert( + key, + value, + block_hashs=block_hashs, + block_linear_idxs=linear_idxs, + len_to_big_page_id=req.linear_att_len_to_big_page_id, + ) old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: assert req.shared_kv_node.node_prefix_total_len <= prefix_len self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + return - def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> bool: - if self.radix_cache is None: - free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) - else: + if shared_kv_len < tail_big_page_token_num <= req.cur_kv_len: + free_token_index.append( + self.req_manager.req_to_token_indexs[req.req_idx][tail_big_page_token_num : req.cur_kv_len] + ) + req.cur_kv_len = tail_big_page_token_num + + assert req.tail_linear_att_small_page_buffer_id is None input_token_ids = req.get_input_token_ids() key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - - prefix_len, node = self.radix_cache.insert(key, value) + cur_page_num = tail_big_page_token_num // hash_page_size + assert tail_big_page_token_num % hash_page_size == 0 + block_hashs = req.shm_req.linear_att_token_hash_list.get_all()[:cur_page_num] + linear_idxs = [None for _ in range(cur_page_num)] + prefix_len, _ = self.radix_cache.insert( + key, + value, + block_hashs=block_hashs, + block_linear_idxs=linear_idxs, + len_to_big_page_id=req.linear_att_len_to_big_page_id, + ) old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: assert req.shared_kv_node.node_prefix_total_len <= prefix_len self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + return - # 请求可能在排队时就被终止,导致node可能为None - if node is not None and node.buffer_idx is None: - req_to_buffer_index = self.req_manager.req_to_buffer_index - buffer_idx = req_to_buffer_index[req.req_idx, 0].item() - self.radix_cache.add_buffer_idx_to_node(node, buffer_idx) - # 该请求的 buffer 已经被插入到 radix cache 中,不需要手动释放 - return False - return True - - def _free_req_mem_and_buffers(self, free_token_index: List, free_buffer_index: List, req: "InferReq"): - """释放请求的 KV cache 和 buffer 内存""" - if self.has_recurrent_state: - need_free_base_buffer = self.free_a_req_mem_for_mamba(free_token_index, req) - req_to_buffer_index = self.req_manager.req_to_buffer_index - if need_free_base_buffer: - free_buffer_index.extend(req_to_buffer_index[req.req_idx, :].tolist()) - else: - self.free_a_req_mem(free_token_index, req) + if shared_kv_len <= req.cur_kv_len: + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][shared_kv_len : req.cur_kv_len]) + # 该分支不会把 prefill 阶段累积的 big page id 插入 radix cache(典型为 pause/abort + # 在 prefill 跨过 big page 边界后、到达末尾前触发),需在此显式释放,避免泄漏。 + + # 释放本请求 prefill 阶段在 big page 边界上申请、但尚未插入 radix cache 的 big page + # state buffer。仅当请求未走 insert 分支(小页/大页插入)就被释放时才会有残留,典型场景: + # big page 模式下请求在 prefill 跨过 big page 边界后、到达末尾前被 pause / abort。 + # 若不释放,会泄漏 big page state slot,并触发 free_a_req_mem 中 dict 为空的断言。 + if req.linear_att_len_to_big_page_id: + self.radix_cache.linear_att_big_page_buffers.free_state_cache( + list(req.linear_att_len_to_big_page_id.values()) + ) + req.linear_att_len_to_big_page_id.clear() + + req.cur_kv_len = shared_kv_len + assert req.tail_linear_att_small_page_buffer_id is None + if req.shared_kv_node is not None: + assert req.shared_kv_node.node_prefix_total_len == req.cur_kv_len + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + return + + assert False, f"error state: cur_kv_len: {req.cur_kv_len}" def _save_promptcache_kvbuffer(self): """ @@ -216,23 +270,19 @@ def _filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] - free_buffer_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() - self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) + self.free_a_req_mem(free_token_index, req) + free_req_index.append(req.req_idx) # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True self.shm_req_manager.put_back_req_obj(req.shm_req) - if len(free_token_index) != 0: - free_token_index = custom_cat(free_token_index) - self.req_manager.free(free_req_index, free_token_index) - - if len(free_buffer_index) != 0 and isinstance(self.req_manager, ReqManagerForMamba): - self.req_manager.free_buffer(free_buffer_index) + free_token_index = custom_cat(free_token_index) + self.req_manager.free(free_req_index, free_token_index) finished_req_ids_set = set(finished_request_ids) self.infer_req_ids = [_id for _id in self.infer_req_ids if _id not in finished_req_ids_set] @@ -242,33 +292,27 @@ def _filter(self, finished_request_ids: List[int]): f"free a batch state:\n" f"radix refed token num {self.radix_cache.get_refed_tokens_num()}\n" f"radix hold token num {self.radix_cache.get_tree_total_tokens_num()}\n" - f"mem manager can alloc token num {self.req_manager.mem_manager.can_use_mem_size}\n" - f"mem manager total size {self.req_manager.mem_manager.size}" + f"mem manager can alloc token num {self.req_manager.mem_manager.allocator.can_use_mem_size}\n" + f"mem manager total size {self.req_manager.mem_manager.allocator.size}\n" ) return def filter_reqs(self, finished_reqs: List["InferReq"]): if finished_reqs: - g_infer_state_lock.acquire() self._filter([req.req_id for req in finished_reqs]) - g_infer_state_lock.release() return @torch.no_grad() def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if pause_reqs: - g_infer_state_lock.acquire() free_token_index = [] - free_buffer_index = [] for req in pause_reqs: if self.args.diverse_mode: # 发生暂停的时候,需要清除 diverse 模式下的主从关系 req.clear_master_slave_state() - self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) - req.cur_kv_len = 0 - req.shm_req.shm_cur_kv_len = req.cur_kv_len + self.free_a_req_mem(free_token_index, req) assert req.wait_pause is True req.wait_pause = False req.paused = True @@ -279,32 +323,27 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if len(free_token_index) != 0: free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) - - if len(free_buffer_index) != 0 and isinstance(self.req_manager, ReqManagerForMamba): - self.req_manager.free_buffer(free_buffer_index) - - g_infer_state_lock.release() return self def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int): if paused_reqs: - g_infer_state_lock.acquire() - recovered_reqs = [] + for req in paused_reqs: prefill_need_token_num = req.get_cur_total_len() if prefill_need_token_num > can_alloc_token_num: break - req._match_radix_cache() + + if g_infer_context.is_linear_att_mixed_model: + req._linear_match_radix_cache() + else: + req._match_radix_cache() + assert req.paused is True req.paused = False if is_master_in_dp: req.shm_req.is_paused = False logger.debug(f"infer recover paused req id {req.req_id}") can_alloc_token_num -= prefill_need_token_num - recovered_reqs.append(req) - if isinstance(self.req_manager, ReqManagerForMamba): - self._alloc_and_copy_req_buffers(self.req_manager, self.radix_cache, recovered_reqs) - g_infer_state_lock.release() return def get_can_alloc_token_num(self): @@ -313,7 +352,73 @@ def get_can_alloc_token_num(self): radix_cache_unref_token_num = ( self.radix_cache.get_tree_total_tokens_num() - self.radix_cache.get_refed_tokens_num() ) - return self.req_manager.mem_manager.can_use_mem_size + radix_cache_unref_token_num + return self.req_manager.mem_manager.allocator.can_use_mem_size + radix_cache_unref_token_num + + def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: List["InferReq"]): + """ + 该函数用于在线性混合模型prefill后,如果存在大页匹配的情况下,将线性层状态复制到 + """ + if not self.is_linear_att_mixed_model: + return + + # 大页对应的 linear att 的拷贝 + big_page_token_num = self.args.linear_att_hash_page_size * self.args.linear_att_page_block_num + big_page_buffer_ids = [] + for req in reqs: + cur_input_len = req.get_chuncked_input_token_len() + if cur_input_len % big_page_token_num == 0 and cur_input_len <= req.linear_att_cache_len: + big_page_id = self.radix_cache.linear_att_big_page_buffers.alloc_one_state_cache() + assert big_page_id is not None + big_page_buffer_ids.append(big_page_id) + assert cur_input_len not in req.linear_att_len_to_big_page_id + req.linear_att_len_to_big_page_id[cur_input_len] = big_page_id + else: + big_page_buffer_ids.append(-1) + + assert len(b_req_idx) == len(big_page_buffer_ids) + if any(buffer_id != -1 for buffer_id in big_page_buffer_ids): + big_page_buffer_ids = torch.tensor( + big_page_buffer_ids, dtype=torch.int32, requires_grad=False, device="cpu" + ) + big_page_buffer_ids = big_page_buffer_ids.cuda(non_blocking=True) + + from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer + + copy_linear_att_state_to_kv_buffer( + b_req_idx=b_req_idx, + big_page_buffer_ids=big_page_buffer_ids, + gpu_conv_state=self.req_manager.req_to_conv_state.buffer, + gpu_ssm_state=self.req_manager.req_to_ssm_state.buffer, + cpu_kv_conv_state=self.radix_cache.linear_att_big_page_buffers.conv_state_cache.buffer, + cpu_kv_ssm_state=self.radix_cache.linear_att_big_page_buffers.ssm_state_cache.buffer, + mtp_step=self.args.mtp_step, + ) + + assert not self.args.disable_chunked_prefill, "chunked prefill mode must be enabled for linear att mixed model" + + # tail small page 的linear att 状态的存储 + for req in reqs: + # 判断本次prefill 完以后 kv 的长度是否到达linear att 块存储的临界点。 + if req.get_chuncked_input_token_len() == req.linear_att_cache_len: + assert req.tail_linear_att_small_page_buffer_id is None + if req.linear_att_cache_len % big_page_token_num != 0: + self.radix_cache.free_one_small_page_linear_att_buffer() + req.tail_linear_att_small_page_buffer_id = ( + self.radix_cache.linear_att_small_page_buffers.alloc_one_state_cache() + ) + if req.tail_linear_att_small_page_buffer_id is not None: + src_buffer_idx = req.req_idx * (self.args.mtp_step + 1) + gpu_conv_state = self.req_manager.req_to_conv_state.buffer[:, src_buffer_idx, ...] + gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...] + dst_buffer_idx = req.tail_linear_att_small_page_buffer_id + + dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache( + buffer_idx=dst_buffer_idx + ) + # TODO 对于非连续对象调用 copy_ 效率并不高 + dst_conv_state.copy_(gpu_conv_state, non_blocking=True) + dst_ssm_state.copy_(gpu_ssm_state, non_blocking=True) + return g_infer_context = InferenceContext() @@ -346,11 +451,8 @@ def __init__( if len(self.allowed_token_ids) == 0: self.allowed_token_ids = None - # p d mode use params - if self.shm_param.move_kv_to_decode_node.exists: - self.move_kv_to_decode_node = self.shm_param.move_kv_to_decode_node.to_dict() - else: - self.move_kv_to_decode_node = None + # if provided, invalid_token_ids are masked to -inf during sampling (see generic_post_process.sample) + self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list() # this check is not very good to placed here. to do... if self.allowed_token_ids is not None: @@ -358,11 +460,16 @@ def __init__( logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids") self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size] - # nixl decode node information - if self.shm_param.nixl_params.data_len > 0: - self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get()) + if len(self.invalid_token_ids) > 0: + if not all(e < vocab_size for e in self.invalid_token_ids): + logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids") + self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size] + + # pd decode node information + if self.shm_param.pd_kv_trans_params.data_len > 0: + self.pd_decode_node: PDDecodeNodeInfo = pickle.loads(self.shm_param.pd_kv_trans_params.get()) else: - self.nixl_decode_node: NIXLDecodeNodeInfo = None + self.pd_decode_node: PDDecodeNodeInfo = None # only pd mode used. self.pd_master_node_id: int = self.shm_param.pd_master_node_id.get() @@ -401,6 +508,7 @@ def __init__( vocab_size: int = -1, init_prefix_cache: bool = True, ): + self.args = get_env_start_args() self.req_id = req_id self.req_idx = req_idx self.shm_index = shm_index @@ -421,12 +529,23 @@ def __init__( self.slave_reqs: List[InferReq] = [] self.related_master_req: InferReq = None - # nixl pd 分离模式使用的变量, 普通模式下这些变量没有具体用途 - self.nixl_trans_kv_start_index: int = 0 - self.nixl_pd_task_num: int = 0 - self.nixl_pd_task_sunccess_num: int = 0 - self.nixl_pd_task_failed_num: int = 0 - self.nixl_trans_device_id: int = -1 + # pd 分离模式使用的变量, 普通模式下这些变量没有具体用途 + self.pd_trans_kv_start_index: int = 0 + self.pd_task_num: int = 0 + self.pd_task_success_num: int = 0 + self.pd_task_failed_num: int = 0 + self.pd_trans_device_id: int = -1 + + # 类似 qwen3.5 这种混合linear att 模型使用的状态,记录申请来用于保存对应的线性att缓存的 buffer id + # 当 prefill 阶段结束后, 对应长度的 linear att state 会写入到申请 buffer id 对应的块中, 方便插入到 radix cache中 + # 方便被后续的请求使用,因为这种资源是有限的,也可能不存在的情况,申请不到时, 为None,则这种小块对应长度的 kv 无法 + # 在后续被插入到radix cache中. 这个id 是对应radix cache中的small page的buffer. + # 对应请求最尾巴上那一个块,对应的 small page buffer id + self.tail_linear_att_small_page_buffer_id: Optional[int] = None + # linear cache 对应的长度位置。 + self.linear_att_cache_len: Optional[int] = None + # 存储对应长度位置的大页buffer_id + self.linear_att_len_to_big_page_id: Optional[SortedDict] = None # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 @@ -440,6 +559,10 @@ def __init__( else: self.decode_need_token_num = self._normal_decode_need_token_num + if g_infer_context.is_linear_att_mixed_model: + self.get_chuncked_input_token_len = self.get_chuncked_input_token_len_for_linear_att + self.get_chuncked_input_token_ids = self.get_chuncked_input_token_ids_for_linear_att + self._init_all_state() self.generator = None @@ -448,7 +571,10 @@ def __init__( self.generator.manual_seed(self.sampling_param.shm_param.seed) if init_prefix_cache: - self._match_radix_cache() + if g_infer_context.is_linear_att_mixed_model: + self._linear_match_radix_cache() + else: + self._match_radix_cache() return def _init_all_state(self): @@ -457,9 +583,9 @@ def _init_all_state(self): self.shm_req.link_logprobs_shm_array() self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) - # 更新 nixl pd 分离模式下, prefill 节点需要开始传输的起始位置 - if self.sampling_param.nixl_decode_node is not None: - self.nixl_trans_kv_start_index = self.sampling_param.nixl_decode_node.ready_kv_len + # 更新 pd 分离模式下, prefill 节点需要开始传输的起始位置 + if self.sampling_param.pd_decode_node is not None: + self.pd_trans_kv_start_index = self.sampling_param.pd_decode_node.ready_kv_len self.cur_kv_len = 0 self.cur_output_len = 0 @@ -473,15 +599,24 @@ def _init_all_state(self): else: self.prefix_token_ids = [] self.multimodal_params = self.multimodal_params.to_dict() - self.shared_kv_node: TreeNode = None + self.shared_kv_node: Union[TreeNode, LinearAttPagedTreeNode] = None self.finish_status = FinishStatus() + + # 申请线性att混合模型使用的缓存资源 + if g_infer_context.is_linear_att_mixed_model: + linear_block_num = self.shm_req.linear_att_token_hash_list.size + self.linear_att_cache_len = linear_block_num * self.args.linear_att_hash_page_size + self.linear_att_len_to_big_page_id = SortedDict() + return def _match_radix_cache(self): - if self.sampling_param.disable_prompt_cache: - return - if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 1 and self.cur_kv_len == 0: + assert ( + g_infer_context.is_linear_att_mixed_model is False + ), "current _match_radix_cache does not support linear att hybrid model, to do..." + enable_prompt_cache = (not self.sampling_param.disable_prompt_cache) and g_infer_context.radix_cache is not None + if enable_prompt_cache and self.get_cur_total_len() > 1 and self.cur_kv_len == 0: input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 @@ -497,6 +632,120 @@ def _match_radix_cache(self): self.shm_req.shm_cur_kv_len = self.cur_kv_len return + def _linear_match_radix_cache(self): + assert ( + g_infer_context.is_linear_att_mixed_model is True + ), "current _linear_match_radix_cache only support linear att hybrid model, to do..." + enable_prompt_cache = (not self.sampling_param.disable_prompt_cache) and g_infer_context.radix_cache is not None + linear_hash_list = self.shm_req.linear_att_token_hash_list.get_all() + linear_att_hash_page_size = self.args.linear_att_hash_page_size + match_tokens = min(len(linear_hash_list) * linear_att_hash_page_size, self.get_cur_total_len() - 1) + match_tokens = max(0, match_tokens) + match_tokens = (match_tokens // linear_att_hash_page_size) * linear_att_hash_page_size + match_block_num = match_tokens // linear_att_hash_page_size + linear_hash_list = linear_hash_list[:match_block_num] + assert len(linear_hash_list) == self.shm_req.linear_att_token_hash_list.size + big_page_token_num = linear_att_hash_page_size * self.args.linear_att_page_block_num + big_page_is_disable = big_page_token_num > self.args.max_req_total_len + if enable_prompt_cache and match_tokens > 1 and len(linear_hash_list) > 0 and self.cur_kv_len == 0: + input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] + key = torch.tensor(input_token_ids[0:match_tokens], dtype=torch.int64, device="cpu") + assert len(key) == len(linear_hash_list) * linear_att_hash_page_size + share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix( + key, block_hashs=linear_hash_list, update_refs=True + ) + if share_node is not None: + assert self.tail_linear_att_small_page_buffer_id is None + if share_node.is_big_page_node(): + # 大页匹配 + self.shared_kv_node = share_node + ready_cache_len = share_node.node_prefix_total_len + # 从 cpu 到 gpu 是流内阻塞操作 + g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor + self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 + assert self.tail_linear_att_small_page_buffer_id is None + # 恢复linear att 状态 + g_infer_context.req_manager.copy_big_page_buffer_to_linear_att_state( + big_page_buffer_idx=share_node.big_page_buffer_idx, req=self + ) + else: + # 小页匹配 + if big_page_is_disable: + # 如果 大页本质是被禁用的,可以直接使用小页的匹配结果 + self.shared_kv_node = share_node + ready_cache_len = share_node.node_prefix_total_len + # 从 cpu 到 gpu 是流内阻塞操作 + g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor + self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 + assert self.tail_linear_att_small_page_buffer_id is None + # 恢复linear att 状态 + g_infer_context.req_manager.copy_small_page_buffer_to_linear_att_state( + req=self, + linear_att_small_page_buffers=g_infer_context.radix_cache.linear_att_small_page_buffers, + ) + else: + # 如果 大页本质是被启用的,则需要使用小页的匹配结果, 将小页的kv 复制到的新申请的kv位置,同时释放 + # 对应的小页对应的节点,递归找到对应最近的大叶节点进行返回,然后赋值到req.shared_node 对象上 + shared_kv_len = share_node.node_prefix_total_len + cur_big_page_tokens = (shared_kv_len // big_page_token_num) * big_page_token_num + need_tokens = shared_kv_len - cur_big_page_tokens + radix_cache = g_infer_context.radix_cache + if g_infer_context.get_can_alloc_token_num() > need_tokens: + # 有充足的token 容量时 + radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_tokens) + tail_mems = radix_cache.mem_manager.alloc(need_size=need_tokens) + g_infer_context.req_manager.req_to_token_indexs[ + self.req_idx, 0:cur_big_page_tokens + ] = value_tensor[0:cur_big_page_tokens] + g_infer_context.req_manager.req_to_token_indexs[ + self.req_idx, cur_big_page_tokens:shared_kv_len + ] = tail_mems + + # 将 对应的 value_tensors 中的 kv 数据 拷贝到 tail_mems 中对应的数据去 + radix_cache.mem_manager.operator.copy_mem_to_mem( + value_tensor[cur_big_page_tokens:shared_kv_len], tail_mems + ) + + self.shared_kv_node = share_node # 只是为了保证 copy_small_page_buffer_to_linear_att_state 正确调用 + g_infer_context.req_manager.copy_small_page_buffer_to_linear_att_state( + req=self, + linear_att_small_page_buffers=g_infer_context.radix_cache.linear_att_small_page_buffers, + ) + self.shared_kv_node = None + + big_page_shared_node = radix_cache.deref_to_first_big_page_node(node=share_node) + self.shared_kv_node = big_page_shared_node + self.cur_kv_len = int(shared_kv_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 + else: + # 没有充足的token 容量时, 直接找到最接近的大页,进行大页恢复 + share_node = radix_cache.deref_to_first_big_page_node(node=share_node) + if share_node is not None: + assert share_node.is_big_page_node() + # 大页匹配 + self.shared_kv_node = share_node + ready_cache_len = share_node.node_prefix_total_len + # 从 cpu 到 gpu 是流内阻塞操作 + g_infer_context.req_manager.req_to_token_indexs[ + self.req_idx, 0:ready_cache_len + ] = value_tensor[0:ready_cache_len] + self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 + assert self.tail_linear_att_small_page_buffer_id is None + # 恢复linear att 状态 + g_infer_context.req_manager.copy_big_page_buffer_to_linear_att_state( + big_page_buffer_idx=share_node.big_page_buffer_idx, req=self + ) + + self.shm_req.shm_cur_kv_len = self.cur_kv_len + + if self.cur_kv_len == 0: + # 说明没有任何命中 + g_infer_context.req_manager.init_linear_att_state(req=self) + return + def is_master_req(self): """ diverse 模式下,判断当前请求是否为独立主请求,其进行prefill后,将 @@ -542,14 +791,40 @@ def get_input_token_ids(self): def get_chuncked_input_token_ids(self): chunked_start = self.cur_kv_len - chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size) return self.shm_req.shm_prompt_ids.arr[0:chunked_end] + def get_chuncked_input_token_ids_for_linear_att(self): + big_page_token_num = self.args.linear_att_hash_page_size * self.args.linear_att_page_block_num + + chunked_start = self.cur_kv_len + chunked_end = chunked_start + self.args.chunked_prefill_size + big_page_end = ((chunked_start // big_page_token_num) + 1) * big_page_token_num + total_end = self.get_cur_total_len() + end = min(total_end, chunked_end, big_page_end) + + if chunked_start < self.linear_att_cache_len < end: + # linear att cache 对应需要存储的部分。 + end = self.linear_att_cache_len + + return self.shm_req.shm_prompt_ids.arr[0:end] + def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len - chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size) return chunked_end + def get_chuncked_input_token_len_for_linear_att(self): + big_page_token_num = self.args.linear_att_hash_page_size * self.args.linear_att_page_block_num + chunked_start = self.cur_kv_len + chunked_end = chunked_start + self.args.chunked_prefill_size + big_page_end = ((chunked_start // big_page_token_num) + 1) * big_page_token_num + total_end = self.get_cur_total_len() + end = min(total_end, chunked_end, big_page_end) + if chunked_start < self.linear_att_cache_len < end: + end = self.linear_att_cache_len + return end + def set_next_gen_token_id(self, next_token_id: int, logprob: float, output_len: int): index = self.shm_req.input_len + output_len self.shm_req.shm_prompt_ids.arr[index - 1] = next_token_id @@ -624,12 +899,12 @@ def handle( eos_ids: List[int], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]], is_master_in_dp: bool, - nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, + pd_prefill_chunked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, ): - # nixl_prefill_chuncked_handle_func 主要是为了处理 nixl prefill 模式下 + # pd_prefill_chunked_handle_func 主要是为了处理 pd prefill 模式下 # 分块 prefill 后,形成对应的pd 分块传输处理。 - if nixl_prefill_chuncked_handle_func is not None: - nixl_prefill_chuncked_handle_func(self.req_obj, next_token_id, next_token_logprob, self.output_len) + if pd_prefill_chunked_handle_func is not None: + pd_prefill_chunked_handle_func(self.req_obj, next_token_id, next_token_logprob, self.output_len) if self.output_len <= 0: return diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 82f3a8ddf4..1a4bf6c020 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -10,11 +10,7 @@ from .diverse_backend.impl import DiversehBackend # pd mode backend -from .continues_batch.pd_mode.decode_node_impl.decode_impl import DecodeNode -from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode -from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ChunckedPrefillForPrefillNode -from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp import DPChunkedForPrefillNode -from .pd_nixl.prefill_node_impl.prefill_impl import NIXLChunckedPrefillForPrefillNode -from .pd_nixl.prefill_node_impl.prefill_impl_for_dp import NIXLDPChunkedForPrefillNode -from .pd_nixl.decode_node_impl.decode_impl import NIXLDecodeNode -from .pd_nixl.decode_node_impl.decode_impl_for_dp import NIXLDPForDecodeNode +from .pd.prefill_node_impl.prefill_impl import PDChunkedPrefillForPrefillNode +from .pd.prefill_node_impl.prefill_impl_for_dp import PDDPChunkedForPrefillNode +from .pd.decode_node_impl.decode_impl import PDDecodeNode +from .pd.decode_node_impl.decode_impl_for_dp import PDDPForDecodeNode diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 49a113b1ba..a65dfb1bbb 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -11,8 +11,11 @@ from lightllm.models import get_model from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad -from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.req_manager import ReqManagerForMamba +from lightllm.common.linear_att_cache_manager import LinearAttCacheManager +from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify from lightllm.utils.dist_utils import init_distributed_env @@ -44,8 +47,9 @@ from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token -from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.server.pd_io_struct import PDChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd class ModeBackend: @@ -70,8 +74,8 @@ def __init__(self) -> None: self.classed_req_no_decode = False self.classed_req_strict_prefill = True - # nixl pd mode callback func - self.nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None + # pd mode callback func + self.pd_prefill_chunked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None # counter self._radix_tree_merge_counter: int = 0 @@ -101,8 +105,8 @@ def init_model(self, kvargs): self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1 - self.is_nixl_pd_mode = self.run_mode in ["nixl_prefill", "nixl_decode"] - self.is_nixl_decode_mode = self.run_mode == "nixl_decode" + self.is_pd_mode = self.run_mode in ["prefill", "decode"] + self.is_pd_decode_mode = self.run_mode == "decode" self.logger = init_logger(__name__) @@ -120,29 +124,6 @@ def init_model(self, kvargs): self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) - # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 - # init_process_group 之后调用 - g_infer_state_lock.obj = ( - InferStateLock( - name=get_unique_server_name(), - rank_in_dp=self.rank_in_dp, - dp_rank_in_node=self.dp_rank_in_node, - dp_world_size=self.dp_world_size, - ) - if self.run_mode in ["prefill", "decode"] - else None - ) - g_infer_state_lock.dp_world_size = self.dp_world_size - self.infer_state_lock = g_infer_state_lock - # 防止InferStateLock 中的全局共享信息被重复异常初始化,导致同步异常的问题。 - # 所以做一次barrier等待 - dist.barrier() - - wait_events = [] - if self.args.enable_cpu_cache: - self.multi_level_cache_module = MultiLevelKvCacheModule(self) - wait_events.append(self.multi_level_cache_module) - if self.args.enable_multimodal: g_infer_context.init_cpu_embed_cache_client() @@ -165,24 +146,42 @@ def init_model(self, kvargs): "batch_max_tokens": kvargs.get("batch_max_tokens", None), "quant_type": kvargs.get("quant_type", None), "quant_cfg": kvargs.get("quant_cfg", None), + "expert_dtype": kvargs.get("expert_dtype", None), "run_mode": self.run_mode, - "wait_events": wait_events, } self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + self.is_linear_att_mixed_model = isinstance(self.model.req_manager, ReqManagerForMamba) - radix_cache_class = self.model.radix_cache_class - self.radix_cache = ( - radix_cache_class( - get_unique_server_name(), - self.model.mem_manager.size, - self.rank_in_node, - kv_cache_mem_manager=self.model.mem_manager, + if self.is_linear_att_mixed_model: + self.linear_att_cache_manager = LinearAttCacheManager( + size=self.args.linear_att_cache_size, + linear_config=self.model.req_manager.linear_config, ) - if self.use_dynamic_prompt_cache - else None - ) + else: + self.linear_att_cache_manager = None + + if not self.use_dynamic_prompt_cache: + self.radix_cache = None + else: + if self.is_linear_att_mixed_model: + self.radix_cache = LinearAttPagedRadixCache( + unique_name=get_unique_server_name(), + total_token_num=self.model.mem_manager.size, + rank_in_node=self.rank_in_node, + hash_page_size=self.args.linear_att_hash_page_size, + big_page_num=self.args.linear_att_page_block_num, + kv_cache_mem_manager=self.model.mem_manager, + linear_att_small_page_buffers=self.linear_att_cache_manager, + ) + else: + self.radix_cache = RadixCache( + unique_name=get_unique_server_name(), + total_token_num=self.model.mem_manager.size, + rank_in_node=self.rank_in_node, + mem_manager=self.model.mem_manager, + ) if "prompt_cache_kv_buffer" in model_cfg: assert self.use_dynamic_prompt_cache @@ -220,10 +219,7 @@ def init_model(self, kvargs): [rank for rank in range(self.global_world_size)], backend="nccl" ) - if ( - self.args.run_mode in ["nixl_prefill", "nixl_decode", "prefill", "decode"] - or self.args.enable_dp_prompt_cache_fetch - ): + if self.args.run_mode in ["prefill", "decode"] or self.args.enable_dp_prompt_cache_fetch: # 如果存在需要跨进程使用mem manger的特性,则将mem manager写入到 shm中,方便 # 读取 self.model.mem_manager.write_to_shm(req_manager=self.model.req_manager) @@ -235,13 +231,20 @@ def init_model(self, kvargs): self.init_dp_kv_shared() self.shm_reqs_io_buffer = ShmObjsIOBuffer() - # 只会在 nixl pd 模式下才会使用,用于上传分块传输任务是否成功。 - self.shm_nixl_trans_io_buffer = ShmObjsIOBuffer(tail_str="nixl") + # 只会在 pd pd 模式下才会使用,用于上传分块传输任务是否成功。 + self.shm_pd_trans_io_buffer = ShmObjsIOBuffer(tail_str="pd") # 开启 mtp 模式,需要完成mtp model的初始化 if self.args.mtp_mode: self.init_mtp_draft_model(kvargs) + if self.args.enable_cpu_cache: + self.multi_level_cache_module = MultiLevelKvCacheModule(self) + + prof_name = f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}" + prof_mode = self.args.enable_profiling + self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name, use_multi_thread=True) if prof_mode else None + # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True) @@ -261,7 +264,6 @@ def init_dp_kv_shared(self): self.dp_kv_shared_module = DPKVSharedMoudle( max_req_num=self.args.running_max_req_size, - max_req_seq_len=self.args.max_req_total_len + 8, dp_size_in_node=self.dp_size_in_node, backend=self, ) @@ -320,6 +322,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): "batch_max_tokens": main_kvargs.get("batch_max_tokens", None), "quant_type": main_kvargs.get("quant_type", None), "quant_cfg": main_kvargs.get("quant_cfg", None), + "expert_dtype": main_kvargs.get("expert_dtype", None), "run_mode": "normal", "main_model": self.model, "mtp_previous_draft_models": self.draft_models.copy(), @@ -365,6 +368,10 @@ def _try_read_new_reqs(self): self._try_read_new_reqs_multinode_tp() else: self._try_read_new_reqs_normal() + + # on each loop thread + if self.profiler is not None: + self.profiler.multi_thread_helper() return def _try_read_new_reqs_normal(self): @@ -380,10 +387,10 @@ def _try_read_new_reqs_normal(self): if new_buffer_is_ready: self._read_reqs_buffer_and_init_reqs() - # nixl pd mode 从 shm_nixl_trans_io_buffer 读取分块传输的完成进度。 - if self.is_nixl_pd_mode: + # pd mode 从 shm_pd_trans_io_buffer 读取分块传输的完成进度。 + if self.is_pd_mode: if self.is_master_in_node: - if self.shm_nixl_trans_io_buffer.is_ready(): + if self.shm_pd_trans_io_buffer.is_ready(): self.node_broadcast_tensor.fill_(1) else: self.node_broadcast_tensor.fill_(0) @@ -392,7 +399,7 @@ def _try_read_new_reqs_normal(self): broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False) new_buffer_is_ready = self.node_broadcast_tensor.detach().item() if new_buffer_is_ready: - self._read_nixl_trans_io_buffer_and_update_req_status() + self._read_pd_trans_io_buffer_and_update_req_status() return def _try_read_new_reqs_multinode_tp(self): @@ -415,7 +422,7 @@ def _try_read_new_reqs_multinode_tp(self): if new_buffer_is_ready: self._read_reqs_buffer_and_init_reqs() - assert self.is_nixl_pd_mode is False + assert self.is_pd_mode is False return def _read_reqs_buffer_and_init_reqs(self): @@ -430,28 +437,29 @@ def _read_reqs_buffer_and_init_reqs(self): if obj.req_id in g_infer_context.requests_mapping: req: InferReq = g_infer_context.requests_mapping[obj.req_id] req.infer_aborted = True + elif isinstance(obj, ProfilerCmd): + if self.profiler is not None: + self.profiler.cmd(obj) else: assert False, f"error type {type(obj)}" if init_reqs: - req_ids = self._init_reqs(reqs=init_reqs) - if self.args.enable_cpu_cache and req_ids: - self._load_cpu_cache_to_reqs(req_ids=req_ids) + self._init_reqs(reqs=init_reqs) return - def _read_nixl_trans_io_buffer_and_update_req_status(self): - cmds: List[NIXLChunckedTransTaskRet] = self.shm_nixl_trans_io_buffer.read_obj() - self.shm_nixl_trans_io_buffer.sub_state() + def _read_pd_trans_io_buffer_and_update_req_status(self): + cmds: List[PDChunckedTransTaskRet] = self.shm_pd_trans_io_buffer.read_obj() + self.shm_pd_trans_io_buffer.sub_state() if cmds: for obj in cmds: if obj.request_id in g_infer_context.requests_mapping: req: InferReq = g_infer_context.requests_mapping[obj.request_id] if obj.has_error: - req.nixl_pd_task_failed_num += 1 + req.pd_task_failed_num += 1 else: - req.nixl_pd_task_sunccess_num += 1 - # nixl decode 节点需要预填充 prefill 节点发送过来的产生的首token信息,以使 + req.pd_task_success_num += 1 + # pd decode 节点需要预填充 prefill 节点发送过来的产生的首token信息,以使 # 推理过程可以继续。 - if self.is_nixl_decode_mode: + if self.is_pd_decode_mode: if obj.first_gen_token_id is not None: assert req.cur_output_len == 0 req.cur_output_len += 1 @@ -467,7 +475,7 @@ def _read_nixl_trans_io_buffer_and_update_req_status(self): eos_ids=self.eos_id, extra_post_req_handle_func=None, is_master_in_dp=self.is_master_in_dp, - nixl_prefill_chuncked_handle_func=None, + pd_prefill_chunked_handle_func=None, ) return @@ -483,24 +491,23 @@ def _init_reqs(self, reqs: List[Tuple]): if self.dp_size_in_node != 1: dp_rank_in_node = self.dp_rank_in_node reqs = [req for req in reqs if req[3] == dp_rank_in_node] - - g_infer_state_lock.acquire() g_infer_context.add_reqs(reqs) - g_infer_state_lock.release() req_ids = [e[0] for e in reqs] + + if self.args.enable_cpu_cache: + self._load_cpu_cache_to_reqs(req_ids=req_ids) + return req_ids def _load_cpu_cache_to_reqs(self, req_ids): req_objs: List[InferReq] = [g_infer_context.requests_mapping[req_id] for req_id in req_ids] - g_infer_state_lock.acquire() self.multi_level_cache_module.load_cpu_cache_to_reqs(reqs=req_objs) - g_infer_state_lock.release() return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: """ 将错误请求从 req_ids 中过滤出来, 然后让 _get_classed_reqs 进行处理。 该函数 - 主要用于在 nixl pd 分离模式下, 由子类继承重载, prefill 和 decode 节点过滤 kv 传输错误,或者 kv + 主要用于在 pd 分离模式下, 由子类继承重载, prefill 和 decode 节点过滤 kv 传输错误,或者 kv 传输没有完成的请求。 """ return [g_infer_context.requests_mapping[request_id] for request_id in req_ids] @@ -512,13 +519,11 @@ def _timer_merge_radix_tree(self): and (self._radix_tree_merge_counter % self._radix_tree_merge_update_delta == 0) and self.radix_cache is not None ): - g_infer_state_lock.acquire() start = time.time() self.radix_cache.merge_unreferenced_nodes() self.logger.info( f"radix tree merge_unreferenced_nodes cost time {time.time() - start} s in rank {self.global_rank}" ) - g_infer_state_lock.release() return # 一些可以复用的通用功能函数 @@ -576,9 +581,6 @@ def _get_classed_reqs( wait_pause_count = 0 prefill_tokens = 0 - # 因为会使用到 radix cache 和 mem_manager 的计数信息 - # 所以需要加锁保护。 - g_infer_state_lock.acquire() can_alloc_token_num = g_infer_context.get_can_alloc_token_num() for req_obj in ready_reqs: @@ -639,8 +641,6 @@ def _get_classed_reqs( req_obj.wait_pause = True wait_pause_count += 1 - g_infer_state_lock.release() - self._pre_handle_finished_reqs(finished_reqs=finished_reqs) # 如果使能了 cpu cache 功能,对于已经完成的请求,进行 gpu kv 卸载到 cpu cache的操作。 if self.args.enable_cpu_cache: @@ -658,6 +658,19 @@ def _get_classed_reqs( paused_reqs=paused_reqs, is_master_in_dp=self.is_master_in_dp, can_alloc_token_num=can_alloc_token_num ) + # 在 enable_prefill_decode_mixed 模式下,如果存在 prefill 请求和 decode 请求, + # 并且 prefill 请求需要的 token 数量 + decode 请求需要的 token 数量小于等于 batch_max_tokens, + # 则将 decode 请求合并到 prefill 请求中。 + if self.args.enable_prefill_decode_mixed and len(prefill_reqs) > 0 and len(decode_reqs) > 0: + if prefill_tokens + len(decode_reqs) <= self.batch_max_tokens: + for decode_req in decode_reqs: + # 给 decode req 添加一个属性标签,标识其为混合prefill的请求。 + # 在 prefill 阶段,会根据这个属性标签, 对这些请求的处理进行一些 + # 特殊化,主要时构建获取input_ids 的方式。 + decode_req.is_decode_req_mixed_in_prefill = True + prefill_reqs.append(decode_req) + decode_reqs = [] + return prefill_reqs, decode_reqs def _pre_handle_finished_reqs(self, finished_reqs: List[InferReq]): @@ -702,7 +715,7 @@ def _post_handle( next_token_logprobs: List[float], run_reqs_update_packs: List[InferReqUpdatePack], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, - nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, + pd_prefill_chunked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, ): """ extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 @@ -719,7 +732,7 @@ def _post_handle( eos_ids=self.eos_id, extra_post_req_handle_func=extra_post_req_handle_func, is_master_in_dp=self.is_master_in_dp, - nixl_prefill_chuncked_handle_func=nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=pd_prefill_chunked_handle_func, ) g_infer_context.req_manager.req_sampling_params_manager.update_reqs_token_counter( @@ -730,9 +743,7 @@ def _post_handle( # 一些可以复用的通用功能函数 def _filter_reqs(self, reqs: List[InferReq]): if reqs: - g_infer_state_lock.acquire() g_infer_context.filter_reqs(reqs) - g_infer_state_lock.release() return # 一些可以复用的通用功能函数 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224ebc..792a10a788 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -15,7 +15,6 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.common.basemodel.triton_kernel.mtp_utils import ( @@ -118,6 +117,10 @@ def prefill_normal( b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu, mask_func=self.prefill_mask_func, ) + g_infer_context.copy_linear_att_state_to_cache_buffer( + b_req_idx=model_input.b_req_idx, + reqs=run_reqs, + ) sync_event = torch.cuda.Event() sync_event.record() @@ -134,7 +137,7 @@ def prefill_normal( next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -199,6 +202,10 @@ def prefill_mtp( self._draft_prefill_forward( model_input=model_input, model_output=model_output, next_token_ids=next_token_ids ) + g_infer_context.copy_linear_att_state_to_cache_buffer( + b_req_idx=model_input.b_req_idx, + reqs=run_reqs, + ) sync_event = torch.cuda.Event() sync_event.record() @@ -216,7 +223,7 @@ def prefill_mtp( next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 @@ -308,9 +315,7 @@ def decode_mtp( ) if len(need_free_mem_indexes) > 0: - g_infer_state_lock.acquire() g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) - g_infer_state_lock.release() # 第四阶段 event_pack.notify_pre_post_handle() @@ -376,11 +381,9 @@ def _draft_decode_eagle( ): batch_size = main_model_input.batch_size num_reqs = batch_size // (self.mtp_step + 1) - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_reqs * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(num_reqs * self.mtp_step) - g_infer_state_lock.release() eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) # share some inference info with the main model diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py deleted file mode 100644 index b367a66a75..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py +++ /dev/null @@ -1,101 +0,0 @@ -import os -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -import threading -from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend -from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, g_infer_state_lock -from lightllm.server.core.objs import FinishStatus -from lightllm.utils.log_utils import init_logger -from rpyc.utils.server import ThreadedServer -from lightllm.common.basemodel.infer_lock import g_router_lock -from .decode_task_cache import g_success_kv_move_task_cache, KVMoveTask -from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.utils.dist_utils import create_new_group_for_current_dp - -logger = init_logger(__name__) - - -class DecodeNode(ChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue) -> None: - super().__init__() - self.info_queue: mp.Queue = info_queue - self.classed_req_strict_prefill = False - - def init_custom(self): - - self.lock_nccl_group = create_new_group_for_current_dp("gloo") - logger.info(f"lock_nccl_group ranks {dist.get_rank(self.lock_nccl_group)}") - - from .decode_infer_rpyc import PDDecodeInferRpcServer - - socket_path = f"/tmp/{get_unique_server_name()}_decode_node_infer_rpyc_{self.pd_rpyc_ports[self.rank_in_node]}" - if os.path.exists(socket_path): - os.remove(socket_path) - - t = ThreadedServer( - PDDecodeInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True} - ) - threading.Thread(target=lambda: t.start(), daemon=True).start() - return - - def _init_reqs(self, reqs: List[Tuple]): - """ - 替换请求初始化操作,替换为 Decode 节点独有的一些特殊初始化流程 - """ - if self.dp_size_in_node != 1: - dp_rank_in_node = self.dp_rank_in_node - reqs = [req for req in reqs if req[3] == dp_rank_in_node] - - g_infer_state_lock.acquire() - - uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=False) - # 匹配radix cache,并更新一些资源的管理。 - self._post_init_reqs(uninit_reqs=uninit_reqs) - - g_infer_state_lock.release() - req_ids = [e[0] for e in reqs] - return req_ids - - def _post_init_reqs(self, uninit_reqs: List[InferReq]): - """ - 检查请求的 kv len 将可能有问题的请求立即结束掉 - """ - if len(uninit_reqs) == 0: - return - - remove_count = 0 - estimated_peak_token_count = 0 - for req_obj in uninit_reqs: - req_obj: InferReq = req_obj # for easy typing - request_id = req_obj.req_id - if request_id in g_success_kv_move_task_cache: - task, share_node, _ = g_success_kv_move_task_cache.pop(request_id) - task: KVMoveTask = task # for easy typing - self.radix_cache.dec_node_ref_counter(share_node) - req_all_len = len(task.input_tokens) + task.decode_node.max_new_tokens - remove_count += req_all_len - estimated_peak_token_count += req_all_len - req_obj._match_radix_cache() - else: - # 对于不合法的请求,直接模拟将其finished掉 - req_obj.cur_output_len += 1 - req_obj.set_next_gen_token_id(0, 0.0, 1) - req_obj.finish_status.set_status(FinishStatus.FINISHED_STOP) - - if self.is_master_in_dp: - req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len - req_obj.shm_req.shm_cur_output_len = req_obj.cur_output_len - req_obj.shm_req.finish_token_index = req_obj.get_cur_total_len() - 1 - req_obj.shm_req.finish_status.set_status(FinishStatus.FINISHED_STOP) - req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len - - req_id = req_obj.shm_req.request_id - logger.error(f"req_id: {req_id} forced to finished, it not in g_success_kv_move_task_cache") - - if self.is_master_in_dp: - with g_router_lock.obj: - self.shared_token_load.add_frozened_token_count(-remove_count, self.dp_rank_in_node) - self.shared_token_load.add_estimated_peak_token_count(estimated_peak_token_count, self.dp_rank_in_node) - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py deleted file mode 100644 index 8dc9ad1a6d..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch.multiprocessing as mp -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.utils.log_utils import init_logger -from typing import List, Tuple -from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend -from .decode_impl import DecodeNode - -logger = init_logger(__name__) - - -class DPForDecodeNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue) -> None: - super().__init__() - self.info_queue: mp.Queue = info_queue - self.classed_req_strict_prefill = False - return - - def init_custom(self): - DecodeNode.init_custom(self) - return - - def _init_reqs(self, reqs: List[Tuple]): - DecodeNode._init_reqs(self, reqs=reqs) - return - - def _post_init_reqs(self, uninit_reqs: List[InferReq]): - DecodeNode._post_init_reqs(self, uninit_reqs=uninit_reqs) - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py deleted file mode 100644 index 696452b419..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch -import torch.distributed as dist -import rpyc -import time -from typing import Dict, List, Tuple, Optional, Union -from rpyc.utils.classic import obtain -from .decode_impl import DecodeNode -from lightllm.common.basemodel.infer_lock import acquire_lock_until_ready, release_acquired_lock, g_router_lock -from .decode_task_cache import g_kv_move_task_cache, g_success_kv_move_task_cache -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class PDDecodeInferRpcServer(rpyc.Service): - def __init__(self, backend: DecodeNode) -> None: - super().__init__() - self.backend = backend - self.device_id = self.backend.current_device_id - self.dp_rank_in_node = self.backend.dp_rank_in_node - self.is_master_in_dp = self.backend.is_master_in_dp - return - - def on_connect(self, conn): - torch.cuda.set_device(f"cuda:{self.device_id}") - return - - def judge_token_is_ok(self, key_len, max_new_token): - # 多 dp 单卡模式下, 每个 dp 各自处理自己的, 不需要同步 - if self.backend.dp_world_size == 1: - with g_router_lock.obj: - shared_token_load = self.backend.shared_token_load - peak_num = shared_token_load.get_estimated_peak_token_count(self.dp_rank_in_node) - peak_num += shared_token_load.get_frozened_token_count(self.dp_rank_in_node) - peak_num += key_len + max_new_token - - if peak_num < self.backend.get_max_total_token_num(): - object_list = [True] - shared_token_load.add_frozened_token_count(key_len + max_new_token, self.dp_rank_in_node) - else: - object_list = [False] - return object_list[0] - - # 普通单dp模式下, 只有主 rank 处理信息,并将数据同步到其他rank上 - if self.is_master_in_dp: - with g_router_lock.obj: - shared_token_load = self.backend.shared_token_load - peak_num = shared_token_load.get_estimated_peak_token_count(self.dp_rank_in_node) - peak_num += shared_token_load.get_frozened_token_count(self.dp_rank_in_node) - peak_num += key_len + max_new_token - - if peak_num < self.backend.get_max_total_token_num(): - object_list = [True] - shared_token_load.add_frozened_token_count(key_len + max_new_token, self.dp_rank_in_node) - else: - object_list = [False] - dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group) - else: - object_list = [None] - dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group) - return object_list[0] - - def recover_frozen_token(self, key_len, max_new_token): - if self.is_master_in_dp: - with g_router_lock.obj: - shared_token_load = self.backend.shared_token_load - shared_token_load.add_frozened_token_count(-(key_len + max_new_token), self.dp_rank_in_node) - return - - def _alloc_to_frozen_some_tokens(self, move_task: KVMoveTask): - is_ok = self.judge_token_is_ok(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) - if not is_ok: - if self.is_master_in_dp: - logger.info(f"req_id: {move_task.to_decode_log_info()} alloc token failed") - shared_token_load = self.backend.shared_token_load - dp_rank = self.dp_rank_in_node - frozen_token_num = shared_token_load.get_frozened_token_count(dp_rank) - estimated_peak_token_num = shared_token_load.get_estimated_peak_token_count(dp_rank) - logger.debug( - f"radix refed token num {self.backend.radix_cache.get_refed_tokens_num()}\n" - f"radix hold token num {self.backend.radix_cache.get_tree_total_tokens_num()}\n" - f"mem manager can alloc token num {self.backend.model.mem_manager.can_use_mem_size}\n" - f"mem manager total size {self.backend.model.mem_manager.size}" - f"frozened token num {frozen_token_num}\n" - f"estimated peak token num {estimated_peak_token_num}\n" - ) - return None - - key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") - tree_node, kv_len, fused_token_indexes = self.backend.radix_cache.match_prefix(key, update_refs=True) - # 如果没匹配到,说明长度是0, 将fused_token_indexes做一下转换 - fused_token_indexes = [] if fused_token_indexes is None else fused_token_indexes.tolist() - need_len = len(move_task.input_tokens) - kv_len - if need_len == 0: - alloc_token_indexes = [] - else: - self.backend.radix_cache.free_radix_cache_to_get_enough_token(need_len) - alloc_token_indexes = self.backend.model.mem_manager.alloc(need_len) - if alloc_token_indexes is not None: - alloc_token_indexes = alloc_token_indexes.tolist() - - if alloc_token_indexes is None: - self.backend.radix_cache.dec_node_ref_counter(tree_node) - self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) - return None - - move_task.decode_token_indexes = alloc_token_indexes - move_task.move_kv_len = need_len - - g_kv_move_task_cache[move_task.group_request_id] = (move_task, tree_node, fused_token_indexes) - return move_task.decode_token_indexes - - # 返回 None 代表服务繁忙已经无法调度新的请求进入了 - def exposed_alloc_to_frozen_some_tokens(self, move_tasks: List[KVMoveTask]) -> List[Optional[List[int]]]: - move_tasks = obtain(move_tasks) - acquire_lock_until_ready(self.backend.lock_nccl_group) - try: - ans_list = [] - for move_task in move_tasks: - ans_list.append(self._alloc_to_frozen_some_tokens(move_task)) - return ans_list - except BaseException as e: - logger.exception(str(e)) - return None - finally: - release_acquired_lock() - - def _put_kv_received_to_radix_cache(self, group_req_id: int): - move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id) - radix_cache = self.backend.radix_cache - key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") - value = torch.tensor(fused_token_indexes + move_task.decode_token_indexes, dtype=torch.int64, device="cpu") - prefix_len, _ = radix_cache.insert(key, value) - assert len(fused_token_indexes) <= prefix_len - self.backend.model.mem_manager.free(value[len(fused_token_indexes) : prefix_len]) - self.backend.radix_cache.dec_node_ref_counter(tree_node) - - # 申请一段key,把 radix cache 锁住,防止极端情况下被刷掉, decode 端通过减两次引用计数来修正。 - tree_node, kv_len, _ = self.backend.radix_cache.match_prefix(key, update_refs=True) - assert len(key) == kv_len - g_success_kv_move_task_cache[group_req_id] = (move_task, tree_node, time.time()) - return - - def exposed_put_kv_received_to_radix_cache(self, group_req_ids: List[int]): - group_req_ids = obtain(group_req_ids) - acquire_lock_until_ready(self.backend.lock_nccl_group) - for group_req_id in group_req_ids: - self._put_kv_received_to_radix_cache(group_req_id) - release_acquired_lock() - return - - def _fail_to_realese_forzen_tokens(self, group_req_id: int): - move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id) - value = torch.tensor(move_task.decode_token_indexes, dtype=torch.int64, device="cpu") - self.backend.model.mem_manager.free(value) - self.backend.radix_cache.dec_node_ref_counter(tree_node) - self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) - return - - def exposed_fail_to_realese_forzen_tokens(self, group_req_ids: List[int]): - group_req_ids = obtain(group_req_ids) - acquire_lock_until_ready(self.backend.lock_nccl_group) - for group_req_id in group_req_ids: - self._fail_to_realese_forzen_tokens(group_req_id) - release_acquired_lock() - return - - def exposed_unfrozen_time_out_reqs_tokens(self): - acquire_lock_until_ready(self.backend.lock_nccl_group) - if self.backend.dp_world_size == 1: - need_release_reqs = self._get_time_out_reqs() - logger.info(f"kv time out reqs: {need_release_reqs}") - remove_tokens = self._remove_time_out_reqs(need_release_reqs) - if remove_tokens != 0: - with g_router_lock.obj: - self.backend.shared_token_load.add_frozened_token_count(-remove_tokens, self.dp_rank_in_node) - else: - if self.is_master_in_dp: - need_release_reqs = self._get_time_out_reqs() - logger.info(f"kv time out reqs: {need_release_reqs}") - dist.broadcast_object_list([need_release_reqs], src=0, group=self.backend.lock_nccl_group) - else: - receive_objs = [None] - dist.broadcast_object_list(receive_objs, src=0, group=self.backend.lock_nccl_group) - need_release_reqs = receive_objs[0] - remove_tokens = self._remove_time_out_reqs(need_release_reqs) - if self.is_master_in_dp and remove_tokens != 0: - with g_router_lock.obj: - self.backend.shared_token_load.add_frozened_token_count(-remove_tokens, self.dp_rank_in_node) - - release_acquired_lock() - return - - def _get_time_out_reqs(self): - need_release_reqs = [] - for req_id, (_, _, time_mark) in g_success_kv_move_task_cache.items(): - # 6s 这个请求都没有被调度使用,就会主动被删除掉锁定,释放其锁定的token - if time.time() - time_mark > 6: - need_release_reqs.append(req_id) - return need_release_reqs - - def _remove_time_out_reqs(self, need_release_reqs: List[int]) -> int: - remove_tokens = 0 - for req_id in need_release_reqs: - task, tree_node, _ = g_success_kv_move_task_cache.pop(req_id) - self.backend.radix_cache.dec_node_ref_counter(tree_node) - remove_tokens += len(task.input_tokens) + task.decode_node.max_new_tokens - return remove_tokens diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py deleted file mode 100644 index 4733a141bf..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ /dev/null @@ -1,383 +0,0 @@ -import rpyc -import random -import asyncio -import os -import signal -import collections -import time -import psutil -import threading -import inspect -import setproctitle -from rpyc.utils.classic import obtain -from dataclasses import dataclass -from typing import List, Dict, Optional, Tuple, Union -from rpyc import ThreadedServer -from lightllm.utils.log_utils import init_logger -from .decode_infer_rpyc import PDDecodeInferRpcServer -from ..task_queue import TaskQueue -import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus, PDTransJoinInfo, PDTransLeaveInfo -from lightllm.utils.retry_utils import retry -import numpy as np -from rpyc import AsyncResult -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.utils.envs_utils import get_unique_server_name - -logger = init_logger(__name__) - -thread_local_data = threading.local() - -KV_MOVE_MAX_NUM = 16 - - -class DecodeKVMoveManager(rpyc.Service): - def __init__(self, args, info_queue: mp.Queue): - super().__init__() - self.args = args - # args.dp // args.nnodes 在跨机tp的场景下,可能为0 - self.dp_size_in_node = max(1, args.dp // args.nnodes) - self.node_world_size = args.tp // args.nnodes - self.dp_world_size = args.tp // args.dp - # 不支持跨机tp的pd 分离策略 - assert self.dp_world_size <= self.node_world_size - - self.info_queue = info_queue - self.infer_rpyc_lock = threading.Lock() - self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = [] - - from .decode_trans_obj import KVTransConnectObj - - self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {} - for port in self.args.pd_node_infer_rpyc_ports: - socket_path = f"/tmp/{get_unique_server_name()}_decode_node_infer_rpyc_{port}" - from rpyc.utils.factory import unix_connect - - con = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) - self.infer_rpyc_objs.append(con.root) - logger.info(f"rpyc connect to port: {port} ok") - - from .up_status import start_up_kv_status_process - - self.up_status_in_queue = mp.Queue() - self.up_status_out_queue = mp.Queue() - start_up_kv_status_process(self.args, self.up_status_in_queue, self.up_status_out_queue) - - # fail release queue - self.fail_to_release_queue = TaskQueue(get_func=lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=None) - self.fail_to_release_thread = threading.Thread(target=self.handle_fail_release_task_loop, daemon=True) - self.fail_to_release_thread.start() - - # 在不使用p2p 复制kv 的方案时,需要全局的传输锁进行控制。这个时候kv传输的效率会下降。 - self.kv_trans_lock = threading.Lock() - - from .decode_trans_obj import KVTransProcess - - self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size - for device_id in range(self.node_world_size): - self.kv_trans_processes[device_id] = KVTransProcess() - assert self.kv_trans_processes[device_id].init_all(device_id, self) - - return - - # ================================================================================== - # _dp_alloc_to_frozen_some_tokens - # _put_kv_received_to_radix_cache - # _fail_to_realese_forzen_tokens - # _unfrozen_time_out_reqs_tokens - # 上述接口都是 kv move manager 与推理进程进行交互的接口,主要用于申请锁定kv资源或者释放 - # kv资源的接口 - # ================================================================================== - - async def wait_all_future_finish(self, futures: List[AsyncResult]): - await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) - return - - def _dp_alloc_to_frozen_some_tokens(self, dp_tasks: List[List[KVMoveTask]]) -> List[List[Optional[List[int]]]]: - with self.infer_rpyc_lock: - futures = [] - for dp_index in range(self.dp_size_in_node): - conn_start = dp_index * self.dp_world_size - conn_end = (dp_index + 1) * self.dp_world_size - conns = self.infer_rpyc_objs[conn_start:conn_end] - for conn in conns: - futures.append(rpyc.async_(conn.alloc_to_frozen_some_tokens)(dp_tasks[dp_index])) - - asyncio.run(self.wait_all_future_finish(futures)) - ans_values = [ - obtain(futures[dp_index * self.dp_world_size].value) for dp_index in range(self.dp_size_in_node) - ] - return ans_values - - def _put_kv_received_to_radix_cache(self, tasks: List[KVMoveTask]) -> None: - with self.infer_rpyc_lock: - dp_to_tasks = collections.defaultdict(list) - for task in tasks: - dp_to_tasks[task.decode_dp_index].append(task) - futures: List[AsyncResult] = [] - for decode_dp_index, _tasks in dp_to_tasks.items(): - conn_start = decode_dp_index * self.dp_world_size - conn_end = (decode_dp_index + 1) * self.dp_world_size - conns = self.infer_rpyc_objs[conn_start:conn_end] - for conn in conns: - futures.append( - rpyc.async_(conn.put_kv_received_to_radix_cache)([task.group_request_id for task in _tasks]) - ) - asyncio.run(self.wait_all_future_finish(futures)) - return - - def _fail_to_realese_forzen_tokens(self, tasks: List[KVMoveTask]) -> None: - with self.infer_rpyc_lock: - dp_to_tasks = collections.defaultdict(list) - for task in tasks: - dp_to_tasks[task.decode_dp_index].append(task) - futures: List[AsyncResult] = [] - for decode_dp_index, _tasks in dp_to_tasks.items(): - conn_start = decode_dp_index * self.dp_world_size - conn_end = (decode_dp_index + 1) * self.dp_world_size - conns = self.infer_rpyc_objs[conn_start:conn_end] - for conn in conns: - futures.append( - rpyc.async_(conn.fail_to_realese_forzen_tokens)([task.group_request_id for task in _tasks]) - ) - asyncio.run(self.wait_all_future_finish(futures)) - return - - def _unfrozen_time_out_reqs_tokens(self) -> None: - # 这个接口比较特殊,可以不区分 dp 的具体模式 - with self.infer_rpyc_lock: - futures: List[AsyncResult] = [] - for conn in self.infer_rpyc_objs: - futures.append(rpyc.async_(conn.unfrozen_time_out_reqs_tokens)()) - asyncio.run(self.wait_all_future_finish(futures)) - return - - # ================================================================================== - # put_to_fail_release_task_queue 将因为一些原因失败,需要释放锁定的kv资源的请求放入到 - # 对应的处理队列中,handle_fail_release_task_loop 是一个循环的线程,专门处理这些失败的请求 - # 通过调用与推理进程交互的接口,释放掉申请锁定的 kv 资源。 - # ================================================================================== - - def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): - if isinstance(task, KVMoveTask): - self.fail_to_release_queue.put(task) - elif isinstance(task, list): - self.fail_to_release_queue.put_list(task) - else: - assert False, "error input" - return - - def handle_fail_release_task_loop(self): - while True: - handle_list: List[KVMoveTask] = self.fail_to_release_queue.get_tasks(log_tag="fail_to_release_queue") - if len(handle_list) == 0: - time.sleep(0.01) - else: - self._fail_to_realese_forzen_tokens(handle_list) - return - - # ================================================================================== - # on_connect - # on_disconnect - # exposed_check_alive - # exposed_build_trans_process - # exposed_request_data_transfer - # 上述接口是decode kv move manager 暴露的 rpyc 调用接口,用于 prefill kv move manager - # 进行连接,进行一些元数据资源的交互。 - # ================================================================================== - - def on_connect(self, conn): - # 用于处理连接断开的时候,自动删除资源 - thread_local_data.connect_id = None - pass - - def on_disconnect(self, conn): - # 用于处理连接断开的时候,自动删除资源 - if thread_local_data.connect_id is not None: - self.remove_trans_obj(thread_local_data.connect_id) - logger.info(f"connect id {thread_local_data.connect_id} disconnect") - import gc - - gc.collect() - pass - - def exposed_check_alive(self): - # 用于 prefill node check 通信连接的状态。 - return - - def exposed_build_trans_connect( - self, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num, connect_id - ): - prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num = list( - map(obtain, [prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num]) - ) - connect_id = obtain(connect_id) - thread_local_data.connect_id = connect_id - - logger.info(f"build trans infos {prefill_node_id} {pd_prefill_nccl_ip} {pd_prefill_nccl_port} {connect_id}") - - from .decode_trans_obj import KVTransConnectObj - - tran_obj = KVTransConnectObj() - tran_obj.create(connect_id, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, self) - self.connect_id_to_trans_obj[connect_id] = tran_obj - return min(prefill_node_max_kv_trans_num, self.args.max_total_token_num) - - # 返回 None 代表繁忙, 放弃该任务的 kv 传送 - def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optional[int]]: - tasks: List[KVMoveTask] = obtain(tasks) - alloc_tokened_tasks = [] - ans_list = [] - try: - for task in tasks: - logger.info(f"exposed_request_data_transfer in {task.to_decode_log_info()}, type {type(task)}") - - trans_obj = self.get_trans_obj(tasks[0]) - assert trans_obj is not None - - id_to_test_range = {} - for task in tasks: - test_dp_indexes = list(range(self.dp_size_in_node)) - random.shuffle(test_dp_indexes) - id_to_test_range[task.group_request_id] = test_dp_indexes - - id_has_result = {} - for test_index in range(self.dp_size_in_node): - dp_tasks = [[] for _ in range(self.dp_size_in_node)] - for task in tasks: - if task.group_request_id not in id_has_result: - test_dp_index = id_to_test_range[task.group_request_id][test_index] - dp_tasks[test_dp_index].append(task) - if not all(len(t) == 0 for t in dp_tasks): - dp_tasks_ans = self._dp_alloc_to_frozen_some_tokens(dp_tasks) - for dp_index in range(self.dp_size_in_node): - for task, decode_token_indexes in zip(dp_tasks[dp_index], dp_tasks_ans[dp_index]): - if decode_token_indexes is not None: - id_has_result[task.group_request_id] = (dp_index, decode_token_indexes) - for task in tasks: - if task.group_request_id in id_has_result: - task.decode_dp_index = id_has_result[task.group_request_id][0] - task.decode_token_indexes = id_has_result[task.group_request_id][1] - task.move_kv_len = len(task.decode_token_indexes) - ans_list.append(task.move_kv_len) - alloc_tokened_tasks.append(task) - else: - logger.info(f"req id {task.id()} request_data_transfer fail, server is busy") - ans_list.append(None) - - except BaseException as e: - self.put_to_fail_release_task_queue(alloc_tokened_tasks) - alloc_tokened_tasks = [] - self.remove_trans_obj(tasks[0].connect_id) - logger.exception(str(e)) - raise e - - if alloc_tokened_tasks: - trans_obj.ready_to_move_queue.put( - alloc_tokened_tasks, error_handle_func=self.put_to_fail_release_task_queue - ) - - return ans_list - - # ================================================================================== - # 定时检测kv 传输成功,但是长时间没有pd master来触发推理的请求, - # 释放这些超时请求占用的kv资源 - # ================================================================================== - - def timer_loop(self): - try: - while True: - self._unfrozen_time_out_reqs_tokens() - time.sleep(3.5) - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - raise e - - # ================================================================================== - # 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启 - # ================================================================================== - - def check_trans_process_loop(self): - try: - while True: - for device_id in range(self.node_world_size): - if not self.kv_trans_processes[device_id].is_trans_process_health(): - raise Exception(f"device_id {device_id} kv process is unhealth") - - time.sleep(10.0) - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - - for device_id in range(self.node_world_size): - self.kv_trans_processes[device_id].killself() - - # 杀掉当前进程的父进程(router), 触发全局崩溃 - os.kill(os.getppid(), signal.SIGKILL) - os.kill(os.getpid(), signal.SIGKILL) - raise e - - # ================================================================================== - # 常用辅助功能函数 - # ================================================================================== - def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] - for obj in self.connect_id_to_trans_obj.values(): - counts[obj.device_index] += 1 - device_index = int(np.argmin(counts)) - return device_index - - def get_trans_obj(self, task: KVMoveTask): - self.__remove_dead_trans_obj() - return self.connect_id_to_trans_obj[task.connect_id] - - def __remove_dead_trans_obj(self): - del_connect_ids = [] - for connect_id, t_obj in self.connect_id_to_trans_obj.items(): - if t_obj.has_error_status(): - del_connect_ids.append(connect_id) - - for connect_id in del_connect_ids: - self.connect_id_to_trans_obj.pop(connect_id, None) - - if del_connect_ids: - import gc - - gc.collect() - return - - def remove_trans_obj(self, connect_id): - if connect_id in self.connect_id_to_trans_obj: - trans_obj = self.connect_id_to_trans_obj.pop(connect_id, None) - if trans_obj is not None: - trans_obj.set_has_error() - return - - -def _init_env(args, info_queue: mp.Queue, event: mp.Event): - import lightllm.utils.rpyc_fix_utils as _ - - # 注册graceful 退出的处理 - graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::decode_kv_move_manager") - - manager = DecodeKVMoveManager(args, info_queue) - t = ThreadedServer(manager, port=args.pd_decode_rpyc_port, protocol_config={"allow_pickle": True}) - threading.Thread(target=lambda: t.start(), daemon=True).start() - - kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True) - kv_trans_process_check.start() - - event.set() - manager.timer_loop() - return - - -def start_decode_kv_move_manager_process(args, info_queue: mp.Queue): - event = mp.Event() - proc = mp.Process(target=_init_env, args=(args, info_queue, event)) - proc.start() - event.wait() - assert proc.is_alive() - logger.info("decode kv move manager process started") - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_task_cache.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_task_cache.py deleted file mode 100644 index 48df4b86fb..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_task_cache.py +++ /dev/null @@ -1,10 +0,0 @@ -# 这个里面声明了一个全局变量,主要用于推理进程缓存发送给其他进程的Kv move 任务的缓存数据 -# 为了减少一些调用时候的序列化开销。有些调用就只需要传输一个请求id就可以了,不用传输特别的 -# 数据了,提升rpyc 调用的速度, 只用在 decode_impl.py 和 decode_infer_rpyc.py 文件中 -from typing import Dict, List, Tuple -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.server.router.dynamic_prompt.radix_cache import TreeNode - -g_kv_move_task_cache: Dict[int, Tuple[KVMoveTask, TreeNode, List[int]]] = {} - -g_success_kv_move_task_cache: Dict[int, Tuple[KVMoveTask, TreeNode, float]] = {} # 第三个float代表的是时间,用于判断过期条件。 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py deleted file mode 100644 index 939f065fb6..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ /dev/null @@ -1,305 +0,0 @@ -import time -import psutil -import threading -from typing import List -from dataclasses import dataclass -from lightllm.utils.log_utils import init_logger -from ..task_queue import TaskQueue -import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup -from lightllm.utils.device_utils import kv_trans_use_p2p -from .decode_kv_move_manager import DecodeKVMoveManager -from lightllm.utils.time_utils import TimeChecker -from ..utils import join_if_alive, clear_queue - -logger = init_logger(__name__) - -KV_MOVE_MAX_NUM = 16 - - -@dataclass -class KVTransConnectObj: - connect_id: str = None - prefill_node_id: int = None - kv_trans_process: "KVTransProcess" = None - pd_prefill_nccl_ip: str = None - pd_prefill_nccl_port: int = None - device_index: int = None - manager: "DecodeKVMoveManager" = None - has_error: bool = False - ready_to_move_queue: TaskQueue = None - kv_move_thread: threading.Thread = None - move_finished_queue: TaskQueue = None - put_to_radix_thread: threading.Thread = None - timer_checker: TimeChecker = None - - def create( - self, - connect_id: str, - prefill_node_id: str, - pd_prefill_nccl_ip: str, - pd_prefill_nccl_port: int, - manager: "DecodeKVMoveManager", - ): - self.connect_id = connect_id - self.device_index = manager.get_next_device_index() - self.kv_trans_process = manager.kv_trans_processes[self.device_index] - decode_node_id = manager.args.pd_node_id - self.prefill_node_id = prefill_node_id - self.decode_node_id = decode_node_id - self.pd_prefill_nccl_ip = pd_prefill_nccl_ip - self.pd_prefill_nccl_port = pd_prefill_nccl_port - - self.manager = manager - self.timer_checker = TimeChecker(6) - - with self.kv_trans_process.device_lock: - clear_queue(self.kv_trans_process.task_out_queue) - self.kv_trans_process.task_in_queue.put( - PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=-1, - pd_prefill_nccl_ip=pd_prefill_nccl_ip, - pd_prefill_nccl_port=pd_prefill_nccl_port, - decode_id=decode_node_id, - decode_device_id=self.device_index, - connect_id=self.connect_id, - ) - ) - assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" - - self.ready_to_move_queue = TaskQueue( - get_func=lambda datas: datas[0:1], fail_func=self.manager.put_to_fail_release_task_queue - ) - self.kv_move_thread = threading.Thread(target=self.kv_move_loop, daemon=True) - self.kv_move_thread.start() - - self.move_finished_queue = TaskQueue( - get_func=lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=self.manager.put_to_fail_release_task_queue - ) - self.put_to_radix_thread = threading.Thread(target=self.put_to_radix_loop, daemon=True) - self.put_to_radix_thread.start() - return - - # ================================================================================== - # 处理接受所有进行 kv 传输的请求,完成后,将请求放入到 move_finished_queue 中 - # ================================================================================== - - def _transfer_kv(self, move_tasks: List[KVMoveTask]): - with self.kv_trans_process.device_lock: - clear_queue(self.kv_trans_process.task_out_queue) - kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) - kv_move_group.connect_id = self.connect_id - self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) - assert self.kv_trans_process.task_out_queue.get(timeout=60) == "ok" - logger.info(f"_transfer_kv ok {move_tasks[0].to_decode_log_info()}") - - # 标记 decode 接收到 kv cache 的时间 - for move_task in move_tasks: - move_task.mark_start_time = time.time() - - self.move_finished_queue.put_list(move_tasks) - move_tasks.clear() - - def kv_move_loop(self): - func_name = self.kv_move_loop.__name__ - while not self.has_error: - move_tasks: List[List[KVMoveTask]] = self.ready_to_move_queue.get_tasks(log_tag="ready_to_move_queue") - if len(move_tasks) == 0: - time.sleep(0.01) - continue - - if len(move_tasks) != 1: - logger.error(f"error get need 1, but get {len(move_tasks)}") - assert False - - move_tasks: List[KVMoveTask] = move_tasks[0] - for task in move_tasks: - logger.info(f"{func_name} get task {task.to_decode_log_info()}") - - try: - self.timer_to_check_status(raise_exception=True) - if not kv_trans_use_p2p(): - with self.manager.kv_trans_lock: - self._transfer_kv(move_tasks) - else: - self._transfer_kv(move_tasks) - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.ready_to_move_queue.clear_tasks() - - finally: - self.manager.put_to_fail_release_task_queue(move_tasks) - - logger.error(f"{func_name} thread quit") - return - - # ================================================================================== - # 将传输完成的请求,放入到 radix cache 中进行管理。 - # ================================================================================== - - def put_to_radix_loop(self): - func_name = self.put_to_radix_loop.__name__ - while not self.has_error: - move_tasks: List[KVMoveTask] = self.move_finished_queue.get_tasks(log_tag="move_finished_queue") - if len(move_tasks) == 0: - time.sleep(0.01) - continue - - for task in move_tasks: - logger.info(f"{func_name} get put radix task {task.to_decode_log_info()}") - - try: - self.timer_to_check_status(raise_exception=True) - # random to check stats - self.manager._put_kv_received_to_radix_cache(move_tasks.copy()) - for task in move_tasks.copy(): - logger.info( - f"{func_name} put kv to radix cache ok, req_id: {task.id()} cost_time {task.get_cost_time()} s" - ) - self.manager.up_status_in_queue.put( - UpKVStatus( - group_request_id=task.group_request_id, - dp_index=task.decode_dp_index, - pd_master_node_id=task.pd_master_node_id, - ) - ) - logger.info(f"{func_name} up kv status req_id: {task.id()} finished") - move_tasks.clear() - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.move_finished_queue.clear_tasks() - - finally: - self.manager.put_to_fail_release_task_queue(move_tasks) - - logger.error(f"{func_name} thread quit, info: {self.to_log_info()}") - return - - # ================================================================================== - # 错误处理检测操作的一些通用函数 - # ================================================================================== - - def timer_to_check_status(self, raise_exception=True): - if self.timer_checker.has_exceeded(): - try: - assert self.kv_trans_process.is_trans_process_health() - except BaseException as e: - logger.error(f"pid {self.kv_trans_process.process.pid} check failed") - logger.exception(str(e)) - - self.set_has_error() - if raise_exception: - raise e - return - - def has_error_status(self): - try: - assert self.has_error is False - assert self.kv_move_thread.is_alive() - assert self.put_to_radix_thread.is_alive() - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - return True - - return False - - def set_has_error(self): - self.has_error = True - - if self.ready_to_move_queue is not None: - self.ready_to_move_queue.has_error = True - - if self.move_finished_queue is not None: - self.move_finished_queue.has_error = True - - if self.manager is not None: - self.manager.remove_trans_obj(self.connect_id) - return - - def __del__(self): - logger.error(f"trans obj del start, info: {self.to_log_info()}") - - try: - self.set_has_error() - - join_if_alive(self.kv_move_thread) - join_if_alive(self.put_to_radix_thread) - - if self.connect_id is not None and self.kv_trans_process is not None: - self.kv_trans_process.task_in_queue.put( - PDTransLeaveInfo( - decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id - ) - ) - - if self.ready_to_move_queue is not None: - self.ready_to_move_queue.clear_tasks() - if self.move_finished_queue is not None: - self.move_finished_queue.clear_tasks() - - except BaseException as e: - logger.exception(str(e)) - - logger.error(f"trans obj deled, info: {self.to_log_info()}") - - def to_log_info(self): - log = f"connect_id: {self.connect_id} " - log += f"decode_node_id: {self.decode_node_id} " - log += f"prefill_node_id: {self.prefill_node_id} " - log += f"device_index: {self.device_index} " - return log - - -@dataclass -class KVTransProcess: - process: mp.Process = None - # 需要每个卡有一个锁来规划每次只能有一个 connection obj 操作对应显卡上的传输任务。 - device_lock: threading.Lock = None - task_in_queue: mp.Queue = None - task_out_queue: mp.Queue = None - device_id: int = None - - def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): - self.device_lock = threading.Lock() - self.device_id = device_id - self.task_in_queue = mp.Queue() - self.task_out_queue = mp.Queue() - - try: - from .decode_trans_process import start_decode_trans_process - - self.process = start_decode_trans_process( - manager.args, - device_id, - self.task_in_queue, - self.task_out_queue, - ) - assert self.task_out_queue.get(timeout=30) == "proc_start" - assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" - - return True - - except Exception as e: - logger.warning(f"Failed start kv trans process for device {device_id}: {e}") - logger.exception(str(e)) - return False - - def is_trans_process_health(self): - try: - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - logger.error(f"kv trans process for device: {self.device_id} dead!!!") - return False - else: - return True - except: - return False - - def killself(self): - self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py deleted file mode 100644 index cdca638873..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ /dev/null @@ -1,155 +0,0 @@ -import torch -import time -import sys -import inspect -import threading -import setproctitle -import torch.multiprocessing as mp -from torch.distributed import TCPStore -from datetime import timedelta -from typing import List, Dict, Union -from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup -from lightllm.utils.envs_utils import get_unique_server_name - -logger = init_logger(__name__) - - -def _handle_kvmove_task( - move_tasks: List[KVMoveTask], - task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], - connect_id_to_comm: Dict[str, PyNcclCommunicator], - connect_id: str, - dp_size_in_node: int, -): - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - device_index = connect_id_to_comm[connect_id].device.index - start = time.time() - if total_move_kv_len != 0: - cur_mem = mem_managers[device_index] - logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") - if kv_trans_use_p2p(): - cur_mem.receive_from_prefill_node_p2p( - move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] - ) - else: - cur_mem.receive_from_prefill_node( - move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] - ) - logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - raise e - - -def _handle_prefill_join( - node_info: PDTransJoinInfo, task_out_queue: mp.Queue, connect_id_to_comm: Dict[str, PyNcclCommunicator] -): - try: - logger.info(f"connect start {node_info}") - store_client = TCPStore( - host_name=node_info.pd_prefill_nccl_ip, - port=node_info.pd_prefill_nccl_port, - is_master=False, - use_libuv=True, - timeout=timedelta(seconds=30), - ) - src_id = node_info.prefill_id - dest_id = node_info.connect_id - logger.info(f"connect src_id {src_id} dest_id {dest_id}") - - result_list = [] - - def async_connect(): - torch.cuda.set_device(node_info.decode_device_id) - group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=False, store=store_client) - comm = PyNcclCommunicator(group, node_info.decode_device_id) - result_list.append(comm) - return - - connect_task = threading.Thread(target=async_connect, daemon=True) - connect_task.start() - connect_task.join(timeout=36) - if connect_task.is_alive(): - raise Exception(f"{node_info} connect time out") - - connect_id_to_comm[node_info.connect_id] = result_list[0] - logger.info(f"{node_info} kv trans connected") - task_out_queue.put("nccl_ok") - except Exception as e: - task_out_queue.put("nccl_fail") - logger.warning(f"error while connect to prefill node: {e}") - - -def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue): - import os - - # os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_MAX_NCHANNELS"] = "2" - os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" - os.environ["NCCL_SOCKET_NTHREADS"] = "1" - torch.backends.cudnn.enabled = False - - dp_size_in_node = max(1, args.dp // args.nnodes) - - setproctitle.setproctitle( - f"lightllm::{get_unique_server_name()}::decode_trans:Device{device_id}_DpSizeInNode{dp_size_in_node}" - ) - - try: - torch.cuda.set_device(device_id) - graceful_registry(inspect.currentframe().f_code.co_name) - task_out_queue.put("proc_start") - - # 从共享内存读取所有rank的mem_manager - node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) - ] - - task_out_queue.put("get_mem_managers_ok") - connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} - while True: - task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() - if isinstance(task, KVMoveTaskGroup): - _handle_kvmove_task( - task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node - ) - elif isinstance(task, PDTransJoinInfo): - _handle_prefill_join(task, task_out_queue, connect_id_to_comm) - elif isinstance(task, PDTransLeaveInfo): - if task.connect_id in connect_id_to_comm: - connect_id_to_comm[task.connect_id].destroy() - logger.info(f"destory {task} nccl communicator.") - else: - logger.info(f"no connect_id {task.connect_id} found in connect_id_to_comm") - - else: - logger.warning(f"unexpected task type: {task}") - - except Exception as e: - logger.error(f"Fatal error happened in kv trans process: {e} in device {device_id}") - raise - - -def start_decode_trans_process( - args, - device_id: int, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, -): - proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue)) - proc.start() - assert proc.is_alive() - logger.info(f"decode trans kv process for device: {device_id} start!") - return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py deleted file mode 100644 index 833ffecc89..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py +++ /dev/null @@ -1,127 +0,0 @@ -import time -import json -import asyncio -import threading -import websockets -import inspect -import setproctitle -import pickle - -from typing import Dict -from dataclasses import asdict -from lightllm.server.pd_io_struct import UpKVStatus -from lightllm.utils.log_utils import init_logger -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.server.pd_io_struct import PD_Master_Obj -import torch.multiprocessing as mp -from lightllm.utils.envs_utils import get_unique_server_name - -logger = init_logger(__name__) - - -class UpStatusManager: - def __init__(self, args, task_in_queue: mp.Queue, task_out_queue: mp.Queue): - self.args = args - self.task_queue: mp.Queue[UpKVStatus] = task_in_queue - self.task_out_queue = task_out_queue - self.daemon_thread = threading.Thread(target=self.thread_loop, daemon=True) - self.daemon_thread.start() - - def thread_loop(self): - asyncio.run(self.task_loop()) - - async def task_loop(self): - - self.id_to_handle_task: Dict[int, asyncio.Task] = {} - self.id_to_handle_queue: Dict[int, asyncio.Queue] = {} - - asyncio.create_task(self.dispatch_task_loop()) - - while True: - try: - from lightllm.server.httpserver.pd_loop import _get_pd_master_objs - - id_to_pd_master_obj = await _get_pd_master_objs(self.args) - logger.info(f"get pd_master_objs {id_to_pd_master_obj}") - - if id_to_pd_master_obj is not None: - for node_id, pd_master_obj in self.id_to_handle_task.items(): - if node_id not in id_to_pd_master_obj: - self.id_to_handle_task[node_id].cancel() - self.id_to_handle_task.pop(node_id, None) - self.id_to_handle_queue.pop(node_id, None) - logger.info(f"up_kv_status_task {pd_master_obj} cancelled") - - for node_id, pd_master_obj in id_to_pd_master_obj.items(): - if node_id not in self.id_to_handle_task: - self.id_to_handle_queue[node_id] = asyncio.Queue() - self.id_to_handle_task[node_id] = asyncio.create_task(self.up_kv_status_task(pd_master_obj)) - - await asyncio.sleep(30) - - except Exception as e: - logger.exception(str(e)) - await asyncio.sleep(10) - - async def dispatch_task_loop(self): - while True: - try: - loop = asyncio.get_event_loop() - upkv_status: UpKVStatus = await loop.run_in_executor(None, self.task_queue.get) - if upkv_status.pd_master_node_id in self.id_to_handle_queue: - await self.id_to_handle_queue[upkv_status.pd_master_node_id].put(upkv_status) - else: - logger.warning(f"upstatus {upkv_status} no connection to pd_master, drop it") - except BaseException as e: - logger.exception(str(e)) - await asyncio.sleep(10) - - async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): - while True: - try: - uri = f"ws://{pd_master_obj.host_ip_port}/kv_move_status" - async with websockets.connect(uri) as websocket: - import socket - - sock = websocket.transport.get_extra_info("socket") - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - while True: - try: - if pd_master_obj.node_id in self.id_to_handle_queue: - task_queue = self.id_to_handle_queue[pd_master_obj.node_id] - upkv_status: UpKVStatus = await task_queue.get() - await websocket.send(pickle.dumps(upkv_status)) - logger.info(f"up status: {upkv_status}") - else: - await asyncio.sleep(3) - except BaseException as e: - logger.error(str(e)) - raise e - except asyncio.CancelledError: - logger.info(f"up_kv_status_task {pd_master_obj} cancelled") - return - - except Exception as e: - logger.error(f"connetion to pd_master {pd_master_obj} has error: {str(e)}") - logger.exception(str(e)) - await asyncio.sleep(10) - logger.info("reconnection to pd_master") - - -def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue): - graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::up_kv_status") - up_kv_manager = UpStatusManager(args, task_in_queue, task_out_queue) - logger.info(f"up kv manager {str(up_kv_manager)} start ok") - while True: - time.sleep(666) - return - - -def start_up_kv_status_process(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue): - proc = mp.Process(target=_init_env, args=(args, task_in_queue, task_out_queue)) - proc.start() - assert proc.is_alive() - logger.info("up_kv_status_process start") - return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py deleted file mode 100644 index 8e7bddc64e..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -import time -import threading -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import InferReq -from lightllm.server.pd_io_struct import KVMoveTask, DecodeNodeInfo -from lightllm.utils.log_utils import init_logger -from lightllm.common.basemodel.infer_lock import g_router_lock, g_infer_state_lock -from rpyc.utils.server import ThreadedServer -from .prefill_task_cache import g_kv_move_task_cache -from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.utils.dist_utils import create_new_group_for_current_dp -from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend - -logger = init_logger(__name__) - - -class ChunckedPrefillForPrefillNode(ChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue) -> None: - super().__init__() - self.support_overlap = False - self.info_queue: mp.Queue = info_queue - self.classed_req_no_decode = True - - def init_custom(self): - - self.lock_nccl_group = create_new_group_for_current_dp("gloo") - logger.info(f"lock_nccl_group ranks {dist.get_rank(self.lock_nccl_group)}") - - from .prefill_infer_rpyc import PDPrefillInferRpcServer - - socket_path = f"/tmp/{get_unique_server_name()}_prefill_node_infer_rpyc_{self.pd_rpyc_ports[self.rank_in_node]}" - if os.path.exists(socket_path): - os.remove(socket_path) - - t = ThreadedServer( - PDPrefillInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True} - ) - threading.Thread(target=lambda: t.start(), daemon=True).start() - return - - def _pre_handle_finished_reqs(self, finished_reqs): - self._prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(finished_reqs=finished_reqs) - return - - def _prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, finished_reqs: List[InferReq]): - if len(finished_reqs) == 0: - return - - # 提前在radix cache中回收相关的信息,并添加引用进行锁定,方便传输进程传输kv。 - if self.is_master_in_dp: - logger.info("prefill_req_handle_and_frozen_tokens") - - g_infer_state_lock.acquire() - try: - for req in finished_reqs: - - # 区分abort 和 正常结束的请求,正常结束的请求才发起kv传输任务。 - if not req.finish_status.is_finished(): - continue - - req: InferReq = req - key = req.get_input_token_ids()[0 : req.cur_kv_len] - key = torch.tensor(key, dtype=torch.int64, device="cpu") - value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len, new_shared_kv_node = self.radix_cache.insert(key, value) - old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len - self.model.mem_manager.free( - self.model.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] - ) - # 将原有共享节点替换为新共享节点,新共享节点对应的长度为当前的cur_kv_len - - self.radix_cache.dec_node_ref_counter(req.shared_kv_node) - self.radix_cache.add_node_ref_counter(new_shared_kv_node) - req.shared_kv_node = new_shared_kv_node - - _kv_len = req.cur_kv_len - _value = self.radix_cache.get_mem_index_value_by_node(new_shared_kv_node) - assert len(_value) == _kv_len - self.model.req_manager.req_to_token_indexs[req.req_idx][0:_kv_len] = _value - - assert new_shared_kv_node.node_prefix_total_len == req.cur_kv_len - - if req.shm_req.sample_params.move_kv_to_decode_node.exists: - # 注意兼容纯tp 和 tp dp 混合模式的逻辑 - if self.is_master_in_dp: - g_router_lock.acquire() - self.shared_token_load.add_frozened_token_count(len(key), self.dp_rank_in_node) - g_router_lock.release() - - share_node, kv_len, value = self.radix_cache.match_prefix(key, update_refs=True) - assert len(key) == len(value) - # 将下面的请求放入到任务队列中, 注意要使用raidx cache 返回的value - decode_node_info = DecodeNodeInfo(**req.shm_req.sample_params.move_kv_to_decode_node.to_dict()) - task = KVMoveTask( - group_request_id=req.shm_req.group_req_id, - input_tokens=key.tolist(), - prefill_token_indexes=value.tolist(), - decode_token_indexes=None, - prefill_node_id=self.args.pd_node_id, - decode_node=decode_node_info, - move_kv_len=None, - prefill_dp_index=self.dp_rank_in_node, - decode_dp_index=None, - pd_master_node_id=req.shm_req.sample_params.pd_master_node_id.get(), - mark_start_time=time.time(), - ) - g_kv_move_task_cache[task.group_request_id] = (task, share_node) - - # 注意兼容纯 tp 和 tp dp 混合模式的逻辑 - if self.is_master_in_dp: - self.info_queue.put(task) - except BaseException as e: - logger.exception(str(e)) - g_infer_state_lock.release() - if self.is_master_in_dp: - logger.info("prefill_req_handle_and_frozen_tokens end") - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py deleted file mode 100644 index 2897f71412..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch.multiprocessing as mp -from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import InferReq -from lightllm.utils.log_utils import init_logger -from .prefill_impl import ChunckedPrefillForPrefillNode -from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend - -logger = init_logger(__name__) - - -class DPChunkedForPrefillNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue) -> None: - super().__init__() - self.support_overlap = False - self.info_queue: mp.Queue = info_queue - self.classed_req_no_decode = True - - def init_custom(self): - ChunckedPrefillForPrefillNode.init_custom(self) - return - - def _pre_handle_finished_reqs(self, finished_reqs): - self._prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(finished_reqs=finished_reqs) - return - - def _prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, finished_reqs: List[InferReq]): - ChunckedPrefillForPrefillNode._prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue( - self, finished_reqs=finished_reqs - ) - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py deleted file mode 100644 index 1f2dd52c5a..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -import torch.distributed as dist -import rpyc -from typing import Dict, List, Tuple -from rpyc.utils.classic import obtain -from .prefill_impl import ChunckedPrefillForPrefillNode -from lightllm.common.basemodel.infer_lock import g_router_lock, acquire_lock_until_ready, release_acquired_lock -from .prefill_task_cache import g_kv_move_task_cache -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class PDPrefillInferRpcServer(rpyc.Service): - def __init__(self, backend: ChunckedPrefillForPrefillNode) -> None: - super().__init__() - self.backend = backend - self.device_id = self.backend.current_device_id - self.dp_rank_in_node = self.backend.dp_rank_in_node - self.is_master_in_dp = self.backend.is_master_in_dp - return - - def on_connect(self, conn): - torch.cuda.set_device(f"cuda:{self.device_id}") - return - - # pd 分离模式会使用的一些接口,用于做一些全局信息管理 - def exposed_remove_req_refs_from_prompt_cache(self, group_req_ids: List[int]): - group_req_ids = obtain(group_req_ids) - acquire_lock_until_ready(self.backend.lock_nccl_group) - for group_req_id in group_req_ids: - if group_req_id in g_kv_move_task_cache: - task, share_node = g_kv_move_task_cache.pop(group_req_id) - if share_node is not None: - self.backend.radix_cache.dec_node_ref_counter(share_node) - # 减少日志数量 - if self.is_master_in_dp: - logger.info(f"unfrozen tokens for req id: {group_req_id}") - - # 更新调度元数据 - if self.is_master_in_dp: - with g_router_lock.obj: - self.backend.shared_token_load.add_frozened_token_count( - -len(task.input_tokens), self.dp_rank_in_node - ) - release_acquired_lock() - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py deleted file mode 100644 index bd5af98ee6..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ /dev/null @@ -1,241 +0,0 @@ -import asyncio -import time -import rpyc -import sys -import os -import gc -import signal -import copy -import numpy as np -import psutil -import threading -import inspect -import collections -import setproctitle -from typing import List, Dict, Union -from lightllm.utils.log_utils import init_logger -from .prefill_infer_rpyc import PDPrefillInferRpcServer -import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.utils.retry_utils import retry -from rpyc import AsyncResult -from lightllm.utils.net_utils import get_hostname_ip -from ..task_queue import TaskQueue -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.utils.envs_utils import get_unique_server_name - -KV_MOVE_MAX_NUM = 16 - -logger = init_logger(__name__) - - -class PrefillKVMoveManager: - def __init__(self, args, info_queue: mp.Queue): - self.args = args - # args.dp // args.nnodes 在跨机tp的场景下,可能为0 - self.dp_size_in_node = max(1, args.dp // args.nnodes) - self.node_world_size = args.tp // args.nnodes - self.dp_world_size = args.tp // args.dp - # 不支持跨机tp的pd 分离策略 - assert self.dp_world_size <= self.node_world_size - - self.info_queue = info_queue - self.infer_rpyc_objs: List[PDPrefillInferRpcServer] = [] - - from .prefill_trans_obj import KVTransConnectObj - - self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {} - - for port in self.args.pd_node_infer_rpyc_ports: - socket_path = f"/tmp/{get_unique_server_name()}_prefill_node_infer_rpyc_{port}" - from rpyc.utils.factory import unix_connect - - con = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) - self.infer_rpyc_objs.append(con.root) - logger.info(f"rpyc connect to infer rpyc port: {port} ok") - self.host_ip = get_hostname_ip() - if self.host_ip is None: - self.host_ip = args.host - - self.infer_rpyc_lock = threading.Lock() - - self.kv_trans_lock = threading.Lock() - # 释放token的task队列 - self.release_task_queue = TaskQueue(lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=None) - self.release_tasks_thread = threading.Thread(target=self.handle_release_task_loop, daemon=True) - self.release_tasks_thread.start() - - from .prefill_trans_obj import KVTransProcess - - self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size - for device_id in range(self.node_world_size): - self.kv_trans_processes[device_id] = KVTransProcess() - assert self.kv_trans_processes[device_id].init_all(device_id, self) - - return - - # ================================================================================== - # 主任务循环,接收需要进行kv传输的请求进行处理 - # ================================================================================== - - def task_dispatcher_loop(self): - try: - # 获取任务,并分发给相关卡的处理队列 - while True: - move_task: KVMoveTask = self.info_queue.get() - try: - trans_obj = self.__get_trans_obj(move_task) - trans_obj.request_kv_trans_task_queue.put(move_task) - except BaseException as e: - logger.exception(str(e)) - self.put_to_release_task_queue(move_task) - finally: - trans_obj = None - - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - raise e - - # ================================================================================== - # 请求出错或者完成kv传输后的处理队列和线程loop - # ================================================================================== - - def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): - if isinstance(task, KVMoveTask): - self.release_task_queue.put(task) - elif isinstance(task, list): - self.release_task_queue.put_list(task) - else: - logger.error("error input in put_to_release_task_queue func") - return - - def handle_release_task_loop(self): - while True: - handle_list: List[KVMoveTask] = self.release_task_queue.get_tasks(log_tag="release_task_queue") - if len(handle_list) == 0: - time.sleep(0.01) - else: - self._remove_req_refs_from_prompt_cache(handle_list) - return - - # ================================================================================== - # 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启 - # ================================================================================== - - def check_trans_process_loop(self): - try: - while True: - for device_id in range(self.node_world_size): - if not self.kv_trans_processes[device_id].is_trans_process_health(): - raise Exception(f"device_id {device_id} kv process is unhealth") - - time.sleep(10.0) - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - - for device_id in range(self.node_world_size): - self.kv_trans_processes[device_id].killself() - - # 杀掉当前进程的父进程(router), 触发全局崩溃 - os.kill(os.getppid(), signal.SIGKILL) - os.kill(os.getpid(), signal.SIGKILL) - raise e - - # ================================================================================== - # 与推理进程交互接口, _remove_req_refs_from_prompt_cache - # ================================================================================== - - def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]): - with self.infer_rpyc_lock: - dp_to_tasks = collections.defaultdict(list) - for task in tasks: - dp_to_tasks[task.prefill_dp_index].append(task) - futures: List[AsyncResult] = [] - for prefill_dp_index, _tasks in dp_to_tasks.items(): - conn_start = prefill_dp_index * self.dp_world_size - conn_end = (prefill_dp_index + 1) * self.dp_world_size - conns = self.infer_rpyc_objs[conn_start:conn_end] - for conn in conns: - futures.append( - rpyc.async_(conn.remove_req_refs_from_prompt_cache)([task.group_request_id for task in _tasks]) - ) - asyncio.run(self.wait_all_future_finish(futures)) - return - - async def wait_all_future_finish(self, futures: List[AsyncResult]): - await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) - return - - # ================================================================================== - # 辅助功能接口 - # ================================================================================== - - def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] - for obj in self.connect_id_to_trans_obj.values(): - counts[obj.device_index] += 1 - device_index = int(np.argmin(counts)) - return device_index - - def remove_trans_obj(self, connect_id): - if connect_id in self.connect_id_to_trans_obj: - trans_obj = self.connect_id_to_trans_obj.pop(connect_id, None) - if trans_obj is not None: - trans_obj.set_has_error() - logger.error(f"remove tran obj decode_node_id {trans_obj.decode_node_id}") - return - - def __get_trans_obj(self, task: KVMoveTask): - self.__remove_dead_trans_obj() - # 如果已经存在连接对象,直接返回 - for obj in self.connect_id_to_trans_obj.values(): - if obj.decode_node_id == task.decode_node.node_id: - return obj - - # 如果不存在连接对象,创建新的连接对象 - gc.collect() - from .prefill_trans_obj import KVTransConnectObj - - trans_obj = KVTransConnectObj() - trans_obj.create(task.decode_node.node_id, task.decode_node.ip, task.decode_node.rpyc_port, self) - self.connect_id_to_trans_obj[trans_obj.connect_id] = trans_obj - return trans_obj - - def __remove_dead_trans_obj(self): - del_connect_ids = [] - for connect_id, t_obj in self.connect_id_to_trans_obj.items(): - if t_obj.has_error_status(): - del_connect_ids.append(connect_id) - - for connect_id in del_connect_ids: - self.connect_id_to_trans_obj.pop(connect_id, None) - - if del_connect_ids: - gc.collect() - return - - -def _init_env(args, info_queue: mp.Queue, event: mp.Event): - import lightllm.utils.rpyc_fix_utils as _ - - # 注册graceful 退出的处理 - graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::prefill_kv_move_manager") - - manager = PrefillKVMoveManager(args, info_queue) - kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True) - kv_trans_process_check.start() - event.set() - # 进入主循环 - manager.task_dispatcher_loop() - return - - -def start_prefill_kv_move_manager_process(args, info_queue: mp.Queue): - event = mp.Event() - proc = mp.Process(target=_init_env, args=(args, info_queue, event)) - proc.start() - event.wait() - assert proc.is_alive() - logger.info("prefill kv move manager process started") - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_task_cache.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_task_cache.py deleted file mode 100644 index afa8e87f44..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_task_cache.py +++ /dev/null @@ -1,8 +0,0 @@ -# 这个里面声明了一个全局变量,主要用于推理进程缓存发送给其他进程的Kv move 任务的缓存数据 -# 为了减少一些调用时候的序列化开销。有些调用就只需要传输一个请求id就可以了,不用传输特别的 -# 数据了,提升rpyc 调用的速度, 只用在 prefill_impl.py 和 prefill_infer_rpyc.py 文件中 -from typing import Dict, Tuple -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.server.router.dynamic_prompt.radix_cache import TreeNode - -g_kv_move_task_cache: Dict[int, Tuple[KVMoveTask, TreeNode]] = {} diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py deleted file mode 100644 index 022be45591..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ /dev/null @@ -1,378 +0,0 @@ -import time -import rpyc -import copy -import uuid -import numpy as np -import psutil -import threading -from dataclasses import dataclass -from typing import List, Dict, Union -from lightllm.utils.log_utils import init_logger -import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup -from rpyc.utils.classic import obtain -from ..task_queue import TaskQueue -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.time_utils import TimeChecker -from .prefill_kv_move_manager import PrefillKVMoveManager -from lightllm.utils.net_utils import find_available_port -from ..utils import join_if_alive, clear_queue - -logger = init_logger(__name__) - - -@dataclass -class KVTransConnectObj: - connect_id: str = None - decode_node_id: int = None - rpyc_conn: object = None # rpyc_con 的连接对象 - kv_trans_process: "KVTransProcess" = None - device_index: int = None # 使用的gpu序号 - manager: "PrefillKVMoveManager" = None - has_error: bool = False - request_kv_trans_task_queue: TaskQueue = None - request_thread: threading.Thread = None - ready_kv_trans_task_queue: TaskQueue = None - kv_trans_thread: threading.Thread = None - timer_checker: TimeChecker = None - - # ================================================================================== - # 构建传输通信对象 - # ================================================================================== - - def create( - self, decode_node_id: int, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" - ): - device_index = manager.get_next_device_index() # 分配使用的显卡index - self.kv_trans_process = manager.kv_trans_processes[device_index] - prefill_node_id = manager.args.pd_node_id - self.connect_id = str(uuid.uuid4()) - self.decode_node_id = decode_node_id - self.prefill_node_id = prefill_node_id - self.device_index = device_index - self.manager = manager - self.timer_checker = TimeChecker(6) - - con = rpyc.connect( - host=decode_node_ip, - port=decode_node_rpyc_port, - config={"allow_pickle": True, "sync_request_timeout": 60}, - keepalive=True, - ) - - self.rpyc_conn = con - - # 创建 nccl 连接 - with self.kv_trans_process.device_lock: - clear_queue(self.kv_trans_process.task_out_queue) - - self.kv_trans_process.task_in_queue.put( - PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=device_index, - pd_prefill_nccl_ip=manager.host_ip, - pd_prefill_nccl_port=self.kv_trans_process.kv_trans_port, - decode_id=decode_node_id, - decode_device_id=-1, - connect_id=self.connect_id, - ) - ) - - # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 - max_kv_trans_token_num = obtain( - con.root.build_trans_connect( - prefill_node_id, - manager.host_ip, - self.kv_trans_process.kv_trans_port, - manager.args.max_total_token_num, - self.connect_id, - ) - ) - self.max_kv_trans_token_num = max_kv_trans_token_num - assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" - - self.request_kv_trans_task_queue = TaskQueue( - get_func=self._get_request_tasks, fail_func=self.manager.put_to_release_task_queue - ) - self.request_thread = threading.Thread(target=self.request_kv_trans_loop, daemon=True) - self.request_thread.start() - - self.ready_kv_trans_task_queue = TaskQueue(lambda datas: datas[0:1], self.manager.put_to_release_task_queue) - self.kv_trans_thread = threading.Thread(target=self.kv_trans_handle_loop, daemon=True) - self.kv_trans_thread.start() - - logger.info(f"create KVTransConnectObj success: {self.to_log_info()}") - return - - def _get_request_tasks(self, datas: List[KVMoveTask]): - """ - 根据可以p和d节点间协商得到的 max_kv_trans_token_num 限制,将排队等待 - 传输的请求打包成一个可以传输的list组。 - """ - ans_list = [] - token_num = 0 - for task in datas: - if token_num + len(task.prefill_token_indexes) <= self.max_kv_trans_token_num: - ans_list.append(task) - token_num += len(task.prefill_token_indexes) - else: - break - return ans_list - - # ================================================================================== - # 与 decode 节点进行元数据交互,申请锁定资源准备进行kv的传输 - # ================================================================================== - def request_kv_trans_loop(self): - func_name = self.request_kv_trans_loop.__name__ - - while not self.has_error: - move_tasks: List[KVMoveTask] = self.request_kv_trans_task_queue.get_tasks( - log_tag="request_kv_trans_task_queue" - ) - if len(move_tasks) == 0: - self.timer_check_status(raise_exception=False) - time.sleep(0.01) - continue - try: - self.timer_check_status(raise_exception=True) - for move_task in move_tasks: - move_task.connect_id = self.connect_id - logger.info( - f"{func_name} get task {move_task.to_prefill_log_info()} " - f"queue time {move_task.get_cost_time()} s " - ) - - trans_move_tasks = [copy.copy(move_task) for move_task in move_tasks] - for trans_move_task in trans_move_tasks: - trans_move_task.prefill_token_indexes = None - - mark_start = time.time() - move_kv_lens = self.rpyc_conn.root.request_data_transfer(trans_move_tasks) - move_kv_lens = obtain(move_kv_lens) - request_data_transfer_cost_time = time.time() - mark_start - - logger.info( - f"{func_name} request_data_transfer ok, {move_tasks[0].to_prefill_log_info()}" - f" cost time: {request_data_transfer_cost_time} s" - ) - - ok_trans_list = [] - for i, move_task in enumerate(move_tasks.copy()): - if move_kv_lens[i] is not None: - move_task.move_kv_len = move_kv_lens[i] - ok_trans_list.append(move_task) - move_tasks.remove(move_task) - else: - logger.info(f"prefill node kv move task req_id: {move_task.id()} not send, decode is busy") - - if ok_trans_list: - self.ready_kv_trans_task_queue.put( - ok_trans_list, error_handle_func=self.manager.put_to_release_task_queue - ) - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.request_kv_trans_task_queue.clear_tasks() - - finally: - # 将没有申请成功的请求放入到释放队列中 - self.manager.put_to_release_task_queue(move_tasks) - - logger.error(f"{func_name}, {self.to_log_info()} thread quit") - return - - # ================================================================================== - # 将准备好 kv 传输的请求进行 kv 传输 - # ================================================================================== - def _transfer_kv(self, move_tasks: List[KVMoveTask]): - with self.kv_trans_process.device_lock: - clear_queue(self.kv_trans_process.task_out_queue) - kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) - self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) - assert self.kv_trans_process.task_out_queue.get(timeout=60) == "ok" - self.manager.put_to_release_task_queue(move_tasks) - - logger.info( - f"_transfer_kv data ok, req_id: {move_tasks[0].id()}" - f" cost total time: {move_tasks[0].get_cost_time()} s" - ) - move_tasks.clear() - - def kv_trans_handle_loop(self): - func_name = self.kv_trans_handle_loop.__name__ - while not self.has_error: - move_tasks: List[List[KVMoveTask]] = self.ready_kv_trans_task_queue.get_tasks( - log_tag="ready_kv_trans_task_queue" - ) - if len(move_tasks) == 0: - self.timer_check_status(raise_exception=False) - time.sleep(0.01) - continue - - if len(move_tasks) != 1: - logger.error(f"error get kv trans move_tasks, must be 1, get {len(move_tasks)}") - assert len(move_tasks) == 1 - - move_tasks: List[KVMoveTask] = move_tasks[0] - - try: - self.timer_check_status(raise_exception=True) - for move_task in move_tasks: - logger.info( - f"{func_name} get task {move_task.to_prefill_log_info()} to start kv move" - f"queue time {move_task.get_cost_time()} s " - ) - - if not kv_trans_use_p2p(): - with self.manager.kv_trans_lock: - self._transfer_kv(move_tasks) - else: - self._transfer_kv(move_tasks) - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.ready_kv_trans_task_queue.clear_tasks() - finally: - self.manager.put_to_release_task_queue(move_tasks) - - logger.error(f"trans kv thread, {self.to_log_info()} thread quit") - return - - # ================================================================================== - # 错误处理检测操作的一些通用函数 - # ================================================================================== - - def has_error_status(self): - try: - assert self.has_error is False - assert self.request_thread.is_alive() - assert self.kv_trans_thread.is_alive() - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - return True - - return False - - def timer_check_status(self, raise_exception=True): - if self.timer_checker.has_exceeded(): - try: - self.rpyc_conn.root.check_alive() - assert self.kv_trans_process.is_trans_process_health() - except BaseException as e: - logger.error(f"pid {self.kv_trans_process.process.pid} check failed") - logger.exception(str(e)) - - self.set_has_error() - if raise_exception: - raise e - - return - - def set_has_error(self): - """ - 将当前传输对象标记为有错误,这样可以防止请求放入到处理队列中 - """ - self.has_error = True - - if self.request_kv_trans_task_queue is not None: - self.request_kv_trans_task_queue.has_error = True - - if self.ready_kv_trans_task_queue is not None: - self.ready_kv_trans_task_queue.has_error = True - - if self.manager is not None: - self.manager.remove_trans_obj(self.connect_id) - return - - def __del__(self): - """ - 函数中有很多判断是否是None的操作,主要是为了避免一些异常流程的del行为不报错。 - """ - logger.error(f"trans obj del start, info: {self.to_log_info()}") - - try: - self.set_has_error() - - join_if_alive(self.request_thread) - join_if_alive(self.kv_trans_thread) - - # 将未处理的请求,清理掉,clear_tasks 会将没处理完的请求 - # 放入到 manager 资源释放队列中 - if self.request_kv_trans_task_queue is not None: - self.request_kv_trans_task_queue.clear_tasks() - if self.ready_kv_trans_task_queue is not None: - self.ready_kv_trans_task_queue.clear_tasks() - - # 传输进程清理掉 nccl 连接 - if self.connect_id is not None: - self.kv_trans_process.task_in_queue.put( - PDTransLeaveInfo( - decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id - ) - ) - - except BaseException as e: - logger.exception(str(e)) - - logger.error(f"trans obj deled, info: {self.to_log_info()}") - - def to_log_info(self): - log = f"connect_id: {self.connect_id} " - log += f"decode_node_id: {self.decode_node_id} " - log += f"prefill_node_id: {self.prefill_node_id} " - log += f"device_index: {self.device_index} " - return log - - -@dataclass -class KVTransProcess: - process: mp.Process = None - # 需要每个卡有一个锁来规划每次只能有一个 connection obj 操作对应显卡上的传输任务。 - device_lock: threading.Lock = None - task_in_queue: mp.Queue = None - task_out_queue: mp.Queue = None - device_id: int = None - kv_trans_port: int = None - - def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): - self.device_id = device_id - self.device_lock = threading.Lock() - self.task_in_queue = mp.Queue() - self.task_out_queue = mp.Queue() - self.kv_trans_port = find_available_port(manager.args.pd_p_allowed_port_min, manager.args.pd_p_allowed_port_max) - - try: - from .prefill_trans_process import start_prefill_trans_process - - self.process = start_prefill_trans_process( - manager.args, - manager.host_ip, - self.kv_trans_port, - device_id, - self.task_in_queue, - self.task_out_queue, - ) - assert self.task_out_queue.get(timeout=30) == "proc_start" - assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" - - return True - except Exception as e: - logger.warning(f"Failed start kv trans process for device {device_id}: {e}") - logger.exception(str(e)) - return False - - def is_trans_process_health(self): - try: - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - logger.error(f"kv trans process for device: {self.device_id} dead!!!") - return False - else: - return True - except: - return False - - def killself(self): - self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py deleted file mode 100644 index a328e3e080..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ /dev/null @@ -1,162 +0,0 @@ -import torch -import time -import sys -import inspect -import threading -import setproctitle -import torch.multiprocessing as mp -from torch.distributed import TCPStore -from datetime import timedelta -from typing import List, Dict, Union -from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.distributed.pynccl import StatelessP2PProcessGroup, PyNcclCommunicator -from lightllm.utils.envs_utils import get_unique_server_name - - -logger = init_logger(__name__) - - -def _handle_kvmove_task( - move_tasks: List[KVMoveTask], - task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], - connect_id_to_comm: Dict[str, PyNcclCommunicator], - connect_id: str, - dp_size_in_node: int, -): - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - device_index = connect_id_to_comm[connect_id].device.index - start = time.time() - if total_move_kv_len != 0: - logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") - cur_mem = mem_managers[device_index] - if kv_trans_use_p2p(): - cur_mem.send_to_decode_node_p2p( - move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] - ) - else: - cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id]) - logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info( - f"trans cost time: {(time.time() - start)}," - f"move_total_kv_len: {total_move_kv_len}, {move_tasks[0].to_prefill_log_info()}" - ) - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - - -def _handle_decode_join( - node_info: PDTransJoinInfo, - task_out_queue: mp.Queue, - connect_id_to_comm: Dict[str, PyNcclCommunicator], - store: TCPStore, -): - try: - logger.info(f"connect start {node_info}") - src_id = node_info.prefill_id - dest_id = node_info.connect_id - logger.info(f"connect src_id {src_id} dest_id {dest_id}") - result_list = [] - - def async_connect(): - torch.cuda.set_device(node_info.prefill_device_id) - group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=True, store=store) - comm = PyNcclCommunicator(group, node_info.prefill_device_id) - result_list.append(comm) - return - - connect_task = threading.Thread(target=async_connect, daemon=True) - connect_task.start() - connect_task.join(timeout=36) - if connect_task.is_alive(): - raise Exception(f"{node_info} connect time out") - - connect_id_to_comm[node_info.connect_id] = result_list[0] - logger.info(f"{node_info} kv trans connected!") - task_out_queue.put("nccl_ok") - except Exception as e: - task_out_queue.put("nccl_fail") - logger.warning(f"error while connect to decode node: {e} node_info {node_info}") - - -def _init_env( - args, - store_ip, - store_port, - device_id, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, -): - import os - - # os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_MAX_NCHANNELS"] = "2" - os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" - os.environ["NCCL_SOCKET_NTHREADS"] = "1" - torch.backends.cudnn.enabled = False - - dp_size_in_node = max(1, args.dp // args.nnodes) - setproctitle.setproctitle( - f"lightllm::{get_unique_server_name()}::prefill_trans:Device{device_id}_DpSizeInNode{dp_size_in_node}" - ) - - try: - torch.cuda.set_device(device_id) - graceful_registry(inspect.currentframe().f_code.co_name) - master_store = TCPStore( - host_name=store_ip, port=store_port, is_master=True, use_libuv=True, timeout=timedelta(seconds=30) - ) - task_out_queue.put("proc_start") - - # 从共享内存读取所有rank的mem_manager - node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) - ] - task_out_queue.put("get_mem_managers_ok") - connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} - - while True: - task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() - if isinstance(task, KVMoveTaskGroup): - _handle_kvmove_task( - task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node - ) - elif isinstance(task, PDTransJoinInfo): - _handle_decode_join(task, task_out_queue, connect_id_to_comm, master_store) - elif isinstance(task, PDTransLeaveInfo): - if task.connect_id in connect_id_to_comm: - connect_id_to_comm[task.connect_id].destroy() - connect_id_to_comm.pop(task.connect_id, None) - logger.info(f"destory {task} nccl communicator.") - else: - logger.error(f"connect id {task.connect_id} dont exist in connect_id_to_comm") - else: - logger.warning(f"unexpected task type: {task}") - - except Exception as e: - logger.error(f"Fatal error happened in kv trans process: {e}") - pass - - -def start_prefill_trans_process( - args, - store_ip, - store_port, - device_id, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, -): - proc = mp.Process(target=_init_env, args=(args, store_ip, store_port, device_id, task_in_queue, task_out_queue)) - proc.start() - assert proc.is_alive() - logger.info(f"prefill trans kv process for device: {device_id} started!") - return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py deleted file mode 100644 index 7b856e54a0..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py +++ /dev/null @@ -1,48 +0,0 @@ -import threading -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class TaskQueue: - def __init__(self, get_func, fail_func): - self.lock = threading.Lock() - self.datas = [] - self.get_func = get_func - self.fail_func = fail_func - self.has_error = False - - def size(self): - return len(self.datas) - - def put(self, obj, error_handle_func=None): - if self.has_error: - if error_handle_func is not None: - error_handle_func(obj) - raise Exception("has error") - - with self.lock: - self.datas.append(obj) - - def put_list(self, objs): - if self.has_error: - raise Exception("has error") - - with self.lock: - self.datas.extend(objs) - - def get_tasks(self, log_tag=None): - with self.lock: - ans = self.get_func(self.datas) - self.datas = self.datas[len(ans) :] - if len(self.datas) != 0: - logger.info(f"queue {log_tag} left size: {len(self.datas)}") - return ans - - def clear_tasks(self): - with self.lock: - if len(self.datas) != 0: - for obj in self.datas: - self.fail_func(obj) - self.datas = [] - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py deleted file mode 100644 index cd1360fd0a..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import threading -import torch.multiprocessing as mp -from queue import Empty - - -def join_if_alive(thread: threading.Thread): - if thread is not None and thread.is_alive(): - try: - thread.join() - except Exception: - pass - return - - -def clear_queue(queue: mp.Queue): - while not queue.empty(): - try: - queue.get_nowait() - except Empty: - break diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb620..f1681eda52 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -14,7 +14,6 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from ..chunked_prefill.impl import ChunkedPrefillBackend -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.utils.envs_utils import get_env_start_args @@ -140,7 +139,7 @@ def _diverse_pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: b pack = InferReqUpdatePack(req_obj=req_obj, output_len=0) update_func_objs.append(pack) pre_master_req_pack = pack - # TODO 如果 diverse mode 需要支持 nixl pd 分离,则应该每个分块prefill后都进行相关的复制, + # TODO 如果 diverse mode 需要支持 pd 分离,则应该每个分块prefill后都进行相关的复制, # 暂时不支持 diverse mode 和 pd 模式的混合 continue @@ -167,7 +166,6 @@ def _diverse_pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: b return update_func_objs def _master_req_to_radix_cache(self, master_req: InferReq): - g_infer_state_lock.acquire() key = master_req.get_input_token_ids()[0 : master_req.cur_kv_len] key = torch.tensor(key, dtype=torch.int64, device="cpu") value = self.model.req_manager.req_to_token_indexs[master_req.req_idx][: master_req.cur_kv_len].detach().cpu() @@ -189,11 +187,9 @@ def _master_req_to_radix_cache(self, master_req: InferReq): share_node, kv_len, value = self.radix_cache.match_prefix(key, update_refs=False) assert share_node == new_shared_kv_node and kv_len == master_req.cur_kv_len self.model.req_manager.req_to_token_indexs[master_req.req_idx][0 : master_req.cur_kv_len] = value - g_infer_state_lock.release() return def _copy_master_req_to_slave_req(self, slave_req: InferReq): - g_infer_state_lock.acquire() master_req = slave_req.related_master_req assert master_req is not None @@ -213,6 +209,4 @@ def _copy_master_req_to_slave_req(self, slave_req: InferReq): slave_req.shm_req.shm_cur_kv_len = slave_req.cur_kv_len assert kv_len <= slave_req.shm_req.input_len - - g_infer_state_lock.release() return diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py index 5de90bef66..2fa2c9cb9a 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py @@ -18,12 +18,11 @@ class DPKVSharedMoudle: _KV_LEN_INDEX = 0 _REQ_IDX_INDEX = 1 - def __init__(self, max_req_num: int, max_req_seq_len: int, dp_size_in_node: int, backend): + def __init__(self, max_req_num: int, dp_size_in_node: int, backend): from .impl import DPChunkedPrefillBackend self.backend: DPChunkedPrefillBackend = backend self.max_req_num = max_req_num - self.max_req_seq_len = max_req_seq_len # 0 代表 kv_len, 1 代表 radix_cache_len self.shared_req_infos = ShmArray( @@ -111,7 +110,7 @@ def kv_trans(self, trans_tasks: List["TransTask"]): max_kv_len_mem_indexes_tensor = torch.cat(max_kv_len_mem_indexes).to(dtype=torch.int64, device="cuda") max_kv_len_dp_ranks_tensor = torch.tensor(max_kv_len_dp_ranks, dtype=torch.int32, device="cuda") mem_indexes_tensor = torch.cat(mem_indexes).to(dtype=torch.int64, device="cuda") - self.backend.model.mem_manager.copy_kv_from_other_dp_ranks( + self.backend.model.mem_manager.operator.copy_kv_from_other_dp_ranks( mem_managers=self.backend.mem_managers, move_token_indexes=max_kv_len_mem_indexes_tensor, token_dp_indexes=max_kv_len_dp_ranks_tensor, diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e76..e6b9d1c18d 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -14,7 +14,6 @@ padded_overlap_prepare_decode_inputs, ) from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.server.router.model_infer.mode_backend.mtp_pre_process import ( prepare_mtp_prefill_inputs, ) @@ -70,8 +69,6 @@ def _init_reqs(self, reqs: List[Tuple]): current_dp_reqs = [req for req in reqs if req[3] == dp_rank_in_node] other_dp_reqs = [req for req in reqs if req[3] != dp_rank_in_node] - g_infer_state_lock.acquire() - infer_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True) req_dp_ranks = [req[3] for req in reqs] self.dp_kv_shared_module.fill_reqs_info(reqs=infer_reqs) @@ -79,9 +76,12 @@ def _init_reqs(self, reqs: List[Tuple]): self.dp_kv_shared_module.kv_trans(trans_tasks=trans_taskes) g_infer_context._filter(finished_request_ids=[req[0] for req in other_dp_reqs]) - g_infer_state_lock.release() req_ids = [e[0] for e in current_dp_reqs] + + if self.args.enable_cpu_cache: + self._load_cpu_cache_to_reqs(req_ids=req_ids) + return req_ids def infer_loop(self): @@ -114,12 +114,18 @@ def infer_loop(self): ) if run_way.is_prefill(): + # 进行一次流同步,保证 _try_read_new_reqs 中的一些算子操作,必然已经完成。 + # 防止后续的推理流程读取到显存中可能存在错误的数据。 + g_infer_context.get_overlap_stream().wait_stream(torch.cuda.current_stream()) self.prefill( event_pack=event_pack, prefill_reqs=prefill_reqs, ) continue elif run_way.is_decode(): + # 进行一次流同步,保证 _try_read_new_reqs 中的一些算子操作,必然已经完成。 + # 防止后续的推理流程读取到显存中可能存在错误的数据。 + g_infer_context.get_overlap_stream().wait_stream(torch.cuda.current_stream()) self.decode( event_pack=event_pack, decode_reqs=decode_reqs, @@ -155,6 +161,10 @@ def prefill_normal( b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu[:run_reqs_num], mask_func=None, ) + g_infer_context.copy_linear_att_state_to_cache_buffer( + b_req_idx=model_input.b_req_idx[:run_reqs_num], + reqs=run_reqs, + ) sync_event = torch.cuda.Event() sync_event.record() @@ -172,7 +182,7 @@ def prefill_normal( next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -263,6 +273,10 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer b_prefill_has_output_cpu=b_has_out_cpu, mask_func=None, ) + + if g_infer_context.is_linear_att_mixed_model: + g_infer_context.copy_linear_att_state_to_cache_buffer(b_req_idx=b_req_idx, reqs=run_reqs) + sync_event = torch.cuda.Event() sync_event.record() @@ -281,7 +295,7 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -384,6 +398,9 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] model_output=model_output, next_token_ids=draft_next_token_ids_gpu, ) + if req_num > 0: + g_infer_context.copy_linear_att_state_to_cache_buffer(b_req_idx=b_req_idx, reqs=run_reqs) + sync_event = torch.cuda.Event() sync_event.record() @@ -403,7 +420,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 @@ -499,9 +516,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): extra_post_req_handle_func=self.extra_post_req_handle_func, ) if len(need_free_mem_indexes) > 0: - g_infer_state_lock.acquire() g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) - g_infer_state_lock.release() # 第四阶段 event_pack.notify_pre_post_handle() @@ -573,12 +588,9 @@ def _draft_decode_eagle( real_req_num = req_num // (self.mtp_step + 1) padded_req_num = model_input.batch_size // (self.mtp_step + 1) - real_req_num eagle_mem_indexes_cpu = None - - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) - g_infer_state_lock.release() eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) # process the draft model output @@ -687,6 +699,10 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I draft_next_token_ids_gpu0 = self._gen_argmax_token_ids(draft_model_output0) draft_next_token_ids_gpu1 = self._gen_argmax_token_ids(draft_model_output1) + if req_num0 + req_num1 > 0 and g_infer_context.is_linear_att_mixed_model: + _b_req_idx = torch.cat((model_input0.b_req_idx[0:req_num0], model_input1.b_req_idx[0:req_num1]), dim=0) + g_infer_context.copy_linear_att_state_to_cache_buffer(b_req_idx=_b_req_idx, reqs=run_reqs) + sync_event = torch.cuda.Event() sync_event.record() @@ -703,7 +719,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) event_pack.notify_pre_post_handle() else: @@ -817,9 +833,7 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf extra_post_req_handle_func=self.extra_post_req_handle_func, ) if len(need_free_mem_indexes) > 0: - g_infer_state_lock.acquire() g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) - g_infer_state_lock.release() event_pack.notify_pre_post_handle() else: event_pack.notify_post_handle_and_wait_pre_post_handle() @@ -932,11 +946,9 @@ def _draft_decode_eagle_overlap( real_req_num = real_req_num0 + real_req_num1 padded_req_num0 = model_input0.batch_size // (self.mtp_step + 1) - real_req_num0 padded_req_num1 = model_input1.batch_size // (self.mtp_step + 1) - real_req_num1 - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) - g_infer_state_lock.release() eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) eagle_mem_indexes0 = eagle_mem_indexes[0 : real_req_num0 * self.mtp_step] eagle_mem_indexes1 = eagle_mem_indexes[real_req_num0 * self.mtp_step : real_req_num * self.mtp_step] diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 03ac4cfb05..68af30b505 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -7,8 +7,8 @@ from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.utils.infer_utils import calculate_time from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from .generic_pre_process import build_b_position_delta def padded_prepare_prefill_inputs( @@ -36,6 +36,7 @@ def padded_prepare_prefill_inputs( b_ready_cache_len = [] b_mtp_index = [] b_prefill_has_output = [] + b_is_decode_req = [] for req in req_objs: @@ -57,6 +58,14 @@ def padded_prepare_prefill_inputs( b_ready_cache_len.append(req.cur_kv_len) b_mtp_index.append(0) + # enable_prefill_decode_mixed 模式下,decode 请求混合在 prefill 请求中。 + # 需要的特殊标记。 + if hasattr(req, "is_decode_req_mixed_in_prefill"): + b_is_decode_req.append(True) + del req.is_decode_req_mixed_in_prefill + else: + b_is_decode_req.append(False) + # padding fake req for prefill for _ in range(padded_req_num): input_ids.append([1]) @@ -69,6 +78,7 @@ def padded_prepare_prefill_inputs( total_token_num += 1 prefix_total_token_num += 0 batch_multimodal_params.append({"images": [], "audios": []}) + b_is_decode_req.append(False) max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) @@ -78,17 +88,16 @@ def padded_prepare_prefill_inputs( input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cpu") b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") + b_is_decode_req = torch.tensor(b_is_decode_req, dtype=torch.bool, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu") b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu") b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) - g_infer_state_lock.release() if padded_req_num > 0: mem_indexes = F.pad( @@ -110,6 +119,7 @@ def padded_prepare_prefill_inputs( b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, + b_is_decode_req=b_is_decode_req, b_ready_cache_len=b_ready_cache_len, b_prefill_start_loc=b_prefill_start_loc, is_prefill=True, @@ -187,14 +197,13 @@ def padded_prepare_decode_inputs( b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_position_delta = build_b_position_delta(batch_multimodal_params) # dynamic prompt cache 准备 token padded_mem_indexes_num = padded_req_num * (args_mtp_step + 1) - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_mem_indexes_num) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_mem_indexes_num) - g_infer_state_lock.release() if padded_mem_indexes_num > 0: mem_indexes = F.pad( @@ -214,6 +223,7 @@ def padded_prepare_decode_inputs( b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, + b_position_delta=b_position_delta, is_prefill=False, multimodal_params=batch_multimodal_params, ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index f3ad03662e..5b29ea0510 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,7 +1,8 @@ import torch from typing import List, Tuple -from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty -from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.envs_utils import get_env_start_args @@ -15,7 +16,10 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_top_ks, b_length_penalty_param, b_mask_eos_reqs, + invalid_token_ids, + cu_invalid_token_num, is_all_greedy, + has_invalid_token_ids, skip_top_k, skip_top_p, exist_req_use_random_seed, @@ -63,6 +67,14 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): eos_ids=eos_ids, sampling_params_manager=sampling_params_manager, ) + + if has_invalid_token_ids: + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + logits.div_(b_temperatures.view((-1, 1))) probs = torch.softmax(logits, dim=-1) @@ -102,7 +114,9 @@ def _top_p_top_k_sample( b_top_ks: torch.Tensor, exist_req_use_random_seed: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - if get_env_start_args().sampling_backend == "triton": + sampling_backend = get_env_start_args().sampling_backend + + if sampling_backend == "triton": probs_sort, probs_idx = _top_p_top_k(probs, b_top_ps, b_top_ks) if not exist_req_use_random_seed: sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) @@ -112,8 +126,8 @@ def _top_p_top_k_sample( next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index)) return next_token_ids.view(-1), next_token_logprobs.view(-1) - elif get_env_start_args().sampling_backend == "sglang_kernel": - from sgl_kernel import top_k_top_p_sampling_from_probs + elif sampling_backend == "flashinfer": + from flashinfer.sampling import top_k_top_p_sampling_from_probs batch_next_token_ids = top_k_top_p_sampling_from_probs( probs, @@ -152,6 +166,12 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_p = True exist_req_use_random_seed = False + # invalid token ids + invalid_token_ids: List[int] = [] + has_invalid_token_ids = False + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param shm_param = sample_param.shm_param @@ -173,6 +193,11 @@ def _get_post_sample_tensors(reqs: List[InferReq]): if req_obj.generator is not None: exist_req_use_random_seed = True req_idxes.append(req_obj.req_idx) + invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids) + cu_invalid_token_num.append(invalid_token_num_start) + if len(req_obj.sampling_param.invalid_token_ids) > 0: + has_invalid_token_ids = True + invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) @@ -183,6 +208,14 @@ def _get_post_sample_tensors(reqs: List[InferReq]): ) mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool) + if has_invalid_token_ids: + invalid_token_ids_cpu = g_pin_mem_manager.gen_from_list( + key="invalid_token_ids", data=invalid_token_ids, dtype=torch.int32 + ) + cu_invalid_token_num_cpu = g_pin_mem_manager.gen_from_list( + key="cu_invalid_token_num", data=cu_invalid_token_num, dtype=torch.int32 + ) + return ( req_idxes_cpu.cuda(non_blocking=True), temperatures_cpu.cuda(non_blocking=True), @@ -190,7 +223,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): top_ks_cpu.cuda(non_blocking=True), length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), + invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, + cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, is_all_greedy, + has_invalid_token_ids, skip_top_k, skip_top_p, exist_req_use_random_seed, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 4eb8c7e1e6..ae294544ce 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -2,7 +2,6 @@ import numpy as np from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelInput from lightllm.utils.envs_utils import ( enable_diverse_mode_gqa_decode_fast_kernel, @@ -22,6 +21,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> b_ready_cache_len = [] b_mtp_index = [] b_prefill_has_output = [] + b_is_decode_req = [] for req in req_objs: run_reqs.append(req) @@ -47,6 +47,11 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> prefix_total_token_num += req.cur_kv_len b_ready_cache_len.append(req.cur_kv_len) b_mtp_index.append(0) + if hasattr(req, "is_decode_req_mixed_in_prefill"): + b_is_decode_req.append(True) + del req.is_decode_req_mixed_in_prefill + else: + b_is_decode_req.append(False) max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) @@ -56,17 +61,16 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cpu") b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") + b_is_decode_req = torch.tensor(b_is_decode_req, dtype=torch.bool, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu") b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu") b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) - g_infer_state_lock.release() model_input = ModelInput( batch_size=b_seq_len.shape[0], @@ -79,6 +83,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, + b_is_decode_req=b_is_decode_req, b_ready_cache_len=b_ready_cache_len, b_prefill_start_loc=b_prefill_start_loc, is_prefill=True, @@ -125,6 +130,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_position_delta = build_b_position_delta(multimodal_params) if enable_diverse_mode_gqa_decode_fast_kernel(): b_shared_seq_len, b_mark_shared_group = build_diverse_shared_group_infos(run_reqs=run_reqs) @@ -133,11 +139,9 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mark_shared_group = None # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) - g_infer_state_lock.release() model_input = ModelInput( batch_size=b_seq_len.shape[0], @@ -149,6 +153,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, + b_position_delta=b_position_delta, b_shared_seq_len=b_shared_seq_len, b_mark_shared_group=b_mark_shared_group, is_prefill=False, @@ -157,6 +162,18 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In return model_input, run_reqs +def build_b_position_delta(multimodal_params: List[dict]) -> torch.Tensor: + b_position_delta = [] + for params in multimodal_params: + position_delta = 0 + for image in params.get("images", []): + grid_thwd = image.get("grid_thwd") + if grid_thwd is not None: + position_delta += grid_thwd[3] + b_position_delta.append(position_delta) + return torch.tensor(b_position_delta, dtype=torch.int32, device="cpu") + + def build_diverse_shared_group_infos(run_reqs: List[InferReq]) -> Tuple[torch.Tensor, torch.Tensor]: # b_shared_seq_len 和 b_mark_shared_group 只会在 diverse_mode 下的 decode 阶段真正被使用的参数, # 用于记录请求间的共享关系。 diff --git a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py index dbce73a94a..3ef0395431 100644 --- a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py @@ -7,7 +7,12 @@ def prepare_mtp_prefill_inputs( model_input: ModelInput, b_next_token_ids: torch.Tensor, mtp_draft_input_hiddens: torch.Tensor ): + # enable_prefill_decode_mixed 模式下,decode 请求混合在 prefill 请求中。 + # 但是mtp的input_ids已经是恢复ok,已经是正常的input_ids, 所以移除掉 b_is_decode_req。 + # 防止在 forward 阶段,因为 b_is_decode_req 不为空,导致 input_ids 被特殊处理。 new_model_input = copy.copy(model_input) + new_model_input.b_is_decode_req = None + new_input_ids = gen_mtp_new_input_ids( input_ids=model_input.input_ids, b_next_token_ids=b_next_token_ids, diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 7c4168a937..d0025a03c1 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -2,10 +2,12 @@ import torch.distributed as dist import torch import dataclasses +import bisect from functools import lru_cache from typing import Optional, List, Deque from collections import deque from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient +from lightllm.utils.config_utils import is_linear_att_mixed_model from lightllm.utils.envs_utils import get_env_start_args from ..infer_batch import InferReq from lightllm.utils.dist_utils import create_new_group_for_current_dp @@ -26,6 +28,9 @@ def __init__(self, backend): self.filter_group = create_new_group_for_current_dp("gloo") self.init_sync_group = create_new_group_for_current_dp("nccl") dist.barrier(group=self.init_sync_group) + self.offload_sync_group = create_new_group_for_current_dp("nccl") + dist.barrier(group=self.offload_sync_group) + self.offload_sync_tensor = torch.empty((1,), dtype=torch.int32, device="cuda") self.page_index_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.int32, device="cuda") self.page_ready_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.bool, device="cuda") @@ -33,14 +38,6 @@ def __init__(self, backend): self.cpu_cache_handle_queue: Deque[TransTask] = deque() self.cpu_cache_client = CpuKvCacheClient(only_create_meta_data=False, init_shm_data=False) - def wait(self): - """ - 等待 cpu cache 相关页面注册完成 - """ - attach_shm_handle = self.cpu_cache_client.attach_shm_handle - if attach_shm_handle is not None: - attach_shm_handle.wait() - @lru_cache() def need_sync_compute_stream(self) -> bool: """ @@ -63,12 +60,19 @@ def need_sync_compute_stream(self) -> bool: def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): idle_token_num = g_infer_context.get_can_alloc_token_num() - token_page_size = self.args.cpu_cache_token_page_size all_page_list = [] is_master_in_dp = self.backend.is_master_in_dp for req in reqs: page_list = req.shm_req.cpu_cache_match_page_indexes.get_all() - match_tokens = len(page_list) * token_page_size + page_len_list = req.shm_req.token_hash_page_len_list.get_all() + page_len_start_list = [0] + page_len_list + assert len(page_list) <= len(page_len_list) + + if page_list: + match_tokens = page_len_list[len(page_list) - 1] + else: + match_tokens = 0 + # 更新命中的 cpu kv cache 长度, 减去radix cache和disk cache的部分. if is_master_in_dp: req.shm_req.cpu_prompt_cache_len = max( @@ -83,57 +87,55 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num) # 计算需要加载的页面(只加载未匹配的部分) - cur_kv_pages = req.cur_kv_len // token_page_size - need_pages = page_list[cur_kv_pages:] # 只取需要的页面 + ready_page_num = bisect.bisect_right(page_len_list, req.cur_kv_len) + assert ready_page_num <= len(page_list) + need_pages = page_list[ready_page_num:] # 只取需要的页面 mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num) if self.need_sync_compute_stream(): # TODO fa3 现在必须使用同步模式, 未来需要移除 - g_infer_context.get_overlap_stream().synchronize() - - # TODO 更有效的分配策略。 - grid_num = 16 + torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream()) + # g_infer_context.get_overlap_stream().synchronize() mem_manager = self.backend.model.mem_manager - if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None: - cpu_cache_meta = self.cpu_cache_client.kv_cache_tensor_meta - cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor[ - :, :, :, :, 0 : cpu_cache_meta.head_dim - ] - cpu_kv_cache_scale = self.cpu_cache_client.cpu_kv_cache_tensor[ - :, :, :, :, cpu_cache_meta.head_dim : - ].view(mem_manager.scale_buffer.dtype) - gpu_kv_cache_scale = mem_manager.scale_buffer - else: - cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor - cpu_kv_cache_scale = None - gpu_kv_cache_scale = None + req_manager = self.backend.model.req_manager mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda( non_blocking=True ) - # 将 cpu page 的内容拷贝到 gpu 页面中 - load_cpu_kv_to_gpu( - gpu_mem_indexes=mem_indexes_cuda, - gpu_kv_cache=mem_manager.kv_buffer, - gpu_kv_cache_scale=gpu_kv_cache_scale, - cpu_kv_cache=cpu_kv_cache, - cpu_kv_cache_scale=cpu_kv_cache_scale, + # 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做, + # 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以 + # 这里需要进行pad操作,使操作的页面是完整的。 + _start = page_len_start_list[ready_page_num] + + _end = req.cur_kv_len + assert 0 <= _start <= _end, f"invalid pad range [{_start}, {_end}]" + mem_indexes_cuda = torch.cat( + [req_manager.req_to_token_indexs[req.req_idx, _start:_end], mem_indexes_cuda] + ) + + assert ( + len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num] + ) + + # 更新 req 状态。 + idle_token_num -= need_token_num + g_infer_context.req_manager.req_to_token_indexs[ + req.req_idx, req.cur_kv_len : (req.cur_kv_len + need_token_num) + ] = mem_indexes + req.cur_kv_len = req.cur_kv_len + need_token_num + + mem_manager.operator.load_cpu_cache_to_gpu( + mem_indexes=mem_indexes_cuda, page_indexes=page_indexes_cuda, - tp_index=self.backend.rank_in_dp, - tp_world_size=self.backend.dp_world_size, - grid_num=grid_num, + cpu_cache_client=self.cpu_cache_client, + req=req, ) torch.cuda.current_stream().synchronize() - idle_token_num -= need_token_num - g_infer_context.req_manager.req_to_token_indexs[ - req.req_idx, req.cur_kv_len : (req.cur_kv_len + need_token_num) - ] = mem_indexes - req.cur_kv_len = req.cur_kv_len + need_token_num if self.backend.is_master_in_dp: req.shm_req.shm_cur_kv_len = req.cur_kv_len @@ -164,10 +166,12 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L continue # 过滤不适合进行 kv 卸载到 cpu cache 的请求。 - if ( - req.cur_kv_len < self.args.cpu_cache_token_page_size - or req.shm_req.input_len <= self.args.cpu_cache_token_page_size - ): + if g_infer_context.is_linear_att_mixed_model: + offload_limit_size = self.args.linear_att_hash_page_size + else: + offload_limit_size = self.args.cpu_cache_token_page_size + + if req.cur_kv_len < offload_limit_size or req.shm_req.input_len <= offload_limit_size: true_finished_reqs.append(req) continue @@ -205,11 +209,20 @@ def _start_kv_cache_offload_task( self, req: InferReq, cpu_kv_cache_stream: torch.cuda.Stream ) -> Optional["TransTask"]: with torch.cuda.stream(cpu_kv_cache_stream): + # 综合考虑后只对prompt做缓存管理,不包含decode内容,这里与radix cache不一致 + token_hash_list = req.shm_req.token_hash_list.get_all() + page_len_list = req.shm_req.token_hash_page_len_list.get_all() + assert len(token_hash_list) == len(page_len_list) + if self.backend.is_master_in_dp: - # 综合考虑后只对prompt做缓存管理,不包含decode内容,这里与radix cache不一致 - token_hash_list = req.shm_req.token_hash_list.get_all() - block_size = req.cur_kv_len // self.args.cpu_cache_token_page_size - move_block_size = min(block_size, len(token_hash_list)) + + find_index = bisect.bisect_right(page_len_list, req.cur_kv_len) + move_block_size = find_index + + # 对于 linear att 模型, 如果最后一个页面是碎页,需要做特殊处理,判断该碎页是否满足卸载条件。 + move_block_size = self._handle_linear_att_last_page( + req=req, move_block_size=move_block_size, page_len_list=page_len_list + ) if move_block_size == 0: dist.broadcast_object_list([0], group=self.gloo_group, group_src=0) @@ -252,49 +265,55 @@ def _start_kv_cache_offload_task( cuda_page_indexes.copy_(page_indexes, non_blocking=True) cuda_page_readies.copy_(page_readies, non_blocking=True) - move_token_num = item_size * self.args.cpu_cache_token_page_size - assert req.cur_kv_len >= item_size * self.args.cpu_cache_token_page_size + move_token_num = page_len_list[item_size - 1] + assert req.cur_kv_len >= move_token_num token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0:move_token_num] - # TODO 更有效的分配策略。 - grid_num = 16 - mem_manager = self.backend.model.mem_manager - if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None: - cpu_cache_meta = self.cpu_cache_client.kv_cache_tensor_meta - cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0 : cpu_cache_meta.head_dim] - cpu_kv_cache_scale = self.cpu_cache_client.cpu_kv_cache_tensor[ - :, :, :, :, cpu_cache_meta.head_dim : - ].view(mem_manager.scale_buffer.dtype) - gpu_kv_cache_scale = mem_manager.scale_buffer - else: - cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor - cpu_kv_cache_scale = None - gpu_kv_cache_scale = None - - # assert max(page_list) < self.cpu_cache_client.cpu_kv_cache_tensor.shape[0] - offload_gpu_kv_to_cpu( - token_indexes=token_indexes, - gpu_kv_cache=mem_manager.kv_buffer, - gpu_kv_cache_scale=gpu_kv_cache_scale, - cpu_kv_cache=cpu_kv_cache, - cpu_kv_cache_scale=cpu_kv_cache_scale, + + mem_manager.operator.offload_gpu_kv_to_cpu_cache( + mem_indexes=token_indexes, page_indexes=cuda_page_indexes, page_readies=cuda_page_readies, - tp_index=self.backend.rank_in_dp, - tp_world_size=self.backend.dp_world_size, - grid_num=grid_num, + cpu_cache_client=self.cpu_cache_client, + req=req, ) + # 这个操作只是为了在offload 对应的cuda stream中,同步标记下对应的kv cache offload 操作已经完成, + if self.backend.dp_world_size > 1: + dist.all_reduce(self.offload_sync_tensor, op=dist.ReduceOp.MAX, group=self.offload_sync_group) + sync_event = torch.cuda.Event() sync_event.record() req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING trans_task = TransTask( - page_indexes=page_indexes, page_readies=page_readies, req_obj=req, sync_event=sync_event + move_token_num=move_token_num, + page_indexes=page_indexes, + page_readies=page_readies, + req_obj=req, + sync_event=sync_event, ) return trans_task + def _handle_linear_att_last_page(self, req: InferReq, move_block_size: int, page_len_list: List[int]) -> int: + if not g_infer_context.is_linear_att_mixed_model: + return move_block_size + + if move_block_size == 0: + return 0 + + if move_block_size == len(page_len_list): + tail_len = page_len_list[move_block_size - 1] + if tail_len % self.args.cpu_cache_token_page_size != 0: + # 全局关闭了碎页的cpu cache 存储功能。 + if self.args.disable_linear_att_small_page_cpu_cache: + return move_block_size - 1 + # 说明是碎页,碎页需要判定是否满足cpu cache 的offload条件。 + if req.tail_linear_att_small_page_buffer_id is None: + return move_block_size - 1 + return move_block_size + def update_cpu_cache_task_states(self): if self.backend.is_master_in_dp: trans_ok_tasks = [] @@ -315,12 +334,16 @@ def update_cpu_cache_task_states(self): if item_size > 0: page_array_list = [task.page_indexes.tolist() for task in trans_ok_tasks] + move_token_nums = [task.move_token_num for task in trans_ok_tasks] if self.backend.is_master_in_dp: self.cpu_cache_client.lock.acquire_sleep1ms() # 分组update,避免不同请求的page交叉,导致disk cache hash不一致 - for pages in page_array_list: + for pages, move_token_num in zip(page_array_list, move_token_nums): self.cpu_cache_client.update_pages_status_to_ready( - page_list=pages, deref=True, disk_offload_enable=self.args.enable_disk_cache + page_list=pages, + deref=True, + disk_offload_enable=self.args.enable_disk_cache, + token_num_in_page_list=move_token_num, ) self.cpu_cache_client.lock.release() for task in trans_ok_tasks: @@ -330,6 +353,7 @@ def update_cpu_cache_task_states(self): @dataclasses.dataclass class TransTask: + move_token_num: int page_indexes: torch.Tensor page_readies: torch.Tensor req_obj: InferReq diff --git a/lightllm/server/router/model_infer/mode_backend/pd/__init__.py b/lightllm/server/router/model_infer/mode_backend/pd/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd/base_kv_move_manager.py similarity index 89% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py rename to lightllm/server/router/model_infer/mode_backend/pd/base_kv_move_manager.py index eb11728029..125edede25 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/base_kv_move_manager.py @@ -6,7 +6,7 @@ import torch.multiprocessing as mp from typing import List, Dict, Union, Callable, Optional from lightllm.utils.log_utils import init_logger -from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.server.pd_io_struct import PDChunckedTransTaskRet from lightllm.server.core.objs import StartArgs from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from .trans_process_obj import KVTransProcess @@ -47,7 +47,7 @@ def __init__( threading.Thread(target=self.task_ret_handle_loop, args=(trans_process,), daemon=True).start() # 通过 io buffer 将命令写入到推理进程中 - self.shm_nixl_trans_io_buffer = ShmObjsIOBuffer(tail_str="nixl") + self.shm_pd_trans_io_buffer = ShmObjsIOBuffer(tail_str="pd") for func in [self.task_dispatcher_loop, self.task_ret_upload_loop, self.check_trans_process_loop]: threading.Thread(target=func, daemon=True).start() @@ -66,15 +66,15 @@ def task_dispatcher_loop(self): @log_exception def task_ret_upload_loop(self): while True: - ret_obj: NIXLChunckedTransTaskRet = self.ret_obj_queue.get() - ret_objs: List[NIXLChunckedTransTaskRet] = [ret_obj] + ret_obj: PDChunckedTransTaskRet = self.ret_obj_queue.get() + ret_objs: List[PDChunckedTransTaskRet] = [ret_obj] ret_objs.extend(self._collect_return_objects()) while True: - if self.shm_nixl_trans_io_buffer.is_empty(): + if self.shm_pd_trans_io_buffer.is_empty(): # to do, 这里写入的数量,可能会超过共享管道的大小。 - self.shm_nixl_trans_io_buffer.write_obj(ret_objs) - self.shm_nixl_trans_io_buffer.set_ready() + self.shm_pd_trans_io_buffer.write_obj(ret_objs) + self.shm_pd_trans_io_buffer.set_ready() break else: time.sleep(0.01) @@ -97,7 +97,7 @@ def _collect_return_objects(self): @log_exception def task_ret_handle_loop(self, trans_process: KVTransProcess): while True: - ret_obj: NIXLChunckedTransTaskRet = trans_process.task_out_queue.get() + ret_obj: PDChunckedTransTaskRet = trans_process.task_out_queue.get() self.ret_obj_queue.put(ret_obj) # ================================================================================== diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/__init__.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/__init__.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/__init__.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl.py similarity index 72% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl.py index f1309ca9cc..242e1089e7 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl.py @@ -1,9 +1,9 @@ import random import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskGroup, NIXLAbortReq +from lightllm.server.pd_io_struct import PDChunckedTransTask, PDChunckedTransTaskGroup, PDAbortReq from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, g_infer_state_lock +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.server.core.objs import FinishStatus from lightllm.utils.log_utils import init_logger from lightllm.utils.device_utils import kv_trans_use_p2p @@ -11,7 +11,7 @@ logger = init_logger(__name__) -class NIXLDecodeNode(ChunkedPrefillBackend): +class PDDecodeNode(ChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.info_queue: mp.Queue = info_queue @@ -31,13 +31,13 @@ def _init_reqs(self, reqs: List[Tuple]): dp_rank_in_node = self.dp_rank_in_node reqs = [req for req in reqs if req[3] == dp_rank_in_node] - g_infer_state_lock.acquire() - - uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=False) + uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True) # 匹配radix cache,并更新一些资源的管理。 self._post_init_reqs(uninit_reqs=uninit_reqs) - g_infer_state_lock.release() + # pd nixl 的 decode 节点模式下当前不支持 cpu cache, 未来可能会支持。 + assert not self.args.enable_cpu_cache + req_ids = [e[0] for e in reqs] return req_ids @@ -50,26 +50,9 @@ def _post_init_reqs(self, uninit_reqs: List[InferReq]): for req_obj in uninit_reqs: req_obj: InferReq = req_obj # for easy typing - request_id = req_obj.req_id - if request_id > 0: - req_obj._match_radix_cache() - # 构建 chuncked trans task - self._decode_node_gen_trans_tasks(req_obj=req_obj) - else: - # 对于不合法的请求, 主要是health请求,直接模拟将其finished掉 - req_obj.cur_output_len += 1 - req_obj.set_next_gen_token_id(0, 0.0, 1) - req_obj.finish_status.set_status(FinishStatus.FINISHED_STOP) - - if self.is_master_in_dp: - req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len - req_obj.shm_req.shm_cur_output_len = req_obj.cur_output_len - req_obj.shm_req.finish_token_index = req_obj.get_cur_total_len() - 1 - req_obj.shm_req.finish_status.set_status(FinishStatus.FINISHED_STOP) - req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len - - req_id = req_obj.shm_req.request_id - logger.error(f"req_id: {req_id} forced to finished") + # 构建 chuncked trans task + self._decode_node_gen_trans_tasks(req_obj=req_obj) + return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: @@ -78,18 +61,17 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: 主要用于在 nixl pd 分离模式下, 由子类继承重载, prefill 和 decode 节点过滤 kv 传输错误,或者 kv 传输没有完成的请求。 """ - g_infer_state_lock.acquire() ans_list: List[InferReq] = [] for request_id in req_ids: req_obj: InferReq = g_infer_context.requests_mapping[request_id] - if self.is_master_in_dp and req_obj.infer_aborted and req_obj.nixl_pd_task_num != 0: - self.info_queue.put(NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id)) + if self.is_master_in_dp and req_obj.infer_aborted and req_obj.pd_task_num != 0: + self.info_queue.put(PDAbortReq(request_id=req_obj.req_id, device_id=req_obj.pd_trans_device_id)) - if req_obj.nixl_pd_task_num != (req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num): + if req_obj.pd_task_num != (req_obj.pd_task_failed_num + req_obj.pd_task_success_num): continue - if req_obj.nixl_pd_task_failed_num > 0: + if req_obj.pd_task_failed_num > 0: # 强制停止 if not req_obj.finish_status.is_finished(): req_obj.cur_output_len += 1 @@ -122,20 +104,18 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len ans_list.append(req_obj) - - g_infer_state_lock.release() return ans_list def _decode_node_gen_trans_tasks(self, req_obj: InferReq): """ decode node 生成所有的传输任务对象。 """ - group = NIXLChunckedTransTaskGroup() + group = PDChunckedTransTaskGroup() input_len = req_obj.shm_req.input_len # 当 decode 节点不能匹配足够的kv的时候,才进行真实的 kv 传输。 if input_len - req_obj.cur_kv_len > 1: - page_size = self.args.nixl_pd_kv_page_size - req_obj.nixl_trans_kv_start_index = req_obj.cur_kv_len + page_size = self.args.pd_kv_page_size + req_obj.pd_trans_kv_start_index = req_obj.cur_kv_len need_mem_size = input_len - req_obj.cur_kv_len if need_mem_size > 0: @@ -147,13 +127,13 @@ def _decode_node_gen_trans_tasks(self, req_obj: InferReq): req_obj.req_idx, req_obj.cur_kv_len : (req_obj.cur_kv_len + need_mem_size) ] = mem_indexes - while req_obj.nixl_trans_kv_start_index < input_len: - cur_page_size = min(page_size, input_len - req_obj.nixl_trans_kv_start_index) + while req_obj.pd_trans_kv_start_index < input_len: + cur_page_size = min(page_size, input_len - req_obj.pd_trans_kv_start_index) # 生成页面传输任务, 放入kv move manager 的处理队列中 - start_index = req_obj.nixl_trans_kv_start_index - end_index = req_obj.nixl_trans_kv_start_index + cur_page_size + start_index = req_obj.pd_trans_kv_start_index + end_index = req_obj.pd_trans_kv_start_index + cur_page_size page_mem_indexes = mem_indexes[start_index - req_obj.cur_kv_len : end_index - req_obj.cur_kv_len] - self._create_nixl_trans_task( + self._create_pd_trans_task( req_obj=req_obj, mem_indexes=page_mem_indexes.tolist(), kv_start_index=start_index, @@ -161,16 +141,27 @@ def _decode_node_gen_trans_tasks(self, req_obj: InferReq): group=group, ) # update - req_obj.nixl_trans_kv_start_index += cur_page_size + req_obj.pd_trans_kv_start_index += cur_page_size req_obj.cur_kv_len += len(mem_indexes) + + # 如果当前是linear att 混合模型,则需要创建一个linear att 状态的传输任务 + if g_infer_context.is_linear_att_mixed_model: + self._create_pd_trans_task( + req_obj=req_obj, + mem_indexes=[], + kv_start_index=input_len, + kv_end_index=input_len, + group=group, + page_kind="linear_att_state", + ) else: assert req_obj.cur_kv_len == input_len - 1 if not group.task_list: # 需要上报一个包含 0 长度的trans task,触发 kv move manager 给 pd master 上报 # upkv_status 状态,使推理流程完整。 - self._create_nixl_trans_task( + self._create_pd_trans_task( req_obj=req_obj, mem_indexes=[], kv_start_index=req_obj.cur_kv_len, @@ -182,29 +173,40 @@ def _decode_node_gen_trans_tasks(self, req_obj: InferReq): self.info_queue.put(group) return - def _create_nixl_trans_task( + def _create_pd_trans_task( self, req_obj: InferReq, mem_indexes: List[int], kv_start_index: int, kv_end_index: int, - group: NIXLChunckedTransTaskGroup, + group: PDChunckedTransTaskGroup, + page_kind: str = "kv", ): # 确定传输设备 - if req_obj.nixl_trans_device_id == -1: + if req_obj.pd_trans_device_id == -1: + if not hasattr(self, "pd_iter_device_id"): + self.pd_iter_device_id = 0 + req_obj.pd_trans_device_id = self.pd_iter_device_id # only self.is_master_in_dp will be used. - req_obj.nixl_trans_device_id = random.randint(0, self.node_world_size - 1) + self.pd_iter_device_id = (self.pd_iter_device_id + 1) % self.node_world_size + + if page_kind == "kv": + req_idx = None + elif page_kind == "linear_att_state": + req_idx = req_obj.req_idx + else: + raise ValueError(f"unknown PD trans page kind {page_kind}") - trans_task = NIXLChunckedTransTask( + trans_task = PDChunckedTransTask( request_id=req_obj.req_id, start_kv_index=kv_start_index, end_kv_index=kv_end_index, - time_out_secs=80, + time_out_secs=180, pd_master_node_id=req_obj.sampling_param.pd_master_node_id, prefill_dp_index=None, decode_dp_index=self.dp_rank_in_node, src_device_id=None, - dst_device_id=req_obj.nixl_trans_device_id, + dst_device_id=req_obj.pd_trans_device_id, mem_indexes=mem_indexes, prefill_agent_name=None, prefill_agent_metadata=None, @@ -216,7 +218,9 @@ def _create_nixl_trans_task( decode_page_reg_desc=None, first_gen_token_id=None, first_gen_token_logprob=None, + page_kind=page_kind, + req_idx=req_idx, ) group.task_list.append(trans_task) - req_obj.nixl_pd_task_num += 1 + req_obj.pd_task_num += 1 return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl_for_dp.py similarity index 64% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl_for_dp.py index 8bf0dd7c51..87af300003 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl_for_dp.py @@ -3,12 +3,12 @@ from lightllm.utils.log_utils import init_logger from typing import List, Tuple from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend -from .decode_impl import NIXLDecodeNode, NIXLChunckedTransTaskGroup +from .decode_impl import PDDecodeNode, PDChunckedTransTaskGroup logger = init_logger(__name__) -class NIXLDPForDecodeNode(DPChunkedPrefillBackend): +class PDDPForDecodeNode(DPChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.info_queue: mp.Queue = info_queue @@ -16,33 +16,35 @@ def __init__(self, info_queue: mp.Queue) -> None: return def init_custom(self): - return NIXLDecodeNode.init_custom(self) + return PDDecodeNode.init_custom(self) def _init_reqs(self, reqs: List[Tuple]): - return NIXLDecodeNode._init_reqs(self, reqs=reqs) + return PDDecodeNode._init_reqs(self, reqs=reqs) def _post_init_reqs(self, uninit_reqs: List[InferReq]): - return NIXLDecodeNode._post_init_reqs(self, uninit_reqs=uninit_reqs) + return PDDecodeNode._post_init_reqs(self, uninit_reqs=uninit_reqs) def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: - return NIXLDecodeNode._filter_not_ready_reqs(self, req_ids=req_ids) + return PDDecodeNode._filter_not_ready_reqs(self, req_ids=req_ids) def _decode_node_gen_trans_tasks(self, req_obj: InferReq): - return NIXLDecodeNode._decode_node_gen_trans_tasks(self, req_obj=req_obj) + return PDDecodeNode._decode_node_gen_trans_tasks(self, req_obj=req_obj) - def _create_nixl_trans_task( + def _create_pd_trans_task( self, req_obj: InferReq, mem_indexes: List[int], kv_start_index: int, kv_end_index: int, - group: NIXLChunckedTransTaskGroup, + group: PDChunckedTransTaskGroup, + page_kind: str = "kv", ): - return NIXLDecodeNode._create_nixl_trans_task( + return PDDecodeNode._create_pd_trans_task( self, req_obj=req_obj, mem_indexes=mem_indexes, kv_start_index=kv_start_index, kv_end_index=kv_end_index, group=group, + page_kind=page_kind, ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_kv_move_manager.py similarity index 84% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_kv_move_manager.py index 877c5c12db..41bdcd361c 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_kv_move_manager.py @@ -1,15 +1,17 @@ import inspect import pickle +import setproctitle import torch.multiprocessing as mp import time from typing import List, Dict, Optional, Tuple, Union, Callable from lightllm.utils.log_utils import init_logger -from lightllm.server.pd_io_struct import NIXLChunckedTransTaskGroup, NIXLAbortReq +from lightllm.server.pd_io_struct import PDChunckedTransTaskGroup, PDAbortReq from lightllm.server.core.objs import StartArgs from lightllm.utils.graceful_utils import graceful_registry from ..trans_process_obj import KVTransProcess from ..base_kv_move_manager import BaseKVMoveManager from lightllm.utils.error_utils import log_exception +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) @@ -29,6 +31,7 @@ def _init_env(args, info_queue: mp.Queue, event: mp.Event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::decode_kv_move_manager") from .up_status import start_up_kv_status_process @@ -73,11 +76,11 @@ def __init__( def task_dispatcher_loop(self): # 获取任务,并分发给相关卡的处理队列 while True: - task_group: Union[NIXLChunckedTransTaskGroup, NIXLAbortReq] = self.info_queue.get() + task_group: Union[PDChunckedTransTaskGroup, PDAbortReq] = self.info_queue.get() - if isinstance(task_group, NIXLChunckedTransTaskGroup): + if isinstance(task_group, PDChunckedTransTaskGroup): device_id = task_group.task_list[0].dst_device_id - elif isinstance(task_group, NIXLAbortReq): + elif isinstance(task_group, PDAbortReq): device_id = task_group.device_id else: assert False, f"error obj {task_group}" @@ -85,7 +88,7 @@ def task_dispatcher_loop(self): try: trans_process: KVTransProcess = self.kv_trans_processes[device_id] trans_process.task_in_queue.put(task_group) - if isinstance(task_group, NIXLChunckedTransTaskGroup): + if isinstance(task_group, PDChunckedTransTaskGroup): logger.info( f"kv move manager dispatch task group {task_group.task_list[0].to_str()} to device {device_id}" ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_trans_process.py new file mode 100644 index 0000000000..036c6f162b --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_trans_process.py @@ -0,0 +1,452 @@ +import torch +import time +import inspect +import threading +import setproctitle +import torch.multiprocessing as mp +import queue +import pickle +from typing import List, Dict, Union, Optional +from lightllm.utils.log_utils import init_logger +from lightllm.common.kv_cache_mem_manager import MemoryManager +from lightllm.server.pd_io_struct import ( + PDChunckedTransTask, + PDChunckedTransTaskGroup, + PDUpKVStatus, + PDAbortReq, +) +from lightllm.server.pd_io_struct import PDDecodeNodeInfo +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.server.core.objs import StartArgs +from ..kv_transporter import create_kv_transporter +from lightllm.utils.error_utils import log_exception +from lightllm.utils.envs_utils import get_unique_server_name + +logger = init_logger(__name__) + + +def start_decode_trans_process( + args, + device_id, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, + up_status_in_queue: Optional[mp.SimpleQueue], +): + proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, up_status_in_queue)) + proc.start() + assert proc.is_alive() + logger.info(f"prefill trans kv process for device: {device_id} started!") + return proc + + +def _init_env( + args: StartArgs, + device_id: int, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, + up_status_in_queue: Optional[mp.SimpleQueue], +): + import lightllm.utils.rpyc_fix_utils as _ + + import os + + # ------------------------------------------------------------------------- + # 问题背景(PD NIXL + 同卡多进程): + # decode 物理 GPU 上至少有两个独立 CUDA 进程:model_infer(解码推理)与 + # decode_trans(把 prefill 侧 KV page 拷入 decode KV cache)。 + # lm_eval batch=64 时会在短时间内并发大量 read_page;拷贝在 copy_cuda_stream + # 上排队,而推理在另一进程的 stream 上执行,彼此无法 cudaStreamWaitEvent + # 协调。日志里的 read_page_gpu_time(event 差值)会把「等 GPU 时间片 / + # 与推理争抢 SM」算进去,出现数十秒级毛刺,但并不代表单次 memcpy 真那么慢。 + # + # 解决思路:依赖 NVIDIA MPS(Multi-Process Service)在同一 GPU 上多进程 + # 共享上下文并做客户端级调度;在子进程 import torch / 创建 CUDA 上下文 + # **之前**设置下列环境变量(故必须放在本函数最前)。 + # + # CUDA_MPS_CLIENT_PRIORITY="0": + # MPS 下数值越小优先级越高。decode 侧 KV 拷贝处于 decode 关键路径(须先 + # 落盘 KV 才能出首 token),故给 trans 进程最高优先级,减轻被同卡推理 + # 饿死导致的排队放大。须集群已启动 nvidia-cuda-mps-control / mps-server, + # 否则该变量不生效。 启动 mps 的命令为 nvidia-cuda-mps-control -d + # ------------------------------------------------------------------------- + os.environ["CUDA_MPS_CLIENT_PRIORITY"] = "0" + + torch.backends.cudnn.enabled = False + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::decode_trans:Device{device_id}") + + try: + torch.cuda.set_device(device_id) + graceful_registry(inspect.currentframe().f_code.co_name) + + task_out_queue.put("proc_start") + + # 从共享内存读取所有rank的mem_manager + node_world_size = args.tp // args.nnodes + mem_managers: List[MemoryManager] = [ + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + ] + + task_out_queue.put("get_mem_managers_ok") + + manager = _DecodeTransModule( + args=args, + device_id=device_id, + task_in_queue=task_in_queue, + task_out_queue=task_out_queue, + mem_managers=mem_managers, + up_status_in_queue=up_status_in_queue, + ) + assert manager is not None + + while True: + time.sleep(100) + + except Exception as e: + logger.exception(str(e)) + logger.error(f"Fatal error happened in kv trans process: {e}") + pass + + +class _DecodeTransModule: + def __init__( + self, + args: StartArgs, + device_id: int, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], + up_status_in_queue: Optional[mp.SimpleQueue], + ): + self.args = args + self.dp_world_size = self.args.tp // self.args.dp + self.device_id = device_id + self.task_in_queue = task_in_queue + self.task_out_queue = task_out_queue + self.mem_managers = mem_managers + self.up_status_in_queue = up_status_in_queue + cur_mem_manager: MemoryManager = self.mem_managers[device_id] + kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer( + page_num=self.args.pd_kv_page_num, page_size=self.args.pd_kv_page_size + ) + self.copy_cuda_stream = torch.cuda.Stream(priority=-1) + self.transporter = create_kv_transporter( + args=self.args, + node_id=self.args.pd_node_id, + tp_idx=device_id, + kv_move_buffer=kv_move_buffer, + ) + self.recv_task_group_queue = queue.Queue() + self.waiting_dict_lock = threading.Lock() + self.waiting_dict: Dict[str, PDChunckedTransTask] = {} + self.request_page_task_queue = queue.Queue() + self.ready_page_task_queue = queue.Queue() + self.success_queue = queue.Queue() + self.failed_queue = queue.Queue() + + self.page_index_queue = queue.Queue() + for page_index in range(self.args.pd_kv_page_num): + self.page_index_queue.put(page_index) + + # warmup 预先加载一次kv 数据到 mem manager,避免第一次拷贝时出现卡顿。 + self._warmup() + + for func in [ + self.recv_task_loop, + self.dispatch_task_loop, + self.accept_peer_task_loop, + self.request_page_loop, + self.read_page_to_mems_loop, + self.success_loop, + self.fail_loop, + ]: + threading.Thread(target=func, daemon=True).start() + return + + def _warmup(self): + for dp_index in range(self.args.dp // self.args.nnodes): + with torch.cuda.stream(stream=self.copy_cuda_stream): + cur_mem = self.mem_managers[self.device_id] + cur_mem.read_page_kv_move_buffer_to_mem( + mem_indexes=[0], + page_index=0, + dp_index=dp_index, + mem_managers=self.mem_managers, + dp_world_size=self.dp_world_size, + ) + torch.cuda.current_stream().synchronize() + return + + @log_exception + def recv_task_loop(self): + while True: + obj: Union[PDChunckedTransTaskGroup, PDAbortReq] = self.task_in_queue.get() + if isinstance(obj, PDChunckedTransTaskGroup): + self.recv_task_group_queue.put(obj) + elif isinstance(obj, PDAbortReq): + self._abort(request_id=obj.request_id) + else: + assert False, f"recv error obj {obj}" + + def _abort(self, request_id: int, error_info: str = "aborted req"): + aborted_tasks = [] + with self.waiting_dict_lock: + for key, trans_task in list(self.waiting_dict.items()): + if trans_task.request_id == request_id and trans_task.dst_page_index is None: + # 对于 已经分配了page index 的任务,不能直接失败,需要两边走完正常流程再失败,不然可能 + # 出现复杂的异步协同问题。 + aborted_tasks.append(self.waiting_dict.pop(key)) + + for trans_task in aborted_tasks: + trans_task.error_info = error_info + self.failed_queue.put(trans_task) + return + + @log_exception + def dispatch_task_loop(self): + while True: + trans_task_group: PDChunckedTransTaskGroup = self.recv_task_group_queue.get() + + with self.waiting_dict_lock: + for task in trans_task_group.task_list: + if task.need_transfer_page(): + self.waiting_dict[task.get_key()] = task + else: + task.start_trans_time = time.time() + self.success_queue.put((None, None, task)) + + # up status + task = trans_task_group.task_list[0] + + decode_node_info = PDDecodeNodeInfo( + decode_node_id=self.args.pd_node_id, + pd_master_node_id=task.pd_master_node_id, + agent_name=self.transporter.agent_name, + agent_metadata=self.transporter.agent_metadata, + num_pages=self.transporter.num_pages, + page_reg_desc=self.transporter.local_page_mem_desc, + request_id=task.request_id, + ready_kv_len=task.start_kv_index, + ) + + up_status = PDUpKVStatus( + group_request_id=task.request_id, + pd_master_node_id=task.pd_master_node_id, + pd_kv_trans_params=pickle.dumps(decode_node_info), + ) + + self.up_status_in_queue.put(up_status) + + @log_exception + def accept_peer_task_loop( + self, + ): + torch.cuda.set_device(self.device_id) + while True: + # notify update + try: + notifies_dict = self.transporter.get_new_notifs() + except BaseException as e: + logger.error(f"get new notifies failed: {str(e)}") + logger.exception(str(e)) + notifies_dict = {} + + if notifies_dict: + for remote_agent_name, _notify_list in notifies_dict.items(): + for notify in _notify_list: + try: + notify_obj = pickle.loads(notify) + except: + notify_obj = None + + if not isinstance(notify_obj, PDChunckedTransTask): + continue + + # 请求有错误 + if notify_obj.error_info is not None: + # 直接清理掉所有的相关请求。 + with self.waiting_dict_lock: + local_trans_task = self.waiting_dict.pop(notify_obj.get_key(), None) + if local_trans_task is not None: + local_trans_task.error_info = notify_obj.error_info + # 软性的调整超时时间,防止一些特殊情况,过快的释放task + # 占用的page 页面,导致多p 复写引起脏内容的问题。 + local_trans_task.transfer_time_out_secs = 12 + self.failed_queue.put(local_trans_task) + + self._abort( + request_id=notify_obj.request_id, + error_info=notify_obj.error_info, + ) + continue + + # 到了请求页面的阶段 + remote_trans_task = notify_obj + if remote_trans_task.write_stage == "request": + with self.waiting_dict_lock: + local_trans_task = self.waiting_dict.pop(remote_trans_task.get_key(), None) + if local_trans_task is not None: + local_trans_task.prefill_agent_name = remote_trans_task.prefill_agent_name + local_trans_task.prefill_agent_metadata = remote_trans_task.prefill_agent_metadata + local_trans_task.prefill_num_pages = remote_trans_task.prefill_num_pages + local_trans_task.prefill_page_reg_desc = remote_trans_task.prefill_page_reg_desc + self.request_page_task_queue.put(local_trans_task) + logger.info(f"recv WRITE request from prefill: {remote_trans_task.to_str()}") + else: + # This does not necessarily mean the WRITE protocol state is corrupted. + # A common benign case is: decode has already received an abort for this + # request and removed its waiting task, while prefill's NIXL WRITE request + # notify arrives later. Keep the original cleanup/error path so true + # missing-task bugs are still visible, but make the log explicit enough + # to avoid misclassifying abort-after-cleanup as a transfer failure. + logger.warning( + "can not find waiting WRITE task for request notify, " + "possibly because request was already aborted and cleaned on decode side: " + f"{remote_trans_task.to_str()}" + ) + # 发一个error信息回去给 prefill 节点,让其可以知道这边有问题了,它可以选择其他清理掉请求。 + remote_trans_task.error_info = "can not find waiting WRITE task for request notify" + self.transporter.send_error_info_to_prefill_node(trans_task=remote_trans_task) + + continue + + # prefill 写完数据到了 done 阶段 + if remote_trans_task.write_stage == "done": + with self.waiting_dict_lock: + local_trans_task = self.waiting_dict.pop(remote_trans_task.get_key(), None) + if local_trans_task is not None: + local_trans_task.first_gen_token_id = remote_trans_task.first_gen_token_id + local_trans_task.first_gen_token_logprob = remote_trans_task.first_gen_token_logprob + self.ready_page_task_queue.put(local_trans_task) + logger.info(f"recv WRITE done from prefill: {remote_trans_task.to_str()}") + else: + # Same race as the WRITE request stage: decode may have cleaned the + # waiting task because the request was aborted, then a late done notify + # arrives from prefill. Preserve the original error path, but make the + # diagnostic tell future readers this can be abort-related noise. + logger.warning( + "can not find waiting WRITE task for done notify, " + "possibly because request was already aborted and cleaned on decode side: " + f"{remote_trans_task.to_str()}" + ) + # 发一个error信息回去给 prefill 节点,让其可以知道这边有问题了,它可以选择其他清理掉请求。 + remote_trans_task.error_info = "can not find waiting WRITE task for done notify" + self.transporter.send_error_info_to_prefill_node(trans_task=remote_trans_task) + continue + + self._check_tasks_time_out() + if not notifies_dict: + time.sleep(0.001) + + def _check_tasks_time_out(self): + with self.waiting_dict_lock: + timeout_tasks = [] + for key, trans_task in list(self.waiting_dict.items()): + if trans_task.time_out(): + timeout_tasks.append(self.waiting_dict.pop(key)) + + for trans_task in timeout_tasks: + trans_task.error_info = "time out in accept_peer_task_loop" + self.failed_queue.put(trans_task) + return + + @log_exception + def request_page_loop(self): + torch.cuda.set_device(self.device_id) + while True: + dst_page_index = self.page_index_queue.get() + trans_task: PDChunckedTransTask = self.request_page_task_queue.get() + trans_task.dst_page_index = dst_page_index + trans_task.start_trans_time = time.time() + key = trans_task.get_key() + try: + with self.waiting_dict_lock: + self.waiting_dict[key] = trans_task + self.transporter.send_write_ready_task_to_prefill_node(trans_task=trans_task) + except BaseException as e: + with self.waiting_dict_lock: + self.waiting_dict.pop(key, None) + logger.error(f"send write ready task to prefill node failed: {trans_task.to_str()}") + logger.exception(str(e)) + self.transporter.remove_remote_agent(peer_name=trans_task.prefill_agent_name) + trans_task.error_info = f"send write ready task to prefill node failed: {str(e)}" + self.failed_queue.put(trans_task) + continue + + return + + @log_exception + def read_page_to_mems_loop(self): + torch.cuda.set_device(self.device_id) + while True: + trans_task: PDChunckedTransTask = self.ready_page_task_queue.get() + copy_start_event = torch.cuda.Event(enable_timing=True) + copy_end_event = torch.cuda.Event(enable_timing=True) + with torch.cuda.stream(stream=self.copy_cuda_stream): + copy_start_event.record(self.copy_cuda_stream) + cur_mem = self.mem_managers[self.device_id] + cur_mem.read_page_kv_move_buffer_to_mem( + trans_task.mem_indexes, + page_index=trans_task.dst_page_index, + dp_index=trans_task.decode_dp_index, + mem_managers=self.mem_managers, + dp_world_size=self.dp_world_size, + page_kind=trans_task.page_kind, + req_idx=trans_task.req_idx, + ) + copy_end_event.record(self.copy_cuda_stream) + self.success_queue.put((copy_end_event, copy_start_event, trans_task)) + + @log_exception + def success_loop(self): + torch.cuda.set_device(self.device_id) + while True: + copy_end_event, copy_start_event, trans_task = self.success_queue.get() + trans_task: PDChunckedTransTask = trans_task + copy_end_event: Optional[torch.cuda.Event] = copy_end_event + copy_start_event: Optional[torch.cuda.Event] = copy_start_event + read_page_gpu_time_ms = -1.0 + if copy_end_event is not None: + copy_end_event.synchronize() + read_page_gpu_time_ms = copy_start_event.elapsed_time(copy_end_event) + + if trans_task.dst_page_index is not None: + self.page_index_queue.put(trans_task.dst_page_index) + + if trans_task.xfer_handle is not None: + self.transporter.release_xfer_handle(trans_task.xfer_handle) + + ret = trans_task.createRetObj() + self.task_out_queue.put(ret) + + if trans_task.start_trans_time is not None: + logger.info( + f"trans task ret success:{ret} cost time: {trans_task.transfer_time()} s " + f"read_page_gpu_time: {read_page_gpu_time_ms:.3f} ms" + ) + else: + logger.info(f"trans task ret success:{ret}") + + @log_exception + def fail_loop(self): + torch.cuda.set_device(self.device_id) + while True: + trans_task: PDChunckedTransTask = self.failed_queue.get() + + # 回收页面 + if trans_task.dst_page_index is not None: + self.page_index_queue.put(trans_task.dst_page_index) + + if trans_task.xfer_handle is not None: + self.transporter.release_xfer_handle(trans_task.xfer_handle) + + ret = trans_task.createRetObj() + self.task_out_queue.put(ret) + logger.info(f"trans task ret fail:{ret}") + + if trans_task.error_info is not None: + # 提前终结所有有问题的属于同一个请求的任务。 + self._abort( + request_id=trans_task.request_id, + error_info=trans_task.error_info, + ) + self.transporter.send_error_info_to_prefill_node(trans_task=trans_task) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py similarity index 90% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py index f79fb4ea2c..bc1d00f384 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py @@ -5,14 +5,16 @@ import websockets import inspect import pickle +import setproctitle -from typing import Dict, Union +from typing import Dict from dataclasses import asdict -from lightllm.server.pd_io_struct import UpKVStatus, NixlUpKVStatus +from lightllm.server.pd_io_struct import PDUpKVStatus from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.pd_io_struct import PD_Master_Obj import torch.multiprocessing as mp +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) @@ -20,7 +22,7 @@ class UpStatusManager: def __init__(self, args, task_in_queue: mp.SimpleQueue): self.args = args - self.task_queue: mp.SimpleQueue[Union[UpKVStatus, NixlUpKVStatus]] = task_in_queue + self.task_queue: mp.SimpleQueue[PDUpKVStatus] = task_in_queue self.daemon_thread = threading.Thread(target=self.thread_loop, daemon=True) self.daemon_thread.start() @@ -64,7 +66,7 @@ async def dispatch_task_loop(self): while True: try: loop = asyncio.get_event_loop() - upkv_status: UpKVStatus = await loop.run_in_executor(None, self.task_queue.get) + upkv_status: PDUpKVStatus = await loop.run_in_executor(None, self.task_queue.get) if upkv_status.pd_master_node_id in self.id_to_handle_queue: await self.id_to_handle_queue[upkv_status.pd_master_node_id].put(upkv_status) else: @@ -87,7 +89,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): try: if pd_master_obj.node_id in self.id_to_handle_queue: task_queue = self.id_to_handle_queue[pd_master_obj.node_id] - upkv_status: Union[UpKVStatus, NixlUpKVStatus] = await task_queue.get() + upkv_status: PDUpKVStatus = await task_queue.get() await websocket.send(pickle.dumps(upkv_status)) logger.info(f"up kv status: {upkv_status}") else: @@ -108,6 +110,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): def _init_env(args, task_in_queue: mp.SimpleQueue): graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::pd_up_kv_status") up_kv_manager = UpStatusManager(args, task_in_queue) logger.info(f"up kv manager {str(up_kv_manager)} start ok") while True: diff --git a/lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py new file mode 100644 index 0000000000..c4e7951043 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py @@ -0,0 +1,40 @@ +import os + +from torch import Tensor + +from lightllm.server.core.objs import StartArgs +from lightllm.utils.log_utils import init_logger +from lightllm.utils.net_utils import get_hostname_ip + +logger = init_logger(__name__) + +_NCCL_CONTROL_PORT_MIN = 20000 +_NCCL_CONTROL_PORT_MAX = 30000 + + +def create_kv_transporter(args: StartArgs, node_id: int, tp_idx: int, kv_move_buffer: Tensor): + backend = os.getenv("LIGHTLLM_PD_KV_TRANSPORT_BACKEND", "nixl").lower() + if backend == "nixl": + from .nixl_kv_transporter import NixlKVTransporter + + return NixlKVTransporter(node_id=node_id, tp_idx=tp_idx, kv_move_buffer=kv_move_buffer) + + if backend == "nccl": + from .nccl_kv_transporter import NcclKVTransporter + + logger.info("Use NCCL as pd KV transporter backend") + port_min = _NCCL_CONTROL_PORT_MIN + tp_idx * 100 + port_max = min(_NCCL_CONTROL_PORT_MAX, port_min + 99) + if port_min > _NCCL_CONTROL_PORT_MAX: + port_min = _NCCL_CONTROL_PORT_MIN + port_max = _NCCL_CONTROL_PORT_MAX + return NcclKVTransporter( + node_id=node_id, + tp_idx=tp_idx, + kv_move_buffer=kv_move_buffer, + host_ip=get_hostname_ip() or args.host, + control_port_min=port_min, + control_port_max=port_max, + ) + + raise ValueError(f"unsupported LIGHTLLM_PD_KV_TRANSPORT_BACKEND={backend}") diff --git a/lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py new file mode 100644 index 0000000000..2ed0335ca5 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py @@ -0,0 +1,555 @@ +import copy +import errno +import queue +import pickle +import threading +from dataclasses import dataclass +from typing import Dict, List, Optional + +import rpyc +import torch +from torch import Tensor +from rpyc.utils.classic import obtain +from rpyc.utils.server import ThreadedServer + +from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup +from lightllm.server.pd_io_struct import PDChunckedTransTask, PDAgentMetadata +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger +from lightllm.utils.net_utils import get_hostname_ip + +logger = init_logger(__name__) + + +@dataclass +class NcclAgentMetadata: + agent_name: str + host_ip: str + control_port: int + device_id: int + + +class NcclKVTransporter: + """ + PD KV transporter backed by NCCL point-to-point operations. + + NCCL does not provide remote notifications or one-sided WRITE, so this class + uses a small RPyC control channel for notifications and communicator bootstrap + while preserving the same request/ready/done/error interface used by pd + trans-process management. + """ + + def __init__( + self, + node_id: int, + tp_idx: int, + kv_move_buffer: Tensor, + host_ip: Optional[str] = None, + control_port_min: int = 20000, + control_port_max: int = 30000, + ): + self.node_id = node_id + self.tp_idx = tp_idx + self.kv_move_buffer = kv_move_buffer + args = get_env_start_args() + assert args.run_mode in ["prefill", "decode"], args.run_mode + self.is_prefill_node = args.run_mode == "prefill" + self.capture_telemetry = False + self.num_pages, self.page_size, self.num_layers, self.kv_head_num, self.head_dims = kv_move_buffer.shape + + self.host_ip = host_ip or get_hostname_ip() + assert self.host_ip is not None, "can not get host ip for NcclKVTransporter" + + self.control_channel = _NcclControlChannel( + host_ip=self.host_ip, + port_min=control_port_min, + port_max=control_port_max, + ) + self.remote_agents: Dict[str, PDAgentMetadata] = {} + self._peers: Dict[str, "_NcclPeer"] = {} + self._peer_lock = threading.Lock() + return + + @property + def agent_name(self) -> str: + return f"{self.node_id}_{self.tp_idx}" + + @property + def agent_metadata(self) -> bytes: + return pickle.dumps( + NcclAgentMetadata( + agent_name=self.agent_name, + host_ip=self.host_ip, + control_port=self.control_channel.port, + device_id=self.tp_idx, + ) + ) + + @property + def local_page_mem_desc(self) -> bytes: + return pickle.dumps( + { + "num_pages": self.num_pages, + "page_size": self.page_size, + "num_layers": self.num_layers, + "kv_head_num": self.kv_head_num, + "head_dims": self.head_dims, + "dtype": str(self.kv_move_buffer.dtype), + } + ) + + def get_new_notifs(self) -> Dict[str, List[bytes]]: + notifs: Dict[str, List[bytes]] = {} + for notify in self.control_channel.get_notifs(): + notifs.setdefault(self._get_notify_source_agent_name(notify), []).append(notify) + return notifs + + def connect_add_remote_agent(self, remote_agent: PDAgentMetadata): + if remote_agent.agent_name in self.remote_agents: + return + + metadata: NcclAgentMetadata = pickle.loads(remote_agent.agent_metadata) + assert ( + metadata.agent_name == remote_agent.agent_name + ), f"Peer name {metadata.agent_name} does not match remote name {remote_agent.agent_name}" + + self.remote_agents[remote_agent.agent_name] = remote_agent + logger.info(f"Added NCCL remote agent {remote_agent.agent_name} at {metadata.host_ip}:{metadata.control_port}") + return + + def remove_remote_agent(self, peer_name: str): + if peer_name in self.remote_agents: + self.remote_agents.pop(peer_name, None) + with self._peer_lock: + peer = self._peers.pop(peer_name, None) + if peer is not None: + peer.close() + else: + logger.warning(f"try to remove remote agent, but peer name {peer_name} agent did not exist") + return + + def send_write_done_task_to_decode_node(self, trans_task: PDChunckedTransTask): + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.write_stage = "done" + new_trans_task.prefill_agent_name = self.agent_name + new_trans_task.prefill_agent_metadata = self.agent_metadata + new_trans_task.prefill_num_pages = self.num_pages + new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.decode_agent_name, new_trans_task) + return + + def send_write_request_task_to_decode_node(self, trans_task: PDChunckedTransTask): + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.write_stage = "request" + new_trans_task.prefill_agent_name = self.agent_name + new_trans_task.prefill_agent_metadata = self.agent_metadata + new_trans_task.prefill_num_pages = self.num_pages + new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.decode_agent_name, new_trans_task) + return + + def send_write_ready_task_to_prefill_node(self, trans_task: PDChunckedTransTask): + if trans_task.prefill_agent_name not in self.remote_agents: + self.connect_add_remote_agent(trans_task.create_prefill_agent_obj()) + + self._get_peer(trans_task.prefill_agent_name).start_recv(trans_task) + + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.write_stage = "ready" + new_trans_task.decode_agent_name = self.agent_name + new_trans_task.decode_agent_metadata = self.agent_metadata + new_trans_task.decode_num_pages = self.num_pages + new_trans_task.decode_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.prefill_agent_name, new_trans_task) + return + + def send_error_info_to_prefill_node(self, trans_task: PDChunckedTransTask): + if trans_task.prefill_agent_name is None: + return + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.write_stage = "error" + new_trans_task.decode_agent_name = self.agent_name + new_trans_task.decode_agent_metadata = self.agent_metadata + new_trans_task.decode_num_pages = self.num_pages + new_trans_task.decode_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.prefill_agent_name, new_trans_task) + return + + def send_error_info_to_decode_node(self, trans_task: PDChunckedTransTask): + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.write_stage = "error" + new_trans_task.prefill_agent_name = self.agent_name + new_trans_task.prefill_agent_metadata = self.agent_metadata + new_trans_task.prefill_num_pages = self.num_pages + new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.decode_agent_name, new_trans_task) + return + + def write_blocks_paged(self, trans_task: PDChunckedTransTask) -> "_NcclXferHandle": + assert trans_task.src_page_index is not None and trans_task.dst_page_index is not None + decode_agent_name = trans_task.decode_agent_name + if decode_agent_name not in self.remote_agents: + self.connect_add_remote_agent(trans_task.create_decode_agent_obj()) + + return self._get_peer(decode_agent_name).send_page(trans_task) + + def check_task_status(self, trans_task: PDChunckedTransTask) -> str: + assert trans_task.xfer_handle is not None + return trans_task.xfer_handle.check_status() + + def release_xfer_handle(self, handle): + return + + def shutdown(self): + with self._peer_lock: + peers = list(self._peers.values()) + self._peers.clear() + for peer in peers: + peer.close() + self.remote_agents.clear() + self.control_channel.close() + return + + def _get_peer(self, peer_name: str) -> "_NcclPeer": + with self._peer_lock: + peer = self._peers.get(peer_name) + if peer is None: + peer = _NcclPeer(self, peer_name) + self._peers[peer_name] = peer + return peer + + def _send_task_notif(self, remote_agent_name: str, trans_task: PDChunckedTransTask): + if remote_agent_name not in self.remote_agents: + if remote_agent_name == trans_task.decode_agent_name: + self.connect_add_remote_agent(trans_task.create_decode_agent_obj()) + else: + self.connect_add_remote_agent(trans_task.create_prefill_agent_obj()) + + remote_metadata = self._get_remote_metadata(remote_agent_name) + self.control_channel.send_notif( + remote_agent_name, + remote_metadata.host_ip, + remote_metadata.control_port, + pickle.dumps(trans_task), + ) + return + + def _get_remote_metadata(self, remote_agent_name: str) -> NcclAgentMetadata: + remote_agent = self.remote_agents[remote_agent_name] + return pickle.loads(remote_agent.agent_metadata) + + def _copy_notify_task(self, trans_task: PDChunckedTransTask) -> PDChunckedTransTask: + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.mem_indexes = None + new_trans_task.xfer_handle = None + return new_trans_task + + def _get_notify_source_agent_name(self, notify: bytes) -> str: + notify_obj = pickle.loads(notify) + assert isinstance(notify_obj, PDChunckedTransTask), type(notify_obj) + + if notify_obj.error_info is not None: + if self.is_prefill_node: + assert notify_obj.decode_agent_name is not None + return notify_obj.decode_agent_name + else: + assert notify_obj.prefill_agent_name is not None + return notify_obj.prefill_agent_name + + if notify_obj.write_stage == "request": + assert notify_obj.prefill_agent_name is not None + return notify_obj.prefill_agent_name + + if notify_obj.write_stage in ["ready", "done"]: + assert notify_obj.decode_agent_name is not None + return notify_obj.decode_agent_name + + raise AssertionError(f"unexpected notify stage: {notify_obj.write_stage}") + + +@dataclass +class _NcclXferHandle: + peer_name: str + event: torch.cuda.Event + status: str = "PROC" + error_info: Optional[str] = None + + def check_status(self) -> str: + if self.status != "PROC": + return self.status + + try: + if self.event.query(): + self.status = "DONE" + except BaseException as e: + self.status = "ERR" + self.error_info = str(e) + return self.status + + +class _NcclPeer: + def __init__(self, transporter: NcclKVTransporter, peer_name: str): + self.transporter = transporter + self.peer_name = peer_name + self.comm: Optional[PyNcclCommunicator] = None + self.stream: Optional[torch.cuda.Stream] = None + self.recv_queue: Optional["queue.Queue[Optional[PDChunckedTransTask]]"] = None + self._lock = threading.Lock() + + def send_page(self, trans_task: PDChunckedTransTask) -> _NcclXferHandle: + assert trans_task.src_page_index is not None and trans_task.dst_page_index is not None + page_tensor = self.transporter.kv_move_buffer[trans_task.src_page_index] + comm = self._ensure_comm(is_server=True) + stream = self._get_stream() + + comm.send(page_tensor, dst=1, stream=stream) + event = torch.cuda.Event() + event.record(stream) + + logger.info( + f"NCCL send page posted request_id={trans_task.request_id} " + f"src_page={trans_task.src_page_index} dst_agent={self.peer_name}" + ) + return _NcclXferHandle(peer_name=self.peer_name, event=event) + + def start_recv(self, trans_task: PDChunckedTransTask): + self._get_recv_queue().put(copy.copy(trans_task)) + return + + def close(self): + with self._lock: + recv_queue = self.recv_queue + self.recv_queue = None + comm = self.comm + self.comm = None + self.stream = None + + if recv_queue is not None: + recv_queue.put(None) + if comm is not None: + comm.destroy() + return + + def _get_stream(self) -> torch.cuda.Stream: + with self._lock: + if self.stream is None: + torch.cuda.set_device(self.transporter.tp_idx) + self.stream = torch.cuda.Stream() + return self.stream + + def _get_recv_queue(self) -> "queue.Queue[Optional[PDChunckedTransTask]]": + with self._lock: + if self.recv_queue is not None: + return self.recv_queue + + self.recv_queue = queue.Queue() + threading.Thread(target=self._recv_page_loop, args=(self.recv_queue,), daemon=True).start() + return self.recv_queue + + def _recv_page_loop(self, recv_queue: "queue.Queue[Optional[PDChunckedTransTask]]"): + torch.cuda.set_device(self.transporter.tp_idx) + while True: + trans_task = recv_queue.get() + if trans_task is None: + return + self._recv_page(trans_task) + + def _recv_page(self, trans_task: PDChunckedTransTask): + try: + page_tensor = self.transporter.kv_move_buffer[trans_task.dst_page_index] + comm = self._ensure_comm(is_server=False) + stream = self._get_stream() + comm.recv(page_tensor, src=0, stream=stream) + logger.info( + f"NCCL recv page done request_id={trans_task.request_id} " f"dst_page={trans_task.dst_page_index}" + ) + except BaseException as e: + trans_task.error_info = str(e) + logger.exception(str(e)) + self._drop_comm() + self.transporter.send_error_info_to_prefill_node(trans_task) + return + + def _ensure_comm(self, is_server: bool) -> PyNcclCommunicator: + with self._lock: + if self.comm is not None: + return self.comm + + if is_server: + src_id = self.transporter.agent_name + dest_id = self.peer_name + else: + src_id = self.peer_name + dest_id = self.transporter.agent_name + + group = StatelessP2PProcessGroup.create( + src_id=src_id, + dest_id=dest_id, + is_server=is_server, + store=_NcclControlStore(self.transporter, self.peer_name), + ) + self.comm = PyNcclCommunicator(group, self.transporter.tp_idx) + logger.info(f"Created NCCL communicator with peer {self.peer_name}") + return self.comm + + def _drop_comm(self): + with self._lock: + comm = self.comm + self.comm = None + + if comm is not None: + comm.destroy() + logger.warning(f"Dropped NCCL communicator with peer {self.peer_name}") + return + + +class _NcclControlService(rpyc.Service): + def __init__(self, channel: "_NcclControlChannel"): + super().__init__() + self.channel = channel + + def exposed_push_notif(self, payload: bytes): + payload = obtain(payload) + self.channel.notif_queue.put(payload) + return + + def exposed_set_value(self, key: str, value: bytes): + key = obtain(key) + value = obtain(value) + self.channel.add_store_value(key, value) + return + + +class _NcclControlChannel: + def __init__( + self, + host_ip: str, + port_min: int, + port_max: int, + ): + self.notif_queue: "queue.Queue[bytes]" = queue.Queue() + self._store_values: Dict[str, bytes] = {} + self._store_cond = threading.Condition() + self._conn_lock = threading.Lock() + self._conns: Dict[tuple[str, str, int], rpyc.Connection] = {} + self._server, self.port = self._start_server(host_ip, port_min, port_max) + + def _start_server(self, host_ip: str, port_min: int, port_max: int) -> tuple[ThreadedServer, int]: + last_error = None + for cur_port in range(port_min, port_max + 1): + try: + server = ThreadedServer( + _NcclControlService(self), + hostname=host_ip, + port=cur_port, + protocol_config={ + "allow_pickle": True, + "allow_all_attrs": True, + "allow_getattr": True, + "allow_setattr": True, + }, + ) + threading.Thread(target=server.start, daemon=True).start() + logger.info(f"NCCL RPyC control channel listen on {host_ip}:{cur_port}") + return server, cur_port + except OSError as e: + last_error = e + if e.errno == errno.EADDRINUSE: + logger.info(f"NCCL RPyC control port {host_ip}:{cur_port} is in use, try next port") + else: + logger.warning(f"Create NCCL RPyC control channel on {host_ip}:{cur_port} failed: {e}") + raise RuntimeError(f"can not allocate NCCL control port in [{port_min}, {port_max}]") from last_error + + def close(self): + with self._conn_lock: + for conn in self._conns.values(): + try: + conn.close() + except Exception: + pass + self._conns.clear() + self._server.close() + return + + def add_store_value(self, key: str, value: bytes): + with self._store_cond: + self._store_values[key] = value + self._store_cond.notify_all() + return + + def wait_store_value(self, key: str, timeout: float = 30.0) -> bytes: + with self._store_cond: + ok = self._store_cond.wait_for(lambda: key in self._store_values, timeout=timeout) + if not ok: + raise TimeoutError(f"wait timeout after {int(timeout * 1000)}ms, key: {key}") + return self._store_values.pop(key) + + def get_notifs(self) -> List[bytes]: + notifs = [] + while True: + try: + notifs.append(self.notif_queue.get_nowait()) + except queue.Empty: + break + return notifs + + def send_notif(self, peer_name: str, host_ip: str, port: int, payload: bytes): + self._call(peer_name, host_ip, port, "push_notif", payload) + return + + def send_store_value(self, peer_name: str, host_ip: str, port: int, key: str, value: bytes): + self._call(peer_name, host_ip, port, "set_value", key, value) + return + + def _call(self, peer_name: str, host_ip: str, port: int, method: str, *args): + conn_key = (peer_name, host_ip, port) + with self._conn_lock: + conn = self._conns.get(conn_key) + if conn is None: + conn = rpyc.connect( + host_ip, + port, + config={ + "allow_pickle": True, + "allow_all_attrs": True, + "allow_getattr": True, + "allow_setattr": True, + }, + ) + self._conns[conn_key] = conn + try: + getattr(conn.root, method)(*args) + except Exception as e: + self._conns.pop(conn_key, None) + try: + conn.close() + except Exception: + pass + raise RuntimeError(f"NCCL control RPC {method} to {peer_name} failed") from e + return + + +class _NcclControlStore: + def __init__(self, transporter: "NcclKVTransporter", remote_agent_name: str): + self.transporter = transporter + self.remote_agent_name = remote_agent_name + + def set(self, key: str, value: bytes): + remote_metadata = self.transporter._get_remote_metadata(self.remote_agent_name) + self.transporter.control_channel.send_store_value( + self.remote_agent_name, + remote_metadata.host_ip, + remote_metadata.control_port, + self._send_key(key), + bytes(value), + ) + return + + def get(self, key: str) -> bytes: + return self.transporter.control_channel.wait_store_value(self._recv_key(key)) + + def _send_key(self, key: str) -> str: + return f"{self.transporter.agent_name}->{self.remote_agent_name}:{key}" + + def _recv_key(self, key: str) -> str: + return f"{self.remote_agent_name}->{self.transporter.agent_name}:{key}" diff --git a/lightllm/server/router/model_infer/mode_backend/pd/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd/nixl_kv_transporter.py new file mode 100644 index 0000000000..bd5e11f05d --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd/nixl_kv_transporter.py @@ -0,0 +1,288 @@ +import pickle +import copy +import os +import time +from dataclasses import dataclass +from typing import Dict +from torch import Tensor +from lightllm.server.pd_io_struct import PDChunckedTransTask, PDAgentMetadata +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) + +try: + from nixl._api import nixl_agent as NixlWrapper + from nixl._api import nixlBind + from nixl._api import nixl_agent_config + + logger.info("Nixl is available") +except ImportError: + logger.warning("nixl is not installed, which is required for pd disagreggation!!!") + NixlWrapper = None + + +class NixlKVTransporter: + def __init__(self, node_id: int, tp_idx: int, kv_move_buffer: Tensor): + self.node_id = node_id + self.tp_idx = tp_idx + self.capture_telemetry = os.getenv("LIGHTLLM_NIXL_CAPTURE_TELEMETRY", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + conf = None + if self.capture_telemetry: + conf = nixl_agent_config() + conf.capture_telemetry = True + logger.info("NIXL telemetry enabled") + self.nixl_agent = NixlWrapper(self.agent_name, conf) + self._register_kv_move_buffer(kv_move_buffer=kv_move_buffer) + self.remote_agents: Dict[str, PDAgentMetadata] = {} + return + + @property + def agent_name(self) -> str: + return f"{self.node_id}_{self.tp_idx}" + + @property + def agent_metadata(self): + return self.nixl_agent.get_agent_metadata() + + @property + def local_page_mem_desc(self): + return self.nixl_agent.get_serialized_descs(self.page_reg_desc) + + def get_new_notifs(self) -> Dict[str, list[bytes]]: + return self.nixl_agent.get_new_notifs() + + def _register_kv_move_buffer(self, kv_move_buffer: Tensor): + self.num_pages, self.page_size, self.num_layers, self.kv_head_num, self.head_dims = kv_move_buffer.shape + self.dtype_byte_size = kv_move_buffer.element_size() + self.page_len = self.page_size * self.num_layers * self.kv_head_num * self.head_dims * self.dtype_byte_size + self.page_reg_desc = self.nixl_agent.register_memory(kv_move_buffer) + self.page_local_xfer_handles = self._create_paged_xfer_handles(self.page_reg_desc, self.num_pages) + + def _create_paged_xfer_handles(self, reg_desc: "nixlBind.nixlRegDList", page_num: int, agent_name: str = ""): + base_addr, _, device_id, _ = reg_desc[0] + pages_data = [] + for page_id in range(page_num): + pages_data.append((base_addr + page_id * self.page_len, self.page_len, device_id)) + descs = self.nixl_agent.get_xfer_descs(pages_data, "VRAM") + return self.nixl_agent.prep_xfer_dlist(agent_name, descs, "VRAM") + + def connect_add_remote_agent(self, remote_agent: PDAgentMetadata): + if remote_agent.agent_name in self.remote_agents: + return + + start_time = time.time() + + peer_name = self.nixl_agent.add_remote_agent(remote_agent.agent_metadata) + if isinstance(peer_name, bytes): + peer_name = peer_name.decode() + + assert ( + peer_name == remote_agent.agent_name + ), f"Peer name {peer_name} does not match remote name {remote_agent.agent_name}" + + page_mem_desc = self.nixl_agent.deserialize_descs(remote_agent.page_reg_desc) + kv_page_xfer_handles = self._create_paged_xfer_handles( + page_mem_desc, remote_agent.num_pages, agent_name=peer_name + ) + remote_agent.page_xfer_handles = kv_page_xfer_handles + + logger.info( + f"Added remote agent {peer_name} with mem desc {page_mem_desc} cost time: {time.time() - start_time} s" + ) + + self.remote_agents[remote_agent.agent_name] = remote_agent + return + + def remove_remote_agent(self, peer_name: str): + if peer_name in self.remote_agents: + try: + remote_agent: PDAgentMetadata = self.remote_agents.pop(peer_name, None) + assert remote_agent.agent_name == peer_name + self.nixl_agent.remove_remote_agent(remote_agent.agent_name) + if remote_agent.page_xfer_handles is not None: + self.nixl_agent.release_dlist_handle(remote_agent.page_xfer_handles) + except BaseException as e: + logger.error(f"remove remote agent {peer_name} failed") + logger.exception(str(e)) + else: + logger.warning(f"try to remove remote agent, but peer name {peer_name} agent did not exist") + + def send_write_done_task_to_decode_node(self, trans_task: PDChunckedTransTask): + decode_agent_name = trans_task.decode_agent_name + if decode_agent_name not in self.remote_agents: + logger.warning(f"decode_agent_name {decode_agent_name} not exist") + _remote_agent = trans_task.create_decode_agent_obj() + self.connect_add_remote_agent(_remote_agent) + + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "done" + new_trans_task.mem_indexes = None + new_trans_task.xfer_handle = None + new_trans_task.decode_agent_metadata = None + new_trans_task.decode_page_reg_desc = None + new_trans_task.prefill_agent_name = self.agent_name + new_trans_task.prefill_agent_metadata = self.agent_metadata + new_trans_task.prefill_num_pages = self.num_pages + new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc + self.nixl_agent.send_notif( + remote_agent_name=decode_agent_name, + notif_msg=pickle.dumps(new_trans_task), + ) + return + + def send_write_request_task_to_decode_node(self, trans_task: PDChunckedTransTask): + decode_agent_name = trans_task.decode_agent_name + if decode_agent_name not in self.remote_agents: + logger.warning(f"decode_agent_name {decode_agent_name} not exist") + _remote_agent = trans_task.create_decode_agent_obj() + self.connect_add_remote_agent(_remote_agent) + + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "request" + new_trans_task.mem_indexes = None + new_trans_task.xfer_handle = None + new_trans_task.prefill_agent_name = self.agent_name + new_trans_task.prefill_agent_metadata = self.agent_metadata + new_trans_task.prefill_num_pages = self.num_pages + new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc + self.nixl_agent.send_notif( + remote_agent_name=decode_agent_name, + notif_msg=pickle.dumps(new_trans_task), + ) + return + + def send_write_ready_task_to_prefill_node(self, trans_task: PDChunckedTransTask): + prefill_agent_name = trans_task.prefill_agent_name + if prefill_agent_name not in self.remote_agents: + logger.warning(f"prefill_agent_name {prefill_agent_name} not exist") + _remote_agent = trans_task.create_prefill_agent_obj() + self.connect_add_remote_agent(_remote_agent) + + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "ready" + new_trans_task.mem_indexes = None + new_trans_task.xfer_handle = None + new_trans_task.decode_agent_name = self.agent_name + new_trans_task.decode_agent_metadata = self.agent_metadata + new_trans_task.decode_num_pages = self.num_pages + new_trans_task.decode_page_reg_desc = self.local_page_mem_desc + self.nixl_agent.send_notif( + remote_agent_name=prefill_agent_name, + notif_msg=pickle.dumps(new_trans_task), + ) + return + + def send_error_info_to_prefill_node(self, trans_task: PDChunckedTransTask): + # decode node 主动发送错误信息给 prefill node, 但是只有到达一定阶段的任务才有对端的信息 + # 才能发送 + if trans_task.prefill_agent_name is None: + return + + try: + prefill_agent_name = trans_task.prefill_agent_name + if prefill_agent_name not in self.remote_agents: + logger.warning(f"prefill_agent_name {prefill_agent_name} not exist") + _remote_agent = trans_task.create_prefill_agent_obj() + self.connect_add_remote_agent(_remote_agent) + assert trans_task.error_info is not None + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "error" + new_trans_task.mem_indexes = None + new_trans_task.xfer_handle = None + new_trans_task.decode_agent_name = self.agent_name + new_trans_task.decode_agent_metadata = self.agent_metadata + new_trans_task.decode_num_pages = self.num_pages + new_trans_task.decode_page_reg_desc = self.local_page_mem_desc + self.nixl_agent.send_notif( + remote_agent_name=prefill_agent_name, + notif_msg=pickle.dumps(new_trans_task), + ) + except BaseException as e: + logger.error(f"send error info to prefill node failed: {trans_task.to_str()}") + logger.exception(str(e)) + self.remove_remote_agent(peer_name=prefill_agent_name) + return + + def send_error_info_to_decode_node(self, trans_task: PDChunckedTransTask): + try: + decode_agent_name = trans_task.decode_agent_name + if decode_agent_name not in self.remote_agents: + logger.warning(f"decode_agent_name {decode_agent_name} not exist") + _remote_agent = trans_task.create_decode_agent_obj() + self.connect_add_remote_agent(_remote_agent) + assert trans_task.error_info is not None + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "error" + new_trans_task.mem_indexes = None + new_trans_task.xfer_handle = None + new_trans_task.prefill_agent_name = self.agent_name + new_trans_task.prefill_agent_metadata = self.agent_metadata + new_trans_task.prefill_num_pages = self.num_pages + new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc + self.nixl_agent.send_notif( + remote_agent_name=decode_agent_name, + notif_msg=pickle.dumps(new_trans_task), + ) + except BaseException as e: + logger.error(f"send error info to decode node failed: {trans_task.to_str()}") + logger.exception(str(e)) + self.remove_remote_agent(peer_name=decode_agent_name) + return + + def write_blocks_paged( + self, + trans_task: PDChunckedTransTask, + ) -> int: + """ + prefill node call this function to write kv blocks into decode node pages + """ + decode_agent_name = trans_task.decode_agent_name + if decode_agent_name not in self.remote_agents: + logger.warning(f"decode_agent_name {decode_agent_name} not exist") + _remote_agent = trans_task.create_decode_agent_obj() + self.connect_add_remote_agent(_remote_agent) + + assert trans_task.src_page_index is not None and trans_task.dst_page_index is not None + remote_agent: PDAgentMetadata = self.remote_agents[decode_agent_name] + src_handle = self.page_local_xfer_handles + dst_handle = remote_agent.page_xfer_handles + handle = self.nixl_agent.make_prepped_xfer( + "WRITE", + src_handle, + [trans_task.src_page_index], + dst_handle, + [trans_task.dst_page_index], + b"", + ) + if not handle: + raise RuntimeError(f"make_prepped_xfer failed for task: {trans_task.to_str()}") + + self.nixl_agent.transfer(handle) + + return handle + + def check_task_status(self, trans_task: PDChunckedTransTask) -> str: + assert trans_task.xfer_handle is not None + handle = trans_task.xfer_handle + xfer_state = self.nixl_agent.check_xfer_state(handle) + if xfer_state == "ERR": + logger.warning(f"Transfer failed with trans task {trans_task.to_str()} for handle {handle}") + return xfer_state + + def release_xfer_handle(self, handle): + self.nixl_agent.release_xfer_handle(handle=handle) + return + + def shutdown(self): + self.nixl_agent.deregister_memory(self.page_reg_desc) + self.nixl_agent.release_dlist_handle(self.page_local_xfer_handles) + agent_names = list(self.remote_agents.keys()) + for agent_name in agent_names: + self.remove_remote_agent(agent_name) + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py b/lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py similarity index 98% rename from lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py rename to lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py index 62609c4c91..5a737c2fc6 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py @@ -120,7 +120,7 @@ def reduce_tensor(tensor): shared_cache[handle] = StorageWeakRef(storage) # _backward_hooks purposely omitted here, see # Note [Don't serialize hooks] - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import ( + from lightllm.server.router.model_infer.mode_backend.pd.p2p_fix import ( p2p_fix_rebuild_cuda_tensor, ) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/__init__.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/__init__.py rename to lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/__init__.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py similarity index 55% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py rename to lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py index 6f5a6e17d8..2a501f509b 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py @@ -1,8 +1,8 @@ import torch.multiprocessing as mp import random -from typing import List, Tuple, Optional +from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq -from lightllm.server.pd_io_struct import NIXLChunckedTransTask +from lightllm.server.pd_io_struct import PDChunckedTransTask from lightllm.utils.log_utils import init_logger from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.server.router.model_infer.infer_batch import g_infer_context @@ -11,13 +11,13 @@ logger = init_logger(__name__) -class NIXLChunckedPrefillForPrefillNode(ChunkedPrefillBackend): +class PDChunkedPrefillForPrefillNode(ChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.support_overlap = False self.info_queue: mp.Queue = info_queue self.classed_req_no_decode = True - self.nixl_prefill_chuncked_handle_func = self._prefill_chuncked_handle_func + self.pd_prefill_chunked_handle_func = self._prefill_chuncked_handle_func def init_custom(self): assert kv_trans_use_p2p() @@ -35,13 +35,11 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: prefill_finished = req_obj.shm_req.input_len <= req_obj.cur_kv_len if prefill_finished: # 等待所有传输任务都已经完成。 - if req_obj.nixl_pd_task_num == (req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num): + if req_obj.pd_task_num == (req_obj.pd_task_failed_num + req_obj.pd_task_success_num): ans_list.append(req_obj) else: if req_obj.infer_aborted: - if req_obj.nixl_pd_task_num == ( - req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num - ): + if req_obj.pd_task_num == (req_obj.pd_task_failed_num + req_obj.pd_task_success_num): ans_list.append(req_obj) else: continue @@ -55,30 +53,36 @@ def _prefill_chuncked_handle_func( """ 在每一步chuncked prefill 后,尝试生成chuncked 传输任务,发个 kv_move_manager 进行处理。 """ - # 系统内部的 health 请求不创建 kv 传输任务。 - if req_obj.req_id < 0: - return assert req_obj.cur_kv_len <= req_obj.shm_req.input_len input_len = req_obj.shm_req.input_len - page_size = self.args.nixl_pd_kv_page_size + page_size = self.args.pd_kv_page_size prefill_finished = req_obj.cur_kv_len == input_len - trans_task_list: List[NIXLChunckedTransTask] = [] - while req_obj.nixl_trans_kv_start_index < req_obj.cur_kv_len: - cur_page_size = min(page_size, req_obj.cur_kv_len - req_obj.nixl_trans_kv_start_index) + trans_task_list: List[PDChunckedTransTask] = [] + while req_obj.pd_trans_kv_start_index < req_obj.cur_kv_len: + cur_page_size = min(page_size, req_obj.cur_kv_len - req_obj.pd_trans_kv_start_index) # 生成页面传输任务, 放入kv move manager 的处理队列中 if cur_page_size == page_size or prefill_finished: - trans_task = self._create_nixl_trans_task( + trans_task = self._create_pd_trans_task( req_obj=req_obj, - kv_start_index=req_obj.nixl_trans_kv_start_index, - kv_end_index=req_obj.nixl_trans_kv_start_index + cur_page_size, + kv_start_index=req_obj.pd_trans_kv_start_index, + kv_end_index=req_obj.pd_trans_kv_start_index + cur_page_size, ) - req_obj.nixl_trans_kv_start_index += cur_page_size + req_obj.pd_trans_kv_start_index += cur_page_size trans_task_list.append(trans_task) else: break if prefill_finished and len(trans_task_list) != 0 and output_len == 1: + if g_infer_context.is_linear_att_mixed_model: + trans_task_list.append( + self._create_pd_trans_task( + req_obj=req_obj, + kv_start_index=input_len, + kv_end_index=input_len, + page_kind="linear_att_state", + ) + ) trans_task_list[-1].first_gen_token_id = next_token_id trans_task_list[-1].first_gen_token_logprob = next_token_prob @@ -87,41 +91,57 @@ def _prefill_chuncked_handle_func( self.info_queue.put(trans_task) return - def _create_nixl_trans_task( - self, req_obj: InferReq, kv_start_index: int, kv_end_index: int - ) -> NIXLChunckedTransTask: + def _create_pd_trans_task( + self, + req_obj: InferReq, + kv_start_index: int, + kv_end_index: int, + page_kind: str = "kv", + ) -> PDChunckedTransTask: # 确定传输设备 - if req_obj.nixl_trans_device_id == -1: - req_obj.nixl_trans_device_id = random.randint(0, self.node_world_size - 1) + if req_obj.pd_trans_device_id == -1: + if not hasattr(self, "pd_iter_device_id"): + self.pd_iter_device_id = 0 + req_obj.pd_trans_device_id = self.pd_iter_device_id + self.pd_iter_device_id = (self.pd_iter_device_id + 1) % self.node_world_size - nixl_decode_node_info = req_obj.sampling_param.nixl_decode_node - mem_indexes = ( - self.model.req_manager.req_to_token_indexs[req_obj.req_idx, kv_start_index:kv_end_index] - .detach() - .cpu() - .tolist() - ) - trans_task = NIXLChunckedTransTask( + pd_decode_node_info = req_obj.sampling_param.pd_decode_node + if page_kind == "kv": + mem_indexes = ( + self.model.req_manager.req_to_token_indexs[req_obj.req_idx, kv_start_index:kv_end_index] + .detach() + .cpu() + .tolist() + ) + req_idx = None + elif page_kind == "linear_att_state": + mem_indexes = [] + req_idx = req_obj.req_idx + else: + raise ValueError(f"unknown PD trans page kind {page_kind}") + trans_task = PDChunckedTransTask( request_id=req_obj.req_id, start_kv_index=kv_start_index, end_kv_index=kv_end_index, - time_out_secs=82, + time_out_secs=182, pd_master_node_id=req_obj.sampling_param.pd_master_node_id, prefill_dp_index=self.dp_rank_in_node, decode_dp_index=None, - src_device_id=req_obj.nixl_trans_device_id, + src_device_id=req_obj.pd_trans_device_id, dst_device_id=None, mem_indexes=mem_indexes, prefill_agent_name=None, prefill_agent_metadata=None, prefill_num_pages=None, prefill_page_reg_desc=None, - decode_agent_name=nixl_decode_node_info.agent_name, - decode_agent_metadata=nixl_decode_node_info.agent_metadata, - decode_num_pages=nixl_decode_node_info.num_pages, - decode_page_reg_desc=nixl_decode_node_info.page_reg_desc, + decode_agent_name=pd_decode_node_info.agent_name, + decode_agent_metadata=pd_decode_node_info.agent_metadata, + decode_num_pages=pd_decode_node_info.num_pages, + decode_page_reg_desc=pd_decode_node_info.page_reg_desc, first_gen_token_id=None, first_gen_token_logprob=None, + page_kind=page_kind, + req_idx=req_idx, ) - req_obj.nixl_pd_task_num += 1 + req_obj.pd_task_num += 1 return trans_task diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl_for_dp.py similarity index 54% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py rename to lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl_for_dp.py index eed98399e7..5f4465a730 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl_for_dp.py @@ -1,38 +1,42 @@ import torch.multiprocessing as mp -from typing import List, Tuple, Optional +from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq from lightllm.utils.log_utils import init_logger -from .prefill_impl import NIXLChunckedPrefillForPrefillNode, NIXLChunckedTransTask +from .prefill_impl import PDChunkedPrefillForPrefillNode, PDChunckedTransTask from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend logger = init_logger(__name__) -class NIXLDPChunkedForPrefillNode(DPChunkedPrefillBackend): +class PDDPChunkedForPrefillNode(DPChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.support_overlap = False self.info_queue: mp.Queue = info_queue self.classed_req_no_decode = True - self.nixl_prefill_chuncked_handle_func = self._prefill_chuncked_handle_func + self.pd_prefill_chunked_handle_func = self._prefill_chuncked_handle_func def init_custom(self): - NIXLChunckedPrefillForPrefillNode.init_custom(self) + PDChunkedPrefillForPrefillNode.init_custom(self) return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: - return NIXLChunckedPrefillForPrefillNode._filter_not_ready_reqs(self, req_ids) + return PDChunkedPrefillForPrefillNode._filter_not_ready_reqs(self, req_ids) def _prefill_chuncked_handle_func( self, req_obj: InferReq, next_token_id: int, next_token_prob: float, output_len: int ): - return NIXLChunckedPrefillForPrefillNode._prefill_chuncked_handle_func( + return PDChunkedPrefillForPrefillNode._prefill_chuncked_handle_func( self, req_obj=req_obj, next_token_id=next_token_id, next_token_prob=next_token_prob, output_len=output_len ) - def _create_nixl_trans_task( - self, req_obj: InferReq, kv_start_index: int, kv_end_index: int - ) -> NIXLChunckedTransTask: - return NIXLChunckedPrefillForPrefillNode._create_nixl_trans_task( - self, req_obj=req_obj, kv_start_index=kv_start_index, kv_end_index=kv_end_index + def _create_pd_trans_task( + self, req_obj: InferReq, kv_start_index: int, kv_end_index: int, page_kind: str = "kv" + ) -> PDChunckedTransTask: + return PDChunkedPrefillForPrefillNode._create_pd_trans_task( + self, + req_obj=req_obj, + kv_start_index=kv_start_index, + kv_end_index=kv_end_index, + page_kind=page_kind, ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_kv_move_manager.py similarity index 88% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py rename to lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_kv_move_manager.py index ac8026e58e..7bda07d54d 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_kv_move_manager.py @@ -1,14 +1,16 @@ import inspect +import setproctitle import torch.multiprocessing as mp import time from typing import List, Dict, Union, Callable from lightllm.utils.log_utils import init_logger -from lightllm.server.pd_io_struct import NIXLChunckedTransTask +from lightllm.server.pd_io_struct import PDChunckedTransTask from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.core.objs import StartArgs from ..trans_process_obj import KVTransProcess from ..base_kv_move_manager import BaseKVMoveManager from lightllm.utils.error_utils import log_exception +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) @@ -28,6 +30,7 @@ def _init_env(args, info_queue: mp.Queue, event: mp.Event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::prefill_kv_move_manager") from .prefill_trans_process import start_prefill_trans_process @@ -53,7 +56,7 @@ def __init__(self, args: StartArgs, info_queue: mp.Queue, start_trans_process_fu def task_dispatcher_loop(self): # 获取任务,并分发给相关卡的处理队列 while True: - task: NIXLChunckedTransTask = self.info_queue.get() + task: PDChunckedTransTask = self.info_queue.get() device_id = task.src_device_id try: diff --git a/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_trans_process.py new file mode 100644 index 0000000000..e286a10f96 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_trans_process.py @@ -0,0 +1,414 @@ +import torch +import time +import inspect +import threading +import setproctitle +import torch.multiprocessing as mp +import queue +import pickle +from typing import List, Dict, Optional +from lightllm.utils.log_utils import init_logger +from lightllm.common.kv_cache_mem_manager import MemoryManager +from lightllm.server.pd_io_struct import PDChunckedTransTask +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.server.core.objs import StartArgs +from ..kv_transporter import create_kv_transporter +from lightllm.utils.error_utils import log_exception +from lightllm.utils.envs_utils import get_unique_server_name + + +logger = init_logger(__name__) + + +def start_prefill_trans_process( + args, + device_id, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, + up_status_in_queue: Optional[mp.SimpleQueue] = None, +): + proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue)) + proc.start() + assert proc.is_alive() + logger.info(f"prefill trans kv process for device: {device_id} started!") + return proc + + +def _init_env( + args: StartArgs, + device_id: int, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, +): + import lightllm.utils.rpyc_fix_utils as _ + + import os + + # prefill source-side page copy and UCX progress are on the request critical path. + os.environ["CUDA_MPS_CLIENT_PRIORITY"] = "0" + + torch.backends.cudnn.enabled = False + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::prefill_trans:Device{device_id}") + + try: + torch.cuda.set_device(device_id) + graceful_registry(inspect.currentframe().f_code.co_name) + task_out_queue.put("proc_start") + + # 从共享内存读取所有rank的mem_manager + node_world_size = args.tp // args.nnodes + mem_managers: List[MemoryManager] = [ + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + ] + + task_out_queue.put("get_mem_managers_ok") + + manager = _PrefillTransModule( + args=args, + device_id=device_id, + task_in_queue=task_in_queue, + task_out_queue=task_out_queue, + mem_managers=mem_managers, + ) + assert manager is not None + + while True: + time.sleep(100) + + except Exception as e: + logger.exception(str(e)) + logger.error(f"Fatal error happened in kv trans process: {e}") + pass + + +class _PrefillTransModule: + def __init__( + self, + args: StartArgs, + device_id: int, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], + ) -> None: + self.args = args + self.dp_world_size = self.args.tp // self.args.dp + self.device_id = device_id + self.task_in_queue = task_in_queue + self.task_out_queue = task_out_queue + self.mem_managers = mem_managers + + cur_mem_manager: MemoryManager = self.mem_managers[device_id] + kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer( + page_num=self.args.pd_kv_page_num, page_size=self.args.pd_kv_page_size + ) + self.copy_cuda_stream = torch.cuda.Stream(priority=-1) + self.transporter = create_kv_transporter( + args=self.args, + node_id=self.args.pd_node_id, + tp_idx=device_id, + kv_move_buffer=kv_move_buffer, + ) + self.waiting_dict_lock = threading.Lock() + self.waiting_dict: Dict[str, PDChunckedTransTask] = {} + + self.local_copy_kv_queue = queue.Queue() + self.ready_transfer_queue = queue.Queue() + self.write_peer_kv_queue = queue.Queue() + self.success_queue = queue.Queue() + self.failed_queue = queue.Queue() + + self.page_index_queue = queue.Queue() + for page_index in range(self.args.pd_kv_page_num): + self.page_index_queue.put(page_index) + + # warmup 预先执行一次 kv 写入 page buffer,避免第一次拷贝时出现卡顿。 + self._warmup() + + for func in [ + self.recv_task_loop, + self.local_copy_kv_loop, + self.ready_transfer_loop, + self.accept_decode_write_task_loop, + self.write_peer_kv_loop, + self.update_task_status_loop, + self.success_loop, + self.fail_loop, + ]: + threading.Thread(target=func, daemon=True).start() + return + + def _warmup(self): + for dp_index in range(self.args.dp // self.args.nnodes): + with torch.cuda.stream(stream=self.copy_cuda_stream): + cur_mem = self.mem_managers[self.device_id] + cur_mem.write_mem_to_page_kv_move_buffer( + mem_indexes=[0], + page_index=0, + dp_index=dp_index, + mem_managers=self.mem_managers, + dp_world_size=self.dp_world_size, + ) + torch.cuda.current_stream().synchronize() + return + + def _abort(self, request_id: int, error_info: str = "aborted req"): + aborted_tasks = [] + with self.waiting_dict_lock: + for key, trans_task in list(self.waiting_dict.items()): + if trans_task.request_id == request_id: + aborted_tasks.append(self.waiting_dict.pop(key)) + + for trans_task in aborted_tasks: + trans_task.error_info = error_info + self.failed_queue.put(trans_task) + return + + @log_exception + def recv_task_loop(self): + torch.cuda.set_device(self.device_id) + + while True: + page_index = self.page_index_queue.get() + trans_task: PDChunckedTransTask = self.task_in_queue.get() + trans_task.src_page_index = page_index + + # 初次校验 time out + if trans_task.time_out(): + trans_task.error_info = "time out in recv_task_loop" + self.failed_queue.put(trans_task) + else: + self.local_copy_kv_queue.put(trans_task) + + @log_exception + def local_copy_kv_loop(self): + torch.cuda.set_device(self.device_id) + while True: + trans_task: PDChunckedTransTask = self.local_copy_kv_queue.get() + + # 将kv 数据拷贝到 page 上,然后传输给 decode node,让其进行读取。 + with torch.cuda.stream(stream=self.copy_cuda_stream): + cur_mem = self.mem_managers[self.device_id] + cur_mem.write_mem_to_page_kv_move_buffer( + trans_task.mem_indexes, + page_index=trans_task.src_page_index, + dp_index=trans_task.prefill_dp_index, + mem_managers=self.mem_managers, + dp_world_size=self.dp_world_size, + page_kind=trans_task.page_kind, + req_idx=trans_task.req_idx, + ) + sync_event = torch.cuda.Event() + sync_event.record() + + self.ready_transfer_queue.put((sync_event, trans_task)) + return + + @log_exception + def ready_transfer_loop(self): + torch.cuda.set_device(self.device_id) + while True: + sync_event, trans_task = self.ready_transfer_queue.get() + trans_task: PDChunckedTransTask = trans_task + sync_event: torch.cuda.Event = sync_event + sync_event.synchronize() + key = trans_task.get_key() + try: + with self.waiting_dict_lock: + self.waiting_dict[key] = trans_task + self.transporter.send_write_request_task_to_decode_node(trans_task) + logger.info(f"send WRITE request to decode: {key}") + except BaseException as e: + with self.waiting_dict_lock: + self.waiting_dict.pop(key, None) + logger.error(f"send WRITE request to decode failed: {trans_task.to_str()}") + logger.exception(str(e)) + trans_task.error_info = f"send WRITE request to decode failed: {str(e)}" + self.transporter.remove_remote_agent(peer_name=trans_task.decode_agent_name) + self.failed_queue.put(trans_task) + continue + return + + @log_exception + def accept_decode_write_task_loop(self): + while True: + try: + notifies_dict = self.transporter.get_new_notifs() + except BaseException as e: + logger.error(f"get new notifies failed: {str(e)}") + logger.exception(str(e)) + notifies_dict = {} + + if notifies_dict: + for _, _notify_list in notifies_dict.items(): + for notify in _notify_list: + try: + notify_obj = pickle.loads(notify) + except BaseException: + notify_obj = None + + if not isinstance(notify_obj, PDChunckedTransTask): + continue + + if notify_obj.error_info is not None: + logger.warning(f"recv WRITE error from decode: {notify_obj.to_str()}") + self._abort(request_id=notify_obj.request_id, error_info=notify_obj.error_info) + continue + + if notify_obj.write_stage == "ready": + key = notify_obj.get_key() + with self.waiting_dict_lock: + trans_task = self.waiting_dict.pop(key, None) + if trans_task is not None: + trans_task.dst_page_index = notify_obj.dst_page_index + self.write_peer_kv_queue.put(trans_task) + logger.info( + f"recv WRITE ready from decode request_id={trans_task.request_id} " + f"kv=[{trans_task.start_kv_index},{trans_task.end_kv_index}) " + f"srcpage={trans_task.src_page_index} dstpage={trans_task.dst_page_index}" + ) + else: + logger.warning( + f"can not find pending WRITE request for ready notify: {notify_obj.to_str()}" + ) + # 发一个error信息回去给 decode 节点,让其可以知道这边有问题了,它可以选择其他清理掉请求。 + notify_obj.error_info = "can not find pending WRITE request for ready notify" + self.transporter.send_error_info_to_decode_node(trans_task=notify_obj) + continue + else: + logger.error(f"ignore unknown WRITE notify stage: {notify_obj.to_str()}") + continue + + self._check_tasks_time_out() + + if not notifies_dict: + time.sleep(0.001) + return + + def _check_tasks_time_out(self): + with self.waiting_dict_lock: + timeout_tasks = [] + for key, trans_task in list(self.waiting_dict.items()): + if trans_task.time_out(): + timeout_tasks.append(self.waiting_dict.pop(key)) + + for trans_task in timeout_tasks: + trans_task.error_info = "time out waiting decode WRITE ready" + self.failed_queue.put(trans_task) + return + + @log_exception + def write_peer_kv_loop(self): + torch.cuda.set_device(self.device_id) + while True: + trans_task = self.write_peer_kv_queue.get() + trans_task: PDChunckedTransTask = trans_task + + try: + xfer_handle = self.transporter.write_blocks_paged(trans_task=trans_task) + trans_task.xfer_handle = xfer_handle + trans_task.start_trans_time = time.time() + with self.waiting_dict_lock: + self.waiting_dict[trans_task.get_key()] = trans_task + logger.info(f"start WRITE to decode node: {trans_task.to_str()}") + continue + except BaseException as e: + logger.error(f"write_blocks_paged failed: {trans_task.to_str()}") + logger.exception(str(e)) + self.transporter.remove_remote_agent(peer_name=trans_task.decode_agent_name) + trans_task.error_info = f"write_blocks_paged failed: {str(e)}" + self.failed_queue.put(trans_task) + continue + return + + @log_exception + def update_task_status_loop( + self, + ): + while True: + if len(self.waiting_dict) == 0: + time.sleep(0.001) + continue + + with self.waiting_dict_lock: + tasks = list(self.waiting_dict.values()) + for trans_task in tasks: + if trans_task.xfer_handle is None: + continue + + # 传输任务状态检查 + ret = self.transporter.check_task_status(trans_task=trans_task) + if ret == "DONE": + trans_task = self.waiting_dict.pop(trans_task.get_key(), None) + if self.transporter.capture_telemetry: + telem = self.transporter.nixl_agent.get_xfer_telemetry(trans_task.xfer_handle) + total_us = telem.xferDuration + post_us = telem.postDuration + backend_us = telem.xferDuration - telem.postDuration + nixl_backend = self.transporter.nixl_agent.query_xfer_backend(trans_task.xfer_handle) + logger.info( + f"write trans task request_id={trans_task.request_id} " + f"kv=[{trans_task.start_kv_index},{trans_task.end_kv_index}) " + f"src_page={trans_task.src_page_index} dst_page={trans_task.dst_page_index} " + f"xfer time: {total_us:.3f} us, " + f"post time: {post_us:.3f} us, backend time: {backend_us:.3f} us, " + f"nixl_backend: {nixl_backend}, total_bytes: {telem.totalBytes}" + ) + self.transporter.send_write_done_task_to_decode_node(trans_task) + logger.info( + f"send WRITE done nixl notify " + f"request_id={trans_task.request_id} " + f"kv=[{trans_task.start_kv_index},{trans_task.end_kv_index}) " + f"src_page={trans_task.src_page_index} dst_page={trans_task.dst_page_index}" + ) + self.success_queue.put(trans_task) + elif ret == "ERR": + trans_task = self.waiting_dict.pop(trans_task.get_key(), None) + trans_task.error_info = "xfer error" + self.failed_queue.put(trans_task) + elif trans_task.time_out(): + trans_task = self.waiting_dict.pop(trans_task.get_key(), None) + trans_task.error_info = "time out in update_task_status_loop" + self.failed_queue.put(trans_task) + + time.sleep(0.001) + + @log_exception + def success_loop(self): + torch.cuda.set_device(self.device_id) + while True: + trans_task: PDChunckedTransTask = self.success_queue.get() + # 写回后,回收页面 + if trans_task.src_page_index is not None: + self.page_index_queue.put(trans_task.src_page_index) + if trans_task.xfer_handle is not None: + self.transporter.release_xfer_handle(trans_task.xfer_handle) + + ret = trans_task.createRetObj() + ret.first_gen_token_id = None + ret.first_gen_token_logprob = None + self.task_out_queue.put(ret) + + if trans_task.start_trans_time is not None: + logger.info(f"trans task ret success:{ret} cost time: {trans_task.transfer_time()}s") + else: + logger.info(f"trans task ret success:{ret}") + + @log_exception + def fail_loop(self): + torch.cuda.set_device(self.device_id) + while True: + trans_task: PDChunckedTransTask = self.failed_queue.get() + + # 回收页面 + if trans_task.src_page_index is not None: + self.page_index_queue.put(trans_task.src_page_index) + if trans_task.xfer_handle is not None: + self.transporter.release_xfer_handle(trans_task.xfer_handle) + + ret = trans_task.createRetObj() + self.task_out_queue.put(ret) + logger.info(f"trans task ret fail:{ret}") + + if trans_task.error_info is not None: + self._abort(request_id=trans_task.request_id, error_info=trans_task.error_info) + self.transporter.send_error_info_to_decode_node(trans_task=trans_task) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/trans_process_obj.py b/lightllm/server/router/model_infer/mode_backend/pd/trans_process_obj.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/trans_process_obj.py rename to lightllm/server/router/model_infer/mode_backend/pd/trans_process_obj.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/__init__.py deleted file mode 100644 index 4b40544fe9..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .decode_kv_move_manager import start_decode_kv_move_manager_process -from .decode_trans_process import start_decode_trans_process diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py deleted file mode 100644 index b04cbb900a..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ /dev/null @@ -1,384 +0,0 @@ -import torch -import time -import inspect -import threading -import torch.multiprocessing as mp -import collections -import queue -import pickle -from typing import List, Dict, Union, Deque, Optional -from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.server.pd_io_struct import ( - NIXLChunckedTransTask, - NIXLChunckedTransTaskGroup, - NIXLChunckedTransTaskRet, - NixlUpKVStatus, - NIXLAbortReq, -) -from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.server.core.objs import StartArgs -from ..nixl_kv_transporter import NixlKVTransporter -from lightllm.utils.error_utils import log_exception - -logger = init_logger(__name__) - - -def start_decode_trans_process( - args, - device_id, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, - up_status_in_queue: Optional[mp.SimpleQueue], -): - proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, up_status_in_queue)) - proc.start() - assert proc.is_alive() - logger.info(f"prefill trans kv process for device: {device_id} started!") - return proc - - -def _init_env( - args: StartArgs, - device_id: int, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, - up_status_in_queue: Optional[mp.SimpleQueue], -): - torch.backends.cudnn.enabled = False - - try: - torch.cuda.set_device(device_id) - graceful_registry(inspect.currentframe().f_code.co_name) - - task_out_queue.put("proc_start") - - # 从共享内存读取所有rank的mem_manager - node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) - ] - - task_out_queue.put("get_mem_managers_ok") - - manager = _DecodeTransModule( - args=args, - device_id=device_id, - task_in_queue=task_in_queue, - task_out_queue=task_out_queue, - mem_managers=mem_managers, - up_status_in_queue=up_status_in_queue, - ) - assert manager is not None - - while True: - time.sleep(100) - - except Exception as e: - logger.exception(str(e)) - logger.error(f"Fatal error happened in kv trans process: {e}") - pass - - -class _DecodeTransModule: - def __init__( - self, - args: StartArgs, - device_id: int, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], - up_status_in_queue: Optional[mp.SimpleQueue], - ): - self.args = args - self.dp_world_size = self.args.tp // self.args.dp - self.device_id = device_id - self.task_in_queue = task_in_queue - self.task_out_queue = task_out_queue - self.mem_managers = mem_managers - self.up_status_in_queue = up_status_in_queue - cur_mem_manager: MemoryManager = self.mem_managers[device_id] - kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer( - page_num=self.args.nixl_pd_kv_page_num, page_size=self.args.nixl_pd_kv_page_size - ) - self.copy_cuda_stream = torch.cuda.Stream() - self.transporter = NixlKVTransporter( - node_id=self.args.pd_node_id, tp_idx=device_id, kv_move_buffer=kv_move_buffer - ) - self.recv_task_group_queue = queue.Queue() - self.waiting_dict_lock = threading.Lock() - self.waiting_dict: Dict[str, NIXLChunckedTransTask] = {} - self.read_peer_kv_queue = queue.Queue() - self.update_status_task_queue = queue.Queue() - self.ready_page_task_queue = queue.Queue() - self.success_queue = queue.Queue() - self.failed_queue = queue.Queue() - - self.page_index_queue = queue.Queue() - for page_index in range(self.args.nixl_pd_kv_page_num): - self.page_index_queue.put(page_index) - - for func in [ - self.recv_task_loop, - self.dispatch_task_loop, - self.accept_peer_task_loop, - self.read_peer_kv_loop, - self.update_task_status_loop, - self.read_page_to_mems_loop, - self.success_loop, - self.fail_loop, - ]: - threading.Thread(target=func, daemon=True).start() - return - - @log_exception - def recv_task_loop(self): - while True: - obj: Union[NIXLChunckedTransTaskGroup, NIXLAbortReq] = self.task_in_queue.get() - if isinstance(obj, NIXLChunckedTransTaskGroup): - self.recv_task_group_queue.put(obj) - elif isinstance(obj, NIXLAbortReq): - self._abort(cmd=obj) - else: - assert False, f"recv error obj {obj}" - - def _abort(self, cmd: NIXLAbortReq): - # check time_out update - with self.waiting_dict_lock: - keys = list(self.waiting_dict.keys()) - - for key in keys: - with self.waiting_dict_lock: - trans_task = self.waiting_dict.pop(key, None) - - if trans_task is not None and trans_task.request_id == cmd.request_id: - trans_task.error_info = "aborted req" - self.failed_queue.put(trans_task) - continue - - if trans_task is not None: - with self.waiting_dict_lock: - self.waiting_dict[trans_task.get_key()] = trans_task - return - - @log_exception - def dispatch_task_loop(self): - while True: - trans_task_group: NIXLChunckedTransTaskGroup = self.recv_task_group_queue.get() - - with self.waiting_dict_lock: - for task in trans_task_group.task_list: - if task.transfer_kv_num() != 0: - self.waiting_dict[task.get_key()] = task - else: - task.start_trans_time = time.time() - self.success_queue.put((None, task)) - - # up status - task = trans_task_group.task_list[0] - - decode_node_info = NIXLDecodeNodeInfo( - decode_node_id=self.args.pd_node_id, - pd_master_node_id=task.pd_master_node_id, - agent_name=self.transporter.agent_name, - agent_metadata=self.transporter.agent_metadata, - num_pages=self.transporter.num_pages, - page_reg_desc=self.transporter.local_page_mem_desc, - request_id=task.request_id, - ready_kv_len=task.start_kv_index, - ) - - up_status = NixlUpKVStatus( - group_request_id=task.request_id, - pd_master_node_id=task.pd_master_node_id, - nixl_params=pickle.dumps(decode_node_info), - ) - - self.up_status_in_queue.put(up_status) - - @log_exception - def accept_peer_task_loop( - self, - ): - torch.cuda.set_device(self.device_id) - while True: - if len(self.waiting_dict) == 0: - time.sleep(0.001) - continue - - # notify update - try: - notifies_dict = self.transporter.get_new_notifs() - except BaseException as e: - logger.error(f"get new notifies failed: {str(e)}") - logger.exception(str(e)) - notifies_dict = {} - - if notifies_dict: - for remote_agent_name, _notify_list in notifies_dict.items(): - for notify in _notify_list: - try: - notify_obj = pickle.loads(notify) - except: - notify_obj = None - - if isinstance(notify_obj, NIXLChunckedTransTask): - remote_trans_task = notify_obj - key = remote_trans_task.get_key() - logger.info(f"recv peer trans task {remote_trans_task.to_str()}") - with self.waiting_dict_lock: - local_trans_task: NIXLChunckedTransTask = self.waiting_dict.pop(key, None) - - if local_trans_task is None: - remote_trans_task.error_info = "peer not find" - try: - self.transporter.send_notify_to_prefill_node( - prefill_agent_name=remote_agent_name, - notify=pickle.dumps(remote_trans_task.createRetObj()), - ) - except BaseException as e: - logger.error(f"send notify to prefill node failed: {str(e)}") - logger.exception(str(e)) - self.transporter.remove_remote_agent(peer_name=remote_agent_name) - else: - local_trans_task.nixl_src_page_index = remote_trans_task.nixl_src_page_index - - local_trans_task.prefill_agent_name = remote_trans_task.prefill_agent_name - local_trans_task.prefill_agent_metadata = remote_trans_task.prefill_agent_metadata - local_trans_task.prefill_num_pages = remote_trans_task.prefill_num_pages - local_trans_task.prefill_page_reg_desc = remote_trans_task.prefill_page_reg_desc - - local_trans_task.first_gen_token_id = remote_trans_task.first_gen_token_id - local_trans_task.first_gen_token_logprob = remote_trans_task.first_gen_token_logprob - - self.read_peer_kv_queue.put(local_trans_task) - - self._check_tasks_time_out() - - def _check_tasks_time_out(self): - # check time_out update - with self.waiting_dict_lock: - keys = list(self.waiting_dict.keys()) - - for key in keys: - with self.waiting_dict_lock: - trans_task = self.waiting_dict.pop(key, None) - - if trans_task is not None and trans_task.time_out(): - trans_task.error_info = "time out in accept_peer_task_loop" - self.failed_queue.put(trans_task) - continue - - if trans_task is not None: - with self.waiting_dict_lock: - self.waiting_dict[trans_task.get_key()] = trans_task - return - - @log_exception - def read_peer_kv_loop(self): - torch.cuda.set_device(self.device_id) - while True: - page_index = self.page_index_queue.get() - local_trans_task = self.read_peer_kv_queue.get() - local_trans_task: NIXLChunckedTransTask = local_trans_task - local_trans_task.nixl_dst_page_index = page_index - - if local_trans_task.time_out(): - local_trans_task.error_info = "time out in read_peer_kv_loop" - self.failed_queue.put(local_trans_task) - continue - - try: - xfer_handle = self.transporter.read_blocks_paged(trans_task=local_trans_task) - local_trans_task.xfer_handle = xfer_handle - local_trans_task.start_trans_time = time.time() - self.update_status_task_queue.put(local_trans_task) - except BaseException as e: - logger.error(f"read_blocks_paged node failed: {local_trans_task.to_str()}") - logger.exception(str(e)) - self.transporter.remove_remote_agent(peer_name=local_trans_task.prefill_agent_name) - local_trans_task.error_info = f"read_blocks_paged failed: {str(e)}" - self.failed_queue.put(local_trans_task) - continue - - @log_exception - def update_task_status_loop( - self, - ): - while True: - trans_task: NIXLChunckedTransTask = self.update_status_task_queue.get() - - while True: - ret = self.transporter.check_task_status(trans_task=trans_task) - if ret == "DONE": - self.ready_page_task_queue.put(trans_task) - break - elif ret == "ERR": - trans_task.error_info = "xfer error" - self.failed_queue.put(trans_task) - break - elif trans_task.time_out(): - trans_task.error_info = "time out in update_task_status_loop" - self.failed_queue.put(trans_task) - break - - time.sleep(0.001) - - @log_exception - def read_page_to_mems_loop(self): - torch.cuda.set_device(self.device_id) - while True: - trans_task: NIXLChunckedTransTask = self.ready_page_task_queue.get() - # 将数据写回 mem manger - with torch.cuda.stream(stream=self.copy_cuda_stream): - cur_mem = self.mem_managers[self.device_id] - cur_mem.read_page_kv_move_buffer_to_mem( - mem_indexes=trans_task.mem_indexes, - page_index=trans_task.nixl_dst_page_index, - dp_index=trans_task.decode_dp_index, - mem_managers=self.mem_managers, - dp_world_size=self.dp_world_size, - ) - sync_event = torch.cuda.Event() - sync_event.record() - - self.success_queue.put((sync_event, trans_task)) - return - - @log_exception - def success_loop(self): - torch.cuda.set_device(self.device_id) - while True: - sync_event, trans_task = self.success_queue.get() - trans_task: NIXLChunckedTransTask = trans_task - sync_event: Optional[torch.cuda.Event] = sync_event - # 兼容传输kv 数量为0的时候, sync_event 为 None的情况。 - if sync_event is not None: - sync_event.synchronize() - - if trans_task.nixl_dst_page_index is not None: - self.page_index_queue.put(trans_task.nixl_dst_page_index) - - if trans_task.xfer_handle is not None: - self.transporter.release_xfer_handle(trans_task.xfer_handle) - - ret = trans_task.createRetObj() - self.task_out_queue.put(ret) - logger.info(f"trans task ret success:{ret} cost time: {trans_task.transfer_time()} s") - - @log_exception - def fail_loop(self): - torch.cuda.set_device(self.device_id) - while True: - trans_task: NIXLChunckedTransTask = self.failed_queue.get() - - # 回收页面 - if trans_task.nixl_dst_page_index is not None: - self.page_index_queue.put(trans_task.nixl_dst_page_index) - if trans_task.xfer_handle is not None: - self.transporter.release_xfer_handle(trans_task.xfer_handle) - ret = trans_task.createRetObj() - self.task_out_queue.put(ret) - logger.info(f"trans task ret fail:{ret}") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py deleted file mode 100644 index 134fbd5027..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py +++ /dev/null @@ -1,195 +0,0 @@ -import pickle -import copy -from dataclasses import dataclass -from collections import defaultdict -from typing import Dict, List, Any, Optional, Tuple -from torch import Tensor -from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NixlAgentMetadata, NIXLChunckedTransTaskRet -from lightllm.utils.log_utils import init_logger - - -logger = init_logger(__name__) - -try: - from nixl._api import nixl_agent as NixlWrapper - from nixl._api import nixlBind - - logger.info("Nixl is available") -except ImportError: - logger.warning("nixl is not installed, which is required for pd disagreggation!!!") - NixlWrapper = None - - -class NixlKVTransporter: - def __init__(self, node_id: int, tp_idx: int, kv_move_buffer: Tensor): - self.node_id = node_id - self.tp_idx = tp_idx - self.nixl_agent = NixlWrapper(self.agent_name, None) - self._register_kv_move_buffer(kv_move_buffer=kv_move_buffer) - self.remote_agents: Dict[str, NixlAgentMetadata] = {} - return - - @property - def agent_name(self) -> str: - return f"{self.node_id}_{self.tp_idx}" - - @property - def agent_metadata(self): - return self.nixl_agent.get_agent_metadata() - - @property - def local_page_mem_desc(self): - return self.nixl_agent.get_serialized_descs(self.page_reg_desc) - - def get_new_notifs(self) -> Dict[str, list[bytes]]: - return self.nixl_agent.get_new_notifs() - - def _register_kv_move_buffer(self, kv_move_buffer: Tensor): - self.num_pages, self.page_size, self.num_layers, self.kv_head_num, self.head_dims = kv_move_buffer.shape - self.dtype_byte_size = kv_move_buffer.element_size() - self.page_len = self.page_size * self.num_layers * self.kv_head_num * self.head_dims * self.dtype_byte_size - self.page_reg_desc = self.nixl_agent.register_memory(kv_move_buffer) - self.page_local_xfer_handles = self._create_paged_xfer_handles(self.page_reg_desc, self.num_pages) - - def _create_paged_xfer_handles(self, reg_desc: "nixlBind.nixlRegDList", page_num: int, agent_name: str = ""): - base_addr, _, device_id, _ = reg_desc[0] - pages_data = [] - for page_id in range(page_num): - pages_data.append((base_addr + page_id * self.page_len, self.page_len, device_id)) - descs = self.nixl_agent.get_xfer_descs(pages_data, "VRAM") - return self.nixl_agent.prep_xfer_dlist(agent_name, descs, "VRAM") - - def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata): - if remote_agent.agent_name in self.remote_agents: - return - - peer_name = self.nixl_agent.add_remote_agent(remote_agent.agent_metadata) - if isinstance(peer_name, bytes): - peer_name = peer_name.decode() - - assert ( - peer_name == remote_agent.agent_name - ), f"Peer name {peer_name} does not match remote name {remote_agent.agent_name}" - - page_mem_desc = self.nixl_agent.deserialize_descs(remote_agent.page_reg_desc) - kv_page_xfer_handles = self._create_paged_xfer_handles( - page_mem_desc, remote_agent.num_pages, agent_name=peer_name - ) - remote_agent.page_xfer_handles = kv_page_xfer_handles - - logger.info(f"Added remote agent {peer_name} with mem desc {page_mem_desc}") - - self.remote_agents[remote_agent.agent_name] = remote_agent - return - - def remove_remote_agent(self, peer_name: str): - if peer_name in self.remote_agents: - try: - remote_agent: NixlAgentMetadata = self.remote_agents.pop(peer_name, None) - assert remote_agent.agent_name == peer_name - self.nixl_agent.remove_remote_agent(remote_agent.agent_name) - if remote_agent.page_xfer_handles is not None: - self.nixl_agent.release_dlist_handle(remote_agent.page_xfer_handles) - except BaseException as e: - logger.error(f"remove remote agent {peer_name} failed") - logger.exception(str(e)) - else: - logger.warning(f"try to remove remote agent, but peer name {peer_name} agent did not exist") - - def send_readtask_to_decode_node(self, trans_task: NIXLChunckedTransTask): - """ - prefill node call this function to send read task to decode node - """ - decode_agent_name = trans_task.decode_agent_name - if decode_agent_name not in self.remote_agents: - logger.warning(f"decode_agent_name {decode_agent_name} not exist") - _remote_agent = trans_task.create_decode_agent_obj() - self.connect_add_remote_agent(_remote_agent) - - # 将页面读取任务发送给 decode 节点 - remote_agent: NixlAgentMetadata = self.remote_agents[decode_agent_name] - assert trans_task.nixl_src_page_index is not None - new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) - - new_trans_task.decode_agent_name = None - new_trans_task.decode_agent_metadata = None - new_trans_task.decode_num_pages = None - new_trans_task.decode_page_reg_desc = None - - new_trans_task.prefill_agent_name = self.agent_name - new_trans_task.prefill_agent_metadata = self.agent_metadata - new_trans_task.prefill_num_pages = self.num_pages - new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc - - # 不需要传输细节的 mem_indexes 信息 - new_trans_task.mem_indexes = None - self.nixl_agent.send_notif( - remote_agent.agent_name, - pickle.dumps(new_trans_task), - ) - return - - def send_notify_to_prefill_node(self, prefill_agent_name: str, notify: bytes): - self.nixl_agent.send_notif(remote_agent_name=prefill_agent_name, notif_msg=notify) - return - - def read_blocks_paged( - self, - trans_task: NIXLChunckedTransTask, - ) -> int: - """ - decode node call this function to read kv blocks from prefill node - """ - prefill_agent_name = trans_task.prefill_agent_name - if prefill_agent_name not in self.remote_agents: - logger.warning(f"prefill_agent_name {prefill_agent_name} not exist") - _remote_agent = trans_task.create_prefill_agent_obj() - self.connect_add_remote_agent(_remote_agent) - - assert trans_task.nixl_src_page_index is not None and trans_task.nixl_dst_page_index is not None - remote_agent: NixlAgentMetadata = self.remote_agents[prefill_agent_name] - src_handle = remote_agent.page_xfer_handles - dst_handle = self.page_local_xfer_handles - notify_obj = NIXLChunckedTransTaskRet( - request_id=trans_task.request_id, - start_kv_index=trans_task.start_kv_index, - end_kv_index=trans_task.end_kv_index, - has_error=False, - error_info=None, - first_gen_token_id=None, - first_gen_token_logprob=None, - ) - handle = self.nixl_agent.make_prepped_xfer( - "READ", - dst_handle, - [trans_task.nixl_dst_page_index], - src_handle, - [trans_task.nixl_src_page_index], - pickle.dumps(notify_obj), - ) - if not handle: - raise RuntimeError(f"make_prepped_xfer failed for task: {trans_task.to_str()}") - - self.nixl_agent.transfer(handle) - - return handle - - def check_task_status(self, trans_task: NIXLChunckedTransTask) -> str: - assert trans_task.xfer_handle is not None - handle = trans_task.xfer_handle - xfer_state = self.nixl_agent.check_xfer_state(handle) - if xfer_state == "ERR": - logger.warning(f"Transfer failed with trans task {trans_task.to_str()} for handle {handle}") - return xfer_state - - def release_xfer_handle(self, handle): - self.nixl_agent.release_xfer_handle(handle=handle) - return - - def shutdown(self): - self.nixl_agent.deregister_memory(self.page_reg_desc) - self.nixl_agent.release_dlist_handle(self.page_local_xfer_handles) - agent_names = list(self.remote_agents.keys()) - for agent_name in agent_names: - self.remove_remote_agent(agent_name) - return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/__init__.py deleted file mode 100644 index 4100e14eda..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .prefill_trans_process import start_prefill_trans_process -from .prefill_kv_move_manager import start_prefill_kv_move_manager_process diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py deleted file mode 100644 index 063ce5c6a9..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ /dev/null @@ -1,280 +0,0 @@ -import torch -import time -import inspect -import threading -import torch.multiprocessing as mp -import collections -import queue -import pickle -from typing import List, Dict, Union, Deque, Optional -from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskRet -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.server.core.objs import StartArgs -from ..nixl_kv_transporter import NixlKVTransporter -from lightllm.utils.error_utils import log_exception - - -logger = init_logger(__name__) - - -def start_prefill_trans_process( - args, - device_id, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, - up_status_in_queue: Optional[mp.SimpleQueue] = None, -): - proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue)) - proc.start() - assert proc.is_alive() - logger.info(f"prefill trans kv process for device: {device_id} started!") - return proc - - -def _init_env( - args: StartArgs, - device_id: int, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, -): - torch.backends.cudnn.enabled = False - - try: - torch.cuda.set_device(device_id) - graceful_registry(inspect.currentframe().f_code.co_name) - task_out_queue.put("proc_start") - - # 从共享内存读取所有rank的mem_manager - node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) - ] - - task_out_queue.put("get_mem_managers_ok") - - manager = _PrefillTransModule( - args=args, - device_id=device_id, - task_in_queue=task_in_queue, - task_out_queue=task_out_queue, - mem_managers=mem_managers, - ) - assert manager is not None - - while True: - time.sleep(100) - - except Exception as e: - logger.exception(str(e)) - logger.error(f"Fatal error happened in kv trans process: {e}") - pass - - -class _PrefillTransModule: - def __init__( - self, - args: StartArgs, - device_id: int, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], - ) -> None: - self.args = args - self.dp_world_size = self.args.tp // self.args.dp - self.device_id = device_id - self.task_in_queue = task_in_queue - self.task_out_queue = task_out_queue - self.mem_managers = mem_managers - - cur_mem_manager: MemoryManager = self.mem_managers[device_id] - kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer( - page_num=self.args.nixl_pd_kv_page_num, page_size=self.args.nixl_pd_kv_page_size - ) - self.copy_cuda_stream = torch.cuda.Stream() - self.transporter = NixlKVTransporter( - node_id=self.args.pd_node_id, tp_idx=device_id, kv_move_buffer=kv_move_buffer - ) - self.waiting_dict_lock = threading.Lock() - self.waiting_dict: Dict[str, NIXLChunckedTransTask] = {} - - self.local_copy_kv_queue = queue.Queue() - self.notify_peer_read_kv_queue = queue.Queue() - self.success_queue = queue.Queue() - self.failed_queue = queue.Queue() - - self.page_index_queue = queue.Queue() - for page_index in range(self.args.nixl_pd_kv_page_num): - self.page_index_queue.put(page_index) - - for func in [ - self.recv_task_loop, - self.local_copy_kv_loop, - self.notify_peer_to_read_kv_loop, - self.update_task_status_loop, - self.success_loop, - self.fail_loop, - ]: - threading.Thread(target=func, daemon=True).start() - return - - @log_exception - def recv_task_loop(self): - torch.cuda.set_device(self.device_id) - - while True: - page_index = self.page_index_queue.get() - trans_task: NIXLChunckedTransTask = self.task_in_queue.get() - trans_task.nixl_src_page_index = page_index - - # 初次校验 time out - if trans_task.time_out(): - trans_task.error_info = "time out in recv_task_loop" - self.failed_queue.put(trans_task) - else: - self.local_copy_kv_queue.put(trans_task) - - @log_exception - def local_copy_kv_loop(self): - torch.cuda.set_device(self.device_id) - while True: - trans_task: NIXLChunckedTransTask = self.local_copy_kv_queue.get() - - # 将kv 数据拷贝到 page 上,然后传输给 decode node,让其进行读取。 - with torch.cuda.stream(stream=self.copy_cuda_stream): - cur_mem = self.mem_managers[self.device_id] - cur_mem.write_mem_to_page_kv_move_buffer( - trans_task.mem_indexes, - page_index=trans_task.nixl_src_page_index, - dp_index=trans_task.prefill_dp_index, - mem_managers=self.mem_managers, - dp_world_size=self.dp_world_size, - ) - sync_event = torch.cuda.Event() - sync_event.record() - - self.notify_peer_read_kv_queue.put((sync_event, trans_task)) - return - - @log_exception - def notify_peer_to_read_kv_loop(self): - torch.cuda.set_device(self.device_id) - while True: - sync_event, trans_task = self.notify_peer_read_kv_queue.get() - trans_task: NIXLChunckedTransTask = trans_task - sync_event: torch.cuda.Event = sync_event - - sync_event.synchronize() - - trans_task.start_trans_time = time.time() - with self.waiting_dict_lock: - self.waiting_dict[trans_task.get_key()] = trans_task - - try: - self.transporter.send_readtask_to_decode_node(trans_task=trans_task) - except BaseException as e: - logger.error(f"send readtask to decode node failed: {trans_task.to_str()}") - logger.exception(str(e)) - self.transporter.remove_remote_agent(peer_name=trans_task.decode_agent_name) - - with self.waiting_dict_lock: - trans_task = self.waiting_dict.pop(trans_task.get_key(), None) - - if trans_task is not None: - trans_task.error_info = f"send readtask to decode node failed: {str(e)}" - self.failed_queue.put(trans_task) - continue - - logger.info(f"send readtask to decode: {trans_task.to_str()}") - return - - @log_exception - def update_task_status_loop( - self, - ): - while True: - if len(self.waiting_dict) == 0: - time.sleep(0.001) - continue - - # notify update - try: - notifies_dict = self.transporter.get_new_notifs() - except BaseException as e: - logger.error(f"get new notifies failed: {str(e)}") - logger.exception(str(e)) - notifies_dict = {} - - if notifies_dict: - for _, _notify_list in notifies_dict.items(): - for notify in _notify_list: - try: - notify_obj = pickle.loads(notify) - except: - notify_obj = None - - if isinstance(notify_obj, NIXLChunckedTransTaskRet): - key = notify_obj.get_key() - with self.waiting_dict_lock: - trans_task = self.waiting_dict.pop(key, None) - - if trans_task is not None: - trans_task.error_info = notify_obj.error_info - if trans_task.error_info is not None: - self.failed_queue.put(trans_task) - else: - self.success_queue.put(trans_task) - else: - logger.warning(f"can not find trans task for ret: {notify_obj}") - - # check time_out update - self._check_tasks_time_out() - - def _check_tasks_time_out(self): - with self.waiting_dict_lock: - keys = list(self.waiting_dict.keys()) - - for key in keys: - with self.waiting_dict_lock: - trans_task = self.waiting_dict.pop(key, None) - - if trans_task is not None and trans_task.time_out(): - trans_task.error_info = "time out in update_task_status_loop" - self.failed_queue.put(trans_task) - continue - - if trans_task is not None: - with self.waiting_dict_lock: - self.waiting_dict[trans_task.get_key()] = trans_task - return - - @log_exception - def success_loop(self): - torch.cuda.set_device(self.device_id) - while True: - trans_task: NIXLChunckedTransTask = self.success_queue.get() - # 写回后,回收页面 - if trans_task.nixl_src_page_index is not None: - self.page_index_queue.put(trans_task.nixl_src_page_index) - - ret = trans_task.createRetObj() - ret.first_gen_token_id = None - ret.first_gen_token_logprob = None - self.task_out_queue.put(ret) - logger.info(f"trans task ret success:{ret} cost time: {trans_task.transfer_time()}s") - - @log_exception - def fail_loop(self): - torch.cuda.set_device(self.device_id) - while True: - trans_task: NIXLChunckedTransTask = self.failed_queue.get() - - # 回收页面 - if trans_task.nixl_src_page_index is not None: - self.page_index_queue.put(trans_task.nixl_src_page_index) - - ret = trans_task.createRetObj() - self.task_out_queue.put(ret) - logger.info(f"trans task ret fail:{ret}") diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 408b173371..864a7405b7 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -22,14 +22,10 @@ XgrammarBackend, DPChunkedPrefillBackend, DiversehBackend, - DecodeNode, - DPForDecodeNode, - ChunckedPrefillForPrefillNode, - DPChunkedForPrefillNode, - NIXLChunckedPrefillForPrefillNode, - NIXLDPChunkedForPrefillNode, - NIXLDecodeNode, - NIXLDPForDecodeNode, + PDChunkedPrefillForPrefillNode, + PDDPChunkedForPrefillNode, + PDDecodeNode, + PDDPForDecodeNode, ) from lightllm.server.router.model_infer.mode_backend.redundancy_expert_manager import RedundancyExpertManager from lightllm.server.core.objs.start_args_type import StartArgs @@ -69,31 +65,18 @@ def exposed_init_model(self, kvargs): assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" is_prefill_node = self.args.run_mode == "prefill" is_decode_node = self.args.run_mode == "decode" - is_nixl_prefill_node = self.args.run_mode == "nixl_prefill" - is_nixl_decode_node = self.args.run_mode == "nixl_decode" if is_prefill_node: if self.args.dp > 1: - self.backend = DPChunkedForPrefillNode(self.info_queue) + self.backend = PDDPChunkedForPrefillNode(self.info_queue) else: - self.backend = ChunckedPrefillForPrefillNode(self.info_queue) - elif is_nixl_prefill_node: - if self.args.dp > 1: - self.backend = NIXLDPChunkedForPrefillNode(self.info_queue) - else: - self.backend = NIXLChunckedPrefillForPrefillNode(self.info_queue) + self.backend = PDChunkedPrefillForPrefillNode(self.info_queue) elif is_decode_node: if self.args.dp > 1: - self.backend = DPForDecodeNode(self.info_queue) + self.backend = PDDPForDecodeNode(self.info_queue) else: - self.backend = DecodeNode(self.info_queue) - - elif is_nixl_decode_node: - if self.args.dp > 1: - self.backend = NIXLDPForDecodeNode(self.info_queue) - else: - self.backend = NIXLDecodeNode(self.info_queue) + self.backend = PDDecodeNode(self.info_queue) elif self.args.dp > 1: self.backend = DPChunkedPrefillBackend() @@ -169,7 +152,6 @@ def _init_env( rank_in_node, node_world_size, info_queue, - router_lock, socket_path, success_event, ): @@ -180,11 +162,6 @@ def _init_env( setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::model_infer:RANK{rank}") start_parent_check_thread() - # 将调度锁注册到全局的共享变量中 - from lightllm.common.basemodel.infer_lock import g_router_lock - - g_router_lock.obj = router_lock - model_rpc_server = ModelRpcServer(args, rank, rank_in_node, node_world_size, info_queue) # Start rpyc server with Unix socket t = ThreadedServer(model_rpc_server, socket_path=socket_path, protocol_config={"allow_pickle": True}) @@ -200,7 +177,6 @@ async def start_model_process( rank_in_node, node_world_size, info_queue: mp.Queue, - router_lock, ): import lightllm.utils.rpyc_fix_utils as _ @@ -217,7 +193,6 @@ async def start_model_process( rank_in_node, node_world_size, info_queue, - router_lock, socket_path, success_event, ), diff --git a/lightllm/server/router/profiler_service.py b/lightllm/server/router/profiler_service.py new file mode 100644 index 0000000000..dd27d8d399 --- /dev/null +++ b/lightllm/server/router/profiler_service.py @@ -0,0 +1,52 @@ +import threading + +import rpyc + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class RouterProfilerCmdQueue: + def __init__(self): + self.cmds = [] + self.lock = threading.Lock() + + def append(self, cmd: str): + with self.lock: + self.cmds.append(cmd) + return + + def pop(self): + with self.lock: + if not self.cmds: + return None + return self.cmds.pop(0) + + +class RouterProfilerService(rpyc.Service): + def __init__(self, profiler_cmd_queue: RouterProfilerCmdQueue): + super().__init__() + self.profiler_cmd_queue = profiler_cmd_queue + + def exposed_profiler_cmd(self, cmd: str): + self.profiler_cmd_queue.append(cmd) + return + + +def start_router_profiler_server(args, profiler_cmd_queue: RouterProfilerCmdQueue): + if not args.enable_profiling: + return None, None + + from rpyc.utils.server import ThreadedServer + import lightllm.utils.rpyc_fix_utils as _ + + server = ThreadedServer( + RouterProfilerService(profiler_cmd_queue), + port=args.router_profiler_port, + protocol_config={"allow_pickle": True}, + ) + thread = threading.Thread(target=server.start, daemon=True) + thread.start() + logger.info(f"router profiler rpyc server started on port {args.router_profiler_port}") + return server, thread diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 067332d945..eb991bb4b9 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -1,7 +1,6 @@ -from .chunked_prefill.impl_for_pd_decode import QueueForPDDecode from .chunked_prefill.impl import ChunkedPrefillQueue from .chunked_prefill.beam_impl import ChunkedBeamContinuesBatchQueue -from .chunked_prefill.impl_for_nixl_pd import NIXLPDQueue +from .chunked_prefill.impl_for_pd import PDQueue from .dp_base_queue import DpQueue @@ -14,12 +13,8 @@ def _get_req_queue_class(args, router, dp_size_in_node: int): return ChunkedPrefillQueue if args.first_token_constraint_mode: return ChunkedPrefillQueue - if args.run_mode in ["decode"]: - return QueueForPDDecode - if args.run_mode in ["prefill"]: - return ChunkedPrefillQueue - if args.run_mode in ["nixl_prefill", "nixl_decode"]: - return NIXLPDQueue + if args.run_mode in ["prefill", "decode"]: + return PDQueue if args.disable_chunked_prefill: # 虽然也使用chuncked prefill queue 但是由于 args.chunked_prefill_size = args.max_req_total_len diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 73113a59b8..0d1ffe6967 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -2,7 +2,6 @@ from lightllm.utils.infer_utils import calculate_time from ..batch import Batch, Req from lightllm.server.core.objs import FinishStatus -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.config_utils import get_fixed_kv_len from lightllm.server.core.objs import StartArgs @@ -46,9 +45,7 @@ def is_busy(self): # 计算当前所有的token使用量, 如果使用了dynamic prompt cache, 使用的token量中不包含,cache tree 中未被引用的数据。 cur_all_used_tokens = self.router.get_used_tokens(self.dp_index) # 判断当前服务是否处于token使用率过高的状态,过高的情况下,调度要偏向保守 - cur_token_ratio = ( - cur_all_used_tokens + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - ) / self.max_total_tokens + cur_token_ratio = cur_all_used_tokens / self.max_total_tokens is_busy = cur_token_ratio >= self.router_token_ratio return is_busy @@ -71,7 +68,7 @@ def generate_new_batch(self, current_batch: Batch): def calcu_batch_token_load(self, current_batch: Batch): if current_batch is None: - return 0, self.router.shared_token_load.get_frozened_token_count(self.dp_index) / self.max_total_tokens + return 0, 0.0 else: return self._calcu_batch_token_load_batch_not_none(current_batch) @@ -82,8 +79,7 @@ def update_token_load(self, current_batch: Batch, force_update=False): if self.router.shared_token_load.need_update_dynamic_max_load() or force_update: estimated_peak_token_count, dynamic_max_load = self.calcu_batch_token_load(current_batch) token_ratio1 = self.router.get_used_tokens(self.dp_index) / self.router.max_total_token_num - with g_router_lock.obj: - self.router.shared_token_load.set_current_load(token_ratio1, self.dp_index) - self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, self.dp_index) - self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, self.dp_index) + self.router.shared_token_load.set_current_load(token_ratio1, self.dp_index) + self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, self.dp_index) + self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, self.dp_index) return diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index 23f94de704..63084d9d3b 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -49,10 +49,7 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new # prefill token 计算, 因为对beam的prefill计算过程是共享的,所以只计算一个请求对应的token数量 new_batch_first_router_need_tokens += req.get_first_router_need_tokens() - ok_token_num = ( - need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens - ) + ok_token_num = need_max_token_num < self.max_total_tokens ok_req_num = len(self.cache_len_list) <= self.running_max_req_size @@ -62,8 +59,7 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new if ok_token_num and ok_req_num and ok_prefill: self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index) self.router.shared_token_load.set_dynamic_max_load( - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, + need_max_token_num / self.max_total_tokens, self.dp_index, ) return True, new_batch_first_router_need_tokens @@ -167,6 +163,5 @@ def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): need_max_token_num = max(need_max_token_num, cumsum_len + index * cur_ouput_len) return ( need_max_token_num, - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, + need_max_token_num / self.max_total_tokens, ) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 884b5930b0..e82cc7e181 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -2,7 +2,6 @@ import numpy as np from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -37,27 +36,22 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens size_array = np.arange(1, len(self.cache_len_list) + 1, 1) need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - with g_router_lock.obj: - ok_token_num = ( - need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens - ) + ok_token_num = need_max_token_num < self.max_total_tokens - ok_req_num = len(self.cache_len_list) <= self.running_max_req_size + ok_req_num = len(self.cache_len_list) <= self.running_max_req_size - new_batch_first_router_need_tokens += req.get_first_router_need_tokens() - ok_prefill = new_batch_first_router_need_tokens <= self.batch_max_tokens + new_batch_first_router_need_tokens += req.get_first_router_need_tokens() + ok_prefill = new_batch_first_router_need_tokens <= self.batch_max_tokens - if ok_token_num and ok_req_num and ok_prefill: - self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index) - self.router.shared_token_load.set_dynamic_max_load( - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, - self.dp_index, - ) - return True, new_batch_first_router_need_tokens - else: - return False, new_batch_first_router_need_tokens + if ok_token_num and ok_req_num and ok_prefill: + self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index) + self.router.shared_token_load.set_dynamic_max_load( + need_max_token_num / self.max_total_tokens, + self.dp_index, + ) + return True, new_batch_first_router_need_tokens + else: + return False, new_batch_first_router_need_tokens # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch): @@ -121,9 +115,7 @@ def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): else: need_max_token_num = 0 - with g_router_lock.obj: - return ( - need_max_token_num, - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, - ) + return ( + need_max_token_num, + need_max_token_num / self.max_total_tokens, + ) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py similarity index 98% rename from lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py rename to lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py index 3b831c92a6..5ec09f5760 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py @@ -3,13 +3,12 @@ from typing import Tuple from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) -class NIXLPDQueue(BaseQueue): +class PDQueue(BaseQueue): def __init__(self, args, router, dp_index, dp_size_in_node) -> None: super().__init__(args, router, dp_index, dp_size_in_node) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py deleted file mode 100644 index 4c2ebf7c00..0000000000 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ /dev/null @@ -1,82 +0,0 @@ -import time -import uuid -import numpy as np -from typing import List -from lightllm.utils.infer_utils import calculate_time -from ...batch import Batch, Req -from lightllm.server.router.req_queue.base_queue import BaseQueue -from lightllm.common.basemodel.infer_lock import g_router_lock -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class QueueForPDDecode(BaseQueue): - def __init__(self, args, router, dp_index, dp_size_in_node) -> None: - super().__init__(args, router, dp_index, dp_size_in_node) - - def _init_cache_list(self, current_batch: Batch, is_busy): - if current_batch is not None: - self.cache_len_list = [ - req.get_tuple_tokens(is_busy, self.router.router_statics.ema_req_out_len) - for req in current_batch.reqs - if req.sample_params.suggested_dp_index == self.dp_index - ] - else: - self.cache_len_list = [] - return - - # @calculate_time(show=True, min_cost_ms=10) - def generate_new_batch(self, current_batch: Batch): - if len(self.waiting_req_list) == 0: - return None - - # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 - exist_req_num = self.get_batch_dp_req_size(current_batch) - req_is_full = exist_req_num >= self.running_max_req_size - if req_is_full: - return None - - can_run_list = [] - abort_req_list = [] - aborted_count = 0 - for req in self.waiting_req_list: - if req.is_aborted: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. - # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏 - aborted_count += 1 - abort_req_list.append(req) - continue - if exist_req_num + len(can_run_list) + 1 <= self.batch_max_tokens: - can_run_list.append(req) - else: - break - new_batch = None - if len(can_run_list) != 0: - new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) - for req in abort_req_list: - req: Req = req - logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") - self.free_aborted_req_cpu_cache_pages(req) - self.router.shm_req_manager.put_back_req_obj(req) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] - return new_batch - - def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): - is_busy = self.is_busy() - self._init_cache_list(current_batch, is_busy) - if len(self.cache_len_list) != 0: - self.cache_len_list.sort(key=lambda x: -x[1]) - left_out_len_array = np.array([e[1] for e in self.cache_len_list]) - has_run_len_array = np.array([e[0] for e in self.cache_len_list]) - cum_run_len_array = np.cumsum(has_run_len_array) - size_array = np.arange(1, len(self.cache_len_list) + 1, 1) - need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - else: - need_max_token_num = 0 - with g_router_lock.obj: - return ( - need_max_token_num, - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, - ) diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index a73823b8b7..866e1b9f42 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -3,7 +3,6 @@ from ..batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.server.router.req_queue.dp_balancer import get_dp_balancer -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -70,8 +69,7 @@ def update_token_load(self, current_batch: Batch, force_update=False): current_batch ) token_ratio1 = self.router.get_used_tokens(dp_index) / self.router.max_total_token_num - with g_router_lock.obj: - self.router.shared_token_load.set_current_load(token_ratio1, dp_index) - self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index) - self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index) + self.router.shared_token_load.set_current_load(token_ratio1, dp_index) + self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index) + self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index) return diff --git a/lightllm/server/router/token_load.py b/lightllm/server/router/token_load.py index e4ce4b8352..45fa34d5da 100644 --- a/lightllm/server/router/token_load.py +++ b/lightllm/server/router/token_load.py @@ -20,7 +20,7 @@ def __init__(self, name, dp_size_in_node) -> None: f"{name}_ext_infos", shape=( self.dp_size_in_node, - 2, + 1, ), dtype=np.int64, ) @@ -40,19 +40,6 @@ def add_estimated_peak_token_count(self, value: int, index: int): def get_estimated_peak_token_count(self, index: int) -> int: return self.shared_token_infos.arr[index, 0] - # 记录系统被临时固定的不能被使用的token数,主要在于 pd 分离的模式下 - # 推理系统需要在 kv 传输时临时固定一些 token, 防止调度系统估计失误,导致调度问题 - def set_frozened_token_count(self, obj: int, index: int): - self.shared_token_infos.arr[index, 1] = obj - return - - def get_frozened_token_count(self, index: int) -> int: - return self.shared_token_infos.arr[index, 1] - - def add_frozened_token_count(self, value: int, index: int): - self.shared_token_infos.arr[index, 1] += value - return - # current_load 当前使用token量,估计的负载 def set_current_load(self, value, index: int): self.shared_token_load.arr[index, 0] = value diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 25726b2578..e1a4e421d1 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -31,7 +31,9 @@ from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer from ..models.gemma3.model import Gemma3Tokenizer +from ..models.gemma4.tokenizer import Gemma4Tokenizer from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer +from ..models import deepseek3_2 # noqa: F401 # registers the deepseek_v32 config with transformers # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" @@ -130,5 +132,13 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + elif model_type == "gemma4": + image_processor = None + if "vision_config" in model_cfg and model_cfg["vision_config"] is not None: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(tokenizer_name) + image_processor = processor.image_processor + tokenizer = Gemma4Tokenizer(tokenizer, model_cfg, image_processor=image_processor) return tokenizer diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index a165be78f2..1dffdaf681 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -84,7 +84,7 @@ async def wait_to_model_ready(self): "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, - "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "max_batch_size": max(self.infer_batch_size // self.vit_dp, 1), "vit_attn_backend": self.vit_attn_backend, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index ae3c4204db..3e74793634 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -4,12 +4,13 @@ import uuid import os import multiprocessing +import setproctitle from lightllm.utils.retry_utils import retry from rpyc.utils.factory import unix_connect from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer from lightllm.utils.graceful_utils import graceful_registry -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name from .model_rpc_client import VisualModelRpcClient from .model_rpc import VisualModelRpcServer from ..objs import rpyc_config @@ -18,6 +19,7 @@ def _init_env(socket_path: str, success_event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_model_infer") import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 55f4704a31..50bc12fd23 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -13,6 +13,7 @@ from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel +from lightllm.models.gemma4.gemma4_visual import Gemma4VisionModel from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel @@ -47,7 +48,7 @@ def exposed_init_model(self, kvargs): # "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], # "quant_type": self.args.vit_quant_type, # "quant_cfg": self.args.vit_quant_cfg, - # "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + # "max_batch_size": max(self.infer_batch_size // self.vit_dp, 1), # "vit_attn_backend": self.vit_attn_backend, # } @@ -97,6 +98,8 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif self.model_type == "gemma4": + self.model = Gemma4VisionModel(data_type=kvargs["data_type"]) elif ( model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type") == "qwen3_omni_moe_vision_encoder" diff --git a/lightllm/server/visualserver/proxy_manager.py b/lightllm/server/visualserver/proxy_manager.py index 2cf02d19e6..0c977b2aa9 100644 --- a/lightllm/server/visualserver/proxy_manager.py +++ b/lightllm/server/visualserver/proxy_manager.py @@ -211,7 +211,7 @@ def start_visual_process(args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_proxy_server") start_parent_check_thread() try: visualserver = ProxyVisualManager(args=args) diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 27275c1e8c..b06713d87c 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -116,7 +116,7 @@ async def wait_to_model_ready(self): "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, - "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "max_batch_size": max(self.infer_batch_size // self.vit_dp, 1), "vit_attn_backend": self.vit_attn_backend, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) @@ -177,7 +177,7 @@ def start_visual_process(args: StartArgs, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_only_server") start_parent_check_thread() try: diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index 6c5fe90309..ab5c0a88a1 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -196,12 +196,15 @@ def _validate_flashmla_sparse(): except Exception as e: return False, f"sgl_kernel.flash_mla import failed: {type(e).__name__}: {e}" - batch, heads, seq, dim = 1, 64, 128, 512 + 64 + batch, heads, seq = 1, 64, 128 + kv_lora_rank = 512 + qk_rope_head_dim = 64 + qk_dim = kv_lora_rank + qk_rope_head_dim dtype = torch.bfloat16 device = "cuda" - q = torch.randn(batch * seq, heads, dim, dtype=dtype, device=device) - kv = torch.zeros(batch * seq, 1, dim, dtype=dtype, device=device) + q = torch.randn(batch * seq, heads, qk_dim, dtype=dtype, device=device) + kv = torch.zeros(batch * seq, 1, qk_dim, dtype=dtype, device=device) index_topk = 128 topk_indices = torch.zeros(batch * seq, index_topk, dtype=torch.int32, device=device) @@ -210,8 +213,7 @@ def _validate_flashmla_sparse(): topk_indices = topk_indices.view(batch * seq, 1, index_topk) - softmax_scale = 1.0 / (dim ** 0.5) - kv_lora_rank = dim + softmax_scale = 1.0 / (qk_dim ** 0.5) try: mla_out, _, _ = flash_mla_sparse_fwd( diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 8683779aa9..85d21477b0 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -14,6 +14,176 @@ def get_config_json(model_path: str): return json_obj +def _derive_max_req_total_len_from_model_config(model_dir: str) -> Optional[int]: + """ + Derive `max_req_total_len` from model config.json. + + Keep the derivation aligned with LightLLM's RoPE initialization logic: + - If `max_sequence_length` exists: use it directly. + - Otherwise: use `max_position_embeddings * rope_scaling.factor` (factor defaults to 1.0). + """ + + try: + cfg = get_config_json(model_dir) + except Exception as e: + logger.warning(f"failed to load config.json for max_req_total_len derive: {e}") + return None + + candidates = [cfg] + + llm_cfg = cfg.get("llm_config") + if isinstance(llm_cfg, dict): + candidates.append(llm_cfg) + + text_cfg = cfg.get("text_config") + if isinstance(text_cfg, dict): + candidates.append(text_cfg) + + thinker_cfg = cfg.get("thinker_config") + if isinstance(thinker_cfg, dict): + thinker_text_cfg = thinker_cfg.get("text_config") + if isinstance(thinker_text_cfg, dict): + candidates.append(thinker_text_cfg) + + def _find_key(key: str): + for c in candidates: + if isinstance(c, dict) and key in c and c[key] is not None: + return c.get(key) + return None + + def _find_rope_scaling() -> dict: + rope_scaling = _find_key("rope_scaling") + if rope_scaling is None: + return {} + if isinstance(rope_scaling, dict): + return rope_scaling + return {} + + max_sequence_length = _find_key("max_sequence_length") + if max_sequence_length is not None: + try: + val = int(max_sequence_length) + if val > 0: + return val + except Exception: + return None + + max_position_embeddings = _find_key("max_position_embeddings") + if max_position_embeddings is None: + return None + + rope_scaling = _find_rope_scaling() + rope_type = None + for k in ("rope_type", "type", "__type"): + v = rope_scaling.get(k) + if isinstance(v, str) and v.strip(): + rope_type = v.strip().lower() + break + + # Align with `lightllm/models/llama/model.py` RoPE initialization: + # - `yarn/dynamic/su/llama3`: do NOT multiply by `rope_scaling.factor` for max length. + # - `default/mrope` (and unknown): multiply by factor when present. + no_factor_types = {"yarn", "dynamic", "su", "llama3"} + multiply_factor = True + if rope_type is not None and rope_type in no_factor_types: + multiply_factor = False + + try: + factor_raw = rope_scaling.get("factor", 1.0) + factor = 1.0 if factor_raw is None else float(factor_raw) + except Exception: + factor = 1.0 + + try: + max_pos = float(max_position_embeddings) + val = int(max_pos * factor) if multiply_factor else int(max_pos) + if val > 0: + logger.info( + "auto set max_req_total_len=%s (rope_type=%s,max_position_embeddings=%s,factor=%s, multiply_factor=%s)", + val, + rope_type, + max_position_embeddings, + factor, + multiply_factor, + ) + return val + except Exception: + return None + + return None + + +def auto_set_max_req_total_len(args) -> None: + """ + Ensure `args.max_req_total_len` is an int. + + If the user provides a value, keep it. + If it's None, auto-derive from config.json; fallback to 16384. + """ + + default_fallback = 16384 + if args.max_req_total_len is not None: + return + + model_dir = args.model_dir + if not model_dir: + logger.warning("model_dir is empty; fallback max_req_total_len=16384") + args.max_req_total_len = default_fallback + return + + try: + derived = _derive_max_req_total_len_from_model_config(model_dir) + except Exception as e: + logger.warning(f"failed to derive max_req_total_len from model config: {e}") + derived = None + + if derived is None: + logger.warning(f"cannot derive max_req_total_len from model config; fallback to {default_fallback}") + args.max_req_total_len = default_fallback + return + + args.max_req_total_len = int(derived) + logger.info(f"auto derived max_req_total_len={args.max_req_total_len} from model config") + + +def auto_set_fused_shared_experts(args) -> None: + """ + Route fused shared experts to supported model families and write the final + decision to `args.enable_fused_shared_experts`. + """ + + if args.enable_fused_shared_experts: + logger.info("skip auto setting fused shared experts: already enabled") + return + + if args.enable_ep_moe: + logger.info("do not enable fused shared experts: EP MoE uses a separate implementation") + return + + model_dir = args.model_dir + if not model_dir: + logger.info("do not enable fused shared experts: model_dir is empty") + return + + model_type = get_model_type(model_dir) + supported_model_types = { + "deepseek_v3", + "deepseek_v31", + "deepseek_v32", + "qwen3_next", + "qwen3_5", + "qwen3_5_text", + "qwen3_5_moe", + "qwen3_5_moe_text", + } + if model_type not in supported_model_types: + logger.info(f"do not enable fused shared experts: unsupported model_type={model_type}") + return + + args.enable_fused_shared_experts = True + logger.info(f"auto enable fused shared experts for model_type={model_type}") + + def _get_config_llm_keyvalue(model_path: str, key_name: list[str]): config_json = get_config_json(model_path) for key in key_name: @@ -79,6 +249,23 @@ def get_layer_num(model_path: str) -> int: def get_eos_token_ids(model_path: str) -> Optional[List[int]]: + # gemma4 special eos_token_id + try: + model_type = get_model_type(model_path) + assert model_type == "gemma4" + + generation_config_path = os.path.join(model_path, "generation_config.json") + with open(generation_config_path, "r") as file: + eos_token_id = json.load(file).get("eos_token_id") + + assert eos_token_id is not None + if isinstance(eos_token_id, int): + return [eos_token_id] + elif isinstance(eos_token_id, list): + return list(eos_token_id) + except: + pass + try: # qwen3-omini special eos_token_id config_json = get_config_json(model_path) @@ -89,16 +276,36 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: # Qwen3.5 checkpoints can have an eos_token_id in config that differs from # tokenizer.eos_token_id. In practice tokenizer.eos_token_id is the reliable - # stop id (<|im_end|>) for detokenization/stop behavior. + # stop id (<|im_end|>, <|endoftext|>) for detokenization/stop behavior. try: config_json = get_config_json(model_path) model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") if model_type in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"}: from transformers import AutoTokenizer + eos_token_ids = [] + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) if tokenizer.eos_token_id is not None: - return [int(tokenizer.eos_token_id)] + eos_token_ids.append(int(tokenizer.eos_token_id)) + + generation_config_path = os.path.join(model_path, "generation_config.json") + if os.path.exists(generation_config_path): + with open(generation_config_path, "r") as file: + generation_eos_token_id = json.load(file).get("eos_token_id") + if isinstance(generation_eos_token_id, int): + eos_token_ids.append(generation_eos_token_id) + elif isinstance(generation_eos_token_id, list): + eos_token_ids.extend(generation_eos_token_id) + + config_eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) + if isinstance(config_eos_token_id, int): + eos_token_ids.append(config_eos_token_id) + elif isinstance(config_eos_token_id, list): + eos_token_ids.extend(config_eos_token_id) + + if eos_token_ids: + return list(set(eos_token_ids)) except Exception: pass @@ -195,6 +402,9 @@ def has_vision_module(model_path: str) -> bool: return True elif model_type == "gemma3": return True + elif model_type == "gemma4": + model_cfg["vision_config"] + return model_cfg["vision_config"] is not None elif ( model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type") == "qwen3_omni_moe_vision_encoder" @@ -231,3 +441,115 @@ def has_audio_module(model_path: str) -> bool: except: logger.info(f"model path: {model_path} does not has audio module") return False + + +@lru_cache(maxsize=None) +def is_linear_att_mixed_model(model_path: str) -> bool: + try: + from transformers.configuration_utils import PretrainedConfig + + model_cfg, _ = PretrainedConfig.get_config_dict(model_path) + model_type = model_cfg["model_type"] + if model_type in ["qwen3_5", "qwen3_5_moe", "qwen3_5_text", "qwen3_5_moe_text"]: + return True + else: + return False + except: + logger.info(f"model path: {model_path} does not has linear hybrid attention") + return False + + +def get_model_type(model_path: str) -> Optional[str]: + """Get model type from config.json""" + try: + config_json = get_config_json(model_path) + model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") + return model_type + except Exception as e: + logger.error(f"Failed to get model_type from {model_path}: {e}") + return None + + +@lru_cache(maxsize=None) +def get_model_type_v1() -> Optional[str]: + start_args = get_env_start_args() + return get_model_type(start_args.model_dir) + + +def get_tool_call_parser_for_model(model_path: str) -> Optional[str]: + """Auto-detect tool_call_parser based on model type""" + model_type = get_model_type(model_path) + if model_type is None: + return None + + # Qwen3.5 series + if model_type in ["qwen3_5", "qwen3_5_moe", "qwen3_5_text", "qwen3_5_moe_text"]: + return "qwen3_coder" + + # Qwen3 series + if model_type in ["qwen3", "qwen3_moe", "qwen3_vl", "qwen3_vl_moe", "qwen3_vl_text", "qwen3_vl_moe_text"]: + return "qwen25" + + # DeepSeek V3 + if model_type == "deepseek_v3": + return "deepseekv3" + + # DeepSeek V3.1 + if model_type == "deepseek_v31": + return "deepseekv31" + + # DeepSeek V32 + if model_type == "deepseek_v32": + return "deepseekv32" + + return None + + +def get_reasoning_parser_for_model(model_path: str) -> Optional[str]: + """Auto-detect reasoning_parser based on model type""" + model_type = get_model_type(model_path) + if model_type is None: + return None + + # Qwen3.5 and Qwen3 series + if model_type in [ + "qwen3", + "qwen3_moe", + "qwen3_vl", + "qwen3_vl_moe", + "qwen3_vl_text", + "qwen3_vl_moe_text", + "qwen3_5", + "qwen3_5_moe", + "qwen3_5_text", + "qwen3_5_moe_text", + ]: + return "qwen3" + + # DeepSeek V3 + if model_type in ["deepseek_v3", "deepseek_v31", "deepseek_v32"]: + return "deepseek-v3" + + # DeepSeek R1 + if model_type == "deepseek_r1": + return "deepseek-r1" + + # Gemma-4 (all variants share the same Harmony-like <|channel>... format) + if model_type == "gemma4": + return "gemma4" + + return None + + +@lru_cache(maxsize=None) +def ffn_use_tanh_approximate_gelu() -> bool: + try: + start_args = get_env_start_args() + model_type = get_model_type(start_args.model_dir) + if model_type in ["gemma4"]: + logger.info("Gemma4 uses tanh-approximate-gelu for FFN") + return True + except: + pass + + return False diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 43b10ec88b..58bff90560 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -40,6 +40,11 @@ def get_device_sm_count(): return properties["multiprocessor_count"] +@lru_cache(maxsize=None) +def is_sm100_gpu(): + return torch.cuda.get_device_capability()[0] == 10 + + @lru_cache(maxsize=None) def get_device_sm_regs_num(): import triton diff --git a/lightllm/utils/dist_check_utils.py b/lightllm/utils/dist_check_utils.py new file mode 100644 index 0000000000..12b0b81993 --- /dev/null +++ b/lightllm/utils/dist_check_utils.py @@ -0,0 +1,173 @@ +""" +通过双卡 NCCL 任务,对可选的 all-reduce 快路径做环境探测。 + +每个后端用 ``torch.multiprocessing.spawn(..., nprocs=2)`` 起两个子进程: +初始化分布式后做一次真实集合通信,再退出。 +""" + +import socket +import threading +from typing import TYPE_CHECKING, Callable + +from lightllm.utils.log_utils import init_logger + +if TYPE_CHECKING: + from lightllm.server.core.objs.start_args_type import StartArgs + +logger = init_logger(__name__) + +_CUSTOM_ALLREDUCE_WORLD_SIZES = (2, 4, 6, 8) +_TWO_GPU_CHECK_TIMEOUT_SECONDS = 600.0 + + +def _start_two_gpu_check_timeout_watchdog(backend_name: str) -> threading.Event: + """Each spawned rank runs its own watchdog thread; exits the process if the check does not finish in time.""" + + import os + import time + + probe_finished = threading.Event() + + def watchdog_main() -> None: + time.sleep(_TWO_GPU_CHECK_TIMEOUT_SECONDS) + if not probe_finished.is_set(): + logger.warning( + "%s 2-GPU all-reduce capability check timed out after %.0fs; force exit.", + backend_name, + _TWO_GPU_CHECK_TIMEOUT_SECONDS, + ) + os._exit(1) + + watchdog_thread = threading.Thread(target=watchdog_main, daemon=True) + watchdog_thread.start() + return probe_finished + + +def _should_run_allreduce_capability_check(args: "StartArgs") -> bool: + if args.hardware_platform != "cuda": + return False + + return (args.tp // args.dp) in _CUSTOM_ALLREDUCE_WORLD_SIZES + + +def _pick_free_tcp_port() -> int: + socket_handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_handle.bind(("127.0.0.1", 0)) + free_port = socket_handle.getsockname()[1] + socket_handle.close() + return int(free_port) + + +def _flashinfer_two_gpu_check_worker(process_rank: int, init_tcp_port: int) -> None: + probe_finished_event = _start_two_gpu_check_timeout_watchdog("FlashInfer") + try: + import torch + import torch.distributed as dist + + cuda_device = torch.device(f"cuda:{process_rank}") + torch.cuda.set_device(cuda_device) + dist.init_process_group( + "nccl", + init_method=f"tcp://127.0.0.1:{init_tcp_port}", + world_size=2, + rank=process_rank, + device_id=cuda_device, + ) + try: + gloo_process_group = dist.new_group([0, 1], backend="gloo") + from lightllm.distributed.flashinfer_all_reduce import FlashInferAllReduce + + flashinfer_all_reduce = FlashInferAllReduce(gloo_process_group, cuda_device) + if flashinfer_all_reduce.disabled: + raise RuntimeError("FlashInferAllReduce disabled") + if process_rank == 0: + input_tensor = torch.zeros(2, 64, device=cuda_device, dtype=torch.bfloat16) + else: + input_tensor = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) + if not flashinfer_all_reduce.should_use(input_tensor): + raise RuntimeError("FlashInferAllReduce unsupported for probe tensor") + output_tensor = flashinfer_all_reduce.all_reduce(input_tensor) + dist.barrier() + expected_reduced = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) + if not torch.allclose(output_tensor, expected_reduced): + raise RuntimeError("FlashInfer allreduce value mismatch") + finally: + if dist.is_initialized(): + dist.destroy_process_group() + finally: + probe_finished_event.set() + + +def _symm_mem_two_gpu_check_worker(process_rank: int, init_tcp_port: int) -> None: + probe_finished_event = _start_two_gpu_check_timeout_watchdog("SymmMem") + try: + import torch + import torch.distributed as dist + + cuda_device = torch.device(f"cuda:{process_rank}") + torch.cuda.set_device(cuda_device) + dist.init_process_group( + "nccl", + init_method=f"tcp://127.0.0.1:{init_tcp_port}", + world_size=2, + rank=process_rank, + device_id=cuda_device, + ) + try: + nccl_process_group = dist.new_group([0, 1], backend="nccl") + from lightllm.distributed.symm_mem_all_reduce import SymmMemAllreduce + + symm_mem_all_reduce = SymmMemAllreduce(nccl_process_group, cuda_device, dtype=torch.bfloat16) + if symm_mem_all_reduce.disabled: + raise RuntimeError("SymmMemAllreduce disabled") + if process_rank == 0: + activation_tensor = torch.zeros(8, 32, device=cuda_device, dtype=torch.bfloat16) + else: + activation_tensor = torch.ones(8, 32, device=cuda_device, dtype=torch.bfloat16) + symm_mem_all_reduce.all_reduce(activation_tensor) + dist.barrier() + expected_reduced = torch.ones(8, 32, device=cuda_device, dtype=torch.bfloat16) + if not torch.allclose(activation_tensor, expected_reduced): + raise RuntimeError("SymmMem allreduce value mismatch") + finally: + if dist.is_initialized(): + dist.destroy_process_group() + finally: + probe_finished_event.set() + + +def _check_ok_two_gpu_all_reduce(worker_entry: Callable[[int, int], None], init_tcp_port: int) -> bool: + import torch.multiprocessing as torch_mp + + try: + torch_mp.spawn(worker_entry, args=(init_tcp_port,), nprocs=2, join=True) + return True + except Exception as error: + error_str = str(error) + error_str = error_str[-66:].replace("\n", "") + logger.warning("2-GPU all-reduce capability check failed for %s: %s", worker_entry.__name__, error_str) + return False + + +def auto_configure_allreduce_flags_from_args(args: "StartArgs") -> None: + """ + 用户若已通过 ``--disable_*`` 关闭某后端,则不再处理该后端。 + + 否则会按环境与并行规模,对每个后端做一次双进程 NCCL 探测;失败则将对应 ``disable_*`` 设为 True。 + + 会就地修改 ``args.disable_flashinfer_allreduce`` / ``args.disable_symm_mem_allreduce``。 + """ + if not _should_run_allreduce_capability_check(args): + return + + if not args.disable_flashinfer_allreduce: + if not _check_ok_two_gpu_all_reduce(_flashinfer_two_gpu_check_worker, _pick_free_tcp_port()): + logger.info( + "Auto-set disable_flashinfer_allreduce=True (2-GPU FlashInfer all-reduce capability check failed)." + ) + args.disable_flashinfer_allreduce = True + + if not args.disable_symm_mem_allreduce: + if not _check_ok_two_gpu_all_reduce(_symm_mem_two_gpu_check_worker, _pick_free_tcp_port()): + logger.info("Auto-set disable_symm_mem_allreduce=True (2-GPU SymmMem all-reduce capability check failed).") + args.disable_symm_mem_allreduce = True diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 350507e897..773320273c 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -11,7 +11,7 @@ def set_unique_server_name(args): - node_uuid = uuid.uuid1().hex[0:8] + node_uuid = uuid.uuid4().hex[0:16] if args.run_mode == "pd_master": os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(node_uuid) + "_pd_master" @@ -59,7 +59,7 @@ def get_llm_data_type() -> torch.dtype: elif data_type in ["fp32", "float32"]: data_type = torch.float32 else: - raise ValueError(f"Unsupport datatype {data_type}!") + raise ValueError(f"Unsupported datatype {data_type}!") return data_type @@ -69,9 +69,22 @@ def enable_env_vars(args): @lru_cache(maxsize=None) -def get_deepep_num_max_dispatch_tokens_per_rank(): +def get_deepep_num_max_dispatch_tokens_per_rank_prefill(): + # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大。 + # 如果未显式配置,则默认至少覆盖当前进程的 `batch_max_tokens`,避免 DeepEP V2 在 autotune + # warmup 或大 prefill batch 时因为 buffer 上界过小而报错。 + configured = os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_PREFILL", None) + if configured is not None: + return int(configured) + + batch_max_tokens = get_env_start_args().batch_max_tokens or 256 + return ((int(batch_max_tokens) + 7) // 8) * 8 + + +@lru_cache(maxsize=None) +def get_deepep_num_max_dispatch_tokens_per_rank_decode(): # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大,如果出现显存不足,可以尝试调小该值 - return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK", 256)) + return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_DECODE", 256)) def get_lightllm_gunicorn_keep_alive(): diff --git a/lightllm/utils/error_utils.py b/lightllm/utils/error_utils.py index 1754d2bb65..77ad90c618 100644 --- a/lightllm/utils/error_utils.py +++ b/lightllm/utils/error_utils.py @@ -1,4 +1,5 @@ from lightllm.utils.log_utils import init_logger +from typing import Optional logger = init_logger(__name__) @@ -23,10 +24,24 @@ def __str__(self): return f"{self.message} (Status code: {self.status_code})" -class NixlPrefillNodeStopGenToken(Exception): - def __init__(self, group_request_id, message="Nixl prefill node stop gen token"): +class ClientDisconnected(Exception): + """Raised when the client closed the HTTP connection mid-request, as + detected by ``request.is_disconnected()``. This is an expected control-flow + signal — handlers should clean up quietly without logging a stack trace. + Internal-module aborts (e.g. visual proxy failures) must NOT raise this — + they should surface as real server errors.""" + + def __init__(self, group_request_id: Optional[int] = None, reason: str = "client disconnected"): + prefix = f"req_id {group_request_id} " if group_request_id is not None else "" + super().__init__(f"{prefix}{reason}") + self.group_request_id = group_request_id + self.reason = reason + + +class PDPrefillNodeStopGenToken(Exception): + def __init__(self, group_request_id, message="PD prefill node stop gen token"): """ - Initialize the NixlPrefillNodeStopGenToken + Initialize the PDPrefillNodeStopGenToken Args: message (str): Error message to display diff --git a/lightllm/utils/health_check.py b/lightllm/utils/health_check.py index f6c52bdb38..d2a776b862 100644 --- a/lightllm/utils/health_check.py +++ b/lightllm/utils/health_check.py @@ -1,105 +1,52 @@ import os import time -import asyncio -import numpy as np from dataclasses import dataclass -from lightllm.server.core.objs import SamplingParams -from lightllm.server.multimodal_params import MultimodalParams -from lightllm.server.httpserver.manager import HttpServerManager +from typing import TYPE_CHECKING + from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -from fastapi import Request from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.envs_utils import get_unique_server_name + +if TYPE_CHECKING: + from lightllm.server.core.objs.shm_req_manager import ShmReqManager logger = init_logger(__name__) @dataclass class HealthObj: - _is_health: bool = False - _is_health_checking: bool = False - _failure_count: int = 0 - _failure_threshold: int = int(os.getenv("HEALTH_FAILURE_THRESHOLD", 3)) - timeout: int = int(os.getenv("HEALTH_TIMEOUT", 100)) - dynamic_timeout: int = int(os.getenv("HEALTH_TIMEOUT", 100)) - latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") - - def begin_check(self): - self._is_health_checking = True - - def end_check(self): - self._is_health_checking = False - - def set_unhealth(self): - self._failure_count += 1 - self.dynamic_timeout += self.timeout - if self._failure_count > self._failure_threshold: - self._is_health = False - - def set_health(self): - self._is_health = True - self._failure_count = 0 - self.dynamic_timeout = self.timeout - - def is_health(self): - return self._is_health + grace_timeout: int = int(os.getenv("HEALTH_TIMEOUT", "200")) - def is_checking(self): - return self._is_health_checking + def __post_init__(self): + uid = get_unique_server_name() + self.latest_success_infer_time_mark = SharedInt(f"{uid}_latest_success_infer_time_mark") + self.run_reqs_count_mark = SharedInt(f"{uid}_run_reqs_count_mark") - def has_latest_inference(self): - last_timemark = self.latest_success_infer_time_mark.get_value() - time_diff = time.time() - last_timemark - return time_diff < self.timeout + def check(self, shm_req_manager: "ShmReqManager") -> bool: + """On-the-fly health check: recent success is ok; otherwise require no in-flight shm requests.""" + try: + now = time.time() + last_success_time = self.latest_success_infer_time_mark.get_value() + + # 如果最近一次成功推理的时间距离现在小于 grace_timeout,则认为系统健康 + if now - last_success_time <= self.grace_timeout: + return True + elif self.run_reqs_count_mark.get_value() == 0 and shm_req_manager.is_idle(): + # 如果最近一次成功推理的时间距离现在大于 grace_timeout,并且没有在推理的请求,则认为系统健康 + return True + else: + logger.warning( + "Health check failed: no success for %ss and in-flight shm requests remain", + int(now - last_success_time), + ) + return False + except Exception as e: + logger.exception(str(e)) + return False health_obj = HealthObj() -async def health_check(args, httpserver_manager: HttpServerManager, request: Request): - if health_obj.is_checking(): - return health_obj.is_health() - - if health_obj.is_health() and health_obj.has_latest_inference(): - return health_obj.is_health() - - health_obj.begin_check() - try: - request_dict = {"inputs": "你好!", "parameters": {"do_sample": True, "temperature": 0.8, "max_new_tokens": 2}} - if args.run_mode in ["prefill", "nixl_prefill"]: - request_dict["parameters"]["max_new_tokens"] = 1 - prompt = request_dict.pop("inputs") - sample_params_dict = request_dict["parameters"] - sampling_params = SamplingParams() - sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) - sampling_params.verify() - - if get_env_start_args().run_mode == "pd_master": - # Since the id assigned by pd master needs to be passed to prefill and decode nodes for inference, - # a normal request id is required instead of a negative id. - sampling_params.group_request_id = httpserver_manager.id_gen.generate_id() - else: - sampling_params.group_request_id = -httpserver_manager.id_gen.generate_id() # health monitor 的 id 是负的 - multimodal_params_dict = request_dict.get("multimodal_params", {}) - multimodal_params = MultimodalParams(**multimodal_params_dict) - results_generator = httpserver_manager.generate( - prompt, sampling_params, multimodal_params, request, is_health_req=True - ) - - async def check_timeout(results_generator): - async for _, _, _, _ in results_generator: - pass - - try: - await asyncio.wait_for(check_timeout(results_generator), timeout=health_obj.dynamic_timeout) - health_obj.set_health() - except asyncio.TimeoutError: - health_obj.set_unhealth() - logger.warning(f"Health check timeout! The failure count is: {str(health_obj._failure_count)}") - return health_obj.is_health() - except Exception as e: - logger.exception(str(e)) - health_obj.set_unhealth() - return health_obj.is_health() - finally: - health_obj.end_check() +def health_check(shm_req_manager: "ShmReqManager") -> bool: + return health_obj.check(shm_req_manager) diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 10764e24b0..494908cb10 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -4,7 +4,7 @@ import os import xxhash import threading -import time +import concurrent.futures import numpy as np import triton from functools import lru_cache @@ -15,19 +15,21 @@ get_added_mtp_kv_layer_num, ) from lightllm.utils.log_utils import init_logger -from lightllm.utils.config_utils import get_num_key_value_heads, get_head_dim, get_layer_num +from lightllm.utils.config_utils import get_num_key_value_heads, get_head_dim, get_layer_num, is_linear_att_mixed_model from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.kv_cache_mem_manager import ( MemoryManager, PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, + Qwen3NextMemManager, ) from typing import List, Tuple, Optional from tqdm import tqdm from lightllm.utils.auto_shm_cleanup import register_sysv_shm_for_cleanup from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig logger = init_logger(__name__) @@ -61,8 +63,25 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": args = get_env_start_args() assert args.enable_cpu_cache - mem_manager_class = select_mem_manager_class() - if mem_manager_class is Deepseek2MemoryManager: + if is_linear_att_mixed_model(args.model_dir): + # 对于 qwen3.5 等 linear att 混合模型的特殊处理。 + mem_manager_class = Qwen3NextMemManager + else: + mem_manager_class = select_mem_manager_class() + + if mem_manager_class is Qwen3NextMemManager: + linear_config = LinearAttCacheConfig.load_from_args() + cpu_cache_meta = CpuKVCacheMeta( + page_num=0, + token_page_size=1, + layer_num=1, + num_heads=1, + head_dim=linear_config.get_cpu_cache_big_page_bytes(), + data_type=torch.uint8, + scale_head_dim=0, + scale_data_type=get_llm_data_type(), + ) + elif mem_manager_class is Deepseek2MemoryManager: cpu_cache_meta = CpuKVCacheMeta( page_num=0, token_page_size=args.cpu_cache_token_page_size, @@ -101,6 +120,7 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 + assert is_linear_att_mixed_model(args.model_dir) is False, "linear att mixed model does not support mtp mode" cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() cpu_cache_page_num = int( @@ -215,7 +235,16 @@ def _get_default_hugepage_size() -> int: def _pre_warm_memory(): page_size = _get_default_hugepage_size() if use_hugetlb else 4096 arr = np.ctypeslib.as_array(ctypes.cast(shm_addr, ctypes.POINTER(ctypes.c_uint8)), shape=(size_to_alloc,)) - volatile_sum = int(arr[::page_size].sum()) + worker_num = 8 + chunk_size = triton.cdiv(size_to_alloc, worker_num * page_size) * page_size + + def _warm_range(worker_id: int): + start = worker_id * chunk_size + end = min(size_to_alloc, start + chunk_size) + return int(arr[start:end:page_size].sum()) + + with concurrent.futures.ThreadPoolExecutor(max_workers=worker_num) as executor: + volatile_sum = sum(executor.map(_warm_range, range(worker_num))) logger.info(f"pre warmed shared memory pages successfully, checksum={volatile_sum})") th = threading.Thread(target=_pre_warm_memory, name=f"cpu_cache_pre_warm_{key}", daemon=True) @@ -225,8 +254,8 @@ def _pre_warm_memory(): @lru_cache(maxsize=None) -def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> "AsyncRegistrationHandle": - """Start async cudaHostRegister on the given [shm_ptr, shm_ptr+size) and return a handle.""" +def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> int: + """Synchronously cudaHostRegister the given [shm_ptr, shm_ptr+size).""" chunk_bytes = 128 * 1024 * 1024 # 128M性能最好 tasks: list[tuple[int, int]] = [] offset = 0 @@ -235,74 +264,43 @@ def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> "AsyncRegistrationHandle tasks.append((offset, seg_len)) offset += seg_len - handle = AsyncRegistrationHandle(total_tasks=len(tasks)) - - def _worker(): - cuda = ctypes.CDLL("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so") - cuda.cudaHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] - cuda.cudaHostRegister.restype = ctypes.c_int - cuda.cudaHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int] - cuda.cudaHostGetDevicePointer.restype = ctypes.c_int - - cudaHostRegisterFlag = 3 - - torch.cuda.set_device(get_current_device_id()) - # TODO 这个地方的分块注册是否具备合法性和合理性。 - for offset, seg_len in tasks: - ptr = ctypes.c_void_p(shm_ptr + offset) - r = cuda.cudaHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag) - if r != 0: - raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb") - handle.task_count += 1 - - device_ptr = ctypes.c_void_p() - host_ptr = ctypes.c_void_p(shm_ptr) - res = cuda.cudaHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0) - if res != 0: - raise Exception(f"cudaHostGetDevicePointer failed with error code {res}") - assert host_ptr.value == device_ptr.value - handle.tasks_finished.set() - - th = threading.Thread(target=_worker, name=f"cpu_cache_register_{shm_ptr}", daemon=True) - handle.thread = th - th.start() - return handle - - -class AsyncRegistrationHandle: - """A handle for async host memory registration. - - - wait(): blocks until registration finishes, prints tqdm progress, and returns device pointer (int). - """ - - def __init__(self, total_tasks: int): - self.total_tasks = total_tasks - self.task_count = 0 - self.thread: Optional[threading.Thread] = None - self.tasks_finished = threading.Event() - - def wait(self): - """Block until the async registration completes. Only here we print tqdm progress.""" - last_count = 0 - desc = f"pid {os.getpid()} Registering pinned host memory (async)" - with tqdm(total=self.total_tasks, desc=desc) as pbar: - while not self.tasks_finished.is_set(): - cur = self.task_count - if cur > last_count: - pbar.update(cur - last_count) - last_count = cur - time.sleep(0.01) - # final update - cur = self.task_count - if cur > last_count: - pbar.update(cur - last_count) - last_count = cur - - if self.thread is not None and self.thread.is_alive(): - self.thread.join() - + cuda = ctypes.CDLL("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so") + cuda.cudaHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] + cuda.cudaHostRegister.restype = ctypes.c_int + cuda.cudaHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int] + cuda.cudaHostGetDevicePointer.restype = ctypes.c_int + + cudaHostRegisterFlag = 3 + + device_id = get_current_device_id() + torch.cuda.set_device(device_id) + desc = f"pid {os.getpid()} Registering pinned host memory" + + def _register_one_segment(task: Tuple[int, int]): + offset, seg_len = task + torch.cuda.set_device(device_id) + ptr = ctypes.c_void_p(shm_ptr + offset) + r = cuda.cudaHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag) + if r != 0: + raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb") return + # worker_num的数值需要与_pre_warm_memory一致,不然会丢失warmup的效果 + if tasks: + worker_num = min(8, len(tasks)) + with concurrent.futures.ThreadPoolExecutor(max_workers=worker_num) as executor: + futures = [executor.submit(_register_one_segment, task) for task in tasks] + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc=desc): + future.result() + + device_ptr = ctypes.c_void_p() + host_ptr = ctypes.c_void_p(shm_ptr) + res = cuda.cudaHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0) + if res != 0: + raise Exception(f"cudaHostGetDevicePointer failed with error code {res}") + logger.info(f"cudaHostGetDevicePointer success, host_ptr={host_ptr.value}, device_ptr={device_ptr.value}") + return device_ptr.value + @lru_cache(maxsize=None) def attach_shm_kv_cache_ptr(key: int, size: int) -> int: diff --git a/lightllm/utils/light_utils.py b/lightllm/utils/light_utils.py deleted file mode 100644 index 944a0fe15f..0000000000 --- a/lightllm/utils/light_utils.py +++ /dev/null @@ -1,13 +0,0 @@ -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) -try: - # TODO: lightllm_kernel release - import lightllm_kernel - - light_ops = getattr(lightllm_kernel, "ops", lightllm_kernel) - HAS_LIGHTLLM_KERNEL = True -except: - light_ops = None - HAS_LIGHTLLM_KERNEL = False - logger.warning("lightllm_kernel is not installed, you can't use the api of it.") diff --git a/lightllm/utils/multimodal_utils.py b/lightllm/utils/multimodal_utils.py index 4b49ea8891..876283b931 100644 --- a/lightllm/utils/multimodal_utils.py +++ b/lightllm/utils/multimodal_utils.py @@ -6,6 +6,7 @@ from io import BytesIO from fastapi import Request from functools import lru_cache +from lightllm.utils.error_utils import ClientDisconnected from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -53,7 +54,7 @@ async def fetch_resource(url, request: Request, timeout, proxy=None): async for chunk in response.aiter_bytes(chunk_size=1024 * 1024): if request is not None and await request.is_disconnected(): await response.aclose() - raise Exception("Request disconnected. User cancelled download.") + raise ClientDisconnected(reason=f"client disconnected during download of {url}") ans_bytes.append(chunk) # 接收的数据不能大于128M if len(ans_bytes) > 128: diff --git a/lightllm/utils/profiler.py b/lightllm/utils/profiler.py new file mode 100644 index 0000000000..6ed23dcedf --- /dev/null +++ b/lightllm/utils/profiler.py @@ -0,0 +1,227 @@ +from dataclasses import dataclass +import os +import threading +import traceback +from typing import Any, Literal, Optional +import torch + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class ProfilerCmd: + cmd: Literal["start", "stop"] + + +def _get_thread_id() -> int: + # Get native thread ID (LWP) for correlation with system tools like htop/nsys + if hasattr(threading, "get_native_id"): + return threading.get_native_id() + return threading.get_ident() + + +class ProcessProfiler: + def __init__( + self, + mode: Literal["torch_profiler", "nvtx"], + name: Optional[str] = None, + use_multi_thread: bool = False, + torch_profiler_with_stack: bool = True, + ) -> None: + """ + Process Level Profiler Manager. + For multi-threading, set `use_multi_thread=True` + and call `.multi_thread_helper()` regularly in each worker thread. + """ + self.mode = mode + self.name = name or "unnamed" + self.use_multi_thread = use_multi_thread + self.torch_profiler_with_stack = torch_profiler_with_stack + + self.is_active: bool = False # Process-level logical state + self._threadlocal = threading.local() + + # make sure only one active torch.profiler per process + self._lock = threading.Lock() + self._process_torch_profiler_active_tid: int | None = None + + if self.mode == "torch_profiler": + self._trace_dir = os.getenv("LIGHTLLM_TRACE_DIR", "./trace") + os.makedirs(self._trace_dir, exist_ok=True) + elif self.mode == "nvtx": + self._nvtx_toplevel_mark = "LIGHTLLM_PROFILE" + else: + raise ValueError("invalid profiler mode") + + self._log_init_info() + + @property + def _local(self): + """Lazy initialization of thread-local storage.""" + if not hasattr(self._threadlocal, "initialized"): + self._threadlocal.initialized = True + self._threadlocal.is_active = False + self._threadlocal.profiler_obj = None + self._threadlocal.nvtx_range_id = None + return self._threadlocal + + def _log_init_info(self): + logger.warning("-" * 50) + logger.warning( + f"[pid={os.getpid()} tid={_get_thread_id()}] Profiler <{self.name}> initialized with mode: {self.mode}" + ) + if self.mode == "torch_profiler": + logger.warning( + "Profiler support for torch.profiler enabled (--enable_profiling=torch_profiler), " + "trace files will be saved to %s (change it with LIGHTLLM_TRACE_DIR env var)", + self._trace_dir, + ) + elif self.mode == "nvtx": + logger.warning( + "Profiler support for NVTX enabled (--enable_profiling=nvtx), toplevel NVTX mark is '%s'\n" + "you can use it with external profiling tools like NVIDIA Nsight Systems.", + self._nvtx_toplevel_mark, + ) + logger.warning( + "e.g. nsys profile --capture-range=nvtx --nvtx-capture=%s --trace=cuda,nvtx " + "-e NSYS_NVTX_PROFILER_REGISTER_ONLY=0 [other nsys options] " + "python -m lightllm.server.api_server --enable_profiling=nvtx [other lightllm options]", + self._nvtx_toplevel_mark, + ) + logger.warning("Use /profiler_start and /profiler_stop HTTP GET APIs to start/stop profiling") + logger.warning("DO NOT enable this feature in production environment") + logger.warning("-" * 50) + + def _torch_profiler_start(self) -> None: + with self._lock: + if self._process_torch_profiler_active_tid is not None: + return + self._process_torch_profiler_active_tid = _get_thread_id() + + torch.cuda.synchronize() + worker_name = f"{self.name}_tid{_get_thread_id()}" if self.use_multi_thread else self.name + + trace_handler = torch.profiler.tensorboard_trace_handler( + self._trace_dir, + worker_name=worker_name, + use_gzip=True, + ) + + p = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=None, + with_stack=self.torch_profiler_with_stack, + record_shapes=True, + on_trace_ready=trace_handler, + ) + + self._local.profiler_obj = p + p.start() + torch.cuda.synchronize() + + def _nvtx_start(self) -> None: + torch.cuda.synchronize() + self._local.nvtx_range_id = torch.cuda.nvtx.range_start(self._nvtx_toplevel_mark) + torch.cuda.synchronize() + + def _thread_start(self) -> None: + if self._local.is_active: + return + + try: + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Start Profiler.") + if self.mode == "torch_profiler": + self._torch_profiler_start() + elif self.mode == "nvtx": + self._nvtx_start() + + self._local.is_active = True + except Exception as e: + logger.error( + f"[{self.name} @ tid={_get_thread_id()}] Failed to start profiler in thread {_get_thread_id()}: {e}" + ) + traceback.print_exc() + # Reset state on failure to prevent infinite retry loops + self._local.is_active = False + + def _torch_profiler_stop(self) -> None: + if self._process_torch_profiler_active_tid != _get_thread_id(): + return + + torch.cuda.synchronize() + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Saving trace (blocking)...") + try: + if self._local.profiler_obj: + self._local.profiler_obj.stop() + except Exception as e: + logger.error(f"[{self.name} @ tid={_get_thread_id()}] Error stopping torch profiler: {e}") + finally: + self._local.profiler_obj = None # Explicitly release reference to allow GC + self._process_torch_profiler_active_tid = None + + torch.cuda.synchronize() + + def _nvtx_stop(self) -> None: + torch.cuda.synchronize() + if self._local.nvtx_range_id is not None: + torch.cuda.nvtx.range_end(self._local.nvtx_range_id) + self._local.nvtx_range_id = None + torch.cuda.synchronize() + + def _thread_stop(self) -> None: + if not self._local.is_active: + return + + try: + if self.mode == "torch_profiler": + self._torch_profiler_stop() + elif self.mode == "nvtx": + self._nvtx_stop() + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Profiler stopped.") + except Exception as e: + logger.error(f"[{self.name} @ tid={_get_thread_id()}] Failed to stop profiler: {e}") + finally: + # Mark inactive regardless of success to avoid repeated errors + self._local.is_active = False + + def start(self) -> None: + self.is_active = True + if not self.use_multi_thread: + self._thread_start() + + def stop(self) -> None: + self.is_active = False + if not self.use_multi_thread: + self._thread_stop() + + def multi_thread_helper(self) -> None: + """ + **only for multi-threading use cases** + Worker polling method. Must be called within the inference loop. + """ + if not self.use_multi_thread: + return + + # Catch-all to prevent profiler errors from crashing inference logic + try: + local_active = self._local.is_active + + if self.is_active and not local_active: + self._thread_start() + elif not self.is_active and local_active: + self._thread_stop() + except Exception: + pass + + def cmd(self, cmd_obj: ProfilerCmd) -> None: + if cmd_obj.cmd == "start": + self.start() + elif cmd_obj.cmd == "stop": + self.stop() + else: + raise ValueError(f"Invalid profiler cmd: {cmd_obj.cmd}") diff --git a/lightllm/utils/torch_dtype_utils.py b/lightllm/utils/torch_dtype_utils.py new file mode 100644 index 0000000000..05071e566b --- /dev/null +++ b/lightllm/utils/torch_dtype_utils.py @@ -0,0 +1,12 @@ +import torch + + +def get_torch_dtype(data_type: str) -> torch.dtype: + if data_type in ["fp16", "float16"]: + return torch.float16 + elif data_type in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif data_type in ["fp32", "float32"]: + return torch.float32 + else: + raise ValueError(f"Unsupported datatype {data_type}!") diff --git a/requirements.txt b/requirements.txt index 5331227586..603d0e488f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ mpmath==1.3.0 multiprocessing-logging==0.3.4 networkx==3.1 ninja==1.11.1 -numpy==1.25.1 +numpy==2.1.3 packaging==24.2 pip==23.0.1 pluggy==1.2.0 @@ -47,7 +47,7 @@ python-dateutil==2.8.2 python-dotenv==1.0.0 PyYAML==6.0.1 pyzmq==25.1.1b2 -regex==2023.6.3 +regex==2026.5.9 requests==2.31.0 rpyc==5.3.1 ruamel.yaml==0.17.32 @@ -59,9 +59,9 @@ six==1.16.0 sniffio==1.3.0 sortedcontainers==2.4.0 toolz==0.12.0 -torch==2.9.1 +torch==2.11.0 tqdm==4.65.0 -transformers==4.57.1 +transformers==5.8.0 tokenizers==0.22.1 urllib3==1.26.16 uvicorn==0.19.0 @@ -71,7 +71,7 @@ zstandard==0.23.0 safetensors==0.4.5 Pillow==10.4.0 tiktoken==0.7.0 -matplotlib==3.8.2 +matplotlib==3.10.0 psutil==5.9.4 prometheus_client==0.20.0 cchardet==2.1.7 @@ -80,19 +80,21 @@ frozendict==2.4.6 atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 -flashinfer-python==0.6.3 -sgl-kernel==0.3.21 +flashinfer-python==0.6.12 +flashinfer-cubin==0.6.12 +sglang-kernel==0.4.2.post1 httpx==0.28.1 librosa==0.11.0 -cuda_bindings==12.9.0 +cuda_bindings==13.2.0 orjson==3.11.2 setproctitle==1.3.6 xxhash==3.6.0 -torchvision==0.24.1 +torchvision==0.26.0 interegular==0.3.3 partial_json_parser==0.2.1.1.post6 websockets==15.0.1 -cupy-cuda12x==13.6.0 -nixl==0.8.0 -xformers==0.0.33.post2 +cupy-cuda13x==14.0.1 +nixl==1.2.0 +xformers==0.0.35 redis==7.3.0 +litellm>=1.52.0,<1.85 diff --git a/skills/lightllm-profiler-control/SKILL.md b/skills/lightllm-profiler-control/SKILL.md new file mode 100644 index 0000000000..2832a628bd --- /dev/null +++ b/skills/lightllm-profiler-control/SKILL.md @@ -0,0 +1,46 @@ +--- +name: lightllm-profiler-control +description: LightLLM profiler 使用说明。用于需要启动或停止 LightLLM 的 torch_profiler / nvtx profiling 功能时,尤其是查看 --enable_profiling、/profiler_start、/profiler_stop 的使用方法。 +--- + +# LightLLM Profiler 使用说明 + +## 使用场景 + +当用户需要使用 LightLLM profiler 功能时使用本 skill,包括: + +- 启动服务时打开 profiler 能力。 +- 通过 HTTP 接口控制 profiler start / stop。 + +## 启动方式 + +服务启动时增加 `--enable_profiling`: + +```bash +python -m lightllm.server.api_server \ + --model_dir /path/to/model \ + --enable_profiling torch_profiler +``` + +支持值: + +- `torch_profiler`:启用 PyTorch profiler,trace 默认写入 `./trace`,也可通过 `LIGHTLLM_TRACE_DIR` 指定目录。 +- `nvtx`:启用 NVTX range,配合 NVIDIA Nsight Systems 等外部工具采集。 + +未设置 `--enable_profiling` 时,`/profiler_start` 和 `/profiler_stop` 会返回未启用提示。 + +## HTTP 控制接口 + +启动 profiler: + +```bash +curl http://127.0.0.1:8000/profiler_start +``` + +停止 profiler: + +```bash +curl http://127.0.0.1:8000/profiler_stop +``` + +端口 `8000` 替换为服务启动时的 `--port`。 diff --git a/skills/test_model/SKILL.md b/skills/test_model/SKILL.md new file mode 100644 index 0000000000..b77432b6ad --- /dev/null +++ b/skills/test_model/SKILL.md @@ -0,0 +1,59 @@ +--- +name: test-model-common +description: >- + Common override guidance for all skills/test_model sub-skills. Applies to + LightLLM model accuracy/speed tests that use lm_eval or lmms_eval, especially + local-completions GSM8K runs. +--- + +# Test Model 通用覆盖规则 + +本目录下所有子 skill 默认继承这些规则。若子 skill 中的命令与这里冲突,优先按这里执行; +只有在用户明确要求在线拉取数据/模型,或本地缓存缺失时,才临时关闭对应离线变量。 + +## lm_eval 启动加速 + +`lm_eval` 每次新进程启动都会加载 task、dataset、tokenizer 和 HuggingFace 相关模块。 +实测 `local-completions + gsm8k --limit 1` 时,默认在线探测模式会在 tokenizer/dataset +初始化阶段等待很久;强制使用本地缓存后,启动耗时明显下降。 + +执行所有 `lm_eval` 精度测试时,默认在命令前加: + +```bash +export HF_ALLOW_CODE_EVAL=1 +export HF_DATASETS_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 +export HF_HUB_OFFLINE=1 +export http_proxy= +export https_proxy= +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${HOST:-127.0.0.1},${PD_MASTER_IP:-127.0.0.1} +export NO_PROXY="${no_proxy}" +``` + +然后再执行子 skill 中的 `lm_eval` 命令,例如: + +```bash +lm_eval --model local-completions \ + --model_args "model=${MODEL_NAME},base_url=${BASE_URL},num_concurrent=64,max_retries=3,tokenized_requests=False,tokenizer=${MODEL_DIR}" \ + --tasks gsm8k \ + --batch_size 64 \ + --confirm_run_unsafe_code +``` + +## 使用前检查 + +- 先确认对应数据集和 tokenizer 已经在本地缓存中;如果离线模式报缓存缺失,再切回在线模式补齐缓存。 +- 精度评测前仍然要先做一次 `curl` warmup,确认服务端已经可用。 +- 如果只是压测吞吐,不要用 `lm_eval`;使用轻量 benchmark client,避免 `lm_eval` 的 task/dataset/metric 初始化成本。 +- 记录结果时要把是否启用了离线缓存写入 summary/log,方便比较不同轮次。 + +## 已验证现象 + +在本机 Qwen3.5-0.8B 普通服务上,`lm_eval --limit 1` 实测: + +| 模式 | 耗时 | +|---|---:| +| 默认在线探测 | 约 123s | +| 离线缓存模式 | 约 20s | + +因此,除非有明确理由,`skills/test_model` 下的 `lm_eval` 测试都应默认启用离线缓存变量。 diff --git a/skills/test_model/deepseekr1-base-tp/SKILL.md b/skills/test_model/deepseekr1-base-tp/SKILL.md new file mode 100644 index 0000000000..37be61db6d --- /dev/null +++ b/skills/test_model/deepseekr1-base-tp/SKILL.md @@ -0,0 +1,102 @@ +--- +name: test-model-deepseekr1-base-tp +description: >- + Runs LightLLM DeepSeek-R1 baseline TP gsm8k: single api_server with --tp 8 and + --batch_max_tokens only, no MTP draft, no --dp, no EP MoE (distinct from deepseekr1-mtp-tp + which adds MTP). GSM8K lm_eval on localhost port 8089. Requires a dedicated log directory, + api_server and eval logs under that tree, summary.txt as consolidated report, tokenizer + aligned with MODEL_DIR. Use for baseline R1 tensor-parallel accuracy runs without MTP/EP. +--- + +# DeepSeek-R1 **Base–TP**(无 MTP / 无 EP,`--tp 8`)本地 GSM8K 评测 + +**测试标识**:仅 **`--tp 8`** 与 **`--batch_max_tokens`** 等基础推理参数,**无** MTP、`--dp`、`--enable_ep_moe`。用于与 **MTP–TP**(含 MTP 草稿)、**MTP–EP**(EP MoE + TP+DP)等流程区分。 + +启动一组 `api_server`,待端口就绪后对同一服务执行一次 `lm_eval`(任务 **`gsm8k`**,`batch_size` **500**)。整轮产物须落在**同一日志目录**内归档日志与 **`summary.txt`**(见「日志目录」);具体操作见「启动说明」。 + +## 日志目录(含 `summary.txt`) + +- 每次评测先选定或新建**一个日志目录**(例如带时间戳或任务名),与其它测试轮次分开,便于区分管理。 +- **`api_server` 进程的标准输出/错误**须写入该目录下文件(示例同级命名 **`server_base_tp.log`**;也可按变体或日期分子目录,团队任选其一,保持可追溯)。 +- **`summary.txt` 固定放在该日志目录下**,写入本轮启动参数摘要、`lm_eval` 关键结果、失败原因或简要结论;**不再**把「最终总结」散落在当前工作目录或其它路径。 +- `lm_eval` 终端输出建议单独落盘(如 **`eval_gsm8k.log`**);**`summary.txt`** 仍承担**总览结论**角色。 + +## 启动说明 + +本节包含:启动前检查 → 启动服务的命令模板(可变项说明)→ 一条完整 server 命令 → 评测命令。 + +### 启动前检查 + +开跑前先确认资源可用;**不满足则先清理相关进程**,再启动服务与评测。 + +1. **显卡占用**:用 `nvidia-smi`(或与集群一致的占用查看方式)检查目标 GPU 是否被无关任务占满;若有冲突进程,结束后再启动本评测。 +2. **端口**:服务固定 **`8089`**(与下文 `lm_eval` 的 `base_url` 一致);用 `ss -tlnp`、`lsof -i :8089` 等确认**无进程监听**该端口;若已被占用,查出 PID 并结束占用进程后再启动。 + +### 启动服务的命令模板(可变项) + +下列命令中出现的可变项含义如下(其余为固定写法): + +| 可变项 | 含义 | +|--------|------| +| `LOG_DIR` | 本轮评测日志目录,建议**绝对路径**;执行前 `export LOG_DIR=…`。 | +| `MODEL_DIR` | 主模型目录,对应 `--model_dir`;与 `lm_eval` 的 `tokenizer` 必须一致。 | +| `server_*.log`、`eval_*.log` | 仅文件名示例,可按任务重命名。 | + +开跑前在同一 shell 中导出路径(将引号内整段替换为本机绝对路径;**勿写死下文未给出的机器路径**): + +```bash +export LOG_DIR='〈日志根目录〉' +export MODEL_DIR='〈主模型目录,对应 --model_dir〉' +``` + +首次试跑可用的**默认 `MODEL_DIR`** 见「执行约定」;与当前环境不符时再改为用户提供的目录。 + +### 一条 server 启动命令(后台落盘) + +本条为 **Base–TP** 固定形态:**`LOADWORKER=18`**,**`--batch_max_tokens 6000`**,**`--tp 8`**,**`--port 8089`**,无 MTP / 无 `--dp` / 无 EP。以下为**可直接执行**的后台启动形式(已含 `nohup` 与日志重定向);若暂时不需落盘,可自行去掉 `nohup`、`>> … 2>&1 &` 并在前台调试。命令中 **`${MODEL_DIR}`、`${LOG_DIR}`** 须已由上文 `export` 赋值。 + +```bash +LOADWORKER=18 \ +nohup python -m lightllm.server.api_server \ + --batch_max_tokens 6000 \ + --model_dir "${MODEL_DIR}" --tp 8 --port 8089 \ + >> "${LOG_DIR}/server_base_tp.log" 2>&1 & +``` + +### 评测命令(服务就绪后执行一次) + +服务就绪后执行(本地回环走代理时用 `no_proxy` / `NO_PROXY` 排除本机)。**`model_args` 中 `tokenizer` 必须与本次 server 的 `--model_dir`(即 **`${MODEL_DIR}`**)为同一字符串路径**。以下为带日志落盘的**完整命令**(`--model_args` 使用双引号以便展开 **`${MODEL_DIR}`**): + +```bash +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +no_proxy=127.0.0.1,localhost,::1 \ +lm_eval --model local-completions \ + --model_args "{\"model\":\"deepseek-ai/DeepSeek-R1\", \"base_url\":\"http://localhost:8089/v1/completions\", \"max_length\": 16384, \"tokenizer\":\"${MODEL_DIR}\"}" \ + --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +- **`LOG_DIR`**:与启动服务一节相同;若仅调试不重定向,去掉 `\` 续行及最后的 `>> "${LOG_DIR}/eval_gsm8k.log" 2>&1` 即可在前台查看输出。 +- **`MODEL_DIR`**:须与 server 启动命令中的 `--model_dir` 一致;路径随环境变化时的默认试跑与向用户确认见「执行约定」。 +- 若环境需要,可同时设置 `NO_PROXY=127.0.0.1,localhost,::1`(或与团队约定一致的列表)。 + +## 执行约定(不要额外写“专用启动脚本”) + +**模型目录(随环境变化)**:`MODEL_DIR` 在不同机器上路径不同。**首轮试跑**可先用下列默认(与常见本地部署对应;若本机不存在则跳过默认、直接执行下一步「向用户确认」): + +```bash +export MODEL_DIR=/mtc/models/DeepSeek-R1 +``` + +若按默认路径 **export** 后仍无法启动服务,或日志中出现**明确的模型路径 / 权重加载 / 文件不存在**等错误,**不要反复盲试**:根据日志判断为路径问题时,**请用户提供**当前环境下实际的主模型目录,更新 `export MODEL_DIR=…` 后再执行(且保证 **`MODEL_DIR` 与 `lm_eval` 的 `tokenizer` 仍为同一路径**)。 + +1. **后台启动 server**:用 shell 后台或终端任务跑 `python -m lightllm.server.api_server ...`,**并将该进程输出重定向到本轮日志目录下的日志文件**(见上文「日志目录(含 summary.txt)」);排查问题时 **tail** 该文件,而不是依赖未落盘的终端缓冲。 +2. **不要用 health 接口** 判断就绪;改为探测 **端口 8089 是否处于 listen**(例如 `ss -tlnp` / `lsof -i :8089` 等,与系统一致即可)。 +3. **等待启动**:若端口未就绪,约 **每 20 秒** 查看一次**服务日志文件**,区分仍在启动还是已报错退出;报错则写入日志目录下的 **`summary.txt`**(或先写服务日志再在 `summary.txt` 引用)并停止,不要继续盲等。 +4. **维护 `summary.txt`**:位于**日志目录**;记录**本条使用的完整启动命令**(或等价摘要)、**端口检测结果**、**`lm_eval` 关键输出**;全部结束后在该文件内写**最终汇总**(是否成功、主要指标或失败原因)。可与用户口头摘要对照,但以日志目录中 **`summary.txt`** 为归档准绳。 +5. **全部完成后**:确认日志目录下的 **`summary.txt`** 已包含完整最终总结;原始 server / eval 日志保留在同目录(或子目录)中备查。 + +## 输出文件 + +- **`summary.txt`**:仅位于**本轮日志目录**,作为本次 **Base–TP** 评测的**最终总结**文档。 +- **服务与评测日志**:全部落在**同一日志目录**(建议按任务命名文件或分子目录),不得与未指定目录混写。 diff --git a/skills/test_model/deepseekr1-mtp-ep/SKILL.md b/skills/test_model/deepseekr1-mtp-ep/SKILL.md new file mode 100644 index 0000000000..dd6da77c3b --- /dev/null +++ b/skills/test_model/deepseekr1-mtp-ep/SKILL.md @@ -0,0 +1,149 @@ +--- +name: test-model-deepseekr1-mtp-ep +description: >- + Runs LightLLM DeepSeek-R1 EP MoE + MTP (EAGLE) server variants and GSM8K lm_eval + against localhost. Requires each full run to use a dedicated log directory: persist every + api_server process log under that tree (per-variant subdirectories recommended), + write the consolidated summary to summary.txt in that same log directory, and keep artifacts + separated from other test runs. Use when running DeepSeek-R1 MTP EP accuracy workflows + or when the user asks to run these four server configurations one-by-one with logged results. +--- + +# DeepSeek-R1 MTP + EP MoE 串行评测流程 + +按固定顺序依次启动四种 `api_server` 配置;每次待服务就绪后执行 `lm_eval`。整轮评测须落在**同一日志目录**内归档日志与最终结论(见「日志目录」);具体操作见「启动说明」。 + +## 日志目录(含 `summary.txt`) + +- 每次完整评测(四种变体串行)先选定或新建**一个日志目录**(例如带时间戳或任务名),与其它测试轮次分开,便于区分管理。 +- **所有 `api_server` 进程的标准输出/错误**须写入该目录下文件(建议每种变体单独子目录,如 `variant_01_baseline/`、`variant_02_tpsp_mix/`;或同级命名 `server_01_baseline.log` 等,团队任选其一,保持可追溯)。 +- **`summary.txt` 固定放在该日志目录下**,汇总整轮测试:各变体启动参数摘要、`lm_eval` 关键结果、失败原因与最终对比;**不再**把「最终总结」散落在当前工作目录或其它路径。 +- `lm_eval` 终端输出也要有单独的日志文件(如 `eval_gsm8k.log`),**`summary.txt`** 仍承担**总览结论**角色。 + +## 启动说明 + +本节包含:启动前检查 → 启动服务的命令模板(可变项说明)→ 四种完整 server 命令 → 评测命令。 + +### 启动前检查 + +开跑前先确认资源可用;**不满足则先清理相关进程,再进入后续变体**。 + +1. **显卡占用**:用 `nvidia-smi`(或与集群一致的占用查看方式)检查目标 GPU 是否被无关任务占满;若有冲突进程,结束后再启动本评测。 +2. **端口**:服务固定 **`8089`**;用 `ss -tlnp`、`lsof -i :8089` 等确认**无进程监听**该端口;若已被占用,查出 PID 并结束占用进程后再启动。 + +### 启动服务的命令模板(可变项) + +下列命令中出现的可变项含义如下(其余为固定写法): + +| 可变项 | 含义 | +|--------|------| +| `LOG_DIR` | 本轮评测日志目录,建议**绝对路径**;执行前 `export LOG_DIR=…`。 | +| `MODEL_DIR` | 主模型目录,对应 `--model_dir`;与 `lm_eval` 的 `tokenizer` 必须一致。 | +| `MTP_DRAFT_DIR` | MTP 草稿模型目录,对应 `--mtp_draft_model_dir`。 | +| `server_*.log`、`eval_*.log` | 仅文件名示例,可按变体重命名。 | + +开跑前在同一 shell 中导出三类路径(将引号内整段替换为本机绝对路径;**勿写死下文未给出的机器路径**): + +```bash +export LOG_DIR='〈日志根目录〉' +export MODEL_DIR='〈主模型目录,对应 --model_dir〉' +export MTP_DRAFT_DIR='〈MTP 草稿目录,对应 --mtp_draft_model_dir〉' +``` + +首次试跑可用的**默认路径组合**见「执行约定」;与当前环境不符时再改为用户提供的目录。 + +### 四种 server 启动命令(按顺序逐个测) + +每条 **单独** 跑完「启动 → 等就绪 → 评测 → 写入日志目录下的日志 → 停服务」再进入下一条,不要并行多个 server。以下为**可直接执行**的后台启动形式(已含 `nohup` 与日志重定向);若暂时不需落盘,可自行去掉 `nohup`、`>> … 2>&1 &` 并在前台调试。命令中 **`${MODEL_DIR}`、`${MTP_DRAFT_DIR}`** 须已由上文 `export` 赋值。 + +#### 变体 1:基线(EP + MTP) + +```bash +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 \ +nohup python -m lightllm.server.api_server \ + --enable_ep_moe --model_dir "${MODEL_DIR}" --tp 8 --dp 8 --port 8089 \ + --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 \ + --max_req_total_len 56000 \ + --mtp_mode eagle_with_att --mtp_draft_model_dir "${MTP_DRAFT_DIR}" --mtp_step 2 \ + >> "${LOG_DIR}/server_01_baseline.log" 2>&1 & +``` + +#### 变体 2:`--enable_tpsp_mix_mode` + +```bash +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 \ +nohup python -m lightllm.server.api_server \ + --enable_ep_moe --model_dir "${MODEL_DIR}" --tp 8 --dp 8 --port 8089 \ + --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 \ + --max_req_total_len 56000 \ + --mtp_mode eagle_with_att --mtp_draft_model_dir "${MTP_DRAFT_DIR}" --mtp_step 2 \ + --enable_tpsp_mix_mode \ + >> "${LOG_DIR}/server_02_tpsp_mix.log" 2>&1 & +``` + +#### 变体 3:prefill / decode microbatch overlap + +```bash +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 \ +nohup python -m lightllm.server.api_server \ + --enable_ep_moe --model_dir "${MODEL_DIR}" --tp 8 --dp 8 --port 8089 \ + --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 \ + --max_req_total_len 56000 \ + --mtp_mode eagle_with_att --mtp_draft_model_dir "${MTP_DRAFT_DIR}" --mtp_step 2 \ + --enable_prefill_microbatch_overlap --enable_decode_microbatch_overlap \ + >> "${LOG_DIR}/server_03_overlap.log" 2>&1 & +``` + +#### 变体 4:overlap + `--enable_dp_prefill_balance` + +```bash +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 \ +nohup python -m lightllm.server.api_server \ + --enable_ep_moe --model_dir "${MODEL_DIR}" --tp 8 --dp 8 --port 8089 \ + --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 \ + --max_req_total_len 56000 \ + --mtp_mode eagle_with_att --mtp_draft_model_dir "${MTP_DRAFT_DIR}" --mtp_step 2 \ + --enable_prefill_microbatch_overlap --enable_decode_microbatch_overlap \ + --enable_dp_prefill_balance \ + >> "${LOG_DIR}/server_04_overlap_dp_balance.log" 2>&1 & +``` + +### 评测命令(每个变体各执行一次) + +服务就绪后执行(本地回环走代理时用 `no_proxy` / `NO_PROXY` 排除本机)。**`model_args` 中 `tokenizer` 必须与本次 server 的 `--model_dir`(即 **`${MODEL_DIR}`**)为同一字符串路径**。以下为带日志落盘的**完整命令**(`--model_args` 使用双引号以便展开 **`${MODEL_DIR}`**): + +```bash +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +no_proxy=127.0.0.1,localhost,::1 \ +lm_eval --model local-completions \ + --model_args "{\"model\":\"deepseek-ai/DeepSeek-R1\", \"base_url\":\"http://localhost:8089/v1/completions\", \"max_length\": 16384, \"tokenizer\":\"${MODEL_DIR}\"}" \ + --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +- **`LOG_DIR`**:与启动服务一节相同;若仅调试不重定向,去掉 `\` 续行及最后的 `>> "${LOG_DIR}/eval_gsm8k.log" 2>&1` 即可在前台查看输出。 +- **`MODEL_DIR`**:须与 server 启动命令中的 `--model_dir` 一致;路径随环境变化时的默认试跑与向用户确认见「执行约定」。 +- 若环境需要,可同时设置 `NO_PROXY=127.0.0.1,localhost,::1`(或与团队约定一致的列表)。 + +## 执行约定(不要额外写“专用启动脚本”) + +**模型与 MTP 目录(随环境变化)**:`MODEL_DIR`(主模型)、`MTP_DRAFT_DIR`(MTP 草稿)在不同机器上路径不同。**首轮试跑**可先用下列默认组合(与本文档常见部署对应;若本机不存在则跳过默认、直接执行下一步「向用户确认」): + +```bash +export MODEL_DIR=/mtc/models/DeepSeek-R1 +export MTP_DRAFT_DIR=/mtc/models/DeepSeek-R1-NextN +``` + +若按默认路径 **export** 后仍无法启动服务,或日志中出现**明确的模型路径 / 权重加载 / 文件不存在**等错误,**不要反复盲试**:根据日志判断为路径问题时,**请用户提供**当前环境下实际的主模型目录与 MTP 草稿目录,更新 `export MODEL_DIR=…`、`export MTP_DRAFT_DIR=…` 后再执行(且保证 **`MODEL_DIR` 与 `lm_eval` 的 `tokenizer` 仍为同一路径**)。 + +1. **后台启动 server**:用 shell 后台或终端任务跑当前变体的 `python -m lightllm.server.api_server ...`,**并将该进程输出重定向到本轮日志目录下的日志文件**(见上文「日志目录(含 summary.txt)」);排查问题时 tail 该文件,而不是依赖未落盘的终端缓冲。 +2. **不要用 health 接口** 判断就绪;改为探测 **端口 8089 是否处于 listen**(例如 `ss -tlnp` / `lsof -i :8089` 等,与系统一致即可)。 +3. **等待启动**:若端口未就绪,约 **每 20 秒** 查看一次**该变体对应的服务日志文件**,区分仍在启动还是已报错退出;报错则写入日志目录下的 `summary.txt`(或先写变体日志再在该汇总文件中引用)并停止该变体,不要继续盲等。 +4. **维护 `summary.txt`**:位于**日志目录**;随进度追加每个变体的标记块——**本条使用的完整启动命令**(或等价摘要)、**端口检测结果**、**lm_eval 关键输出**;全部结束后在该文件内写**最终汇总**(各配置成败、指标对比或失败原因)。可与用户口头摘要对照,但以日志目录中 **`summary.txt`** 为归档准绳。 +5. **变体之间**:停止上一进程的 server,再启动下一变体(避免端口占用)。 +6. **全部完成后**:确认日志目录下的 **`summary.txt`** 已包含完整最终总结;原始 server / eval 日志保留在同目录(或子目录)中备查。 + +## 输出文件 + +- **`summary.txt`**:仅位于**本轮日志目录**,作为整次四变体测试的**最终总结**文档。 +- **服务与评测日志**:全部落在**同一日志目录**(建议按变体分子目录或分文件名),不得与未指定目录混写。 diff --git a/skills/test_model/deepseekr1-mtp-tp/SKILL.md b/skills/test_model/deepseekr1-mtp-tp/SKILL.md new file mode 100644 index 0000000000..d6cbef5271 --- /dev/null +++ b/skills/test_model/deepseekr1-mtp-tp/SKILL.md @@ -0,0 +1,106 @@ +--- +name: test-model-deepseekr1-mtp-tp +description: >- + DeepSeek-R1 MTP-TP test: LightLLM api_server with MTP (EAGLE) draft, tensor parallel + only (--tp 8, no --dp, no EP MoE), plus GSM8K lm_eval on localhost. Distinct from the + MTP-EP-TPDP skill which uses --tp 8 --dp 8 and EP MoE. Requires a dedicated log directory, + summary.txt, tokenizer aligned with MODEL_DIR. Use for TP-only MTP gsm8k accuracy runs. +--- + +# DeepSeek-R1 **MTP–TP**(仅张量并行 `--tp 8`,无 DP / 无 EP)本地 GSM8K 评测 + +**测试标识**:并行方式为 **`--tp 8` 单路 TP**,不包含 **`--dp`** 与 **`--enable_ep_moe`**。用于与 **MTP–EP–TPDP**(`--tp 8 --dp 8` + EP MoE)流水线区分。 + +启动一组 `api_server`(`eagle_with_att` MTP),待就绪后对同一进程执行一次 `lm_eval`(任务 `gsm8k`)。全过程产物落在**同一日志目录**(见「日志目录」);命令与流程见「启动说明」。 + +## 日志目录(含 `summary.txt`) + +- 先选定或新建**一个日志目录**(例如带时间戳或任务名),与其它测试轮次分开。 +- **`api_server` 的标准输出/错误**写入该目录下文件(示例文件名 `server_mtp_tp.log`;可按团队习惯改名或分子目录)。 +- **`summary.txt` 固定放在该日志目录下**,写入本轮启动参数摘要、`lm_eval` 关键结果与简要结论。 +- `lm_eval` 终端输出建议单独落盘(如 `eval_gsm8k.log`);**`summary.txt`** 仍为整次任务的**总览结论**。 + +## 启动说明 + +本节包含:启动前检查 → 启动服务的命令模板(可变项说明)→ 一条完整 server 命令 → 评测命令。 + +### 启动前检查 + +开跑前先确认资源可用;**不满足则先清理相关进程,再启动**。 + +1. **显卡独占**:用 `nvidia-smi` 检查 **8 张 GPU 均无其它推理任务占用**(显存应基本空闲);若有冲突进程,结束后再启动。本评测 `--tp 8` 需占满 8 卡,勿与其它 `api_server` 同卡混跑。 +2. **端口独占**:服务固定 **`8089`**;用 `ss -tlnp`、`lsof -i :8089` 等确认 **无进程监听** 该端口;若已被占用,结束占用进程后再启动。 + +### 启动服务的命令模板(可变项) + +下列符号与 EP–TPDP 版评测共用含义: + +| 可变项 | 含义 | +|--------|------| +| `LOG_DIR` | 本轮评测日志目录,建议**绝对路径**;执行前 `export LOG_DIR=…`。 | +| `MODEL_DIR` | 主模型目录,对应 `--model_dir`;与 `lm_eval` 的 `tokenizer` 必须一致。 | +| `MTP_DRAFT_DIR` | MTP 草稿模型目录,对应 `--mtp_draft_model_dir`。 | + +开跑前在同一 shell 中导出路径(引号内替换为本机绝对路径): + +```bash +export LOG_DIR='〈日志根目录〉' +export MODEL_DIR='〈主模型目录,对应 --model_dir〉' +export MTP_DRAFT_DIR='〈MTP 草稿目录,对应 --mtp_draft_model_dir〉' +``` + +首次试跑可用的**默认路径组合**见「执行约定」。 + +### 一条 server 启动命令(后台落盘) + +以下为 **MTP–TP** 固定形态:**`--tp 8`**,**无 `--dp`**。可直接执行的后台形式(已含 `nohup` 与日志重定向);调试时可去掉 `nohup` 与 `>> … 2>&1 &` 改前台。**`${MODEL_DIR}`、`${MTP_DRAFT_DIR}`、`${LOG_DIR}`** 须已由上文 `export` 赋值。 + +`--mem_fraction` 使用 **0.65**(较 0.75 更省显存,MTP 加载主模型与草稿时不易 OOM)。 + +```bash +LOADWORKER=18 \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" --tp 8 --port 8089 \ + --mem_fraction 0.65 --batch_max_tokens 6000 \ + --mtp_mode eagle_with_att --mtp_draft_model_dir "${MTP_DRAFT_DIR}" --mtp_step 2 \ + >> "${LOG_DIR}/server_mtp_tp.log" 2>&1 & +``` + +### 评测命令(服务就绪后执行一次) + +本地回环需排除代理:`no_proxy` / `NO_PROXY`。**`tokenizer` 与 `--model_dir`(`${MODEL_DIR}`)须为同一路径**。以下为带日志落盘的**完整命令**: + +```bash +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +no_proxy=127.0.0.1,localhost,::1 \ +lm_eval --model local-completions \ + --model_args "{\"model\":\"deepseek-ai/DeepSeek-R1\", \"base_url\":\"http://localhost:8089/v1/completions\", \"max_length\": 16384, \"tokenizer\":\"${MODEL_DIR}\"}" \ + --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +- **`LOG_DIR`**:与 server 一节一致;若仅调试不重定向,可去掉末尾 `>> "${LOG_DIR}/eval_gsm8k.log" 2>&1`。 +- **`MODEL_DIR`**:与 server 的 `--model_dir` 一致;默认试跑与用户确认路径见「执行约定」。 +- 若环境需要,可同时设置 `NO_PROXY=127.0.0.1,localhost,::1`。 + +## 执行约定(不要额外写“专用启动脚本”) + +**模型与 MTP 目录(随环境变化)**:`MODEL_DIR`、`MTP_DRAFT_DIR` 在不同机器上路径不同。**首轮试跑**可先使用: + +```bash +export MODEL_DIR=/mtc/models/DeepSeek-R1 +export MTP_DRAFT_DIR=/mtc/models/DeepSeek-R1-NextN +``` + +若默认路径不存在或服务报错指向路径/权重加载失败,**请用户提供**本机实际目录并更新两个 `export`;**保持 `MODEL_DIR` 与 `lm_eval` 中 `tokenizer` 一致**。 + +1. **后台启动 server**:将 `api_server` 输出重定向到日志目录下文件(见「日志目录」);排查时用 `tail` 查看该日志。 +2. **不要用 health 接口** 判断就绪;改为探测 **端口 8089 是否 listen**(例如 `ss -tlnp` / `lsof -i :8089`)。 +3. **等待启动**:端口未就绪时约 **每 20 秒** 查看服务日志,区分仍在启动或已报错;路径类错误按上文向用户确认目录。 +4. **维护 `summary.txt`**:记录完整启动命令摘要(须能看出 **`--tp 8`、无 `--dp`**)、端口检测结果、`lm_eval` 关键输出与最终结论。 +5. **全部完成后**:确认 **`summary.txt`** 完整;server / eval 原始日志保留在同一日志目录备查。 + +## 输出文件 + +- **`summary.txt`**:位于**本轮日志目录**,作为本次 **MTP–TP** 评测的**最终总结**。 +- **服务与评测日志**:与 **`summary.txt`** 落在**同一日志目录**。 diff --git a/skills/test_model/deepseekv32-ep/SKILL.md b/skills/test_model/deepseekv32-ep/SKILL.md new file mode 100644 index 0000000000..0acc78057e --- /dev/null +++ b/skills/test_model/deepseekv32-ep/SKILL.md @@ -0,0 +1,122 @@ +--- +name: test-model-deepseekv32-ep +description: >- + Runs LightLLM DeepSeek-V3.2 EP MoE gsm8k: api_server with --tp 8 --dp 8 --enable_ep_moe, + tool_call_parser deepseekv32, reasoning_parser deepseek-v3, graph_max_batch_size 32, + mem_fraction 0.8, LOADWORKER 14, port 8000 aligned with lm_eval base_url. Requires a + dedicated log directory, api_server and eval logs, summary.txt consolidated report. + lm_eval uses tokenizer_backend=null (server-side tokenization) because local + transformers does not recognize model_type deepseek_v32. Distinct from R1 MTP/Base + flows. Use for V3.2 EP MoE gsm8k accuracy on LightLLM. +--- + +# DeepSeek-V3.2 **EP**(`--tp 8`、`--dp 8`、`--enable_ep_moe`)本地 GSM8K 评测 + +**测试标识**:本流程针对 **DeepSeek-V3.2**,启用 **EP MoE**(**`--enable_ep_moe`**)与 **TP+DP**(**`--tp 8 --dp 8`**),并包含 **`tool_call_parser deepseekv32`**、**`reasoning_parser deepseek-v3`**、**`graph_max_batch_size 32`**、**`mem_fraction 0.8`** 等与推理栈相关的参数。与 **Base–R1**、**MTP–TP / MTP–EP**(R1 系列)区分。 + +**监听端口**:`api_server` 与 `lm_eval` 的 **`base_url` 必须使用同一端口**;本流程固定为 **`8000`**(下文 server 命令含 **`--port 8000`**,评测 URL 为 `http://localhost:8000/v1/completions`)。 + +启动一组 `api_server`,待端口就绪后执行一次 `lm_eval`(任务 **`gsm8k`**,`batch_size` **500**)。整轮产物须落在**同一日志目录**内归档日志与 **`summary.txt`**(见「日志目录」);具体操作见「启动说明」。 + +## 日志目录(含 `summary.txt`) + +- 每次评测先选定或新建**一个日志目录**(例如带时间戳或任务名),与其它测试轮次分开,便于区分管理。 +- **所有 `api_server` 进程的标准输出/错误**须写入该目录下文件(示例同级命名 **`server_v32_ep.log`**;也可分子目录,团队任选其一,保持可追溯)。 +- **`summary.txt` 固定放在该日志目录下**,写入本轮启动参数摘要、`lm_eval` 关键结果、失败原因或简要对比;**不再**把「最终总结」散落在当前工作目录或其它路径。 +- `lm_eval` 终端输出也要有单独的日志文件(如 **`eval_gsm8k.log`**);**`summary.txt`** 仍承担**总览结论**角色。 + +## 启动说明 + +本节包含:启动前检查 → 启动服务的命令模板(可变项说明)→ 一条完整 server 命令 → 评测命令。 + +### 启动前检查 + +开跑前先确认资源可用;**不满足则先清理相关进程**,再启动服务与评测。 + +1. **显卡占用**:用 `nvidia-smi`(或与集群一致的占用查看方式)检查目标 GPU 是否被无关任务占满;若有冲突进程,结束后再启动本评测(本配置为 **TP+DP**,需足够 GPU 资源)。 +2. **端口**:服务固定 **`8000`**(与下文 `lm_eval` 的 `base_url` 端口一致);用 `ss -tlnp`、`lsof -i :8000` 等确认**无进程监听**该端口;若已被占用,查出 PID 并结束占用进程后再启动。 + +### 启动服务的命令模板(可变项) + +下列命令中出现的可变项含义如下(其余为固定写法): + +| 可变项 | 含义 | +|--------|------| +| `LOG_DIR` | 本轮评测日志目录,建议**绝对路径**;执行前 `export LOG_DIR=…`。 | +| `MODEL_DIR` | 主模型目录,对应 `--model_dir`;与 `lm_eval` 的 `tokenizer` 必须一致。 | +| `server_*.log`、`eval_*.log` | 仅文件名示例,可按任务重命名。 | + +开跑前在同一 shell 中导出路径(将引号内整段替换为本机绝对路径;**勿写死下文未给出的机器路径**): + +```bash +export LOG_DIR='〈日志根目录〉' +export MODEL_DIR='〈主模型目录,对应 --model_dir〉' +``` + +首次试跑可用的**默认 `MODEL_DIR`** 见「执行约定」;与当前环境不符时再改为用户提供的目录。 + +### 一条 server 启动命令(后台落盘) + +本条为 **DeepSeek-V3.2 EP** 固定形态:**`LOADWORKER=14`**,**`--tp 8 --dp 8 --enable_ep_moe`**,**`--port 8000`**,以及 **`tool_call_parser` / `reasoning_parser` / `graph_max_batch_size` / `mem_fraction`** 等与脚本一致的参数。以下为**可直接执行**的后台启动形式(已含 `nohup` 与日志重定向);若暂时不需落盘,可自行去掉 `nohup`、`>> … 2>&1 &` 并在前台调试。命令中 **`${MODEL_DIR}`、`${LOG_DIR}`** 须已由上文 `export` 赋值。 + +```bash +LOADWORKER=14 \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" --tp 8 \ + --graph_max_batch_size 32 \ + --tool_call_parser deepseekv32 \ + --mem_fraction 0.8 \ + --reasoning_parser deepseek-v3 \ + --dp 8 --enable_ep_moe \ + --port 8000 \ + >> "${LOG_DIR}/server_v32_ep.log" 2>&1 & +``` + +### 评测命令(服务就绪后执行一次) + +服务就绪后执行(本地回环走代理时用 `no_proxy` / `NO_PROXY` 排除本机)。**`base_url` 中的端口须为 `8000`,与 `api_server` 的 `--port` 一致。** 以下为带日志落盘的**完整命令**: + +```bash +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +no_proxy=127.0.0.1,localhost,::1 \ +lm_eval --model local-completions \ + --model_args '{"model":"deepseek-ai/DeepSeek-V3.2", "base_url":"http://localhost:8000/v1/completions", "tokenizer_backend":null, "eos_string":"<|end▁of▁sentence|>"}' \ + --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +> **为什么用 `tokenizer_backend=null` 而非 `tokenizer=${MODEL_DIR}`**:`local-completions` 默认会用 `transformers.AutoTokenizer.from_pretrained(${MODEL_DIR})` 在本地加载 HF tokenizer,但当前环境的 **transformers 不识别 `model_type: deepseek_v32`**(`KeyError: 'deepseek_v32'` → rope `AttributeError`),评测在加载 tokenizer 阶段即崩溃,根本跑不到推理。设 **`tokenizer_backend=null`** 后 lm_eval 不再本地加载 tokenizer,直接把 **prompt 文本**发给 server,由 lightllm 服务端用真正的 deepseek_v32 tokenizer 分词——更贴合实际且无需本地 HF 适配。`eos_string` 显式给出 DeepSeek 的结束符以消除 “Cannot determine EOS string” 告警(gsm8k 本身也带 stop 序列)。`tokenized_requests` 会被自动关闭、不再做 context 长度校验(gsm8k 5-shot prompt 很短,无需截断)。 +> 若哪天升级了能识别 `deepseek_v32` 的 transformers,可改回 `"tokenizer":"${MODEL_DIR}"` 形式(届时 tokenizer 须与 `--model_dir` 同一路径)。 + +- **`LOG_DIR`**:与启动服务一节相同;若仅调试不重定向,去掉 `\` 续行及最后的 `>> "${LOG_DIR}/eval_gsm8k.log" 2>&1` 即可在前台查看输出。 +- **tokenizer**:本命令用 `tokenizer_backend=null`,评测端不再依赖 `MODEL_DIR` 下的 HF tokenizer(分词在 server 端完成),故 `MODEL_DIR` 路径变化不影响评测命令;server 启动命令中的 `--model_dir` 仍按「执行约定」处理。 +- 若环境需要,可同时设置 `NO_PROXY=127.0.0.1,localhost,::1`(或与团队约定一致的列表)。 + +## 执行约定(不要额外写“专用启动脚本”) + +**模型目录(随环境变化)**:`MODEL_DIR` 在不同机器上路径不同。**首轮试跑**可先用下列默认(与本文档常见部署对应;若本机不存在则跳过默认、直接执行下一步「向用户确认」): + +```bash +export MODEL_DIR=/mtc/models/DeepSeek-V3.2 +``` + +若按默认路径 **export** 后仍无法启动服务,或日志中出现**明确的模型路径 / 权重加载 / 文件不存在**等错误,**不要反复盲试**:根据日志判断为路径问题时,**请用户提供**当前环境下实际的主模型目录,更新 `export MODEL_DIR=…` 后再执行(且保证 **`MODEL_DIR` 与 `lm_eval` 的 `tokenizer` 仍为同一路径**)。 + +1. **后台启动 server**:用 shell 后台或终端任务跑 `python -m lightllm.server.api_server ...`,**并将该进程输出重定向到本轮日志目录下的日志文件**(见上文「日志目录(含 summary.txt)」);排查问题时 **tail** 该文件,而不是依赖未落盘的终端缓冲。 +2. **不要用 health 接口** 判断就绪;改为探测 **端口 8000 是否处于 listen**(例如 `ss -tlnp` / `lsof -i :8000` 等,与系统一致即可)。 +3. **等待启动**:若端口未就绪,约 **每 20 秒** 查看一次**服务日志文件**,区分仍在启动还是已报错退出;报错则写入日志目录下的 **`summary.txt`**(或先写服务日志再在 `summary.txt` 引用)并停止,不要继续盲等。 +4. **维护 `summary.txt`**:位于**日志目录**;记录**本条使用的完整启动命令**(须能看出 **`--tp 8`、`--dp 8`、EP MoE**)、**端口检测结果**、**`lm_eval` 关键输出**;全部结束后在该文件内写**最终汇总**(是否成功、主要指标或失败原因)。可与用户口头摘要对照,但以日志目录中 **`summary.txt`** 为归档准绳。 +5. **全部完成后**:确认日志目录下的 **`summary.txt`** 已包含完整最终总结;原始 server / eval 日志保留在同目录(或子目录)中备查。 + +### 服务启动 OK 判定经验(本流程补充) + +- **不要只看“主进程在不在”**:`python -m lightllm.server.api_server` 进程存活不代表可用;必须至少满足“`8000` 已 listen”再进入评测。 +- **长时间加载不等于失败**:DeepSeek-V3.2 EP 首次加载可能持续数分钟。若日志持续出现 `Loading model weights ...` 进度推进,视为“仍在启动”,继续按 20 秒间隔观察。 +- **判定“启动 OK”建议三要素**:① `8000` 端口监听;② 服务日志无新的 `OutOfMemoryError`/Traceback;③ 用一条最小请求(如 1 条 completions/chat 请求)拿到 200 或有效响应,再跑 `lm_eval`。 +- **出现 OOM 要先清残留再重试**:一旦日志出现 `torch.OutOfMemoryError`,先结束该轮 `api_server` 及其派生进程(含 `hypercorn`/`lightllm::...` 子进程),确认 `8000` 释放后再重启,避免“旧进程占资源导致假失败”。 +- **重试优先调启动参数而非盲等**:若 OOM 发生在权重加载阶段,优先降低加载/显存压力(例如使用更保守的 `mem_fraction`),并在 `summary.txt` 记录“失败参数 -> 重试参数 -> 结果”。 + +## 输出文件 + +- **`summary.txt`**:仅位于**本轮日志目录**,作为本次 **DeepSeek-V3.2 EP** 评测的**最终总结**文档。 +- **服务与评测日志**:全部落在**同一日志目录**(建议按任务命名文件或分子目录),不得与未指定目录混写。 diff --git a/skills/test_model/qwen2.5-14b-fp8kv-gsm8k/SKILL.md b/skills/test_model/qwen2.5-14b-fp8kv-gsm8k/SKILL.md new file mode 100644 index 0000000000..6c932e43e2 --- /dev/null +++ b/skills/test_model/qwen2.5-14b-fp8kv-gsm8k/SKILL.md @@ -0,0 +1,137 @@ +--- +name: test-model-qwen2.5-14b-fp8kv-gsm8k +description: >- + LightLLM Qwen2.5-14B-Instruct GSM8K with FP8 KV cache quantization: either fp8kv_sph + (per-head calibration JSON) or fp8kv_spt (per-tensor calibration JSON). Single api_server + tp 2 fixed HTTP port 8089 (not configurable), lm_eval local-completions. Assign GPUs via nvidia-smi then export + CUDA_VISIBLE_DEVICES. Before starting api_server, cwd must be LightLLM repo root; pass + --kv_quant_calibration_config_path as the repo-relative path from the table row that matches + --llm_kv_type (fp8kv_sph with per-head JSON only; fp8kv_spt with per-tensor JSON only; no absolute path, + no REPO_ROOT/CALIB_JSON shell concatenation). If default MODEL_DIR path is missing or + load fails with path errors, ask the user for the correct MODEL_DIR. LOG_DIR, + summary.txt, port listen checks (not health), no_proxy, background server with log redirect. + Two variants documented in one skill. +--- + +# Qwen2.5-14B-Instruct **FP8 KV Cache(`fp8kv_sph` / `fp8kv_spt`)** GSM8K 评测 + +**测试标识**:同一 **`Qwen2.5-14B-Instruct`** 权重下,用 **`api_server`** 跑 **单机 TP=2**;通过 **`--llm_kv_type`** 区分两种 **FP8 KV 量化形态**,每种形态对应 **不同的标定 JSON**(**per-head** vs **per-tensor**)。**每一轮只选其中一种形态**跑通:先起服务,再 **`lm_eval`**。 + +| 形态 | `--llm_kv_type` | 标定配置(相对 LightLLM 仓库根目录) | +|------|-----------------|--------------------------------------| +| **SPH**(per-head) | **`fp8kv_sph`** | **`test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_14b.json`** | +| **SPT**(per-tensor) | **`fp8kv_spt`** | **`test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_14b.json`** | + +**配对规则(必守)**:**`--llm_kv_type` 与 `--kv_quant_calibration_config_path` 必须取自上表同一行**。 **`fp8kv_sph` 只能** 搭配 **per_head** 标定 JSON;**`fp8kv_spt` 只能** 搭配 **per_tensor** 标定 JSON。只改其一会导致启动或运行时报错。**不要**在命令里写 **`--llm_kv_type "${LLM_KV_TYPE}"`** 却**固定**另一条 `--kv_quant_calibration_config_path`(二者会漂移);应像下文 **按形态分块**:每一块内两条参数**字面一致、成对出现**。 + +**端口**:**固定 `8089`**,**不可改**(与脚本一致;**`--port`** 与 **`lm_eval` 的 `base_url`** 均须为 **`8089`**)。**`--tp 2`** 需要 **2 张 GPU**。 + +整轮产物落在**同一日志目录**:**`summary.txt`**、**`server.log`**、**`eval_gsm8k.log`**;**不要**写复杂聚合启动脚本,按下面块**手动**或复制为独立命令执行。 + +## 日志目录(含 `summary.txt`) + +- 每次评测新建或选定 **`LOG_DIR`**(建议带任务名与时间戳,例如 `…/qwen25_fp8kv_sph_〈时间〉` 与 `…/qwen25_fp8kv_spt_〈时间〉` **分开**,便于对比两种形态)。 +- **`api_server`** 标准输出/错误 → **`"${LOG_DIR}/server.log"`**(后台 **`nohup … >> … 2>&1 &`**)。 +- **`lm_eval`** → **`"${LOG_DIR}/eval_gsm8k.log"`**。 +- **`summary.txt`**:本轮 **`--llm_kv_type`(`fp8kv_sph` / `fp8kv_spt`)**、启动命令摘要、端口检测结果、**`lm_eval` 要点**、失败原因与结论。 + +## 启动前检查 + +1. **显卡**:**`--tp 2`**,需 **2 张物理 GPU**。**不要写死卡号**:先 **`nvidia-smi`**,再 **`export CUDA_VISIBLE_DEVICES='i,j'`**。 +2. **端口**:**`8089`** 未被占用(**`ss -tlnp`** / **`lsof -i :8089`**)。 +3. **标定文件与 KV 形态**:**`--kv_quant_calibration_config_path`** 须为 **与本轮 `--llm_kv_type` 上表同一行** 的相对路径(**不要**写成磁盘绝对路径;**不要** `fp8kv_sph` 配 per_tensor 或 `fp8kv_spt` 配 per_head)。启动 **`python -m lightllm.server.api_server` 时,shell 当前目录须已是仓库根**(先由 Agent **`cd` 到检出根** 再执行 `nohup`;或在一行里 **`cd '…根…' && nohup python …`**)。确认 `os.path.exists` 意义下该相对路径可读。**禁止** `export CALIB_JSON="${REPO_ROOT}/…"` 这类环境变量拼接。 +4. **模型目录 `MODEL_DIR`**:启动前确认路径存在(例如 **`test -d "${MODEL_DIR}"`**)且内含权重;默认可用 **`/mtc/models/Qwen2.5-14B-Instruct`**。若默认不存在、或服务 / 日志出现 **找不到模型目录、权重文件缺失、路径类加载失败** 等,**不要盲换路径重试**:**向用户询问**本机正确的 **`MODEL_DIR` 绝对路径**,待用户回复后更新 **`export MODEL_DIR=…`**,并在 **`summary.txt`** 中记录最终采用的路径;**`--model_dir` 与 `lm_eval` 的 `tokenizer` 必须始终为同一字符串**。 +5. **代理**:启动 server 前 **`export http_proxy=`**、**`export https_proxy=`**;评测时设置 **`no_proxy`**(见评测命令)。 + +## 可变项 + +| 变量 | 含义 | +|------|------| +| `LOG_DIR` | 本轮日志目录(绝对路径)。 | +| `MODEL_DIR` | **`--model_dir`** 与 **`lm_eval` 的 `tokenizer`**,须为**同一路径**。默认试跑 **`/mtc/models/Qwen2.5-14B-Instruct`**;**不可用或报错时向用户询问**正确目录后再 `export`(见「执行约定」与启动前检查第 4 条)。 | +| `LLM_KV_TYPE` | 即 **`--llm_kv_type`**:**`fp8kv_sph`**(上表 **SPH / per-head**)或 **`fp8kv_spt`**(上表 **SPT / per-tensor**);本轮只选其一;**须与下一行的标定文件同表同行成对**。 | +| 标定 JSON(**相对仓库根**) | **`--kv_quant_calibration_config_path`**:仅允许为上表中 **与当前 `LLM_KV_TYPE` 同一行** 的那一个相对路径;依赖 **`cd` 到仓库根** 后的 cwd,**不要**写绝对路径,勿用 **`${REPO_ROOT}/…`** 拼接;**禁止**与 **`--llm_kv_type` 交叉混用**(见上文配对规则)。 | +| `CUDA_VISIBLE_DEVICES` | 两张卡,**`nvidia-smi` 后 export**。 | + +**标定路径写法**:**`--kv_quant_calibration_config_path`** 里 **只写相对路径**,且 **必须是上表与本轮 `--llm_kv_type` 同一行的那一个**。由 Agent 保证 **`nohup` 所在进程的工作目录为 LightLLM 根**(先 `cd` 再启动,或与 `nohup` 写在同一行的 `cd … &&`)。 + +## 启动服务(后台 + 日志) + +**不要用 health 接口**判断就绪;以 **端口 listen**(**`8089`**)结合 **`server.log`** 为准;约 **每 20 秒**看日志直至就绪或报错。 + +以下为 **两种成对配置**,**每次整段复制其一**;**不要**混用「变量展开的 `--llm_kv_type` + 写死的标定路径」以免与上表不一致。 + +### 形态 A:**SPH**(`fp8kv_sph` + per-head 标定) + +```bash +export http_proxy= +export https_proxy= + +export LOG_DIR='〈本轮日志目录〉' +export MODEL_DIR='/mtc/models/Qwen2.5-14B-Instruct' + +# 〈LightLLM 仓库根〉由 Agent 改为本机检出目录的绝对路径,仅用于 cd +cd '〈LightLLM 仓库根〉' + +LOADWORKER=18 CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port 8089 \ + --llm_kv_type fp8kv_sph \ + --kv_quant_calibration_config_path test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_14b.json \ + >> "${LOG_DIR}/server.log" 2>&1 & +``` + +### 形态 B:**SPT**(`fp8kv_spt` + per-tensor 标定) + +```bash +export http_proxy= +export https_proxy= + +export LOG_DIR='〈本轮日志目录〉' +export MODEL_DIR='/mtc/models/Qwen2.5-14B-Instruct' + +cd '〈LightLLM 仓库根〉' + +LOADWORKER=18 CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port 8089 \ + --llm_kv_type fp8kv_spt \ + --kv_quant_calibration_config_path test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_14b.json \ + >> "${LOG_DIR}/server.log" 2>&1 & +``` + +(**`--kv_quant_calibration_config_path`** 均为 **相对仓库根**;**不要**写成绝对路径。) + +- **`lm_eval` 的 `base_url`**:本 skill 约定 **`http://127.0.0.1:8089/v1/completions`**(**端口固定**,评测与 **`no_proxy`** 均按 **`127.0.0.1`**);**`api_server` 须 `--port 8089`**(默认不显式 **`--host`** 时一般为 `0.0.0.0`,本机访问用 **`127.0.0.1`** 即可)。 + +## 评测命令(服务就绪后) + +**`tokenizer` 与 `MODEL_DIR` 对齐**(与其它 test_model skill 一致): + +```bash +export http_proxy= +export https_proxy= + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +no_proxy=localhost,127.0.0.1,0.0.0.0,::1 \ +lm_eval --model local-completions \ + --model_args "{\"model\":\"Qwen/Qwen2.5-14B-Instruct\", \"base_url\":\"http://127.0.0.1:8089/v1/completions\", \"max_length\": 16384, \"tokenizer\":\"${MODEL_DIR}\"}" \ + --tasks gsm8k --batch_size 64 --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +## 执行约定 + +### 模型目录(`MODEL_DIR`) + +- 不同机器上路径不同。**首轮**可 **`export MODEL_DIR=/mtc/models/Qwen2.5-14B-Instruct`**(与命令模板一致)。 +- **在启动 `api_server` 之前**:若该路径**不存在**,或启动后日志明确为 **模型路径 / 权重 / 文件不存在** 等问题,**停止盲试**,**向用户询问**当前环境下 **Qwen2.5-14B-Instruct** 的实际目录绝对路径;用户给出后更新 **`export MODEL_DIR='…用户提供的绝对路径…'`**,并保证后续 **`--model_dir`** 与 **`lm_eval` 的 `tokenizer`** 使用该同一变量;将最终采用的 **`MODEL_DIR`** 写入 **`summary.txt`**。 + +1. **两种形态分两轮测**:先完整跑 **形态 A(SPH)**(含 **`summary.txt`**),再换 **`LOG_DIR`** 并完整使用 **形态 B(SPT)** 启动块(**`--llm_kv_type fp8kv_spt` 与 per-tensor 标定路径须同时来自上表 SPT 行**);不要混在同一 **`server.log`** 里,也不要只改 **`--llm_kv_type`** 而不换标定 JSON。 +2. **端口**:确认 **`8089`** 进入 **LISTEN** 后再跑 **`lm_eval`**(**勿改端口**)。 +3. **结束后**:关闭 **`api_server`** 进程,释放 GPU 与端口。 +4. **错误**:将摘要写入 **`summary.txt`**,并在对话中说明关键日志行。 diff --git a/skills/test_model/qwen3-8b-gsm8k-scenarios/SKILL.md b/skills/test_model/qwen3-8b-gsm8k-scenarios/SKILL.md new file mode 100644 index 0000000000..632e3c4121 --- /dev/null +++ b/skills/test_model/qwen3-8b-gsm8k-scenarios/SKILL.md @@ -0,0 +1,196 @@ +--- +name: test-model-qwen3-8b-gsm8k-scenarios +description: >- + LightLLM Qwen3-8B GSM8K multi-scenario regression: seven isolated api_server configs + (baseline, vllm-fp8w8a8 quant, tpsp mix, tpsp with dp2 and dp prefill balance, cpu cache, + int8kv on top of cpu cache, disk cache with LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH). + Each scenario then lm_eval gsm8k batch 500. Scenarios 5–7 run lm_eval twice for cache + hit. Per-scenario LOG_DIR, server.log, eval logs, summary.txt. Default MODEL_DIR + /mtc/models/qwen3-8b; DISK_CACHE_DIR /mtc/test/tmp/ for scenario 7; ask user if paths + invalid. Fixed HTTP port 8089 (not configurable). nvidia-smi GPUs, port listen not health, + clear proxies and no_proxy. +--- + +# Qwen3-8B **多场景 GSM8K 回归** + +同一 **`MODEL_DIR`(Qwen3-8B 权重)** 下,按 **七种 `api_server` 配置** 依次各跑一轮:**启动服务 → 端口与日志就绪 → `lm_eval`**。场景 **5、6、7** 在相同服务配置下 **`lm_eval` 连续执行两次**(缓存预热与命中后效率/精度对照,与历史脚本注释一致)。 + +**端口**:**固定 `8089`**(与 **`--port`**、**`lm_eval` 的 `base_url`** 一致;**不作为环境变量**)。 + +**评测**:**`lm_eval`**,**`tasks gsm8k`**,**`batch_size 500`**,**`model`:`qwen/qwen3-8b`**。**`tokenizer` 与 `MODEL_DIR` 须为同一目录路径**。 + +## 场景总览 + +| # | 名称 | `api_server` 要点 | `lm_eval` | +|---|------|-------------------|-----------| +| 1 | 基线 | **`--tp 2`**,无额外开关 | 1 次 | +| 2 | FP8 量化 | **`--quant_type vllm-fp8w8a8`**(在场景 1 基础上) | 1 次 | +| 3 | TP-SP 混合 | **`--enable_tpsp_mix_mode`**(**`--tp 2`**,无 **`--dp`**) | 1 次 | +| 4 | TP-SP + DP2 + DP prefill 均衡 | **`--tp 2 --dp 2`**、**`--enable_tpsp_mix_mode`**、**`--enable_dp_prefill_balance`** | 1 次 | +| 5 | CPU Cache | **`--tp 2 --dp 2`**,**`--max_total_token_num 200000`**,**`--enable_cpu_cache`**,**`--cpu_cache_storage_size 128`**,**`--cpu_cache_token_page_size 128`** | **2 次** | +| 6 | CPU Cache + INT8 KV | 在场景 5 基础上增加 **`--llm_kv_type int8kv`** | **2 次** | +| 7 | Disk Cache | **`LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH=128`**;**`--tp 2 --dp 2`**,**`--max_total_token_num 200000`**,CPU cache 与 **disk cache** 一组参数(见命令块) | **2 次** | + +**算力说明**:历史脚本使用 **`CUDA_VISIBLE_DEVICES`** 指向 **2 张卡**;场景 **4–7** 含 **`--dp 2`**。执行前须结合 **`nvidia-smi`** 与 **LightLLM 对 tp/dp 的资源说明** 确认本机 GPU 数与映射是否满足;不满足时 **向用户询问** 正确启动方式,**不要盲试**。 + +## 日志目录(含 `summary.txt`) + +- **每个场景**使用独立 **`LOG_DIR`**。 +- **`api_server`** → **`"${LOG_DIR}/server.log"`**(推荐 **`nohup … >> … 2>&1 &`**)。 +- **`lm_eval`**:第一次 **`"${LOG_DIR}/eval_gsm8k.log"`**;第二次(场景 5–7)**`"${LOG_DIR}/eval_gsm8k_run2.log"`**。 +- **`summary.txt`**:本场景完整启动参数、**`lm_eval` 摘要**、端口与日志就绪情况、两轮评测说明(若适用)、结论与失败原因。 + +## 启动前检查 + +1. **显卡**:**`nvidia-smi`** 后 **`export CUDA_VISIBLE_DEVICES`**;**不要写死卡号**;**`--tp` / `--dp`** 与卡数须匹配本机规范。 +2. **端口**:每轮前确认 **`8089`** 空闲;上一轮结束后 **终止 `api_server`** 再启下一轮。 +3. **`MODEL_DIR`**:见 **「路径约定」**;**`test -d "${MODEL_DIR}"`**。 +4. **`DISK_CACHE_DIR`(仅场景 7)**:见 **「路径约定」**;**`mkdir -p`** 后须可写。 +5. **代理**:**`api_server` / `lm_eval` 前** 置空 **`http_proxy` / `https_proxy`**;**`lm_eval`** 配置 **`no_proxy`**(见评测块)。 + +## 路径约定(`MODEL_DIR` 与 `DISK_CACHE_DIR`) + +- **`MODEL_DIR`**:**首轮试跑默认** **`/mtc/models/qwen3-8b`**。若目录不存在或加载失败,**向用户询问** 本机正确路径;**`--model_dir` 与 `lm_eval` 的 `tokenizer` 保持一致**。 +- **`DISK_CACHE_DIR`(场景 7)**:**默认** **`/mtc/test/tmp/`**;不可写或不存在时 **向用户询问** 可写目录;**`summary.txt`** 记录最终路径。 + +## 可变项 + +| 变量 | 含义 | +|------|------| +| `LOG_DIR` | 当前场景日志根目录。 | +| `MODEL_DIR` | **`--model_dir`**;**`lm_eval` 的 `tokenizer`**。 | +| `BIND_URL_HOST` | **`base_url` 主机**;常用 **`127.0.0.1`**。 | +| `CUDA_VISIBLE_DEVICES` | 由 **`nvidia-smi`** 决定;与 tp/dp 组合须匹配环境。 | +| `DISK_CACHE_DIR` | 场景 7 的 **`--disk_cache_dir`**;默认 **`/mtc/test/tmp/`**。 | + +**开跑前导出示例**: + +```bash +export LOG_DIR='〈本场景日志目录〉' +export MODEL_DIR='/mtc/models/qwen3-8b' +export DISK_CACHE_DIR='/mtc/test/tmp/' +export BIND_URL_HOST='127.0.0.1' +# export CUDA_VISIBLE_DEVICES='6,7' +``` + +## 服务就绪判定 + +**不要使用 HTTP health 作为唯一依据**。结合 **`8089` 是否 LISTEN** 与 **`server.log`**;可约 **每 20 秒** 查看一次直至可评测或确认失败。 + +## `lm_eval` 命令模板(单次) + +```bash +export http_proxy= +export https_proxy= + +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${BIND_URL_HOST} + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +lm_eval --model local-completions \ + --model_args "{\"model\":\"qwen/qwen3-8b\", \"base_url\":\"http://${BIND_URL_HOST}:8089/v1/completions\", \"max_length\": 16384, \"tokenizer\":\"${MODEL_DIR}\"}" \ + --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +场景 **5–7** 第二次:将重定向改为 **`>> "${LOG_DIR}/eval_gsm8k_run2.log" 2>&1`**。 + +## 各场景 `api_server` 命令模板 + +以下省略 **`export http_proxy=` / `export https_proxy=`**、**`LOADWORKER=18`**、**`CUDA_VISIBLE_DEVICES`**、**`nohup`** 与 **`>> "${LOG_DIR}/server.log" 2>&1 &`**;实际执行时与其它 acc skill 一致自行补全。 + +### 场景 1:基线 + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port 8089 +``` + +### 场景 2:FP8 量化(`vllm-fp8w8a8`) + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port 8089 \ + --quant_type vllm-fp8w8a8 +``` + +### 场景 3:TP-SP 混合 + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port 8089 \ + --enable_tpsp_mix_mode +``` + +### 场景 4:TP-SP + DP2 + DP prefill 均衡 + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --dp 2 \ + --port 8089 \ + --enable_tpsp_mix_mode \ + --enable_dp_prefill_balance +``` + +### 场景 5:CPU Cache(`lm_eval` 两次) + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --dp 2 \ + --port 8089 \ + --max_total_token_num 200000 \ + --enable_cpu_cache \ + --cpu_cache_storage_size 128 \ + --cpu_cache_token_page_size 128 +``` + +### 场景 6:CPU Cache + INT8 KV(`lm_eval` 两次) + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --dp 2 \ + --port 8089 \ + --max_total_token_num 200000 \ + --enable_cpu_cache \ + --cpu_cache_storage_size 128 \ + --cpu_cache_token_page_size 128 \ + --llm_kv_type int8kv +``` + +### 场景 7:Disk Cache(`lm_eval` 两次) + +与历史脚本一致:在 **`python`** 前加 **`LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH=128`**。 + +```bash +LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH=128 \ +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --dp 2 \ + --port 8089 \ + --max_total_token_num 200000 \ + --enable_cpu_cache \ + --cpu_cache_storage_size 64 \ + --cpu_cache_token_page_size 128 \ + --enable_disk_cache \ + --disk_cache_storage_size 256 \ + --disk_cache_dir "${DISK_CACHE_DIR}" +``` + +## 执行约定 + +1. **顺序**:**1 → 7** 严格递增;每步 **新 `LOG_DIR`**,**先停旧服务**。 +2. **场景 5–7**:**`lm_eval` 各执行两次**,并在 **`summary.txt`** 说明 run1 / run2 目的。 +3. **`MODEL_DIR` / `DISK_CACHE_DIR`**:遵循 **「路径约定」**。 +4. **收尾**:全部结束后释放进程、端口与 GPU。 diff --git a/skills/test_model/qwen3-8b-pd-nixl/SKILL.md b/skills/test_model/qwen3-8b-pd-nixl/SKILL.md new file mode 100644 index 0000000000..a1775d09d8 --- /dev/null +++ b/skills/test_model/qwen3-8b-pd-nixl/SKILL.md @@ -0,0 +1,225 @@ +--- +name: test-model-qwen3-8b-pd-nixl +description: >- + LightLLM Qwen3-8b PD disaggregation gsm8k: pd_master on 8089, prefill on 8001, + decode on 8002, tp 2 each. Assign four GPUs via nvidia-smi then export + PREFILL_CUDA_DEVICES / DECODE_CUDA_DEVICES (no fixed card IDs; no complex shell automation). + UCX_NET_DEVICES and TLS for RDMA per cluster. lm_eval hits pd_master URL. HOST vs + PD_MASTER_IP when co-located. Before lm_eval, must POST one completion via curl to + pd_master for warmup verification. Requires LOG_DIR, MODEL_DIR, proxy cleared, no_proxy, + summary.txt. Same-GPU model_infer + pd_*_trans need NVIDIA MPS for best KV copy perf; + record MPS on/off in summary. Run check_nvidia_peermem.sh in this skill dir; record in summary.txt. + Use for PD separation tests with either the default NIXL transport or NCCL transport. +--- + +# Qwen3-8B **PD 分离**(`pd_master` + `prefill` + `decode`)本地 GSM8K 评测 + +**测试标识**:同一 **`--model_dir`**(Qwen3-8B)下拆 **三条** `api_server` 进程——**调度/入口(`pd_master`)**、**`prefill` 节点**、**`decode` 节点**;评测 **`lm_eval`** 只访问 **`pd_master` 的 HTTP 端口(8089)**。默认使用 NIXL 传输;需要验证 NCCL 数据面时,设置 **`LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl`**,上层仍保持相同的 `prefill` / `decode` 管理路径。 + +**端口约定**:**`pd_master`:`8089`**;**prefill:`8001`**;**decode:`8002`**。启动与就绪探测须覆盖这三处(以及日志中的 PD 注册/报错信息)。 + +**绑定 IP(`HOST` / `PD_MASTER_IP`)**:各进程的 **`--host`** 表示 **本服务监听绑定的 IP**。当 **`pd_master`、`prefill`、`decode` 部署在同一台机器上时**,三者使用的绑定 IP **相同**:可 **`export HOST="${PD_MASTER_IP}"`**;**`lm_eval` 的 `base_url` 仍指向 `pd_master`**。 + +整轮产物落在**同一日志目录**,写入 **`summary.txt`** 与各进程日志;**不要**写聚合启动脚本,按「启动说明」逐条手动启动并在后台落盘。 + +## 日志目录(含 `summary.txt`) + +- 每次评测先选定或新建**一个日志目录**(例如带时间戳或任务名),与其它测试轮次分开。 +- **三个 `api_server` 的标准输出/错误**分别写入该目录,建议命名:**`pd_master.log`**、**`prefill.log`**、**`decode.log`**(文件名可沿用习惯,与 NCCL 测试一致便于对比)。 +- **`summary.txt` 固定放在该日志目录下**,汇总:三台进程的启动参数摘要、端口与就绪情况、**UCX 配置要点**、**`check_nvidia_peermem.sh` 输出**、**MPS 是否开启**、**KV 传输指标**、`lm_eval` 关键结果、失败原因与最终结论。 +- **`eval_gsm8k.log`**:`lm_eval` 终端输出;**`curl_warmup.log`**:测试前 **`curl`** 打 **`pd_master`** 的留档(建议);**`summary.txt`** 仍为**总览结论**。 + +## 启动说明 + +本节包含:启动前检查 → 可变项说明 → 显卡分配 → UCX → **按顺序**三条完整 server 命令 → **curl warmup** → 评测命令。 + +### 启动前检查 + +1. **显卡**:prefill / decode 各需 **2 张物理 GPU**(**`--tp 2`**),共 **4 张互不重复**。**不要写死卡号**:先 **`nvidia-smi`**(见「显卡分配」),再 **`export PREFILL_CUDA_DEVICES`**、**`DECODE_CUDA_DEVICES`**。 +2. **端口**:**`8089`、`8001`、`8002`** 均须未被占用。 +3. **网络 / IP**:**`HOST`** 与 **`PD_MASTER_IP`** 约定同 NCCL PD skill;单机三进程 **`export HOST="${PD_MASTER_IP}"`**。 +4. **代理**:启动 **任一 server 前**将 **`http_proxy` / `https_proxy` 置空**;评测使用 **`no_proxy`**(见评测命令)。 +5. **RDMA / UCX**:prefill 与 decode 进程在启动 Python 前须设置 **`UCX_NET_DEVICES`**(及可选 **`UCX_LOG_LEVEL`**、**`UCX_TLS`**),取值依赖本机 **`ibv_devinfo`** 与机房拓扑(见「UCX / RDMA」);**不要**默认照抄他机上的设备名或排除列表。 +6. **`nvidia_peermem`**:`bash skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh >> "${LOG_DIR}/summary.txt"`;失败按脚本提示 `modprobe` 后重启服务(跨机各节点都要做)。 +7. **CUDA MPS(强烈建议,见下节)**:**要达到 PD KV 拷贝与 batch 评测最佳性能,须在启动 `api_server` 之前在本机启用 NVIDIA MPS**。未开 MPS 时功能通常仍可用,但易出现 **`read_page_gpu_time` 数十秒级毛刺**、**`lm_eval` 单 batch 近百秒**;**`summary.txt` 须写明 MPS 是否已开启及验证方式**。 + +### 启动服务的命令模板(可变项) + +| 可变项 | 含义 | +|--------|------| +| `LOG_DIR` | 本轮日志根目录;`export LOG_DIR=…`。 | +| `MODEL_DIR` | **`--model_dir`**;`lm_eval` 的 **`tokenizer` 须与此路径一致**。 | +| `PD_MASTER_IP` | **`pd_master` 的 `--host`**;**`lm_eval` 的 `base_url` 主机**。 | +| `HOST` | **`prefill` / `decode` 的 `--host`**。同机时 **`export HOST="${PD_MASTER_IP}"`**。 | +| `PREFILL_CUDA_DEVICES` | prefill 的 **`CUDA_VISIBLE_DEVICES`**(两张物理索引);**`nvidia-smi` 后 export**。 | +| `DECODE_CUDA_DEVICES` | decode 的 **`CUDA_VISIBLE_DEVICES`**;与 prefill **四卡互不重复**。 | +| `UCX_NET_DEVICES` | UCX 使用的 HCA 列表,形如 `mlx5_0:1,mlx5_1:1`;**按本机 `ibv_devinfo` 与规划填写**。 | +| `UCX_LOG_LEVEL` / `UCX_TLS` | 常见为 **`info`** 与 **`rc,cuda,gdr_copy`**;可按环境调整。 | + +开跑前导出示例: + +```bash +export LOG_DIR='〈日志根目录〉' +export MODEL_DIR='〈Qwen3-8B 模型目录〉' +export PD_MASTER_IP='〈本机绑定 IP〉' +export HOST="${PD_MASTER_IP}" +export UCX_NET_DEVICES='〈按 ibv_devinfo 填写,逗号分隔 port :1〉' +export UCX_LOG_LEVEL=info +export UCX_TLS=rc,cuda,gdr_copy +``` + +### 显卡分配(`nvidia-smi` + 人工/Agent 决策,不用复杂脚本) + +**prefill**、**decode** 各 **2** 张 GPU,共 **4** 张互不重复。需要验证 NCCL 数据面时,额外设置 **`LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl`**。 + +1. 执行 **`nvidia-smi`**(可选用 `--query-gpu=index,name,memory.used,memory.free --format=csv`)。 +2. 由执行者选定哪 2 张给 prefill、哪 2 张给 decode(不重叠)。 +3. **`export PREFILL_CUDA_DEVICES='…','…'`**、**`export DECODE_CUDA_DEVICES='…','…'`**。 +4. 将选卡依据记入 **`summary.txt`**。 + +**禁止**:为选卡编写 **awk / mapfile / 长段 bash** 自动化;以 **`nvidia-smi` 事实 + 明确决策**为准。 + +### UCX / RDMA(默认 NIXL 传输) + +- **`UCX_NET_DEVICES`**:须覆盖本进程要用的 **RDMA 设备**;是否排除某些 HCA(例如数据面网卡)由**本机拓扑**决定,在 **`summary.txt`** 中写明依据。 +- **`UCX_TLS`**:常见 **`rc,cuda,gdr_copy`**;若环境不支持再按报错调整。 +- **IB 传 GPU KV** 需加载内核模块 **`nvidia_peermem`**(检测:**`skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh`**)。 + +#### 要达到最优性能:须开启 MPS + +如果用户没有特别说明要开启 mps,测试的时候可以不开启。 + +1. **在启动任意 `api_server` 之前**,按机房规范启动 MPS(示例,**以本集群文档为准**): + +```bash +# 确认无其它任务占用目标 GPU 后再执行;具体参数问运维 +export CUDA_VISIBLE_DEVICES="${PREFILL_CUDA_DEVICES},${DECODE_CUDA_DEVICES}" # 或整机 MPS,按规范 +nvidia-cuda-mps-control -d +# 验证:nvidia-smi 应出现 nvidia-cuda-mps-server,且各 GPU 有少量固定占用 +``` + +2. **验证 MPS 已生效**(写入 **`summary.txt`**): + +```bash +nvidia-smi --query-compute-apps=pid,process_name --format=csv | grep -i mps || true +pgrep -a mps-control || pgrep -a cuda-mps +``` + +### 1)启动 `pd_master`(须最先就绪监听) + +```bash +export http_proxy= +export https_proxy= + +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --run_mode pd_master \ + --host "${PD_MASTER_IP}" \ + --port 8089 \ + >> "${LOG_DIR}/pd_master.log" 2>&1 & +``` + +### 2)启动 `prefill` 节点 + +**须在 `pd_master` 就绪后**再启动。启动前已完成 **`nvidia-smi` 决策**并 **`export PREFILL_CUDA_DEVICES`**,且已设置 **UCX**。 + +```bash +export http_proxy= +export https_proxy= + +LOADWORKER=18 CUDA_VISIBLE_DEVICES="${PREFILL_CUDA_DEVICES}" \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --run_mode prefill \ + --tp 2 \ + --dp 1 \ + --host "${HOST}" \ + --port 8001 \ + --disable_cudagraph \ + --pd_master_ip "${PD_MASTER_IP}" \ + --pd_master_port 8089 \ + >> "${LOG_DIR}/prefill.log" 2>&1 & +``` + +(若需显式传入 UCX,可在同一 shell 中于本块之前 **`export UCX_NET_DEVICES`** 等;**`nohup` 会继承当前 shell 的环境变量**。) + +### 3)启动 `decode` 节点 + +启动前 **`export DECODE_CUDA_DEVICES`**,并确保 **UCX** 已设置。 + +```bash +export http_proxy= +export https_proxy= +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} + +LOADWORKER=18 CUDA_VISIBLE_DEVICES="${DECODE_CUDA_DEVICES}" \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --run_mode decode \ + --tp 2 \ + --dp 1 \ + --host "${HOST}" \ + --port 8002 \ + --pd_master_ip "${PD_MASTER_IP}" \ + --pd_master_port 8089 \ + >> "${LOG_DIR}/decode.log" 2>&1 & +``` + +### 测试前 curl warmup(**须执行**,再走 `lm_eval`) + +PD 链路在首次真实推理前易出现冷启动与传输路径问题。**在跑 `lm_eval` 正式评测之前**,必须先对 **`pd_master`** 的 **`/v1/completions`** 发 **至少一次** HTTP 请求,确认返回 **2xx** 且响应体含正常 completion(再走长评测)。 + +1. **时机**:**`prefill` 与 `decode` 均已启动**,且日志显示已与 **`pd_master`** 建立 PD 链路后再执行(可与端口 listen、日志轮询结合判断)。 +2. **代理**:执行 **`curl` 前**同样 **`export http_proxy=` / `export https_proxy=`**;若评测机对 **`PD_MASTER_IP`** 走代理会失败,可对本次 shell 设置 **`no_proxy`**(与下文 `lm_eval` 一致,须包含 **`${PD_MASTER_IP}`**)。 +3. **记录**:将 **`curl` 使用的命令、HTTP 状态码、若失败则错误摘要** 写入 **`summary.txt`**;成功后再启动 **`lm_eval`**。 + +示例(**`model` 与 `lm_eval` 中 `model` 字段保持一致**,一般为 **`qwen/qwen3-8b`**;可按需改 **`prompt` / `max_tokens`**): + +```bash +export http_proxy= +export https_proxy= +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} + +curl -sS -w "\nhttp_code:%{http_code}\n" -X POST "http://${PD_MASTER_IP}:8089/v1/completions" \ + -H "Content-Type: application/json" \ + -d "{\"model\":\"qwen/qwen3-8b\",\"prompt\":\"warmup\",\"max_tokens\":32}" \ + | tee "${LOG_DIR}/curl_warmup.log" +``` + +- 期望 **`http_code:200`**(或环境约定的成功码);非 2xx 时先查 **`pd_master.log` / `prefill.log` / `decode.log`**,**不要**直接开大批量 `lm_eval`。 +- 可将 **`curl` 输出**保留为 **`curl_warmup.log`**(如上),便于与 **`eval_gsm8k.log`** 对照。 + +### 评测命令(**curl warmup 成功后**执行) + +**`base_url` 指向 `pd_master`**;**`tokenizer` 与 `MODEL_DIR` 一致**: + +```bash +export http_proxy= +export https_proxy= + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} \ +lm_eval --model local-completions \ + --model_args "{\"model\":\"qwen/qwen3-8b\", \"base_url\":\"http://${PD_MASTER_IP}:8089/v1/completions\", \"max_length\": 16384, \"tokenized_requests\": false, \"tokenizer\":\"${MODEL_DIR}\"}" \ + --tasks gsm8k --batch_size 64 --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +- 若需 **`lm_eval` 侧**再跑一次小样本,可加 **`--limit 1`**;**不能替代**上文 **`curl` warmup**。 +- 若需先用代理下载 `lm_eval` 缓存,见「执行约定」。 + +## 执行约定 + +**模型目录**:首轮可 **`export MODEL_DIR=/mtc/models/qwen3-8b`**;路径报错时由用户提供本机 **`MODEL_DIR`**。 + +1. **启动顺序**:**`bash skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh >> "${LOG_DIR}/summary.txt"`** → **`pd_master`** → **`nvidia-smi` + export 四卡** → **设置 UCX** → **`prefill`** → **`decode`** → **`curl` warmup(须成功)** → **`lm_eval`**。 +2. **不要用 health 接口**作为唯一依据;结合 **端口 listen** 与 **`pd_master.log` / `prefill.log` / `decode.log`**。 +3. **约每 20 秒**查看日志直至就绪或报错;异常写入 **`summary.txt`**。 +4. **`summary.txt`**:记录启动摘要、**`PREFILL_CUDA_DEVICES` / `DECODE_CUDA_DEVICES`** 与选卡依据、**`UCX_NET_DEVICES` 等**、**`curl` warmup 结果(或 `curl_warmup.log` 路径)**、评测关键输出、最终结论。 +5. **结束后**关闭 **`pd_master`、`prefill`、`decode`** 相关进程。 +6. 当用户说明是压测的时候,将lmeval 的 --batch_size 修改为 500 +7. 发现 connetion to pd_master has error 错误的时候,可以先容忍一会,这种网络状态错误有时是可以自行恢复的。 + +## 输出文件 + +- **`summary.txt`**、**`pd_master.log`、`prefill.log`、`decode.log`**、**`curl_warmup.log`(建议)**、**`eval_gsm8k.log`** 均落在**同一 `LOG_DIR`**。 diff --git a/skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh b/skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh new file mode 100755 index 0000000000..21bc2f35e6 --- /dev/null +++ b/skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Check nvidia_peermem (GPUDirect RDMA) for NIXL PD / UCX over IB. +# Usage: bash skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh [LOG_DIR] +# LOG_DIR optional: scan prefill.log / decode.log for UCX GPUDirect lines. +set -euo pipefail + +LOG_DIR="${1:-}" +FAIL=0 + +echo "=== nvidia_peermem ===" + +if lsmod 2>/dev/null | awk '{print $1}' | grep -qx nvidia_peermem; then + ver="$(cat /sys/module/nvidia_peermem/version 2>/dev/null || echo '?')" + echo "OK: module loaded (version ${ver})" +else + echo "FAIL: nvidia_peermem not loaded" + FAIL=1 +fi + +if [[ -n "$LOG_DIR" ]]; then + for f in prefill.log decode.log; do + [[ -f "${LOG_DIR}/${f}" ]] || continue + if grep -q 'GPUDirect RDMA is not detected' "${LOG_DIR}/${f}" 2>/dev/null; then + echo "FAIL: ${f} -> GPUDirect RDMA is not detected (restart services after modprobe)" + FAIL=1 + elif grep -q 'GPUDirect RDMA is detected' "${LOG_DIR}/${f}" 2>/dev/null; then + echo "OK: ${f} -> GPUDirect RDMA is detected" + fi + done +fi + +if [[ "$FAIL" -ne 0 ]]; then + cat <<'EOF' + +Enable GPUDirect RDMA: + sudo modprobe nvidia_peermem + lsmod | grep nvidia_peermem + # cross-node: run on every host; then restart prefill / decode +EOF + exit 1 +fi + +exit 0 diff --git a/skills/test_model/qwen3-vl-8b-mmmu-val/SKILL.md b/skills/test_model/qwen3-vl-8b-mmmu-val/SKILL.md new file mode 100644 index 0000000000..5db4754500 --- /dev/null +++ b/skills/test_model/qwen3-vl-8b-mmmu-val/SKILL.md @@ -0,0 +1,114 @@ +--- +name: test-model-qwen3-vl-8b-mmmu-val +description: >- + LightLLM Qwen3-VL-8B-Instruct: api_server tp 2 on port 8089, then lmms-eval CLI + (python -m lmms_eval, model openai_compatible, tasks mmmu_val, batch_size 900) + with OPENAI_API_BASE pointing at LightLLM OpenAI-compatible /v1. Restore https_proxy for Hub + while no_proxy includes 127.0.0.1. Requires lmms-eval install, OPENAI_API_KEY placeholder, + LOG_DIR and MODEL_DIR, nvidia-smi GPU choice, pipefail with tee, summary.txt. No wrapper + script; use command line only. +--- + +# Qwen3-VL-8B-Instruct **MMMU 验证集(`mmmu_val`)** 评测 + +**测试标识**:先在本机启动 **`lightllm.server.api_server`**(**Qwen3-VL-8B-Instruct**,**`--tp 2`**,HTTP **`8089`**);服务就绪后,在已安装 **`lmms-eval`** 的环境中直接执行 **`python3 -m lmms_eval`**(**`openai_compatible`**,任务 **`mmmu_val`**),通过环境变量 **`OPENAI_API_BASE`** 指向 **`api_server`** 的 OpenAI 兼容前缀(**含 `/v1`**)。 + +**依赖(评测侧)**:**`lmms-eval`**(版本示例): + +```text +git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git +pip install -e lmms-eval/ +``` + +执行 **`python3 -m lmms_eval`** 的 Python 环境须已安装上述包;**不要求**在 LightLLM 仓库根目录下执行(除非你的数据或配置依赖 **`cwd`**)。 + +## 日志目录(含 `summary.txt`) + +- 选定 **`LOG_DIR`**(绝对路径建议带时间戳)。 +- **`api_server`** → **`"${LOG_DIR}/server.log"`**(推荐 **`nohup`** 后台)。 +- **`lmms_eval`** 的 **`--output_path`**:**建议 `"${LOG_DIR}/lmms_eval_out"`**;控制台输出可 **`tee`** 到 **`"${LOG_DIR}/lmms_eval_console.log"`**。 +- **`summary.txt`**:模型路径、**`OPENAI_API_BASE`**、完整 **`lmms_eval` 命令**、端口检测结果、输出目录路径、失败原因。 + +## 启动前检查 + +1. **显卡**:**`--tp 2`** → **2 张物理 GPU**;先 **`nvidia-smi`**,再 **`export CUDA_VISIBLE_DEVICES`**(**不要写死**)。 +2. **端口**:**`8089`** 未被占用。 +3. **`MODEL_DIR`**:**`api_server --model_dir`** 与 **`--model_args` 里的 `model_version=`** 须为**同一 Qwen3-VL-8B-Instruct 权重路径**(默认示例 **`/mtc/models/Qwen3-VL-8B-Instruct`**;不存在时向用户询问本机路径)。 +4. **`lmms-eval` 已安装**且 **`python3 -m lmms_eval`** 可用。 +5. **代理**:启动 **`api_server` 前**清空 **`http_proxy` / `https_proxy`**;跑 **`lmms_eval` 前**将 **`no_proxy`** 设为包含本机 **`127.0.0.1`**(见下文评测块);**若需从 Hugging Face Hub 拉取 `lmms-lab/MMMU`,评测阶段应恢复可用的 `https_proxy`(或等价镜像)**,否则清空代理后可能出现 **`ConnectionError: Couldn't reach 'lmms-lab/MMMU' on the Hub`**。 + +## 可变项 + +| 变量 | 含义 | +|------|------| +| `LOG_DIR` | 本轮日志与 **`lmms_eval --output_path`** 父目录。 | +| `MODEL_DIR` | **`api_server --model_dir`**;**`--model_args` 中 `model_version=`** 与之相同。 | +| `PORT` | 默认 **`8089`**。 | +| `BIND_URL_HOST` | 与 **`OPENAI_API_BASE`** 主机一致;本机常用 **`127.0.0.1`**。 | +| `OPENAI_API_BASE` | 形如 **`http://${BIND_URL_HOST}:${PORT}/v1`**(**末尾含 `/v1`**)。 | +| `OPENAI_API_KEY` | 占位即可(常用 **`lightllm123`**);若服务端校验密钥,与用户环境对齐。 | +| `CUDA_VISIBLE_DEVICES` | 两张卡。 | + +## 启动 `api_server` + +**不要用 health 作为唯一依据**;以 **端口 listen** + **`server.log`** 为准;可约 **每 20 秒**查看日志。 + +```bash +export http_proxy= +export https_proxy= + +export LOG_DIR='〈日志目录〉' +export MODEL_DIR='/mtc/models/Qwen3-VL-8B-Instruct' +export PORT=8089 + +LOADWORKER=18 CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port "${PORT}" \ + >> "${LOG_DIR}/server.log" 2>&1 & +``` + +## 运行 `lmms_eval`(服务就绪后,仅命令行) + +设置 **`OPENAI_API_*`** 与代理后,直接 **`python3 -m lmms_eval`**(**`timeout` 可选**,例如单次上限 **3600 秒**): + +```bash +# 若启动 api_server 时曾清空代理,请先保存并在评测前恢复 Hub 代理,例如: +# export ORIG_HTTPS_PROXY="${https_proxy-}" +# export http_proxy=; export https_proxy= +# … 启动 api_server … +# export https_proxy="${ORIG_HTTPS_PROXY}" + +export BIND_URL_HOST='127.0.0.1' +export PORT=8089 +export OPENAI_API_BASE="http://${BIND_URL_HOST}:${PORT}/v1" +export OPENAI_API_KEY="${OPENAI_API_KEY:-lightllm123}" +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${BIND_URL_HOST} + +export LOG_DIR='〈与上文同一日志目录〉' +export MODEL_DIR='/mtc/models/Qwen3-VL-8B-Instruct' + +mkdir -p "${LOG_DIR}/lmms_eval_out" + +timeout 3600 python3 -m lmms_eval \ + --model openai_compatible \ + --model_args "model_version=${MODEL_DIR},tp=1" \ + --tasks mmmu_val \ + --batch_size 900 \ + --log_samples \ + --log_samples_suffix openai_compatible \ + --output_path "${LOG_DIR}/lmms_eval_out" \ + 2>&1 | tee "${LOG_DIR}/lmms_eval_console.log" +``` + +说明:**`model_args` 中的 `tp=1`** 为 **`lmms_eval` / `openai_compatible` 侧参数**,与 **`api_server` 的 `--tp 2`** 不同;**不要**混用含义。 + +若环境无 **`timeout`** 命令,可去掉 **`timeout 3600`**。 + +## 执行约定 + +1. **顺序**:**`api_server` 就绪** → 再 **`lmms_eval`**。 +2. **`model_version` 与 `MODEL_DIR` 必须一致**。 +3. 超时或失败将摘要写入 **`summary.txt`**。 +4. 结束后关闭 **`api_server`**,释放 GPU 与端口。 diff --git a/skills/test_model/qwen3-vl-8b-vit-sep-mode/SKILL.md b/skills/test_model/qwen3-vl-8b-vit-sep-mode/SKILL.md new file mode 100644 index 0000000000..2224c35ae3 --- /dev/null +++ b/skills/test_model/qwen3-vl-8b-vit-sep-mode/SKILL.md @@ -0,0 +1,180 @@ +--- +name: test-model-qwen3-vl-8b-vit-sep-mode +description: >- + LightLLM Qwen3-VL-8B-Instruct visual separation (ViT sep / proxy): three processes in + order—config_server on 8090; internal Redis on 6000; visual_only with visual_rpyc 8091 + and afs_image_embed_dir; normal api_server tp 2 port 8089 with visual_use_proxy_mode. + After HTTP /v1/models on normal, lmms_eval mmmu_val (openai_compatible, batch 900, + OPENAI_API_BASE http://HOST:8089/v1); restore https_proxy for Hub while no_proxy includes 127.0.0.1. + lmms_eval_out, console log, mmmu_acc in summary. pipefail for tee exit code. +--- + +# Qwen3-VL-8B-Instruct **视觉分离(`visual_only` + `normal` + `config_server`)** + +**测试标识**:按顺序启动 **三条** `api_server` 进程——**`config_server`**(配置与元数据;**进程内部会启动 Redis 服务,并通过 `--config_server_visual_redis_port`(默认 `6000`)对外暴露**)、**`visual_only`**(独立视觉 / ViT 侧)、**`normal`**(主 LLM,**`--visual_use_proxy_mode`** 经 config 访问视觉侧)。**`6000` 不是本机另行安装的 `redis-server`**,勿与系统包管理器里的 Redis 混为一谈。本流程验证 **ViT 分离 + AFS 嵌入目录** 的联调,并在 **`normal` 就绪后须强制跑通 MMMU 验证集 `mmmu_val`**(**`lmms_eval` + `openai_compatible`**,命令与 **`skills/test_model/qwen3-vl-8b-mmmu-val/SKILL.md`** 评测块一致;仅 **`api_server` 拓扑** 为本文的 **visual 分离三进程**)。 + +**端口约定(固定,与脚本一致)**: + +| 用途 | 端口 | +|------|------| +| **`config_server`** | **`8090`**(**`--config_server_port`**) | +| **Redis(由 `config_server` 内部启动并对外暴露)** | **`6000`**(**`--config_server_visual_redis_port`**;与系统独立安装的 Redis 无关) | +| **`visual_only` RPyC** | **`8091`**(**`--visual_rpyc_port`**) | +| **`normal` HTTP** | **`8089`**(**`--port`**) | + +**算力**:**`visual_only`** 默认 **1 张 GPU**;**`normal`** 默认 **`--tp 2`** → **2 张 GPU**;**三组进程不得争抢同一物理 GPU**(脚本示例为 **visual:`0`**,**LLM:`6,7`**;实际以 **`nvidia-smi`** 选定)。 + +## 依赖 + +- **Python 环境**:与运行 **`lightllm.server.api_server`** 的虚拟环境一致。 +- **`mmmu_val` 评测(必须)**:已安装 **`lmms-eval`**,**`python3 -m lmms_eval`** 可用(安装示例见 **`skills/test_model/qwen3-vl-8b-mmmu-val/SKILL.md`**);未完成 **`mmmu_val`** 则本轮 **不算通过**。 +- **`6000` 端口**:由 **`config_server` 在启动后内部拉起 Redis 并监听**;**无需**、也**不应**依赖「事先在本机 **`apt install redis-server`** 并独占 **`6000`**」——若系统已有其它服务占用 **`6000`**,须释放或改 **`--config_server_visual_redis_port`**(**`visual_only` / `normal` 须同步同一端口参数**)。 + +## 日志目录(含 `summary.txt`) + +- 选定 **`LOG_DIR`**,三条进程日志建议:**`"${LOG_DIR}/config_server.log"`**、**`"${LOG_DIR}/visual_only.log"`**、**`"${LOG_DIR}/normal.log"`**(**`nohup … >> … 2>&1 &`**)。 +- **`summary.txt`**:三条命令摘要、各端口 listen 情况、**`MODEL_DIR` / `AFS_IMAGE_EMBED_DIR`** 最终取值;**`mmmu_val` 必记**:**`OPENAI_API_BASE`**、完整 **`lmms_eval` 命令**、**`lmms_eval_console.log`** 与 **`lmms_eval_out`** 路径、**`mmmu_acc`**(见下文「精度」);失败时写清原因与结论。 +- **`lmms_eval_console.log`**(**必须**):**`lmms_eval`** 终端输出(**`tee`**)。 +- **`lmms_eval_out/`**(**必须**):**`--output_path`** 下的 **`*_results.json`**、**`*_samples_mmmu_val.jsonl`**(**`--log_samples`** 生成样本日志)。 + +## 启动前检查 + +1. **端口**:本机 **`8090`、`6000`、`8091`、`8089`** 未被其它进程占用(**`8090` / `6000` 在 `config_server` 启动后由其占用**)。 +2. **`config_server` 已就绪**:**`8090`** 与 **`6000`** 均已 **LISTEN**(表明 **HTTP 配置面** 与 **内部 Redis 暴露面** 已起来),再启动 **`visual_only`**。 +3. **`MODEL_DIR`**:指向 **Qwen3-VL-8B-Instruct**;默认试跑 **`/mtc/models/Qwen3-VL-8B-Instruct`**;不存在或加载失败时 **向用户询问** 本机路径。 +4. **`AFS_IMAGE_EMBED_DIR`**:**`--afs_image_embed_dir`** 指向的目录须存在或可创建;默认试跑 **`/mtc/afs/vit_embed_dir`**;不可用时 **向用户询问**。 +5. **显卡**:**`visual_only`** 与 **`normal`** 的 **`CUDA_VISIBLE_DEVICES`** **不得重叠**;先 **`nvidia-smi`** 再 **`export`**(**不要写死**示例卡号)。 +6. **`CONFIG_SERVER_HOST`**:各进程 **`--config_server_host`** 须能访问 **`config_server` 的 HTTP 服务**。**三进程同机**时常见为 **`0.0.0.0`**(与历史脚本一致);多机时改为 **对端可达的 IP**,并保证防火墙放行 **`config_server` 端口(8090)**、**其内部 Redis 暴露端口(6000)**、**`8091`** 等。 +7. **代理**:每条 **`python`(LightLLM)** 前 **`export http_proxy=`**、**`export https_proxy=`**(与仓库其它 acc 测试一致)。**`lmms_eval` 拉取 `lmms-lab/MMMU` 时**若需经企业代理访问 Hugging Face,应在 **`no_proxy` 已包含 `127.0.0.1`** 的前提下**恢复 `https_proxy`(或等价镜像)**;否则易出现 Hub **`ConnectionError`**。仅当本机已有完整离线缓存且确认 **`datasets` 可纯离线命中**时,才可全程无代理。 +8. **`lmms-eval`**:**`python3 -m lmms_eval --help`** 可执行;否则无法完成必跑 **`mmmu_val`**。 + +## 可变项 + +| 变量 | 含义 | +|------|------| +| `LOG_DIR` | 本轮日志根目录。 | +| `MODEL_DIR` | **`--model_dir`**(三条中涉及模型的命令一致)。 | +| `AFS_IMAGE_EMBED_DIR` | **`--afs_image_embed_dir`**;**`visual_only` 与 `normal` 须一致**。 | +| `CONFIG_SERVER_HOST` | **`--config_server_host`**;同机试跑常用 **`0.0.0.0`**。 | +| `VISUAL_CUDA_DEVICES` | **`visual_only`** 的 **`CUDA_VISIBLE_DEVICES`**(1 张卡)。 | +| `LLM_CUDA_DEVICES` | **`normal`** 的 **`CUDA_VISIBLE_DEVICES`**(**`--tp 2`** → 2 张卡)。 | +| `AFS_EMBED_CAPACITY` | **`--afs_embed_capacity`**;默认 **`250000`**;调试替换逻辑时可改为较小值(例如 **`100`**),见「调试提示」。 | +| `ORIG_HTTP_PROXY` / `ORIG_HTTPS_PROXY` | 在清空代理启动 LightLLM **之前**备份(见 **`lmms_eval` 命令块**),评测阶段恢复以便访问 Hugging Face Hub。 | + +**开跑前导出示例**: + +```bash +export ORIG_HTTP_PROXY="${http_proxy-}" +export ORIG_HTTPS_PROXY="${https_proxy-}" +export LOG_DIR='〈日志目录〉' +export MODEL_DIR='/mtc/models/Qwen3-VL-8B-Instruct' +export AFS_IMAGE_EMBED_DIR='/mtc/afs/vit_embed_dir' +export CONFIG_SERVER_HOST='0.0.0.0' +export AFS_EMBED_CAPACITY=250000 +# export VISUAL_CUDA_DEVICES='0' +# export LLM_CUDA_DEVICES='6,7' +``` + +## 服务就绪判定 + +**不要使用 HTTP health 作为唯一依据**。依次确认:**`config_server` 已占用 `8090` 与 `6000`(内部 Redis)** → **`8091`(`visual_only` RPyC)** → **`8089`(`normal`)** 的 **LISTEN** 状态,并结合各 **`*.log`**;可约 **每 20 秒** 查看日志直至就绪或报错。 + +## 启动命令(须按顺序) + +以下块前均须 **`export http_proxy=`**、**`export https_proxy=`**;生产式跑法请自行加 **`nohup`** 与 **`>> "${LOG_DIR}/….log" 2>&1 &`**。 + +### 1)`config_server`(最先) + +```bash +python -m lightllm.server.api_server \ + --run_mode config_server \ + --config_server_host 0.0.0.0 \ + --config_server_port 8090 \ + --config_server_visual_redis_port 6000 +``` + +若仅需绑定到 **`CONFIG_SERVER_HOST`**,将 **`--config_server_host`** 改为 **`"${CONFIG_SERVER_HOST}"`**(须与 **`visual_only` / `normal` 中的 `--config_server_host`** 指向同一可达地址)。 + +### 2)`visual_only`(**`config_server` 已在 8090 / 6000 就绪后**) + +**`--visual_rpyc_port 8091`** 为 **visual_only 模式必需**,供其它服务调用本机视觉推理接口。 + +```bash +CUDA_VISIBLE_DEVICES="${VISUAL_CUDA_DEVICES}" python -m lightllm.server.api_server \ + --run_mode visual_only \ + --host 0.0.0.0 \ + --config_server_host "${CONFIG_SERVER_HOST}" \ + --config_server_port 8090 \ + --config_server_visual_redis_port 6000 \ + --model_dir "${MODEL_DIR}" \ + --visual_dp 1 \ + --visual_tp 1 \ + --afs_image_embed_dir "${AFS_IMAGE_EMBED_DIR}" \ + --afs_embed_capacity "${AFS_EMBED_CAPACITY}" \ + --visual_rpyc_port 8091 +``` + +**`--host`** 为 **本进程监听地址**;与 **`--config_server_host`** 含义不同:后者为 **config_server 的可达地址**。 + +### 3)`normal`(visual 就绪后) + +```bash +CUDA_VISIBLE_DEVICES="${LLM_CUDA_DEVICES}" python -m lightllm.server.api_server \ + --run_mode normal \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port 8089 \ + --config_server_host "${CONFIG_SERVER_HOST}" \ + --config_server_port 8090 \ + --config_server_visual_redis_port 6000 \ + --visual_dp 1 \ + --afs_image_embed_dir "${AFS_IMAGE_EMBED_DIR}" \ + --afs_embed_capacity "${AFS_EMBED_CAPACITY}" \ + --visual_use_proxy_mode +``` + +## **`mmmu_val` 评测(`normal` 已监听 `8089` 后,必须执行)** + +在 **`config_server` / `visual_only` / `normal` 均就绪**、**`normal`** 仍占用 **`8089`** 时执行;**关停任一服务前须跑完本节**。评测流量只打 **`normal` HTTP**;**`OPENAI_API_BASE`** 须指向 **`http://〈可达主机〉:8089/v1`**(**末尾含 `/v1`**)。**`--model_args` 中 `model_version=` 与三进程共用的 `MODEL_DIR` 须为同一权重目录**。 + +```bash +export BIND_URL_HOST='127.0.0.1' +export PORT=8089 +export OPENAI_API_BASE="http://${BIND_URL_HOST}:${PORT}/v1" +export OPENAI_API_KEY="${OPENAI_API_KEY:-lightllm123}" +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${BIND_URL_HOST} + +# 若启动 LightLLM 时清空了代理,此处恢复 Hub 代理(勿把 127.0.0.1 放进 ALL_PROXY) +export http_proxy="${ORIG_HTTP_PROXY:-}" +export https_proxy="${ORIG_HTTPS_PROXY:-}" + +export LOG_DIR='〈与上文同一日志目录〉' +export MODEL_DIR='/mtc/models/Qwen3-VL-8B-Instruct' + +mkdir -p "${LOG_DIR}/lmms_eval_out" + +timeout 3600 python3 -m lmms_eval \ + --model openai_compatible \ + --model_args "model_version=${MODEL_DIR},tp=1" \ + --tasks mmmu_val \ + --batch_size 900 \ + --log_samples \ + --log_samples_suffix openai_compatible \ + --output_path "${LOG_DIR}/lmms_eval_out" \ + 2>&1 | tee "${LOG_DIR}/lmms_eval_console.log" +``` + +说明:**`model_args` 中的 `tp=1`** 为 **`lmms_eval` / `openai_compatible` 侧参数**,与 **`normal` 的 `--tp 2`** 不同。若环境无 **`timeout`**,去掉 **`timeout 3600`**。仅本地冒烟可在命令中加 **`--limit`**;**正式回归不得省略全量 `mmmu_val`**。 + +**精度(必须写入 `summary.txt`)**:最新 **`"${LOG_DIR}/lmms_eval_out"/*_results.json`** 中 **`results.mmmu_val["mmmu_acc,none"]`**(**0~1**);无 **`jq`** 时打开 JSON 或对照 **`lmms_eval_console.log`** 末尾汇总表。 + +## 调试提示(可选) + +- 将 **`AFS_EMBED_CAPACITY`** 设为较小值(例如 **`100`**)可更快触发 **嵌入目录替换 / 淘汰** 相关逻辑,便于缩短调试周期;正式回归再恢复 **`250000`**(或业务约定值)。 + +## 执行约定 + +1. **顺序**:**`config_server` → `visual_only` → `normal` → `mmmu_val`(`lmms_eval`)**,前三步不可颠倒;**`mmmu_val` 未完成不得视为本轮通过**。 +2. **`mmmu_val`**:须在 **`normal` 就绪** 之后、**关停 `normal` 之前** 跑完(依赖 **`8089`** 与 **`OPENAI_API_BASE`**)。**`model_version` 与 `MODEL_DIR` 必须一致**。 +3. **关停**:**`mmmu_val` 成功或失败均已落盘**(日志与 **`summary.txt`**)后,依次结束 **`normal`、`visual_only`、`config_server`**;**`config_server` 退出后,其内部在 `6000` 上暴露的 Redis 随之停止**,释放端口与 GPU。 +4. **失败**:将摘要写入 **`summary.txt`**(含 **`lmms_eval` 退出码**、**`lmms_eval_console.log`** 末尾),并在对话中给出关键日志与端口状态。 diff --git a/skills/test_model/qwen3.5-0.8b-gsm8k-scenarios/SKILL.md b/skills/test_model/qwen3.5-0.8b-gsm8k-scenarios/SKILL.md new file mode 100644 index 0000000000..15f79285d3 --- /dev/null +++ b/skills/test_model/qwen3.5-0.8b-gsm8k-scenarios/SKILL.md @@ -0,0 +1,199 @@ +--- +name: test-model-qwen3.5-0.8b-gsm8k-scenarios +description: >- + LightLLM Qwen3.5-0.8B GSM8K multi-scenario regression: five isolated runs (baseline + api_server, prefill cudagraph, linear-attention cache flags, CPU cache plus linear-att, + disk cache with LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH). Each scenario uses api_server + tp 2 port 8089, then lm_eval local-completions gsm8k batch 500. Scenarios 4 and 5 run + lm_eval twice for cache warm hit. Per-scenario LOG_DIR, server.log, eval logs, summary.txt. + GPUs from nvidia-smi; port listen readiness not health; clear proxies and set no_proxy. + Default MODEL_DIR HuggingFace hub snapshot path; default DISK_CACHE_DIR /mtc/test/tmp/ + for scenario 5; ask user for paths if missing or not writable. +--- + +# Qwen3.5-0.8B **多场景 GSM8K 回归** + +覆盖五种 **`api_server` 配置**:**基线**、**Prefill CUDA Graph**、**Linear-Attention 缓存参数**、**CPU Cache 与 Linear-Att 组合**、**Disk Cache(含环境变量)**。每种配置单独起服务并完成 **`lm_eval`**,互不混用日志。 + +**测试标识**:同一 **`MODEL_DIR`**(Qwen3.5-0.8B 权重)下,按场景顺序执行:**启动 `api_server` → 确认端口监听与服务日志正常 → 执行 `lm_eval`**。场景 **4、5** 在相同服务配置下 **`lm_eval` 连续执行两次**(缓存预热与命中后行为/耗时对照)。 + +**端口**:**`8089`**(默认,与 **`PORT`** 一致)。 + +**算力**:**`--tp 2`**,需要 **2 张物理 GPU**。 + +**评测**:**`lm_eval`**,**`--tasks gsm8k`**,**`--batch_size 500`**,**`model`** 为 **`qwen/Qwen3.5-0.8B`**。**`--model_dir` 与 `model_args` 中的 `tokenizer` 必须为同一目录路径(即 `MODEL_DIR`)**。 + +## 场景总览 + +| # | 名称 | `api_server` 相对上一场景的增量 | `lm_eval` | +|---|------|----------------------------------|-----------| +| 1 | 基线 | **`--model_dir` / `--tp 2` / `--port`** | 1 次 | +| 2 | Prefill CUDA Graph | **`--enable_prefill_cudagraph`** | 1 次 | +| 3 | Linear-Attention 参数 | **`--linear_att_cache_size 10`**、**`--linear_att_hash_page_size 256`**、**`--linear_att_page_block_num 2`**、**`--max_total_token_num 270000`** | 1 次 | +| 4 | CPU Cache + Linear-Att | 在场景 3 同类参数基础上增加 **`--enable_cpu_cache`**、**`--cpu_cache_storage_size 128`**(**`--max_total_token_num` 仍为 `270000`**) | **2 次** | +| 5 | Disk Cache | **`LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH=128`**;**`--linear_att_cache_size 128`** 等一组参数,及 **`--enable_cpu_cache`**、**`--enable_disk_cache`**、**`--disk_cache_dir`** 等(见下文命令块) | **2 次** | + +## 日志目录(含 `summary.txt`) + +- **每个场景**使用独立 **`LOG_DIR`**(建议绝对路径,并带场景编号或时间戳)。 +- **`api_server`**:标准输出与标准错误写入 **`"${LOG_DIR}/server.log"`**(推荐使用 **`nohup … >> … 2>&1 &`**)。 +- **`lm_eval`**:默认写入 **`"${LOG_DIR}/eval_gsm8k.log"`**;同一服务下的第二次评测写入 **`"${LOG_DIR}/eval_gsm8k_run2.log"`**(场景 4、5)。 +- **`summary.txt`**:记录本场景完整启动参数、**`lm_eval` 命令摘要**、端口与日志就绪情况、两轮评测目的与结论(若适用)、异常与最终结论。 + +## 启动前检查 + +1. **显卡**:执行 **`nvidia-smi`**,按占用选定 2 张卡后 **`export CUDA_VISIBLE_DEVICES='i,j'`**(**不要写死卡号**)。 +2. **端口**:每轮启动前确认 **`8089`**(或当前 **`PORT`**)未被占用(例如 **`ss -tlnp`**、**`lsof -i :8089`**);上一轮结束后须 **终止对应 `api_server` 进程** 再启动下一轮。 +3. **`MODEL_DIR`**:见 **「路径约定」**;启动前执行 **`test -d "${MODEL_DIR}"`**;若路径无效或服务报路径类错误,按该节处理。 +4. **`DISK_CACHE_DIR`(仅场景 5)**:见 **「路径约定」**;须为可写目录;先 **`mkdir -p "${DISK_CACHE_DIR}"`**,仍不可写则按该节处理。 +5. **代理**:每次启动 **`api_server`** 或执行 **`lm_eval`** 之前,将 **`http_proxy` / `https_proxy` 置空**;执行 **`lm_eval`** 时配置 **`no_proxy`**(见评测命令块)。 + +## 路径约定(`MODEL_DIR` 与 `DISK_CACHE_DIR`) + +**原则**:下列 **默认路径** 与历史 **`test/acc/test_qwen3.5.sh`** 一致,可作为首轮试跑起点。若目录不存在、不可读、权重加载失败,或磁盘缓存路径不可写,**不得在未向用户确认的情况下反复更换路径盲试**;应 **向用户询问** 本机可用的 **`MODEL_DIR` / `DISK_CACHE_DIR` 绝对路径**,在收到答复后更新环境变量,并在 **`summary.txt`** 中记录最终采用的路径。 + +### `MODEL_DIR` + +- **`--model_dir`** 与 **`lm_eval` 的 `tokenizer` 字段** 必须为 **同一字符串**(本 skill 中均记为 **`MODEL_DIR`**)。 +- **默认路径(HuggingFace Hub 本地缓存示例)**:以下 **`snapshots/` 下目录名随实际下载版本可能变化**;若本机不存在该路径,则默认值不适用,须改用本机真实路径或询问用户。 + +```bash +export MODEL_DIR='/root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17' +``` + +- **备选**:若权重部署在统一管理目录(例如 **`/mtc/models/Qwen3.5-0.8B`**),且 **`test -d`** 通过,可 **`export MODEL_DIR='/mtc/models/Qwen3.5-0.8B'`**。 + +### `DISK_CACHE_DIR`(仅场景 5:`--disk_cache_dir`) + +- **默认路径**: + +```bash +export DISK_CACHE_DIR='/mtc/test/tmp/' +``` + +- 场景 5 启动 **`api_server`** 前执行 **`mkdir -p "${DISK_CACHE_DIR}"`**。若父目录不可创建、无写权限或不符合运维规范,**向用户询问** 合适的可写目录后再 **`export`**。 + +## 可变项 + +| 变量 | 含义 | +|------|------| +| `LOG_DIR` | 当前场景的日志根目录(**`server.log` / `eval_gsm8k*.log` / `summary.txt`**)。 | +| `MODEL_DIR` | **`--model_dir`**;**`lm_eval` 的 `tokenizer`**。默认路径与失败时处理见 **「路径约定」**。 | +| `PORT` | HTTP 端口,默认 **`8089`**。 | +| `BIND_URL_HOST` | **`lm_eval` 中 `base_url` 的主机名**;本机常用 **`127.0.0.1`** 或 **`localhost`**。 | +| `CUDA_VISIBLE_DEVICES` | 2 个物理 GPU 索引(逗号分隔)。 | +| `DISK_CACHE_DIR` | 场景 5 的 **`--disk_cache_dir`**。默认 **`/mtc/test/tmp/`**;不可写或不可用时见 **「路径约定」**。 | + +**开跑前导出示例**(按需修改引号内路径): + +```bash +export LOG_DIR='〈场景日志目录,每场景换新目录〉' +export MODEL_DIR='/root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17' +export DISK_CACHE_DIR='/mtc/test/tmp/' +export PORT=8089 +export BIND_URL_HOST='127.0.0.1' +# export CUDA_VISIBLE_DEVICES='6,7' +``` + +## 服务就绪判定 + +**不要使用 HTTP health 接口作为唯一就绪依据**。应结合:**`PORT` 是否处于 LISTEN 状态**、**`server.log` 是否出现致命错误**;可约 **每 20 秒** 查看一次日志直至可接受评测或确认失败。 + +## `lm_eval` 命令模板(单次) + +服务就绪后执行;场景 **4、5** 在**同一 `api_server` 生命周期内**再执行一次,并将第二次输出重定向到 **`eval_gsm8k_run2.log`**。 + +```bash +export http_proxy= +export https_proxy= + +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${BIND_URL_HOST} + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +lm_eval --model local-completions \ + --model_args "{\"model\":\"qwen/Qwen3.5-0.8B\", \"base_url\":\"http://${BIND_URL_HOST}:${PORT}/v1/completions\", \"max_length\": 16384, \"tokenizer\":\"${MODEL_DIR}\"}" \ + --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +第二次(场景 4、5):将末尾重定向改为 **`>> "${LOG_DIR}/eval_gsm8k_run2.log" 2>&1`**,并在 **`summary.txt`** 中说明两次运行的目的(例如缓存预热与命中后对照)。 + +## 各场景 `api_server` 命令模板 + +以下命令块仅列出 **`api_server` 参数差异**。实际执行时须在 **`export http_proxy=`、`export https_proxy=`** 之后,按仓库其它 acc 测试惯例自行补全:**`LOADWORKER=18 CUDA_VISIBLE_DEVICES=…`**、**`nohup`**、以及 **`>> "${LOG_DIR}/server.log" 2>&1 &`**。 + +### 场景 1:基线 + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port "${PORT}" +``` + +### 场景 2:Prefill CUDA Graph + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port "${PORT}" \ + --enable_prefill_cudagraph +``` + +### 场景 3:Linear-Attention 参数 + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port "${PORT}" \ + --linear_att_cache_size 10 \ + --linear_att_hash_page_size 256 \ + --linear_att_page_block_num 2 \ + --max_total_token_num 270000 +``` + +### 场景 4:CPU Cache + Linear-Attention(`lm_eval` 两次) + +```bash +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port "${PORT}" \ + --linear_att_cache_size 10 \ + --linear_att_hash_page_size 256 \ + --linear_att_page_block_num 2 \ + --max_total_token_num 270000 \ + --enable_cpu_cache \ + --cpu_cache_storage_size 128 +``` + +### 场景 5:Disk Cache(环境变量 + `lm_eval` 两次) + +在 **`python`** 进程前设置 **`LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH=128`**(与历史脚本一致,使子进程继承该变量): + +```bash +LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH=128 \ +python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --tp 2 \ + --port "${PORT}" \ + --linear_att_cache_size 128 \ + --linear_att_hash_page_size 256 \ + --linear_att_page_block_num 32 \ + --max_total_token_num 270000 \ + --enable_cpu_cache \ + --cpu_cache_storage_size 32 \ + --enable_disk_cache \ + --disk_cache_storage_size 512 \ + --disk_cache_dir "${DISK_CACHE_DIR}" +``` + +## 执行约定 + +1. **场景顺序**:按 **1 → 2 → 3 → 4 → 5** 执行;每步 **先停止上一场景的 `api_server`**,使用 **新的 `LOG_DIR`**。 +2. **`MODEL_DIR` 与 `DISK_CACHE_DIR`**:遵循 **「路径约定」**;**`summary.txt`** 记录最终采用的路径。 +3. **场景 4、5**:在同一服务配置下 **`lm_eval` 执行两次**,并保留两次日志文件以便对比。 +4. **收尾**:全部场景结束后,确保 **`api_server` 进程已结束**,释放 GPU 与端口。 +5. **失败处理**:将错误摘要写入 **`summary.txt`**,并在对话中给出关键日志信息。 \ No newline at end of file diff --git a/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md b/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md new file mode 100644 index 0000000000..dca4fa9a2a --- /dev/null +++ b/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md @@ -0,0 +1,313 @@ +--- +name: test-model-qwen3.5-0.8b-pd-nixl +description: >- + LightLLM Qwen3.5-0.8B PD disaggregation over NIXL gsm8k: pd_master on 8089, + prefill on 8001, decode on 8002. Supports TP1 and TP2 runs by setting + TP / PREFILL_CUDA_DEVICES / DECODE_CUDA_DEVICES. Qwen3.5 has linear-attention + state transfer; use --pd_kv_page_size 2048 and --pd_kv_page_num 16. + lm_eval hits pd_master URL. Requires UCX/RDMA env, nvidia_peermem + check, curl warmup before lm_eval, registration wait in pd_master.log, and + summary.txt. Includes optional repeated-prompt decode cache probe for linear-att + page-boundary behavior. +--- + +# Qwen3.5-0.8B **PD 分离(NIXL)** 本地 GSM8K 评测 + +**测试标识**:同一 **`MODEL_DIR`(Qwen3.5-0.8B)** 下拆三条 `api_server` 进程: +**`pd_master`**、**`prefill`**、**`decode`**。评测和 warmup 只访问 +**`pd_master` 的 HTTP 端口 `8089`**。 + +Qwen3.5 与 Qwen3-8B 的关键差异: + +| 项 | Qwen3.5-0.8B NIXL PD 要点 | +|---|---| +| linear-att 状态 | PD 传输除了 KV page,还会传 `linear_att_state` 特殊页 | +| NIXL page size | 建议固定 **`--pd_kv_page_size 2048`**;`1024` 可能不足以容纳 linear-att 状态 | +| page num | 建议 **`--pd_kv_page_num 16`** 起步,避免 page 池过大导致显存压力 | +| cache 判断 | repeated prompt 可能只在 prefill 侧命中,decode 侧不一定 decode-only 命中 | + +## 日志目录 + +每轮使用独立 `LOG_DIR`,至少保留: + +- `summary.txt` +- `pd_master.log` +- `prefill.log` +- `decode.log` +- `curl_warmup.log` +- `eval_gsm8k.log` + +建议命名: + +```bash +export LOG_DIR="/mtc/wzj/lightllm_dev2/LightLLM/test/benchmark/static_inference/log/qwen35_pd_$(date +%Y%m%d_%H%M%S)" +mkdir -p "${LOG_DIR}" +``` + +## 启动前检查 + +1. **模型目录**:优先使用 `MODEL_DIR=/mtc/models/Qwen3.5-0.8B`;不存在时再改成本机实际路径。 +2. **端口**:确认 `8089`、`8001`、`8002` 空闲。 +3. **显卡**:TP1 需要 prefill/decode 各 1 张卡;TP2 需要 prefill/decode 各 2 张卡,互不重叠。 +4. **代理**:启动服务和评测前清空 `http_proxy` / `https_proxy`;评测设置 `no_proxy`。 +5. **UCX/RDMA**:prefill/decode 启动前设置 `UCX_NET_DEVICES`、`UCX_TLS`。本机若默认 UCX 打到 `mlx5_8` 报 `Address not valid`,可显式使用 `mlx5_0:1` 到 `mlx5_7:1`。 +6. **nvidia_peermem**:运行本目录的 `check_nvidia_peermem.sh`,结果写入 `summary.txt`。 +7. **MPS**:如需更稳定的高并发/传输性能,可在启动服务前开启 NVIDIA MPS,并把开启状态写入 `summary.txt`。 + +## 变量配置 + +### TP2 推荐配置 + +```bash +export MODEL_DIR=/mtc/models/Qwen3.5-0.8B +export MODEL_NAME='qwen/Qwen3.5-0.8B' +export TP=2 +export PREFILL_CUDA_DEVICES='0,1' +export DECODE_CUDA_DEVICES='2,3' +export PD_KV_PAGE_SIZE=2048 +export PD_KV_PAGE_NUM=16 +export PD_MASTER_IP="$(hostname -I | awk '{print $1}')" +export HOST="${PD_MASTER_IP}" +``` + +### TP1 快速验证配置 + +```bash +export MODEL_DIR=/mtc/models/Qwen3.5-0.8B +export MODEL_NAME='qwen/Qwen3.5-0.8B' +export TP=1 +export PREFILL_CUDA_DEVICES='4' +export DECODE_CUDA_DEVICES='5' +export PD_KV_PAGE_SIZE=2048 +export PD_KV_PAGE_NUM=16 +export PD_MASTER_IP="$(hostname -I | awk '{print $1}')" +export HOST="${PD_MASTER_IP}" +``` + +### UCX 示例 + +按本机拓扑调整,不要盲目照抄其它机器: + +```bash +export UCX_NET_DEVICES='mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1' +export UCX_TLS=rc,cuda,gdr_copy +``` + +## 启动命令 + +先写入基础信息: + +```bash +export http_proxy= +export https_proxy= +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} + +{ + echo "MODEL_DIR=${MODEL_DIR}" + echo "MODEL_NAME=${MODEL_NAME}" + echo "TP=${TP}" + echo "PREFILL_CUDA_DEVICES=${PREFILL_CUDA_DEVICES}" + echo "DECODE_CUDA_DEVICES=${DECODE_CUDA_DEVICES}" + echo "PD_KV_PAGE_SIZE=${PD_KV_PAGE_SIZE}" + echo "PD_KV_PAGE_NUM=${PD_KV_PAGE_NUM}" + echo "PD_MASTER_IP=${PD_MASTER_IP}" + echo "HOST=${HOST}" + echo "UCX_NET_DEVICES=${UCX_NET_DEVICES-}" + echo "UCX_TLS=${UCX_TLS-}" +} | tee "${LOG_DIR}/summary.txt" + +bash skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh >> "${LOG_DIR}/summary.txt" 2>&1 +``` + +### 1. 启动 `pd_master` + +```bash +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --run_mode pd_master \ + --host "${PD_MASTER_IP}" \ + --port 8089 \ + >> "${LOG_DIR}/pd_master.log" 2>&1 & +``` + +等待 `8089` listen 后再启动节点。 + +### 2. 启动 `prefill` + +```bash +LOADWORKER=18 CUDA_VISIBLE_DEVICES="${PREFILL_CUDA_DEVICES}" \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --run_mode prefill \ + --tp "${TP}" \ + --dp 1 \ + --host "${HOST}" \ + --port 8001 \ + --disable_cudagraph \ + --pd_master_ip "${PD_MASTER_IP}" \ + --pd_master_port 8089 \ + --pd_kv_page_size "${PD_KV_PAGE_SIZE}" \ + --pd_kv_page_num "${PD_KV_PAGE_NUM}" \ + >> "${LOG_DIR}/prefill.log" 2>&1 & +``` + +### 3. 启动 `decode` + +```bash +LOADWORKER=18 CUDA_VISIBLE_DEVICES="${DECODE_CUDA_DEVICES}" \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --run_mode decode \ + --tp "${TP}" \ + --dp 1 \ + --host "${HOST}" \ + --port 8002 \ + --pd_master_ip "${PD_MASTER_IP}" \ + --pd_master_port 8089 \ + --pd_kv_page_size "${PD_KV_PAGE_SIZE}" \ + --pd_kv_page_num "${PD_KV_PAGE_NUM}" \ + >> "${LOG_DIR}/decode.log" 2>&1 & +``` + +## 就绪判定 + +不要只看端口。必须等待 `pd_master.log` 同时出现: + +```text +mode: prefill ... registed +mode: decode ... registed +``` + +可用命令: + +```bash +rg 'mode: prefill .* registed|mode: decode .* registed|ERROR|Traceback|Exception' "${LOG_DIR}/pd_master.log" "${LOG_DIR}/prefill.log" "${LOG_DIR}/decode.log" +``` + +## Warmup + +`lm_eval` 前必须先打一次 `pd_master`: + +```bash +curl -sS -w "\nhttp_code:%{http_code}\n" -X POST "http://${PD_MASTER_IP}:8089/v1/completions" \ + -H "Content-Type: application/json" \ + -d "{\"model\":\"${MODEL_NAME}\",\"prompt\":\"warmup\",\"max_tokens\":16,\"temperature\":0}" \ + | tee "${LOG_DIR}/curl_warmup.log" +``` + +期望 `http_code:200`。失败时先查 `pd_master.log` / `prefill.log` / `decode.log`,不要直接跑全量评测。 + +## GSM8K 评测 + +默认并发和 batch 使用 64,避免高并发掩盖关键问题;压测时再提高。 + +```bash +export http_proxy= +export https_proxy= +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +lm_eval --model local-completions \ + --model_args "model=${MODEL_NAME},base_url=http://${PD_MASTER_IP}:8089/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False,tokenizer=${MODEL_DIR}" \ + --tasks gsm8k \ + --batch_size 64 \ + --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +提取结果: + +```bash +rg -n 'flexible-extract|strict-match|exact_match|Traceback|ERROR|can not find waiting WRITE task|has_error=True' \ + "${LOG_DIR}/eval_gsm8k.log" "${LOG_DIR}/pd_master.log" "${LOG_DIR}/prefill.log" "${LOG_DIR}/decode.log" \ + | tee -a "${LOG_DIR}/summary.txt" +``` + +参考正常结果: + +| 场景 | 参考精度 | +|---|---| +| TP1 NIXL PD | `flexible-extract exact_match ~= 0.332`,`strict-match exact_match ~= 0.327` | +| TP2 NIXL PD | `flexible-extract exact_match ~= 0.331`,`strict-match exact_match ~= 0.328` | + +## 可选:decode-only cache 命中探针 + +这个探针用于确认重复 prompt 是否在 decode 节点全命中。Qwen3.5 的 linear-att cache 以 +`linear_att_hash_page_size` 为粒度,默认 `512`。历史观察显示: + +- prefill 侧会按 512 token 粒度逐步命中,例如 513 的第二次可命中 512。 +- decode 侧可能仍为 `gpu cache hit: False`、`gpu_prompt_cache_len:0`。 +- 只要 decode 未全命中,仍会出现 `recv WRITE request from prefill` 和 `linear_att_state` 传输。 + +### 简单重复 prompt + +在同一套服务生命周期内连续请求两次相同 prompt: + +```bash +PROMPT_FILE="${LOG_DIR}/repeat_prompt.txt" +python3 - <<'PY' "${MODEL_DIR}" "${PROMPT_FILE}" +from transformers import AutoTokenizer +import sys +tok = AutoTokenizer.from_pretrained(sys.argv[1], trust_remote_code=True) +target = 2049 +s = "Qwen3.5 linear attention cache boundary probe. " +unit = " Repeatable cache probe sentence." +while len(tok.encode(s, add_special_tokens=False)) < target: + s += unit +open(sys.argv[2], "w").write(s) +print(len(tok.encode(s, add_special_tokens=False))) +PY + +for i in 1 2; do + curl -sS -X POST "http://${PD_MASTER_IP}:8089/v1/completions" \ + -H "Content-Type: application/json" \ + -d "{\"model\":\"${MODEL_NAME}\",\"prompt\":$(python3 -c 'import json,sys; print(json.dumps(open(sys.argv[1]).read()))' "${PROMPT_FILE}"),\"max_tokens\":4,\"temperature\":0}" \ + > "${LOG_DIR}/repeat_${i}.json" + sleep 2 +done +``` + +### 判定信号 + +```bash +rg -n 'gpu cache hit:|recv WRITE request from prefill|start WRITE to decode node|linear_att_state|trans task ret success' \ + "${LOG_DIR}/prefill.log" "${LOG_DIR}/decode.log" \ + | tee -a "${LOG_DIR}/summary.txt" +``` + +decode-only 全命中的期望信号: + +| 日志 | 期望 | +|---|---| +| `decode.log` | `gpu cache hit: True` | +| `decode.log` | `gpu_prompt_cache_len` 接近 `prompt_tokens` 或至少 `input_len - cur_kv_len <= 1` | +| `decode.log` | 不再出现真实 `recv WRITE request from prefill` | +| `prefill.log` | 不再出现对应请求的 `start WRITE to decode node` | + +如果 decode 仍是 `gpu cache hit: False gpu_prompt_cache_len:0`,则说明没有进入 decode-only 命中路径。 + +## 常见问题 + +| 现象 | 处理 | +|---|---| +| `NIXL_ERR_BACKEND` / `uct_iface_open(rc_verbs/mlx5_8:1) failed: Address not valid` | 显式设置可用 `UCX_NET_DEVICES`,例如避开 `mlx5_8/9` | +| `digest sent was rejected` | 多为快速重启后的共享内存 / multiprocessing authkey 残留;清理端口和残留 `lightllm::...` worker 后重启 | +| `can not find waiting WRITE task` | 检查 NIXL notify key、abort 日志、以及 `pd_io_struct.py` 中 key 是否包含进程本地 `req_idx` | +| 1024 page size 失败 | Qwen3.5 linear-att state 页可能放不下;使用 `--pd_kv_page_size 2048` | +| 第二次同 prompt 仍走 WRITE | 可能是 decode 侧没有建立可复用 cache,或 linear-att 尾块状态无法全命中 | + +## 收尾 + +结束后释放本轮服务: + +```bash +fuser -k 8089/tcp 8001/tcp 8002/tcp || true +``` + +如仍有显存占用,检查残留 worker: + +```bash +ps -eo pid,ppid,stat,cmd | rg 'lightllm::|api_server|hypercorn' +nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv +``` + diff --git a/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh b/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh new file mode 100755 index 0000000000..86dca002d0 --- /dev/null +++ b/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Check nvidia_peermem (GPUDirect RDMA) for NIXL PD / UCX over IB. +# Usage: bash skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh [LOG_DIR] +# LOG_DIR optional: scan prefill.log / decode.log for UCX GPUDirect lines. +set -euo pipefail + +LOG_DIR="${1:-}" +FAIL=0 + +echo "=== nvidia_peermem ===" + +if lsmod 2>/dev/null | awk '{print $1}' | grep -qx nvidia_peermem; then + ver="$(cat /sys/module/nvidia_peermem/version 2>/dev/null || echo '?')" + echo "OK: module loaded (version ${ver})" +else + echo "FAIL: nvidia_peermem not loaded" + FAIL=1 +fi + +if [[ -n "$LOG_DIR" ]]; then + for f in prefill.log decode.log; do + [[ -f "${LOG_DIR}/${f}" ]] || continue + if grep -q 'GPUDirect RDMA is not detected' "${LOG_DIR}/${f}" 2>/dev/null; then + echo "FAIL: ${f} -> GPUDirect RDMA is not detected (restart services after modprobe)" + FAIL=1 + elif grep -q 'GPUDirect RDMA is detected' "${LOG_DIR}/${f}" 2>/dev/null; then + echo "OK: ${f} -> GPUDirect RDMA is detected" + fi + done +fi + +if [[ "$FAIL" -ne 0 ]]; then + cat <<'EOF' + +Enable GPUDirect RDMA: + sudo modprobe nvidia_peermem + lsmod | grep nvidia_peermem + # cross-node: run on every host; then restart prefill / decode +EOF + exit 1 +fi + +exit 0 diff --git a/test/acc/test_deepseekr1.sh b/test/acc/test_deepseekr1.sh index 899020ce42..0d2df72d20 100644 --- a/test/acc/test_deepseekr1.sh +++ b/test/acc/test_deepseekr1.sh @@ -2,4 +2,4 @@ LOADWORKER=18 python -m lightllm.server.api_server --batch_max_tokens 6000 --mod -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384, "tokenizer":"/mtc/models/DeepSeek-R1"}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_deepseekr1_mtp.sh b/test/acc/test_deepseekr1_mtp.sh index 20975c2da6..7b511e41ac 100644 --- a/test/acc/test_deepseekr1_mtp.sh +++ b/test/acc/test_deepseekr1_mtp.sh @@ -1,3 +1,7 @@ LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --mem_fraction 0.75 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384, "tokenizer":"/mtc/models/DeepSeek-R1"}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + +# 帮我写一段提示词,告诉AI单独一个一个的进行上述测试的启动服务,然后再执行评测脚本,将结果写入out.txt 中,注意需要标记启动的参数和结果信息。不要用health 接口去判断服务是否启动,直接探测端口是否处于listen状态即可, 执行评测命令的时候,需要用no_proxy 将本地local ip 排除。 +# 不要写额外的脚本来启动服务,就是单独一个一个的按照上面的描述启动服务,然后再执行评测脚本,然后注意等待服务启动完成,可以20s检测一次其控制台输出,看是否启动完成,还是启动报错。 +# 应该把server启动在后台,然后再去探测端口, 判断服务是否启动成功。最后需要总结下测试的结果。 \ No newline at end of file diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh index d5d50f9e94..29c7515b27 100644 --- a/test/acc/test_deepseekr1_mtp_ep.sh +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -1,3 +1,22 @@ -LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --enable_ep_moe --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --enable_ep_moe --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --max_req_total_len 56000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code + + +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --enable_ep_moe --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --max_req_total_len 56000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 --enable_tpsp_mix_mode + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code + + +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --enable_ep_moe --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --max_req_total_len 56000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 --enable_prefill_microbatch_overlap --enable_decode_microbatch_overlap + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code + + +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --enable_ep_moe --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --max_req_total_len 56000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 --enable_prefill_microbatch_overlap --enable_decode_microbatch_overlap --enable_dp_prefill_balance + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code + +# 帮我写一段提示词,告诉AI单独一个一个的进行上述测试的启动服务,然后再执行评测脚本,将结果写入out.txt 中,注意需要标记启动的参数和结果信息。不要用health 接口去判断服务是否启动,直接探测端口是否处于listen状态即可, 执行评测命令的时候,需要用no_proxy 将本地local ip 排除。 +# 不要写额外的脚本来启动服务,就是单独一个一个的按照上面的描述启动服务,然后再执行评测脚本,然后注意等待服务启动完成,可以20s检测一次其控制台输出,看是否启动完成,还是启动报错。 +# 应该把server启动在后台,然后再去探测端口, 判断服务是否启动成功。最后需要总结下测试的结果。 \ No newline at end of file diff --git a/test/acc/test_deepseekv32_ep.sh b/test/acc/test_deepseekv32_ep.sh index 815d31c5e8..c34f546022 100644 --- a/test/acc/test_deepseekv32_ep.sh +++ b/test/acc/test_deepseekv32_ep.sh @@ -1,6 +1,4 @@ -LOADWORKER=14 python -m lightllm.server.api_server --model_dir /mtc/sufubao/DeepSeek-V3.2 --tp 8 --graph_max_batch_size 32 --tool_call_parser deepseekv32 --mem_fraction 0.8 --reasoning_parser deepseek-v3 --dp 4 --enable_ep_moe +LOADWORKER=14 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V3.2 --tp 8 --graph_max_batch_size 32 --tool_call_parser deepseekv32 --mem_fraction 0.8 --reasoning_parser deepseek-v3 --dp 8 --enable_ep_moe --port 8000 -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-V3.2", "base_url":"http://localhost:8000/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code - -export no_proxy="localhost,127.0.0.1,::1" \ No newline at end of file +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 no_proxy=127.0.0.1,localhost,::1 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-V3.2", "base_url":"http://localhost:8000/v1/completions", "max_length": 16384, "tokenizer":"/mtc/models/DeepSeek-V3.2"}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_pd.sh b/test/acc/test_pd.sh new file mode 100644 index 0000000000..ee94e73e91 --- /dev/null +++ b/test/acc/test_pd.sh @@ -0,0 +1,74 @@ +$pd_master_ip 为本机的ip地址, 测试的时候,自己修改为对应的ip地址 + +# 启动pd_master节点 +# 测试前关闭代理 +export http_proxy= +export https_proxy= +python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --run_mode "pd_master" --host $pd_master_ip --port 8089 + +# 启动prefill 节点 +$host 为本机的ip地址, 测试的时候,自己修改为对应的ip地址 +$pd_master_ip 为pd_master的ip地址, 测试的时候,自己修改为对应的ip地址,在测试的时候为本机ip地址 +# 测试前关闭代理 +export http_proxy= +export https_proxy= +# 设置ucx环境变量, 走 rdma 传输数据, 排除环境中的数据网卡,避免影响性能。 +export UCX_NET_DEVICES=$(ibv_devinfo | grep 'hca_id:' | grep -v -E 'mlx5_8|mlx5_9' | awk '{print $2":1"}' | paste -sd, -) +export UCX_LOG_LEVEL=info +export UCX_TLS=rc,cuda,gdr_copy +LOADWORKER=18 CUDA_VISIBLE_DEVICES=0,1 python -m lightllm.server.api_server \ +--model_dir /mtc/models/qwen3-8b \ +--run_mode "prefill" \ +--tp 2 \ +--dp 1 \ +--host $host \ +--port 8001 \ +--disable_cudagraph \ +--pd_master_ip $pd_master_ip \ +--pd_master_port 8089 + +# 启动 decode 节点 +# 测试前关闭代理 +export http_proxy= +export https_proxy= +# 设置ucx环境变量, 走 rdma 传输数据, 排除环境中的数据网卡,避免影响性能。 +export UCX_NET_DEVICES=$(ibv_devinfo | grep 'hca_id:' | grep -v -E 'mlx5_8|mlx5_9' | awk '{print $2":1"}' | paste -sd, -) +export UCX_LOG_LEVEL=info +export UCX_TLS=rc,cuda,gdr_copy +$host 为本机的ip地址, 测试的时候,自己修改为对应的ip地址 +$pd_master_ip 为pd_master的ip地址, 测试的时候,自己修改为对应的ip地址,在测试的时候为本机ip地址 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=2,3 python -m lightllm.server.api_server \ +--model_dir /mtc/models/qwen3-8b \ +--run_mode "decode" \ +--tp 2 \ +--dp 1 \ +--host $host \ +--port 8002 \ +--pd_master_ip $pd_master_ip \ +--pd_master_port 8089 + +# 等待 prefill 和 decode 节点启动完成,并连上 pd master以后,执行测试脚本 +# 测试前关闭代理 +export http_proxy= +export https_proxy= +$pd_master_ip 为pd_master的ip地址, 测试的时候,自己修改为对应的ip地址 +# warm up +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval \ +--model local-completions --model_args \ +'{"model":"qwen/qwen3-8b", "base_url":"http://$pd_master_ip:8089/v1/completions", "max_length": 16384, "tokenized_requests": false}' \ +--tasks gsm8k --batch_size 1 --confirm_run_unsafe_code --limit 1 + +# test +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval \ +--model local-completions --model_args \ +'{"model":"qwen/qwen3-8b", "base_url":"http://$pd_master_ip:8089/v1/completions", "max_length": 16384, "tokenized_requests": false}' \ +--tasks gsm8k --batch_size 36 --confirm_run_unsafe_code + +# 1. 按顺序在不同的cmd中启动上面的程序,然后再执行评测脚本,将结果写入out.txt 中,注意需要标记启动的参数和结果信息。 +# 2. 执行评测命令的时候,需要用no_proxy 将本地local ip 排除。 +# 3. 不要写额外的脚本来启动服务,就是单独一个一个的按照上面的描述启动服务,然后再执行评测脚本,然后注意等待服务启动完成,可以20s检测一次其控制台输出,看是否启动完成,还是启动报错。 +# 4. 最后需要总结下测试的结果,并将结果输出到对话中。 +# 5. 如果启动过程中出现错误,需要记录错误信息,并输出到对话中。 +# 6. 测试完成后,关闭所有启动的进程。 +# 7. lm_eval 的评测命令有时候需要利用代理去下载一些缓存,所以可以先不关闭代码,跑一次lm_eval对应的命令,等cache下载好了以后,再关闭代理,跑第二次评测命令。 +# 8. 同一组测试的log要放在一个目录下,这样好查询。 \ No newline at end of file diff --git a/test/acc/test_qwen2.5_fp8kv_sph.sh b/test/acc/test_qwen2.5_fp8kv_sph.sh index 834102ced8..660000d4a0 100644 --- a/test/acc/test_qwen2.5_fp8kv_sph.sh +++ b/test/acc/test_qwen2.5_fp8kv_sph.sh @@ -4,4 +4,8 @@ LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/Qwen2.5-14B-Instruct --tp 2 --port 8089 --llm_kv_type fp8kv_sph --kv_quant_calibration_config_path /mtc/wzj/lightllm_dev/LightLLM/test/advanced_config/fp8_calibration_per_head/test_kv_cache_calib_per_head_qwen2.5_14b.json # second -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-14B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 64 --confirm_run_unsafe_code \ No newline at end of file +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-14B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 64 --confirm_run_unsafe_code + +# 帮我写一段提示词,告诉AI单独一个一个的进行上述测试的启动服务,然后再执行评测脚本,将结果写入out.txt 中,注意需要标记启动的参数和结果信息。不要用health 接口去判断服务是否启动,直接探测端口是否处于listen状态即可, 执行评测命令的时候,需要用no_proxy 将本地local ip 排除。 +# 不要写额外的脚本来启动服务,就是单独一个一个的按照上面的描述启动服务,然后再执行评测脚本,然后注意等待服务启动完成,可以20s检测一次其控制台输出,看是否启动完成,还是启动报错。 +# 应该把server启动在后台,然后再去探测端口, 判断服务是否启动成功。最后需要总结下测试的结果。 \ No newline at end of file diff --git a/test/acc/test_qwen2.5_fp8kv_spt.sh b/test/acc/test_qwen2.5_fp8kv_spt.sh index 72a0a8a1e6..7efad967c5 100644 --- a/test/acc/test_qwen2.5_fp8kv_spt.sh +++ b/test/acc/test_qwen2.5_fp8kv_spt.sh @@ -4,4 +4,9 @@ LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/Qwen2.5-14B-Instruct --tp 2 --port 8089 --llm_kv_type fp8kv_spt --kv_quant_calibration_config_path ./test/advanced_config/fp8_calibration_per_tensor/test_kv_cache_calib_per_tensor_qwen2.5_14b.json # second -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-14B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 64 --confirm_run_unsafe_code \ No newline at end of file +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-14B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 64 --confirm_run_unsafe_code + + +# 帮我写一段提示词,告诉AI单独一个一个的进行上述测试的启动服务,然后再执行评测脚本,将结果写入out.txt 中,注意需要标记启动的参数和结果信息。不要用health 接口去判断服务是否启动,直接探测端口是否处于listen状态即可, 执行评测命令的时候,需要用no_proxy 将本地local ip 排除。 +# 不要写额外的脚本来启动服务,就是单独一个一个的按照上面的描述启动服务,然后再执行评测脚本,然后注意等待服务启动完成,可以20s检测一次其控制台输出,看是否启动完成,还是启动报错。 +# 应该把server启动在后台,然后再去探测端口, 判断服务是否启动成功。最后需要总结下测试的结果。 \ No newline at end of file diff --git a/test/acc/test_qwen3.5.sh b/test/acc/test_qwen3.5.sh new file mode 100644 index 0000000000..a590aa0835 --- /dev/null +++ b/test/acc/test_qwen3.5.sh @@ -0,0 +1,71 @@ +/root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17 + +# first 测试基础功能 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ +--model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17 \ +--tp 2 \ +--port 8089 + +# second +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + +# prefill cuda graph 功能测试 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ +--model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17 \ +--tp 2 \ +--port 8089 \ +--enable_prefill_cudagraph + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + + +# 测试 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ +--model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17 \ +--tp 2 \ +--port 8089 \ +--linear_att_cache_size 10 \ +--linear_att_hash_page_size 256 \ +--linear_att_page_block_num 2 \ +--max_total_token_num 200000 + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + +# 测试 cpu cache 与 linear att 的配合是否正常 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ +--model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17 \ +--tp 2 \ +--port 8089 \ +--linear_att_cache_size 10 \ +--linear_att_hash_page_size 256 \ +--linear_att_page_block_num 2 \ +--max_total_token_num 20000 \ +--enable_cpu_cache \ +--cpu_cache_storage_size 128 + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions \ +--model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' \ +--tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + + +# disk cache test +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH=128 python -m lightllm.server.api_server \ +--model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17 \ +--tp 2 --port 8089 \ +--linear_att_cache_size 128 \ +--linear_att_hash_page_size 256 \ +--linear_att_page_block_num 32 \ +--max_total_token_num 200000 \ +--enable_cpu_cache \ +--cpu_cache_storage_size 32 \ +--enable_disk_cache \ +--disk_cache_storage_size 512 \ +--disk_cache_dir /mtc/test/tmp/ + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions \ +--model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' \ +--tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + +# 帮我写一段提示词,告诉AI单独一个一个的进行上述测试的启动服务,然后再执行评测脚本,将结果写入out.txt 中,注意需要标记启动的参数和结果信息。不要用health 接口去判断服务是否启动,直接探测端口是否处于listen状态即可, 执行评测命令的时候,需要用no_proxy 将本地local ip 排除。 +# 不要写额外的脚本来启动服务,就是单独一个一个的按照上面的描述启动服务,然后再执行评测脚本,然后注意等待服务启动完成,可以20s检测一次其控制台输出,看是否启动完成,还是启动报错。 +# 应该把server启动在后台,然后再去探测端口, 判断服务是否启动成功。最后需要总结下测试的结果。 如果是 cpu cache 和 硬盘cache的测试, lmeval要跑两次,确认命中后的效率。 \ No newline at end of file diff --git a/test/acc/test_qwen3.sh b/test/acc/test_qwen3.sh index 2714a92d7b..7f7181ef06 100644 --- a/test/acc/test_qwen3.sh +++ b/test/acc/test_qwen3.sh @@ -2,4 +2,38 @@ LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 # second -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + +# test quant +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --quant_type vllm-fp8w8a8 + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + + +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --enable_tpsp_mix_mode + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + + +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --dp 2 --port 8089 --enable_tpsp_mix_mode --enable_dp_prefill_balance + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + + +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --dp 2 --port 8089 --max_total_token_num 200000 --enable_cpu_cache --cpu_cache_storage_size 128 --cpu_cache_token_page_size 128 + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + + +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --dp 2 --port 8089 --max_total_token_num 200000 --enable_cpu_cache --cpu_cache_storage_size 128 --cpu_cache_token_page_size 128 --llm_kv_type int8kv + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + + +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH=128 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --dp 2 --port 8089 --max_total_token_num 200000 --enable_cpu_cache --cpu_cache_storage_size 64 --cpu_cache_token_page_size 128 --enable_disk_cache --disk_cache_storage_size 256 --disk_cache_dir /mtc/test/tmp/ + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + +# 帮我写一段提示词,告诉AI单独一个一个的进行上述测试的启动服务,然后再执行评测脚本,将结果写入out.txt 中,注意需要标记启动的参数和结果信息。不要用health 接口去判断服务是否启动,直接探测端口是否处于listen状态即可, 执行评测命令的时候,需要用no_proxy 将本地local ip 排除。 +# 不要写额外的脚本来启动服务,就是单独一个一个的按照上面的描述启动服务,然后再执行评测脚本,然后注意等待服务启动完成,可以20s检测一次其控制台输出,看是否启动完成,还是启动报错。 +# 应该把server启动在后台,然后再去探测端口, 判断服务是否启动成功。最后需要总结下测试的结果。对于开启 cpu cache 和 disk cache的测试,需要跑两次,确认命中后的效率和精度。 \ No newline at end of file diff --git a/test/advanced_config/mixed_quantization/qwen3_5-122b-moe-only-fp8.yaml b/test/advanced_config/mixed_quantization/qwen3_5-122b-moe-only-fp8.yaml new file mode 100644 index 0000000000..71a4524dcc --- /dev/null +++ b/test/advanced_config/mixed_quantization/qwen3_5-122b-moe-only-fp8.yaml @@ -0,0 +1,4 @@ +quant_type: none +mix_bits: + - name: "fused_moe" + quant_type: "deepgemm-fp8w8a8-b128" diff --git a/test/benchmark/service/benchmark_client.py b/test/benchmark/service/benchmark_client.py index 09009fc9e1..3f55bcab1e 100644 --- a/test/benchmark/service/benchmark_client.py +++ b/test/benchmark/service/benchmark_client.py @@ -27,6 +27,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_output_length(input_num: int, output_len: int) -> List[int]: min_len, max_len = 2, output_len * 2 mean = (min_len + max_len) * 0.5 @@ -162,7 +169,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) diff --git a/test/benchmark/service/benchmark_multiturn.py b/test/benchmark/service/benchmark_multiturn.py new file mode 100644 index 0000000000..897d125077 --- /dev/null +++ b/test/benchmark/service/benchmark_multiturn.py @@ -0,0 +1,753 @@ +""" +Multi-turn dialogue benchmark for LightLLM. + +For each concurrency level in --concurrency_levels, launches N concurrent +"sessions". Each session starts from a prompt of ~start_input_len tokens +(with a per-session random prefix so different sessions don't share KV +cache) and keeps issuing streaming requests turn by turn. After every +turn, deterministic synthetic assistant tokens plus a dynamically sampled +number of new user tokens are appended to the prompt. This keeps the exact +request stream reproducible for a fixed seed. +A session stops when the next prompt would exceed max_input_len, or +after max_turns turns. + +Metrics aggregated per concurrency level: + - TTFT (Time To First Token, ms): per-turn first-byte latency + - TPOT (Time Per Output Token, ms): mean inter-token gap after TTFT + - QPS (turns / wall_time) + - TPM ((prompt_tokens + completion_tokens) / wall_time * 60) + - Cache hit ratio = sum(cached_tokens) / sum(prompt_tokens) across turns + +The OpenAI v1/completions streaming endpoint is used because its final +`usage` chunk carries `prompt_tokens_details.cached_tokens`, which is +how prompt-cache hit length is exposed to clients. + +Example: + python benchmark_multiturn.py \\ + --url http://127.0.0.1:8000/v1/completions \\ + --tokenizer_path /path/to/tokenizer \\ + --model_name my-model \\ + --concurrency_levels 1,4,8,16 \\ + --start_input_len 1024 \\ + --max_input_len 16384 \\ + --turn_input_increment 256 \\ + --output_len 256 +""" + +import argparse +import json +import os +import random +import threading +import time +import urllib.parse +import urllib.request +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import requests +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +_DEFAULT_TRANSIENT_RETRIES = 2 +_PROMPT_LEN_OVERLAP_CHARS = 512 +_TRANSIENT_STREAM_ERRORS = ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + requests.exceptions.Timeout, +) + + +def seed_all(seed: int) -> None: + if not seed: + seed = int(time.time()) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + + +def get_tokenizer(tokenizer_name: str) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + + +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + +def get_models_url(completions_url: str) -> str: + parsed = urllib.parse.urlsplit(completions_url) + path = parsed.path.rstrip("/") + for suffix in ("/chat/completions", "/completions"): + if path.endswith(suffix): + path = path[: -len(suffix)] + "/models" + return urllib.parse.urlunsplit(parsed._replace(path=path, query="", fragment="")) + return urllib.parse.urlunsplit(parsed._replace(path="/v1/models", query="", fragment="")) + + +def fetch_served_model_names(completions_url: str, timeout_s: int = 10) -> List[str]: + models_url = get_models_url(completions_url) + request = urllib.request.Request(models_url, headers={"Accept": "application/json"}) + with urllib.request.urlopen(request, timeout=timeout_s) as response: + payload = json.loads(response.read().decode("utf-8")) + return [item["id"] for item in payload.get("data", []) if item.get("id")] + + +def resolve_model_name( + completions_url: str, + requested_model_name: str, + explicit_model_name: bool, +) -> Tuple[str, Optional[str]]: + normalized_name = normalize_model_name(requested_model_name) + if normalized_name != requested_model_name: + note = f"Normalized model name from `{requested_model_name}` to `{normalized_name}`." + else: + note = None + + try: + served_model_names = fetch_served_model_names(completions_url) + except Exception as exc: + if note is not None: + note = f"{note} Failed to query served models: {exc}." + return normalized_name, note + + if requested_model_name in served_model_names: + return requested_model_name, note + if normalized_name in served_model_names: + if normalized_name != requested_model_name: + return normalized_name, ( + f"Normalized model name from `{requested_model_name}` to `{normalized_name}` " "to match `/v1/models`." + ) + return normalized_name, note + + requested_basename = os.path.basename(normalized_name) + basename_matches = [ + served_name + for served_name in served_model_names + if os.path.basename(normalize_model_name(served_name)) == requested_basename + ] + if len(basename_matches) == 1: + matched_name = basename_matches[0] + return matched_name, ( + f"Resolved model name `{requested_model_name}` to served model `{matched_name}` " "via `/v1/models`." + ) + + if not explicit_model_name and len(served_model_names) == 1: + matched_name = served_model_names[0] + return matched_name, ( + f"Using the only served model `{matched_name}` returned by `/v1/models` " + f"instead of `{requested_model_name}`." + ) + + if note is not None: + note = ( + f"{note} Available served models: {', '.join(served_model_names) or '(none)'}. " + f"Using `{normalized_name}`." + ) + return normalized_name, note + + +def gen_random_token_ids(tokenizer, n: int, rng: random.Random) -> List[int]: + vocab = tokenizer.vocab_size + return [rng.randint(0, vocab - 1) for _ in range(n)] + + +def decode_ids(tokenizer, ids: List[int]) -> str: + return tokenizer.decode(ids, skip_special_tokens=False) + + +def gen_session_initial_prompt( + tokenizer, + start_input_len: int, + session_seed: int, +) -> Tuple[str, int]: + """Build the initial prompt for a session. The prefix is unique per + session so that prefix-cache hits across sessions are not counted.""" + rng = random.Random(session_seed) + ids = gen_random_token_ids(tokenizer, start_input_len, rng) + text = decode_ids(tokenizer, ids) + # Re-encode so that the recorded token length matches what the server + # will tokenize. Random ids -> decode -> re-encode is not lossless. + real_ids = tokenizer.encode(text, add_special_tokens=False) + return text, len(real_ids) + + +def append_turn_input( + tokenizer, + prompt: str, + prompt_token_len: int, + assistant_token_count: int, + turn_input_increment: int, + rng: random.Random, +) -> Tuple[str, int]: + """Append deterministic synthetic assistant/user text to the prompt. + + The benchmark measures server output, but the next request must not depend + on that output; otherwise repeated runs with the same seed can diverge. + """ + if assistant_token_count > 0: + assistant_ids = gen_random_token_ids(tokenizer, assistant_token_count, rng) + assistant_text = decode_ids(tokenizer, assistant_ids) + else: + assistant_text = "" + + if turn_input_increment > 0: + user_ids = gen_random_token_ids(tokenizer, turn_input_increment, rng) + user_text = decode_ids(tokenizer, user_ids) + else: + user_text = "" + + appended_text = assistant_text + user_text + new_prompt = prompt + appended_text + if not appended_text: + return new_prompt, prompt_token_len + + # Token merges only depend on a small boundary window, so avoid + # re-encoding the entire prompt on every turn. + overlap_text = prompt[-_PROMPT_LEN_OVERLAP_CHARS:] + if overlap_text: + overlap_token_len = len(tokenizer.encode(overlap_text, add_special_tokens=False)) + merged_token_len = len(tokenizer.encode(overlap_text + appended_text, add_special_tokens=False)) + appended_token_len = max(merged_token_len - overlap_token_len, 0) + else: + appended_token_len = len(tokenizer.encode(appended_text, add_special_tokens=False)) + new_len = prompt_token_len + appended_token_len + return new_prompt, new_len + + +def stream_one_turn( + tokenizer, + url: str, + model_name: str, + prompt: str, + prompt_token_len: int, + max_new_tokens: int, + request_timeout_s: int, + max_retries: int = _DEFAULT_TRANSIENT_RETRIES, +) -> Optional[Dict]: + """Send one streaming completion request, return per-turn stats: + { + "ttft": float seconds, + "decode_times": [float seconds, ...], # gaps between subsequent tokens + "prompt_tokens": int, + "completion_tokens": int, + "cached_tokens": int, + "cached_tokens_reported": bool, + "usage_estimated": bool, + "generated_text": str, + } + Returns None on failure.""" + payload = { + "model": model_name, + "prompt": prompt, + "max_tokens": max_new_tokens, + "temperature": 0.0, + "ignore_eos": True, + "stream": True, + "stream_options": {"include_usage": True}, + } + headers = {"Content-Type": "application/json"} + + for attempt in range(max_retries + 1): + start_time = time.time() + first_token_time: Optional[float] = None + last_token_time: Optional[float] = None + decode_times: List[float] = [] + generated_text_parts: List[str] = [] + prompt_tokens = 0 + completion_tokens = 0 + cached_tokens = 0 + cached_tokens_reported = False + + try: + with requests.Session() as req_session: + req_session.trust_env = False + with req_session.post( + url, + headers=headers, + json=payload, + stream=True, + timeout=(10, request_timeout_s), + ) as response: + if response.status_code != 200: + err = response.text + if response.status_code >= 500 and attempt < max_retries: + time.sleep(0.2 * (attempt + 1)) + continue + print(f"\n[turn failed] status={response.status_code} body={err[:200]}") + return None + + for raw in response.iter_lines(): + if not raw: + continue + line = raw.strip() + if not line.startswith(b"data:"): + continue + data_str = line[len(b"data:") :].strip() + if data_str == b"[DONE]": + break + try: + chunk = json.loads(data_str) + except Exception: + continue + + # Final usage-only chunk: choices == [] and usage present + usage = chunk.get("usage") + choices = chunk.get("choices") or [] + if usage is not None and not choices: + prompt_tokens = usage.get("prompt_tokens", prompt_tokens) + completion_tokens = usage.get("completion_tokens", completion_tokens) + details = usage.get("prompt_tokens_details") + if isinstance(details, dict) and details.get("cached_tokens") is not None: + cached_tokens = details["cached_tokens"] + cached_tokens_reported = True + continue + + # Token-bearing chunk + if not choices: + continue + text_piece = choices[0].get("text", "") + if text_piece == "" and choices[0].get("finish_reason") is None: + continue + + now = time.time() + if first_token_time is None: + first_token_time = now + else: + decode_times.append(now - last_token_time) + last_token_time = now + if text_piece: + generated_text_parts.append(text_piece) + except _TRANSIENT_STREAM_ERRORS as e: + if first_token_time is None and attempt < max_retries: + time.sleep(0.2 * (attempt + 1)) + continue + + if first_token_time is not None: + print(f"\n[turn warning] {e}; discarding partial turn (attempt={attempt + 1})") + return None + + print(f"\n[turn exception] {e}") + return None + except Exception as e: + print(f"\n[turn exception] {e}") + return None + + if first_token_time is None: + if attempt < max_retries: + time.sleep(0.2 * (attempt + 1)) + continue + return None + + generated_text = "".join(generated_text_parts) + usage_estimated = False + if prompt_tokens == 0: + prompt_tokens = prompt_token_len + usage_estimated = True + if completion_tokens == 0: + estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False)) + completion_tokens = max(estimated_completion_tokens, len(generated_text_parts)) + usage_estimated = True + + return { + "ttft": first_token_time - start_time, + "decode_times": decode_times, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "cached_tokens": cached_tokens, + "cached_tokens_reported": cached_tokens_reported, + "usage_estimated": usage_estimated, + "generated_text": generated_text, + } + + return None + + +def run_session( + session_id: int, + tokenizer, + url: str, + model_name: str, + start_input_len: int, + max_input_len: int, + min_turn_input_increment: int, + turn_input_increment: int, + min_output_len: int, + output_len: int, + max_turns: int, + base_seed: int, + request_timeout_s: int, + progress_state: Dict, + progress_lock: threading.Lock, +) -> List[Dict]: + """Run a single multi-turn dialogue session. Returns a list of per-turn + stat dicts (same schema as stream_one_turn output).""" + rng = random.Random(base_seed + session_id) + prompt, prompt_len = gen_session_initial_prompt(tokenizer, start_input_len, base_seed + session_id) + + per_turn: List[Dict] = [] + turn_idx = 0 + try: + while turn_idx < max_turns and prompt_len < max_input_len: + turn_output_len = rng.randint(min_output_len, output_len) + result = stream_one_turn( + tokenizer=tokenizer, + url=url, + model_name=model_name, + prompt=prompt, + prompt_token_len=prompt_len, + max_new_tokens=turn_output_len, + request_timeout_s=request_timeout_s, + ) + if result is None: + break + per_turn.append(result) + with progress_lock: + progress_state["finished_turns"] += 1 + print( + f"\rconc={progress_state['concurrency']} " + f"finished_turns={progress_state['finished_turns']} " + f"active_sessions={progress_state['active_sessions']}\033[K", + end="", + flush=True, + ) + turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment) + prompt, prompt_len = append_turn_input( + tokenizer, + prompt, + prompt_len, + turn_output_len, + turn_input_len, + rng, + ) + turn_idx += 1 + finally: + with progress_lock: + progress_state["active_sessions"] -= 1 + return per_turn + + +def run_concurrency_level( + concurrency: int, + tokenizer, + url: str, + model_name: str, + start_input_len: int, + max_input_len: int, + min_turn_input_increment: int, + turn_input_increment: int, + min_output_len: int, + output_len: int, + max_turns: int, + base_seed: int, + request_timeout_s: int, +) -> Dict: + """Run one concurrency level. Returns the aggregated stats dict.""" + progress_state = { + "concurrency": concurrency, + "finished_turns": 0, + "active_sessions": concurrency, + } + progress_lock = threading.Lock() + + wall_start = time.time() + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = [ + executor.submit( + run_session, + sid, + tokenizer, + url, + model_name, + start_input_len, + max_input_len, + min_turn_input_increment, + turn_input_increment, + min_output_len, + output_len, + max_turns, + base_seed, + request_timeout_s, + progress_state, + progress_lock, + ) + for sid in range(concurrency) + ] + session_results: List[List[Dict]] = [] + for fut in as_completed(futures): + session_results.append(fut.result()) + wall_end = time.time() + wall_time = max(wall_end - wall_start, 1e-9) + print() # newline after progress bar + + all_turns: List[Dict] = [t for s in session_results for t in s] + return summarize( + concurrency=concurrency, + turns=all_turns, + wall_time=wall_time, + num_sessions=concurrency, + max_turns=max_turns, + ) + + +def summarize( + concurrency: int, + turns: List[Dict], + wall_time: float, + num_sessions: int, + max_turns: int, +) -> Dict: + percentiles = [50, 75, 90, 95, 99] + out: Dict = { + "concurrency": concurrency, + "num_sessions": num_sessions, + "max_turns_per_session": max_turns, + "total_turns": len(turns), + "wall_time_s": round(wall_time, 4), + } + + if not turns: + out["error"] = "no successful turns" + return out + + ttfts_ms = [t["ttft"] * 1000.0 for t in turns] + # TPOT per turn = mean of decode_times (skip turns with <2 tokens) + tpots_ms: List[float] = [] + for t in turns: + if t["decode_times"]: + tpots_ms.append(1000.0 * sum(t["decode_times"]) / len(t["decode_times"])) + prompt_tokens = sum(t["prompt_tokens"] for t in turns) + completion_tokens = sum(t["completion_tokens"] for t in turns) + cached_tokens = sum(t["cached_tokens"] for t in turns) + cached_tokens_reported_turns = sum(1 for t in turns if t.get("cached_tokens_reported")) + usage_estimated_turns = sum(1 for t in turns if t.get("usage_estimated")) + total_tokens = prompt_tokens + completion_tokens + + qps = len(turns) / wall_time + tpm_total = total_tokens / wall_time * 60.0 + tpm_prompt = prompt_tokens / wall_time * 60.0 + tpm_completion = completion_tokens / wall_time * 60.0 + + out["QPS"] = round(qps, 4) + out["TPM_total"] = round(tpm_total, 2) + out["TPM_prompt"] = round(tpm_prompt, 2) + out["TPM_completion"] = round(tpm_completion, 2) + out["total_prompt_tokens"] = prompt_tokens + out["total_completion_tokens"] = completion_tokens + out["total_cached_prompt_tokens"] = cached_tokens + out["cached_tokens_reported_turns"] = cached_tokens_reported_turns + out["usage_estimated_turns"] = usage_estimated_turns + if cached_tokens_reported_turns > 0: + cache_hit_ratio = cached_tokens / prompt_tokens if prompt_tokens else 0.0 + out["cache_hit_ratio"] = round(cache_hit_ratio, 6) + else: + out["cache_hit_ratio"] = None + out["cache_hit_ratio_note"] = ( + "Server did not return usage.prompt_tokens_details.cached_tokens. " + "For vLLM OpenAI-compatible APIs, start the server with " + "--enable-prompt-tokens-details to expose cache-hit stats." + ) + out["avg_prompt_tokens_per_turn"] = round(prompt_tokens / len(turns), 2) + out["avg_completion_tokens_per_turn"] = round(completion_tokens / len(turns), 2) + + ttft_pcts = np.percentile(ttfts_ms, percentiles) + out["TTFT_ms"] = {"mean": round(float(np.mean(ttfts_ms)), 3)} + for p, v in zip(percentiles, ttft_pcts): + out["TTFT_ms"][f"P{p}"] = round(float(v), 3) + + if tpots_ms: + tpot_pcts = np.percentile(tpots_ms, percentiles) + out["TPOT_ms"] = {"mean": round(float(np.mean(tpots_ms)), 3)} + for p, v in zip(percentiles, tpot_pcts): + out["TPOT_ms"][f"P{p}"] = round(float(v), 3) + else: + out["TPOT_ms"] = {"mean": None, "note": "all turns produced <2 tokens"} + + return out + + +def print_summary(summary: Dict) -> None: + print("=" * 80) + print( + f"Concurrency = {summary['concurrency']} sessions = {summary['num_sessions']} " + f"total_turns = {summary['total_turns']} wall_time = {summary['wall_time_s']}s" + ) + if "error" in summary: + print(f" ERROR: {summary['error']}") + return + print(f" QPS : {summary['QPS']}") + print(f" TPM (total) : {summary['TPM_total']}") + print(f" TPM (prompt) : {summary['TPM_prompt']}") + print(f" TPM (completion) : {summary['TPM_completion']}") + if summary["cache_hit_ratio"] is None: + print(" Cache hit ratio : n/a") + print(f" Cache hit note : {summary['cache_hit_ratio_note']}") + else: + print( + f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% " + f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})" + ) + if summary.get("usage_estimated_turns"): + print(f" Usage estimated : {summary['usage_estimated_turns']} turns") + print(f" Avg prompt tokens : {summary['avg_prompt_tokens_per_turn']}") + print(f" Avg output tokens : {summary['avg_completion_tokens_per_turn']}") + ttft = summary["TTFT_ms"] + tpot = summary["TPOT_ms"] + print( + f" TTFT ms mean={ttft['mean']} P50={ttft.get('P50')} P90={ttft.get('P90')} " + f"P95={ttft.get('P95')} P99={ttft.get('P99')}" + ) + if tpot.get("mean") is None: + print(f" TPOT ms (n/a: {tpot.get('note')})") + else: + print( + f" TPOT ms mean={tpot['mean']} P50={tpot.get('P50')} P90={tpot.get('P90')} " + f"P95={tpot.get('P95')} P99={tpot.get('P99')}" + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--url", + type=str, + default="http://127.0.0.1:8000/v1/completions", + help="Streaming OpenAI completion endpoint. The benchmark relies on " + "the final SSE `usage` chunk to obtain cached_tokens.", + ) + parser.add_argument("--tokenizer_path", type=str, required=True) + parser.add_argument( + "--model_name", + type=str, + default=None, + help="Model name passed to the server. Defaults to --tokenizer_path.", + ) + parser.add_argument( + "--concurrency_levels", + type=str, + default="1,4,8,16,32,64,128,256", + help="Comma-separated list of concurrency levels to sweep.", + ) + parser.add_argument( + "--start_input_len", type=int, default=32768, help="Initial prompt length in tokens per session." + ) + parser.add_argument( + "--max_input_len", type=int, default=163840, help="Stop a session when its prompt exceeds this length." + ) + parser.add_argument( + "--turn_input_increment", + type=int, + default=2048, + help="Maximum new 'user' tokens sampled after each turn, on top of deterministic synthetic assistant tokens.", + ) + parser.add_argument( + "--min_turn_input_increment", type=int, default=512, help="Minimum new 'user' tokens sampled after each turn." + ) + parser.add_argument("--output_len", type=int, default=512, help="Maximum max_new_tokens sampled per turn.") + parser.add_argument("--min_output_len", type=int, default=128, help="Minimum max_new_tokens sampled per turn.") + parser.add_argument( + "--max_turns", + type=int, + default=64, + help="Hard cap on turns per session. The session also stops once " "prompt length reaches --max_input_len.", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--request_timeout_s", type=int, default=3600) + parser.add_argument( + "--dump_file", + type=str, + default="", + help="If set, append the per-concurrency summary dict to this JSON file. " + "If the file already exists and is non-empty, it is read and printed.", + ) + + args = parser.parse_args() + + if args.min_output_len < 1: + raise ValueError("--min_output_len must be >= 1") + if args.min_output_len > args.output_len: + raise ValueError("--min_output_len must be <= --output_len") + if args.min_turn_input_increment < 0: + raise ValueError("--min_turn_input_increment must be >= 0") + if args.min_turn_input_increment > args.turn_input_increment: + raise ValueError("--min_turn_input_increment must be <= --turn_input_increment") + + if args.dump_file and os.path.exists(args.dump_file) and os.path.getsize(args.dump_file) > 0: + with open(args.dump_file, "r") as f: + print(json.dumps(json.load(f), indent=4)) + return + + seed_all(args.seed) + requested_model_name = args.model_name or args.tokenizer_path + model_name, model_name_note = resolve_model_name( + args.url, + requested_model_name, + explicit_model_name=args.model_name is not None, + ) + tokenizer = get_tokenizer(args.tokenizer_path) + concurrency_levels = [int(x) for x in args.concurrency_levels.split(",") if x.strip()] + + print(f"URL : {args.url}") + print(f"Model : {model_name}") + if model_name_note: + print(f"Model note : {model_name_note}") + print(f"Concurrency levels : {concurrency_levels}") + print(f"start_input_len : {args.start_input_len}") + print(f"max_input_len : {args.max_input_len}") + print(f"min_turn_input_increment: {args.min_turn_input_increment}") + print(f"turn_input_increment: {args.turn_input_increment}") + print(f"min_output_len : {args.min_output_len}") + print(f"output_len : {args.output_len}") + print(f"max_turns : {args.max_turns}") + + all_summaries: List[Dict] = [] + for concurrency in concurrency_levels: + summary = run_concurrency_level( + concurrency=concurrency, + tokenizer=tokenizer, + url=args.url, + model_name=model_name, + start_input_len=args.start_input_len, + max_input_len=args.max_input_len, + min_turn_input_increment=args.min_turn_input_increment, + turn_input_increment=args.turn_input_increment, + min_output_len=args.min_output_len, + output_len=args.output_len, + max_turns=args.max_turns, + base_seed=args.seed, + request_timeout_s=args.request_timeout_s, + ) + print_summary(summary) + all_summaries.append(summary) + + dump = { + "config": { + "url": args.url, + "model_name": model_name, + "requested_model_name": requested_model_name, + "tokenizer_path": args.tokenizer_path, + "concurrency_levels": concurrency_levels, + "start_input_len": args.start_input_len, + "max_input_len": args.max_input_len, + "min_turn_input_increment": args.min_turn_input_increment, + "turn_input_increment": args.turn_input_increment, + "min_output_len": args.min_output_len, + "output_len": args.output_len, + "max_turns": args.max_turns, + "seed": args.seed, + }, + "results": all_summaries, + } + print("\n" + "=" * 80) + print(json.dumps(dump, indent=4, ensure_ascii=False)) + if args.dump_file: + with open(args.dump_file, "w") as f: + json.dump(dump, f, indent=4, ensure_ascii=False) + print(f"\nResults dumped to {args.dump_file}") + + +if __name__ == "__main__": + main() diff --git a/test/benchmark/service/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py index 8249ae2c49..43f60b91d3 100644 --- a/test/benchmark/service/benchmark_qps.py +++ b/test/benchmark/service/benchmark_qps.py @@ -31,6 +31,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_random_length(reqs_num: int, length: int, range_ratio: float) -> List[int]: lens = [] lens = np.random.randint( @@ -101,6 +108,11 @@ def get_custom_input_data(data_path, output_len, tokenizer, range_ratio): model_name = [] +sampling_config = { + "temperature": 1.0, + "top_p": 0.9, + "top_k": -1, +} # Minimal fix: one retry on transient network errors. @@ -116,7 +128,9 @@ async def async_post_stream_openai(url, prompt, max_new_tokens, session): "max_tokens": max_new_tokens, "ignore_eos": True, "stream": True, - "temperature": 0.0, + "temperature": sampling_config["temperature"], + "top_p": sampling_config["top_p"], + "top_k": sampling_config["top_k"], "best_of": 1, } headers = {"Content-Type": "application/json"} @@ -159,9 +173,12 @@ async def async_post_stream_lightllm(url, prompt, max_new_tokens, session): data = { "inputs": text_input, "parameters": { - "do_sample": False, + "do_sample": True, "ignore_eos": True, "max_new_tokens": max_new_tokens, + "temperature": sampling_config["temperature"], + "top_p": sampling_config["top_p"], + "top_k": sampling_config["top_k"], "add_special_tokens": False, }, } @@ -394,6 +411,12 @@ def main(): ) parser.add_argument("--num_clients", type=int, default=100) parser.add_argument("--tokenizer_path", type=str, default=None) + parser.add_argument( + "--model_name", + type=str, + default=None, + help="Model name passed to the server. Defaults to --tokenizer_path.", + ) parser.add_argument("--data_path", type=str, default=None) parser.add_argument("--input_num", type=int, default=2000) parser.add_argument("--input_qps", type=float, default=30.0) @@ -429,7 +452,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(args.model_name if args.model_name is not None else normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 919f379b96..f2c900af09 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -439,4 +439,9 @@ def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, an ans_queue.put(True) - return + try: + ans_queue.close() + ans_queue.join_thread() + except Exception: + pass + os._exit(0) diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 07ad52a132..72f06a919c 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -276,4 +276,10 @@ def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, a dist.barrier() ans_queue.put(True) - return + + try: + ans_queue.close() + ans_queue.join_thread() + except Exception: + pass + os._exit(0) diff --git a/test/cpu_cache_kernel/test_speed.py b/test/cpu_cache_kernel/test_speed.py new file mode 100644 index 0000000000..254142050c --- /dev/null +++ b/test/cpu_cache_kernel/test_speed.py @@ -0,0 +1,276 @@ +""" +Speed benchmark for copy_cpu_cache_to_kv_buffer in linear_att_cpu_cache_copy.py. + +Test configuration (matching the user's LinearAttCacheConfig): + tp_world_size=8, full_att_all_num_kv_heads=2, full_att_dtype=torch.bfloat16, + full_att_num_kv_heads=1, full_att_head_dim=256, + num_linear_k_heads=2, num_linear_v_heads=8, + head_linear_k_dim=128, head_linear_v_dim=128, + conv_kernel_size=4, linear_layer_num=36, + conv_state_dtype=torch.bfloat16, ssm_state_dtype=torch.float32, + full_attention_interval=4, all_layer_num=48 +""" + +import os +import json +import time +import triton +import torch +from easydict import EasyDict + +# --------------------------------------------------------------------------- +# Step 0 – set up environment args BEFORE any import that calls +# get_env_start_args() / LinearAttCacheConfig.load_from_args(). +# --------------------------------------------------------------------------- +_env_args = { + "cpu_cache_token_page_size": 2048 * 8, # big_page_token_num + "linear_att_hash_page_size": 2048, + "linear_att_page_block_num": 8, # 512 * 1 == 512 + "data_type": "bfloat16", + "linear_att_ssm_data_type": "float32", + "model_dir": "/tmp/fake_model", # dummy – not used when config is built directly + "tp": 8, + "dp": 1, + "running_max_req_size": 2048, + "enable_cpu_cache": True, +} +os.environ["LIGHTLLM_START_ARGS"] = json.dumps(_env_args) + +# --------------------------------------------------------------------------- +# Step 1 – build LinearAttCacheConfig directly (avoids needing a real model dir) +# --------------------------------------------------------------------------- +from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + +linear_config = LinearAttCacheConfig( + tp_world_size=8, + full_att_all_num_kv_heads=2, + full_att_dtype=torch.bfloat16, + full_att_num_kv_heads=1, + full_att_head_dim=256, + num_linear_k_heads=2, + num_linear_v_heads=8, + head_linear_k_dim=128, + head_linear_v_dim=128, + conv_kernel_size=4, + linear_layer_num=36, + conv_state_dtype=torch.bfloat16, + ssm_state_dtype=torch.float32, + full_attention_interval=4, + all_layer_num=48, +) +print(f"LinearAttCacheConfig:\n{linear_config}\n", flush=True) + +# --------------------------------------------------------------------------- +# Step 2 – derive sizes from the config +# --------------------------------------------------------------------------- +big_page_token_num = _env_args["cpu_cache_token_page_size"] # 512 +full_att_layer_num = linear_config.all_layer_num // linear_config.full_attention_interval # 12 + +full_att_bytes = linear_config.get_cpu_cache_full_att_bytes() # per big page +conv_bytes = linear_config.get_cpu_cache_conv_bytes() +ssm_bytes = linear_config.get_cpu_cache_ssm_bytes() +total_bytes = full_att_bytes + conv_bytes + ssm_bytes +print( + f"Per-page bytes full_att={full_att_bytes:,} conv={conv_bytes:,} " f"ssm={ssm_bytes:,} total={total_bytes:,}", + flush=True, +) +total_bytes = linear_config.get_cpu_cache_big_page_bytes() + +# --------------------------------------------------------------------------- +# Step 3 – allocate tensors +# --------------------------------------------------------------------------- +grid_num = 8 +PAGE_NUM = 1 # number of big pages to copy per call +SEQ_LEN = 2048 * 8 # total sequence length in gpu_full_att_kv_state dim-1 +BIG_PAGE_COUNT = PAGE_NUM # big_page_buffer_ids length == page_indexes length + +# --- GPU tensors --- +mem_indexes = torch.arange(0, big_page_token_num * PAGE_NUM, dtype=torch.int64, device="cpu") +big_page_buffer_ids = torch.arange(0, BIG_PAGE_COUNT, dtype=torch.int64, device="cpu") +page_indexes = torch.arange(0, PAGE_NUM, dtype=torch.int32, device="cpu") + +gpu_full_att_kv_state = torch.empty( + ( + full_att_layer_num, + SEQ_LEN, + 2 * max(1, linear_config.full_att_num_kv_heads // linear_config.tp_world_size), + linear_config.full_att_head_dim, + ), + dtype=linear_config.full_att_dtype, + device="cuda", +) + +# --- CPU tensors --- +buffer_count = triton.cdiv(SEQ_LEN, big_page_token_num) + 2 # matches Qwen3NextMemManager + + +conv_shape = linear_config.get_conv_state_shape() +cpu_kv_conv_state = torch.empty( + (buffer_count, linear_config.linear_layer_num, *conv_shape), + dtype=linear_config.conv_state_dtype, + device="cpu", + pin_memory=True, +) + +ssm_shape = linear_config.get_ssm_state_shape() # (num_linear_v_heads, head_linear_k_dim, head_linear_v_dim) +cpu_kv_ssm_state = torch.empty( + (buffer_count, linear_config.linear_layer_num, *ssm_shape), + dtype=linear_config.ssm_state_dtype, + device="cpu", + pin_memory=True, +) + +# conv_shape = linear_config.get_conv_state_shape() +# cpu_kv_conv_state = torch.empty( +# (buffer_count, linear_config.linear_layer_num, *conv_shape), +# dtype=linear_config.conv_state_dtype, device="cuda", +# ) + +# ssm_shape = linear_config.get_ssm_state_shape() # (num_linear_v_heads, head_linear_k_dim, head_linear_v_dim) +# cpu_kv_ssm_state = torch.empty( +# (buffer_count, linear_config.linear_layer_num, *ssm_shape), +# dtype=linear_config.ssm_state_dtype, device="cuda", +# ) + + +# cpu_cache_tensor: [page_num, 1, 1, 1, total_bytes] +cpu_cache_tensor = torch.empty( + (PAGE_NUM, 1, 1, 1, total_bytes), + dtype=torch.uint8, + device="cpu", + pin_memory=True, +) + +# Move GPU tensors to CUDA +mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) +big_page_buffer_ids_cuda = big_page_buffer_ids.cuda(non_blocking=True) +page_indexes_cuda = page_indexes.cuda(non_blocking=True) +gpu_full_att_kv_state = gpu_full_att_kv_state.cuda(non_blocking=True) + +torch.cuda.synchronize() +print("All tensors allocated and moved to GPU.\n", flush=True) + +# --------------------------------------------------------------------------- +# Step 4 – import and warm-up the triton kernel +# --------------------------------------------------------------------------- +from lightllm.common.basemodel.triton_kernel.linear_att_cpu_cache_copy import ( + copy_cpu_cache_to_kv_buffer, +) + +print("Warming up …", flush=True) +copy_cpu_cache_to_kv_buffer( + mem_indexes=mem_indexes_cuda, + big_page_buffer_ids=big_page_buffer_ids_cuda, + page_indexes=page_indexes_cuda, + gpu_full_att_kv_state=gpu_full_att_kv_state, + cpu_kv_conv_state=cpu_kv_conv_state, + cpu_kv_ssm_state=cpu_kv_ssm_state, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=linear_config.tp_world_size, + big_page_token_num=big_page_token_num, + linear_config=linear_config, + grid_num=grid_num, +) +torch.cuda.synchronize() +print("Warm-up done.\n", flush=True) + +# --------------------------------------------------------------------------- +# Step 5 – benchmark +# --------------------------------------------------------------------------- +WARMUP_ITERS = 10 +BENCH_ITERS = 100 + +print(f"Benchmarking ({BENCH_ITERS} iterations, {PAGE_NUM} pages / {big_page_token_num} tokens each) …", flush=True) + +# Warm-up +for _ in range(WARMUP_ITERS): + copy_cpu_cache_to_kv_buffer( + mem_indexes=mem_indexes_cuda, + big_page_buffer_ids=big_page_buffer_ids_cuda, + page_indexes=page_indexes_cuda, + gpu_full_att_kv_state=gpu_full_att_kv_state, + cpu_kv_conv_state=cpu_kv_conv_state, + cpu_kv_ssm_state=cpu_kv_ssm_state, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=linear_config.tp_world_size, + big_page_token_num=big_page_token_num, + linear_config=linear_config, + grid_num=grid_num, + ) +torch.cuda.synchronize() + +# Timed runs +times = [] +for _ in range(BENCH_ITERS): + torch.cuda.synchronize() + t0 = time.perf_counter() + copy_cpu_cache_to_kv_buffer( + mem_indexes=mem_indexes_cuda, + big_page_buffer_ids=big_page_buffer_ids_cuda, + page_indexes=page_indexes_cuda, + gpu_full_att_kv_state=gpu_full_att_kv_state, + cpu_kv_conv_state=cpu_kv_conv_state, + cpu_kv_ssm_state=cpu_kv_ssm_state, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=linear_config.tp_world_size, + big_page_token_num=big_page_token_num, + linear_config=linear_config, + grid_num=grid_num, + ) + torch.cuda.synchronize() + t1 = time.perf_counter() + times.append(t1 - t0) + +# --------------------------------------------------------------------------- +# Step 6 – report +# --------------------------------------------------------------------------- +import statistics + +times_ms = [t * 1e3 for t in times] +total_tokens = PAGE_NUM * big_page_token_num + +# Calculate head_scale_size (same logic as in copy_cpu_cache_to_kv_buffer) +if linear_config.full_att_all_num_kv_heads % linear_config.tp_world_size == 0: + head_scale_size = 1 +else: + head_scale_size = linear_config.tp_world_size // linear_config.full_att_all_num_kv_heads + +# Each TP rank copies: +# - full_att_bytes / head_scale_size (full attention is sharded by head_scale_size) +# - conv_bytes / tp_world_size (conv state is sharded by tp_rank) +# - ssm_bytes / tp_world_size (ssm state is sharded by tp_rank) +full_att_bytes = linear_config.get_cpu_cache_full_att_bytes() +conv_bytes = linear_config.get_cpu_cache_conv_bytes() +ssm_bytes = linear_config.get_cpu_cache_ssm_bytes() + +bytes_per_page_per_tp = ( + full_att_bytes + * max(1, linear_config.full_att_all_num_kv_heads // linear_config.tp_world_size) + / linear_config.full_att_all_num_kv_heads + + conv_bytes // linear_config.tp_world_size + + ssm_bytes // linear_config.tp_world_size +) +total_bytes_copied = PAGE_NUM * bytes_per_page_per_tp + +print() +print("=" * 60) +print(f" copy_cpu_cache_to_kv_buffer speed benchmark") +print("=" * 60) +print(f" Pages / call : {PAGE_NUM}") +print(f" Tokens / page : {big_page_token_num}") +print(f" Total tokens / call : {total_tokens}") +print(f" Bytes / page (total) : {total_bytes:,}") +print(f" Bytes / page (per TP) : {bytes_per_page_per_tp:,}") +print(f" Total bytes / call : {total_bytes_copied:,} ({total_bytes_copied / 1024**3:.3f} GB)") +print(f" Iterations : {BENCH_ITERS}") +print(f" Mean latency : {statistics.mean(times_ms):.3f} ms") +print(f" Median latency : {statistics.median(times_ms):.3f} ms") +print(f" Std latency : {statistics.stdev(times_ms):.3f} ms") +print(f" Min latency : {min(times_ms):.3f} ms") +print(f" Max latency : {max(times_ms):.3f} ms") +print(f" Throughput (tokens/s) : {total_tokens / statistics.mean(times_ms) * 1e3:,.0f}") +print(f" Throughput (GB/s) : {total_bytes_copied / 1024**3 / statistics.mean(times_ms) * 1e3:.3f}") +print("=" * 60) diff --git a/test/performance/bench_image_verify.py b/test/performance/bench_image_verify.py new file mode 100644 index 0000000000..e04f1117a7 --- /dev/null +++ b/test/performance/bench_image_verify.py @@ -0,0 +1,128 @@ +"""Benchmark: find the right LIGHTLLM_IMAGE_VERIFY_WORKERS value. + +Methodology: + - Generate N independent JPEGs once (random pixels so libjpeg can't cheat). + - For each candidate pool size, create a FRESH ThreadPoolExecutor of that size, + submit all N decodes concurrently (no semaphore), measure wall time. + - This faithfully simulates production: at peak, many requests pile into + run_in_executor at once and the pool size is the real bottleneck. + +This lets us compare different LIGHTLLM_IMAGE_VERIFY_WORKERS settings in one run. + +Usage: + python test/performance/bench_image_verify.py + python test/performance/bench_image_verify.py --size 4096 --num 128 --pool_sizes 1,2,4,8,16,32,64 +""" +import argparse +import asyncio +import os +import time +from concurrent.futures import ThreadPoolExecutor +from io import BytesIO +from typing import List + +import numpy as np +from PIL import Image + +from lightllm.server.multimodal_params import _verify_image_bytes + + +def make_big_jpeg(size: int, seed: int) -> bytes: + """Random-noise JPEG so decode time is real (flat images decode too fast).""" + rng = np.random.default_rng(seed) + arr = rng.integers(0, 256, (size, size, 3), dtype=np.uint8) + buf = BytesIO() + Image.fromarray(arr).save(buf, format="JPEG", quality=85) + return buf.getvalue() + + +def bench_serial(images: List[bytes]) -> float: + t0 = time.perf_counter() + for img in images: + _verify_image_bytes(img) + return time.perf_counter() - t0 + + +def bench_pool(images: List[bytes], pool_size: int) -> float: + """Fresh pool of `pool_size`, submit all images concurrently, wait, time it.""" + pool = ThreadPoolExecutor(max_workers=pool_size, thread_name_prefix=f"bench-{pool_size}") + try: + # Pre-warm threads so we don't time thread spawn-up + list(pool.map(lambda _: None, range(pool_size))) + + async def run(): + loop = asyncio.get_running_loop() + futs = [loop.run_in_executor(pool, _verify_image_bytes, img) for img in images] + await asyncio.gather(*futs) + + t0 = time.perf_counter() + asyncio.run(run()) + return time.perf_counter() - t0 + finally: + pool.shutdown(wait=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--size", type=int, default=2048, help="image edge length, e.g. 2048/4096") + parser.add_argument("--num", type=int, default=64, help="total images to decode per run") + parser.add_argument( + "--pool_sizes", + default="1,2,4,8,16,32,64", + help="comma-separated pool sizes (LIGHTLLM_IMAGE_VERIFY_WORKERS candidates)", + ) + parser.add_argument("--warmup", type=int, default=4) + parser.add_argument("--repeat", type=int, default=2, help="repeats per pool size, takes the best") + args = parser.parse_args() + + print(f"CPU count : {os.cpu_count()}") + print(f"Image size : {args.size}x{args.size}") + print(f"Images per run : {args.num}") + print(f"Pool sizes to test : {args.pool_sizes}") + print(f"Repeats per pool : {args.repeat} (best time wins)\n") + + print("Generating distinct test images ...") + images = [make_big_jpeg(args.size, seed=i) for i in range(args.num)] + avg_kb = sum(len(b) for b in images) / len(images) / 1024 + print(f" per-image encoded size ~ {avg_kb:.1f} KB\n") + + # Warmup libjpeg / page faults + for _ in range(args.warmup): + _verify_image_bytes(images[0]) + + # Baseline + serial_times = [bench_serial(images) for _ in range(args.repeat)] + serial_t = min(serial_times) + print( + f"[serial] {args.num} images in {serial_t * 1000:.1f} ms " + f"=> {args.num / serial_t:.1f} img/s, {serial_t / args.num * 1000:.2f} ms/img\n" + ) + + # Sweep pool size + print("[threaded] — vary LIGHTLLM_IMAGE_VERIFY_WORKERS") + print(f" {'pool':>6} | {'time(ms)':>10} | {'img/s':>8} | {'speedup':>8} | {'efficiency':>10}") + print(f" {'-' * 6}-+-{'-' * 10}-+-{'-' * 8}-+-{'-' * 8}-+-{'-' * 10}") + rows = [] + for p in [int(x) for x in args.pool_sizes.split(",")]: + times = [bench_pool(images, p) for _ in range(args.repeat)] + t = min(times) + ips = args.num / t + speedup = serial_t / t + eff = speedup / p + rows.append((p, t, ips, speedup, eff)) + print(f" {p:>6} | {t * 1000:>10.1f} | {ips:>8.1f} | {speedup:>7.2f}x | {eff * 100:>9.1f}%") + + # Pick the sweet spot: largest speedup before efficiency drops below 50% + best = max(rows, key=lambda r: r[3]) + knee = next((r for r in rows if r[4] < 0.5), rows[-1]) + print(f"\nBest absolute throughput : pool={best[0]} ({best[2]:.1f} img/s, {best[3]:.2f}x)") + print(f"Diminishing-returns knee : pool={knee[0]} (efficiency drops <50% beyond here)") + print("\nHints:") + print(" - efficiency = speedup / pool_size. ~100% means perfect linear scaling.") + print(" - You usually want the smallest pool size that still gets >80% of peak throughput,") + print(" since extra threads only add scheduling + memory pressure.") + print(f" - Recommended: export LIGHTLLM_IMAGE_VERIFY_WORKERS={knee[0]}") + + +if __name__ == "__main__": + main() diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index 8ed44a2753..ffaad87f7f 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -27,8 +27,6 @@ This directory contains various startup scripts for deploying DeepSeek models wi - `multi_pd_master/config_server.sh` - Configuration server - `multi_pd_master/pd_master_1.sh` - PD Master 1 - `multi_pd_master/pd_master_2.sh` - PD Master 2 -- `multi_pd_master/pd_prefill.sh` - Prefill service -- `multi_pd_master/pd_decode.sh` - Decode service ## Usage Instructions @@ -89,9 +87,8 @@ sh multi_pd_master/config_server.sh sh multi_pd_master/pd_master_1.sh sh multi_pd_master/pd_master_2.sh -# Step 3: Start Prefill and Decode services -sh multi_pd_master/pd_prefill.sh -sh multi_pd_master/pd_decode.sh +# Step 3: Start Prefill and Decode services with the prefill/decode run modes. +# Multi-PD startup scripts for these nodes are not provided in this directory. ``` ## Configuration Guide @@ -99,7 +96,7 @@ sh multi_pd_master/pd_decode.sh ### Environment Variables - `LOADWORKER`: Model loading thread count, recommended 8-18 -- `DISABLE_KV_TRANS_USE_P2P`: Disable P2P communication optimization to transfer kv data +- `LIGHTLLM_PD_KV_TRANSPORT_BACKEND`: KV transporter backend for PD disaggregation, `nixl` by default; set to `nccl` to use the NCCL data plane. - `CUDA_VISIBLE_DEVICES`: Specify GPU devices to use ### Important Parameters @@ -198,4 +195,4 @@ python benchmark_client.py \ 2. Adjust parameters according to actual hardware configuration 3. Ensure network environment meets multi-node deployment requirements 4. Recommend thorough testing before production deployment -5. Regularly monitor service status and performance metrics \ No newline at end of file +5. Regularly monitor service status and performance metrics diff --git a/test/start_scripts/qwen35/122b.sh b/test/start_scripts/qwen35/122b.sh new file mode 100644 index 0000000000..79dfe07157 --- /dev/null +++ b/test/start_scripts/qwen35/122b.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +set -euo pipefail + +ARGS=( + --model_dir /mtc/models/Qwen3.5-122B-A10B + --tp 8 + --port 8088 + --max_req_total_len 262144 + --linear_att_hash_page_size 8192 + --linear_att_page_block_num 8 +) + +if [[ "${ENABLE_DEEPEP:-0}" == "1" || "${ENABLE_DEEPEP:-}" == "true" ]]; then + ARGS+=(--quant_cfg ../../advanced_config/mixed_quantization/qwen3_5-122b-moe-only-fp8.yaml --enable_ep_moe --dp 8 --batch_max_tokens 4096 --graph_max_batch_size 64 --chunked_prefill_size 2048 --mem_fraction 0.8 --linear_att_cache_size 300) +elif [[ -n "${QUANT_TYPE:-}" ]]; then + ARGS+=(--mem_fraction 0.85 --quant_type "${QUANT_TYPE}" --linear_att_cache_size 3000) +else + ARGS+=(--mem_fraction 0.85 --linear_att_cache_size 3000) +fi + +ENABLE_CPU_CACHE_ARG=false +ENABLE_DISK_CACHE_ARG=false + +if [[ "${ENABLE_DISK_CACHE:-0}" == "1" || "${ENABLE_DISK_CACHE:-}" == "true" || -n "${DISK_CACHE_STORAGE_SIZE:-}" || -n "${DISK_CACHE_DIR:-}" ]]; then + ENABLE_DISK_CACHE_ARG=true +fi + +if [[ "${ENABLE_CPU_CACHE:-0}" == "1" || "${ENABLE_CPU_CACHE:-}" == "true" || -n "${CPU_CACHE_STORAGE_SIZE:-}" || "${ENABLE_DISK_CACHE_ARG}" == "true" ]]; then + ENABLE_CPU_CACHE_ARG=true +fi + +if [[ "${ENABLE_CPU_CACHE_ARG}" == "true" ]]; then + ARGS+=(--enable_cpu_cache) +fi + +if [[ -n "${CPU_CACHE_STORAGE_SIZE:-}" ]]; then + ARGS+=(--cpu_cache_storage_size "${CPU_CACHE_STORAGE_SIZE}") +fi + +if [[ "${ENABLE_DISK_CACHE_ARG}" == "true" ]]; then + ARGS+=(--enable_disk_cache) +fi + +if [[ -n "${DISK_CACHE_STORAGE_SIZE:-}" ]]; then + ARGS+=(--disk_cache_storage_size "${DISK_CACHE_STORAGE_SIZE}") +fi + +if [[ -n "${DISK_CACHE_DIR:-}" ]]; then + ARGS+=(--disk_cache_dir "${DISK_CACHE_DIR}") +fi + +LOADWORKER="${LOADWORKER:-18}" CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}" \ + python -m lightllm.server.api_server "${ARGS[@]}" \ No newline at end of file diff --git a/test/start_scripts/single_pd_master/pd_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh index dac7a6dac6..eb45622f3e 100644 --- a/test/start_scripts/single_pd_master/pd_decode.sh +++ b/test/start_scripts/single_pd_master/pd_decode.sh @@ -4,6 +4,11 @@ # sh pd_decode.sh export host=$1 export pd_master_ip=$2 + +export UCX_NET_DEVICES=$(ibv_devinfo | grep 'hca_id:' | grep -v -E 'mlx5_8|mlx5_9' | awk '{print $2":1"}' | paste -sd, -) +export UCX_LOG_LEVEL=info +export UCX_TLS=rc,cuda,gdr_copy + nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ @@ -17,4 +22,4 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_decode_microbatch_overlap \ No newline at end of file +#--enable_decode_microbatch_overlap diff --git a/test/start_scripts/single_pd_master/pd_nixl_decode.sh b/test/start_scripts/single_pd_master/pd_nixl_decode.sh deleted file mode 100644 index 4b3fd0bc4e..0000000000 --- a/test/start_scripts/single_pd_master/pd_nixl_decode.sh +++ /dev/null @@ -1,25 +0,0 @@ -# PD decode mode for deepseek R1 (DP+EP) on H200 -# host: the host of the current node -# pd_master_ip: the ip of the pd master -# sh pd_decode.sh -export host=$1 -export pd_master_ip=$2 - -export UCX_NET_DEVICES=$(ibv_devinfo | grep 'hca_id:' | grep -v -E 'mlx5_8|mlx5_9' | awk '{print $2":1"}' | paste -sd, -) -export UCX_LOG_LEVEL=info -export UCX_TLS=rc,cuda,gdr_copy - -nvidia-cuda-mps-control -d -LOADWORKER=18 python -m lightllm.server.api_server \ ---model_dir /path/DeepSeek-R1 \ ---run_mode "nixl_decode" \ ---tp 8 \ ---dp 8 \ ---host $host \ ---port 8121 \ ---nccl_port 12322 \ ---enable_ep_moe \ ---pd_master_ip $pd_master_ip \ ---pd_master_port 60011 -# if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_decode_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh deleted file mode 100644 index f415919f90..0000000000 --- a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh +++ /dev/null @@ -1,27 +0,0 @@ -# PD prefill mode for deepseek R1 (DP+EP) on H200 -# host: the host of the current node -# pd_master_ip: the ip of the pd master -# sh pd_prefill.sh - -### nixl pd mode used -export UCX_NET_DEVICES=$(ibv_devinfo | grep 'hca_id:' | grep -v -E 'mlx5_8|mlx5_9' | awk '{print $2":1"}' | paste -sd, -) -export UCX_LOG_LEVEL=info -export UCX_TLS=rc,cuda,gdr_copy - -export host=$1 -export pd_master_ip=$2 -nvidia-cuda-mps-control -d -LOADWORKER=18 python -m lightllm.server.api_server \ ---model_dir /path/DeepSeek-R1 \ ---run_mode "nixl_prefill" \ ---tp 8 \ ---dp 8 \ ---host $host \ ---port 8019 \ ---nccl_port 2732 \ ---enable_ep_moe \ ---disable_cudagraph \ ---pd_master_ip $pd_master_ip \ ---pd_master_port 60011 -# if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_prefill_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh index 6bde9ef32c..08f5300bc6 100644 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -2,6 +2,12 @@ # host: the host of the current node # pd_master_ip: the ip of the pd master # sh pd_prefill.sh + +### PD mode using the default KV transport +export UCX_NET_DEVICES=$(ibv_devinfo | grep 'hca_id:' | grep -v -E 'mlx5_8|mlx5_9' | awk '{print $2":1"}' | paste -sd, -) +export UCX_LOG_LEVEL=info +export UCX_TLS=rc,cuda,gdr_copy + export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d @@ -13,9 +19,9 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8019 \ --nccl_port 2732 \ +--enable_ep_moe \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ ---pd_master_port 60011 \ ---enable_ep_moe +--pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_prefill_microbatch_overlap \ No newline at end of file +#--enable_prefill_microbatch_overlap diff --git a/test/test_api/test_anthropic_extra_body.py b/test/test_api/test_anthropic_extra_body.py new file mode 100644 index 0000000000..99f08a01a6 --- /dev/null +++ b/test/test_api/test_anthropic_extra_body.py @@ -0,0 +1,193 @@ +"""Unit test for Anthropic -> OpenAI request translation with extra_body. + +Verifies that ``extra_body.chat_template_kwargs`` (and other backend-specific +fields nested under ``extra_body`` per OpenAI SDK convention) survive the +/v1/messages request translation, so clients can opt out of model-default +thinking modes on engines that expose the toggle through +ChatCompletionRequest.chat_template_kwargs. + +No server required — calls the pure translation helper directly. +""" + +import asyncio +import pytest +import ujson as json + +pytest.importorskip("litellm") + +from lightllm.server.api_anthropic import _anthropic_to_chat_request, _openai_sse_to_anthropic_events + + +def _base_body(): + return { + "model": "test-model", + "max_tokens": 32, + "messages": [{"role": "user", "content": "hi"}], + } + + +def test_extra_body_chat_template_kwargs_forwarded(): + body = _base_body() + body["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}} + + chat_dict, _ = _anthropic_to_chat_request(body) + + assert chat_dict.get("chat_template_kwargs") == {"enable_thinking": False} + assert "extra_body" not in chat_dict + + +def test_extra_body_multiple_fields_forwarded(): + body = _base_body() + body["extra_body"] = { + "chat_template_kwargs": {"enable_thinking": False}, + "do_sample": False, + "top_k": 5, + } + + chat_dict, _ = _anthropic_to_chat_request(body) + + assert chat_dict.get("chat_template_kwargs") == {"enable_thinking": False} + assert chat_dict.get("do_sample") is False + assert chat_dict.get("top_k") == 5 + + +def test_top_level_openai_field_beats_extra_body_duplicate(): + # If a field ends up in openai_dict via the Anthropic->OpenAI translation + # AND the same key appears in extra_body, the translation path wins. + body = _base_body() + body["temperature"] = 0.1 # translated by litellm -> openai_dict["temperature"] = 0.1 + body["extra_body"] = {"temperature": 0.9} + + chat_dict, _ = _anthropic_to_chat_request(body) + + assert chat_dict.get("temperature") == 0.1 + + +def test_missing_extra_body_is_noop(): + body = _base_body() + chat_dict, _ = _anthropic_to_chat_request(body) + assert "extra_body" not in chat_dict + assert "chat_template_kwargs" not in chat_dict + + +def test_non_dict_extra_body_is_ignored(): + body = _base_body() + body["extra_body"] = "not-a-dict" + chat_dict, _ = _anthropic_to_chat_request(body) + assert "extra_body" not in chat_dict + + +# Helpers for streaming test +def _chunk(delta, finish_reason=None, usage=None): + obj = {"choices": [{"delta": delta, "finish_reason": finish_reason}]} + if usage is not None: + obj["usage"] = usage + return f"data: {json.dumps(obj)}\n\n" + + +def test_interleaved_tool_calls_do_not_emit_against_closed_block(): + """Deltas for tool-call idx=1 arriving after idx=0 started must not + stream into the (now-closed) idx=0 block.""" + + async def chunks(): + yield _chunk( + { + "tool_calls": [ + {"index": 0, "id": "call_a", "function": {"name": "fn_a", "arguments": '{"x":1'}}, + ] + } + ) + yield _chunk( + { + "tool_calls": [ + {"index": 1, "id": "call_b", "function": {"name": "fn_b", "arguments": '{"y":2'}}, + ] + } + ) + yield _chunk( + { + "tool_calls": [ + {"index": 0, "function": {"arguments": "}"}}, + ] + } + ) + yield _chunk({}, finish_reason="tool_calls", usage={"prompt_tokens": 3, "completion_tokens": 4}) + + async def run(): + out = [] + async for ev in _openai_sse_to_anthropic_events(chunks(), "m", "msg_x"): + out.append(ev.decode("utf-8")) + return out + + events = asyncio.run(run()) + index_of_delta = [] + currently_open = None + for raw in events: + lines = raw.strip().split("\n") + etype = lines[0].split(": ", 1)[1] + data = json.loads(lines[1].split(": ", 1)[1]) + if etype == "content_block_start": + currently_open = data["index"] + elif etype == "content_block_stop": + currently_open = None + elif etype == "content_block_delta": + assert ( + currently_open == data["index"] + ), f"delta for index {data['index']} but open block is {currently_open}" + index_of_delta.append(data["index"]) + assert index_of_delta, "no deltas observed" + + +def test_chat_response_translation_failure_returns_valid_json(): + """If response translation raises, the error path must return a clean + Anthropic-shaped JSONResponse — not a JSONResponse wrapped in another + JSONResponse.""" + from fastapi.responses import JSONResponse + + from lightllm.server import api_anthropic + + # Exercise the helper directly; the bug in anthropic_messages_impl was + # wrapping this return value in another JSONResponse. + resp = api_anthropic._anthropic_error_response(api_anthropic.HTTPStatus.INTERNAL_SERVER_ERROR, "synthetic") + assert isinstance(resp, JSONResponse) + body = bytes(resp.body).decode("utf-8") + assert '"type":"error"' in body + assert '"message":"synthetic"' in body + assert resp.status_code == 500 + + +def test_unknown_fields_emit_debug_log(caplog): + """Silently-dropped Anthropic fields should at least emit a debug log so + users can trace 'my metadata isn't propagating' without adding prints.""" + import logging + + from lightllm.server.api_anthropic import _anthropic_to_chat_request + + body = { + "model": "m", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 8, + "metadata": {"user_id": "abc"}, + "anthropic_version": "2023-06-01", + } + # Set logger to DEBUG so caplog can capture it + logger = logging.getLogger("lightllm.server.api_anthropic") + logger.setLevel(logging.DEBUG) + + # Manually add caplog's handler to the logger to intercept logs + # (works even with propagate=False) + caplog_handler = logging.Handler() + caplog_handler.emit = lambda record: caplog.records.append(record) + logger.addHandler(caplog_handler) + + try: + try: + _anthropic_to_chat_request(body) + except RuntimeError: + import pytest + + pytest.skip("litellm not available; cannot exercise drop path") + joined = "\n".join(rec.getMessage() for rec in caplog.records) + assert "metadata" in joined or "anthropic_version" in joined + finally: + logger.removeHandler(caplog_handler) diff --git a/test/test_api/test_image_verify_api.py b/test/test_api/test_image_verify_api.py new file mode 100644 index 0000000000..c1583fa57d --- /dev/null +++ b/test/test_api/test_image_verify_api.py @@ -0,0 +1,72 @@ +"""验证残缺图片在 OpenAI /v1/chat/completions 接口被前端拦截为 4xx。 + +启动 server: + python -m lightllm.server.api_server --port 8000 --model_dir --tp 1 + +运行: + python test/test_api/test_image_verify_api.py +""" +import argparse +import base64 +import os +from io import BytesIO + +import requests +from PIL import Image + + +def make_jpeg(w=512, h=512) -> bytes: + buf = BytesIO() + Image.new("RGB", (w, h), color=(255, 0, 0)).save(buf, format="JPEG", quality=85) + return buf.getvalue() + + +def truncate(data: bytes, ratio: float = 0.3) -> bytes: + return data[: int(len(data) * (1 - ratio))] + + +def data_url(img_bytes: bytes) -> str: + return "data:image/jpeg;base64," + base64.b64encode(img_bytes).decode("ascii") + + +def call(url: str, model: str, img_bytes: bytes): + payload = { + "model": model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_url(img_bytes)}}, + {"type": "text", "text": "Describe this image."}, + ], + } + ], + "max_tokens": 16, + "temperature": 0.0, + } + return requests.post(f"{url}/v1/chat/completions", json=payload, timeout=30) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--url", default="http://127.0.0.1:8000") + parser.add_argument("--model", default="your_model_name") + args = parser.parse_args() + + cases = [ + ("intact JPEG", make_jpeg(), 200), + ("truncated JPEG", truncate(make_jpeg(1024, 1024), 0.3), 400), + ("garbage bytes", os.urandom(4096), 400), + ] + + for name, img, expected in cases: + resp = call(args.url, args.model, img) + ok = resp.status_code == expected + print(f"[{'OK' if ok else 'FAIL'}] {name:18s} -> {resp.status_code} (expected {expected})") + if not ok: + print(f" body: {resp.text[:200]}") + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/test/test_api/test_invalid_token_ids.py b/test/test_api/test_invalid_token_ids.py new file mode 100644 index 0000000000..82923f4613 --- /dev/null +++ b/test/test_api/test_invalid_token_ids.py @@ -0,0 +1,129 @@ +""" +Smoke test for the invalid_token_ids feature (logit_bias path). + +Hits the lightllm-native /generate endpoint, which forwards `logit_bias` keys +into the SamplingParams `invalid_token_ids` field. The kernel masks those +ids to -inf, so they must never appear in the output. + +Run: + python test/test_api/test_invalid_token_ids.py + +Assumes the server is up on http://localhost:8000 and the model tokenizer +is Qwen3.5 (matches the launch command in the PR description). +""" + + +import json +import sys +from typing import Dict, List, Tuple + +import requests +from transformers import AutoTokenizer + + +URL = "http://localhost:8000/generate" +HEADERS = {"Content-Type": "application/json"} +MODEL_DIR = "/nvme/models/Qwen3.5-35B-A3B" + +# Stay under INVALID_TOKEN_IDS_MAX_LENGTH (default 10). +BLOCK_WORDS = ["the", " the", "The", " is", " a", " of", " and"] + + +def _post_generate(prompt: str, parameters: dict, timeout: int = 120) -> dict: + payload = {"inputs": prompt, "parameters": parameters} + resp = requests.post(URL, headers=HEADERS, data=json.dumps(payload), timeout=timeout) + if resp.status_code != 200: + raise RuntimeError(f"{resp.status_code} {resp.text}") + return resp.json() + + +def _generated_text(resp: dict) -> str: + text = resp["generated_text"] + return text[0] if isinstance(text, list) else text + + +def _token_ids_from_details(resp: dict) -> List[int]: + tokens = resp.get("tokens", []) + if tokens and isinstance(tokens[0], list): + tokens = tokens[0] + out: List[int] = [] + for tok in tokens: + tid = tok.get("id") + if tid is not None: + out.append(int(tid)) + return out + + +def _build_block_map(tokenizer) -> Tuple[Dict[int, float], Dict[int, str]]: + """Map token id -> bias (-100 = block) and id -> source word.""" + bias: Dict[int, float] = {} + source: Dict[int, str] = {} + for w in BLOCK_WORDS: + ids = tokenizer.encode(w, add_special_tokens=False) + for tid in ids: + bias.setdefault(tid, -100.0) + source.setdefault(tid, w) + return bias, source + + +def test_invalid_token_ids(): + print("[1/3] Loading tokenizer...", flush=True) + tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) + + bias_map, source_map = _build_block_map(tokenizer) + blocked_ids = sorted(bias_map.keys()) + print(f" Blocking {len(blocked_ids)} token ids: {blocked_ids}") + for tid in blocked_ids: + print(f" {tid:6d} <- {source_map[tid]!r}") + + prompt = "Write three short English sentences about San Francisco. " "Mention the bay, the bridge and the weather." + base_params = { + "do_sample": False, + "temperature": 1.0, + "max_new_tokens": 80, + "return_details": True, + } + + print("[2/3] Baseline request (no logit_bias)...", flush=True) + base_resp = _post_generate(prompt, dict(base_params)) + base_text = _generated_text(base_resp) + base_ids = _token_ids_from_details(base_resp) + print(f" text: {base_text!r}") + base_hits = [tid for tid in base_ids if tid in bias_map] + print(f" blocked-tokens that appeared in baseline: {len(base_hits)} ({base_hits[:10]})") + + print("[3/3] logit_bias request...", flush=True) + bias_params = dict(base_params) + bias_params["logit_bias"] = {str(k): v for k, v in bias_map.items()} + biased_resp = _post_generate(prompt, bias_params) + biased_text = _generated_text(biased_resp) + biased_ids = _token_ids_from_details(biased_resp) + print(f" text: {biased_text!r}") + biased_hits = [(tid, source_map[tid]) for tid in biased_ids if tid in bias_map] + print(f" blocked-tokens that appeared with bias: {len(biased_hits)} ({biased_hits[:10]})") + + failures = [] + if biased_hits: + failures.append(f"Blocked token ids leaked into biased output: {biased_hits}") + + # Sanity check: the baseline should have produced at least one of the blocked tokens. + # If it did not, the test is uninformative (but still passes the strict check above). + if not base_hits: + print( + " WARNING: baseline did not produce any of the target tokens; " + "the assertion below is trivially satisfied." + ) + + if biased_text == base_text: + failures.append("Biased output is identical to baseline; bias may not be applied.") + + if failures: + for f in failures: + print(f"FAIL: {f}", file=sys.stderr) + sys.exit(1) + + print("PASS: invalid_token_ids correctly suppressed blocked tokens.") + + +if __name__ == "__main__": + test_invalid_token_ids() diff --git a/unit_tests/common/basemodel/attention/flashinfer/test_mla_cuda_graph_plan.py b/unit_tests/common/basemodel/attention/flashinfer/test_mla_cuda_graph_plan.py new file mode 100644 index 0000000000..376376d58f --- /dev/null +++ b/unit_tests/common/basemodel/attention/flashinfer/test_mla_cuda_graph_plan.py @@ -0,0 +1,132 @@ +import pytest +import torch + +flashinfer = pytest.importorskip("flashinfer") + +from lightllm.common.basemodel.triton_kernel.flashinfer_mla_plan import ( # noqa: E402 + fill_mla_decode_plan_for_cuda_graph, +) + + +def _make_wrapper( + q_indptr, + kv_indptr_buf, + kv_indptr, + kv_indices, + kv_lens_buf, + num_heads, + head_dim_ckv, + head_dim_kpe, + sm_scale, + dtype, + init_short, +): + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + workspace, + use_cuda_graph=True, + qo_indptr=q_indptr, + kv_indices=kv_indices, + kv_indptr=kv_indptr_buf, + kv_len_arr=kv_lens_buf, + ) + if init_short: + batch_size = kv_lens_buf.numel() + init_lens = torch.full((batch_size,), 2, dtype=torch.int32, device="cuda") + init_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") * 2 + wrapper.plan( + q_indptr, + init_indptr, + kv_indices, + init_lens, + num_heads, + head_dim_ckv, + head_dim_kpe, + 1, + False, + sm_scale, + dtype, + dtype, + ) + kv_indptr_buf.copy_(kv_indptr) + kv_lens_buf.copy_(torch.diff(kv_indptr)) + else: + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens_buf, + num_heads, + head_dim_ckv, + head_dim_kpe, + 1, + False, + sm_scale, + dtype, + dtype, + ) + return wrapper + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("lengths,num_heads", [([4, 64], 32), ([1000, 32768], 32), ([1000, 131072], 128)]) +def test_mla_cuda_graph_triton_plan_matches_flashinfer_plan(lengths, num_heads): + torch.manual_seed(0) + dtype = torch.bfloat16 + head_dim_ckv = 512 + head_dim_kpe = 64 + batch_size = len(lengths) + total_kv_len = sum(lengths) + sm_scale = (head_dim_ckv + head_dim_kpe) ** -0.5 + + q_nope = torch.randn((batch_size, num_heads, head_dim_ckv), dtype=dtype, device="cuda") + q_pe = torch.randn((batch_size, num_heads, head_dim_kpe), dtype=dtype, device="cuda") + ckv = torch.randn((total_kv_len, 1, head_dim_ckv), dtype=dtype, device="cuda") + kpe = torch.randn((total_kv_len, 1, head_dim_kpe), dtype=dtype, device="cuda") + kv_lens = torch.tensor(lengths, dtype=torch.int32, device="cuda") + kv_indptr = torch.empty(batch_size + 1, dtype=torch.int32, device="cuda") + kv_indptr[0] = 0 + kv_indptr[1:] = torch.cumsum(kv_lens, dim=0) + kv_indices = torch.arange(total_kv_len, dtype=torch.int32, device="cuda") + q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + + ref_wrapper = _make_wrapper( + q_indptr, + kv_indptr.clone(), + kv_indptr, + kv_indices, + kv_lens.clone(), + num_heads, + head_dim_ckv, + head_dim_kpe, + sm_scale, + dtype, + init_short=False, + ) + graph_wrapper = _make_wrapper( + q_indptr, + kv_indptr.clone(), + kv_indptr, + kv_indices, + kv_lens.clone(), + num_heads, + head_dim_ckv, + head_dim_kpe, + sm_scale, + dtype, + init_short=True, + ) + fill_mla_decode_plan_for_cuda_graph( + graph_wrapper, + graph_wrapper._kv_indptr_buf, + batch_size, + num_heads, + max(lengths), + ) + + ref_out = torch.empty((batch_size, num_heads, head_dim_ckv), dtype=dtype, device="cuda") + graph_out = torch.empty_like(ref_out) + ref_wrapper.run(q_nope, q_pe, ckv, kpe, out=ref_out, return_lse=False) + graph_wrapper.run(q_nope, q_pe, ckv, kpe, out=graph_out, return_lse=False) + + assert torch.allclose(ref_out, graph_out, atol=1e-2, rtol=1e-2) diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py index a01bbf32d8..3e2555b339 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py @@ -1,7 +1,6 @@ import pytest import torch -from lightllm.utils.light_utils import light_ops def alloc_tensor_func(shape, dtype, device): @@ -41,17 +40,17 @@ def __init__( # @pytest.mark.parametrize("shared_seq_len", [512]) @pytest.mark.parametrize("shared_seq_len", [0, 77, 256, 311, 512, 550]) @pytest.mark.parametrize("batch_size", list(range(6, 121, 6))) -def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len, batch_size): +def test_token_decode_attention_flash_decoding_diverse_matches_normal_decode(shared_seq_len, batch_size): """ - 测试 int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding - 与 ppl_int8kv_flash_decoding (baseline) 的对比。 + diverse 与 normal 均为仓库内 Triton 实现,应数值一致(无外部 CUDA extension)。 + diverse:int8kv_flash_decoding_diverse;对照:int8kv/normal token_decode_attention_flash_decoding。 """ from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding as diverse_attention, ) - from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( - token_decode_attention_flash_decoding as baseline_attention, + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.normal import ( + token_decode_attention_flash_decoding as normal_decode, ) num_heads = 32 @@ -89,7 +88,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le b_mark_shared_group = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") b_mark_shared_group[mark_shared_group_size - 1 :: mark_shared_group_size] = mark_shared_group_size - # 创建 baseline 的 infer_state (不需要 b_shared_seq_len) + # 标准 int8 decode(单路径 Triton) baseline_infer_state = MockInferState( batch_size=batch_size, max_kv_seq_len=max_len_in_batch, @@ -98,7 +97,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le b_seq_len=b_seq_len, ) - # 创建 diverse 的 infer_state + # diverse:多流 + 共享前缀(Triton) diverse_infer_state = MockInferState( batch_size=batch_size, max_kv_seq_len=max_len_in_batch, @@ -110,7 +109,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le ) # 运行 baseline - baseline_out = baseline_attention( + normal_out = normal_decode( q=q.clone(), infer_state=baseline_infer_state, cache_k=cache_k, @@ -131,11 +130,10 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le ) print(f"\nshared_seq_len={shared_seq_len}\nbatch_size={batch_size}") - print(f"baseline_out: {baseline_out[0, 0, :4]}") + print(f"normal_out: {normal_out[0, 0, :4]}") print(f"diverse_out: {diverse_out[0, 0, :4]}") - print(f"max diff: {(baseline_out - diverse_out).abs().max()}") + print(f"max diff: {(normal_out - diverse_out).abs().max()}") - # 与 baseline 对比 assert torch.allclose( - baseline_out, diverse_out, atol=1e-2, rtol=1e-2 - ), f"Diverse attention output should match baseline for shared_seq_len={shared_seq_len}" + normal_out, diverse_out, atol=1e-2, rtol=1e-2 + ), f"diverse vs normal decode mismatch for shared_seq_len={shared_seq_len}" diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py index c7d4442543..985a33289f 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py @@ -140,9 +140,9 @@ def test_flash_decode_stage2_execution(shared_seq_len): if __name__ == "__main__": - import importlib + # 可选:对 Triton diverse stage2 做 cudagraph bench(仅本仓库内核,无外部 CUDA 扩展)。 + import triton - from lightllm.utils.light_utils import light_ops batch_sizes = [8, 16, 32, 64] seq_lens = [32, 64, 128, 256] @@ -150,7 +150,6 @@ def test_flash_decode_stage2_execution(shared_seq_len): results = [] for batch in batch_sizes: for seq in seq_lens: - # Clear GPU cache to reduce CUDA Graph capture failures. torch.cuda.empty_cache() setup_tensors = create_tensors( @@ -161,133 +160,33 @@ def test_flash_decode_stage2_execution(shared_seq_len): kv_len=seq, req_to_tokens_len=seq, ) - - # Outputs for CUDA implementation - mid_out_cuda = setup_tensors["mid_out"].clone() - mid_out_logsumexp_cuda = setup_tensors["mid_out_logsumexp"].clone() - - # Outputs for Triton implementation - mid_out_triton = setup_tensors["mid_out"].clone() - mid_out_logsumexp_triton = setup_tensors["mid_out_logsumexp"].clone() - - # Run CUDA to get reference - light_ops.group8_int8kv_flashdecoding_diverse_stage2( - setup_tensors["block_seq"], - mid_out_cuda, - mid_out_logsumexp_cuda, - 1.0 / (setup_tensors["head_dim"] ** 0.5), - setup_tensors["q"], - setup_tensors["k"], - setup_tensors["k_scale"], - setup_tensors["v"], - setup_tensors["v_scale"], - setup_tensors["Req_to_tokens"], - setup_tensors["B_req_idx"], - setup_tensors["b_seq_len"], - setup_tensors["b_shared_seq_len"], - setup_tensors["max_len_in_batch"], - ) - - # Run Triton - flash_decode_stage2( - q=setup_tensors["q"], - k=setup_tensors["k"], - k_scale=setup_tensors["k_scale"], - v=setup_tensors["v"], - v_scale=setup_tensors["v_scale"], - Req_to_tokens=setup_tensors["Req_to_tokens"], - B_req_idx=setup_tensors["B_req_idx"], - B_Seqlen=setup_tensors["b_seq_len"], - b_shared_seq_len=setup_tensors["b_shared_seq_len"], - max_len_in_batch=setup_tensors["max_len_in_batch"], - mid_out=mid_out_triton, - mid_out_logsumexp=mid_out_logsumexp_triton, - block_seq=setup_tensors["block_seq"], - ) - - # Compare results - diff_mid_out = torch.abs(mid_out_cuda - mid_out_triton) - diff_logsumexp = torch.abs(mid_out_logsumexp_cuda - mid_out_logsumexp_triton) - max_diff_out = diff_mid_out.max().item() - max_diff_logsumexp = diff_logsumexp.max().item() - mean_diff_out = diff_mid_out.mean().item() - mean_diff_logsumexp = diff_logsumexp.mean().item() - - cos_sim_out = torch.nn.functional.cosine_similarity( - mid_out_cuda.flatten(), mid_out_triton.flatten(), dim=0 - ).item() - cos_sim_logsumexp = torch.nn.functional.cosine_similarity( - mid_out_logsumexp_cuda.flatten(), mid_out_logsumexp_triton.flatten(), dim=0 - ).item() - - print(f"\n[batch={batch}, seq={seq}] Consistency check:") - print(" mid_out:") - print(f" max_diff: {max_diff_out:.6f}, mean_diff: {mean_diff_out:.6f}, cosine_sim: {cos_sim_out:.8f}") - print(" logsumexp:") - print( - f" max_diff: {max_diff_logsumexp:.6f}, " - f"mean_diff: {mean_diff_logsumexp:.6f}, " - f"cosine_sim: {cos_sim_logsumexp:.8f}" - ) - - # Performance - fn_cuda = lambda: light_ops.group8_int8kv_flashdecoding_diverse_stage2( - setup_tensors["block_seq"], - setup_tensors["mid_out"], - setup_tensors["mid_out_logsumexp"], - 1.0 / (setup_tensors["head_dim"] ** 0.5), - setup_tensors["q"], - setup_tensors["k"], - setup_tensors["k_scale"], - setup_tensors["v"], - setup_tensors["v_scale"], - setup_tensors["Req_to_tokens"], - setup_tensors["B_req_idx"], - setup_tensors["b_seq_len"], - setup_tensors["b_shared_seq_len"], - setup_tensors["max_len_in_batch"], - ) - ms_cuda = triton.testing.do_bench_cudagraph(fn_cuda, rep=100) - - fn_triton = lambda: flash_decode_stage2( - q=setup_tensors["q"], - k=setup_tensors["k"], - k_scale=setup_tensors["k_scale"], - v=setup_tensors["v"], - v_scale=setup_tensors["v_scale"], - Req_to_tokens=setup_tensors["Req_to_tokens"], - B_req_idx=setup_tensors["B_req_idx"], - B_Seqlen=setup_tensors["b_seq_len"], - b_shared_seq_len=setup_tensors["b_shared_seq_len"], - max_len_in_batch=setup_tensors["max_len_in_batch"], - mid_out=setup_tensors["mid_out"], - mid_out_logsumexp=setup_tensors["mid_out_logsumexp"], - block_seq=setup_tensors["block_seq"], - ) - ms_triton = triton.testing.do_bench_cudagraph(fn_triton, rep=100) - - results.append( - { - "batch_size": batch, - "seq_len": seq, - "triton_ms": ms_triton, - "cuda_ms": ms_cuda, - } - ) + st = setup_tensors + + def bench_stage2(): + flash_decode_stage2( + q=st["q"], + k=st["k"], + k_scale=st["k_scale"], + v=st["v"], + v_scale=st["v_scale"], + Req_to_tokens=st["Req_to_tokens"], + B_req_idx=st["B_req_idx"], + B_Seqlen=st["b_seq_len"], + b_shared_seq_len=st["b_shared_seq_len"], + max_len_in_batch=st["max_len_in_batch"], + mid_out=st["mid_out"], + mid_out_logsumexp=st["mid_out_logsumexp"], + block_seq=st["block_seq"], + ) + + ms = triton.testing.do_bench_cudagraph(bench_stage2, rep=100) + results.append({"batch_size": batch, "seq_len": seq, "flash_decode_stage2_ms": ms}) print(results[-1]) - del setup_tensors print(f"\n{'='*80}") - print("SUMMARY - Performance Comparison") - print(f"{'='*80}") - print(f"{'batch_size':<8} {'seq_len':<12} {'triton_ms':<12} {'cuda_ms':<12} {'vs cuda':<10}") + print(f"{'batch_size':<10} {'seq_len':<10} {'flash_decode_stage2_ms':<22}") print(f"{'-'*80}") for r in results: - vs_cuda = f"{r['cuda_ms']/r['triton_ms']:.2f}x" - emoji = "🎉" if r["triton_ms"] < r["cuda_ms"] else "" - print( - f"{r['batch_size']:<8} {r['seq_len']:<12} {r['triton_ms']:<12.3f} {r['cuda_ms']:<12.3f}" - f"{vs_cuda:<10} {emoji}" - ) + print(f"{r['batch_size']:<10} {r['seq_len']:<10} {r['flash_decode_stage2_ms']:<22.4f}") print(f"{'='*80}") diff --git a/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py new file mode 100644 index 0000000000..3b2f159f62 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py @@ -0,0 +1,50 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import ( + apply_invalid_token_ids, +) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_apply_invalid_token_ids(dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for Triton kernels.") + + batch_size = 4 + vocab_size = 32 + logits = torch.randn((batch_size, vocab_size), device="cuda", dtype=dtype) + expected = logits.clone() + + invalid_token_ids_per_batch = [ + [1, 3, 5], + [], + [0, 2, 31], + [7], + ] + + flat_ids = [] + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for ids in invalid_token_ids_per_batch: + flat_ids.extend(ids) + invalid_token_num_start += len(ids) + cu_invalid_token_num.append(invalid_token_num_start) + + invalid_token_ids = torch.tensor(flat_ids, device="cuda", dtype=torch.int32) + cu_invalid_token_num = torch.tensor(cu_invalid_token_num, device="cuda", dtype=torch.int32) + + for batch_idx, ids in enumerate(invalid_token_ids_per_batch): + if ids: + expected[batch_idx, ids] = float("-inf") + + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + assert torch.equal(logits, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/unit_tests/common/fused_moe/test_append_shared_expert_topk.py b/unit_tests/common/fused_moe/test_append_shared_expert_topk.py new file mode 100644 index 0000000000..bae8e5fe70 --- /dev/null +++ b/unit_tests/common/fused_moe/test_append_shared_expert_topk.py @@ -0,0 +1,109 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.fused_moe.append_shared_expert_topk import ( + append_fused_shared_experts, +) + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for Triton kernels") + + +def test_append_fused_shared_experts_without_gate(): + topk_ids = torch.tensor([[0, 2], [1, 3], [2, 0]], dtype=torch.int32, device="cuda") + topk_weights = torch.tensor([[0.2, 0.8], [0.4, 0.6], [0.7, 0.3]], dtype=torch.float32, device="cuda") + + out_weights, out_ids = append_fused_shared_experts( + topk_weights=topk_weights, + topk_ids=topk_ids, + shared_expert_start_id=4, + num_fused_shared_experts=1, + shared_expert_gate=None, + ) + + expect_ids = torch.tensor([[0, 2, 4], [1, 3, 4], [2, 0, 4]], dtype=torch.int32, device="cuda") + expect_weights = torch.tensor( + [[0.2, 0.8, 1.0], [0.4, 0.6, 1.0], [0.7, 0.3, 1.0]], dtype=torch.float32, device="cuda" + ) + assert torch.equal(out_ids, expect_ids) + assert torch.allclose(out_weights, expect_weights) + + +def test_append_fused_shared_experts_with_gate(): + topk_ids = torch.tensor([[0, 2], [1, 3], [2, 0]], dtype=torch.int32, device="cuda") + topk_weights = torch.tensor([[0.2, 0.8], [0.4, 0.6], [0.7, 0.3]], dtype=torch.float32, device="cuda") + shared_expert_gate = torch.tensor([[0.0], [2.0], [-2.0]], dtype=torch.float32, device="cuda") + + out_weights, out_ids = append_fused_shared_experts( + topk_weights=topk_weights, + topk_ids=topk_ids, + shared_expert_start_id=4, + num_fused_shared_experts=1, + shared_expert_gate=shared_expert_gate, + ) + + expect_ids = torch.tensor([[0, 2, 4], [1, 3, 4], [2, 0, 4]], dtype=torch.int32, device="cuda") + expect_weights = torch.cat([topk_weights, torch.sigmoid(shared_expert_gate)], dim=1) + assert torch.equal(out_ids, expect_ids) + assert torch.allclose(out_weights, expect_weights) + + +def test_append_fused_shared_experts_multiple_tokens_per_grid(): + token_num = 4097 + topk_ids = torch.stack( + [ + torch.arange(token_num, dtype=torch.int32, device="cuda") % 4, + (torch.arange(token_num, dtype=torch.int32, device="cuda") + 1) % 4, + ], + dim=1, + ) + topk_weights = torch.rand((token_num, 2), dtype=torch.float32, device="cuda") + shared_expert_gate = torch.randn((token_num, 1), dtype=torch.float32, device="cuda") + + out_weights, out_ids = append_fused_shared_experts( + topk_weights=topk_weights, + topk_ids=topk_ids, + shared_expert_start_id=4, + num_fused_shared_experts=1, + shared_expert_gate=shared_expert_gate, + ) + + expect_ids = torch.cat( + [topk_ids, torch.full((token_num, 1), 4, dtype=torch.int32, device="cuda")], + dim=1, + ) + expect_weights = torch.cat([topk_weights, torch.sigmoid(shared_expert_gate)], dim=1) + assert torch.equal(out_ids, expect_ids) + assert torch.allclose(out_weights, expect_weights) + + +def test_append_fused_shared_experts_multi_shared_gate(): + token_num = 7 + topk_ids = torch.stack( + [ + torch.arange(token_num, dtype=torch.int32, device="cuda") % 4, + (torch.arange(token_num, dtype=torch.int32, device="cuda") + 1) % 4, + ], + dim=1, + ) + topk_weights = torch.rand((token_num, 2), dtype=torch.float32, device="cuda") + shared_expert_gate = torch.randn((token_num, 2), dtype=torch.float32, device="cuda") + + out_weights, out_ids = append_fused_shared_experts( + topk_weights=topk_weights, + topk_ids=topk_ids, + shared_expert_start_id=4, + num_fused_shared_experts=2, + shared_expert_gate=shared_expert_gate, + ) + + expect_ids = torch.cat( + [ + topk_ids, + torch.tensor([[4, 5]], dtype=torch.int32, device="cuda").repeat(token_num, 1), + ], + dim=1, + ) + expect_weights = torch.cat([topk_weights, torch.sigmoid(shared_expert_gate)], dim=1) + assert torch.equal(out_ids, expect_ids) + assert torch.allclose(out_weights, expect_weights) diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index 9c08cfc1a4..0376d01ee7 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -3,7 +3,10 @@ import pytest import triton from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import ( + _moe_align_fused_atomic_token, + fused_experts_impl, moe_align, + moe_align_fused, moe_align1, moe_align2, grouped_matmul, @@ -74,6 +77,145 @@ def test_moe_align1(): assert torch.equal(experts_info, true_experts_info) +def _check_moe_align_fused(topk_ids, topk_weights, expert_num, ordered=True): + expert_to_token_index = torch.empty((expert_num, topk_ids.numel()), dtype=torch.int32, device="cuda") + expert_to_weight = torch.empty((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") + expert_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") + + moe_align_fused( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + ) + torch.cuda.synchronize() + + flat_topk_ids = topk_ids.flatten() + flat_topk_weights = topk_weights.flatten() + expected_token_num = torch.bincount(flat_topk_ids, minlength=expert_num).to(torch.int32) + assert torch.equal(expert_token_num, expected_token_num) + + for expert_id, token_num in enumerate(expected_token_num.tolist()): + expected_index = torch.nonzero(flat_topk_ids == expert_id, as_tuple=False).flatten() + expected_weight = flat_topk_weights[expected_index] + expected_index = expected_index.to(torch.int32) + token_index = expert_to_token_index[expert_id, :token_num] + token_weight = expert_to_weight[expert_id, :token_num] + + if not ordered: + order = torch.argsort(token_index) + token_index = token_index[order] + token_weight = token_weight[order] + + assert torch.equal(token_index, expected_index) + assert torch.allclose(token_weight, expected_weight) + + +def test_moe_align_fused_small_token(): + expert_num = 5 + small_topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4]], dtype=torch.int32, device="cuda") + small_topk_weights = torch.tensor( + [[0.3, 0.7, 0.1], [0.2, 0.8, 0.4], [0.5, 0.6, 0.9]], dtype=torch.float32, device="cuda" + ) + _check_moe_align_fused(small_topk_ids, small_topk_weights, expert_num) + + small_many_topk_ids = torch.arange(128 * 17, dtype=torch.int32, device="cuda").reshape(128, 17) % expert_num + small_many_topk_weights = torch.arange(small_many_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape( + 128, 17 + ) + _check_moe_align_fused(small_many_topk_ids, small_many_topk_weights, expert_num) + + +def test_moe_align_fused_large_token(): + expert_num = 5 + + base_topk_ids = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 1, 4], [2, 0, 4]], dtype=torch.int32, device="cuda") + large_topk_ids = base_topk_ids.repeat(33, 1)[:129].contiguous() + large_topk_weights = torch.arange(large_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(129, 3) + _check_moe_align_fused(large_topk_ids, large_topk_weights, expert_num, ordered=False) + + medium_topk_ids = base_topk_ids.repeat(1024, 1).contiguous() + medium_topk_weights = torch.arange(medium_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(4096, 3) + _check_moe_align_fused(medium_topk_ids, medium_topk_weights, expert_num, ordered=False) + + shared_expert_num = 257 + shared_routing = torch.arange(512 * 7, dtype=torch.int32, device="cuda").reshape(512, 7) % 256 + shared_last = torch.full((512, 1), 256, dtype=torch.int32, device="cuda") + shared_topk_ids = torch.cat([shared_routing, shared_last], dim=1).contiguous() + shared_topk_weights = torch.arange(shared_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(512, 8) + _check_moe_align_fused(shared_topk_ids, shared_topk_weights, shared_expert_num, ordered=False) + + large_atomic_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() + large_atomic_topk_weights = torch.arange(large_atomic_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape( + 5121, 3 + ) + _check_moe_align_fused(large_atomic_topk_ids, large_atomic_topk_weights, expert_num, ordered=False) + + sparse_expert_num = 257 + sparse_topk_ids = base_topk_ids.repeat(1281, 1)[:5121].contiguous() + sparse_topk_weights = torch.arange(sparse_topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(5121, 3) + _check_moe_align_fused(sparse_topk_ids, sparse_topk_weights, sparse_expert_num, ordered=False) + + +def test_moe_align_fused_large_token_unordered(): + expert_num = 257 + topk_ids = torch.arange(5121 * 8, dtype=torch.int32, device="cuda").reshape(5121, 8) % expert_num + topk_weights = torch.arange(topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(5121, 8) + _check_moe_align_fused(topk_ids, topk_weights, expert_num, ordered=False) + + +def test_moe_align_fused_atomic_token_unordered(): + expert_num = 9 + topk_ids = torch.arange(257 * 4, dtype=torch.int32, device="cuda").reshape(257, 4) % expert_num + topk_weights = torch.arange(topk_ids.numel(), dtype=torch.float32, device="cuda").reshape(257, 4) + expert_to_token_index = torch.empty((expert_num, topk_ids.numel()), dtype=torch.int32, device="cuda") + expert_to_weight = torch.empty((expert_num, topk_ids.numel()), dtype=torch.float32, device="cuda") + expert_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") + + _moe_align_fused_atomic_token( + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + ) + torch.cuda.synchronize() + + flat_topk_ids = topk_ids.flatten() + flat_topk_weights = topk_weights.flatten() + expected_token_num = torch.bincount(flat_topk_ids, minlength=expert_num).to(torch.int32) + assert torch.equal(expert_token_num, expected_token_num) + + for expert_id, token_num in enumerate(expected_token_num.tolist()): + expected_index = torch.nonzero(flat_topk_ids == expert_id, as_tuple=False).flatten().to(torch.int32) + expected_weight = flat_topk_weights[expected_index] + token_index = expert_to_token_index[expert_id, :token_num] + token_weight = expert_to_weight[expert_id, :token_num] + order = torch.argsort(token_index) + assert torch.equal(token_index[order], expected_index) + assert torch.allclose(token_weight[order], expected_weight) + + +def test_fused_experts_atomic_align_path_is_deterministic(): + token_num = 129 + expert_num = 9 + hidden_size = 64 + intermediate_size = 128 + topk = 4 + hidden_states = torch.randn((token_num, hidden_size), dtype=torch.bfloat16, device="cuda") / 10 + w1 = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device="cuda") / 10 + w2 = torch.randn((expert_num, hidden_size, intermediate_size // 2), dtype=torch.bfloat16, device="cuda") / 10 + topk_ids = torch.arange(token_num * topk, dtype=torch.int32, device="cuda").reshape(token_num, topk) % expert_num + topk_weights = torch.softmax(torch.randn((token_num, topk), dtype=torch.float32, device="cuda"), dim=-1) + + out_0 = fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids) + out_1 = fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids) + torch.cuda.synchronize() + + assert torch.equal(out_0, out_1) + + def test_moe_align2(): experts_token_num = torch.zeros((4,), dtype=torch.int32, device="cuda") @@ -83,13 +225,14 @@ def test_moe_align2(): experts_token_num[3] = 16 mblocks_to_tuple_info = moe_align2(100, experts_token_num, block_m=16) + expected_expert_ids = torch.tensor([0, 2, 2, 2, 2, 3, -1, -1, -1, -1], device="cuda", dtype=torch.int32) + valid_blocks = expected_expert_ids != -1 + assert mblocks_to_tuple_info.shape[0] == triton.cdiv(100 + 4 * (16 - 1), 16) - assert torch.allclose( - mblocks_to_tuple_info[:, 0], - torch.tensor([0, 2, 2, 2, 2, 3, -1, -1, -1, -1], device="cuda", dtype=torch.int32), - ) - assert torch.allclose( - mblocks_to_tuple_info[:, 1], torch.tensor([0, 0, 1, 2, 3, 0, 0, 0, 0, 0], device="cuda", dtype=torch.int32) + assert torch.equal(mblocks_to_tuple_info[:, 0], expected_expert_ids) + assert torch.equal( + mblocks_to_tuple_info[valid_blocks, 1], + torch.tensor([0, 0, 1, 2, 3, 0], device="cuda", dtype=torch.int32), ) diff --git a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py index 29aed2a70e..8783f35a42 100644 --- a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py +++ b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py @@ -13,7 +13,6 @@ def is_fp8_native_supported(): if not is_fp8_native_supported(): pytest.skip(reason="not support fp8 test in this gpu card", allow_module_level=True) -import random from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( silu_and_mul_masked_post_quant_fwd, ) @@ -45,10 +44,9 @@ def test_silu_and_mul_masked(expert_num, token_num, hidden_dim): (expert_num, token_num, hidden_dim // 2 // quant_group_size), dtype=torch.float32, device="cuda" ) - true_out_tensor_mid = torch.randn((expert_num, token_num, hidden_dim // 2), dtype=torch.float16, device="cuda") + true_out_tensor_mid = torch.empty((expert_num, token_num, hidden_dim // 2), dtype=in_tensor.dtype, device="cuda") - masked_m = [random.randint(0, token_num) for _ in range(expert_num)] - masked_m = torch.tensor(masked_m, dtype=torch.int32, device="cuda") + masked_m = torch.full((expert_num,), token_num, dtype=torch.int32, device="cuda") silu_and_mul_fwd(in_tensor.view(-1, hidden_dim), true_out_tensor_mid.view(-1, hidden_dim // 2)) true_out_tensor, true_out_scale_tensor = per_token_group_quant_fp8( @@ -62,22 +60,53 @@ def test_silu_and_mul_masked(expert_num, token_num, hidden_dim): true_out_tensor = true_out_tensor.view(out_tensor.shape) true_out_scale_tensor = true_out_scale_tensor.view(out_scale_tensor.shape) + hidden_dim_scale_count = hidden_dim // 2 // quant_group_size for expert_id, expert_token_num in enumerate(masked_m.cpu().numpy()): + true_scale = true_out_scale_tensor[expert_id, :expert_token_num, :hidden_dim_scale_count] + out_scale = out_scale_tensor[expert_id, :expert_token_num, :hidden_dim_scale_count] assert torch.allclose( - true_out_tensor[expert_id, :expert_token_num, :].to(torch.float32), - out_tensor[expert_id, :expert_token_num, :].to(torch.float32), - atol=1e-3, - rtol=1e-2, - ) - hidden_dim_scale_count = hidden_dim // 2 // quant_group_size - assert torch.allclose( - true_out_scale_tensor[expert_id, :expert_token_num, :hidden_dim_scale_count], - out_scale_tensor[expert_id, :expert_token_num, :hidden_dim_scale_count], + true_scale, + out_scale, atol=1e-3, rtol=1e-2, ) + + true_out = true_out_tensor[expert_id, :expert_token_num, :].to(torch.float32) + out = out_tensor[expert_id, :expert_token_num, :].to(torch.float32) + true_dequant = true_out * true_scale.repeat_interleave(quant_group_size, dim=-1) + out_dequant = out * out_scale.repeat_interleave(quant_group_size, dim=-1) + assert torch.allclose(true_dequant, out_dequant, atol=1e-1, rtol=1e-1) return +def test_silu_and_mul_masked_skips_padded_tokens(): + expert_num = 3 + token_num = 4 + hidden_dim = 256 + quant_group_size = 128 + masked_m = torch.tensor([0, 2, token_num], dtype=torch.int32, device="cuda") + + in_tensor = torch.randn((expert_num, token_num, hidden_dim), dtype=torch.bfloat16, device="cuda") + out_tensor = torch.empty((expert_num, token_num, hidden_dim // 2), dtype=torch.float8_e4m3fn, device="cuda") + out_scale_tensor = torch.empty( + (expert_num, token_num, hidden_dim // 2 // quant_group_size), dtype=torch.float32, device="cuda" + ) + out_tensor.fill_(1.0) + out_scale_tensor.fill_(7.0) + + silu_and_mul_masked_post_quant_fwd(in_tensor, out_tensor, out_scale_tensor, quant_group_size, masked_m) + torch.cuda.synchronize() + + for expert_id, expert_token_num in enumerate(masked_m.cpu().tolist()): + assert torch.equal( + out_tensor[expert_id, expert_token_num:, :], + torch.ones_like(out_tensor[expert_id, expert_token_num:, :]), + ) + assert torch.equal( + out_scale_tensor[expert_id, expert_token_num:, :], + torch.full_like(out_scale_tensor[expert_id, expert_token_num:, :], 7.0), + ) + + if __name__ == "__main__": pytest.main() diff --git a/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py b/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py new file mode 100644 index 0000000000..8a54d6d9fd --- /dev/null +++ b/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py @@ -0,0 +1,83 @@ +import pytest +import torch +import triton + +from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform + + +TP = 8 +INDEX_N_HEADS = 64 +INDEX_HEAD_DIM = 128 +TP_INDEX_N_HEADS = INDEX_N_HEADS // TP +SCALE = INDEX_HEAD_DIM ** -0.5 + + +def _get_sgl_kernel_hadamard_transform(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for hadamard_transform comparison") + try: + from sgl_kernel import hadamard_transform as sgl_hadamard_transform + except ImportError: + pytest.skip("sgl_kernel.hadamard_transform is not available") + return sgl_hadamard_transform + + +def _bench(fn, x): + ms = triton.testing.do_bench_cudagraph(lambda: fn(x, scale=SCALE), return_mode="median") + return ms, fn(x, scale=SCALE) + + +@pytest.mark.parametrize("tokens", [1, 16, 128, 512, 1024, 2048, 4096, 8192, 16384]) +def test_hadamard_transform_matches_sgl_kernel_deepseek_v32_shapes(tokens): + sgl_hadamard_transform = _get_sgl_kernel_hadamard_transform() + + q = torch.randn(tokens, TP_INDEX_N_HEADS, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + k = torch.randn(tokens, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + + q_expected = sgl_hadamard_transform(q, scale=SCALE) + q_actual = hadamard_transform(q, scale=SCALE) + k_expected = sgl_hadamard_transform(k, scale=SCALE) + k_actual = hadamard_transform(k, scale=SCALE) + torch.cuda.synchronize() + + assert torch.equal(q_actual, q_expected) + assert torch.equal(k_actual, k_expected) + + +def test_hadamard_transform_perf_report_deepseek_v32_shapes(): + sgl_hadamard_transform = _get_sgl_kernel_hadamard_transform() + + print( + "\nDeepSeek-V3.2 per-rank shapes with tp=8:" + "\n q: [tokens, 8, 128]" + "\n k: [tokens, 128]" + "\n\ntokens | q_diff | k_diff | sgl_q ms | tri_q ms | sgl_k ms | tri_k ms | tri(q+k) ms | slowdown q+k" + ) + + for tokens in [1, 16, 128, 512, 1024, 2048, 4096, 8192, 16384]: + q = torch.randn(tokens, TP_INDEX_N_HEADS, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + k = torch.randn(tokens, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + + q_expected = sgl_hadamard_transform(q, scale=SCALE) + q_actual = hadamard_transform(q, scale=SCALE) + k_expected = sgl_hadamard_transform(k, scale=SCALE) + k_actual = hadamard_transform(k, scale=SCALE) + torch.cuda.synchronize() + + q_diff = (q_expected.float() - q_actual.float()).abs().max().item() + k_diff = (k_expected.float() - k_actual.float()).abs().max().item() + sgl_q_ms, _ = _bench(sgl_hadamard_transform, q) + tri_q_ms, _ = _bench(hadamard_transform, q) + sgl_k_ms, _ = _bench(sgl_hadamard_transform, k) + tri_k_ms, _ = _bench(hadamard_transform, k) + sgl_sum_ms = sgl_q_ms + sgl_k_ms + tri_sum_ms = tri_q_ms + tri_k_ms + + print( + f"{tokens:6d} | {q_diff:6.1g} | {k_diff:6.1g} | " + f"{sgl_q_ms:8.4f} | {tri_q_ms:8.4f} | {sgl_k_ms:8.4f} | {tri_k_ms:8.4f} | " + f"{tri_sum_ms:11.4f} | {tri_sum_ms / sgl_sum_ms:10.2f}x" + ) + + assert q_diff == 0 + assert k_diff == 0 diff --git a/unit_tests/models/qwen3next/test_fused_recurrent_strided.py b/unit_tests/models/qwen3next/test_fused_recurrent_strided.py new file mode 100644 index 0000000000..cf9d06ec98 --- /dev/null +++ b/unit_tests/models/qwen3next/test_fused_recurrent_strided.py @@ -0,0 +1,83 @@ +import pytest +import torch + +from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( + fused_recurrent_gated_delta_rule, +) + +if not torch.cuda.is_available(): + pytest.skip("CUDA required", allow_module_level=True) + + +@pytest.mark.parametrize("batch", [1, 2, 16]) +def test_decode_strided_views_match_contiguous(batch): + """q/k/v/a/b passed as column views of one projection output (the decode + path layout) must produce the same result as contiguous copies.""" + torch.manual_seed(0) + H, HV, K, V = 2, 8, 128, 128 + key_dim, value_dim = H * K, HV * V + qkv_dim = 2 * key_dim + value_dim + total_dim = qkv_dim + value_dim + 2 * HV # qkv + z + b + a + cache_slots = 64 + + mixed = torch.randn(batch, total_dim, device="cuda", dtype=torch.bfloat16) + mixed_qkv = mixed[:, :qkv_dim] + b_raw = mixed[:, qkv_dim + value_dim : qkv_dim + value_dim + HV] + a_raw = mixed[:, qkv_dim + value_dim + HV :] + + query, key, value = torch.split(mixed_qkv, [key_dim, key_dim, value_dim], dim=-1) + q = query.view(batch, 1, H, K) + k = key.view(batch, 1, H, K) + v = value.view(batch, 1, HV, V) + + A_log = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 + dt_bias = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 + ssm_state = torch.randn(cache_slots, HV, K, V, device="cuda", dtype=torch.bfloat16) + idx = torch.randperm(cache_slots, device="cuda")[:batch].to(torch.int32) + + def run(q_, k_, v_, a_, b_, state): + out, _ = fused_recurrent_gated_delta_rule( + q=q_, + k=k_, + v=v_, + initial_state=state, + inplace_final_state=True, + ssm_state_indices=idx, + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_, + b_raw=b_, + ) + return out + + state_ref = ssm_state.clone() + out_ref = run(q.contiguous(), k.contiguous(), v.contiguous(), a_raw.contiguous(), b_raw.contiguous(), state_ref) + state_strided = ssm_state.clone() + out_strided = run(q, k, v, a_raw, b_raw, state_strided) + + assert torch.equal(out_ref, out_strided) + assert torch.equal(state_ref, state_strided) + + +def test_cu_seqlens_is_not_supported(): + """The fused recurrent kernel is decode-only in LightLLM's Qwen3Next path.""" + H, HV, K, V = 2, 2, 4, 4 + q = torch.randn(1, 2, H, K, device="cuda", dtype=torch.bfloat16) + k = torch.randn(1, 2, H, K, device="cuda", dtype=torch.bfloat16) + v = torch.randn(1, 2, HV, V, device="cuda", dtype=torch.bfloat16) + initial_state = torch.randn(1, HV, K, V, device="cuda", dtype=torch.bfloat16) + cu_seqlens = torch.tensor([0, 2], device="cuda", dtype=torch.long) + + with pytest.raises(AssertionError, match="decode-only fused recurrent kernel"): + fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + initial_state=initial_state, + cu_seqlens=cu_seqlens, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])