Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions astraflow/raas/engine/remote_inf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,14 @@ def __init__(
self.lock = Lock()

self.lora_initialized = False
# Versioned LoRA adapter naming: each weight sync loads under a NEW
# name (``lora_v{seq}``) and we never unload. Unloading an adapter that
# still has paused/aborted in-flight requests deadlocks on SGLang's
# ``wait_for_unload`` (aborted requests never release their usage
# counter). New unique names avoid the unload entirely; SGLang's
# mem-pool LRU evicts stale adapters from GPU automatically.
self._lora_seq = 0
self._current_lora_name: str | None = None

self._executor: ProcessPoolExecutor | None = None
self._paused: bool = False
Expand Down Expand Up @@ -654,7 +662,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
f"agenerate() building HTTP request, rid={req.rid}, "
f"iteration={iteration}, server_addr={server_addr}"
)
http_req = self.backend.build_generation_request(req, self.lora_initialized)
http_req = self.backend.build_generation_request(req, self._current_lora_name)

# Loop until the generation is complete
logger.debug(
Expand Down Expand Up @@ -745,19 +753,33 @@ def load_weights_from_path(
For full weights: ``/update_weights_from_disk`` includes
``abort_all_requests: True`` and ``flush_cache`` internally.

For LoRA adapters (``use_lora=True``): unloads the old adapter,
loads the new one, then flushes the KV cache via ``/flush_cache``
to discard stale entries computed with the old LoRA weights.
Relies on sglang releasing the ``lora_registry`` counter for
aborted requests (fixed upstream in
``TokenizerManager._handle_abort_finish_reason`` as of 0.5.12).
For LoRA adapters (``use_lora=True``): loads the new adapter under a
fresh versioned name (``lora_v{seq}``) without explicitly unloading the
previous one, then flushes the KV cache. SGLang's registry LRU evicts
old versions once ``max_loaded_loras`` is reached and its mem-pool LRU
reclaims GPU slots (bounded by ``max_loras_per_batch``); an evicted
adapter is transparently re-loaded on next use.

Historically, explicitly unloading an adapter that still had
paused/aborted in-flight requests deadlocked SGLang's ``wait_for_unload``
because the adapter's usage counter was never released on abort. That
leak is now fixed at the source by ``LoRACounterLeakPatch``
(``astraflow/raas/patch/sglang.py``), so unload/eviction is safe. We keep
the fresh-name scheme because it stays correct without draining under
``lora_update_lock`` on every sync.
"""
import time as _time

_t0 = _time.monotonic()
lora_name = "lora_1"

if use_lora:
# Load under a NEW versioned name and do NOT explicitly unload the
# old one. The abort-time usage-counter leak that used to make
# ``wait_for_unload`` (and thus registry-LRU eviction) hang is fixed
# by LoRACounterLeakPatch, so eviction is safe; the fresh name also
# avoids draining under ``lora_update_lock`` on every sync.
self._lora_seq += 1
lora_name = f"lora_v{self._lora_seq}"
logger.info(
"load_weights_from_path: sending /load_lora_adapter "
"to %d servers (path=%s, lora_name=%s) ...",
Expand All @@ -766,19 +788,13 @@ def load_weights_from_path(
lora_name,
)
try:
if self.lora_initialized:
unload_req = HttpRequest(
endpoint="/unload_lora_adapter",
payload={"lora_name": lora_name},
)
self._run_request_on_all_servers(unload_req)

load_req = HttpRequest(
endpoint="/load_lora_adapter",
payload={"lora_name": lora_name, "lora_path": str(path)},
)
self._run_request_on_all_servers(load_req)
self.lora_initialized = True
self._current_lora_name = lora_name

# Flush stale KV cache entries computed with old LoRA weights.
# Safe because caller already paused generation (is_pause=True
Expand Down
12 changes: 8 additions & 4 deletions astraflow/raas/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ class SGLangBackend:
"""Backend that translates engine operations into SGLang HTTP API calls."""

def build_generation_request(
self, req: ModelRequest, with_lora: bool
self, req: ModelRequest, lora_name: str | None
) -> HttpRequest:
"""Convert a ModelRequest into an SGLang /generate HTTP request."""
"""Convert a ModelRequest into an SGLang /generate HTTP request.

``lora_name`` is the currently-active versioned adapter name (e.g.
``lora_v3``) or ``None`` when no adapter is loaded.
"""
gconfig = req.gconfig
stop_token_ids = gconfig.stop_token_ids
stop = gconfig.stop
Expand Down Expand Up @@ -55,8 +59,8 @@ def build_generation_request(
"stream": False,
}

if with_lora:
payload["lora_path"] = "lora_1"
if lora_name:
payload["lora_path"] = lora_name

return HttpRequest(endpoint="/generate", payload=payload)

Expand Down
15 changes: 12 additions & 3 deletions astraflow/raas/engine/vllm_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@ def __init__(self):
pass

def build_generation_request(
self, req: ModelRequest, with_lora: bool
self, req: ModelRequest, lora_name: str | None
) -> HttpRequest:
"""Convert a ModelRequest into a vLLM completions or chat HTTP request."""
"""Convert a ModelRequest into a vLLM completions or chat HTTP request.

``lora_name`` is a truthy marker that a LoRA is active; vLLM selects
the adapter via ``gconfig.lora_name`` (its own naming), so the marker's
value is unused here.
"""
gconfig = req.gconfig
stop_token_ids = gconfig.stop_token_ids
stop = gconfig.stop
Expand All @@ -54,7 +59,7 @@ def build_generation_request(
if stop:
payload["stop"] = stop

if with_lora and len(gconfig.lora_name) > 0:
if lora_name and len(gconfig.lora_name) > 0:
payload["model"] = gconfig.lora_name

if req.vision_msg_vllm:
Expand Down Expand Up @@ -181,6 +186,10 @@ def __init__(self, config: InferenceEngineConfig):
self.config = config
self._engine = RemoteInfEngine(config, VLLMBackend())
self._engine.lora_initialized = config.use_lora
# vLLM selects the adapter via gconfig.lora_name; this just marks LoRA
# active so the shared generation-request builder passes a truthy flag.
if config.use_lora:
self._engine._current_lora_name = "vllm_lora"

def __getattr__(self, name: str):
return getattr(self._engine, name)
Expand Down
2 changes: 2 additions & 0 deletions astraflow/raas/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ def _validate_patch_results(results: Dict[str, bool], strict: bool) -> None:
def _run_sglang_patches(strict: bool) -> bool:
from astraflow.raas.patch.sglang import (
HttpServerPatch,
LoRACounterLeakPatch,
ServerArgsPatch,
)

manager = PatchManager()
manager.register(ServerArgsPatch())
manager.register(HttpServerPatch())
manager.register(LoRACounterLeakPatch())

results = manager.apply_all()
_log_patch_results(results)
Expand Down
113 changes: 113 additions & 0 deletions astraflow/raas/patch/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
can register with RaaS at startup.
2. HttpServerPatch — register SGLang instance with the rollout manager
during ``launch_server``.
3. LoRACounterLeakPatch — guarantee the LoRA adapter usage counter is
released for every request, including aborted / client-disconnected ones,
fixing a weight-sync deadlock at its source (see the class docstring).
"""

import logging
Expand All @@ -16,6 +19,37 @@
logger = logging.getLogger(__name__)


async def release_lora_ref_once(tm, sub_obj) -> None:
"""Release ``sub_obj``'s LoRA usage counter on ``tm`` (TokenizerManager)
exactly once, if it is still held.

Idempotency invariant: SGLang's two native release sites both
``del rid_to_state[rid]`` immediately before releasing, so ``rid in
rid_to_state`` iff the request has NOT yet been released. The membership
check and ``pop`` have no ``await`` between them, so they are atomic on the
single-threaded event loop — guaranteeing release is awaited at most once
per request. This matters because ``ConcurrentCounter.decrement`` has no
floor: a double-release would drive the counter to -1 and make
``wait_for_zero`` (hence ``wait_for_unload``) hang forever.
"""
if not getattr(tm.server_args, "enable_lora", False):
return
if not getattr(sub_obj, "lora_path", None):
return
rid = getattr(sub_obj, "rid", None)
if rid is None or rid not in tm.rid_to_state:
return
tm.rid_to_state.pop(rid, None)
lora_id = getattr(sub_obj, "lora_id", None)
if lora_id is not None:
try:
await tm.lora_registry.release(lora_id)
except Exception:
logger.exception(
"release_lora_ref_once: release failed for rid=%s", rid
)


class ServerArgsPatch(BasePatch):
"""Add ``--rollout-manager-address`` to SGLang's ServerArgs."""

Expand Down Expand Up @@ -94,3 +128,82 @@ def patched_launch_server(server_args, *args, **kwargs):

traceback.print_exc()
return False


class LoRACounterLeakPatch(BasePatch):
"""Release the LoRA adapter usage counter on EVERY request teardown.

Root cause of the LoRA weight-sync deadlock: SGLang's ``LoRARegistry`` keeps
a per-adapter ``ConcurrentCounter`` (``lora/lora_registry.py``). It is
``acquire()``-ed for every generate request but ``release()``-ed only on two
conditional branches in the tokenizer manager — normal completion
(``_handle_batch_output``) and one scheduler-abort case (``_wait_one_response``,
status SERVICE_UNAVAILABLE / INTERNAL_SERVER_ERROR). Requests that are aborted
or whose client disconnects (which the RaaS per-step drain routinely creates)
exit ``_wait_one_response`` without releasing — via a ``raise`` (client
disconnect, BAD_REQUEST) or a plain ``break`` (waiting-queue abort). The
adapter's counter then never returns to zero, so ``LoRARegistry.wait_for_unload``
blocks forever. That hangs both an explicit ``/unload_lora_adapter`` AND the
``load_lora_adapter`` LRU eviction that fires once ``max_loaded_loras`` versioned
adapters accumulate — while holding ``lora_update_lock``, freezing all further
LoRA ops. (The RaaS versioned-name scheme merely defers this to ~``max_loaded_loras``
steps; this patch removes the leak so unload/eviction is always safe.)

Fix: wrap ``TokenizerManager.generate_request`` — the single outermost
per-request async generator, where ``acquire`` happens (via
``_validate_and_resolve_lora``) — and release in a ``finally`` so it runs on
every exit (normal return, raise, ``GeneratorExit``, ``CancelledError``).
Release is idempotent via the invariant that both native release sites
``del rid_to_state[rid]`` immediately before releasing: ``rid in rid_to_state``
iff not yet released. The membership check and ``pop`` have no ``await``
between them, so they are atomic on the single-threaded event loop — no
double-release (which would drive the counter to -1 and hang ``wait_for_zero``
permanently, since ``ConcurrentCounter.decrement`` has no floor).
"""

def apply(self) -> bool:
import os

if os.getenv("ASTRAFLOW_DISABLE_LORA_LEAK_FIX", "0").lower() in ("1", "true"):
logger.warning(
"LoRACounterLeakPatch disabled via ASTRAFLOW_DISABLE_LORA_LEAK_FIX; "
"LoRA weight-sync may deadlock on registry-LRU eviction."
)
return True

try:
from sglang.srt.managers.tokenizer_manager import TokenizerManager
except Exception as e:
logger.error(f"LoRACounterLeakPatch failed: {e}")
return False

original_generate_request = TokenizerManager.generate_request
if self._is_patched(original_generate_request, "generate_request"):
return True

async def patched_generate_request(self, obj, request=None):
try:
async for response in original_generate_request(self, obj, request):
yield response
finally:
# Guaranteed release on every exit path (normal, raise,
# GeneratorExit, CancelledError). ``obj`` has been normalized by
# ``original_generate_request`` before it reached the scheduler.
try:
if getattr(obj, "is_single", True):
await release_lora_ref_once(self, obj)
else:
# Batch request: release each sub-request that still
# holds its counter. (RaaS rollouts are single;
# best-effort.)
rids = getattr(obj, "rid", None)
if isinstance(rids, (list, tuple)):
for i in range(len(rids)):
await release_lora_ref_once(self, obj[i])
except Exception:
logger.exception("LoRACounterLeakPatch cleanup error")

self._mark_as_patched(patched_generate_request, "generate_request")
TokenizerManager.generate_request = patched_generate_request

return True
Empty file.
Loading
Loading