diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index a793188727..8001ceb73d 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -33,6 +33,18 @@ index 6fbd1db82..4c681b58d 100644 ) elif not needs_tf_v5: logger.warning( +diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py +index da4629e52..c03f98231 100644 +--- a/python/sglang/srt/disaggregation/base/conn.py ++++ b/python/sglang/srt/disaggregation/base/conn.py +@@ -17,6 +17,7 @@ class KVArgs: + kv_data_ptrs: List[int] + kv_data_lens: List[int] + kv_item_lens: List[int] ++ aux_buffer_names: List[str] + aux_data_ptrs: List[int] + aux_data_lens: List[int] + aux_item_lens: List[int] diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 67fe82ad6..2ef25c49b 100644 --- a/python/sglang/srt/disaggregation/common/conn.py @@ -114,7 +126,7 @@ index 67fe82ad6..2ef25c49b 100644 "prefill_pp_size": self.pp_size, "prefill_page_size": self.page_size, diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py -index 1d8baf002..8f0a28741 100644 +index 1d8baf002..1ebb95929 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server @@ -125,7 +137,26 @@ index 1d8baf002..8f0a28741 100644 import time from collections import deque from dataclasses import dataclass -@@ -336,6 +337,16 @@ class DecodePreallocQueue: +@@ -40,8 +41,10 @@ from sglang.srt.disaggregation.utils import ( + MetadataBuffers, + ReqToMetadataIdxAllocator, + TransferBackend, ++ apply_prefill_timing_payload, + get_kv_class, + is_mla_backend, ++ is_slime_profiling_enabled, + kv_to_page_indices, + poll_and_all_reduce, + prepare_abort, +@@ -295,6 +298,7 @@ class DecodePreallocQueue: + kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( + self.metadata_buffers.get_buf_infos() + ) ++ kv_args.aux_buffer_names = self.metadata_buffers.get_aux_buffer_names() + + if hasattr(self.token_to_kv_pool, "get_state_buf_infos"): + state_data_ptrs, state_data_lens, state_item_lens = ( +@@ -336,6 +340,16 @@ class DecodePreallocQueue: ) return kv_manager @@ -142,7 +173,7 @@ index 1d8baf002..8f0a28741 100644 def add(self, req: Req, is_retracted: bool = False) -> None: """Add a request to the pending queue.""" if self._check_if_req_exceed_kv_capacity(req): -@@ -440,12 +451,37 @@ class DecodePreallocQueue: +@@ -440,12 +454,37 @@ class DecodePreallocQueue: [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) @@ -181,7 +212,38 @@ index 1d8baf002..8f0a28741 100644 elif poll == KVPoll.WaitingForInput: decode_req.waiting_for_input = True elif poll == KVPoll.Failed: -@@ -830,6 +866,13 @@ class DecodeTransferQueue: +@@ -590,6 +629,7 @@ class DecodePreallocQueue: + self.req_to_metadata_buffer_idx_allocator.alloc() + ) + assert decode_req.metadata_buffer_index is not None ++ self.metadata_buffers.clear_profiling_buf(decode_req.metadata_buffer_index) + page_indices = kv_to_page_indices(kv_indices, page_size) + decode_req.kv_receiver.init( + page_indices, decode_req.metadata_buffer_index, state_indices +@@ -751,6 +791,7 @@ class DecodeTransferQueue: + output_topk_index, + output_hidden_states, + output_bootstrap_room, ++ output_prefill_timing, + ) = self.metadata_buffers.get_buf(idx) + + # Validate bootstrap_room to detect context corruption +@@ -813,6 +854,14 @@ class DecodeTransferQueue: + output_top_logprobs_idx[: decode_req.req.top_logprobs_num].tolist() + ) + ++ # Inject prefill-side PD timing forwarded from the P instance. ++ # Layout: [bootstrap_queue, forward, transfer_queue, bootstrap, ++ # alloc_waiting, transfer_speed, transfer_mb, retry_count] ++ if is_slime_profiling_enabled(): ++ apply_prefill_timing_payload( ++ decode_req.req.time_stats, output_prefill_timing ++ ) ++ + decode_req.kv_receiver.clear() + decode_req.kv_receiver = None + trace_slice_end( +@@ -830,6 +879,13 @@ class DecodeTransferQueue: [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) @@ -195,7 +257,7 @@ index 1d8baf002..8f0a28741 100644 transferred_reqs = [] indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): -@@ -877,7 +920,20 @@ class DecodeTransferQueue: +@@ -877,7 +933,20 @@ class DecodeTransferQueue: KVPoll.WaitingForInput, KVPoll.Transferring, ]: @@ -217,7 +279,7 @@ index 1d8baf002..8f0a28741 100644 else: raise ValueError(f"Unexpected poll case: {poll}") -@@ -893,6 +949,14 @@ class DecodeTransferQueue: +@@ -893,6 +962,14 @@ class DecodeTransferQueue: return transferred_reqs @@ -232,7 +294,7 @@ index 1d8baf002..8f0a28741 100644 class SchedulerDisaggregationDecodeMixin: -@@ -1072,7 +1136,15 @@ class SchedulerDisaggregationDecodeMixin: +@@ -1072,7 +1149,15 @@ class SchedulerDisaggregationDecodeMixin: resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs() self.waiting_queue.extend(resumed_reqs) if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0: @@ -291,9 +353,18 @@ index a2d08e0e3..ed0790604 100644 mm_item = MultimodalDataItem.from_dict( { diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py -index d0d4efd95..960f585e2 100644 +index d0d4efd95..b3a207063 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py +@@ -30,7 +30,7 @@ from sglang.srt.disaggregation.common.utils import ( + from sglang.srt.disaggregation.mooncake.utils import ( + check_mooncake_custom_mem_pool_enabled, + ) +-from sglang.srt.disaggregation.utils import DisaggregationMode ++from sglang.srt.disaggregation.utils import DisaggregationMode, iter_aux_transfer_specs + from sglang.srt.distributed.parallel_state import get_mooncake_transfer_engine + from sglang.srt.environ import envs + from sglang.srt.server_args import ServerArgs @@ -260,6 +260,19 @@ class MooncakeKVManager(CommonKVManager): self.kv_args.state_data_ptrs, self.kv_args.state_data_lens ) @@ -314,7 +385,44 @@ index d0d4efd95..960f585e2 100644 def _transfer_data(self, mooncake_session_id, transfer_blocks): if not transfer_blocks: return 0 -@@ -643,13 +656,13 @@ class MooncakeKVManager(CommonKVManager): +@@ -524,10 +537,14 @@ class MooncakeKVManager(CommonKVManager): + prefill_aux_ptrs = self.kv_args.aux_data_ptrs + prefill_aux_item_lens = self.kv_args.aux_item_lens + +- for i, dst_aux_ptr in enumerate(dst_aux_ptrs): +- length = prefill_aux_item_lens[i] +- src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index +- dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index ++ for _, src_addr, dst_addr, length in iter_aux_transfer_specs( ++ self.kv_args.aux_buffer_names, ++ prefill_aux_ptrs, ++ prefill_aux_item_lens, ++ dst_aux_ptrs, ++ prefill_aux_index, ++ req.dst_aux_index, ++ ): + transfer_blocks.append((src_addr, dst_addr, length)) + + return self._transfer_data(req.mooncake_session_id, transfer_blocks) +@@ -541,9 +558,14 @@ class MooncakeKVManager(CommonKVManager): + prefill_aux_ptrs = self.kv_args.aux_data_ptrs + prefill_aux_item_lens = self.kv_args.aux_item_lens + +- for i in range(len(prefill_aux_ptrs)): +- length = prefill_aux_item_lens[i] +- src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index ++ for i, src_addr, _, length in iter_aux_transfer_specs( ++ self.kv_args.aux_buffer_names, ++ prefill_aux_ptrs, ++ prefill_aux_item_lens, ++ dst_aux_ptrs, ++ prefill_aux_index, ++ req.dst_aux_index, ++ ): + data = AuxDataCodec.serialize_data_from_buffer(src_addr, length) + + self.send_aux_data_to_endpoint( +@@ -643,13 +665,13 @@ class MooncakeKVManager(CommonKVManager): raise RuntimeError( f"PD Disaggregation does NOT support PD different TP sizes for non-MLA {state_type.upper()} hybrid models yet." ) @@ -334,7 +442,7 @@ index d0d4efd95..960f585e2 100644 # Reuse _send_kvcache_generic interface to send extra pool data prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32) dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32) -@@ -858,12 +871,6 @@ class MooncakeKVManager(CommonKVManager): +@@ -858,12 +880,6 @@ class MooncakeKVManager(CommonKVManager): if ret != 0: with self.session_lock: self.session_failures[req.mooncake_session_id] += 1 @@ -347,7 +455,7 @@ index d0d4efd95..960f585e2 100644 self.record_failure( kv_chunk.room, f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}", -@@ -880,13 +887,31 @@ class MooncakeKVManager(CommonKVManager): +@@ -880,13 +896,31 @@ class MooncakeKVManager(CommonKVManager): if kv_chunk.is_last: if kv_chunk.state_indices is not None: @@ -380,7 +488,7 @@ index d0d4efd95..960f585e2 100644 # Only the last chunk we need to send the aux data ret = self.send_aux( -@@ -895,6 +920,11 @@ class MooncakeKVManager(CommonKVManager): +@@ -895,6 +929,11 @@ class MooncakeKVManager(CommonKVManager): target_rank_registration_info.dst_aux_ptrs, ) polls.append(True if ret == 0 else False) @@ -392,7 +500,7 @@ index d0d4efd95..960f585e2 100644 dst_ranks_infos.append( (req.endpoint, req.dst_port, req.room) ) -@@ -977,15 +1007,20 @@ class MooncakeKVManager(CommonKVManager): +@@ -977,15 +1016,20 @@ class MooncakeKVManager(CommonKVManager): if status == KVPoll.Success: if bootstrap_room in self.request_status: @@ -420,7 +528,7 @@ index d0d4efd95..960f585e2 100644 elif status == KVPoll.Failed: self.record_failure( bootstrap_room, -@@ -1266,7 +1301,10 @@ class MooncakeKVReceiver(CommonKVReceiver): +@@ -1266,7 +1310,10 @@ class MooncakeKVReceiver(CommonKVReceiver): super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) @@ -433,7 +541,7 @@ index d0d4efd95..960f585e2 100644 def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py -index fbc801635..ff1f52796 100644 +index fbc801635..ade111c9f 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -20,6 +20,7 @@ Life cycle of a request in the prefill server @@ -444,7 +552,15 @@ index fbc801635..ff1f52796 100644 import time from collections import deque from http import HTTPStatus -@@ -276,6 +277,12 @@ class PrefillBootstrapQueue: +@@ -167,6 +168,7 @@ class PrefillBootstrapQueue: + kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( + self.metadata_buffers.get_buf_infos() + ) ++ kv_args.aux_buffer_names = self.metadata_buffers.get_aux_buffer_names() + kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device + kv_args.gpu_id = self.scheduler.gpu_id + +@@ -276,6 +278,12 @@ class PrefillBootstrapQueue: [req.disagg_kv_sender for req in self.queue], self.gloo_group ) @@ -457,7 +573,7 @@ index fbc801635..ff1f52796 100644 for i, (req, poll) in enumerate(zip(self.queue, polls)): if rids_to_check is not None: # if req not in reqs_info_to_check, skip -@@ -283,6 +290,27 @@ class PrefillBootstrapQueue: +@@ -283,6 +291,27 @@ class PrefillBootstrapQueue: continue if poll == KVPoll.Bootstrapping: @@ -485,7 +601,7 @@ index fbc801635..ff1f52796 100644 continue elif poll == KVPoll.Failed: error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" -@@ -335,6 +363,15 @@ class PrefillBootstrapQueue: +@@ -335,6 +364,15 @@ class PrefillBootstrapQueue: else: return bootstrapped_reqs, failed_reqs @@ -501,7 +617,7 @@ index fbc801635..ff1f52796 100644 class SchedulerDisaggregationPrefillMixin: """ -@@ -547,6 +584,18 @@ class SchedulerDisaggregationPrefillMixin: +@@ -547,6 +585,18 @@ class SchedulerDisaggregationPrefillMixin: self.maybe_send_health_check_signal() @@ -520,7 +636,7 @@ index fbc801635..ff1f52796 100644 def process_disagg_prefill_inflight_queue( self: Scheduler, rids_to_check: Optional[List[str]] = None ) -> List[Req]: -@@ -559,11 +608,24 @@ class SchedulerDisaggregationPrefillMixin: +@@ -559,11 +609,24 @@ class SchedulerDisaggregationPrefillMixin: done_reqs = [] @@ -546,7 +662,7 @@ index fbc801635..ff1f52796 100644 undone_reqs: List[Req] = [] # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue for req, poll in zip(self.disagg_prefill_inflight_queue, polls): -@@ -573,10 +635,35 @@ class SchedulerDisaggregationPrefillMixin: +@@ -573,10 +636,35 @@ class SchedulerDisaggregationPrefillMixin: undone_reqs.append(req) continue @@ -584,7 +700,7 @@ index fbc801635..ff1f52796 100644 elif poll == KVPoll.Success: # transfer done release_kv_cache(req, self.tree_cache) # unlock the tree req.finished_reason = FINISH_LENGTH(length=0) -@@ -628,9 +715,12 @@ class SchedulerDisaggregationPrefillMixin: +@@ -628,9 +716,12 @@ class SchedulerDisaggregationPrefillMixin: """ Used by PP, get the transferred rids but **do not pop** """ @@ -598,6 +714,198 @@ index fbc801635..ff1f52796 100644 ) transferred_rids: List[str] = [] +diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py +index 6d58f415a..84723c342 100644 +--- a/python/sglang/srt/disaggregation/utils.py ++++ b/python/sglang/srt/disaggregation/utils.py +@@ -21,6 +21,17 @@ if TYPE_CHECKING: + # Constants & Enums + ######################### + FAKE_BOOTSTRAP_HOST = "2.2.2.2" ++PREFILL_TIMING_AUX_BUFFER_NAME = "prefill_timing" ++PREFILL_TIMING_DEST_ATTRS = ( ++ ("fwd_prefill_bootstrap_queue_duration", float), ++ ("fwd_prefill_forward_duration", float), ++ ("fwd_prefill_transfer_queue_duration", float), ++ ("fwd_bootstrap_duration", float), ++ ("fwd_alloc_waiting_duration", float), ++ ("fwd_transfer_speed_gb_s", float), ++ ("fwd_transfer_total_mb", float), ++ ("fwd_prefill_retry_count", int), ++) + + + class DisaggregationMode(Enum): +@@ -139,46 +150,35 @@ class MetadataBuffers: + self.bootstrap_room = torch.zeros( + (size, 8), dtype=torch.uint64, device=device + ) ++ # Prefill-side PD timing (8 floats, padded to 16 for RDMA alignment). ++ # Layout: [bootstrap_queue, forward, transfer_queue, bootstrap, ++ # alloc_waiting, transfer_speed, transfer_mb, retry_count] ++ self.prefill_timing = torch.zeros( ++ (size, 16), dtype=torch.float32, device=device ++ ) ++ self.aux_buffers = [ ++ ("output_ids", self.output_ids), ++ ("cached_tokens", self.cached_tokens), ++ ("output_token_logprobs_val", self.output_token_logprobs_val), ++ ("output_token_logprobs_idx", self.output_token_logprobs_idx), ++ ("output_top_logprobs_val", self.output_top_logprobs_val), ++ ("output_top_logprobs_idx", self.output_top_logprobs_idx), ++ ("output_topk_p", self.output_topk_p), ++ ("output_topk_index", self.output_topk_index), ++ ("output_hidden_states", self.output_hidden_states), ++ ("bootstrap_room", self.bootstrap_room), ++ (PREFILL_TIMING_AUX_BUFFER_NAME, self.prefill_timing), ++ ] + + def get_buf_infos(self): +- ptrs = [ +- self.output_ids.data_ptr(), +- self.cached_tokens.data_ptr(), +- self.output_token_logprobs_val.data_ptr(), +- self.output_token_logprobs_idx.data_ptr(), +- self.output_top_logprobs_val.data_ptr(), +- self.output_top_logprobs_idx.data_ptr(), +- self.output_topk_p.data_ptr(), +- self.output_topk_index.data_ptr(), +- self.output_hidden_states.data_ptr(), +- self.bootstrap_room.data_ptr(), +- ] +- data_lens = [ +- self.output_ids.nbytes, +- self.cached_tokens.nbytes, +- self.output_token_logprobs_val.nbytes, +- self.output_token_logprobs_idx.nbytes, +- self.output_top_logprobs_val.nbytes, +- self.output_top_logprobs_idx.nbytes, +- self.output_topk_p.nbytes, +- self.output_topk_index.nbytes, +- self.output_hidden_states.nbytes, +- self.bootstrap_room.nbytes, +- ] +- item_lens = [ +- self.output_ids[0].nbytes, +- self.cached_tokens[0].nbytes, +- self.output_token_logprobs_val[0].nbytes, +- self.output_token_logprobs_idx[0].nbytes, +- self.output_top_logprobs_val[0].nbytes, +- self.output_top_logprobs_idx[0].nbytes, +- self.output_topk_p[0].nbytes, +- self.output_topk_index[0].nbytes, +- self.output_hidden_states[0].nbytes, +- self.bootstrap_room[0].nbytes, +- ] ++ ptrs = [buffer.data_ptr() for _, buffer in self.aux_buffers] ++ data_lens = [buffer.nbytes for _, buffer in self.aux_buffers] ++ item_lens = [buffer[0].nbytes for _, buffer in self.aux_buffers] + return ptrs, data_lens, item_lens + ++ def get_aux_buffer_names(self): ++ return [name for name, _ in self.aux_buffers] ++ + def get_buf(self, idx: int): + return ( + self.output_ids[idx], +@@ -191,8 +191,12 @@ class MetadataBuffers: + self.output_topk_index[idx], + self.output_hidden_states[idx], + self.bootstrap_room[idx], ++ self.prefill_timing[idx], + ) + ++ def clear_profiling_buf(self, idx: int): ++ self.prefill_timing[idx].zero_() ++ + def set_buf(self, req: Req): + + self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] +@@ -237,6 +241,84 @@ class MetadataBuffers: + self.bootstrap_room[req.metadata_buffer_index, 0] = ( + req.bootstrap_room if req.bootstrap_room is not None else 0 + ) ++ # Pack prefill-side PD timing durations for transfer to decode instance. ++ # Note: set_buf is called at the START of the last KV chunk send, so ++ # completion_time and prefill_transfer_queue_entry_time are not yet set. ++ # We use time.perf_counter() as the "forward just completed" timestamp. ++ import time ++ ++ ts = req.time_stats ++ timing = self.prefill_timing[req.metadata_buffer_index] ++ self.clear_profiling_buf(req.metadata_buffer_index) ++ if not is_slime_profiling_enabled(): ++ return ++ for idx, value in enumerate( ++ build_prefill_timing_payload(ts, now=time.perf_counter()) ++ ): ++ if value > 0: ++ timing[idx] = value ++ ++ ++def is_slime_profiling_enabled() -> bool: ++ return envs.SLIME_ENABLE_PROFILING.get() ++ ++ ++def build_prefill_timing_payload(time_stats, now: float) -> tuple[float, ...]: ++ bootstrap_queue_duration = 0.0 ++ if ( ++ time_stats.prefill_bootstrap_queue_entry_time > 0 ++ and time_stats.wait_queue_entry_time > 0 ++ ): ++ bootstrap_queue_duration = ( ++ time_stats.wait_queue_entry_time ++ - time_stats.prefill_bootstrap_queue_entry_time ++ ) ++ ++ prefill_forward_duration = ( ++ now - time_stats.forward_entry_time ++ if time_stats.forward_entry_time > 0 ++ else 0.0 ++ ) ++ ++ return ( ++ bootstrap_queue_duration, ++ prefill_forward_duration, ++ 0.0, ++ max(0.0, time_stats.bootstrap_duration), ++ max(0.0, time_stats.alloc_waiting_duration), ++ max(0.0, time_stats.transfer_speed_gb_s), ++ max(0.0, time_stats.transfer_total_mb), ++ float(max(0, time_stats.prefill_retry_count)), ++ ) ++ ++ ++def apply_prefill_timing_payload(time_stats, timing) -> None: ++ for value, (attr_name, caster) in zip( ++ timing[: len(PREFILL_TIMING_DEST_ATTRS)].tolist(), ++ PREFILL_TIMING_DEST_ATTRS, ++ ): ++ if value > 0: ++ setattr(time_stats, attr_name, caster(value)) ++ ++ ++def iter_aux_transfer_specs( ++ aux_buffer_names: list[str], ++ prefill_aux_ptrs: list[int], ++ prefill_aux_item_lens: list[int], ++ dst_aux_ptrs: list[int], ++ prefill_aux_index: int, ++ dst_aux_index: int, ++): ++ profiling_enabled = is_slime_profiling_enabled() ++ for i, (buffer_name, dst_aux_ptr) in enumerate(zip(aux_buffer_names, dst_aux_ptrs)): ++ if not profiling_enabled and buffer_name == PREFILL_TIMING_AUX_BUFFER_NAME: ++ continue ++ length = prefill_aux_item_lens[i] ++ if length <= 0: ++ continue ++ src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index ++ dst_addr = dst_aux_ptr + length * dst_aux_index ++ yield i, src_addr, dst_addr, length + + + ######################### diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 8f1069c00..e47589295 100644 --- a/python/sglang/srt/distributed/parallel_state.py @@ -723,6 +1031,18 @@ index 1d6816c01..402b42e05 100644 @app.post("/update_weight_version") @auth_level(AuthLevel.ADMIN_OPTIONAL) async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): +diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py +index 8293796a2..bff34e422 100644 +--- a/python/sglang/srt/environ.py ++++ b/python/sglang/srt/environ.py +@@ -244,6 +244,7 @@ class Envs: + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE = EnvInt(2) + SGLANG_DISAGGREGATION_WAITING_TIMEOUT = EnvInt(300) + SGLANG_DISAGGREGATION_NIXL_BACKEND = EnvStr("UCX") ++ SLIME_ENABLE_PROFILING = EnvBool(False) + + # Scheduler: others: + SGLANG_EMPTY_CACHE_INTERVAL = EnvFloat(-1) # in seconds. Set if you observe high memory accumulation over a long serving period. diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py index 1cdf65b91..4783cd18f 100644 --- a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py @@ -1348,7 +1668,7 @@ index 00bd68755..12d5577af 100644 def get_routed_experts( diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -index 8539639d5..e7f5d1565 100644 +index 8539639d5..d44496c2f 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -388,6 +388,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): @@ -1387,7 +1707,7 @@ index 8539639d5..e7f5d1565 100644 output = hidden_states else: raise NotImplementedError() # triton runner was supported but it's temporarily disabled -@@ -551,10 +562,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): +@@ -551,10 +562,18 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): buffer = self._get_buffer() topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids topk_ids = topk_ids.to(torch.int64) @@ -1395,16 +1715,22 @@ index 8539639d5..e7f5d1565 100644 - hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1] - + self.num_experts - ) // self.num_experts -+ # Use a correctness-preserving upper bound for per-expert token count. -+ # In the worst case, every rank routes all local tokens to the same expert. -+ expected_m = min( -+ hidden_states.shape[0] * buffer.group_size, -+ self.num_max_dispatch_tokens_per_rank * buffer.group_size, -+ ) ++ if self.quant_config.get("bf16_weights", False): ++ # BF16 low-latency path slices hidden_states[:, :expected_m, :], so ++ # expected_m must remain a correctness-preserving upper bound. ++ expected_m = min( ++ hidden_states.shape[0] * buffer.group_size, ++ self.num_max_dispatch_tokens_per_rank * buffer.group_size, ++ ) ++ else: ++ expected_m = ( ++ hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1] ++ + self.num_experts ++ ) // self.num_experts hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, topk_ids, -@@ -609,7 +622,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): +@@ -609,7 +628,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): input_global_scale = self.quant_config.get("input_global_scale", None) if input_global_scale is not None: use_nvfp4 = True @@ -1581,11 +1907,76 @@ index ae0614635..3b6a8d254 100644 # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 assert ( fused_set_kv_buffer_arg is None +diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py +index 652227860..7d3a5d0c4 100644 +--- a/python/sglang/srt/managers/detokenizer_manager.py ++++ b/python/sglang/srt/managers/detokenizer_manager.py +@@ -405,6 +405,17 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + prefill_launch_delay=recv_obj.prefill_launch_delay, + prefill_launch_latency=recv_obj.prefill_launch_latency, + prefill_finished_ts=recv_obj.prefill_finished_ts, ++ pd_prefill_bootstrap_queue_duration=recv_obj.pd_prefill_bootstrap_queue_duration, ++ pd_prefill_forward_duration=recv_obj.pd_prefill_forward_duration, ++ pd_prefill_transfer_queue_duration=recv_obj.pd_prefill_transfer_queue_duration, ++ pd_decode_prealloc_duration=recv_obj.pd_decode_prealloc_duration, ++ pd_decode_transfer_duration=recv_obj.pd_decode_transfer_duration, ++ pd_decode_forward_duration=recv_obj.pd_decode_forward_duration, ++ pd_bootstrap_duration=recv_obj.pd_bootstrap_duration, ++ pd_alloc_waiting_duration=recv_obj.pd_alloc_waiting_duration, ++ pd_transfer_speed_gb_s=recv_obj.pd_transfer_speed_gb_s, ++ pd_transfer_total_mb=recv_obj.pd_transfer_total_mb, ++ pd_prefill_retry_count=recv_obj.pd_prefill_retry_count, + ) + + def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index ff1774567..42d27a82a 100644 +index ff1774567..f947e71d7 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py -@@ -1403,6 +1403,20 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): +@@ -101,6 +101,42 @@ class RequestTimingMetricsMixin: + # This marks when the prefill computation finishes. + prefill_finished_ts: Optional[List[Optional[float]]] + ++ # --- PD disaggregation timing fields --- ++ # All fields are None when profiling is disabled or not in PD disaggregation mode. ++ ++ # P instance: duration spent in bootstrap queue before entering the wait queue. ++ pd_prefill_bootstrap_queue_duration: Optional[List[Optional[float]]] ++ ++ # P instance: duration for the actual prefill forward computation. ++ pd_prefill_forward_duration: Optional[List[Optional[float]]] ++ ++ # P instance: duration spent in the KV transfer queue. ++ pd_prefill_transfer_queue_duration: Optional[List[Optional[float]]] ++ ++ # D instance: duration waiting for KV cache slot pre-allocation. ++ pd_decode_prealloc_duration: Optional[List[Optional[float]]] ++ ++ # D instance: duration waiting for the KV cache transfer to complete. ++ pd_decode_transfer_duration: Optional[List[Optional[float]]] ++ ++ # D instance: duration for the actual decode forward computation. ++ pd_decode_forward_duration: Optional[List[Optional[float]]] ++ ++ # Bootstrap handshake duration (P and D instances). ++ pd_bootstrap_duration: Optional[List[Optional[float]]] ++ ++ # KV cache allocation waiting duration (P and D instances). ++ pd_alloc_waiting_duration: Optional[List[Optional[float]]] ++ ++ # KV cache transfer speed in GB/s. ++ pd_transfer_speed_gb_s: Optional[List[Optional[float]]] ++ ++ # Total KV cache transferred in MB. ++ pd_transfer_total_mb: Optional[List[Optional[float]]] ++ ++ # Number of prefill retries (P instance only). ++ pd_prefill_retry_count: Optional[List[Optional[int]]] ++ + + @dataclass + class SpeculativeDecodingMetricsMixin: +@@ -1403,6 +1439,20 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): message: str @@ -1606,7 +1997,7 @@ index ff1774567..42d27a82a 100644 @dataclass class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): success: bool -@@ -1802,6 +1816,10 @@ class GetLoadReqOutput(BaseReq): +@@ -1802,6 +1852,10 @@ class GetLoadReqOutput(BaseReq): num_waiting_reqs: int num_tokens: int ts_tic: float @@ -1617,6 +2008,202 @@ index ff1774567..42d27a82a 100644 @dataclass +diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py +index e1236aa0f..daa598a1f 100644 +--- a/python/sglang/srt/managers/multi_tokenizer_mixin.py ++++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py +@@ -142,6 +142,39 @@ def _handle_output_by_index(output, i): + prefill_finished_ts=_extract_field_by_index( + output, "prefill_finished_ts", i + ), ++ pd_prefill_bootstrap_queue_duration=_extract_field_by_index( ++ output, "pd_prefill_bootstrap_queue_duration", i ++ ), ++ pd_prefill_forward_duration=_extract_field_by_index( ++ output, "pd_prefill_forward_duration", i ++ ), ++ pd_prefill_transfer_queue_duration=_extract_field_by_index( ++ output, "pd_prefill_transfer_queue_duration", i ++ ), ++ pd_decode_prealloc_duration=_extract_field_by_index( ++ output, "pd_decode_prealloc_duration", i ++ ), ++ pd_decode_transfer_duration=_extract_field_by_index( ++ output, "pd_decode_transfer_duration", i ++ ), ++ pd_decode_forward_duration=_extract_field_by_index( ++ output, "pd_decode_forward_duration", i ++ ), ++ pd_bootstrap_duration=_extract_field_by_index( ++ output, "pd_bootstrap_duration", i ++ ), ++ pd_alloc_waiting_duration=_extract_field_by_index( ++ output, "pd_alloc_waiting_duration", i ++ ), ++ pd_transfer_speed_gb_s=_extract_field_by_index( ++ output, "pd_transfer_speed_gb_s", i ++ ), ++ pd_transfer_total_mb=_extract_field_by_index( ++ output, "pd_transfer_total_mb", i ++ ), ++ pd_prefill_retry_count=_extract_field_by_index( ++ output, "pd_prefill_retry_count", i ++ ), + finished_reasons=_extract_field_by_index(output, "finished_reasons", i), + decoded_texts=_extract_field_by_index(output, "decoded_texts", i), + decode_ids=_extract_field_by_index(output, "decode_ids", i), +@@ -211,6 +244,50 @@ def _handle_output_by_index(output, i): + elif isinstance(output, BatchEmbeddingOutput): + new_output = BatchEmbeddingOutput( + rids=[output.rids[i]], ++ queue_time=_extract_field_by_index(output, "queue_time", i), ++ forward_entry_time=_extract_field_by_index(output, "forward_entry_time", i), ++ prefill_launch_delay=_extract_field_by_index( ++ output, "prefill_launch_delay", i ++ ), ++ prefill_launch_latency=_extract_field_by_index( ++ output, "prefill_launch_latency", i ++ ), ++ prefill_finished_ts=_extract_field_by_index( ++ output, "prefill_finished_ts", i ++ ), ++ pd_prefill_bootstrap_queue_duration=_extract_field_by_index( ++ output, "pd_prefill_bootstrap_queue_duration", i ++ ), ++ pd_prefill_forward_duration=_extract_field_by_index( ++ output, "pd_prefill_forward_duration", i ++ ), ++ pd_prefill_transfer_queue_duration=_extract_field_by_index( ++ output, "pd_prefill_transfer_queue_duration", i ++ ), ++ pd_decode_prealloc_duration=_extract_field_by_index( ++ output, "pd_decode_prealloc_duration", i ++ ), ++ pd_decode_transfer_duration=_extract_field_by_index( ++ output, "pd_decode_transfer_duration", i ++ ), ++ pd_decode_forward_duration=_extract_field_by_index( ++ output, "pd_decode_forward_duration", i ++ ), ++ pd_bootstrap_duration=_extract_field_by_index( ++ output, "pd_bootstrap_duration", i ++ ), ++ pd_alloc_waiting_duration=_extract_field_by_index( ++ output, "pd_alloc_waiting_duration", i ++ ), ++ pd_transfer_speed_gb_s=_extract_field_by_index( ++ output, "pd_transfer_speed_gb_s", i ++ ), ++ pd_transfer_total_mb=_extract_field_by_index( ++ output, "pd_transfer_total_mb", i ++ ), ++ pd_prefill_retry_count=_extract_field_by_index( ++ output, "pd_prefill_retry_count", i ++ ), + finished_reasons=_extract_field_by_index(output, "finished_reasons", i), + embeddings=_extract_field_by_index(output, "embeddings", i), + prompt_tokens=_extract_field_by_index(output, "prompt_tokens", i), +@@ -239,6 +316,39 @@ def _handle_output_by_index(output, i): + prefill_finished_ts=_extract_field_by_index( + output, "prefill_finished_ts", i + ), ++ pd_prefill_bootstrap_queue_duration=_extract_field_by_index( ++ output, "pd_prefill_bootstrap_queue_duration", i ++ ), ++ pd_prefill_forward_duration=_extract_field_by_index( ++ output, "pd_prefill_forward_duration", i ++ ), ++ pd_prefill_transfer_queue_duration=_extract_field_by_index( ++ output, "pd_prefill_transfer_queue_duration", i ++ ), ++ pd_decode_prealloc_duration=_extract_field_by_index( ++ output, "pd_decode_prealloc_duration", i ++ ), ++ pd_decode_transfer_duration=_extract_field_by_index( ++ output, "pd_decode_transfer_duration", i ++ ), ++ pd_decode_forward_duration=_extract_field_by_index( ++ output, "pd_decode_forward_duration", i ++ ), ++ pd_bootstrap_duration=_extract_field_by_index( ++ output, "pd_bootstrap_duration", i ++ ), ++ pd_alloc_waiting_duration=_extract_field_by_index( ++ output, "pd_alloc_waiting_duration", i ++ ), ++ pd_transfer_speed_gb_s=_extract_field_by_index( ++ output, "pd_transfer_speed_gb_s", i ++ ), ++ pd_transfer_total_mb=_extract_field_by_index( ++ output, "pd_transfer_total_mb", i ++ ), ++ pd_prefill_retry_count=_extract_field_by_index( ++ output, "pd_prefill_retry_count", i ++ ), + finished_reasons=_extract_field_by_index(output, "finished_reasons", i), + output_strs=_extract_field_by_index(output, "output_strs", i), + output_ids=_extract_field_by_index(output, "output_ids", i), +@@ -524,6 +634,60 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10): + "uvicorn.supervisors.multiprocess not found, skipping monkey patch" + ) + ++ # Fix stdin fd issue when running under Ray (or other managed ++ # environments where stdin may not be a real terminal): ++ # ++ # Uvicorn's get_subprocess() captures sys.stdin.fileno() in the parent ++ # and passes it to spawn'd children, which call os.fdopen(stdin_fileno) ++ # to re-attach stdin. This is intended for interactive debugging (e.g. ++ # pdb attach to a child worker). ++ # ++ # In Ray Actors, sys.stdin.fileno() succeeds in the parent (returns a ++ # valid fd number), but the fd is not inheritable across spawn. The ++ # child's os.fdopen() then crashes with OSError: [Errno 9] Bad file ++ # descriptor, killing every tokenizer worker. ++ # ++ # Instead of unconditionally disabling stdin passthrough, we probe ++ # whether the fd is truly usable by dup'ing it. If os.dup() fails, ++ # the fd won't survive spawn either, so we fall back to None. In a ++ # normal terminal environment os.dup() succeeds and debugging ability ++ # is preserved. ++ try: ++ import uvicorn._subprocess as _uv_sub ++ import uvicorn.supervisors.multiprocess as _uv_mp ++ ++ def _safe_get_stdin_fileno(): ++ """Return stdin fileno only if it is genuinely usable.""" ++ try: ++ fileno = sys.stdin.fileno() ++ # Verify the fd is valid and duplicable — if it isn't, ++ # spawn'd children won't be able to reopen it either. ++ dup_fd = os.dup(fileno) ++ os.close(dup_fd) ++ return fileno ++ except (AttributeError, OSError): ++ return None ++ ++ def _patched_get_subprocess(config, target, sockets): ++ stdin_fileno = _safe_get_stdin_fileno() ++ kwargs = { ++ "config": config, ++ "target": target, ++ "sockets": sockets, ++ "stdin_fileno": stdin_fileno, ++ } ++ return _uv_sub.spawn.Process( ++ target=_uv_sub.subprocess_started, kwargs=kwargs ++ ) ++ ++ # Must patch both: the supervisor module caches its own reference ++ # to get_subprocess at import time via ++ # ``from uvicorn._subprocess import get_subprocess``. ++ _uv_sub.get_subprocess = _patched_get_subprocess ++ _uv_mp.get_subprocess = _patched_get_subprocess ++ except Exception: ++ pass ++ + + class SenderWrapper: + def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c07995798..dd8ca7167 100644 --- a/python/sglang/srt/managers/schedule_batch.py @@ -1734,10 +2321,70 @@ index 30b2732b9..68090b161 100644 def get_loads(self: Scheduler, req: GetLoadsReqInput = None) -> GetLoadsReqOutput: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -index 482bc6ca6..857cfa6a3 100644 +index 482bc6ca6..fbc486417 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -@@ -1134,7 +1134,7 @@ class SchedulerOutputProcessorMixin: +@@ -922,6 +922,18 @@ class SchedulerOutputProcessorMixin: + prefill_launch_delays = [] + prefill_launch_latencies = [] + prefill_finished_timestamps = [] ++ profiling_enabled = envs.SLIME_ENABLE_PROFILING.get() ++ pd_prefill_bootstrap_queue_durations = [] if profiling_enabled else None ++ pd_prefill_forward_durations = [] if profiling_enabled else None ++ pd_prefill_transfer_queue_durations = [] if profiling_enabled else None ++ pd_decode_prealloc_durations = [] if profiling_enabled else None ++ pd_decode_transfer_durations = [] if profiling_enabled else None ++ pd_decode_forward_durations = [] if profiling_enabled else None ++ pd_bootstrap_durations = [] if profiling_enabled else None ++ pd_alloc_waiting_durations = [] if profiling_enabled else None ++ pd_transfer_speeds_gb_s = [] if profiling_enabled else None ++ pd_transfer_totals_mb = [] if profiling_enabled else None ++ pd_prefill_retry_counts = [] if profiling_enabled else None + + if return_logprob: + input_token_logprobs_val = [] +@@ -1037,6 +1049,40 @@ class SchedulerOutputProcessorMixin: + prefill_finished_timestamps.append( + req.time_stats.get_prefill_finished_ts() + ) ++ if profiling_enabled: ++ pd_prefill_bootstrap_queue_durations.append( ++ req.time_stats.get_pd_prefill_bootstrap_queue_duration() ++ ) ++ pd_prefill_forward_durations.append( ++ req.time_stats.get_pd_prefill_forward_duration() ++ ) ++ pd_prefill_transfer_queue_durations.append( ++ req.time_stats.get_pd_prefill_transfer_queue_duration() ++ ) ++ pd_decode_prealloc_durations.append( ++ req.time_stats.get_pd_decode_prealloc_duration() ++ ) ++ pd_decode_transfer_durations.append( ++ req.time_stats.get_pd_decode_transfer_duration() ++ ) ++ pd_decode_forward_durations.append( ++ req.time_stats.get_pd_decode_forward_duration() ++ ) ++ pd_bootstrap_durations.append( ++ req.time_stats.get_pd_bootstrap_duration() ++ ) ++ pd_alloc_waiting_durations.append( ++ req.time_stats.get_pd_alloc_waiting_duration() ++ ) ++ pd_transfer_speeds_gb_s.append( ++ req.time_stats.get_pd_transfer_speed_gb_s() ++ ) ++ pd_transfer_totals_mb.append( ++ req.time_stats.get_pd_transfer_total_mb() ++ ) ++ pd_prefill_retry_counts.append( ++ req.time_stats.get_pd_prefill_retry_count() ++ ) + + if not self.spec_algorithm.is_none(): + spec_verify_ct.append(req.spec_verify_ct) +@@ -1134,7 +1180,7 @@ class SchedulerOutputProcessorMixin: req.log_time_stats() # Send to detokenizer @@ -1746,6 +2393,102 @@ index 482bc6ca6..857cfa6a3 100644 if self.model_config.is_multimodal_gen: return self.send_to_detokenizer.send_output( +@@ -1149,6 +1195,17 @@ class SchedulerOutputProcessorMixin: + prefill_launch_delay=prefill_launch_delays, + prefill_launch_latency=prefill_launch_latencies, + prefill_finished_ts=prefill_finished_timestamps, ++ pd_prefill_bootstrap_queue_duration=pd_prefill_bootstrap_queue_durations, ++ pd_prefill_forward_duration=pd_prefill_forward_durations, ++ pd_prefill_transfer_queue_duration=pd_prefill_transfer_queue_durations, ++ pd_decode_prealloc_duration=pd_decode_prealloc_durations, ++ pd_decode_transfer_duration=pd_decode_transfer_durations, ++ pd_decode_forward_duration=pd_decode_forward_durations, ++ pd_bootstrap_duration=pd_bootstrap_durations, ++ pd_alloc_waiting_duration=pd_alloc_waiting_durations, ++ pd_transfer_speed_gb_s=pd_transfer_speeds_gb_s, ++ pd_transfer_total_mb=pd_transfer_totals_mb, ++ pd_prefill_retry_count=pd_prefill_retry_counts, + finished_reasons=finished_reasons, + decoded_texts=decoded_texts, + decode_ids=decode_ids_list, +@@ -1198,6 +1255,18 @@ class SchedulerOutputProcessorMixin: + prefill_launch_delays = [] + prefill_launch_latencies = [] + prefill_finished_timestamps = [] ++ profiling_enabled = envs.SLIME_ENABLE_PROFILING.get() ++ pd_prefill_bootstrap_queue_durations = [] if profiling_enabled else None ++ pd_prefill_forward_durations = [] if profiling_enabled else None ++ pd_prefill_transfer_queue_durations = [] if profiling_enabled else None ++ pd_decode_prealloc_durations = [] if profiling_enabled else None ++ pd_decode_transfer_durations = [] if profiling_enabled else None ++ pd_decode_forward_durations = [] if profiling_enabled else None ++ pd_bootstrap_durations = [] if profiling_enabled else None ++ pd_alloc_waiting_durations = [] if profiling_enabled else None ++ pd_transfer_speeds_gb_s = [] if profiling_enabled else None ++ pd_transfer_totals_mb = [] if profiling_enabled else None ++ pd_prefill_retry_counts = [] if profiling_enabled else None + retraction_counts = [] + for req in reqs: + if req.finished(): +@@ -1221,6 +1290,40 @@ class SchedulerOutputProcessorMixin: + prefill_finished_timestamps.append( + req.time_stats.get_prefill_finished_ts() + ) ++ if profiling_enabled: ++ pd_prefill_bootstrap_queue_durations.append( ++ req.time_stats.get_pd_prefill_bootstrap_queue_duration() ++ ) ++ pd_prefill_forward_durations.append( ++ req.time_stats.get_pd_prefill_forward_duration() ++ ) ++ pd_prefill_transfer_queue_durations.append( ++ req.time_stats.get_pd_prefill_transfer_queue_duration() ++ ) ++ pd_decode_prealloc_durations.append( ++ req.time_stats.get_pd_decode_prealloc_duration() ++ ) ++ pd_decode_transfer_durations.append( ++ req.time_stats.get_pd_decode_transfer_duration() ++ ) ++ pd_decode_forward_durations.append( ++ req.time_stats.get_pd_decode_forward_duration() ++ ) ++ pd_bootstrap_durations.append( ++ req.time_stats.get_pd_bootstrap_duration() ++ ) ++ pd_alloc_waiting_durations.append( ++ req.time_stats.get_pd_alloc_waiting_duration() ++ ) ++ pd_transfer_speeds_gb_s.append( ++ req.time_stats.get_pd_transfer_speed_gb_s() ++ ) ++ pd_transfer_totals_mb.append( ++ req.time_stats.get_pd_transfer_total_mb() ++ ) ++ pd_prefill_retry_counts.append( ++ req.time_stats.get_pd_prefill_retry_count() ++ ) + retraction_counts.append(req.retraction_count) + self.send_to_detokenizer.send_output( + BatchEmbeddingOutput( +@@ -1231,6 +1334,17 @@ class SchedulerOutputProcessorMixin: + prefill_launch_delay=prefill_launch_delays, + prefill_launch_latency=prefill_launch_latencies, + prefill_finished_ts=prefill_finished_timestamps, ++ pd_prefill_bootstrap_queue_duration=pd_prefill_bootstrap_queue_durations, ++ pd_prefill_forward_duration=pd_prefill_forward_durations, ++ pd_prefill_transfer_queue_duration=pd_prefill_transfer_queue_durations, ++ pd_decode_prealloc_duration=pd_decode_prealloc_durations, ++ pd_decode_transfer_duration=pd_decode_transfer_durations, ++ pd_decode_forward_duration=pd_decode_forward_durations, ++ pd_bootstrap_duration=pd_bootstrap_durations, ++ pd_alloc_waiting_duration=pd_alloc_waiting_durations, ++ pd_transfer_speed_gb_s=pd_transfer_speeds_gb_s, ++ pd_transfer_total_mb=pd_transfer_totals_mb, ++ pd_prefill_retry_count=pd_prefill_retry_counts, + finished_reasons=finished_reasons, + embeddings=embeddings, + prompt_tokens=prompt_tokens, diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 1a65a3c3d..f76606469 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -2074,7 +2817,7 @@ index f2ffa9909..6e4d1d460 100644 self, obj: InitWeightsSendGroupForRemoteInstanceReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 0914a5230..cce2d8a2b 100644 +index 0914a5230..33bb3844a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -324,8 +324,12 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi @@ -2109,6 +2852,82 @@ index 0914a5230..cce2d8a2b 100644 self.is_pause_cond.notify_all() async def update_weights_from_disk( +@@ -1510,6 +1514,40 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + self._add_metric_if_present( + recv_obj, "prefill_finished_ts", meta_info, i + ) ++ # PD disaggregation timing ++ self._add_metric_if_present( ++ recv_obj, "pd_prefill_bootstrap_queue_duration", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_prefill_forward_duration", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_prefill_transfer_queue_duration", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_decode_prealloc_duration", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_decode_transfer_duration", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_decode_forward_duration", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_bootstrap_duration", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_alloc_waiting_duration", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_transfer_speed_gb_s", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_transfer_total_mb", meta_info, i ++ ) ++ self._add_metric_if_present( ++ recv_obj, "pd_prefill_retry_count", meta_info, i ++ ) + + if getattr(state.obj, "return_logprob", False): + self.convert_logprob_style( +@@ -1955,19 +1993,17 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + if custom_labels + else self.metrics_collector.labels + ) +- if ( +- state.first_token_time == 0.0 +- and self.disaggregation_mode != DisaggregationMode.PREFILL +- ): ++ if state.first_token_time == 0.0: + state.first_token_time = state.last_time = time.time() + state.first_token_time_perf = time.perf_counter() + state.last_completion_tokens = completion_tokens +- self.metrics_collector.observe_time_to_first_token( +- labels, state.first_token_time - state.created_time +- ) ++ if self.disaggregation_mode != DisaggregationMode.PREFILL: ++ self.metrics_collector.observe_time_to_first_token( ++ labels, state.first_token_time - state.created_time ++ ) + else: + num_new_tokens = completion_tokens - state.last_completion_tokens +- if num_new_tokens: ++ if num_new_tokens > 0: + new_time = time.time() + interval = new_time - state.last_time + self.metrics_collector.observe_inter_token_latency( +@@ -1976,7 +2012,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + num_new_tokens, + ) + state.last_time = new_time +- state.last_completion_tokens = completion_tokens ++ state.last_completion_tokens = completion_tokens + + if state.finished: + retraction_count = ( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 86b009df4..16ebd52ae 100644 --- a/python/sglang/srt/managers/tp_worker.py @@ -2415,6 +3234,189 @@ index 42b169728..8e799196a 100644 node = node.parent return delta +diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py +index 255d41ccc..f93bedb4d 100644 +--- a/python/sglang/srt/metrics/collector.py ++++ b/python/sglang/srt/metrics/collector.py +@@ -20,7 +20,10 @@ import time + from dataclasses import dataclass, field + from typing import Any, Dict, List, Optional, Union + +-from sglang.srt.disaggregation.utils import DisaggregationMode ++from sglang.srt.disaggregation.utils import ( ++ DisaggregationMode, ++ is_slime_profiling_enabled, ++) + from sglang.srt.environ import envs + from sglang.srt.metrics.utils import exponential_buckets, generate_buckets + from sglang.srt.model_executor.forward_batch_info import ForwardMode +@@ -77,6 +80,17 @@ class TimeStats: + # Number of prefill retries for this request + prefill_retry_count: int = 0 + ++ # Prefill-side durations forwarded via metadata transfer from P to D instance. ++ # Set on the decode instance after KV cache transfer completes. ++ fwd_prefill_bootstrap_queue_duration: Optional[float] = None ++ fwd_prefill_forward_duration: Optional[float] = None ++ fwd_prefill_transfer_queue_duration: Optional[float] = None ++ fwd_bootstrap_duration: Optional[float] = None ++ fwd_alloc_waiting_duration: Optional[float] = None ++ fwd_transfer_speed_gb_s: Optional[float] = None ++ fwd_transfer_total_mb: Optional[float] = None ++ fwd_prefill_retry_count: Optional[int] = None ++ + # Timestamp when prefill phase finishes, obtained from `time.time()`. + # Note that this differs from the other `_time` fields tracked by the + # `TimeStats` class, which are obtained from `time.perf_counter()`. +@@ -102,6 +116,148 @@ class TimeStats: + return self.prefill_finished_ts + return None + ++ # --- PD disaggregation timing getters --- ++ ++ def get_pd_prefill_bootstrap_queue_duration(self) -> Optional[float]: ++ """P instance: time spent in bootstrap queue before entering the wait queue.""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if self.fwd_prefill_bootstrap_queue_duration is not None: ++ return self.fwd_prefill_bootstrap_queue_duration ++ if ( ++ self.disagg_mode == DisaggregationMode.PREFILL ++ and self.prefill_bootstrap_queue_entry_time > 0.0 ++ and self.wait_queue_entry_time > 0.0 ++ ): ++ return self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time ++ return None ++ ++ def get_pd_prefill_forward_duration(self) -> Optional[float]: ++ """P instance: time for the actual prefill forward computation.""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if self.fwd_prefill_forward_duration is not None: ++ return self.fwd_prefill_forward_duration ++ if ( ++ self.disagg_mode == DisaggregationMode.PREFILL ++ and self.forward_entry_time > 0.0 ++ and self.completion_time > 0.0 ++ ): ++ return self.completion_time - self.forward_entry_time ++ return None ++ ++ def get_pd_prefill_transfer_queue_duration(self) -> Optional[float]: ++ """P instance: time spent in the transfer queue (KV cache send).""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if self.fwd_prefill_transfer_queue_duration is not None: ++ return self.fwd_prefill_transfer_queue_duration ++ if ( ++ self.disagg_mode == DisaggregationMode.PREFILL ++ and self.prefill_transfer_queue_entry_time > 0.0 ++ and self.completion_time > 0.0 ++ ): ++ return self.completion_time - self.prefill_transfer_queue_entry_time ++ return None ++ ++ def get_pd_decode_prealloc_duration(self) -> Optional[float]: ++ """D instance: time spent in the pre-alloc queue (waiting for KV cache slot allocation).""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if ( ++ self.disagg_mode == DisaggregationMode.DECODE ++ and self.decode_prealloc_queue_entry_time > 0.0 ++ and self.decode_transfer_queue_entry_time > 0.0 ++ ): ++ return ( ++ self.decode_transfer_queue_entry_time ++ - self.decode_prealloc_queue_entry_time ++ ) ++ return None ++ ++ def get_pd_decode_transfer_duration(self) -> Optional[float]: ++ """D instance: time spent waiting for KV cache transfer to complete.""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if ( ++ self.disagg_mode == DisaggregationMode.DECODE ++ and self.decode_transfer_queue_entry_time > 0.0 ++ and self.wait_queue_entry_time > 0.0 ++ ): ++ return self.wait_queue_entry_time - self.decode_transfer_queue_entry_time ++ return None ++ ++ def get_pd_decode_forward_duration(self) -> Optional[float]: ++ """D instance: time for the actual decode forward computation.""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if ( ++ self.disagg_mode == DisaggregationMode.DECODE ++ and self.forward_entry_time > 0.0 ++ and self.completion_time > 0.0 ++ ): ++ return self.completion_time - self.forward_entry_time ++ return None ++ ++ def get_pd_bootstrap_duration(self) -> Optional[float]: ++ """Bootstrap handshake duration (both P and D instances).""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if self.fwd_bootstrap_duration is not None: ++ return self.fwd_bootstrap_duration ++ if ( ++ self.disagg_mode != DisaggregationMode.NULL ++ and self.bootstrap_duration > 0.0 ++ ): ++ return self.bootstrap_duration ++ return None ++ ++ def get_pd_alloc_waiting_duration(self) -> Optional[float]: ++ """KV cache allocation waiting duration (both P and D instances).""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if self.fwd_alloc_waiting_duration is not None: ++ return self.fwd_alloc_waiting_duration ++ if ( ++ self.disagg_mode != DisaggregationMode.NULL ++ and self.alloc_waiting_duration > 0.0 ++ ): ++ return self.alloc_waiting_duration ++ return None ++ ++ def get_pd_transfer_speed_gb_s(self) -> Optional[float]: ++ """KV cache transfer speed in GB/s.""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if self.fwd_transfer_speed_gb_s is not None: ++ return self.fwd_transfer_speed_gb_s ++ if ( ++ self.disagg_mode != DisaggregationMode.NULL ++ and self.transfer_speed_gb_s > 0.0 ++ ): ++ return self.transfer_speed_gb_s ++ return None ++ ++ def get_pd_transfer_total_mb(self) -> Optional[float]: ++ """Total KV cache transferred in MB.""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if self.fwd_transfer_total_mb is not None: ++ return self.fwd_transfer_total_mb ++ if self.disagg_mode != DisaggregationMode.NULL and self.transfer_total_mb > 0.0: ++ return self.transfer_total_mb ++ return None ++ ++ def get_pd_prefill_retry_count(self) -> Optional[int]: ++ """Number of prefill retries for this request.""" ++ if not is_slime_profiling_enabled(): ++ return None ++ if self.fwd_prefill_retry_count is not None: ++ return self.fwd_prefill_retry_count ++ if self.disagg_mode == DisaggregationMode.PREFILL: ++ return self.prefill_retry_count ++ return None ++ + def convert_to_duration(self) -> str: + if self.disagg_mode == DisaggregationMode.NULL: + queue_duration = self.forward_entry_time - self.wait_queue_entry_time diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 234523532..f5d479945 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py diff --git a/docker/version.txt b/docker/version.txt index 198a048061..709d39f4bf 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20260327a +nightly-dev-20260329a diff --git a/docs/_static/image/trace.png b/docs/_static/image/trace.png new file mode 100644 index 0000000000..847f4009cb Binary files /dev/null and b/docs/_static/image/trace.png differ diff --git a/docs/en/developer_guide/trace.md b/docs/en/developer_guide/trace.md new file mode 100644 index 0000000000..d37f867eca --- /dev/null +++ b/docs/en/developer_guide/trace.md @@ -0,0 +1,65 @@ +# Trace Viewer + +slime can attach lightweight execution traces to each rollout sample. These traces capture span-style events such as generation and reward-model calls, and they can be inspected later from a saved rollout debug dump. + +![trace timeline viewer](../../_static/image/trace.png) + +## Save rollout trace data + +To inspect traces later, save rollout debug data during a run: + +```bash +python train.py \ + ... \ + --save-debug-rollout-data /path/to/debug/rollout_{rollout_id}.pt +``` + +Each saved `.pt` file contains the rollout samples together with their `trace` payloads. You can also replay the same dump later with `--load-debug-rollout-data`. + +## Open the timeline viewer + +Use the trace viewer script on a saved rollout dump: + +```bash +python tools/trace_timeline_viewer.py /path/to/debug/rollout_0.pt +``` + +The script generates: + +- `rollout_0.trace_timeline_cache.json` +- `rollout_0.trace_timeline_viewer.html` + +By default it also starts a local static server so you can open the generated HTML immediately. If you only want the files, use `--no-serve`. + +## How to read the viewer + +- Each row corresponds to one sample. +- Bars represent spans, while point markers represent instant events. +- Span attributes recorded at the start or end of `trace_span(...)` are shown in the details panel. +- When SGLang returns PD disaggregation timings, the viewer adds synthetic `[P]` and `[D]` lanes to break out prefill/decode work. +- When PD is not enabled, those virtual lanes are omitted automatically and the base trace still renders normally. + +## Instrument custom code + +For custom rollout or reward code, reuse helpers from `slime.utils.trace_utils`: + +- `trace_span(target, name, attrs=...)`: record a duration span. +- `trace_event(target, name, attrs=...)`: record an instant event. +- `bind_trace(sample)`: ensure a sample already has a trace carrier before passing it across helpers or tasks. + +If you want to record SGLang generation metadata in a consistent way, reuse `build_sglang_meta_trace_attrs`: + +```python +from slime.utils.trace_utils import build_sglang_meta_trace_attrs, trace_span + +with trace_span(sample, "sglang_generate") as span: + output = await post(url, payload) + span.update(build_sglang_meta_trace_attrs(output["meta_info"])) +``` + +## Tips + +- Save a small number of rollouts first; the viewer is easiest to read when each dump contains a manageable number of samples. +- The viewer is built from the saved `.pt` dump, so traces can be inspected offline on another machine. +- For GPU/kernel-level SGLang profiling traces, see [Profiling](./profiling.md). + diff --git a/docs/en/index.rst b/docs/en/index.rst index e0af06ec31..0c0b7521f9 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -66,6 +66,7 @@ slime is the RL-framework behind GLM-4.7, GLM-4.6 and GLM-4.5. Apart from models developer_guide/ci.md developer_guide/debug.md + developer_guide/trace.md developer_guide/profiling.md .. toctree:: diff --git a/docs/zh/developer_guide/trace.md b/docs/zh/developer_guide/trace.md new file mode 100644 index 0000000000..04c424bdaa --- /dev/null +++ b/docs/zh/developer_guide/trace.md @@ -0,0 +1,65 @@ +# Trace 可视化 + +slime 可以为每条 rollout sample 挂上轻量级执行 trace。它会记录生成、奖励模型等 span 事件,并且可以在保存下来的 rollout debug dump 中离线查看。 + +![trace 时间线查看器](../../_static/image/trace.png) + +## 保存 rollout trace 数据 + +如果想在运行结束后查看 trace,可以在训练时打开 rollout debug dump: + +```bash +python train.py \ + ... \ + --save-debug-rollout-data /path/to/debug/rollout_{rollout_id}.pt +``` + +每个保存出来的 `.pt` 文件都会包含 rollout samples,以及对应的 `trace` 数据。之后也可以通过 `--load-debug-rollout-data` 复用同一份 dump。 + +## 打开时间线查看器 + +对保存好的 rollout dump 运行: + +```bash +python tools/trace_timeline_viewer.py /path/to/debug/rollout_0.pt +``` + +脚本会生成: + +- `rollout_0.trace_timeline_cache.json` +- `rollout_0.trace_timeline_viewer.html` + +默认情况下,它还会启动一个本地静态文件服务,方便直接在浏览器里打开。如果只想生成文件,可以加 `--no-serve`。 + +## 如何理解可视化结果 + +- 每一行对应一条 sample。 +- 条形块表示 span,点表示瞬时事件。 +- `trace_span(...)` 在开始和结束时记录的属性,都会显示在详情面板里。 +- 当 SGLang 返回 PD 分离相关时延时,viewer 会自动补出 `[P]` 和 `[D]` 两条虚拟 lane,用来拆开展示 prefill/decode。 +- 如果没有开启 PD,这两条虚拟 lane 不会出现,基础 trace 也仍然可以正常渲染。 + +## 给自定义代码打点 + +在自定义 rollout 或 reward 逻辑中,可以直接复用 `slime.utils.trace_utils` 里的工具: + +- `trace_span(target, name, attrs=...)`:记录一段持续时间。 +- `trace_event(target, name, attrs=...)`:记录一个瞬时事件。 +- `bind_trace(sample)`:在 sample 被传递到其他 helper 或任务之前,确保它已经绑定好 trace carrier。 + +如果想统一记录 SGLang 返回的 generation 元信息,可以复用 `build_sglang_meta_trace_attrs`: + +```python +from slime.utils.trace_utils import build_sglang_meta_trace_attrs, trace_span + +with trace_span(sample, "sglang_generate") as span: + output = await post(url, payload) + span.update(build_sglang_meta_trace_attrs(output["meta_info"])) +``` + +## 使用建议 + +- 先保存少量 rollout;单个 dump 的 sample 数量适中时,viewer 会更容易阅读。 +- viewer 直接基于保存下来的 `.pt` dump 工作,因此可以把文件拷到别的机器离线分析。 +- 如果你想看的是 SGLang 自身的 GPU / kernel 级 profiling trace,请参考 [性能分析](./profiling.md)。 + diff --git a/docs/zh/index.rst b/docs/zh/index.rst index 3ae7a783a3..6c9a81fe57 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -66,6 +66,7 @@ slime 是 GLM-4.7、GLM-4.6、GLM-4.5 背后的 RL 训练框架。除此之外 developer_guide/ci.md developer_guide/debug.md + developer_guide/trace.md developer_guide/profiling.md .. toctree:: diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 2a7f883c7a..0bd4c0fa02 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -120,6 +120,8 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_FALLBACK_VARIANT": "true", "SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION": "false", "SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE": "false", + "SGLANG_TRANSFER_PROFILING_INFO": "true", + "SLIME_ENABLE_PROFILING": "true", }.items() } diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index e75346a64f..72b42b0752 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -26,6 +26,7 @@ load_processor, load_tokenizer, ) +from slime.utils.trace_utils import build_sglang_meta_trace_attrs, trace_function, trace_span from slime.utils.types import Sample from .rm_hub import async_rm, batched_async_rm @@ -174,7 +175,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A if getattr(args, "router_policy", None) == "consistent_hashing": headers = {"X-SMG-Routing-Key": sample.session_id} - output = await post(url, payload, headers=headers) + with trace_span(sample, "sglang_generate", attrs={"max_new_tokens": sampling_params["max_new_tokens"]}) as span: + output = await post(url, payload, headers=headers) + span.update(build_sglang_meta_trace_attrs(output["meta_info"])) if "output_token_logprobs" in output["meta_info"]: new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] @@ -211,6 +214,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A return sample +@trace_function("generate_and_rm", target="sample") async def generate_and_rm( args: Namespace, sample: Sample | list[Sample], @@ -261,7 +265,8 @@ async def generate_and_rm( # for multi agent system, the reward of some sample is calculated during generation. samples_need_reward = [sample for sample in samples if sample.reward is None] - rewards = await batched_async_rm(args, samples_need_reward) + with trace_span(samples_need_reward, "reward_model"): + rewards = await batched_async_rm(args, samples_need_reward) for sample, reward in zip(samples_need_reward, rewards, strict=False): sample.reward = reward return samples @@ -270,11 +275,17 @@ async def generate_and_rm( return sample # for multi-turn environment, a reward could be assigned to the agent. if sample.reward is None: - sample.reward = await async_rm(args, sample) + with trace_span(sample, "reward_model"): + sample.reward = await async_rm(args, sample) return sample +@trace_function( + "generate_and_rm_group", + target="group", + attrs_getter=lambda args, group, sampling_params, evaluation=False: {"group_size": len(group)}, +) async def generate_and_rm_group( args: Namespace, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False ) -> list[Sample]: @@ -302,7 +313,8 @@ async def generate_and_rm_group( # for the rm that need the whole group, we will do the rm here if not state.aborted and args.group_rm: - rewards = await batched_async_rm(args, group) + with trace_span(group, "group_reward_model"): + rewards = await batched_async_rm(args, group) for sample, reward in zip(group, rewards, strict=False): sample.reward = reward diff --git a/slime/utils/trace_utils.py b/slime/utils/trace_utils.py new file mode 100644 index 0000000000..b51bc5b931 --- /dev/null +++ b/slime/utils/trace_utils.py @@ -0,0 +1,612 @@ +from __future__ import annotations + +import contextvars +import functools +import inspect +import logging +import time +import uuid +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any + +from slime.utils.types import Sample + +TRACE_VERSION = 1 +SGLANG_TRACE_META_KEYS = ( + "prompt_tokens", + "completion_tokens", + "cached_tokens", + "pd_prefill_bootstrap_queue_duration", + "pd_prefill_forward_duration", + "pd_prefill_transfer_queue_duration", + "pd_prefill_retry_count", + "pd_decode_prealloc_duration", + "pd_decode_transfer_duration", + "pd_decode_forward_duration", + "pd_bootstrap_duration", + "pd_alloc_waiting_duration", + "pd_transfer_speed_gb_s", + "pd_transfer_total_mb", +) + +logger = logging.getLogger(__name__) +_TRACE_STACK: contextvars.ContextVar[tuple[tuple[str, str], ...]] = contextvars.ContextVar( + "slime_trace_stack", + default=(), +) +_TRACE_HANDLE_STACK: contextvars.ContextVar[tuple[tuple[TraceHandle, ...], ...]] = contextvars.ContextVar( + "slime_trace_handle_stack", + default=(), +) +_TRACE_AUTO_INFER_WARNED: set[str] = set() + + +@dataclass +class TraceHandle: + trace_id: str + carrier: dict[str, Any] + sample_id: int | str | None = None + group_id: int | str | None = None + attempt: int = 0 + parent_span_id: str | None = None + + +@dataclass +class TraceSpanContext: + target: Sample | TraceHandle | list[Sample | TraceHandle] + handles: list[TraceHandle] + end_attrs: dict[str, Any] = field(default_factory=dict) + end_events: list[dict[str, Any]] = field(default_factory=list) + closed: bool = False + + def set(self, key: str, value: Any) -> TraceSpanContext: + self.end_attrs[key] = value + self._sync_end_events({key: value}) + return self + + def update(self, attrs: dict[str, Any] | None) -> TraceSpanContext: + if attrs: + self.end_attrs.update(attrs) + self._sync_end_events(attrs) + return self + + def set_attr(self, key: str, value: Any) -> TraceSpanContext: + return self.set(key, value) + + def update_attrs(self, attrs: dict[str, Any] | None) -> TraceSpanContext: + return self.update(attrs) + + def build_end_attrs(self) -> dict[str, Any] | None: + return dict(self.end_attrs) or None + + def finalize(self, end_events: list[dict[str, Any]]) -> None: + self.end_events = end_events + self.closed = True + if self.end_attrs: + self._sync_end_events(self.end_attrs) + + def _sync_end_events(self, attrs: dict[str, Any]) -> None: + if not self.end_events or not attrs: + return + for event in self.end_events: + event.setdefault("attrs", {}) + event["attrs"].update(attrs) + + +def _noop_handle() -> TraceHandle: + return TraceHandle( + trace_id="", + carrier={ + "version": TRACE_VERSION, + "trace_id": "", + "events": [], + "sample_id": None, + "group_id": None, + "attempt": 0, + }, + ) + + +def _log_trace_error(action: str, exc: Exception) -> None: + logger.debug("trace %s skipped: %s", action, exc, exc_info=True) + + +def _new_trace_id() -> str: + return uuid.uuid4().hex + + +def _new_span_id() -> str: + return uuid.uuid4().hex + + +def build_sglang_meta_trace_attrs(meta: dict[str, Any]) -> dict[str, Any]: + attrs = {key: meta[key] for key in SGLANG_TRACE_META_KEYS if key in meta and meta[key] is not None} + attrs["finish_reason"] = meta["finish_reason"]["type"] + return attrs + + +def _ensure_trace_carrier( + carrier: dict[str, Any] | None, + *, + trace_id: str | None = None, + sample_id: int | str | None = None, + group_id: int | str | None = None, + attempt: int = 0, +) -> dict[str, Any]: + if carrier is None: + carrier = {} + carrier.setdefault("version", TRACE_VERSION) + carrier.setdefault("trace_id", trace_id or _new_trace_id()) + carrier.setdefault("events", []) + if sample_id is not None: + carrier["sample_id"] = sample_id + else: + carrier.setdefault("sample_id", None) + if group_id is not None: + carrier["group_id"] = group_id + else: + carrier.setdefault("group_id", None) + carrier["attempt"] = int(carrier.get("attempt", attempt)) + return carrier + + +def bind_trace(sample: Sample) -> TraceHandle: + try: + sample.trace = _ensure_trace_carrier( + getattr(sample, "trace", None), + sample_id=sample.index, + group_id=sample.group_index, + ) + return TraceHandle( + trace_id=sample.trace["trace_id"], + carrier=sample.trace, + sample_id=sample.trace.get("sample_id"), + group_id=sample.trace.get("group_id"), + attempt=int(sample.trace.get("attempt", 0)), + ) + except Exception as exc: + _log_trace_error("bind", exc) + return _noop_handle() + + +def bind_trace_carrier( + carrier: dict[str, Any] | None, + *, + trace_id: str | None = None, + sample_id: int | str | None = None, + group_id: int | str | None = None, + attempt: int = 0, + parent_span_id: str | None = None, +) -> TraceHandle: + try: + trace = _ensure_trace_carrier( + carrier, + trace_id=trace_id, + sample_id=sample_id, + group_id=group_id, + attempt=attempt, + ) + return TraceHandle( + trace_id=trace["trace_id"], + carrier=trace, + sample_id=trace.get("sample_id"), + group_id=trace.get("group_id"), + attempt=int(trace.get("attempt", 0)), + parent_span_id=parent_span_id, + ) + except Exception as exc: + _log_trace_error("bind_carrier", exc) + handle = _noop_handle() + handle.parent_span_id = parent_span_id + return handle + + +def export_trace(handle: TraceHandle) -> dict[str, Any]: + try: + return { + "version": TRACE_VERSION, + "trace_id": handle.trace_id, + "sample_id": handle.sample_id, + "group_id": handle.group_id, + "attempt": handle.attempt, + "parent_span_id": handle.parent_span_id or _get_current_parent_span_id(handle.trace_id), + } + except Exception as exc: + _log_trace_error("export", exc) + return { + "version": TRACE_VERSION, + "trace_id": "", + "sample_id": None, + "group_id": None, + "attempt": 0, + "parent_span_id": None, + } + + +def import_trace(payload: dict[str, Any], carrier: dict[str, Any] | None = None) -> TraceHandle: + try: + return bind_trace_carrier( + carrier, + trace_id=payload.get("trace_id"), + sample_id=payload.get("sample_id"), + group_id=payload.get("group_id"), + attempt=int(payload.get("attempt", 0)), + parent_span_id=payload.get("parent_span_id"), + ) + except Exception as exc: + _log_trace_error("import", exc) + return _noop_handle() + + +def trace_event( + target: Sample | TraceHandle | list[Sample | TraceHandle], name: str, *, attrs: dict[str, Any] | None = None +): + try: + timestamp = time.time() + for handle in _coerce_handles(target): + _append_event(handle, kind="event", name=name, timestamp=timestamp, attrs=attrs) + except Exception as exc: + _log_trace_error(f"event:{name}", exc) + + +@contextmanager +def trace_span( + target: Sample | TraceHandle | list[Sample | TraceHandle], + name: str, + *, + attrs: dict[str, Any] | None = None, + record_error: bool = True, +): + try: + handles = _coerce_handles(target) + except Exception as exc: + _log_trace_error(f"span:{name}", exc) + handles = [] + + if not handles: + yield target + return + + timestamp = time.time() + stack_before = _TRACE_STACK.get() + handle_stack_before = _TRACE_HANDLE_STACK.get() + span_records: list[tuple[TraceHandle, str]] = [] + new_entries: list[tuple[str, str]] = [] + + try: + for handle in handles: + span_id = _new_span_id() + parent_span_id = handle.parent_span_id or _get_current_parent_span_id(handle.trace_id, stack=stack_before) + _append_event( + handle, + kind="span_start", + name=name, + timestamp=timestamp, + span_id=span_id, + parent_span_id=parent_span_id, + attrs=attrs, + ) + span_records.append((handle, span_id)) + new_entries.append((handle.trace_id, span_id)) + token = _TRACE_STACK.set(stack_before + tuple(new_entries)) + handle_token = _TRACE_HANDLE_STACK.set(handle_stack_before + (tuple(handles),)) + except Exception as exc: + _log_trace_error(f"span:{name}", exc) + yield target + return + + span_context = TraceSpanContext( + target=handles[0] if len(handles) == 1 else handles, + handles=handles, + ) + + try: + yield span_context + except Exception as exc: + try: + end_attrs = span_context.build_end_attrs() + if record_error: + error_attrs = {"error_type": type(exc).__name__, "error_message": str(exc)} + if end_attrs: + end_attrs.update(error_attrs) + else: + end_attrs = error_attrs + span_context.finalize(_record_span_end(span_records, name=name, attrs=end_attrs)) + except Exception as trace_exc: + _log_trace_error(f"span_end:{name}", trace_exc) + raise + else: + try: + span_context.finalize(_record_span_end(span_records, name=name, attrs=span_context.build_end_attrs())) + except Exception as exc: + _log_trace_error(f"span_end:{name}", exc) + finally: + try: + _TRACE_STACK.reset(token) + except Exception as exc: + _log_trace_error(f"span_reset:{name}", exc) + try: + _TRACE_HANDLE_STACK.reset(handle_token) + except Exception as exc: + _log_trace_error(f"span_handle_reset:{name}", exc) + + +def trace_next_attempt( + target: Sample | TraceHandle | list[Sample | TraceHandle], + *, + attrs: dict[str, Any] | None = None, +): + try: + handles = _coerce_handles(target) + for handle in handles: + next_attempt = int(handle.carrier.get("attempt", 0)) + 1 + handle.carrier["attempt"] = next_attempt + handle.attempt = next_attempt + attempt_attrs = {"attempt": next_attempt} + if attrs: + attempt_attrs.update(attrs) + trace_event(handle, "attempt_start", attrs=attempt_attrs) + if not handles: + return target + return handles[0] if len(handles) == 1 else handles + except Exception as exc: + _log_trace_error("next_attempt", exc) + return target + + +def trace_function( + name: str, + *, + target: str | None = None, + target_getter: Callable[..., Sample | TraceHandle | list[Sample | TraceHandle] | None] | None = None, + attrs_getter: Callable[..., dict[str, Any] | None] | None = None, + record_error: bool = True, +): + def decorator(fn): + if inspect.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def async_wrapper(*args, **kwargs): + resolved_target = _resolve_trace_function_target( + fn, + args, + kwargs, + target=target, + target_getter=target_getter, + ) + if resolved_target is None: + return await fn(*args, **kwargs) + attrs = _resolve_trace_function_attrs(fn, args, kwargs, attrs_getter=attrs_getter) + with trace_span(resolved_target, name, attrs=attrs, record_error=record_error): + return await fn(*args, **kwargs) + + return async_wrapper + + @functools.wraps(fn) + def sync_wrapper(*args, **kwargs): + resolved_target = _resolve_trace_function_target( + fn, + args, + kwargs, + target=target, + target_getter=target_getter, + ) + if resolved_target is None: + return fn(*args, **kwargs) + attrs = _resolve_trace_function_attrs(fn, args, kwargs, attrs_getter=attrs_getter) + with trace_span(resolved_target, name, attrs=attrs, record_error=record_error): + return fn(*args, **kwargs) + + return sync_wrapper + + return decorator + + +def _record_span_end( + span_records: list[tuple[TraceHandle, str]], + *, + name: str, + attrs: dict[str, Any] | None, +) -> list[dict[str, Any]]: + timestamp = time.time() + events = [] + for handle, span_id in span_records: + events.append( + _append_event( + handle, + kind="span_end", + name=name, + timestamp=timestamp, + span_id=span_id, + attrs=attrs, + ) + ) + return events + + +def _append_event( + handle: TraceHandle, + *, + kind: str, + name: str, + timestamp: float, + attrs: dict[str, Any] | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, +) -> dict[str, Any]: + event = { + "type": kind, + "name": name, + "ts": timestamp, + "trace_id": handle.trace_id, + "sample_id": handle.sample_id, + "group_id": handle.group_id, + "attempt": int(handle.carrier.get("attempt", handle.attempt)), + } + if span_id is not None: + event["span_id"] = span_id + if parent_span_id is not None: + event["parent_span_id"] = parent_span_id + if attrs: + event["attrs"] = dict(attrs) + handle.carrier["events"].append(event) + return event + + +def _coerce_handles(target: Sample | TraceHandle | list[Sample | TraceHandle]) -> list[TraceHandle]: + target = _adapt_trace_target(target) + if isinstance(target, TraceHandle): + return [target] + if isinstance(target, Sample): + return [bind_trace(target)] + if isinstance(target, list): + handles = [] + for item in target: + handles.extend(_coerce_handles(item)) + return handles + return [] + + +def _get_current_parent_span_id( + trace_id: str, + *, + stack: tuple[tuple[str, str], ...] | None = None, +) -> str | None: + stack = _TRACE_STACK.get() if stack is None else stack + for current_trace_id, span_id in reversed(stack): + if current_trace_id == trace_id: + return span_id + return None + + +def _resolve_trace_function_target( + fn, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + target: str | None, + target_getter: Callable[..., Sample | TraceHandle | list[Sample | TraceHandle] | None] | None, +): + try: + bound = inspect.signature(fn).bind_partial(*args, **kwargs) + except Exception as exc: + _log_trace_error(f"trace_function_bind:{getattr(fn, '__qualname__', fn)}", exc) + bound = None + + if target is not None: + if bound is None or target not in bound.arguments: + logger.warning( + "trace_function target '%s' not found for %s; tracing disabled for this call", + target, + getattr(fn, "__qualname__", repr(fn)), + ) + return None + resolved = _normalize_trace_target(bound.arguments.get(target)) + if resolved is None: + logger.warning( + "trace_function target '%s' for %s is not a supported trace target; tracing disabled for this call", + target, + getattr(fn, "__qualname__", repr(fn)), + ) + return resolved + + if target_getter is not None: + try: + resolved = _normalize_trace_target(target_getter(*args, **kwargs)) + return resolved + except Exception as exc: + _log_trace_error(f"trace_function_target_getter:{getattr(fn, '__qualname__', fn)}", exc) + return None + + inferred = _infer_trace_target(bound.arguments.values() if bound is not None else args) + if inferred is not None: + warn_key = getattr(fn, "__module__", "") + "." + getattr(fn, "__qualname__", repr(fn)) + if warn_key not in _TRACE_AUTO_INFER_WARNED: + _TRACE_AUTO_INFER_WARNED.add(warn_key) + logger.warning( + "trace_function auto-inferred target for %s; inference may be ambiguous, prefer explicit target=...", + getattr(fn, "__qualname__", repr(fn)), + ) + return inferred + + return _get_current_trace_target() + + +def _resolve_trace_function_attrs( + fn, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + attrs_getter: Callable[..., dict[str, Any] | None] | None, +) -> dict[str, Any] | None: + if attrs_getter is None: + return None + try: + attrs = attrs_getter(*args, **kwargs) + if attrs is None: + return None + if isinstance(attrs, dict): + return attrs + logger.warning( + "trace_function attrs_getter for %s returned non-dict %s; ignoring attrs", + getattr(fn, "__qualname__", repr(fn)), + type(attrs).__name__, + ) + return None + except Exception as exc: + _log_trace_error(f"trace_function_attrs_getter:{getattr(fn, '__qualname__', fn)}", exc) + return None + + +def _infer_trace_target(values) -> Sample | TraceHandle | list[Sample | TraceHandle] | None: + for value in values: + normalized = _normalize_trace_target(value) + if normalized is not None: + return normalized + return None + + +def _normalize_trace_target(value): + value = _adapt_trace_target(value) + if isinstance(value, (Sample, TraceHandle)): + return value + if isinstance(value, list) and value: + if all(_normalize_trace_target(item) is not None for item in value): + return value + return None + + +def _adapt_trace_target(value): + if value is None: + return None + if isinstance(value, (Sample, TraceHandle)): + return value + if isinstance(value, list): + return [_adapt_trace_target(item) for item in value] + if _looks_like_sample_box(value): + generation = getattr(value, "generation", None) + if generation: + return generation + return getattr(value, "prompt_sample", None) + return value + + +def _get_current_trace_target() -> TraceHandle | list[TraceHandle] | None: + handle_stack = _TRACE_HANDLE_STACK.get() + if not handle_stack: + return None + current_handles = list(handle_stack[-1]) + if not current_handles: + return None + if len(current_handles) == 1: + return current_handles[0] + return current_handles + + +def _looks_like_sample_box(value: Any) -> bool: + cls = getattr(value, "__class__", None) + if cls is None or getattr(cls, "__name__", "") != "SampleBox": + return False + return hasattr(value, "prompt_sample") and hasattr(value, "generation") diff --git a/tests/plugin_contracts/test_plugin_generate_contracts.py b/tests/plugin_contracts/test_plugin_generate_contracts.py index d63a69f45f..9cf7af3a32 100644 --- a/tests/plugin_contracts/test_plugin_generate_contracts.py +++ b/tests/plugin_contracts/test_plugin_generate_contracts.py @@ -173,5 +173,29 @@ def test_custom_generate_function_path_supports_user_override(patch_generate_sta assert_sample_contract(result) +def test_generate_and_rm_group_rm_accepts_list_result_from_custom_generate(patch_generate_state, monkeypatch): + sglang_rollout = patch_generate_state + + async def custom_generate_list(args, sample: Sample, sampling_params: dict): + sample.status = Sample.Status.COMPLETED + sibling = Sample(index=1, prompt="prompt-1", status=Sample.Status.COMPLETED) + return [sample, sibling] + + monkeypatch.setattr(sglang_rollout, "load_function", lambda _path: custom_generate_list) + + result = asyncio.run( + generate_and_rm( + make_args(custom_generate_function_path="plugin_contracts.fake_generate", group_rm=True), + Sample(index=0, prompt="prompt-0"), + sampling_params={"temperature": 0.3}, + evaluation=False, + ) + ) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(sample, Sample) for sample in result) + + if __name__ == "__main__": run_contract_test_file() diff --git a/tests/utils/test_trace_utils.py b/tests/utils/test_trace_utils.py new file mode 100644 index 0000000000..01a6b527b8 --- /dev/null +++ b/tests/utils/test_trace_utils.py @@ -0,0 +1,83 @@ +import importlib.util +import sys +from pathlib import Path + +import pytest +import torch + +from slime.utils.trace_utils import build_sglang_meta_trace_attrs, trace_span +from slime.utils.types import Sample + + +def _load_trace_timeline_viewer_module(): + module_path = Path(__file__).resolve().parents[2] / "tools" / "trace_timeline_viewer.py" + module_name = "test_trace_timeline_viewer_module" + sys.modules.pop(module_name, None) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +@pytest.mark.unit +def test_build_sglang_meta_trace_attrs_keeps_standard_and_pd_fields(): + meta = { + "prompt_tokens": 12, + "completion_tokens": 7, + "cached_tokens": 3, + "pd_prefill_forward_duration": 0.125, + "pd_decode_transfer_duration": None, + "finish_reason": {"type": "stop"}, + "unused_field": "ignored", + } + + assert build_sglang_meta_trace_attrs(meta) == { + "prompt_tokens": 12, + "completion_tokens": 7, + "cached_tokens": 3, + "pd_prefill_forward_duration": 0.125, + "finish_reason": "stop", + } + + +@pytest.mark.unit +def test_trace_timeline_viewer_omits_virtual_pd_lanes_without_pd_attrs(tmp_path: Path): + viewer = _load_trace_timeline_viewer_module() + sample = Sample(index=0, prompt="hello") + + with trace_span(sample, "sglang_generate", attrs={"max_new_tokens": 8}) as span: + span.update( + build_sglang_meta_trace_attrs( + { + "prompt_tokens": 4, + "completion_tokens": 2, + "cached_tokens": 1, + "finish_reason": {"type": "stop"}, + } + ) + ) + + pt_path = tmp_path / "rollout.pt" + torch.save({"samples": [sample]}, pt_path) + + cache = viewer._build_cache_data(pt_path) + + assert cache["sample_count"] == 1 + row = cache["rows"][0] + assert row["lane_count"] == 1 + assert row["item_count"] == 1 + assert row["closed_span_count"] == 1 + + item = row["items"][0] + assert item["name"] == "sglang_generate" + assert item["attrs"]["end_attrs"] == { + "prompt_tokens": 4, + "completion_tokens": 2, + "cached_tokens": 1, + "finish_reason": "stop", + } + assert "[P]" not in item["name"] + assert "[D]" not in item["name"] diff --git a/tools/trace_timeline_viewer.py b/tools/trace_timeline_viewer.py new file mode 100644 index 0000000000..d1658f66cc --- /dev/null +++ b/tools/trace_timeline_viewer.py @@ -0,0 +1,2369 @@ +#!/usr/bin/env python3 +"""Build and serve an interactive timeline viewer for rollout trace dumps. + +The viewer consumes a rollout debug dump `.pt` file, extracts per-sample trace +events, rebuilds spans and point events, and writes a lightweight JSON cache +plus a self-contained HTML viewer next to the source file. +""" + +from __future__ import annotations + +import argparse +import functools +import json +import pickle +import socketserver +import sys +import time +import types +from dataclasses import dataclass +from http.server import SimpleHTTPRequestHandler +from pathlib import Path +from typing import Any + +import torch + +CACHE_VERSION = 1 + + +class _MissingPickleObject: + def __setstate__(self, state: Any) -> None: + if isinstance(state, dict): + self.__dict__.update(state) + return + self.__dict__["_raw_state"] = state + + +_MISSING_PICKLE_GLOBALS: set[tuple[str, str]] = set() + + +def _ensure_dummy_module(module_name: str) -> types.ModuleType: + module = sys.modules.get(module_name) + if isinstance(module, types.ModuleType): + return module + + module = types.ModuleType(module_name) + sys.modules[module_name] = module + if "." in module_name: + parent_name, child_name = module_name.rsplit(".", 1) + parent = _ensure_dummy_module(parent_name) + setattr(parent, child_name, module) + return module + + +def _make_dummy_pickle_global(module_name: str, name: str) -> type[_MissingPickleObject]: + module = _ensure_dummy_module(module_name) + existing = getattr(module, name, None) + if isinstance(existing, type): + return existing + + dummy_type = type(name, (_MissingPickleObject,), {"__module__": module_name}) + setattr(module, name, dummy_type) + _MISSING_PICKLE_GLOBALS.add((module_name, name)) + return dummy_type + + +class _DummyFallbackUnpickler(pickle.Unpickler): + def find_class(self, module: str, name: str) -> Any: + try: + return super().find_class(module, name) + except (AttributeError, ImportError, ModuleNotFoundError): + return _make_dummy_pickle_global(module, name) + + +_DUMMY_FALLBACK_PICKLE_MODULE = types.SimpleNamespace( + __name__="pickle", + Unpickler=_DummyFallbackUnpickler, + load=pickle.load, + loads=pickle.loads, +) + + +@dataclass +class TimelinePaths: + pt_path: Path + cache_path: Path + html_path: Path + + +def _json_safe(value: Any) -> Any: + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, dict): + return {str(k): _json_safe(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_json_safe(v) for v in value] + return str(value) + + +def _round_float(value: float | None) -> float | None: + if value is None: + return None + return round(float(value), 6) + + +def _compact_text(value: Any, max_len: int = 256) -> Any: + value = _json_safe(value) + if not isinstance(value, str): + return value + if len(value) <= max_len: + return value + return f"{value[:max_len]}..." + + +def _safe_duration(start: float | None, end: float | None) -> float | None: + if start is None or end is None: + return None + return max(0.0, float(end) - float(start)) + + +def _to_sample_dict(sample: Any) -> dict[str, Any]: + if hasattr(sample, "to_dict"): + sample = sample.to_dict() + if isinstance(sample, dict): + return sample + result = {} + for key in ( + "group_index", + "index", + "prompt", + "response", + "response_length", + "reward", + "metadata", + "source", + "status", + "label", + "trace", + ): + if hasattr(sample, key): + result[key] = getattr(sample, key) + return result + + +def _infer_source(sample: dict[str, Any], metadata: dict[str, Any]) -> Any: + if sample.get("source") not in (None, ""): + return sample.get("source") + if metadata.get("source") not in (None, ""): + return metadata.get("source") + if metadata.get("source_name") not in (None, ""): + return metadata.get("source_name") + for key, value in metadata.items(): + if "source" in str(key).lower() and value not in (None, ""): + return value + return None + + +def _event_timestamp(event: dict[str, Any]) -> float | None: + ts = event.get("ts") + if ts is None: + return None + try: + return float(ts) + except (TypeError, ValueError): + return None + + +def _normalize_trace_events(trace: dict[str, Any]) -> list[dict[str, Any]]: + raw_events = trace.get("events") or [] + normalized = [] + active_stack: list[str] = [] + + for order, raw_event in enumerate(raw_events): + if not isinstance(raw_event, dict): + continue + ts = _event_timestamp(raw_event) + if ts is None: + continue + + event = { + "order": order, + "ts": ts, + "type": _json_safe(raw_event.get("type")), + "name": _json_safe(raw_event.get("name")), + "attempt": int(raw_event.get("attempt", trace.get("attempt", 0)) or 0), + "sample_id": _json_safe(raw_event.get("sample_id", trace.get("sample_id"))), + "group_id": _json_safe(raw_event.get("group_id", trace.get("group_id"))), + "span_id": _json_safe(raw_event.get("span_id")), + "parent_span_id": _json_safe(raw_event.get("parent_span_id")), + "attrs": _json_safe(raw_event.get("attrs") or {}), + } + event["inferred_parent_span_id"] = active_stack[-1] if active_stack else None + normalized.append(event) + + if event["type"] == "span_start" and event["span_id"]: + active_stack.append(event["span_id"]) + continue + + if event["type"] == "span_end" and event["span_id"]: + for idx in range(len(active_stack) - 1, -1, -1): + if active_stack[idx] == event["span_id"]: + del active_stack[idx] + break + + return normalized + + +def _span_name(item: dict[str, Any]) -> str: + return str(item.get("name") or "span") + + +def _span_type(item: dict[str, Any]) -> str: + if item["type"] == "event": + return "point_event" + if item["type"] == "orphan_end": + return "orphan_end" + return item["state"] + + +def _compute_span_depths(spans: list[dict[str, Any]]) -> dict[str, int]: + span_by_id = {span["span_id"]: span for span in spans if span.get("span_id")} + cache: dict[str, int] = {} + + def resolve(span_id: str | None, seen: set[str]) -> int: + if not span_id or span_id not in span_by_id: + return 0 + if span_id in cache: + return cache[span_id] + if span_id in seen: + cache[span_id] = 0 + return 0 + seen.add(span_id) + parent_id = span_by_id[span_id].get("parent_span_id") + depth = 0 if not parent_id or parent_id not in span_by_id else resolve(parent_id, seen) + 1 + cache[span_id] = depth + return depth + + for span in spans: + span_id = span.get("span_id") + if span_id: + resolve(span_id, set()) + return cache + + +def _build_items_from_trace(sample: dict[str, Any], sample_idx: int) -> dict[str, Any] | None: + trace = sample.get("trace") + if not isinstance(trace, dict): + return None + + events = _normalize_trace_events(trace) + if not events: + return None + + open_starts: dict[str, dict[str, Any]] = {} + closed_spans: list[dict[str, Any]] = [] + point_events: list[dict[str, Any]] = [] + orphan_ends: list[dict[str, Any]] = [] + all_timestamps: list[float] = [] + + for event in events: + all_timestamps.append(event["ts"]) + event_type = event["type"] + + if event_type == "span_start" and event["span_id"]: + open_starts[event["span_id"]] = { + "type": "span", + "state": "closed_span", + "name": event["name"], + "start_ts": event["ts"], + "end_ts": None, + "display_end_ts": None, + "attempt": event["attempt"], + "span_id": event["span_id"], + "parent_span_id": event.get("parent_span_id"), + "start_attrs": event.get("attrs") or {}, + "end_attrs": {}, + } + continue + + if event_type == "span_end": + span_id = event.get("span_id") + start_record = open_starts.pop(span_id, None) if span_id else None + if start_record is None: + orphan_ends.append( + { + "type": "orphan_end", + "state": "orphan_end", + "name": event["name"], + "ts": event["ts"], + "attempt": event["attempt"], + "span_id": span_id, + "parent_span_id": event.get("parent_span_id") or event.get("inferred_parent_span_id"), + "attrs": event.get("attrs") or {}, + } + ) + continue + + start_record["end_ts"] = event["ts"] + start_record["display_end_ts"] = event["ts"] + start_record["end_attrs"] = event.get("attrs") or {} + closed_spans.append(start_record) + continue + + point_events.append( + { + "type": "event", + "state": "point_event", + "name": event["name"], + "ts": event["ts"], + "attempt": event["attempt"], + "span_id": None, + "parent_span_id": event.get("inferred_parent_span_id"), + "attrs": event.get("attrs") or {}, + } + ) + + row_end_ts = max(all_timestamps) if all_timestamps else None + open_spans = list(open_starts.values()) + all_spans = closed_spans + open_spans + span_depths = _compute_span_depths(all_spans) + span_by_id = {span["span_id"]: span for span in all_spans if span.get("span_id")} + sibling_groups: dict[str | None, list[dict[str, Any]]] = {} + + for span in all_spans: + sibling_groups.setdefault(span.get("parent_span_id"), []).append(span) + + for siblings in sibling_groups.values(): + siblings.sort(key=lambda item: (item["start_ts"], item["end_ts"] or float("inf"), _span_name(item))) + + def nearest_closed_ancestor_end(span: dict[str, Any]) -> float | None: + current_parent = span.get("parent_span_id") + while current_parent: + parent = span_by_id.get(current_parent) + if parent is None: + return None + if parent.get("end_ts") is not None: + return float(parent["end_ts"]) + current_parent = parent.get("parent_span_id") + return None + + for span in open_spans: + candidates: list[tuple[float, str]] = [] + if row_end_ts is not None: + candidates.append((row_end_ts, "row_end")) + + siblings = sibling_groups.get(span.get("parent_span_id"), []) + for sibling in siblings: + if sibling is span: + continue + sibling_start = float(sibling["start_ts"]) + if sibling_start > float(span["start_ts"]): + candidates.append((sibling_start, "next_sibling_start")) + break + + ancestor_end = nearest_closed_ancestor_end(span) + if ancestor_end is not None and ancestor_end > float(span["start_ts"]): + candidates.append((ancestor_end, "ancestor_end")) + + if candidates: + display_end_ts, clipped_by = min(candidates, key=lambda item: item[0]) + if display_end_ts <= float(span["start_ts"]): + display_end_ts = float(span["start_ts"]) + clipped_by = "self" + else: + display_end_ts = float(span["start_ts"]) + clipped_by = "self" + + span["state"] = "open_span" + span["display_end_ts"] = display_end_ts + span.setdefault("end_attrs", {}) + span["end_attrs"]["clipped_by"] = clipped_by + + for span in all_spans: + span["depth"] = span_depths.get(span.get("span_id") or "", 0) + span["lane"] = span["depth"] + + for event in point_events: + parent_span_id = event.get("parent_span_id") + event["depth"] = span_depths.get(parent_span_id or "", 0) + event["lane"] = event["depth"] + + for item in orphan_ends: + parent_span_id = item.get("parent_span_id") + item["depth"] = span_depths.get(parent_span_id or "", 0) + item["lane"] = item["depth"] + + def parent_span_name(parent_span_id: str | None) -> str | None: + if not parent_span_id: + return None + parent = span_by_id.get(parent_span_id) + if not parent: + return None + return parent.get("name") + + all_items: list[dict[str, Any]] = [] + for span in all_spans: + all_items.append( + { + "type": "span", + "state": span["state"], + "name": span["name"], + "start_ts": _round_float(span["start_ts"]), + "end_ts": _round_float(span["end_ts"]), + "display_end_ts": _round_float(span["display_end_ts"]), + "attempt": span["attempt"], + "span_id": span.get("span_id"), + "parent_span_id": span.get("parent_span_id"), + "parent_span_name": parent_span_name(span.get("parent_span_id")), + "lane": span["lane"], + "depth": span["depth"], + "attrs": { + "start_attrs": _json_safe(span.get("start_attrs") or {}), + "end_attrs": _json_safe(span.get("end_attrs") or {}), + }, + } + ) + + for event in point_events: + all_items.append( + { + "type": "event", + "state": "point_event", + "name": event["name"], + "ts": _round_float(event["ts"]), + "attempt": event["attempt"], + "span_id": None, + "parent_span_id": event.get("parent_span_id"), + "parent_span_name": parent_span_name(event.get("parent_span_id")), + "lane": event["lane"], + "depth": event["depth"], + "attrs": _json_safe(event.get("attrs") or {}), + } + ) + + for item in orphan_ends: + all_items.append( + { + "type": "orphan_end", + "state": "orphan_end", + "name": item["name"], + "ts": _round_float(item["ts"]), + "attempt": item["attempt"], + "span_id": item.get("span_id"), + "parent_span_id": item.get("parent_span_id"), + "parent_span_name": parent_span_name(item.get("parent_span_id")), + "lane": item["lane"], + "depth": item["depth"], + "attrs": _json_safe(item.get("attrs") or {}), + } + ) + + pd_lane_specs = [ + ( + "prefill", + "P", + [ + "pd_prefill_bootstrap_queue_duration", + "pd_bootstrap_duration", + "pd_alloc_waiting_duration", + "pd_prefill_forward_duration", + "pd_prefill_transfer_queue_duration", + ], + ), + ( + "decode", + "D", + [ + "pd_decode_prealloc_duration", + "pd_decode_transfer_duration", + "pd_decode_forward_duration", + ], + ), + ] + next_virtual_lane = max((item["lane"] for item in all_items), default=-1) + for span in all_spans: + if span["state"] != "closed_span" or span.get("end_ts") is None: + continue + end_attrs = span.get("end_attrs") or {} + for role, suffix, keys in pd_lane_specs: + role_attrs = { + key: value for key in keys if isinstance((value := end_attrs.get(key)), (int, float)) and value > 0 + } + if not role_attrs: + continue + next_virtual_lane += 1 + role_attrs.update( + { + "timeline_pd_virtual_role": role, + "timeline_pd_parent_name": span["name"], + "timeline_pd_parent_duration": _round_float(_safe_duration(span["start_ts"], span["end_ts"])), + } + ) + all_items.append( + { + "type": "span", + "state": "closed_span", + "name": f'{span["name"]} [{suffix}]', + "start_ts": _round_float(span["start_ts"]), + "end_ts": _round_float(span["end_ts"]), + "display_end_ts": _round_float(span["display_end_ts"]), + "attempt": span["attempt"], + "span_id": f'{span.get("span_id") or span["name"]}:pd:{role}', + "parent_span_id": span.get("span_id"), + "parent_span_name": span["name"], + "lane": next_virtual_lane, + "depth": next_virtual_lane, + "attrs": { + "start_attrs": {}, + "end_attrs": _json_safe(role_attrs), + }, + } + ) + + all_items.sort( + key=lambda item: ( + item["lane"], + item.get("start_ts", item.get("ts", 0.0)), + item.get("display_end_ts", item.get("ts", 0.0)), + item["name"], + ) + ) + + row_start = min(item.get("start_ts", item.get("ts")) for item in all_items) + row_end = max(item.get("display_end_ts", item.get("ts")) for item in all_items) + response_lengths = [] + for item in all_items: + attrs = item.get("attrs") or {} + for payload in (attrs, attrs.get("start_attrs"), attrs.get("end_attrs")): + if not isinstance(payload, dict): + continue + response_length = payload.get("response_length") + if isinstance(response_length, (int, float)): + response_lengths.append(int(response_length)) + + metadata = sample.get("metadata") or {} + if not isinstance(metadata, dict): + metadata = {} + + reward = sample.get("reward") + if isinstance(reward, dict): + reward = _json_safe(reward) + + return { + "row_id": sample_idx, + "sample_index": sample.get("index", sample_idx), + "group_index": sample.get("group_index"), + "source": _compact_text(_infer_source(sample, metadata), max_len=64), + "status": _compact_text(sample.get("status"), max_len=64), + "label": _compact_text(sample.get("label"), max_len=256), + "reward": reward, + "trace_id": _json_safe(trace.get("trace_id")), + "attempt": int(trace.get("attempt", 0) or 0), + "start": row_start, + "end": row_end, + "duration": _round_float(_safe_duration(row_start, row_end)), + "lane_count": 1 + max((item["lane"] for item in all_items), default=0), + "item_count": len(all_items), + "closed_span_count": sum(1 for item in all_items if item["state"] == "closed_span"), + "open_span_count": sum(1 for item in all_items if item["state"] == "open_span"), + "point_event_count": sum(1 for item in all_items if item["state"] == "point_event"), + "orphan_count": sum(1 for item in all_items if item["state"] == "orphan_end"), + "total_response_length": sum(response_lengths), + "max_response_length": max(response_lengths, default=0), + "items": all_items, + } + + +def _build_cache_data(pt_path: Path) -> dict[str, Any]: + before_missing = len(_MISSING_PICKLE_GLOBALS) + data = torch.load( + pt_path, + map_location="cpu", + weights_only=False, + pickle_module=_DUMMY_FALLBACK_PICKLE_MODULE, + ) + if len(_MISSING_PICKLE_GLOBALS) > before_missing: + missing_names = ", ".join(f"{module}.{name}" for module, name in sorted(_MISSING_PICKLE_GLOBALS)) + print( + f"[trace_timeline_viewer] substituted missing pickle globals with dummy classes: {missing_names}", + file=sys.stderr, + ) + samples = data["samples"] if isinstance(data, dict) and "samples" in data else data + + rows: list[dict[str, Any]] = [] + global_start = None + global_end = None + + for sample_idx, raw_sample in enumerate(samples): + sample = _to_sample_dict(raw_sample) + row = _build_items_from_trace(sample, sample_idx) + if row is None: + continue + rows.append(row) + global_start = row["start"] if global_start is None else min(global_start, row["start"]) + global_end = row["end"] if global_end is None else max(global_end, row["end"]) + + return { + "cache_version": CACHE_VERSION, + "pt_path": str(pt_path), + "generated_at": time.time(), + "sample_count": len(rows), + "global_start": _round_float(global_start), + "global_end": _round_float(global_end), + "rows": rows, + } + + +def _timeline_paths(pt_path: Path) -> TimelinePaths: + stem = pt_path.stem + directory = pt_path.parent + return TimelinePaths( + pt_path=pt_path, + cache_path=directory / f"{stem}.trace_timeline_cache.json", + html_path=directory / f"{stem}.trace_timeline_viewer.html", + ) + + +def ensure_cache(paths: TimelinePaths, rebuild: bool = False) -> dict[str, Any]: + if not rebuild and paths.cache_path.exists() and paths.cache_path.stat().st_mtime >= paths.pt_path.stat().st_mtime: + with paths.cache_path.open("r", encoding="utf-8") as handle: + cached = json.load(handle) + if cached.get("cache_version") == CACHE_VERSION: + return cached + + cache_data = _build_cache_data(paths.pt_path) + with paths.cache_path.open("w", encoding="utf-8") as handle: + json.dump(cache_data, handle, ensure_ascii=True, separators=(",", ":")) + return cache_data + + +HTML_TEMPLATE = r""" + + + + + __TITLE__ + + + +
+
+
+
+ Trace Timeline + +
+ +
+
+
+ + + + + + + + + + + + +
+
+ + drag = pan, wheel = zoom, click = set cursor and select item + + +
+
+
+ +
+
+
+ +
+
+
+
+ +
+
+ +
+
+
+ + + + +""" + + +def ensure_html(paths: TimelinePaths) -> None: + title = f"{paths.pt_path.name} trace timeline" + html = HTML_TEMPLATE.replace("__CACHE_FILE__", paths.cache_path.name).replace("__TITLE__", title) + with paths.html_path.open("w", encoding="utf-8") as handle: + handle.write(html) + + +class QuietHandler(SimpleHTTPRequestHandler): + def log_message(self, format: str, *args: Any) -> None: + return + + +def serve_directory(directory: Path, port: int) -> None: + handler = functools.partial(QuietHandler, directory=str(directory)) + with socketserver.TCPServer(("0.0.0.0", port), handler) as httpd: + print(f"Serving http://127.0.0.1:{port}/") + print("Press Ctrl+C to stop.") + try: + httpd.serve_forever() + except KeyboardInterrupt: + pass + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("pt_path", help="Path to rollout debug dump .pt file") + parser.add_argument("--rebuild", action="store_true", help="Rebuild cache even if it already exists") + parser.add_argument( + "--serve", + action=argparse.BooleanOptionalAction, + default=True, + help="Start a local static file server for the generated HTML", + ) + parser.add_argument("--port", type=int, default=9999, help="Port for --serve") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + pt_path = Path(args.pt_path).expanduser().resolve() + if not pt_path.exists(): + raise SystemExit(f"pt file not found: {pt_path}") + + paths = _timeline_paths(pt_path) + cache_data = ensure_cache(paths, rebuild=args.rebuild) + ensure_html(paths) + + print(f"pt: {paths.pt_path}") + print(f"cache: {paths.cache_path}") + print(f"html: {paths.html_path}") + print(f"samples: {cache_data['sample_count']}") + + if args.serve: + serve_directory(paths.html_path.parent, args.port) + + +if __name__ == "__main__": + main()