From af27c13e1b80ce293c0ac1ec46204bc37ce9769e Mon Sep 17 00:00:00 2001 From: nexadodigital Date: Mon, 11 May 2026 21:57:17 +0100 Subject: [PATCH] fix: replace pickle.loads with RestrictedUnpickler in WebSocket endpoints (CVE-2026-26220) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The PD disaggregation WebSocket endpoints (/pd_register, /kv_move_status) and the config-server HTTP response handler called bare pickle.loads() on untrusted network data, enabling unauthenticated remote code execution. Introduce lightllm/utils/safe_pickle.py with a RestrictedUnpickler that whitelists only the internal LightLLM dataclass modules that legitimately flow through these channels. All four vulnerable callsites in api_http.py and httpserver/pd_loop.py are replaced with safe_loads(). No protocol changes — pickle wire format is preserved; only class instantiation is restricted. Fixes CVE-2026-26220 (CVSS 9.3 Critical, CWE-502). Co-Authored-By: Claude Sonnet 4.6 --- lightllm/server/api_http.py | 7 +-- lightllm/server/httpserver/pd_loop.py | 7 +-- lightllm/utils/safe_pickle.py | 78 +++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 6 deletions(-) create mode 100644 lightllm/utils/safe_pickle.py diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index c106ca1cd9..a1621febcd 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -25,8 +25,9 @@ import base64 import os from io import BytesIO -import pickle +import pickle # kept for non-network uses; WebSocket paths use safe_pickle import setproctitle +from lightllm.utils.safe_pickle import safe_loads as _safe_pickle_loads asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import ujson as json @@ -426,7 +427,7 @@ async def register_and_keep_alive(websocket: WebSocket): while True: # 等待接收消息,设置超时为10秒 data = await websocket.receive_bytes() - obj = pickle.loads(data) + obj = _safe_pickle_loads(data) # CVE-2026-26220: restricted unpickling await g_objs.httpserver_manager.put_to_handle_queue(obj) except (WebSocketDisconnect, Exception, RuntimeError) as e: @@ -447,7 +448,7 @@ async def kv_move_status(websocket: WebSocket): while True: # 等待接收消息,设置超时为10秒 data = await websocket.receive_bytes() - upkv_status = pickle.loads(data) + upkv_status = _safe_pickle_loads(data) # CVE-2026-26220: restricted unpickling logger.info(f"received upkv_status {upkv_status} from {(client_ip, client_port)}") await g_objs.httpserver_manager.update_req_status(upkv_status) except (WebSocketDisconnect, Exception, RuntimeError) as e: diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index a646d4f4cc..9ac3c28324 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -1,6 +1,7 @@ import asyncio -import pickle +import pickle # kept for outbound pickle.dumps; inbound paths use safe_pickle import websockets +from lightllm.utils.safe_pickle import safe_loads as _safe_pickle_loads import ujson as json import socket import httpx @@ -103,7 +104,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 while True: recv_bytes = await websocket.recv() - obj = pickle.loads(recv_bytes) + obj = _safe_pickle_loads(recv_bytes) # CVE-2026-26220: restricted unpickling if obj[0] == ObjType.REQ: prompt, sampling_params, multimodal_params = obj[1] group_req_id = sampling_params.group_request_id @@ -183,7 +184,7 @@ async def _get_pd_master_objs(args: StartArgs) -> Optional[Dict[int, PD_Master_O response = await client.get(uri) if response.status_code == 200: base64data = response.json()["data"] - id_to_pd_master_obj = pickle.loads(base64.b64decode(base64data)) + id_to_pd_master_obj = _safe_pickle_loads(base64.b64decode(base64data)) # CVE-2026-26220: restricted unpickling return id_to_pd_master_obj else: logger.error(f"get pd_master_objs error {response.status_code}") diff --git a/lightllm/utils/safe_pickle.py b/lightllm/utils/safe_pickle.py new file mode 100644 index 0000000000..0f786977a7 --- /dev/null +++ b/lightllm/utils/safe_pickle.py @@ -0,0 +1,78 @@ +""" +safe_pickle.py — Restricted unpickling for LightLLM PD WebSocket endpoints. + +CVE-2026-26220: The PD disaggregation WebSocket endpoints and the config-server +HTTP response handler called pickle.loads() on untrusted network data, enabling +unauthenticated remote code execution. This module replaces those bare +pickle.loads() calls with a RestrictedUnpickler that whitelists only the +internal LightLLM dataclass modules that legitimately flow through those +channels. + +Usage: + from lightllm.utils.safe_pickle import safe_loads + + obj = safe_loads(data) # raises UnpicklingError for non-whitelisted types +""" + +import io +import pickle + +# --------------------------------------------------------------------------- +# Allowlist: (module, name) pairs permitted to be deserialized. +# Keep this list minimal — add entries only when a new type is deliberately +# added to a PD WebSocket protocol message. +# --------------------------------------------------------------------------- +_ALLOWED_MODULES: dict[str, set[str]] = { + # Built-in safe types used as containers / primitives + "builtins": {"dict", "list", "tuple", "set", "int", "float", "str", "bool", "bytes", "NoneType"}, + # LightLLM PD protocol structures + "lightllm.server.pd_io_struct": { + "ObjType", + "NodeRole", + "PD_Master_Obj", + "PD_Client_Obj", + "_PD_Client_RunStatus", + "UpKVStatus", + "NixlUpKVStatus", + "DecodeNodeInfo", + "NIXLDecodeNodeInfo", + "KVMoveTask", + "KVMoveTaskGroup", + "PDTransJoinInfo", + "PDTransLeaveInfo", + "NIXLChunckedTransTask", + "NIXLChunckedTransTaskRet", + "NIXLChunckedTransTaskGroup", + }, + # SamplingParams / StartArgs sent inside REQ messages + "lightllm.server.core.objs.py_sampling_params": {"SamplingParams"}, + "lightllm.server.core.objs.sampling_params": {"SamplingParams"}, + "lightllm.server.core.objs.start_args_type": {"StartArgs"}, + # Multimodal params can accompany REQ messages + "lightllm.server.multimodal_params": {"MultimodalParams"}, + # enum base class (needed for ObjType / NodeRole reconstruction) + "enum": {"Enum"}, +} + + +class _RestrictedUnpickler(pickle.Unpickler): + """Unpickler that raises UnpicklingError for any non-whitelisted class.""" + + def find_class(self, module: str, name: str): + allowed_names = _ALLOWED_MODULES.get(module) + if allowed_names is not None and name in allowed_names: + return super().find_class(module, name) + raise pickle.UnpicklingError( + f"Refusing to deserialize {module}.{name}: not in safe_pickle allowlist. " + f"This may indicate an attempted pickle-based RCE exploit (CVE-2026-26220)." + ) + + +def safe_loads(data: bytes) -> object: + """ + Drop-in replacement for pickle.loads() that only permits whitelisted types. + + Raises pickle.UnpicklingError if the payload attempts to instantiate a + class outside the allowlist. + """ + return _RestrictedUnpickler(io.BytesIO(data)).load()