Skip to content

feat(RL): add RL support for verl#1298

Open
shihaobai wants to merge 196 commits intomainfrom
rl_verl_rebase_main
Open

feat(RL): add RL support for verl#1298
shihaobai wants to merge 196 commits intomainfrom
rl_verl_rebase_main

Conversation

@shihaobai
Copy link
Copy Markdown
Collaborator

No description provided.

shihaobai and others added 13 commits March 26, 2026 13:58
Adds multi-instance port isolation to allow multiple LightLLM servers
on the same machine without port conflicts. Each instance gets a
dedicated 1000-port range (instance 0: 10000-10999, etc.).

Changes:
- Added --lightllm_instance_id CLI arg (0-7) for instance selection
- Refactored port allocation to use deterministic ranges instead of
  random selection via portpicker
- Removed portpicker dependency from requirements.txt
- Base port configurable via LIGHTLLM_BASE_PORT env var
- Removed SO_REUSEADDR from port probe to avoid false positives
- Simplified to single linear scan (removed ineffective retry logic)
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several significant enhancements to the LightLLM framework, including support for Multi-Token Prediction (MTP) with optimized Triton kernels, memory management improvements (including a torch memory saver), and support for new model architectures like NeoChat. It also adds a routing data capture mechanism for MoE models and improves the robustness of the server launch process. My review identified several critical issues in the new MTP state management and kernel logic, including incorrect assertion logic, device mismatch errors when indexing tensors, and incorrect method names being called. I have provided specific suggestions to address these bugs.

):
start_args = get_env_start_args()
if self.size is not None:
assert self.size < start_args.running_max_req_size * 2, (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion logic here is inverted. It should check that self.size is at least running_max_req_size * 2, as indicated by the error message on the following lines.

Suggested change
assert self.size < start_args.running_max_req_size * 2, (
assert self.size >= start_args.running_max_req_size * 2, (

Comment on lines +279 to +280
if mask.sum() > 0:
actual_req_idxes = model_input.b_req_idx[b_req_mtp_start_loc[mask]]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are two issues here:

  1. mask.sum() > 0 on a boolean tensor returns a tensor. It's safer to use .any().
  2. b_req_mtp_start_loc is a list (initialized at line 255), so it cannot be indexed by a boolean tensor mask. You should convert it to a tensor or use list comprehension.
Suggested change
if mask.sum() > 0:
actual_req_idxes = model_input.b_req_idx[b_req_mtp_start_loc[mask]]
if mask.any():
mask_cpu = mask.cpu()
actual_req_idxes = model_input.b_req_idx[[b_req_mtp_start_loc[i] for i, m in enumerate(mask_cpu) if m]]

Comment on lines +485 to +487
if mask.sum() > 0:
actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]]
src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the issue in chunked_prefill/impl.py:

  1. mask.sum() > 0 should be mask.any().
  2. b_req_mtp_start_loc is a list and cannot be indexed by a tensor.
  3. mtp_accept_len[mask] should be moved to CPU to avoid device mismatch with req_to_buffer_index.
Suggested change
if mask.sum() > 0:
actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]]
src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[
if mask.any():
mask_cpu = mask.cpu()
actual_req_idxes = b_req_idx[[b_req_mtp_start_loc[i] for i, m in enumerate(mask_cpu) if m]]
src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[
actual_req_idxes, mtp_accept_len[mask].cpu() - 1
]

Comment thread lightllm/common/basemodel/basemodel.py Outdated
Comment on lines +1235 to +1241
def resume_all(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
self.mem_manager.free_all()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The resume_all method is missing a call to self.req_manager.resume(), which is present in resume_kv_cache. Without this, the request manager might not be properly re-initialized after a memory resume operation.

Suggested change
def resume_all(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
self.mem_manager.free_all()
def resume_all(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
self.mem_manager.free_all()
self.req_manager.resume()

cur_group_reqs, is_busy, new_batch_first_router_need_tokens
)
if ok_insert:
if ok_insert and False:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition if ok_insert and False: effectively disables adding new group requests to the running list. This looks like debug code that was accidentally left in.

Suggested change
if ok_insert and False:
if ok_insert:

actual_req_idxes = model_input.b_req_idx[b_req_mtp_start_loc[mask]]
# Source: the accepted buffer (at index accept_len - 1)
src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[
actual_req_idxes, mtp_accept_len[mask] - 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Potential device mismatch. mtp_accept_len is a GPU tensor, while req_to_buffer_index is likely a CPU tensor managed by ReqManager. You should move the indices to CPU before indexing.

                src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[
                    actual_req_idxes, mtp_accept_len[mask].cpu() - 1
                ]

Comment on lines +491 to +494
if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"):
g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p(
src_buffer_indexes, dst_buffer_indexes
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The method name in MambaCacheManager is copy_state_buffers, not copy_buffer_p2p. The current check and call will fail to execute the state copy.

Suggested change
if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"):
g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p(
src_buffer_indexes, dst_buffer_indexes
)
if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_state_buffers"):
g_infer_context.req_manager.buffer_mem_manager.copy_state_buffers(
src_buffer_indexes, dst_buffer_indexes
)

# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using string comparison on torch.__version__ (e.g., '2.10' < '2.6') can lead to incorrect results. It is safer to use version.parse since it is already imported.

Suggested change
pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
pg_options_param_name = "backend_options" if version.parse(torch.__version__) >= version.parse("2.6") else "pg_options"

@shihaobai shihaobai requested review from hiworldwzj and kingder May 9, 2026 05:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants