From 19363670eb5b347ff33e20346d396035213a83a2 Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> Date: Thu, 9 Apr 2026 10:41:11 -0700 Subject: [PATCH 1/4] [#11548][feat] AutoDeploy: Optimize Qwen3.5 perf (#12265) Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> --- .../configs/qwen3.5_moe_400b.yaml | 20 +- .../models/custom/modeling_qwen3_5_moe.py | 41 ++- .../transform/library/quantization.py | 4 + .../auto_deploy/transform/library/sharding.py | 128 ++++++++- .../_torch/auto_deploy/utils/_graph.py | 10 + .../library/test_tp_sharding.py | 245 +++++++++++++++++- .../singlegpu/models/test_qwen3_5_moe.py | 13 +- 7 files changed, 437 insertions(+), 24 deletions(-) diff --git a/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml b/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml index 3056ab839b1..89ccabe63c8 100644 --- a/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml +++ b/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml @@ -2,12 +2,16 @@ runtime: trtllm compile_backend: torch-cudagraph attn_backend: trtllm max_seq_len: 262144 -max_num_tokens: 8192 -max_batch_size: 32 +max_num_tokens: 16000 +max_batch_size: 256 cuda_graph_config: - batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + batch_sizes: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 64, 128, 256] +world_size: 8 enable_chunked_prefill: true -model_factory: Qwen3_5MoeForConditionalGeneration +# For text-only mode, use AutoModelForCausalLM until issue #12699 is resolved +# Once issue #12699 is resolved, consider to unify the factory to Qwen3_5MoeForConditionGeneration for both VLM and text mode +# model_factory: Qwen3_5MoeForConditionalGeneration +model_factory: AutoModelForCausalLM kv_cache_config: enable_block_reuse: false free_gpu_memory_fraction: 0.8 @@ -15,13 +19,18 @@ kv_cache_config: model_kwargs: torch_dtype: bfloat16 transforms: + # disable for text only use case initialize_mrope_delta_cache: enabled: true export_to_gm: num_moe_experts_for_export: 2 fuse_gemms_mixed_children: enabled: true + fuse_nvfp4_moe: + backend: trtllm_gen detect_sharding: + # for long input, tp8ep1 gives better performance + # dist_mapping: {moe_tp: 8, moe_ep: 1} allreduce_strategy: SYMM_MEM shard_all_unprocessed: true simple_shard_filter: "lm_head" @@ -37,6 +46,9 @@ transforms: "k_proj": "colwise" "v_proj": "colwise" "o_proj": "rowwise" + # lm_head: "gather" = column split + all_gather (not "colwise" which + # requires a LayerSubgraph and crashes for standalone unprocessed nodes) + "lm_head": "gather" # replicating shared experts (keep them commented out) # "shared_expert_gate_proj": "colwise" # "shared_expert_up_proj": "colwise" diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py index e227bc7ebec..6afe6481165 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py @@ -727,10 +727,15 @@ class Qwen3_5MoeCausalLMOutput(ModelOutput): """Output of the Qwen3.5 MoE causal language model.""" logits: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None class Qwen3_5MoeTextModel(Qwen3_5MoePreTrainedModel): - """Qwen3.5 MoE text model (embed + decoder layers + final norm).""" + """Qwen3.5 MoE text model (embed + decoder layers + final norm + lm_head). + + lm_head is included so that the exported GraphModule contains it directly, + allowing sharding and gather_logits_before_lm_head transforms to see it. + """ def __init__(self, config: Qwen3_5MoeTextConfig): super().__init__(config) @@ -746,10 +751,15 @@ def __init__(self, config: Qwen3_5MoeTextConfig): ) self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3_5MoeTextRotaryEmbedding(config=config) + self.lm_head = None # set by parent model via set_lm_head() # Initialize weights and apply final processing self.post_init() + def set_lm_head(self, lm_head: nn.Module): + """Set the lm_head from the parent model.""" + self.lm_head = lm_head + def get_input_embeddings(self): return self.embed_tokens @@ -801,7 +811,11 @@ def forward( hidden_states = decoder_layer(hidden_states, position_embeddings=position_embeddings) hidden_states = self.norm(hidden_states) - return Qwen3_5MoeOutput(last_hidden_state=hidden_states) + assert self.lm_head is not None, ( + "lm_head not set — call set_lm_head() from the parent model before forward()" + ) + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + return Qwen3_5MoeCausalLMOutput(logits=logits, last_hidden_state=hidden_states) class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin): @@ -814,6 +828,7 @@ def __init__(self, config: Qwen3_5MoeTextConfig, **kwargs): self.model = Qwen3_5MoeTextModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.model.set_lm_head(self.lm_head) # Initialize weights and apply final processing self.post_init() @@ -829,6 +844,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + self.model.set_lm_head(new_embeddings) def forward( self, @@ -848,8 +864,7 @@ def forward( rope_cos=rope_cos, rope_sin=rope_sin, ) - hidden_states = outputs[0] - logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + logits = outputs.logits return Qwen3_5MoeCausalLMOutput(logits=logits) @@ -2565,10 +2580,19 @@ def __init__(self, config: Qwen3_5MoeConfig, **kwargs): self.lm_head = nn.Linear( config.text_config.hidden_size, config.text_config.vocab_size, bias=False ) + # Share lm_head with the text model so it's inside the exported graph + self.model.language_model.set_lm_head(self.lm_head) # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.language_model.get_input_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + self.model.language_model.set_lm_head(new_embeddings) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -2590,8 +2614,7 @@ def forward( video_grid_thw=video_grid_thw, **kwargs, ) - hidden_states = outputs.last_hidden_state - logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + logits = outputs.logits return Qwen3_5MoeConditionalOutput(logits=logits) @@ -2607,6 +2630,9 @@ class Qwen3_5MoeTextExportInfo(TextModelExportInfo): (batch, sequence) are dynamic. """ + def __init__(self, submodule_name: str): + super().__init__(submodule_name) + def _init_dynamic_shape_lookup(self): base = super()._init_dynamic_shape_lookup() batch_size_dyn = Dim.DYNAMIC @@ -2858,4 +2884,7 @@ def init_input_processor(self, base): AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeTextConfig) AutoModelForCausalLMFactory.register_custom_model_cls("Qwen3_5MoeTextConfig", Qwen3_5MoeForCausalLM) +AutoModelForCausalLMFactory.register_custom_model_cls( + "Qwen3_5MoeConfig", Qwen3_5MoeForConditionalGeneration +) Qwen3_5MoeFactory.register_custom_model_cls("Qwen3_5MoeConfig", Qwen3_5MoeForConditionalGeneration) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index a3cf9661ee3..633d0500e9f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -405,6 +405,10 @@ def build_custom_args_for_linear(self, scales: Dict[str, Node]) -> Tuple: return ([scales["input_scale"]], [scales["weight_scale"], scales["alpha"]], [], []) def load_hook(self, state_dict, prefix, *args, weight_name): + # Prepend prefix so the hook works when the GraphModule is a submodule + # of the model on which load_state_dict is called (e.g., VLM models + # where the text model lives at model.language_model.*). + weight_name = prefix + weight_name if weight_name in state_dict: input_scale_name = weight_name.rsplit(".", 1)[0] + ".input_scale" alpha_name = weight_name.rsplit(".", 1)[0] + ".alpha" diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 2983b2c4bf6..ba9b1a325c4 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -674,7 +674,8 @@ def shard_load_hook( world_size: int, min_local_shape: int = 1, ) -> None: - scale_key = weight_name + "_scale_inv" + # Prepend prefix for VLM models where gm is a submodule + scale_key = prefix + weight_name + "_scale_inv" if scale_key in state_dict: scale = state_dict[scale_key] weight_original_n = weight_original_shape[dim] @@ -759,7 +760,8 @@ def shard_load_hook( world_size: int, min_local_shape: int = 1, ) -> None: - key = weight_name + "_scale" + # Prepend prefix for VLM models where gm is a submodule + key = prefix + weight_name + "_scale" if key in state_dict: state_dict[key] = _shard_fp4_weight_scale( state_dict[key], @@ -1607,19 +1609,23 @@ def f_split( sharded_weight = f_split(weight_tensor) sharded_shape = sharded_weight.shape - # Register load hook - gm._register_load_state_dict_pre_hook( + # Update the parameter in the module + modname, _, param_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) + + # Register load hook on the owning submodule (not the top-level gm). + # This ensures the hook runs *after* any parent-level hooks that transform + # the state_dict (e.g., unfusing fused MoE checkpoint weights into + # individual expert keys). With the hook on gm, it would run before + # unfusing and fail to find the individual expert keys. + submod._register_load_state_dict_pre_hook( partial( _load_hook, f_split=f_split, - param_key=param_key, + param_key=param_name, param_shape=sharded_shape, ) ) - - # Update the parameter in the module - modname, _, param_name = param_key.rpartition(".") - submod = gm.get_submodule(modname) param_new = nn.Parameter(sharded_weight.detach().clone(), requires_grad=requires_grad) setattr(submod, param_name, param_new) @@ -1786,6 +1792,84 @@ def _merge_arg(current_arg: Any, stored_arg: Any) -> Any: ad_logger.debug(f"Updated node {node}: sharded arguments are now {node.args}.") +def _shard_nvfp4_moe_scale( + scale: torch.Tensor, + orig_weight_shape: torch.Size, + dim: int, + rank: int, + world_size: int, +) -> torch.Tensor: + """Shard NVFP4 weight_scale for MoE TP, preserving 2D cutlass format. + + Unlike _shard_fp4_weight_scale (which returns 1D), this returns a 2D tensor + with the correct padded shape, matching the format expected by MoE stacking. + """ + weight_shape_elements = list(orig_weight_shape) + weight_shape_elements[-1] *= 2 # uint8 -> element count (FP4 packs 2 per byte) + modelopt_scale = cutlass_fp4_scale_to_modelopt_fp4_scale(scale, tuple(weight_shape_elements)) + sharded = _split_tensor_for_tp(modelopt_scale, dim, rank, world_size) + m, n = sharded.shape + # Pad to match CUTLASS FP4 scale swizzle alignment requirements: + # 128 rows (4 * 32 tile in M dim) and 4 columns (N dim grouping). + # See modelopt_fp4_scale_to_cutlass_fp4_scale in quantization_utils.py. + pad_m = (128 - m % 128) % 128 + pad_n = (4 - n % 4) % 4 + result_1d = modelopt_fp4_scale_to_cutlass_fp4_scale(sharded) + return result_1d.reshape(m + pad_m, n + pad_n) + + +def _tp_shard_moe_scale( + gm: GraphModule, + scale_node: Node, + scale_name: str, + dim: int, + rank: int, + world_size: int, + orig_weight_shape: torch.Size, +) -> None: + """TP-shard a single MoE expert's blocked scale tensor. + + For NVFP4 (weight_scale): converts from cutlass format, splits, reconverts to 2D. + For FineGrained FP8 (weight_scale_inv): directly splits the 2D scale tensor. + """ + param_key = scale_node.target + modname, _, attr_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) + scale_tensor = submod.get_buffer(attr_name) + + if scale_name == "weight_scale": + f_split = partial( + _shard_nvfp4_moe_scale, + orig_weight_shape=orig_weight_shape, + dim=dim, + rank=rank, + world_size=world_size, + ) + elif scale_name == "weight_scale_inv": + f_split = partial( + FineGrainedFP8WeightShardingInfo._split_scale, + dim=dim, + rank=rank, + world_size=world_size, + ) + else: + return + + sharded_scale = f_split(scale_tensor) + submod.register_buffer(attr_name, sharded_scale) + + # Register load hook on the owning submodule so it runs after any + # parent-level checkpoint format conversion hooks (e.g., fused MoE unfusing). + submod._register_load_state_dict_pre_hook( + partial( + _load_hook, + f_split=f_split, + param_key=attr_name, + param_shape=sharded_scale.shape, + ) + ) + + def _insert_sharded_moe( gm: GraphModule, node: Node, @@ -1846,6 +1930,11 @@ def get_partition(lst, world_size, rank): # if tp_size > 1, we do 2D EP+TP sharding. if tp_size > 1: + # Capture original weight shapes before TP sharding (needed for scale TP sharding) + w_up_orig_shapes = [gm.get_parameter(w.target).shape for w in w_up_list_sharded] + w_down_orig_shapes = [gm.get_parameter(w.target).shape for w in w_down_list_sharded] + w_gate_orig_shapes = [gm.get_parameter(w.target).shape for w in w_gate_list_sharded] + # we add TP sharding of all expert weights. for w_up in w_up_list_sharded + w_gate_list_sharded: shard_weight_tensor( @@ -1882,6 +1971,27 @@ def get_partition(lst, world_size, rank): args[6 + i] = sharded scales_to_remove.extend(to_remove) + # ===================================================================================== + # TP-shard blocked scales (weight_scale for NVFP4, weight_scale_inv for FineGrained FP8) + # ===================================================================================== + if tp_size > 1 and scale_names: + _BLOCKED_SCALE_NAMES = {"weight_scale", "weight_scale_inv"} + for s_idx, s_name in enumerate(scale_names): + if s_name not in _BLOCKED_SCALE_NAMES: + continue + # For each scale_name, the 3 lists correspond to w_up, w_down, w_gate + # w_up/w_gate use COLUMN split (dim=0), w_down uses ROW split (dim=1) + scale_dim_groups = [ + (6 + s_idx * 3 + 0, SplitDimension.COLUMN, w_up_orig_shapes), + (6 + s_idx * 3 + 1, SplitDimension.ROW, w_down_orig_shapes), + (6 + s_idx * 3 + 2, SplitDimension.COLUMN, w_gate_orig_shapes), + ] + for arg_idx, dim, orig_shapes in scale_dim_groups: + for j, scale_node in enumerate(args[arg_idx]): + _tp_shard_moe_scale( + gm, scale_node, s_name, dim, tp_rank, tp_size, orig_shapes[j] + ) + if enable_alltoall: # --------------------------------------------------------------------------- # ALL-TO-ALL PARADIGM diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index 6ba653cec9b..117301c3302 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -571,6 +571,16 @@ def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Nod if is_op(lm_head_node, torch.ops.aten.to): lm_head_node = lm_head_node.all_input_nodes[0] + # Unwrap all_gather for sharded lm_head: when lm_head weight is column- + # sharded the graph contains lm_head_linear -> all_gather -> output. + # We look through the all_gather so that callers (e.g. + # gather_logits_before_lm_head) see the underlying linear and can insert + # gather_tokens *before* the sharded GEMM + all_gather, keeping both out + # of the main CUDA graph and avoiding NVLink contention with layer + # AllReduces. + if is_op(lm_head_node, torch.ops.auto_deploy.trtllm_dist_all_gather): + lm_head_node = lm_head_node.all_input_nodes[0] + return lm_head_node diff --git a/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py index fe1827fa216..f1667be8e2c 100644 --- a/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py @@ -11,7 +11,8 @@ import torch.nn.functional as F from _dist_test_utils import get_device_counts from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm -from _model_test_utils import FakeFineGrainedFP8Linear, FakeFP8Linear +from _model_test_utils import FakeFineGrainedFP8Linear, FakeFP8Linear, MoEOpModel +from _torch_test_utils import fp4_compatible, trtllm_ops_available from torch._inductor.pattern_matcher import stable_topological_sort import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common @@ -31,6 +32,7 @@ from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op, is_weight_node from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import ( cutlass_fp4_scale_to_modelopt_fp4_scale, + fp4_global_scale, modelopt_fp4_scale_to_cutlass_fp4_scale, ) from tensorrt_llm.functional import AllReduceStrategy @@ -1373,3 +1375,244 @@ def test_pad_nvfp4_weight_scale_roundtrip(n, k): if k_padded > k: pad_region_k = padded_modelopt[:n, k // block_size :] assert (pad_region_k.float() == 0).all(), "k-padding region should be zero" + + +class NVFP4MoEOpModel(nn.Module): + """NVFP4-quantized MoE model using torch.ops.trtllm.fp4_quantize. + + Mimics the real Qwen3.5-MoE NVFP4 checkpoint loading path where weights are + quantized via fp4_quantize and scales are in cutlass swizzled 2D format. + Dimensions must be compatible with NVFP4 block size (16) and cutlass alignment. + """ + + SCALING_VECTOR_SIZE = 16 + + def __init__(self, hidden_size=128, intermediate_size=256, num_experts=4, top_k=2): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts = num_experts + self.top_k = top_k + + self.gate = nn.Linear(hidden_size, num_experts) + + for i in range(num_experts): + w1_bf16 = (torch.randn(intermediate_size, hidden_size) * 0.1).to(torch.bfloat16) + w2_bf16 = (torch.randn(hidden_size, intermediate_size) * 0.1).to(torch.bfloat16) + w3_bf16 = (torch.randn(intermediate_size, hidden_size) * 0.1).to(torch.bfloat16) + + inp_scale = fp4_global_scale(w1_bf16) + + for prefix, w_bf16 in [("w1", w1_bf16), ("w2", w2_bf16), ("w3", w3_bf16)]: + wt_scale_2 = fp4_global_scale(w_bf16) + w_fp4, w_scale_1d = torch.ops.trtllm.fp4_quantize( + w_bf16.cuda(), wt_scale_2.cuda(), self.SCALING_VECTOR_SIZE, False + ) + _, k_packed = w_fp4.shape + k_elements = k_packed * 2 # uint8 packs 2 fp4 values + n_scale = k_elements // self.SCALING_VECTOR_SIZE + m_scale = w_scale_1d.numel() // n_scale + w_scale_2d = w_scale_1d.reshape(m_scale, n_scale).contiguous() + + self.register_parameter( + f"expert_{i}_{prefix}", + nn.Parameter(w_fp4.cpu(), requires_grad=False), + ) + self.register_buffer(f"expert_{i}_{prefix}_input_scale", inp_scale) + self.register_buffer(f"expert_{i}_{prefix}_weight_scale", w_scale_2d.cpu()) + self.register_buffer( + f"expert_{i}_{prefix}_alpha", + (1.0 / (inp_scale * wt_scale_2)).cpu(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + router_logits = self.gate(x) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(x.dtype) + + w1 = [getattr(self, f"expert_{i}_w1") for i in range(self.num_experts)] + w2 = [getattr(self, f"expert_{i}_w2") for i in range(self.num_experts)] + w3 = [getattr(self, f"expert_{i}_w3") for i in range(self.num_experts)] + w1_is = [getattr(self, f"expert_{i}_w1_input_scale") for i in range(self.num_experts)] + w2_is = [getattr(self, f"expert_{i}_w2_input_scale") for i in range(self.num_experts)] + w3_is = [getattr(self, f"expert_{i}_w3_input_scale") for i in range(self.num_experts)] + w1_ws = [getattr(self, f"expert_{i}_w1_weight_scale") for i in range(self.num_experts)] + w2_ws = [getattr(self, f"expert_{i}_w2_weight_scale") for i in range(self.num_experts)] + w3_ws = [getattr(self, f"expert_{i}_w3_weight_scale") for i in range(self.num_experts)] + w1_a = [getattr(self, f"expert_{i}_w1_alpha") for i in range(self.num_experts)] + w2_a = [getattr(self, f"expert_{i}_w2_alpha") for i in range(self.num_experts)] + w3_a = [getattr(self, f"expert_{i}_w3_alpha") for i in range(self.num_experts)] + + return torch.ops.auto_deploy.torch_quant_nvfp4_moe( + x, + selected_experts, + routing_weights, + w1, + w2, + w3, + w1_is, + w2_is, + w3_is, + w1_ws, + w2_ws, + w3_ws, + w1_a, + w2_a, + w3_a, + ) + + def get_input(self, device, dtype=torch.bfloat16): + return torch.randn(2, self.hidden_size, device=device, dtype=dtype) + + +def _run_nvfp4_moe_tp_shard_job( + num_experts: int, + _rank: int, + world_size: int, +) -> None: + """Run NVFP4 MoE TP sharding test. See NVFP4MoEOpModel for scale format details.""" + device = "cuda" + hidden_size = 128 + intermediate_size = 256 + model = NVFP4MoEOpModel( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts=num_experts, + ).to(device=device) + x = model.get_input(device=device, dtype=torch.bfloat16) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + + # Apply MoE TP sharding with moe_tp=world_size, moe_ep=1 + gm_transformed = InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "sharding_dims": ["ep"], + "dist_mapping": {"moe_tp": world_size, "moe_ep": 1}, + }, + "sharding_transform_executor": { + "stage": "sharding", + }, + }, + )(None, gm) + + # Verify all_reduce is inserted after MoE node + allreduce_correct = any( + is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm_transformed.graph.nodes + ) == (world_size > 1) + assert allreduce_correct, f"Expected all_reduce present={world_size > 1} after MoE TP sharding" + + # Verify: NVFP4 expert weights should be sharded along TP dimension + # FP4 weights are uint8-packed (2 elements per byte), so packed dim is halved + if world_size > 1: + for name, param in gm_transformed.named_parameters(): + if "experts" in name and "w1" in name: + # w1 (up_proj) column-sharded: intermediate_size // world_size rows + assert param.shape[0] == intermediate_size // world_size, ( + f"w1 {name} shape {param.shape} not TP-sharded " + f"(expected dim0={intermediate_size // world_size})" + ) + elif "experts" in name and "w2" in name: + # w2 (down_proj) row-sharded: packed k dim = intermediate_size // world_size // 2 + expected_k_packed = intermediate_size // world_size // 2 + assert param.shape[1] == expected_k_packed, ( + f"w2 {name} shape {param.shape} not TP-sharded " + f"(expected packed dim1={expected_k_packed})" + ) + elif "experts" in name and "w3" in name: + # w3 (gate_proj) column-sharded: intermediate_size // world_size rows + assert param.shape[0] == intermediate_size // world_size, ( + f"w3 {name} shape {param.shape} not TP-sharded " + f"(expected dim0={intermediate_size // world_size})" + ) + + +def _run_moe_tp_shard_job( + num_experts: int, + _rank: int, + world_size: int, +) -> None: + """Run BF16 MoE TP sharding test.""" + device = "cuda" + hidden_size = 32 + intermediate_size = 16 + model = MoEOpModel( + hidden_size=hidden_size, + num_experts=num_experts, + intermediate_size=intermediate_size, + ).to(device=device, dtype=torch.bfloat16) + x = model.get_input(device=device, dtype=torch.bfloat16) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "sharding_dims": ["ep"], + "dist_mapping": {"moe_tp": world_size, "moe_ep": 1}, + }, + "sharding_transform_executor": { + "stage": "sharding", + }, + }, + )(None, gm) + + # Verify: TP sharding should insert all_reduce after MoE node + allreduce_correct = any( + is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm_transformed.graph.nodes + ) == (world_size > 1) + assert allreduce_correct, ( + f"Expected all_reduce present={world_size > 1} after MoE TP sharding, " + f"world_size={world_size}" + ) + + # Verify: expert weights should be sharded along TP dimension + if world_size > 1: + for name, param in gm_transformed.named_parameters(): + if "experts" in name and "w1" in name: + # w1 (up_proj) is column-sharded: intermediate_size // world_size + assert param.shape[0] == intermediate_size // world_size, ( + f"w1 {name} shape {param.shape} not TP-sharded " + f"(expected dim0={intermediate_size // world_size})" + ) + elif "experts" in name and "w2" in name: + # w2 (down_proj) is row-sharded: hidden_size x (intermediate_size // world_size) + assert param.shape[1] == intermediate_size // world_size, ( + f"w2 {name} shape {param.shape} not TP-sharded " + f"(expected dim1={intermediate_size // world_size})" + ) + elif "experts" in name and "w3" in name: + # w3 (gate_proj) is column-sharded: intermediate_size // world_size + assert param.shape[0] == intermediate_size // world_size, ( + f"w3 {name} shape {param.shape} not TP-sharded " + f"(expected dim0={intermediate_size // world_size})" + ) + + +@pytest.mark.parametrize("device_count", get_device_counts([2, 8])) +@pytest.mark.parametrize("num_experts", [4, 8]) +def test_moe_tp_shard_bf16(device_count: int, num_experts: int): + """Test MoE TP sharding with BF16 weights.""" + dist_common.spawn_multiprocess_job( + job=partial(_run_moe_tp_shard_job, num_experts), + size=device_count, + ) + + +@pytest.mark.skipif( + not (fp4_compatible() and trtllm_ops_available()), + reason="Requires NVFP4 support (SM100+) and TRTLLM ops", +) +@pytest.mark.parametrize("device_count", get_device_counts([2, 8])) +@pytest.mark.parametrize("num_experts", [4, 8]) +def test_moe_tp_shard_nvfp4(device_count: int, num_experts: int): + """Test MoE TP sharding with NVFP4 quantized weights (Qwen3.5-like).""" + dist_common.spawn_multiprocess_job( + job=partial(_run_nvfp4_moe_tp_shard_job, num_experts), + size=device_count, + ) diff --git a/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_5_moe.py b/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_5_moe.py index bfc57a4cc81..44ae368c817 100644 --- a/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_5_moe.py +++ b/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_5_moe.py @@ -33,6 +33,7 @@ import pytest import torch +import torch.nn as nn import torch.nn.functional as F from PIL import Image @@ -999,6 +1000,8 @@ def test_position_embeddings_passthrough(): torch.manual_seed(42) model = Qwen3_5MoeTextModel(config) model.eval() + lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + model.set_lm_head(lm_head) B, S = 2, 8 input_ids = torch.randint(0, config.vocab_size, (B, S)) @@ -1014,8 +1017,8 @@ def test_position_embeddings_passthrough(): output_external = model(input_ids=input_ids, position_embeddings=(cos, sin)) torch.testing.assert_close( - output_internal.last_hidden_state, - output_external.last_hidden_state, + output_internal.logits, + output_external.logits, rtol=1e-5, atol=1e-5, msg="External position_embeddings should produce identical output to internal computation", @@ -1029,6 +1032,8 @@ def test_rope_cos_sin_kwargs(): torch.manual_seed(42) model = Qwen3_5MoeTextModel(config) model.eval() + lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + model.set_lm_head(lm_head) B, S = 2, 8 input_ids = torch.randint(0, config.vocab_size, (B, S)) @@ -1046,8 +1051,8 @@ def test_rope_cos_sin_kwargs(): output_kwargs = model(input_ids=input_ids, rope_cos=cos, rope_sin=sin) torch.testing.assert_close( - output_ref.last_hidden_state, - output_kwargs.last_hidden_state, + output_ref.logits, + output_kwargs.logits, rtol=1e-5, atol=1e-5, msg="rope_cos/rope_sin kwargs should produce identical output to internal computation", From f8d2090e793a94f48a851344d4a1c18d396dbe48 Mon Sep 17 00:00:00 2001 From: Guoming Zhang <137257613+nv-guomingz@users.noreply.github.com> Date: Fri, 10 Apr 2026 02:16:22 +0800 Subject: [PATCH 2/4] =?UTF-8?q?[None][chore]=20Set=20the=20use=5Fone=5Fmod?= =?UTF-8?q?el=20flag=20to=20True=20by=20default=20on=20llm=20ap=E2=80=A6?= =?UTF-8?q?=20(#12836)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> --- examples/llm-api/quickstart_advanced.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 55ffaf6f52b..076fec69717 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -181,7 +181,9 @@ def add_llm_args(parser): parser.add_argument('--spec_decode_max_draft_len', type=int, default=1) parser.add_argument('--draft_model_dir', type=str, default=None) parser.add_argument('--max_matching_ngram_size', type=int, default=5) - parser.add_argument('--use_one_model', default=False, action='store_true') + parser.add_argument('--use_one_model', + default=True, + action=argparse.BooleanOptionalAction) parser.add_argument('--eagle_choices', type=str, default=None) parser.add_argument('--use_dynamic_tree', default=False, From aed6b6aa88ca2033e4451c141a80ae144f826216 Mon Sep 17 00:00:00 2001 From: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:51:02 +0300 Subject: [PATCH 3/4] [https://nvbugs/5921674][fix] unwaive TestNemotronNanoV3 fp8 tests (#12792) Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 201de58dc31..58238a7b174 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -249,7 +249,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2 accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugspro.nvidia.com/bug/5916155) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugspro.nvidia.com/bug/5916155) unittest/_torch/visual_gen/test_wan.py::TestWanTwoStageTransformer::test_two_stage_with_trtllm_attention SKIP (https://nvbugspro.nvidia.com/bug/5916830) -accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[fp8-1-trtllm] SKIP (https://nvbugs/5921674) full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] SKIP (https://nvbugs/5929339) accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5940463) unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460) @@ -307,7 +306,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_piecewi accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5992113) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[use_temperature=False-attn_backend=TRTLLM] SKIP (https://nvbugs/5997547) accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B_Instruct_Eagle3::test_eagle3_one_model SKIP (https://nvbugs/5997534) -accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[fp8-4-trtllm] SKIP (https://nvbugs/5997046) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python_scheduler[ep4-mtp_nextn=0] SKIP (https://nvbugs/5997051) perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_v32_fp4_blackwell-v32_fp4_tep8_mtp3_8k1k] SKIP (https://nvbugs/5997092) accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 SKIP (https://nvbugs/6004530) From 0de88a283141888a3e4b55bb03e862068c18eb1b Mon Sep 17 00:00:00 2001 From: Venky <23023424+venkywonka@users.noreply.github.com> Date: Thu, 9 Apr 2026 13:50:25 -0700 Subject: [PATCH 4/4] [None][feat] Add NvTelemetry/GXT-compliant usage telemetry (#12384) Signed-off-by: venkywonka <23023424+venkywonka@users.noreply.github.com> --- .github/tava_architecture_diagram.md | 18 + README.md | 35 + setup.py | 3 +- tensorrt_llm/bench/benchmark/low_latency.py | 2 + tensorrt_llm/bench/benchmark/throughput.py | 2 + tensorrt_llm/bench/build/build.py | 3 +- tensorrt_llm/bench/dataclasses/general.py | 1 + tensorrt_llm/commands/bench.py | 18 +- tensorrt_llm/commands/eval.py | 48 +- tensorrt_llm/commands/serve.py | 50 +- tensorrt_llm/llmapi/llm.py | 28 +- tensorrt_llm/llmapi/llm_args.py | 25 + tensorrt_llm/usage/__init__.py | 42 + tensorrt_llm/usage/config.py | 59 ++ tensorrt_llm/usage/schema.py | 276 +++++++ tensorrt_llm/usage/schemas/README.md | 174 +++++ tensorrt_llm/usage/schemas/__init__.py | 44 ++ tensorrt_llm/usage/schemas/__main__.py | 84 ++ .../schemas/trtllm_usage_event_schema.json | 166 ++++ tensorrt_llm/usage/usage_lib.py | 717 ++++++++++++++++++ tests/integration/defs/conftest.py | 2 + .../integration/test_lists/test-db/l0_a10.yml | 12 + .../test_lists/test-db/l0_h100.yml | 8 + .../api_stability/references/llm.yaml | 4 + tests/unittest/conftest.py | 2 + .../unittest/llmapi/test_features_contract.py | 155 ++++ tests/unittest/llmapi/test_llm_args.py | 138 +++- tests/unittest/llmapi/test_llm_telemetry.py | 409 ++++++++++ .../llmapi/test_llm_telemetry_payload.py | 193 +++++ tests/unittest/usage/__init__.py | 14 + tests/unittest/usage/conftest.py | 93 +++ tests/unittest/usage/test_collectors.py | 562 ++++++++++++++ tests/unittest/usage/test_config.py | 106 +++ tests/unittest/usage/test_e2e_capture.py | 261 +++++++ tests/unittest/usage/test_opt_out.py | 240 ++++++ tests/unittest/usage/test_reporter.py | 593 +++++++++++++++ tests/unittest/usage/test_schema.py | 714 +++++++++++++++++ tests/unittest/usage/test_transport.py | 283 +++++++ 38 files changed, 5552 insertions(+), 32 deletions(-) create mode 100644 tensorrt_llm/usage/__init__.py create mode 100644 tensorrt_llm/usage/config.py create mode 100644 tensorrt_llm/usage/schema.py create mode 100644 tensorrt_llm/usage/schemas/README.md create mode 100644 tensorrt_llm/usage/schemas/__init__.py create mode 100644 tensorrt_llm/usage/schemas/__main__.py create mode 100644 tensorrt_llm/usage/schemas/trtllm_usage_event_schema.json create mode 100644 tensorrt_llm/usage/usage_lib.py create mode 100644 tests/unittest/llmapi/test_features_contract.py create mode 100644 tests/unittest/llmapi/test_llm_telemetry.py create mode 100644 tests/unittest/llmapi/test_llm_telemetry_payload.py create mode 100644 tests/unittest/usage/__init__.py create mode 100644 tests/unittest/usage/conftest.py create mode 100644 tests/unittest/usage/test_collectors.py create mode 100644 tests/unittest/usage/test_config.py create mode 100644 tests/unittest/usage/test_e2e_capture.py create mode 100644 tests/unittest/usage/test_opt_out.py create mode 100644 tests/unittest/usage/test_reporter.py create mode 100644 tests/unittest/usage/test_schema.py create mode 100644 tests/unittest/usage/test_transport.py diff --git a/.github/tava_architecture_diagram.md b/.github/tava_architecture_diagram.md index 2744d626a87..4eea834bb98 100644 --- a/.github/tava_architecture_diagram.md +++ b/.github/tava_architecture_diagram.md @@ -91,6 +91,16 @@ graph TB BatchManager --> KVCache end + subgraph "Usage_Telemetry" + ReportUsage[report_usage] + BgReporter[Background Reporter] + GxtPayload[GXT Payload Builder] + GxtEndpoint[NvTelemetry Endpoint] + ReportUsage --> BgReporter + BgReporter --> GxtPayload + GxtPayload --> GxtEndpoint + end + subgraph "Output_Results" Tokens[Generated Tokens] Stats[Performance Stats] @@ -99,6 +109,8 @@ graph TB GenVideos[Generated Videos] end + LLMAPI --> ReportUsage + PyTorch_Flow ~~~ TensorRT_Flow TensorRT_Flow --> Output_Results @@ -106,6 +118,8 @@ graph TB AutoDeploy_Flow --> Output_Results Visual_Gen_Flow --> Output_Results + AutoDeploy_Flow ~~~ Usage_Telemetry + %% Force Output_Results to be between PyTorch_flow and TensorRT_flow PyTorch_Flow ~~~ Output_Results @@ -141,6 +155,10 @@ graph TB classDef api fill:#bfb,stroke:#333,stroke-width:2px; class PythonAPI,CppAPI,LLMAPI api; + %% Telemetry format + classDef telemetry fill:#cef,stroke:#333,stroke-width:2px; + class ReportUsage,BgReporter,GxtPayload,GxtEndpoint telemetry; + %% Results format classDef result fill:#fbb,stroke:#333,stroke-width:2px; class Tokens,Stats,Metrics,GenImages,GenVideos result; diff --git a/README.md b/README.md index bb8124d6316..0c367356ed4 100644 --- a/README.md +++ b/README.md @@ -284,6 +284,41 @@ Deprecation is used to inform developers that some APIs and tools are no longer 4. Removal After Migration Period - After the 3-month migration period ends, deprecated APIs, tools, or parameters are removed in a manner consistent with semantic versioning (major version changes may include breaking removals). +## Telemetry Data Collection + +TensorRT-LLM collects anonymous telemetry data by default. This data is used +in aggregate to understand usage patterns and prioritize engineering efforts. +**This data cannot be traced back to any individual user.** No prompts, +user-identifying information, or persistent identifiers are collected. Any +deployment identifiers are ephemeral, randomly generated per deployment, and +not linked to users. The data we collect includes: + +- Ingress point (e.g., LLM API, CLI, serve command) +- Deployment duration (via periodic heartbeats) +- GPU SKUs, count, memory, and CUDA version +- Model architecture class name (e.g., `LlamaForCausalLM`) +- Parallelism configuration (TP/PP/CP/MoE-EP/MoE-TP sizes), quantization algorithm, dtype, KV cache dtype +- System information (OS platform, Python version, CPU architecture, CPU count) +- TRT-LLM version and backend +- Feature flags (LoRA, speculative decoding, prefix caching, CUDA graphs, chunked context, data parallelism) +- Disaggregated serving metadata (role and deployment ID) + +Telemetry is automatically disabled in CI and test environments. + +### Opting Out of Telemetry Data Collection + +To disable telemetry data collection, use any of the following methods: + +- **Environment variable**: Set `TRTLLM_NO_USAGE_STATS=1`, `DO_NOT_TRACK=1`, or `TELEMETRY_DISABLED=true` +- **File-based**: Create the file `~/.config/trtllm/do_not_track` +- **Python API**: Pass `TelemetryConfig(disabled=True)` to `LLM()` +- **CLI flag**: Use `--no-telemetry` on `trtllm-serve`, `trtllm-bench`, or `trtllm-eval` + +The telemetry collection code is fully open source and auditable at +[`tensorrt_llm/usage/`](./tensorrt_llm/usage/). For a detailed field-by-field +reference of exactly what is collected, see the +[schema documentation](./tensorrt_llm/usage/schemas/README.md). + ## Useful Links - [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT LLM. - [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT LLM. diff --git a/setup.py b/setup.py index b26af764f83..394eb4ea852 100644 --- a/setup.py +++ b/setup.py @@ -170,7 +170,8 @@ def has_ext_modules(self): "_torch/auto_deploy/config/*.yaml", # Include CUDA source for fused MoE align extension so runtime JIT can find it in wheels '_torch/auto_deploy/custom_ops/fused_moe/moe_align_kernel.cu', - '_torch/auto_deploy/custom_ops/fused_moe/triton_fused_moe_configs/*' + '_torch/auto_deploy/custom_ops/fused_moe/triton_fused_moe_configs/*', + 'usage/schemas/*.json', ] diff --git a/tensorrt_llm/bench/benchmark/low_latency.py b/tensorrt_llm/bench/benchmark/low_latency.py index 2619f74d2fc..3743e914c6f 100644 --- a/tensorrt_llm/bench/benchmark/low_latency.py +++ b/tensorrt_llm/bench/benchmark/low_latency.py @@ -292,6 +292,8 @@ def latency_command( llm = None kwargs = kwargs | runtime_config.get_llm_args() kwargs['backend'] = options.backend + if bench_env.telemetry_config is not None: + kwargs["telemetry_config"] = bench_env.telemetry_config # Set environment variables for setting runtime options. default_env_overrides = { diff --git a/tensorrt_llm/bench/benchmark/throughput.py b/tensorrt_llm/bench/benchmark/throughput.py index 9660f6ba6e7..dfef40b106f 100755 --- a/tensorrt_llm/bench/benchmark/throughput.py +++ b/tensorrt_llm/bench/benchmark/throughput.py @@ -428,6 +428,8 @@ def throughput_command( kwargs = kwargs | runtime_config.get_llm_args() kwargs['skip_tokenizer_init'] = not no_skip_tokenizer_init kwargs['backend'] = options.backend + if bench_env.telemetry_config is not None: + kwargs["telemetry_config"] = bench_env.telemetry_config llm = get_llm(runtime_config, kwargs) diff --git a/tensorrt_llm/bench/build/build.py b/tensorrt_llm/bench/build/build.py index 10f21ba86c0..2c33f0e0f11 100644 --- a/tensorrt_llm/bench/build/build.py +++ b/tensorrt_llm/bench/build/build.py @@ -332,7 +332,8 @@ def build_command( quant_config=quant_config, workspace=str(bench_env.workspace), load_format=load_format, - trust_remote_code=trust_remote_code) + trust_remote_code=trust_remote_code, + telemetry_config=bench_env.telemetry_config) # Save the engine. llm.save(engine_dir) llm.shutdown() diff --git a/tensorrt_llm/bench/dataclasses/general.py b/tensorrt_llm/bench/dataclasses/general.py index 33b23836d17..ec509d62ffe 100644 --- a/tensorrt_llm/bench/dataclasses/general.py +++ b/tensorrt_llm/bench/dataclasses/general.py @@ -15,6 +15,7 @@ class BenchmarkEnvironment(BaseModel): checkpoint_path: Optional[Path] workspace: Path revision: Optional[str] = None + telemetry_config: Optional[Any] = None class InferenceRequest(BaseModel): diff --git a/tensorrt_llm/commands/bench.py b/tensorrt_llm/commands/bench.py index c11b8ad4567..269d4d1d3f2 100644 --- a/tensorrt_llm/commands/bench.py +++ b/tensorrt_llm/commands/bench.py @@ -10,6 +10,7 @@ from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment from tensorrt_llm.bench.dataset.prepare_dataset import prepare_dataset from tensorrt_llm.logger import logger, severity_map +from tensorrt_llm.usage import config as _telemetry_config class NotRequiredForHelp(click.Option): @@ -56,6 +57,9 @@ def handle_parse_result(self, ctx, opts, args): default=None, help="The revision to use for the HuggingFace model " "(branch name, tag name, or commit id).") +@click.option("--telemetry/--no-telemetry", + default=True, + help="Enable or disable anonymous usage telemetry collection.") @click.pass_context def main( ctx, @@ -64,15 +68,21 @@ def main( workspace: Path, log_level: str, revision: Optional[str], + telemetry: bool, ) -> None: logger.set_level(log_level) if model is None: return - ctx.obj = BenchmarkEnvironment(model=model, - checkpoint_path=model_path, - workspace=workspace, - revision=revision) + ctx.obj = BenchmarkEnvironment( + model=model, + checkpoint_path=model_path, + workspace=workspace, + revision=revision, + telemetry_config=_telemetry_config.TelemetryConfig( + disabled=not telemetry, + usage_context=_telemetry_config.UsageContext.CLI_BENCH), + ) # Create the workspace where we plan to store intermediate files. ctx.obj.workspace.mkdir(parents=True, exist_ok=True) diff --git a/tensorrt_llm/commands/eval.py b/tensorrt_llm/commands/eval.py index 8bfd8bab0f7..4bdec03443c 100644 --- a/tensorrt_llm/commands/eval.py +++ b/tensorrt_llm/commands/eval.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,6 +26,7 @@ from ..llmapi import BuildConfig, KvCacheConfig from ..llmapi.llm_utils import update_llm_args_with_extra_options from ..logger import logger, severity_map +from ..usage import config as _telemetry_config @click.group() @@ -117,6 +118,9 @@ is_flag=True, default=False, help="Flag for disabling KV cache reuse.") +@click.option("--telemetry/--no-telemetry", + default=True, + help="Enable or disable anonymous usage telemetry collection.") @click.pass_context def main(ctx, model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str], log_level: str, backend: str, @@ -124,7 +128,8 @@ def main(ctx, model: str, tokenizer: Optional[str], max_seq_len: int, tp_size: int, pp_size: int, ep_size: Optional[int], gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float, trust_remote_code: bool, revision: Optional[str], - extra_llm_api_options: Optional[str], disable_kv_cache_reuse: bool): + extra_llm_api_options: Optional[str], disable_kv_cache_reuse: bool, + telemetry: bool): logger.set_level(log_level) kv_cache_config = KvCacheConfig( @@ -132,16 +137,30 @@ def main(ctx, model: str, tokenizer: Optional[str], enable_block_reuse=not disable_kv_cache_reuse) llm_args = { - "model": model, - "tokenizer": tokenizer, - "custom_tokenizer": custom_tokenizer, - "tensor_parallel_size": tp_size, - "pipeline_parallel_size": pp_size, - "moe_expert_parallel_size": ep_size, - "gpus_per_node": gpus_per_node, - "trust_remote_code": trust_remote_code, - "revision": revision, - "kv_cache_config": kv_cache_config, + "model": + model, + "tokenizer": + tokenizer, + "custom_tokenizer": + custom_tokenizer, + "tensor_parallel_size": + tp_size, + "pipeline_parallel_size": + pp_size, + "moe_expert_parallel_size": + ep_size, + "gpus_per_node": + gpus_per_node, + "trust_remote_code": + trust_remote_code, + "revision": + revision, + "kv_cache_config": + kv_cache_config, + "telemetry_config": + _telemetry_config.TelemetryConfig( + disabled=not telemetry, + usage_context=_telemetry_config.UsageContext.CLI_EVAL), } if backend == 'pytorch': @@ -166,6 +185,11 @@ def main(ctx, model: str, tokenizer: Optional[str], llm_args = update_llm_args_with_extra_options(llm_args, extra_llm_api_options) + # CLI --no-telemetry always wins over YAML config + if not telemetry: + llm_args["telemetry_config"] = llm_args["telemetry_config"].model_copy( + update={"disabled": True}) + profiler.start("trtllm init") llm = llm_cls(**llm_args) profiler.stop("trtllm init") diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 48663599f76..78c2ab89d3c 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -7,6 +7,7 @@ import socket import subprocess # nosec B404 import sys +import uuid from pathlib import Path from typing import Any, Dict, Literal, Mapping, Optional, Sequence @@ -44,6 +45,7 @@ from tensorrt_llm.serve.tool_parser.tool_parser_factory import \ resolve_auto_tool_parser from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir +from tensorrt_llm.usage import config as _telemetry_config from tensorrt_llm.visual_gen import VisualGen # Global variable to store the Popen object of the child process @@ -128,6 +130,8 @@ def is_non_default_or_required(param_name, value, backend): default = field_info.default if callable(default): default = default() + elif field_info.default_factory is not None: + default = field_info.default_factory() return value != default @@ -161,6 +165,7 @@ def get_llm_args( enable_chunked_prefill: bool = False, enable_attention_dp: bool = False, video_pruning_rate: Optional[float] = None, + telemetry: bool = True, **llm_args_extra_dict: Any): if gpus_per_node is None: @@ -247,6 +252,10 @@ def get_llm_args( fail_fast_on_attention_window_too_large, "video_pruning_rate": video_pruning_rate, + "telemetry_config": + _telemetry_config.TelemetryConfig( + disabled=not telemetry, + usage_context=_telemetry_config.UsageContext.CLI_SERVE), } llm_args = { @@ -723,6 +732,9 @@ def convert(self, value: Any, param: Optional["click.Parameter"], help=help_info_with_stability_tag( "Target URL to which OpenTelemetry traces will be sent.", "prototype")) +@click.option("--telemetry/--no-telemetry", + default=True, + help="Enable or disable anonymous usage telemetry collection.") @click.option("--disagg_cluster_uri", type=str, default=None, @@ -794,8 +806,9 @@ def serve( otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool, enable_attention_dp: bool, disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str], video_pruning_rate: Optional[float], - custom_module_dirs: list[Path], chat_template: Optional[str], - grpc: bool, served_model_name: Optional[str], + telemetry: bool, custom_module_dirs: list[Path], + chat_template: Optional[str], grpc: bool, + served_model_name: Optional[str], extra_visual_gen_options: Optional[str]): """Running an OpenAI API compatible server @@ -881,7 +894,8 @@ def _serve_llm(): otlp_traces_endpoint=otlp_traces_endpoint, enable_chunked_prefill=enable_chunked_prefill, enable_attention_dp=enable_attention_dp, - video_pruning_rate=video_pruning_rate) + video_pruning_rate=video_pruning_rate, + telemetry=telemetry) llm_args_extra_dict = {} if extra_llm_api_options is not None: @@ -890,6 +904,11 @@ def _serve_llm(): llm_args = update_llm_args_with_extra_dict(llm_args, llm_args_extra_dict) + # CLI --no-telemetry always wins over YAML config + if not telemetry: + llm_args["telemetry_config"] = llm_args[ + "telemetry_config"].model_copy(update={"disabled": True}) + metadata_server_cfg = parse_metadata_server_config_file( metadata_server_config_file) @@ -1039,12 +1058,15 @@ def _serve_visual_gen(): type=str, default=None, help="Path to metadata server config file") +@click.option("--telemetry/--no-telemetry", + default=True, + help="Enable or disable anonymous usage telemetry collection.") def serve_encoder(model: str, host: str, port: int, log_level: str, max_batch_size: int, max_num_tokens: int, gpus_per_node: Optional[int], trust_remote_code: bool, extra_encoder_options: Optional[str], revision: Optional[str], free_gpu_memory_fraction: float, tensor_parallel_size: int, - metadata_server_config_file: Optional[str]): + metadata_server_config_file: Optional[str], telemetry: bool): """Running an OpenAI API compatible server MODEL: model name | HF checkpoint path | TensorRT engine path @@ -1063,7 +1085,8 @@ def serve_encoder(model: str, host: str, port: int, log_level: str, trust_remote_code=trust_remote_code, revision=revision, free_gpu_memory_fraction=free_gpu_memory_fraction, - tensor_parallel_size=tensor_parallel_size) + tensor_parallel_size=tensor_parallel_size, + telemetry=telemetry) encoder_args_extra_dict = {} if extra_encoder_options is not None: @@ -1072,6 +1095,11 @@ def serve_encoder(model: str, host: str, port: int, log_level: str, encoder_args = update_llm_args_with_extra_dict(llm_args, encoder_args_extra_dict) + # CLI --no-telemetry always wins over YAML config + if not telemetry: + encoder_args["telemetry_config"] = encoder_args[ + "telemetry_config"].model_copy(update={"disabled": True}) + metadata_server_cfg = parse_metadata_server_config_file( metadata_server_config_file) @@ -1143,6 +1171,11 @@ def disaggregated( disagg_cfg = parse_disagg_config_file(config_file) if schedule_style: disagg_cfg.schedule_style = schedule_style + + # Generate a shared deployment ID for all workers in this disagg deployment. + # Inherited by child processes via env var; used for deduplication at query time. + os.environ[DisaggLauncherEnvs.TLLM_DISAGG_DEPLOYMENT_ID] = uuid.uuid4().hex + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.bind((disagg_cfg.hostname, disagg_cfg.port)) @@ -1268,6 +1301,8 @@ def disaggregated_mpi_worker(config_file: Optional[str], log_level: str): class DisaggLauncherEnvs(StrEnum): TLLM_DISAGG_INSTANCE_IDX = "TLLM_DISAGG_INSTANCE_IDX" TLLM_DISAGG_RUN_REMOTE_MPI_SESSION_CLIENT = "TLLM_DISAGG_RUN_REMOTE_MPI_SESSION_CLIENT" + TLLM_DISAGG_DEPLOYMENT_ID = "TRTLLM_DISAGG_DEPLOYMENT_ID" + TLLM_DISAGG_ROLE = "TRTLLM_DISAGG_ROLE" def _launch_disaggregated_server(disagg_config_file: str, llm_args: dict): @@ -1277,6 +1312,11 @@ def _launch_disaggregated_server(disagg_config_file: str, llm_args: dict): disagg_config = parse_disagg_config_file(disagg_config_file) server_cfg = disagg_config.server_configs[int(instance_idx)] + # Set disagg role for telemetry (server_cfg.type is 'ctx' or 'gen') + role_map = {"ctx": "context", "gen": "generation"} + os.environ[DisaggLauncherEnvs.TLLM_DISAGG_ROLE] = role_map.get( + server_cfg.type, "") + logger.info( f"rank {mpi_rank()} for index {instance_idx} launch the disagg server") diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 4f207cf4403..49fd606a136 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -256,6 +256,25 @@ def __init__(self, self.mpi_session.shutdown() raise + # --- Usage telemetry (fail-silent) --- + try: + import tensorrt_llm.usage as _usage + telemetry_config = getattr(self.args, 'telemetry_config', None) + # Promote UNKNOWN -> LLM_CLASS for direct Python API usage. + # CLI commands set their specific context before LLM construction, + # so this only fires for users calling LLM() directly. + if telemetry_config is not None: + if telemetry_config.usage_context == _usage.UsageContext.UNKNOWN: + telemetry_config = telemetry_config.model_copy( + update={"usage_context": _usage.UsageContext.LLM_CLASS}) + _usage.report_usage( + llm_args=self.args, + pretrained_config=self._hf_model_config, + telemetry_config=telemetry_config, + ) + except Exception as exc: + logger.debug("Usage telemetry setup failed: %s", exc) + try: if self.args.otlp_traces_endpoint: tracing.init_tracer("trt.llm", self.args.otlp_traces_endpoint) @@ -1064,7 +1083,14 @@ def _build_model(self): # Tokenizer and config loading should be after calling model_loader(), since model_loader() may download the model from HF hub. # It should also be before bindings ExecutorConfig, which may depend on tokenizer info. self._tokenizer = self._try_load_tokenizer() - self._hf_model_config = self._try_load_hf_model_config() + # Load HF config from the original HF model dir when available, + # since self.args.model now points to the engine dir (whose + # config.json uses TRT-LLM schema, not HF schema). + if self._hf_model_dir is not None: + self._hf_model_config = ModelLoader.load_hf_model_config( + self._hf_model_dir) + else: + self._hf_model_config = self._try_load_hf_model_config() self._generation_config = self._try_load_generation_config() # Multimodal special handling: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 9f001b4e5ae..30cccbc88f9 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -59,6 +59,7 @@ from ..models.modeling_utils import (PretrainedConfig, QuantAlgo, QuantConfig, SpeculativeDecodingMode) from ..sampling_params import BatchedLogitsProcessor +from ..usage.config import TelemetryConfig, UsageContext # noqa: F401 from .build_cache import BuildCacheConfig from .tokenizer import TokenizerBase, tokenizer_factory from .utils import (StrictBaseModel, generate_api_docs_as_docstring, @@ -2787,6 +2788,11 @@ class BaseLlmArgs(StrictBaseModel): "[EXPERIMENTAL] Environment variable overrides. NOTE: import-time-cached env vars in the code won't update unless the code fetches them from os.environ on demand.", status="prototype") + telemetry_config: TelemetryConfig = Field( + default_factory=TelemetryConfig, + description="Telemetry configuration (opt-out, usage context).", + status="prototype") + @field_validator('env_overrides', mode='before') @classmethod def coerce_env_overrides_to_str(cls, v): @@ -3875,6 +3881,24 @@ def update_llm_args_with_extra_dict( llm_args_dict['kv_cache_config'] = base_kv_config | llm_args_dict[ 'kv_cache_config'] + # Deep merge telemetry_config: YAML can override fields like `disabled`, + # but `usage_context` is determined by the CLI entry point and must not + # be overridden by user config. + if 'telemetry_config' in llm_args and 'telemetry_config' in llm_args_dict: + yaml_tc = llm_args_dict['telemetry_config'] + if not isinstance(yaml_tc, (dict, TelemetryConfig)): + # YAML value is null / false / etc. — drop it so the CLI default + # is preserved by the field_mapping coercion step below. + del llm_args_dict['telemetry_config'] + else: + base_tc = llm_args['telemetry_config'] + if isinstance(base_tc, TelemetryConfig): + base_tc = base_tc.model_dump(exclude_unset=True) + if isinstance(yaml_tc, TelemetryConfig): + yaml_tc = yaml_tc.model_dump(exclude_unset=True) + yaml_tc.pop('usage_context', None) + llm_args_dict['telemetry_config'] = base_tc | yaml_tc + field_mapping = { "quant_config": QuantConfig, "calib_config": CalibConfig, @@ -3887,6 +3911,7 @@ def update_llm_args_with_extra_dict( "attention_dp_config": AttentionDpConfig, "kv_cache_config": KvCacheConfig, "dwdp_config": DwdpConfig, + "telemetry_config": TelemetryConfig, } for field_name, field_type in field_mapping.items(): if field_name in llm_args_dict: diff --git a/tensorrt_llm/usage/__init__.py b/tensorrt_llm/usage/__init__.py new file mode 100644 index 00000000000..5a753234ff4 --- /dev/null +++ b/tensorrt_llm/usage/__init__.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TRT-LLM Usage Telemetry. + +Collects anonymous usage statistics to help improve TensorRT-LLM. +Data is sent to NVIDIA's telemetry service (NvTelemetry/GXT). + +Opt-out: + - Set environment variable TRTLLM_NO_USAGE_STATS=1 + - Set environment variable TELEMETRY_DISABLED=true or TELEMETRY_DISABLED=1 + - Set environment variable DO_NOT_TRACK=1 + - Create file ~/.config/trtllm/do_not_track + - Pass TelemetryConfig(disabled=True) to LLM() or --telemetry-disabled via CLI + - Automatically disabled in CI/test environments (override with TRTLLM_USAGE_FORCE_ENABLED=1) +""" + +from tensorrt_llm.usage import config as _config +from tensorrt_llm.usage import usage_lib as _usage_lib + +TelemetryConfig = _config.TelemetryConfig +UsageContext = _config.UsageContext +report_usage = _usage_lib.report_usage +is_usage_stats_enabled = _usage_lib.is_usage_stats_enabled + +__all__ = [ + "TelemetryConfig", + "UsageContext", + "report_usage", + "is_usage_stats_enabled", +] diff --git a/tensorrt_llm/usage/config.py b/tensorrt_llm/usage/config.py new file mode 100644 index 00000000000..f08359f1ed6 --- /dev/null +++ b/tensorrt_llm/usage/config.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Telemetry configuration types. + +Canonical location for TelemetryConfig and UsageContext. These are defined +here (in the usage package) rather than in llm_args.py so that the +dependency arrow points correctly: llm_args imports from usage, not +vice versa. + +Imported by tensorrt_llm.llmapi.llm_args for use in BaseLlmArgs. +""" + +from enum import Enum + +from pydantic import Field + +from tensorrt_llm.llmapi.utils import StrictBaseModel + + +class UsageContext(str, Enum): + """Identifies how TRT-LLM was invoked for telemetry tracking.""" + + UNKNOWN = "unknown" + LLM_CLASS = "llm_class" + CLI_SERVE = "cli_serve" + CLI_BENCH = "cli_bench" + CLI_EVAL = "cli_eval" + + +class TelemetryConfig(StrictBaseModel): + """Telemetry configuration for usage data collection. + + Controls opt-out behavior and tracks which entry point invoked TRT-LLM. + """ + + disabled: bool = Field( + default=False, + description="Disable anonymous usage telemetry collection. " + "Can also be set via TRTLLM_NO_USAGE_STATS=1, TELEMETRY_DISABLED=true, " + "DO_NOT_TRACK=1, or file ~/.config/trtllm/do_not_track.", + ) + usage_context: UsageContext = Field( + default=UsageContext.UNKNOWN, + description="Identifies how TRT-LLM was invoked (CLI command vs Python API). " + "Set automatically by CLI commands; defaults to UNKNOWN (promoted to " + "LLM_CLASS by BaseLLM.__init__ for direct Python API usage).", + ) diff --git a/tensorrt_llm/usage/schema.py b/tensorrt_llm/usage/schema.py new file mode 100644 index 00000000000..be9f9127cc7 --- /dev/null +++ b/tensorrt_llm/usage/schema.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pydantic models for GXT Event Protocol v1.6 telemetry payloads. + +These models define the wire format for TRT-LLM usage telemetry sent to +the NvTelemetry/GXT endpoint. The envelope follows the GXT Event Protocol v1.6 +specification; the event parameters are TRT-LLM-specific. + +Reference: +- GXT API Design: GX Telemetry API Design.md +- DataDesigner reference: DataDesigner telemetry.py +""" + +import platform +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_UINT32_MAX = 4_294_967_295 # 2**32 - 1; NvTelemetry "PositiveInt" +_SHORT_STR = 128 # NvTelemetry "ShortString" maxLength +_LONG_STR = 256 # NvTelemetry "LongString" maxLength + +CLIENT_ID = "616561816355034" +EVENT_PROTOCOL = "1.6" +EVENT_SCHEMA_VER = "0.1" +EVENT_SYS_VER = "trtllm-telemetry/1.0" +CLIENT_TYPE = "Native" +CLIENT_VARIANT = "Release" +CPU_ARCHITECTURE = platform.uname().machine + + +# --------------------------------------------------------------------------- +# TRT-LLM Event Parameters (inner payloads) +# --------------------------------------------------------------------------- + + +class TrtllmInitialReport(BaseModel): + """TRT-LLM initial report event parameters. + + Sent once at startup with full environment and configuration details. + All fields are required by the SMS schema (GXT convention: every declared + property must be in ``required``). Fields use sentinel defaults (empty + string for strings, 0 for ints) when the actual value is unavailable. + + Field constraints match the SMS JSON schema type definitions: + - ShortString: maxLength=128 + - LongString: maxLength=256 + - PositiveInt: ge=0, le=_UINT32_MAX (2**32 - 1) + """ + + # TRT-LLM version (ShortString) + trtllm_version: str = Field(default="", max_length=_SHORT_STR, alias="trtllmVersion") + + # System info + platform_info: str = Field(default="", max_length=_LONG_STR, alias="platform") # LongString + python_version: str = Field( + default="", max_length=_SHORT_STR, alias="pythonVersion" + ) # ShortString + cpu_architecture: str = Field( + default="", max_length=_SHORT_STR, alias="cpuArchitecture" + ) # ShortString + cpu_count: int = Field(default=0, ge=0, le=_UINT32_MAX, alias="cpuCount") # PositiveInt + + # GPU info + gpu_count: int = Field(default=0, ge=0, le=_UINT32_MAX, alias="gpuCount") # PositiveInt + gpu_name: str = Field(default="", max_length=_LONG_STR, alias="gpuName") # LongString + gpu_memory_mb: int = Field(default=0, ge=0, le=_UINT32_MAX, alias="gpuMemoryMB") # PositiveInt + cuda_version: str = Field(default="", max_length=_SHORT_STR, alias="cudaVersion") # ShortString + + # Model info (architecture class name only -- no raw config) (LongString) + architecture_class_name: str = Field( + default="", max_length=_LONG_STR, alias="architectureClassName" + ) + + # TRT-LLM config + backend: str = Field(default="", max_length=_SHORT_STR, alias="backend") # ShortString + tensor_parallel_size: int = Field(default=1, ge=0, le=_UINT32_MAX, alias="tensorParallelSize") + pipeline_parallel_size: int = Field( + default=1, ge=0, le=_UINT32_MAX, alias="pipelineParallelSize" + ) + context_parallel_size: int = Field(default=1, ge=0, le=_UINT32_MAX, alias="contextParallelSize") + moe_expert_parallel_size: int = Field( + default=0, ge=0, le=_UINT32_MAX, alias="moeExpertParallelSize" + ) + moe_tensor_parallel_size: int = Field( + default=0, ge=0, le=_UINT32_MAX, alias="moeTensorParallelSize" + ) + dtype: str = Field(default="", max_length=_SHORT_STR, alias="dtype") # ShortString + quantization_algo: str = Field( + default="", max_length=_SHORT_STR, alias="quantizationAlgo" + ) # ShortString + kv_cache_dtype: str = Field( + default="", max_length=_SHORT_STR, alias="kvCacheDtype" + ) # ShortString + + # Ingress point (how TRT-LLM was invoked) (ShortString) + ingress_point: str = Field(default="", max_length=_SHORT_STR, alias="ingressPoint") + + # Disaggregated serving metadata (ShortString) + disagg_role: str = Field(default="", max_length=_SHORT_STR, alias="disaggRole") + deployment_id: str = Field(default="", max_length=_SHORT_STR, alias="deploymentId") + + # Feature flags (JSON-serialized dict of enabled features) + features_json: str = Field(default="{}", alias="featuresJson") + + model_config = {"populate_by_name": True} + + +class TrtllmHeartbeat(BaseModel): + """TRT-LLM heartbeat event parameters. + + Sent periodically to signal the session is still alive. + Contains only a monotonically increasing sequence counter. + """ + + seq: int = Field(..., ge=0, le=_UINT32_MAX, alias="seq") # PositiveInt + + model_config = {"populate_by_name": True} + + +# --------------------------------------------------------------------------- +# GXT Event Wrapper (single event in the events array) +# --------------------------------------------------------------------------- + + +class GxtEvent(BaseModel): + """A single event entry in the GXT events array.""" + + ts: str = Field(..., description="ISO 8601 timestamp") + name: str = Field(default="trtllm_usage_event") + parameters: Dict[str, Any] = Field(...) + + +# --------------------------------------------------------------------------- +# GXT Envelope (top-level payload) +# --------------------------------------------------------------------------- + + +class GxtPayload(BaseModel): + """GXT Event Protocol v1.6 envelope. + + The GXT ingestion endpoint validates a fixed envelope schema and rejects + payloads that omit any of the required top-level fields — even fields + that are semantically irrelevant for this client. + + GXT was designed for consumer products (GeForce Experience, NVIDIA App) + where device fingerprints, user IDs, and GDPR consent flags carry real + values. TRT-LLM is a server-side SDK with no browser, device, or login, + so those fields are hardcoded to sentinel values: + + - Privacy/identity fields → "undefined" + Signals to downstream pipelines and auditors that these fields were + *deliberately not collected*, not accidentally missing. + + - GDPR opt-in fields → "None" (the string, not Python None) + In GXT's data-policy framework this means "no consent model applies". + + This pattern matches the DataDesigner reference implementation (another + internal GXT client for server-side telemetry). + """ + + # Client identification + client_id: str = Field(default=CLIENT_ID, alias="clientId") + client_type: str = Field(default=CLIENT_TYPE, alias="clientType") + client_variant: str = Field(default=CLIENT_VARIANT, alias="clientVariant") + client_ver: str = Field(..., alias="clientVer") + cpu_architecture: str = Field(default=CPU_ARCHITECTURE, alias="cpuArchitecture") + + # Protocol metadata + event_protocol: str = Field(default=EVENT_PROTOCOL, alias="eventProtocol") + event_schema_ver: str = Field(default=EVENT_SCHEMA_VER, alias="eventSchemaVer") + event_sys_ver: str = Field(default=EVENT_SYS_VER, alias="eventSysVer") + + # Session + sent_ts: str = Field(..., alias="sentTs") + session_id: str = Field(..., alias="sessionId") + + # Required by GXT schema but unused by server-side SDK clients. + # "undefined" = deliberately not collected (not accidentally missing). + browser_type: str = Field(default="undefined", alias="browserType") + device_id: str = Field(default="undefined", alias="deviceId") + device_make: str = Field(default="undefined", alias="deviceMake") + device_model: str = Field(default="undefined", alias="deviceModel") + device_os: str = Field(default="undefined", alias="deviceOS") + device_os_version: str = Field(default="undefined", alias="deviceOSVersion") + device_type: str = Field(default="undefined", alias="deviceType") + user_id: str = Field(default="undefined", alias="userId") + external_user_id: str = Field(default="undefined", alias="externalUserId") + idp_id: str = Field(default="undefined", alias="idpId") + integration_id: str = Field(default="undefined", alias="integrationId") + product_name: str = Field(default="undefined", alias="productName") + product_version: str = Field(default="undefined", alias="productVersion") + + # Required by GXT schema but no consent model applies for opt-out + # server-side telemetry. "None" (string) = no GDPR consent tracking. + gdpr_beh_opt_in: str = Field(default="None", alias="gdprBehOptIn") + gdpr_func_opt_in: str = Field(default="None", alias="gdprFuncOptIn") + gdpr_tech_opt_in: str = Field(default="None", alias="gdprTechOptIn") + device_gdpr_beh_opt_in: str = Field(default="None", alias="deviceGdprBehOptIn") + device_gdpr_func_opt_in: str = Field(default="None", alias="deviceGdprFuncOptIn") + device_gdpr_tech_opt_in: str = Field(default="None", alias="deviceGdprTechOptIn") + + # Events array + events: List[GxtEvent] = Field(...) + + model_config = {"populate_by_name": True} + + +# --------------------------------------------------------------------------- +# Helper: build a ready-to-serialize payload +# --------------------------------------------------------------------------- + + +def get_iso_timestamp(dt: Optional[datetime] = None) -> str: + """Return ISO 8601 timestamp with millisecond precision.""" + if dt is None: + dt = datetime.now(timezone.utc) + return dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{dt.microsecond // 1000:03d}Z" + + +def build_gxt_payload( + event: Union[TrtllmInitialReport, TrtllmHeartbeat], + *, + session_id: str, + trtllm_version: str, +) -> dict: + """Build a complete GXT payload dict ready for json.dumps(). + + Args: + event: The TRT-LLM event to send (initial report or heartbeat). + session_id: Ephemeral session UUID (hex string). + trtllm_version: TRT-LLM package version string. + + Returns: + dict suitable for json.dumps() and HTTP POST body. + """ + now = get_iso_timestamp() + + if isinstance(event, TrtllmInitialReport): + event_name = "trtllm_initial_report" + elif isinstance(event, TrtllmHeartbeat): + event_name = "trtllm_heartbeat" + else: + raise TypeError(f"Unknown event type: {type(event).__name__}") + + gxt_event = GxtEvent( + ts=now, + name=event_name, + parameters=event.model_dump(by_alias=True), + ) + + payload = GxtPayload( + client_ver=trtllm_version, + sent_ts=now, + session_id=session_id, + events=[gxt_event], + ) + + return payload.model_dump(by_alias=True) diff --git a/tensorrt_llm/usage/schemas/README.md b/tensorrt_llm/usage/schemas/README.md new file mode 100644 index 00000000000..8bc49551fe3 --- /dev/null +++ b/tensorrt_llm/usage/schemas/README.md @@ -0,0 +1,174 @@ +# TRT-LLM Telemetry Schema Reference + +Schema version: **0.1** | Client ID: `616561816355034` | Protocol: GXT Event Protocol v1.6 + +## Overview + +TRT-LLM collects anonymous, session-level deployment telemetry to understand +how the library is used in production (GPU types, parallelism configs, model +architectures). No PII, model weights, prompts, or outputs are collected. + +**Opt-out** (any one of these disables telemetry): +- `TRTLLM_NO_USAGE_STATS=1` +- `TELEMETRY_DISABLED=true` +- `DO_NOT_TRACK=1` +- Create file `~/.config/trtllm/do_not_track` +- `TelemetryConfig(disabled=True)` in code + +**Auto-disabled** in CI/test environments (detects `CI`, `GITHUB_ACTIONS`, +`JENKINS_URL`, `GITLAB_CI`, `PYTEST_CURRENT_TEST`, etc.). Override with +`TRTLLM_USAGE_FORCE_ENABLED=1` for staging deployments. + +## GXT Envelope + +Every payload is wrapped in a GXT v1.6 envelope. Dashboard builders will see +these top-level fields in Kibana alongside the event parameters. + +| Field | Type | Description | +|-------|------|-------------| +| `clientId` | string | Always `"616561816355034"`. Identifies TRT-LLM in the GXT system. | +| `clientType` | string | Always `"Native"`. | +| `clientVer` | string | TRT-LLM version, e.g. `"1.3.0rc9"`. | +| `eventProtocol` | string | Always `"1.6"`. | +| `eventSchemaVer` | string | Schema version, currently `"0.1"`. | +| `eventSysVer` | string | Always `"trtllm-telemetry/1.0"`. | +| `sessionId` | string | Unique hex UUID per server lifetime. Use this to correlate initial report with heartbeats. | +| `sentTs` | string | ISO 8601 UTC timestamp of when the payload was sent. | + +Privacy/identity fields (`osVersion`, `geoInfo`, `deviceGUID`, etc.) are +hardcoded to `"undefined"` — TRT-LLM is a server-side SDK with no browser or +login context. + +## Events + +### `trtllm_initial_report` + +Sent once at server startup. Contains system info and serving configuration. + +#### System fields + +| Field | Type | Description | Example | +|-------|------|-------------|---------| +| `trtllmVersion` | ShortString | TRT-LLM package version. | `"1.3.0rc9"` | +| `platform` | LongString | OS platform string. | `"Linux-5.15.0-88-generic-x86_64"` | +| `pythonVersion` | ShortString | Python version. | `"3.12.3"` | +| `cpuArchitecture` | ShortString | CPU architecture. | `"x86_64"`, `"aarch64"` | +| `cpuCount` | PositiveInt | Number of logical CPUs (from `os.cpu_count()`). | `128` | + +#### GPU fields + +| Field | Type | Description | Example | +|-------|------|-------------|---------| +| `gpuCount` | PositiveInt | Number of GPUs **visible to the process** (`torch.cuda.device_count()`). Reflects `CUDA_VISIBLE_DEVICES`, not total system GPUs. | `8` | +| `gpuName` | LongString | Name of GPU 0. | `"NVIDIA H100 80GB HBM3"` | +| `gpuMemoryMB` | PositiveInt | Total memory of GPU 0 in MB. | `81559` | +| `cudaVersion` | ShortString | CUDA toolkit version. | `"12.4"` | + +#### Parallelism fields + +| Field | Type | Description | Example | +|-------|------|-------------|---------| +| `tensorParallelSize` | PositiveInt | Tensor parallelism degree. | `8` | +| `pipelineParallelSize` | PositiveInt | Pipeline parallelism degree. | `1` | +| `contextParallelSize` | PositiveInt | Context parallelism degree. | `1` | +| `moeExpertParallelSize` | PositiveInt | MoE expert parallelism. **`0` = auto/unset** (runtime decides). Positive value = explicitly configured. | `0`, `8` | +| `moeTensorParallelSize` | PositiveInt | MoE tensor parallelism. **`0` = auto/unset** (runtime decides). Positive value = explicitly configured. | `0`, `2` | + +#### Model & config fields + +| Field | Type | Description | Example | +|-------|------|-------------|---------| +| `architectureClassName` | LongString | HuggingFace model architecture class. | `"MixtralForCausalLM"`, `"LlamaForCausalLM"` | +| `backend` | ShortString | Execution backend. | `"pytorch"`, `"tensorrt"` | +| `dtype` | ShortString | Model data type. | `"float16"`, `"bfloat16"`, `"auto"` | +| `quantizationAlgo` | ShortString | Quantization algorithm. Empty string if none. | `""`, `"fp8"`, `"w4a16_awq"` | +| `kvCacheDtype` | ShortString | KV cache data type. Empty string if default. | `""`, `"fp8"`, `"auto"` | + +#### Serving context fields + +| Field | Type | Description | Example | +|-------|------|-------------|---------| +| `ingressPoint` | ShortString | How TRT-LLM was invoked. See [Ingress point values](#ingress-point-values). | `"cli_serve"` | +| `featuresJson` | string | JSON-serialized dict of feature flags. See [featuresJson keys](#featuresjson-keys). | `'{"lora":false,...}'` | +| `disaggRole` | ShortString | Disaggregated serving role. Empty if not disaggregated. | `""`, `"context"`, `"generation"` | +| `deploymentId` | ShortString | Shared ID across disaggregated workers. Empty if not disaggregated. | `""`, `"dep-abc123"` | + +### `trtllm_heartbeat` + +Sent periodically (default: every 600s) to track session duration. Up to 1000 +heartbeats per session. + +| Field | Type | Description | Example | +|-------|------|-------------|---------| +| `seq` | PositiveInt | Zero-based heartbeat sequence number. | `0`, `1`, `42` | + +## Type Reference + +| Type | JSON type | Constraints | +|------|-----------|-------------| +| ShortString | string | 0–128 characters | +| LongString | string | 0–256 characters | +| PositiveInt | integer | 0–4,294,967,295 | + +## Ingress Point Values + +The `ingressPoint` field identifies which TRT-LLM entry point started the session. + +| Value | Meaning | +|-------|---------| +| `"cli_serve"` | Started via `trtllm-serve` CLI | +| `"cli_bench"` | Started via `trtllm-bench` CLI | +| `"cli_eval"` | Started via evaluation CLI | +| `"llm_class"` | Started via `LLM()` Python API directly | +| `"unknown"` | Entry point not identified | + +## `featuresJson` Keys + +The `featuresJson` field is a JSON-serialized dict. All keys are always present +with safe defaults. This list may evolve as features are added. + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| `lora` | bool | `false` | LoRA adapter enabled (`enable_lora=True` or `lora_config` provided). | +| `speculative_decoding` | bool | `false` | Speculative decoding enabled (`speculative_config` is not None). Covers MTP, EAGLE, Medusa, etc. | +| `prefix_caching` | bool | `false` | KV cache block reuse / prefix caching enabled. | +| `cuda_graphs` | bool | `false` | CUDA graphs enabled for reduced launch overhead. | +| `chunked_context` | bool | `false` | Chunked prefill enabled (`enable_chunked_prefill=True`). | +| `data_parallel_size` | int | `1` | Data parallel degree. `1` = no data parallelism. Derived from `tp_size` when attention DP is enabled. | + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `TRTLLM_NO_USAGE_STATS` | unset | Set to `1` to disable telemetry. | +| `TELEMETRY_DISABLED` | unset | Set to `true` to disable telemetry. | +| `DO_NOT_TRACK` | unset | Set to `1` to disable telemetry. | +| `TRTLLM_USAGE_STATS_SERVER` | `https://events.gfe.nvidia.com/v1.1/events/json` | Override the GXT endpoint URL. Use for staging. | +| `TRTLLM_USAGE_HEARTBEAT_INTERVAL` | `600` | Heartbeat interval in seconds. | +| `TRTLLM_USAGE_FORCE_ENABLED` | `0` | Set to `1` to force-enable telemetry in CI/test environments. | +| `TRTLLM_DISAGG_ROLE` | unset | Disaggregated serving role (`context` or `generation`). | +| `TRTLLM_DISAGG_DEPLOYMENT_ID` | unset | Shared deployment ID across disaggregated workers. | + +## For Developers: Adding a New Field + +Checklist for adding a telemetry field: + +1. **`tensorrt_llm/usage/schema.py`** — Add field to `TrtllmInitialReport` (or `TrtllmHeartbeat`) Pydantic model with alias. +2. **`tensorrt_llm/usage/schemas/trtllm_usage_event_schema.json`** — Add to `properties` and `required` array. +3. **`tensorrt_llm/usage/usage_lib.py`** — Populate the field in `_background_reporter()` and add extraction logic in `_extract_trtllm_config()` or `_collect_gpu_info()` as appropriate. +4. **`tests/unittest/usage/test_schema.py`** — Update test fixtures and expected field sets. +5. **`tests/unittest/usage/test_collectors.py`** — Add extraction test. +6. **`tests/unittest/usage/test_e2e_capture.py`** — Update e2e payload assertions if needed. +7. **SMS schema upload** — Upload the updated JSON schema to the NvTelemetry Schema Management Service and toggle "on stage" / "on prod". +8. **Update this README** — Add the field to the appropriate table above. + +### Conventions + +- Use **camelCase** aliases for JSON wire format (Pydantic `alias=`). +- Use **snake_case** for Python field names. +- String fields: use `ShortString` (128 chars) or `LongString` (256 chars). +- Integer fields: use `PositiveInt` (0–4B). Use `0` for "auto/unset" semantics. +- All fields must be **required** in the JSON schema (no optional fields). +- Empty string `""` is the sentinel for "not applicable" string fields. +- The telemetry code is **fail-silent** — exceptions are caught and swallowed. +- No PII. No model weights. No prompts. No outputs. Architecture class names only. diff --git a/tensorrt_llm/usage/schemas/__init__.py b/tensorrt_llm/usage/schemas/__init__.py new file mode 100644 index 00000000000..2c439deb076 --- /dev/null +++ b/tensorrt_llm/usage/schemas/__init__.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ground-truth SMS Event Definition Schema for TRT-LLM usage telemetry. + +A single JSON Schema file (``trtllm_usage_event_schema.json``) captures the +SMS Event Definition -- the event structure registered with the NvTelemetry +Data Platform. This is the **ground truth** for the telemetry schema; the +Pydantic models in ``tensorrt_llm.usage.schema`` must stay in sync with it. + +This file is hand-written (not auto-generated from Pydantic models) and +checked into version control as the canonical reference for: + +1. **Drift detection** -- CI tests assert the Pydantic model fields match + the SMS schema properties. A field change in ``schema.py`` without + updating the SMS schema (or vice versa) causes a test failure. + +2. **NvTelemetry cross-reference** -- The SMS schema is the format expected + by the NvTelemetry / Data Platform team for event registration. + +3. **Auditability** -- Legal and privacy reviewers can inspect the schema + without reading Python. + +Validate after changing ``tensorrt_llm.usage.schema``:: + + pytest tests/unittest/usage/test_schema.py -x -q +""" + +from pathlib import Path + +SCHEMAS_DIR = Path(__file__).parent + +SMS_SCHEMA_PATH = SCHEMAS_DIR / "trtllm_usage_event_schema.json" diff --git a/tensorrt_llm/usage/schemas/__main__.py b/tensorrt_llm/usage/schemas/__main__.py new file mode 100644 index 00000000000..3cacdcdc00b --- /dev/null +++ b/tensorrt_llm/usage/schemas/__main__.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Validate that the Pydantic models stay in sync with the SMS JSON schema. + +Usage:: + + python -m tensorrt_llm.usage.schemas +""" + +import json +import sys +from typing import List + +from tensorrt_llm.usage import schema +from tensorrt_llm.usage.schemas import SMS_SCHEMA_PATH + + +def validate() -> List[str]: + """Check Pydantic model fields match the SMS JSON schema properties. + + Returns a list of human-readable error strings (empty = all good). + """ + errors: List[str] = [] + + if not SMS_SCHEMA_PATH.exists(): + errors.append(f"SMS schema file not found: {SMS_SCHEMA_PATH}") + return errors + + sms = json.loads(SMS_SCHEMA_PATH.read_text()) + events = sms.get("definitions", {}).get("events", {}) + + # Map SMS event name -> Pydantic model + model_map = { + "trtllm_initial_report": schema.TrtllmInitialReport, + "trtllm_heartbeat": schema.TrtllmHeartbeat, + } + + for event_name, model_cls in model_map.items(): + if event_name not in events: + errors.append(f"SMS schema missing event definition: {event_name}") + continue + + sms_props = set(events[event_name].get("properties", {}).keys()) + # Pydantic field aliases are the wire names (camelCase) + pydantic_aliases = {f.alias or name for name, f in model_cls.model_fields.items()} + + missing_in_pydantic = sms_props - pydantic_aliases + missing_in_sms = pydantic_aliases - sms_props + + for field in sorted(missing_in_pydantic): + errors.append( + f"{event_name}: field '{field}' in SMS schema but missing " + f"from Pydantic model {model_cls.__name__}" + ) + for field in sorted(missing_in_sms): + errors.append( + f"{event_name}: field '{field}' in Pydantic model " + f"{model_cls.__name__} but missing from SMS schema" + ) + + return errors + + +if __name__ == "__main__": + errs = validate() + if errs: + print("Validation FAILED:") + for e in errs: + print(f" - {e}") + sys.exit(1) + else: + print("Validation OK: Pydantic models match SMS schema.") diff --git a/tensorrt_llm/usage/schemas/trtllm_usage_event_schema.json b/tensorrt_llm/usage/schemas/trtllm_usage_event_schema.json new file mode 100644 index 00000000000..e175b939eee --- /dev/null +++ b/tensorrt_llm/usage/schemas/trtllm_usage_event_schema.json @@ -0,0 +1,166 @@ +{ + "oneOf": [ + { + "$ref": "#/definitions/events/trtllm_initial_report" + }, + { + "$ref": "#/definitions/events/trtllm_heartbeat" + } + ], + "$schema": "http://json-schema.org/draft-07/schema#", + "schemaMeta": { + "schemaVersion": "0.1", + "clientId": "616561816355034", + "clientName": "TrtllmTelemetry", + "definitionVersion": "2.0" + }, + "description": "TensorRT-LLM usage telemetry events. Collects anonymous session-level deployment data (GPU type, parallelism config, model architecture class) with opt-out via TRTLLM_NO_USAGE_STATS=1. Auto-disabled in CI/test environments.", + "definitions": { + "types": { + "ShortString": { + "type": "string", + "minLength": 0, + "maxLength": 128 + }, + "LongString": { + "type": "string", + "minLength": 0, + "maxLength": 256 + }, + "PositiveInt": { + "type": "integer", + "minimum": 0, + "maximum": 4294967295 + } + }, + "events": { + "trtllm_initial_report": { + "eventMeta": { + "service": "telemetry", + "gdpr": { + "category": "functional", + "description": "TRT-LLM initial session report with system and config details. No PII collected." + } + }, + "additionalProperties": false, + "type": "object", + "properties": { + "trtllmVersion": { + "$ref": "#/definitions/types/ShortString" + }, + "platform": { + "$ref": "#/definitions/types/LongString" + }, + "pythonVersion": { + "$ref": "#/definitions/types/ShortString" + }, + "cpuArchitecture": { + "$ref": "#/definitions/types/ShortString" + }, + "cpuCount": { + "$ref": "#/definitions/types/PositiveInt" + }, + "gpuCount": { + "$ref": "#/definitions/types/PositiveInt" + }, + "gpuName": { + "$ref": "#/definitions/types/LongString" + }, + "gpuMemoryMB": { + "$ref": "#/definitions/types/PositiveInt" + }, + "cudaVersion": { + "$ref": "#/definitions/types/ShortString" + }, + "architectureClassName": { + "$ref": "#/definitions/types/LongString" + }, + "backend": { + "$ref": "#/definitions/types/ShortString" + }, + "tensorParallelSize": { + "$ref": "#/definitions/types/PositiveInt" + }, + "pipelineParallelSize": { + "$ref": "#/definitions/types/PositiveInt" + }, + "contextParallelSize": { + "$ref": "#/definitions/types/PositiveInt" + }, + "moeExpertParallelSize": { + "$ref": "#/definitions/types/PositiveInt" + }, + "moeTensorParallelSize": { + "$ref": "#/definitions/types/PositiveInt" + }, + "dtype": { + "$ref": "#/definitions/types/ShortString" + }, + "quantizationAlgo": { + "$ref": "#/definitions/types/ShortString" + }, + "kvCacheDtype": { + "$ref": "#/definitions/types/ShortString" + }, + "ingressPoint": { + "$ref": "#/definitions/types/ShortString" + }, + "featuresJson": { + "type": "string", + "description": "JSON-serialized dict of feature flags (lora, speculative_decoding, prefix_caching, cuda_graphs, chunked_context, data_parallel_size)" + }, + "disaggRole": { + "$ref": "#/definitions/types/ShortString" + }, + "deploymentId": { + "$ref": "#/definitions/types/ShortString" + } + }, + "required": [ + "trtllmVersion", + "platform", + "pythonVersion", + "cpuArchitecture", + "cpuCount", + "gpuCount", + "gpuName", + "gpuMemoryMB", + "cudaVersion", + "architectureClassName", + "backend", + "tensorParallelSize", + "pipelineParallelSize", + "contextParallelSize", + "moeExpertParallelSize", + "moeTensorParallelSize", + "dtype", + "quantizationAlgo", + "kvCacheDtype", + "ingressPoint", + "featuresJson", + "disaggRole", + "deploymentId" + ] + }, + "trtllm_heartbeat": { + "eventMeta": { + "service": "telemetry", + "gdpr": { + "category": "functional", + "description": "TRT-LLM periodic heartbeat for session duration tracking. No payload beyond sequence number." + } + }, + "additionalProperties": false, + "type": "object", + "properties": { + "seq": { + "$ref": "#/definitions/types/PositiveInt" + } + }, + "required": [ + "seq" + ] + } + } + } +} diff --git a/tensorrt_llm/usage/usage_lib.py b/tensorrt_llm/usage/usage_lib.py new file mode 100644 index 00000000000..3b4ae954598 --- /dev/null +++ b/tensorrt_llm/usage/usage_lib.py @@ -0,0 +1,717 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TRT-LLM usage telemetry collection and reporting. + +Collects anonymous usage data (system info, GPU config, model architecture) +and sends it to NVIDIA's NvTelemetry/GXT service. Runs in a background +daemon thread, never blocks or crashes the main process. + +Adapted from PR #11299 (usage lib POC), with: +- GXT Event Protocol v1.6 envelope (NvTelemetry-compliant) +- Architecture-class-only model sanitization +- DO_NOT_TRACK industry-standard env var support +- First-launch console notification + +Environment variables: + TRTLLM_NO_USAGE_STATS: Set to "1" to disable telemetry. + TELEMETRY_DISABLED: Set to "true" or "1" to disable telemetry. + DO_NOT_TRACK: Set to "1" to disable telemetry (industry standard). + TRTLLM_USAGE_STATS_SERVER: Override the GXT endpoint URL. + TRTLLM_USAGE_HEARTBEAT_INTERVAL: Heartbeat interval in seconds (default 600). + TRTLLM_USAGE_FORCE_ENABLED: Set to "1" to force-enable telemetry even in + CI/test environments (e.g., for staging deployments run via CI). + +CI/Test auto-detection: + Telemetry is automatically disabled when running in CI environments or + test frameworks to ensure only real deployment data is collected. Detected + via well-known environment variables set by CI systems (CI, GITHUB_ACTIONS, + JENKINS_URL, etc.) and test runners (PYTEST_CURRENT_TEST). Override with + TRTLLM_USAGE_FORCE_ENABLED=1 if needed. +""" + +import json +import logging +import os +import platform +import threading +import urllib.error +import urllib.parse +import urllib.request +import uuid +from pathlib import Path +from typing import Any, Dict, Optional + +from tensorrt_llm.usage import schema + +logger = logging.getLogger("tensorrt_llm") + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +_DISAGG_ROLE_ENV = "TRTLLM_DISAGG_ROLE" +_DISAGG_DEPLOYMENT_ID_ENV = "TRTLLM_DISAGG_DEPLOYMENT_ID" +_DEFAULT_ENDPOINT = "https://events.gfe.nvidia.com/v1.1/events/json" +_HTTP_TIMEOUT = 2.0 +_MAX_HEARTBEATS = 1000 + + +class _NoRedirectHandler(urllib.request.HTTPRedirectHandler): + """Redirect handler that rejects all redirects (SSRF protection). + + build_opener() auto-adds HTTPRedirectHandler unless a *subclass* is + provided. By passing this handler, the default is replaced and any + 3xx response raises HTTPError instead of being followed. + """ + + def redirect_request(self, req, fp, code, msg, headers, newurl): + raise urllib.error.HTTPError(req.full_url, code, msg, headers, fp) + + +try: + _OPT_OUT_FILE: Optional[Path] = Path.home() / ".config" / "trtllm" / "do_not_track" +except (RuntimeError, KeyError): + # Path.home() fails when HOME is unset and passwd lookup fails + # (e.g. minimal containers). Degrade gracefully — the file-based + # opt-out simply becomes unavailable; env-var opt-out still works. + _OPT_OUT_FILE = None + +# --------------------------------------------------------------------------- +# CI / Test environment detection +# --------------------------------------------------------------------------- + +# Well-known environment variables set by CI systems. +# If any of these are set (to any non-empty value), telemetry is auto-disabled. +_CI_ENV_VARS = ( + "CI", # GitHub Actions, GitLab CI, Travis CI, generic + "GITHUB_ACTIONS", # GitHub Actions + "JENKINS_URL", # Jenkins + "GITLAB_CI", # GitLab CI + "BUILDKITE", # Buildkite + "CIRCLECI", # CircleCI + "TRAVIS", # Travis CI + "TF_BUILD", # Azure DevOps Pipelines + "TEAMCITY_VERSION", # TeamCity + "CODEBUILD_BUILD_ID", # AWS CodeBuild +) + +# Well-known environment variables set by test frameworks. +_TEST_ENV_VARS = ( + "PYTEST_CURRENT_TEST", # Set by pytest during test execution +) + + +def _is_ci_or_test_environment() -> bool: + """Detect if we are running inside a CI pipeline or test framework. + + Returns True if any well-known CI or test environment variable is set + to a non-empty value. This ensures telemetry only fires in real + deployment scenarios -- not during development, testing, or CI runs. + + Neither vLLM nor NeMo DataDesigner implement CI/test auto-detection; + they rely on CI engineers remembering to set opt-out env vars, which + is fragile. By detecting CI/test environments automatically, we + avoid polluting telemetry data with non-deployment noise. + + Users who genuinely want telemetry from CI (e.g., staging deployments) + can override this by setting TRTLLM_USAGE_FORCE_ENABLED=1. + """ + # Allow force-enable override for CI-based deployments + if os.environ.get("TRTLLM_USAGE_FORCE_ENABLED", "0") == "1": + return False + + for var in _CI_ENV_VARS: + if os.environ.get(var): + return True + for var in _TEST_ENV_VARS: + if os.environ.get(var): + return True + return False + + +def _get_stats_server() -> str: + """Read endpoint URL at call time so env changes after import take effect. + + Validates overrides: HTTPS required, domain must be *.nvidia.com. + Invalid overrides fall back to the default endpoint. + """ + override = os.environ.get("TRTLLM_USAGE_STATS_SERVER") + if override is None: + return _DEFAULT_ENDPOINT + + try: + parsed = urllib.parse.urlparse(override) + if parsed.scheme != "https": + logger.warning( + "TRTLLM_USAGE_STATS_SERVER must use HTTPS; " + "ignoring override and using default endpoint." + ) + return _DEFAULT_ENDPOINT + host = (parsed.hostname or "").lower() + if not (host == "nvidia.com" or host.endswith(".nvidia.com")): + logger.warning( + "TRTLLM_USAGE_STATS_SERVER must be an *.nvidia.com domain; " + "ignoring override and using default endpoint." + ) + return _DEFAULT_ENDPOINT + except Exception: + logger.warning( + "TRTLLM_USAGE_STATS_SERVER is not a valid URL; " + "ignoring override and using default endpoint." + ) + return _DEFAULT_ENDPOINT + + logger.info(f"Telemetry endpoint overridden: {override}") + return override + + +def _get_heartbeat_interval() -> int: + """Read heartbeat interval at call time, with safe fallback on bad values.""" + try: + val = int(os.environ.get("TRTLLM_USAGE_HEARTBEAT_INTERVAL", "600")) + return val if val > 0 else 600 + except ValueError: + return 600 + + +# --------------------------------------------------------------------------- +# Notification (shown once per process) +# --------------------------------------------------------------------------- + +_NOTIFICATION_SHOWN = threading.Event() +_USAGE_NOTICE = ( + "TRT-LLM collects anonymous usage data to help improve the product. " + "This data cannot be traced back to any individual user. " + "No user-identifying information, persistent identifiers, or prompts " + "are collected. To disable, set TRTLLM_NO_USAGE_STATS=1, " + "TELEMETRY_DISABLED=true, or pass " + "TelemetryConfig(disabled=True). " + "See https://github.com/NVIDIA/TensorRT-LLM for details." +) + + +def _show_usage_notification(): + """Show a one-time usage notification via logger (thread-safe).""" + if not _NOTIFICATION_SHOWN.is_set(): + _NOTIFICATION_SHOWN.set() + logger.info(_USAGE_NOTICE) + + +# --------------------------------------------------------------------------- +# Opt-out check +# --------------------------------------------------------------------------- + + +def is_usage_stats_enabled(telemetry_disabled: bool = False) -> bool: + """Check whether usage stats collection is enabled. + + Returns False if any of these conditions are met: + - telemetry_disabled=True (programmatic opt-out via LLM API or CLI) + - TRTLLM_NO_USAGE_STATS=1 + - TELEMETRY_DISABLED=true/1 (case-insensitive) + - DO_NOT_TRACK=1 (industry standard: https://consoledonottrack.com/) + - File ~/.config/trtllm/do_not_track exists + - Running in a CI pipeline or test framework (auto-detected) + Override with TRTLLM_USAGE_FORCE_ENABLED=1 if needed. + """ + if telemetry_disabled: + return False + if os.environ.get("TRTLLM_NO_USAGE_STATS", "0") == "1": + return False + if os.environ.get("TELEMETRY_DISABLED", "").lower() in ("1", "true"): + return False + if os.environ.get("DO_NOT_TRACK", "0") == "1": + return False + if _OPT_OUT_FILE is not None and _OPT_OUT_FILE.exists(): + return False + if _is_ci_or_test_environment(): + logger.debug( + "Telemetry auto-disabled: CI/test environment detected. " + "Set TRTLLM_USAGE_FORCE_ENABLED=1 to override." + ) + return False + return True + + +# --------------------------------------------------------------------------- +# Version detection +# --------------------------------------------------------------------------- + + +def _get_trtllm_version() -> str: + """Get TRT-LLM package version, or 'unknown' if not installed.""" + try: + import tensorrt_llm + + return getattr(tensorrt_llm, "__version__", "unknown") + except (ImportError, AttributeError): + return "unknown" + + +# --------------------------------------------------------------------------- +# System info collection (from PR #11299) +# --------------------------------------------------------------------------- + + +def _collect_system_info() -> Dict[str, Any]: + """Collect platform, Python version, CPU info.""" + return { + "platform": platform.platform(), + "python_version": platform.python_version(), + "cpu_architecture": platform.machine(), + "cpu_count": os.cpu_count(), + } + + +def _collect_gpu_info() -> Dict[str, Any]: + """Collect GPU info via torch.cuda. Returns empty dict if unavailable.""" + try: + import torch + + if not torch.cuda.is_available(): + return {} + return { + "gpu_count": torch.cuda.device_count(), + "gpu_name": torch.cuda.get_device_name(0), + "gpu_memory_mb": torch.cuda.get_device_properties(0).total_memory // (1024 * 1024), + "cuda_version": torch.version.cuda or "unknown", + } + except (ImportError, RuntimeError, AttributeError, OSError): + return {} + + +# --------------------------------------------------------------------------- +# Model info extraction (sanitized -- architecture class name only) +# --------------------------------------------------------------------------- + + +def _extract_architecture_class_name(pretrained_config: Any) -> Optional[str]: + """Extract the architecture class name from a pretrained model config. + + Handles three config formats: + + 1. **HF PretrainedConfig** (from ``transformers.PretrainedConfig``): + Has ``.architectures`` — a *list* of strings, e.g. ``["LlamaForCausalLM"]``. + This is the standard format when loading from a HuggingFace model dir. + + 2. [DEPRECATED] **TRT-LLM PretrainedConfig** (from ``tensorrt_llm.models.modeling_utils``): + Has ``.architecture`` — a *singular string*, e.g. ``"LlamaForCausalLM"``. + This is the format used in TRT-LLM checkpoint ``config.json`` files + (``_ModelFormatKind.TLLM_CKPT``). + + 3. [DEPRECATED] **Engine config loaded by HF** (``transformers.PretrainedConfig.from_pretrained`` + reading a TRT-LLM engine dir): + The engine ``config.json`` has top-level keys ``pretrained_config`` (dict) + and ``build_config`` (dict). HF's loader puts these as attributes on a + generic ``PretrainedConfig`` object. The architecture string is at + ``pretrained_config["architecture"]``. + """ + if pretrained_config is None: + return None + try: + # Case 1: HF PretrainedConfig — .architectures (plural list) + architectures = getattr(pretrained_config, "architectures", None) + if architectures and isinstance(architectures, (list, tuple)) and len(architectures) > 0: + return str(architectures[0]) + + # Case 2: TRT-LLM PretrainedConfig / TLLM_CKPT — .architecture (singular str) + architecture = getattr(pretrained_config, "architecture", None) + if architecture and isinstance(architecture, str): + return architecture + + # Case 3: HF from_pretrained on engine dir — nested pretrained_config dict + nested_config = getattr(pretrained_config, "pretrained_config", None) + if isinstance(nested_config, dict) and "architecture" in nested_config: + return str(nested_config["architecture"]) + + # Last resort: config class name (e.g. "LlamaConfig") + return type(pretrained_config).__name__ + except (AttributeError, TypeError, KeyError, IndexError): + return None + + +# --------------------------------------------------------------------------- +# TRT-LLM config extraction +# --------------------------------------------------------------------------- + + +def _extract_trtllm_config(llm_args: Any) -> Dict[str, Any]: + """Extract TRT-LLM configuration from LlmArgs. + + Args: + llm_args: The args object from BaseLLM (TrtLlmArgs, TorchLlmArgs, etc.) + + Returns: + Dict of config values, with None for unavailable fields. + """ + if llm_args is None: + return {} + + config = {} + try: + # Backend detection + backend = getattr(llm_args, "backend", None) + if backend is not None: + config["backend"] = str(backend) + else: + # Infer backend from args class when not explicitly set + cls_name = type(llm_args).__name__ + if "TrtLlm" in cls_name: + config["backend"] = "tensorrt" + + # Parallelism + parallel_config = getattr(llm_args, "parallel_config", None) + if parallel_config is not None: + config["tensor_parallel_size"] = getattr(parallel_config, "tp_size", None) + config["pipeline_parallel_size"] = getattr(parallel_config, "pp_size", None) + config["context_parallel_size"] = getattr(parallel_config, "cp_size", None) + moe_ep = getattr(parallel_config, "moe_ep_size", None) + if moe_ep is not None: + # Map -1 (auto/unset) to 0 for telemetry; PositiveInt schema. + config["moe_expert_parallel_size"] = max(moe_ep, 0) + moe_tp = getattr(parallel_config, "moe_tp_size", None) + if moe_tp is not None: + config["moe_tensor_parallel_size"] = max(moe_tp, 0) + + # dtype + dtype = getattr(llm_args, "dtype", None) + if dtype is not None: + config["dtype"] = str(dtype) + + # Quantization + quant_config = getattr(llm_args, "quant_config", None) + if quant_config is not None: + quant_algo = getattr(quant_config, "quant_algo", None) + if quant_algo is not None: + config["quantization_algo"] = str(quant_algo) + + # KV cache dtype + kv_cache_config = getattr(llm_args, "kv_cache_config", None) + if kv_cache_config is not None: + kv_dtype = getattr(kv_cache_config, "dtype", None) + if kv_dtype is not None: + config["kv_cache_dtype"] = str(kv_dtype) + + except (AttributeError, TypeError): + pass # fail-silent + + return config + + +# --------------------------------------------------------------------------- +# Feature flag collection +# --------------------------------------------------------------------------- + +# Keys and defaults for the features JSON blob. All keys are always present +# in the output to simplify downstream analytics (no ambiguity between +# "feature disabled" and "field missing because old client version"). +_FEATURES_DEFAULTS = { + "lora": False, + "speculative_decoding": False, + "prefix_caching": False, + "cuda_graphs": False, + "chunked_context": False, + "data_parallel_size": 1, +} + + +def _collect_features(llm_args: Any) -> str: + """Collect feature flags from llm_args and return as compact JSON string. + + Inspects the LlmArgs object for enabled features (LoRA, speculative + decoding, prefix caching, CUDA graphs, chunked context, data parallelism). + Returns a JSON-serialized dict with snake_case keys. All keys are always + present with safe defaults, even if extraction fails. + + The output is a string suitable for the ``featuresJson`` field in the + GXT event schema (``stringVariableLength``). + + Args: + llm_args: The args object from BaseLLM (TrtLlmArgs, TorchLlmArgs, etc.) + May be None. + + Returns: + Compact JSON string, e.g. '{"lora":false,"speculative_decoding":false,...}' + """ + features = dict(_FEATURES_DEFAULTS) + + if llm_args is None: + return json.dumps(features, separators=(",", ":")) + + try: + # LoRA: enabled if enable_lora flag is True OR lora_config is provided. + # On PyTorch backend, enable_lora is ignored when lora_config is set, + # so checking both catches all cases. + enable_lora = getattr(llm_args, "enable_lora", False) or False + lora_config = getattr(llm_args, "lora_config", None) + features["lora"] = bool(enable_lora or lora_config is not None) + + # Speculative decoding: enabled if speculative_config is not None. + spec_config = getattr(llm_args, "speculative_config", None) + features["speculative_decoding"] = spec_config is not None + + # Prefix caching (KV block reuse): kv_cache_config.enable_block_reuse. + # kv_cache_config has a default_factory (never None in practice), but + # we guard defensively since llm_args may be a mock or partial object. + kv_cache_config = getattr(llm_args, "kv_cache_config", None) + if kv_cache_config is not None: + block_reuse = getattr(kv_cache_config, "enable_block_reuse", None) + if block_reuse is not None: + features["prefix_caching"] = bool(block_reuse) + + # CUDA graphs: two different config paths depending on backend. + # PyTorch backend: cuda_graph_config (TorchLlmArgs only). + # None = disabled; CudaGraphConfig() = enabled (default). + # TRT backend: extended_runtime_perf_knob_config.cuda_graph_mode (TrtLlmArgs only). + cuda_graph_config = getattr(llm_args, "cuda_graph_config", None) + ext_config = getattr(llm_args, "extended_runtime_perf_knob_config", None) + if cuda_graph_config is not None: + # PyTorch path: presence of config object means enabled + features["cuda_graphs"] = True + elif ext_config is not None: + # TRT path: explicit cuda_graph_mode flag + features["cuda_graphs"] = bool(getattr(ext_config, "cuda_graph_mode", False)) + + # Chunked context / chunked prefill: defined on BaseLlmArgs. + features["chunked_context"] = bool(getattr(llm_args, "enable_chunked_prefill", False)) + + # Data parallel size: derived from parallel_config. + # dp_size = tp_size if enable_attention_dp else 1 (no dp_size field exists). + parallel_config = getattr(llm_args, "parallel_config", None) + if parallel_config is not None: + enable_adp = getattr(parallel_config, "enable_attention_dp", False) + if enable_adp: + tp_size = getattr(parallel_config, "tp_size", 1) or 1 + features["data_parallel_size"] = int(tp_size) + + except Exception: + pass # fail-silent: return whatever we collected so far + + return json.dumps(features, separators=(",", ":")) + + +# --------------------------------------------------------------------------- +# HTTP transport +# --------------------------------------------------------------------------- + + +def _send_to_gxt(payload: dict) -> None: + """Send a GXT payload via HTTP POST. Fail-silent. + + Uses urllib (stdlib) with 2s timeout and no redirects (SSRF protection). + """ + try: + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + _get_stats_server(), + data=data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + method="POST", + ) + # SSRF protection: use a custom opener that does NOT follow redirects. + # build_opener auto-adds HTTPRedirectHandler unless a subclass is + # provided, so we pass a handler that rejects all redirects. + opener = urllib.request.build_opener( + urllib.request.HTTPHandler, + urllib.request.HTTPSHandler, + _NoRedirectHandler, + ) + opener.open(req, timeout=_HTTP_TIMEOUT) + except (urllib.error.URLError, OSError, ValueError, TypeError): + pass # fail-silent: network errors, timeouts, etc. + + +# --------------------------------------------------------------------------- +# Background reporter (daemon thread) +# --------------------------------------------------------------------------- + + +def _clamp_str(value: str, max_len: int) -> str: + """Truncate a string to max_len if it exceeds the limit.""" + return value[:max_len] if len(value) > max_len else value + + +def _background_reporter( + llm_args: Any, + pretrained_config: Any, + usage_context: str = "", +) -> None: + """Background thread entry point. Sends initial report + heartbeats. + + This function is the target of the daemon thread spawned by report_usage(). + It is wrapped in try/except at every level to ensure fail-silent behavior. + """ + try: + session_id = uuid.uuid4().hex + trtllm_version = _get_trtllm_version() + + # --- Collect initial data --- + system_info = _collect_system_info() + gpu_info = _collect_gpu_info() + arch_class_name = _extract_architecture_class_name(pretrained_config) + trtllm_config = _extract_trtllm_config(llm_args) + features_json = _collect_features(llm_args) + + # Disaggregated serving metadata (set by serve.py orchestrator) + disagg_role = os.environ.get(_DISAGG_ROLE_ENV, "") + deployment_id = os.environ.get(_DISAGG_DEPLOYMENT_ID_ENV, "") + + # --- Build initial report event --- + # All fields are required by the SMS schema. Use empty string / 0 + # as sentinel values when actual data is unavailable (e.g., no GPU). + # String values are clamped to schema limits (ShortString=128, + # LongString=256) to prevent ValidationError from real-world data + # exceeding the Pydantic field constraints. + _S = schema._SHORT_STR # ShortString maxLength + _L = schema._LONG_STR # LongString maxLength + initial_event = schema.TrtllmInitialReport( + trtllmVersion=_clamp_str(trtllm_version or "", _S), + # System info + platform=_clamp_str(system_info.get("platform") or "", _L), + pythonVersion=_clamp_str(system_info.get("python_version") or "", _S), + cpuArchitecture=_clamp_str(system_info.get("cpu_architecture") or "", _S), + cpuCount=system_info.get("cpu_count") or 0, + # GPU info + gpuCount=gpu_info.get("gpu_count") or 0, + gpuName=_clamp_str(gpu_info.get("gpu_name") or "", _L), + gpuMemoryMB=gpu_info.get("gpu_memory_mb") or 0, + cudaVersion=_clamp_str(gpu_info.get("cuda_version") or "", _S), + # Model + architectureClassName=_clamp_str(arch_class_name or "", _L), + # Config + backend=_clamp_str(trtllm_config.get("backend") or "", _S), + tensorParallelSize=trtllm_config.get("tensor_parallel_size") or 1, + pipelineParallelSize=trtllm_config.get("pipeline_parallel_size") or 1, + contextParallelSize=trtllm_config.get("context_parallel_size") or 1, + moeExpertParallelSize=trtllm_config.get("moe_expert_parallel_size", 0), + moeTensorParallelSize=trtllm_config.get("moe_tensor_parallel_size", 0), + dtype=_clamp_str(trtllm_config.get("dtype") or "", _S), + quantizationAlgo=_clamp_str(trtllm_config.get("quantization_algo") or "", _S), + kvCacheDtype=_clamp_str(trtllm_config.get("kv_cache_dtype") or "", _S), + # Ingress point + ingressPoint=_clamp_str(usage_context or "", _S), + # Feature flags + featuresJson=features_json, + # Disaggregated serving + disaggRole=_clamp_str(disagg_role, _S), + deploymentId=_clamp_str(deployment_id, _S), + ) + + # --- Send initial report --- + payload = schema.build_gxt_payload( + event=initial_event, + session_id=session_id, + trtllm_version=trtllm_version, + ) + _send_to_gxt(payload) + + # --- Heartbeat loop --- + heartbeat_interval = _get_heartbeat_interval() + for seq in range(_MAX_HEARTBEATS): + if _REPORTER_STOP.wait(timeout=heartbeat_interval): + return # stop requested + + try: + heartbeat_event = schema.TrtllmHeartbeat(seq=seq) + heartbeat_payload = schema.build_gxt_payload( + event=heartbeat_event, + session_id=session_id, + trtllm_version=trtllm_version, + ) + _send_to_gxt(heartbeat_payload) + except (urllib.error.URLError, OSError, ValueError, TypeError): + pass # fail-silent on individual heartbeat + + except Exception: + pass # fail-silent: entire background reporter + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +_REPORTER_STARTED = False +_REPORTER_LOCK = threading.Lock() +_REPORTER_STOP = threading.Event() # signal heartbeat loop to exit + + +def report_usage( + llm_args: Any = None, + pretrained_config: Any = None, + telemetry_config: Any = None, +) -> None: + """Start background usage telemetry reporting. + + Call this once after model initialization. It spawns a daemon thread + that sends an initial report and periodic heartbeats. Subsequent calls + are no-ops (only one reporter thread per process). + + This function is fail-silent -- it will never raise an exception or + block the calling thread. + + Args: + llm_args: The LlmArgs object from BaseLLM (for config extraction). + pretrained_config: The pretrained model config (for architecture name). + telemetry_config: TelemetryConfig object (opt-out + usage context). + """ + global _REPORTER_STARTED + try: + # Extract fields from TelemetryConfig (defensive -- may be None or wrong type) + disabled = False + usage_context = "" + if telemetry_config is not None: + disabled = getattr(telemetry_config, "disabled", False) + ctx = getattr(telemetry_config, "usage_context", None) + if ctx is not None: + usage_context = ctx.value if hasattr(ctx, "value") else str(ctx) + + if not is_usage_stats_enabled(telemetry_disabled=disabled): + return + + # Only rank 0 in a TP group should report (matches vLLM behavior). + # NOTE: This import is intentionally deferred (not top-level) because + # usage_lib.py must be importable without the full TRT-LLM stack — + # test conftest stubs out tensorrt_llm. The try/except ensures + # lightweight installs and test environments aren't broken. + try: + from tensorrt_llm._utils import mpi_rank # noqa: E402 — deferred by design + + if mpi_rank() != 0: + return + except Exception: + pass # fail-silent: if we can't determine rank, proceed + + with _REPORTER_LOCK: + if _REPORTER_STARTED: + return + _REPORTER_STARTED = True + + _show_usage_notification() + + thread = threading.Thread( + target=_background_reporter, + args=(llm_args, pretrained_config, usage_context), + daemon=True, + name="trtllm-usage-stats", + ) + thread.start() + + except Exception: + with _REPORTER_LOCK: + _REPORTER_STARTED = False diff --git a/tests/integration/defs/conftest.py b/tests/integration/defs/conftest.py index 77fdcc2d477..8bef41e091c 100644 --- a/tests/integration/defs/conftest.py +++ b/tests/integration/defs/conftest.py @@ -2268,6 +2268,8 @@ def pytest_collection_modifyitems(session, config, items): def pytest_configure(config): + os.environ.setdefault("TRTLLM_NO_USAGE_STATS", "1") + # avoid thread leak of tqdm's TMonitor tqdm.tqdm.monitor_interval = 0 if config.getoption("--run-ray"): diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 6c51374b67e..73e1aaf88d4 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -50,6 +50,13 @@ l0_a10: - unittest/disaggregated/test_peer.py - unittest/disaggregated/region/test_block.py - unittest/disaggregated/test_mamba_transfer.py + - unittest/usage/test_collectors.py + - unittest/usage/test_config.py + - unittest/usage/test_opt_out.py + - unittest/usage/test_reporter.py + - unittest/usage/test_schema.py + - unittest/usage/test_transport.py + - unittest/usage/test_e2e_capture.py - disaggregated/test_disaggregated.py::test_disaggregated_single_gpu[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_trt_backend[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] @@ -107,6 +114,9 @@ l0_a10: - unittest/llmapi/apps/test_chat_utils.py - unittest/llmapi/apps/test_tool_parsers.py - unittest/llmapi/apps/test_harmony_channel_validation.py + # usage telemetry + - unittest/llmapi/test_llm_telemetry.py::TestTelemetryPyTorchBackend + - unittest/llmapi/test_llm_telemetry.py::TestTelemetryArchitectureExtraction - llmapi/test_llm_api_connector.py::test_connector_simple[True] - llmapi/test_llm_api_connector.py::test_connector_simple[False] - llmapi/test_llm_api_connector.py::test_connector_async_onboard[True] @@ -169,6 +179,8 @@ l0_a10: - examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-BertModel-bert/bert-base-uncased] - unittest/trt/model/test_mistral.py - unittest/trt/model/test_llama.py + # usage telemetry (TRT backend) + - unittest/llmapi/test_llm_telemetry.py::TestTelemetryTRTBackend - llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command[llama-llama-models/llama-7b-hf] # 5min - llmapi/test_llm_e2e.py::test_llmapi_build_command_parameters_align[llama-llama-models-v2/TinyLlama-1.1B-Chat-v1.0] - llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command_with_lora[llama-llama-models-v2/llama-v2-7b-hf] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 58750e43b3f..ed1039b81e9 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -65,6 +65,14 @@ l0_h100: - unittest/disaggregated/test_kv_transfer_mp.py - unittest/others/test_kv_cache_transceiver.py::test_kv_cache_transceiver_single_process[PYTHON-mha-ctx_fp16_gen_fp16] - unittest/others/test_kv_cache_transceiver.py::test_kv_cache_transceiver_single_process[PYTHON-mla-ctx_fp16_gen_fp16] + - unittest/llmapi/test_llm_telemetry.py + - unittest/usage/test_collectors.py + - unittest/usage/test_config.py + - unittest/usage/test_opt_out.py + - unittest/usage/test_reporter.py + - unittest/usage/test_schema.py + - unittest/usage/test_transport.py + - unittest/usage/test_e2e_capture.py - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index f436ead7c82..3e5a4206279 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -231,6 +231,10 @@ methods: annotation: Optional[Dict[str, str]] default: null status: prototype + telemetry_config: + annotation: tensorrt_llm.usage.config.TelemetryConfig + default: null + status: prototype max_stats_len: annotation: int default: 1000 diff --git a/tests/unittest/conftest.py b/tests/unittest/conftest.py index b3df6951f2f..e56aab6cf0c 100644 --- a/tests/unittest/conftest.py +++ b/tests/unittest/conftest.py @@ -45,6 +45,8 @@ def dump_threads(signum, frame): def pytest_configure(config): + os.environ.setdefault("TRTLLM_NO_USAGE_STATS", "1") + # avoid thread leak of tqdm's TMonitor tqdm.tqdm.monitor_interval = 0 diff --git a/tests/unittest/llmapi/test_features_contract.py b/tests/unittest/llmapi/test_features_contract.py new file mode 100644 index 00000000000..d2ba22150e2 --- /dev/null +++ b/tests/unittest/llmapi/test_features_contract.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for telemetry feature collection.""" + +import json +import types +from pathlib import Path + +import pytest +import yaml + +from tensorrt_llm import lora_helper +from tensorrt_llm.llmapi import llm_args +from tensorrt_llm.usage import usage_lib + +_STABILITY_DIR = Path(__file__).resolve().parents[1] / "api_stability" +_COMMITTED_YAML = _STABILITY_DIR / "references_committed" / "llm.yaml" +_REFERENCE_YAML = _STABILITY_DIR / "references" / "llm.yaml" + +_FEATURE_API_DEPS = { + "lora": (_COMMITTED_YAML, ("enable_lora", "lora_config")), + "speculative_decoding": (_COMMITTED_YAML, ("speculative_config",)), + "prefix_caching": (_COMMITTED_YAML, ("kv_cache_config",)), + "chunked_context": (_COMMITTED_YAML, ("enable_chunked_prefill",)), + "cuda_graphs": (_REFERENCE_YAML, ("cuda_graph_config",)), + "data_parallel_size": (_REFERENCE_YAML, ("enable_attention_dp",)), +} + +_KV_DEFAULT = llm_args.KvCacheConfig() +_KV_NO_REUSE = llm_args.KvCacheConfig(enable_block_reuse=False) +_LORA_CONFIG = lora_helper.LoraConfig(lora_dir=["/tmp/fake"]) +_NGRAM_CONFIG = llm_args.NGramDecodingConfig(max_draft_len=1) + + +def _load_init_params(yaml_path: Path) -> dict: + with yaml_path.open(encoding="utf-8") as yaml_file: + return yaml.safe_load(yaml_file)["methods"]["__init__"]["parameters"] + + +def _args(**kwargs): + defaults = { + "enable_lora": False, + "lora_config": None, + "speculative_config": None, + "kv_cache_config": _KV_DEFAULT, + "cuda_graph_config": None, + "extended_runtime_perf_knob_config": None, + "enable_chunked_prefill": False, + "parallel_config": llm_args._ParallelConfig(), + } + defaults.update(kwargs) + return types.SimpleNamespace(**defaults) + + +@pytest.mark.parametrize( + ("feature", "yaml_path", "fields"), + [(feature, *yaml_and_fields) for feature, yaml_and_fields in _FEATURE_API_DEPS.items()], + ids=list(_FEATURE_API_DEPS), +) +def test_telemetry_fields_exist_in_api_yaml(feature, yaml_path, fields): + """If this fails, the LLM API changed and `_collect_features()` needs updating.""" + init_params = _load_init_params(yaml_path) + missing_fields = [field for field in fields if field not in init_params] + + assert not missing_fields, ( + f"Telemetry feature '{feature}' depends on LLM.__init__ parameter(s) {missing_fields} " + f"which no longer exist in {yaml_path.name}. Update _collect_features() in usage_lib.py." + ) + + +@pytest.mark.parametrize( + ("args_kwargs", "key", "expected"), + [ + ({}, "lora", False), + ({}, "speculative_decoding", False), + ({}, "prefix_caching", True), + ({}, "cuda_graphs", False), + ({}, "chunked_context", False), + ({}, "data_parallel_size", 1), + ({"enable_lora": True}, "lora", True), + ({"lora_config": _LORA_CONFIG}, "lora", True), + ({"enable_lora": True, "lora_config": _LORA_CONFIG}, "lora", True), + ({"speculative_config": _NGRAM_CONFIG}, "speculative_decoding", True), + ({"kv_cache_config": _KV_NO_REUSE}, "prefix_caching", False), + ({"cuda_graph_config": llm_args.CudaGraphConfig()}, "cuda_graphs", True), + ({"enable_chunked_prefill": True}, "chunked_context", True), + ( + {"parallel_config": llm_args._ParallelConfig(tp_size=4, enable_attention_dp=True)}, + "data_parallel_size", + 4, + ), + ( + {"parallel_config": llm_args._ParallelConfig(tp_size=4, enable_attention_dp=False)}, + "data_parallel_size", + 1, + ), + ], + ids=[ + "default-lora", + "default-spec", + "default-prefix", + "default-cuda", + "default-chunked", + "default-dp", + "lora-flag", + "lora-config", + "lora-both", + "spec-ngram", + "prefix-disabled", + "cuda-pytorch", + "chunked-enabled", + "dp-4gpu", + "dp-disabled", + ], +) +def test_collect_features_real_configs(args_kwargs, key, expected): + """Real config objects should drive the same feature values as live telemetry.""" + features = json.loads(usage_lib._collect_features(_args(**args_kwargs))) + assert features[key] == expected + + +def test_all_features_enabled_real_configs(): + """A fully enabled config should emit the expected feature payload.""" + args = _args( + enable_lora=True, + lora_config=_LORA_CONFIG, + speculative_config=_NGRAM_CONFIG, + kv_cache_config=_KV_DEFAULT, + cuda_graph_config=llm_args.CudaGraphConfig(), + enable_chunked_prefill=True, + parallel_config=llm_args._ParallelConfig(tp_size=8, enable_attention_dp=True), + ) + + features = json.loads(usage_lib._collect_features(args)) + + assert features == { + "lora": True, + "speculative_decoding": True, + "prefix_caching": True, + "cuda_graphs": True, + "chunked_context": True, + "data_parallel_size": 8, + } diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 999edbcdcde..f5d1c33f77e 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -549,6 +549,129 @@ def check_nested_dict_equality(dict1, dict2, path=""): check_nested_dict_equality(build_config_dict1, build_config_dict2) +class TestTelemetryConfigPrecedence: + """Test that telemetry config follows: default < YAML < CLI precedence.""" + + def test_default_telemetry_config_preserved_when_no_yaml(self): + """Default telemetry_config survives YAML merge when YAML has none.""" + from tensorrt_llm.usage.config import TelemetryConfig, UsageContext + base = { + "model": + "dummy", + "telemetry_config": + TelemetryConfig(disabled=False, + usage_context=UsageContext.CLI_SERVE), + } + yaml_dict = {"max_batch_size": 8} + merged = update_llm_args_with_extra_dict(base, yaml_dict) + tc = merged["telemetry_config"] + assert isinstance(tc, TelemetryConfig) + assert tc.disabled is False + assert tc.usage_context == UsageContext.CLI_SERVE + + def test_yaml_can_override_disabled(self): + """YAML telemetry_config.disabled overrides the default.""" + from tensorrt_llm.usage.config import TelemetryConfig, UsageContext + base = { + "model": + "dummy", + "telemetry_config": + TelemetryConfig(disabled=False, + usage_context=UsageContext.CLI_SERVE), + } + yaml_dict = {"telemetry_config": {"disabled": True}} + merged = update_llm_args_with_extra_dict(base, yaml_dict) + tc = merged["telemetry_config"] + assert isinstance(tc, TelemetryConfig) + assert tc.disabled is True + + def test_yaml_cannot_override_usage_context(self): + """usage_context is coupled to the CLI entry point. + + The CLI entry point (serve, eval, etc.) that first creates the + TelemetryConfig sets usage_context, so YAML must not override it. + """ + from tensorrt_llm.usage.config import TelemetryConfig, UsageContext + base = { + "model": + "dummy", + "telemetry_config": + TelemetryConfig(disabled=False, + usage_context=UsageContext.CLI_SERVE), + } + yaml_dict = { + "telemetry_config": { + "disabled": True, + "usage_context": "cli_eval", + } + } + merged = update_llm_args_with_extra_dict(base, yaml_dict) + tc = merged["telemetry_config"] + assert isinstance(tc, TelemetryConfig) + assert tc.usage_context == UsageContext.CLI_SERVE + assert tc.disabled is True + + def test_cli_disabled_overrides_yaml_enabled(self): + """CLI --telemetry-disabled wins over YAML disabled=false.""" + from tensorrt_llm.usage.config import TelemetryConfig, UsageContext + base = { + "model": + "dummy", + "telemetry_config": + TelemetryConfig(disabled=False, + usage_context=UsageContext.CLI_EVAL), + } + yaml_dict = {"telemetry_config": {"disabled": False}} + merged = update_llm_args_with_extra_dict(base, yaml_dict) + # Simulate CLI --no-telemetry (as done in eval.py / serve.py) + telemetry = False + if not telemetry: + merged["telemetry_config"] = merged["telemetry_config"].model_copy( + update={"disabled": True}) + tc = merged["telemetry_config"] + assert tc.disabled is True + assert tc.usage_context == UsageContext.CLI_EVAL + + def test_yaml_disabled_respected_when_cli_not_set(self): + """When CLI doesn't set --no-telemetry, YAML disabled=true is kept.""" + from tensorrt_llm.usage.config import TelemetryConfig, UsageContext + base = { + "model": + "dummy", + "telemetry_config": + TelemetryConfig(disabled=False, + usage_context=UsageContext.CLI_SERVE), + } + yaml_dict = {"telemetry_config": {"disabled": True}} + merged = update_llm_args_with_extra_dict(base, yaml_dict) + # CLI flag not set (--telemetry is default True) — no override + telemetry = True + if not telemetry: + merged["telemetry_config"] = merged["telemetry_config"].model_copy( + update={"disabled": True}) + tc = merged["telemetry_config"] + assert tc.disabled is True + assert tc.usage_context == UsageContext.CLI_SERVE + + @pytest.mark.parametrize("yaml_value", [None, False, "invalid", 0]) + def test_yaml_null_telemetry_config_preserves_default(self, yaml_value): + """YAML telemetry_config: null/false/invalid preserves the CLI default.""" + from tensorrt_llm.usage.config import TelemetryConfig, UsageContext + base = { + "model": + "dummy", + "telemetry_config": + TelemetryConfig(disabled=False, + usage_context=UsageContext.CLI_SERVE), + } + yaml_dict = {"telemetry_config": yaml_value} + merged = update_llm_args_with_extra_dict(base, yaml_dict) + tc = merged["telemetry_config"] + assert isinstance(tc, TelemetryConfig) + assert tc.usage_context == UsageContext.CLI_SERVE + assert tc.disabled is False + + class TestTorchLlmArgsCudaGraphSettings: def test_cuda_graph_batch_sizes_case_0(self): @@ -1169,8 +1292,7 @@ def test_partial_user_override(self): @pytest.mark.part0 def test_empty_nested_config_preserves_defaults(self): - """Passing an empty nested config (e.g. KvCacheConfig()) should not - block model defaults from applying to that config's sub-fields. + """Passing an empty nested config should not block model defaults. This covers the pattern used by tests that conditionally build a KvCacheConfig: ``kv_cache_config=KvCacheConfig(...) if cond else @@ -1225,10 +1347,7 @@ def _get_all_llm_args_classes(): def _get_all_pydantic_models_from_llm_args(): - """ - Get all Pydantic models referenced by BaseLlmArgs and its subclasses, - including nested models. - """ + """Get all Pydantic models referenced by BaseLlmArgs and its subclasses.""" visited = set() models = [] @@ -1422,8 +1541,11 @@ def test_all_fields_have_descriptions(self): ) def test_all_fields_have_allowed_types(self): - """Test that all fields in LlmArgs classes (including subfields) have types that are allowed - (i.e. are Pydantic-compatible) according to the logic in _is_allowed_type.""" + """Test that all fields in LlmArgs classes have allowed types. + + Checks that fields (including subfields) have Pydantic-compatible + types according to the logic in _is_allowed_type. + """ violations = [] for cls in _get_all_pydantic_models_from_llm_args(): diff --git a/tests/unittest/llmapi/test_llm_telemetry.py b/tests/unittest/llmapi/test_llm_telemetry.py new file mode 100644 index 00000000000..421ad9633ae --- /dev/null +++ b/tests/unittest/llmapi/test_llm_telemetry.py @@ -0,0 +1,409 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for telemetry hook in BaseLLM.__init__(). + +Verifies that pretrained_config is populated and valid when the telemetry +hook fires, and that telemetry_disabled flows through correctly. +""" + +import os +import sys +from unittest.mock import patch + +import pytest + +from tensorrt_llm import LLM as LLM_torch +from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm.llmapi import KvCacheConfig, llm_args +from tensorrt_llm.usage import usage_lib + +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") +from utils.llm_data import llm_models_root # noqa: E402 + +pytestmark = pytest.mark.threadleak(enabled=False) + +MODEL_NAME = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" +_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + + +def _get_model_path(): + root = llm_models_root() + assert root is not None, ( + "LLM_MODELS_ROOT must be set or /home/scratch.trt_llm_data must be " + "accessible to run telemetry integration tests" + ) + return str(root / MODEL_NAME) + + +def _make_spy(): + """Return (captured_dict, spy_function) for patching report_usage.""" + captured = {} + + def spy_report_usage(**kwargs): + captured.update(kwargs) + + return captured, spy_report_usage + + +class TestTelemetryPyTorchBackend: + """Verify pretrained_config lifecycle with PyTorch backend (no engine build).""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.model_path = _get_model_path() + + def test_telemetry_receives_hf_config_pytorch(self): + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch(model=self.model_path, kv_cache_config=_kv_cache_config) as _: + pass + + pretrained_config = captured.get("pretrained_config") + assert pretrained_config is not None, "report_usage was not called with pretrained_config" + assert hasattr(pretrained_config, "architectures"), ( + "pretrained_config missing .architectures attribute" + ) + assert isinstance(pretrained_config.architectures, list) + assert len(pretrained_config.architectures) > 0 + assert pretrained_config.architectures[0] == "LlamaForCausalLM" + + assert captured.get("llm_args") is not None, "report_usage was not called with llm_args" + + +class TestTelemetryTRTBackend: + """Verify pretrained_config lifecycle with TensorRT backend (engine build). + + When starting from an HF model, _TrtLLM builds a TRT engine and + overwrites self.args.model to point to the engine dir. The telemetry + hook must still receive the *original* HF PretrainedConfig (loaded from + _hf_model_dir before the overwrite), not the TRT-LLM engine config. + """ + + @pytest.fixture(autouse=True) + def _setup(self): + self.model_path = _get_model_path() + + def test_telemetry_receives_hf_config_trt(self): + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM(model=self.model_path, kv_cache_config=_kv_cache_config) as _: + pass + + pretrained_config = captured.get("pretrained_config") + assert pretrained_config is not None, "report_usage was not called with pretrained_config" + # The config should be an HF PretrainedConfig with .architectures + # (plural list), NOT a TRT-LLM config with .architecture (singular). + assert hasattr(pretrained_config, "architectures"), ( + "pretrained_config missing .architectures attribute" + ) + assert isinstance(pretrained_config.architectures, list) + assert len(pretrained_config.architectures) > 0 + assert pretrained_config.architectures[0] == "LlamaForCausalLM" + + assert captured.get("llm_args") is not None, "report_usage was not called with llm_args" + + def test_telemetry_arch_extraction_trt(self): + """End-to-end: _extract_architecture_class_name with TRT backend config.""" + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM(model=self.model_path, kv_cache_config=_kv_cache_config) as _: + pass + + pretrained_config = captured.get("pretrained_config") + assert pretrained_config is not None + + arch = usage_lib._extract_architecture_class_name(pretrained_config) + assert arch == "LlamaForCausalLM", f"Expected 'LlamaForCausalLM', got '{arch}'" + + +class TestTelemetryArchitectureExtraction: + """End-to-end: _extract_architecture_class_name with a real HF config.""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.model_path = _get_model_path() + + def test_telemetry_config_has_extractable_architecture(self): + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch(model=self.model_path, kv_cache_config=_kv_cache_config) as _: + pass + + pretrained_config = captured.get("pretrained_config") + assert pretrained_config is not None + + arch = usage_lib._extract_architecture_class_name(pretrained_config) + assert arch == "LlamaForCausalLM", f"Expected 'LlamaForCausalLM', got '{arch}'" + + +class TestTelemetryDisabledFlag: + """Verify that telemetry_disabled=True flows from LLM args to report_usage.""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.model_path = _get_model_path() + + def test_telemetry_disabled_passed_to_report_usage(self): + """When TelemetryConfig(disabled=True), report_usage receives it.""" + from tensorrt_llm.llmapi import llm_args as _llm_args_mod + + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch( + model=self.model_path, + kv_cache_config=_kv_cache_config, + telemetry_config=_llm_args_mod.TelemetryConfig(disabled=True), + ) as _: + pass + + telemetry_config = captured.get("telemetry_config") + assert telemetry_config is not None + assert telemetry_config.disabled is True + + def test_telemetry_enabled_by_default(self): + """When TelemetryConfig not set, disabled defaults to False.""" + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch(model=self.model_path, kv_cache_config=_kv_cache_config) as _: + pass + + telemetry_config = captured.get("telemetry_config") + assert telemetry_config is not None + assert telemetry_config.disabled is False + + +class TestUsageContextFlow: + """Verify that usage_context flows correctly from LLM args to report_usage.""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.model_path = _get_model_path() + + def test_default_promotes_to_llm_class(self): + """LLM() without explicit context gets promoted to LLM_CLASS.""" + from tensorrt_llm.llmapi import llm_args as _llm_args_mod + + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch(model=self.model_path, kv_cache_config=_kv_cache_config) as _: + pass + + telemetry_config = captured.get("telemetry_config") + assert telemetry_config is not None + assert telemetry_config.usage_context == _llm_args_mod.UsageContext.LLM_CLASS + + def test_cli_serve_context_preserved(self): + """CLI_SERVE context is not overridden by BaseLLM.""" + from tensorrt_llm.llmapi import llm_args as _llm_args_mod + + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch( + model=self.model_path, + kv_cache_config=_kv_cache_config, + telemetry_config=_llm_args_mod.TelemetryConfig( + usage_context=_llm_args_mod.UsageContext.CLI_SERVE + ), + ) as _: + pass + + telemetry_config = captured.get("telemetry_config") + assert telemetry_config is not None + assert telemetry_config.usage_context == _llm_args_mod.UsageContext.CLI_SERVE + + def test_cli_bench_context_preserved(self): + """CLI_BENCH context is not overridden by BaseLLM.""" + from tensorrt_llm.llmapi import llm_args as _llm_args_mod + + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch( + model=self.model_path, + kv_cache_config=_kv_cache_config, + telemetry_config=_llm_args_mod.TelemetryConfig( + usage_context=_llm_args_mod.UsageContext.CLI_BENCH + ), + ) as _: + pass + + telemetry_config = captured.get("telemetry_config") + assert telemetry_config is not None + assert telemetry_config.usage_context == _llm_args_mod.UsageContext.CLI_BENCH + + def test_cli_eval_context_preserved(self): + """CLI_EVAL context is not overridden by BaseLLM.""" + from tensorrt_llm.llmapi import llm_args as _llm_args_mod + + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch( + model=self.model_path, + kv_cache_config=_kv_cache_config, + telemetry_config=_llm_args_mod.TelemetryConfig( + usage_context=_llm_args_mod.UsageContext.CLI_EVAL + ), + ) as _: + pass + + telemetry_config = captured.get("telemetry_config") + assert telemetry_config is not None + assert telemetry_config.usage_context == _llm_args_mod.UsageContext.CLI_EVAL + + +class TestFeatureTrackingIntegration: + """Verify that _collect_features works with real llm_args from model init.""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.model_path = _get_model_path() + + def test_features_json_present_in_report_pytorch(self): + """_collect_features returns valid JSON with all 6 keys from real llm_args.""" + import json + + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch(model=self.model_path, kv_cache_config=_kv_cache_config) as _: + pass + + llm_args = captured.get("llm_args") + assert llm_args is not None, "report_usage was not called with llm_args" + + features_str = usage_lib._collect_features(llm_args) + features = json.loads(features_str) + expected_keys = { + "lora", + "speculative_decoding", + "prefix_caching", + "cuda_graphs", + "chunked_context", + "data_parallel_size", + } + assert set(features.keys()) == expected_keys + + def test_features_json_default_values_pytorch(self): + """Default TinyLlama config has expected feature defaults.""" + import json + + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch(model=self.model_path, kv_cache_config=_kv_cache_config) as _: + pass + + llm_args = captured.get("llm_args") + features = json.loads(usage_lib._collect_features(llm_args)) + + # TinyLlama loaded with defaults: no LoRA, no spec dec, no chunked prefill + assert features["lora"] is False + assert features["speculative_decoding"] is False + assert features["chunked_context"] is False + assert features["data_parallel_size"] == 1 + + def test_features_json_chunked_prefill_pytorch(self): + """enable_chunked_prefill=True is detected in features.""" + import json + + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch( + model=self.model_path, + kv_cache_config=_kv_cache_config, + enable_chunked_prefill=True, + ) as _: + pass + + llm_args = captured.get("llm_args") + features = json.loads(usage_lib._collect_features(llm_args)) + assert features["chunked_context"] is True + + @pytest.mark.parametrize( + ("extra_kwargs", "key", "expected"), + [ + ({}, "lora", False), + ({}, "prefix_caching", True), + ({"enable_lora": True}, "lora", True), + ( + { + "kv_cache_config": llm_args.KvCacheConfig( + free_gpu_memory_fraction=0.4, enable_block_reuse=False + ) + }, + "prefix_caching", + False, + ), + ({"enable_chunked_prefill": True}, "chunked_context", True), + ], + ids=[ + "default-lora", + "default-prefix", + "lora-enabled", + "prefix-disabled", + "chunked-enabled", + ], + ) + def test_feature_detection_pytorch(self, extra_kwargs, key, expected): + """Feature extraction should match real PyTorch backend llm_args values.""" + import json + + captured, spy = _make_spy() + kwargs = {"model": self.model_path, "kv_cache_config": _kv_cache_config} + kwargs.update(extra_kwargs) + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch(**kwargs) as _: + pass + + features = json.loads(usage_lib._collect_features(captured["llm_args"])) + assert features[key] == expected + + +# --------------------------------------------------------------------------- +# Cycle 8: Eval and Bench CLI context tests (COVERAGE GAP) +# --------------------------------------------------------------------------- + + +class TestTelemetryEvalContext: + """Verify UsageContext.CLI_EVAL flows through TelemetryConfig.""" + + def test_eval_sets_cli_eval_context(self): + """eval.py sets UsageContext.CLI_EVAL in TelemetryConfig.""" + from tensorrt_llm.llmapi import llm_args as _llm_args_mod + + config = _llm_args_mod.TelemetryConfig(usage_context=_llm_args_mod.UsageContext.CLI_EVAL) + assert config.usage_context == _llm_args_mod.UsageContext.CLI_EVAL + + +class TestTelemetryBenchContext: + """Verify UsageContext.CLI_BENCH flows through TelemetryConfig.""" + + def test_bench_sets_cli_bench_context(self): + """bench.py sets UsageContext.CLI_BENCH in TelemetryConfig.""" + from tensorrt_llm.llmapi import llm_args as _llm_args_mod + + config = _llm_args_mod.TelemetryConfig(usage_context=_llm_args_mod.UsageContext.CLI_BENCH) + assert config.usage_context == _llm_args_mod.UsageContext.CLI_BENCH diff --git a/tests/unittest/llmapi/test_llm_telemetry_payload.py b/tests/unittest/llmapi/test_llm_telemetry_payload.py new file mode 100644 index 00000000000..c24ed6b9b1b --- /dev/null +++ b/tests/unittest/llmapi/test_llm_telemetry_payload.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end payload verification: real model → real JSON → ground truth check. + +Loads a real model on real GPUs, captures the report_usage kwargs via a spy, +then calls _background_reporter() with those real args and verifies every +JSON parameter against ground truth values from torch.cuda, platform, etc. +""" + +import json +import os +import platform +import sys +import threading +from unittest.mock import patch + +import pytest + +from tensorrt_llm import LLM as LLM_torch +from tensorrt_llm.llmapi import KvCacheConfig +from tensorrt_llm.usage import schemas + +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") +from utils.llm_data import llm_models_root # noqa: E402 + +pytestmark = pytest.mark.threadleak(enabled=False) + +MODEL_NAME = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" +_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + + +def _get_model_path(): + root = llm_models_root() + assert root is not None, ( + "LLM_MODELS_ROOT must be set or /home/scratch.trt_llm_data must be " + "accessible to run payload verification tests" + ) + return str(root / MODEL_NAME) + + +def _make_spy(): + captured = {} + + def spy_report_usage(**kwargs): + captured.update(kwargs) + + return captured, spy_report_usage + + +class TestPayloadVerification: + """Verify actual JSON payload parameters against ground truth on real GPUs.""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.model_path = _get_model_path() + + def test_payload_parameters_match_ground_truth(self): + """Load real model, build payload, verify every parameter is accurate.""" + import torch + + import tensorrt_llm.usage.usage_lib as usage_lib + + # Step 1: Load real model, capture report_usage kwargs + captured, spy = _make_spy() + + with patch("tensorrt_llm.usage.report_usage", side_effect=spy): + with LLM_torch(model=self.model_path, kv_cache_config=_kv_cache_config) as _: + pass + + llm_args = captured.get("llm_args") + pretrained_config = captured.get("pretrained_config") + telemetry_config = captured.get("telemetry_config") + + assert llm_args is not None, "report_usage was not called with llm_args" + assert pretrained_config is not None, "report_usage was not called with pretrained_config" + + # Extract usage_context the same way report_usage does + usage_context = "" + if telemetry_config is not None: + ctx = getattr(telemetry_config, "usage_context", None) + if ctx is not None: + usage_context = ctx.value if hasattr(ctx, "value") else str(ctx) + + # Step 2: Call _background_reporter with real args, capture payload + captured_payloads = [] + + def capture_send(payload): + captured_payloads.append(json.loads(json.dumps(payload))) + + stop = threading.Event() + stop.set() + + with ( + patch.object(usage_lib, "_send_to_gxt", side_effect=capture_send), + patch.object(usage_lib, "_REPORTER_STOP", stop), + ): + usage_lib._background_reporter(llm_args, pretrained_config, usage_context) + + assert captured_payloads, "No payloads captured from _background_reporter" + + payload = captured_payloads[0] + params = payload["events"][0]["parameters"] + event_name = payload["events"][0]["name"] + + assert event_name == "trtllm_initial_report" + + # Step 3: Spot-check against ground truth + assert params["gpuName"] == torch.cuda.get_device_name(0) + assert params["gpuCount"] == torch.cuda.device_count() + assert params["gpuMemoryMB"] == torch.cuda.get_device_properties(0).total_memory // ( + 1024 * 1024 + ) + assert params["platform"] == platform.platform() + assert params["pythonVersion"] == platform.python_version() + assert params["cpuArchitecture"] == platform.machine() + assert params["cpuCount"] == os.cpu_count() + assert params["cudaVersion"] == torch.version.cuda + assert params["architectureClassName"] == "LlamaForCausalLM" + assert params["backend"] == "pytorch" + + # Step 4: String length checks (ShortString<=128, LongString<=256) + short_fields = [ + "trtllmVersion", + "pythonVersion", + "cpuArchitecture", + "cudaVersion", + "cloudProvider", + "backend", + "dtype", + "quantizationAlgo", + "kvCacheDtype", + "ingressPoint", + "disaggRole", + "deploymentId", + ] + long_fields = ["platform", "gpuName", "architectureClassName"] + + for f in short_fields: + v = params.get(f, "") + assert len(v) <= 128, f"{f} len={len(v)} exceeds ShortString max 128" + + for f in long_fields: + v = params.get(f, "") + assert len(v) <= 256, f"{f} len={len(v)} exceeds LongString max 256" + + # Step 5: Integer range checks (0 <= x <= 4294967295) + int_fields = [ + "cpuCount", + "gpuCount", + "gpuMemoryMB", + "tensorParallelSize", + "pipelineParallelSize", + "contextParallelSize", + "moeExpertParallelSize", + "moeTensorParallelSize", + ] + for f in int_fields: + v = params.get(f) + assert isinstance(v, int), f"{f} is not int: {type(v)}" + assert 0 <= v <= 4294967295, f"{f}={v} out of PositiveInt range" + + # Step 6: featuresJson check + fj = params.get("featuresJson", "") + features = json.loads(fj) + expected_keys = { + "lora", + "speculative_decoding", + "prefix_caching", + "cuda_graphs", + "chunked_context", + "data_parallel_size", + } + assert set(features.keys()) == expected_keys + + # Step 7: Full jsonschema validation + import jsonschema + + schema = json.loads(schemas.SMS_SCHEMA_PATH.read_text()) + initial_schema = schema["definitions"]["events"]["trtllm_initial_report"].copy() + initial_schema["definitions"] = schema["definitions"] + jsonschema.validate(instance=params, schema=initial_schema) diff --git a/tests/unittest/usage/__init__.py b/tests/unittest/usage/__init__.py new file mode 100644 index 00000000000..1f175b4aca9 --- /dev/null +++ b/tests/unittest/usage/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittest/usage/conftest.py b/tests/unittest/usage/conftest.py new file mode 100644 index 00000000000..e1ecad5231e --- /dev/null +++ b/tests/unittest/usage/conftest.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conftest for usage tests. + +Patches the tensorrt_llm top-level __init__.py import chain so that +the usage subpackage can be tested in isolation (without GPU libs, +nvtx, MPI, etc.). + +This works by pre-populating sys.modules with a stub 'tensorrt_llm' +package that only contains the usage subpackage, before any test +imports trigger the full (heavy) __init__.py. +""" + +import sys +import types +from pathlib import Path + +import pytest + +# Locate the repo root so we can import from tensorrt_llm/usage/ directly +_REPO_ROOT = Path(__file__).resolve().parents[3] +_TRTLLM_PKG = _REPO_ROOT / "tensorrt_llm" + + +def _create_stub_tensorrt_llm(): + """Create a minimal stub tensorrt_llm package in sys.modules. + + This prevents the real __init__.py (which imports torch, nvtx, etc.) + from executing, while still allowing `from tensorrt_llm.usage import ...` + to resolve correctly. + """ + if "tensorrt_llm" in sys.modules: + # Already loaded (e.g. in a full TRT-LLM environment) -- skip + return + + # Create stub package module + stub = types.ModuleType("tensorrt_llm") + stub.__path__ = [str(_TRTLLM_PKG)] + stub.__file__ = str(_TRTLLM_PKG / "__init__.py") + stub.__package__ = "tensorrt_llm" + stub.__version__ = "0.0.0-test" + sys.modules["tensorrt_llm"] = stub + + +# Run before any test collection +_create_stub_tensorrt_llm() + + +def pytest_addoption(parser): + """Register --run-staging CLI flag for live endpoint tests.""" + parser.addoption( + "--run-staging", + action="store_true", + default=False, + help="Run live tests against the GXT staging endpoint.", + ) + + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line( + "markers", "staging: live tests against GXT staging endpoint (require --run-staging)" + ) + + +@pytest.fixture +def enable_telemetry(monkeypatch): + """Clear all opt-out env vars and force-enable for clean telemetry testing. + + Uses TRTLLM_USAGE_FORCE_ENABLED=1 to override CI/test auto-detection + (PYTEST_CURRENT_TEST is set by pytest after fixtures run, so delenv + is not sufficient). + """ + monkeypatch.delenv("TRTLLM_NO_USAGE_STATS", raising=False) + monkeypatch.delenv("DO_NOT_TRACK", raising=False) + monkeypatch.delenv("TELEMETRY_DISABLED", raising=False) + monkeypatch.setenv("TRTLLM_USAGE_FORCE_ENABLED", "1") + monkeypatch.setattr( + "tensorrt_llm.usage.usage_lib._OPT_OUT_FILE", + Path("/nonexistent/path/do_not_track"), + ) diff --git a/tests/unittest/usage/test_collectors.py b/tests/unittest/usage/test_collectors.py new file mode 100644 index 00000000000..294570c07b0 --- /dev/null +++ b/tests/unittest/usage/test_collectors.py @@ -0,0 +1,562 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for data collection functions: system info, GPU, model, config, features.""" + +import json +from unittest.mock import MagicMock, patch + +from tensorrt_llm.usage import schema, usage_lib + +# --------------------------------------------------------------------------- +# System info tests +# --------------------------------------------------------------------------- + + +class TestSystemInfo: + def test_collect_system_info(self): + """System info returns expected keys with valid types.""" + info = usage_lib._collect_system_info() + assert "platform" in info + assert isinstance(info["platform"], str) + assert "python_version" in info + assert isinstance(info["python_version"], str) + assert "cpu_architecture" in info + assert isinstance(info["cpu_architecture"], str) + assert "cpu_count" in info + assert isinstance(info["cpu_count"], int) + assert info["cpu_count"] > 0 + + def test_collect_system_info_handles_none_cpu_count(self): + """cpu_count can be None (e.g. some container/embedded envs).""" + with patch("os.cpu_count", return_value=None): + info = usage_lib._collect_system_info() + assert info["cpu_count"] is None + + def test_collect_gpu_info_no_torch(self): + """GPU info returns empty dict when torch is unavailable.""" + with patch.dict("sys.modules", {"torch": None}): + result = usage_lib._collect_gpu_info() + assert result == {} + + def test_collect_gpu_info_no_cuda(self): + """GPU info returns empty dict when CUDA is unavailable.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = False + with patch.dict("sys.modules", {"torch": mock_torch, "torch.cuda": mock_torch.cuda}): + result = usage_lib._collect_gpu_info() + assert result == {} + + def test_collect_gpu_info_with_cuda(self): + """GPU info returns populated dict when CUDA is available.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.device_count.return_value = 8 + mock_torch.cuda.get_device_name.return_value = "NVIDIA H100" + mock_props = MagicMock() + mock_props.total_memory = 80 * 1024 * 1024 * 1024 # 80 GB in bytes + mock_torch.cuda.get_device_properties.return_value = mock_props + mock_torch.version.cuda = "12.4" + with patch.dict( + "sys.modules", + { + "torch": mock_torch, + "torch.cuda": mock_torch.cuda, + "torch.version": mock_torch.version, + }, + ): + result = usage_lib._collect_gpu_info() + assert result["gpu_count"] == 8 + assert result["gpu_name"] == "NVIDIA H100" + assert result["gpu_memory_mb"] == 80 * 1024 # 80 GB in MB + assert result["cuda_version"] == "12.4" + + def test_collect_gpu_info_catches_import_error(self): + """_collect_gpu_info returns {} when torch is not installed.""" + with patch.dict("sys.modules", {"torch": None}): + assert usage_lib._collect_gpu_info() == {} + + def test_collect_gpu_info_catches_runtime_error(self): + """_collect_gpu_info returns {} when CUDA is broken.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.side_effect = RuntimeError("CUDA error") + with patch.dict( + "sys.modules", + {"torch": mock_torch, "torch.cuda": mock_torch.cuda}, + ): + assert usage_lib._collect_gpu_info() == {} + + +# --------------------------------------------------------------------------- +# Model info tests +# --------------------------------------------------------------------------- + + +class TestModelInfo: + def test_extract_architecture_class_name(self): + """Extracts first architecture from config.architectures list.""" + mock = MagicMock() + mock.architectures = ["LlamaForCausalLM"] + assert usage_lib._extract_architecture_class_name(mock) == "LlamaForCausalLM" + + def test_extract_architecture_multiple(self): + """Extracts first architecture when list has multiple entries.""" + mock = MagicMock() + mock.architectures = ["Qwen2ForCausalLM", "GPT2LMHeadModel"] + assert usage_lib._extract_architecture_class_name(mock) == "Qwen2ForCausalLM" + + def test_extract_architecture_none_config(self): + """Returns None when config is None.""" + assert usage_lib._extract_architecture_class_name(None) is None + + def test_extract_architecture_empty_list(self): + """Falls back to class name when architectures list is empty.""" + mock = MagicMock(spec=[]) # No attributes by default + mock.architectures = [] + result = usage_lib._extract_architecture_class_name(mock) + assert result is not None # Should return the class name + + def test_extract_architecture_no_attr(self): + """Falls back to class name when architecture attr is missing.""" + + class FakeConfig: + pass + + config = FakeConfig() + result = usage_lib._extract_architecture_class_name(config) + assert result == "FakeConfig" + + def test_extract_architecture_trtllm_singular(self): + """TRT-LLM PretrainedConfig uses .architecture (singular string).""" + mock = MagicMock(spec=[]) + mock.architecture = "LlamaForCausalLM" + assert usage_lib._extract_architecture_class_name(mock) == "LlamaForCausalLM" + + def test_extract_architecture_engine_config_nested(self): + """HF from_pretrained on engine dir produces nested pretrained_config dict.""" + + class FakeEngineConfig: + pretrained_config = { + "architecture": "LlamaForCausalLM", + "dtype": "float16", + "hidden_size": 2048, + } + build_config = {"max_batch_size": 8} + + config = FakeEngineConfig() + assert usage_lib._extract_architecture_class_name(config) == "LlamaForCausalLM" + + def test_extract_architecture_hf_takes_priority(self): + """HF .architectures (plural) takes priority over TRT-LLM .architecture.""" + mock = MagicMock(spec=[]) + mock.architectures = ["LlamaForCausalLM"] + mock.architecture = "ShouldNotBeUsed" + assert usage_lib._extract_architecture_class_name(mock) == "LlamaForCausalLM" + + def test_extract_architecture_singular_over_nested(self): + """Direct .architecture (singular) takes priority over nested dict.""" + + class FakeConfig: + architecture = "MixtralForCausalLM" + pretrained_config = { + "architecture": "ShouldNotBeUsed", + } + + assert usage_lib._extract_architecture_class_name(FakeConfig()) == "MixtralForCausalLM" + + def test_no_raw_config_fields(self): + """Ensure PR #11299's raw config fields are NOT in the schema.""" + fields = schema.TrtllmInitialReport.model_fields + assert "num_layers" not in fields + assert "hidden_size" not in fields + assert "num_attention_heads" not in fields + assert "model_type" not in fields + + def test_extract_arch_catches_attribute_error(self): + """_extract_architecture_class_name handles configs without expected attrs.""" + result = usage_lib._extract_architecture_class_name(42) # int has no .architectures + assert result is not None # falls through to type(42).__name__ == "int" + + +# --------------------------------------------------------------------------- +# TRT-LLM config extraction tests +# --------------------------------------------------------------------------- + + +class TestConfigExtraction: + def test_extract_config(self): + """Extracts all config fields from a fully-populated mock.""" + mock = MagicMock() + mock.backend = "pytorch" + mock.parallel_config.tp_size = 4 + mock.parallel_config.pp_size = 2 + mock.parallel_config.cp_size = 1 + mock.parallel_config.moe_ep_size = 8 + mock.parallel_config.moe_tp_size = 2 + mock.dtype = "float16" + mock.quant_config.quant_algo = "fp8" + mock.kv_cache_config.dtype = "auto" + + result = usage_lib._extract_trtllm_config(mock) + assert result["backend"] == "pytorch" + assert result["tensor_parallel_size"] == 4 + assert result["pipeline_parallel_size"] == 2 + assert result["context_parallel_size"] == 1 + assert result["moe_expert_parallel_size"] == 8 + assert result["moe_tensor_parallel_size"] == 2 + assert result["dtype"] == "float16" + assert result["quantization_algo"] == "fp8" + assert result["kv_cache_dtype"] == "auto" + + def test_extract_config_moe_sentinel_mapped_to_zero(self): + """MoE parallel sizes of -1 (auto) are mapped to 0 for telemetry.""" + mock = MagicMock() + mock.backend = "pytorch" + mock.parallel_config.tp_size = 4 + mock.parallel_config.pp_size = 1 + mock.parallel_config.cp_size = 1 + mock.parallel_config.moe_ep_size = -1 + mock.parallel_config.moe_tp_size = -1 + mock.dtype = "float16" + mock.quant_config.quant_algo = None + mock.kv_cache_config.dtype = None + + result = usage_lib._extract_trtllm_config(mock) + assert result["moe_expert_parallel_size"] == 0 + assert result["moe_tensor_parallel_size"] == 0 + + def test_extract_config_none(self): + """Returns empty dict when llm_args is None.""" + assert usage_lib._extract_trtllm_config(None) == {} + + def test_extract_config_partial(self): + """Extracts available fields without crashing on missing ones.""" + + class PartialArgs: + backend = "tensorrt" + + result = usage_lib._extract_trtllm_config(PartialArgs()) + assert result["backend"] == "tensorrt" + + def test_extract_config_defaults_for_missing(self): + """Missing optional configs are omitted from the result dict.""" + mock = MagicMock() + mock.backend = "pytorch" + mock.parallel_config.tp_size = 4 + mock.parallel_config.pp_size = 1 + mock.parallel_config.cp_size = 1 + mock.parallel_config.moe_ep_size = None + mock.parallel_config.moe_tp_size = None + mock.dtype = None + mock.quant_config.quant_algo = None + mock.kv_cache_config.dtype = None + + result = usage_lib._extract_trtllm_config(mock) + assert "moe_expert_parallel_size" not in result + assert "moe_tensor_parallel_size" not in result + assert "dtype" not in result + assert "quantization_algo" not in result + assert "kv_cache_dtype" not in result + + def test_extract_config_infers_backend_from_class_name(self): + """Backend inferred as 'tensorrt' when backend missing and class name contains 'TrtLlm'.""" + + class TrtLlmArgsLike: + pass # no backend attr -> triggers cls_name inference + + result = usage_lib._extract_trtllm_config(TrtLlmArgsLike()) + assert result.get("backend") == "tensorrt" + + def test_extract_config_no_backend_no_trtllm_in_name(self): + """Backend omitted when class name does not contain 'TrtLlm'.""" + + class GenericArgs: + pass + + result = usage_lib._extract_trtllm_config(GenericArgs()) + assert "backend" not in result + + +# --------------------------------------------------------------------------- +# _clamp_str truncation helper tests +# --------------------------------------------------------------------------- + + +class TestClampStr: + """Tests for _clamp_str truncation helper.""" + + def test_truncates_long_value(self): + """Strings exceeding max_len are truncated.""" + assert usage_lib._clamp_str("a" * 200, 128) == "a" * 128 + + def test_preserves_short_value(self): + """Strings shorter than max_len are unchanged.""" + assert usage_lib._clamp_str("short", 128) == "short" + + def test_exact_boundary_not_truncated(self): + """String exactly at max_len is not truncated.""" + assert usage_lib._clamp_str("a" * 128, 128) == "a" * 128 + + def test_one_over_boundary_truncated(self): + """String one char over max_len is truncated.""" + assert usage_lib._clamp_str("a" * 129, 128) == "a" * 128 + + +# --------------------------------------------------------------------------- +# Feature flag extraction tests +# --------------------------------------------------------------------------- + + +class TestFeatureExtraction: + """Tests for _collect_features() -- extracting feature flags from llm_args.""" + + # --- All keys always present --- + + def test_all_keys_always_present(self): + """Every expected key is present regardless of config values.""" + mock = MagicMock() + result = json.loads(usage_lib._collect_features(mock)) + expected_keys = set(usage_lib._FEATURES_DEFAULTS.keys()) + assert set(result.keys()) == expected_keys + + # --- Default / None scenarios --- + + def test_all_features_disabled_defaults(self): + """All features default to false/1 when llm_args has no feature configs.""" + mock = MagicMock(spec=[]) # No attributes + result = json.loads(usage_lib._collect_features(mock)) + assert result == { + "lora": False, + "speculative_decoding": False, + "prefix_caching": False, + "cuda_graphs": False, + "chunked_context": False, + "data_parallel_size": 1, + } + + def test_none_llm_args_returns_defaults(self): + """None llm_args returns all defaults.""" + result = json.loads(usage_lib._collect_features(None)) + assert result == dict(usage_lib._FEATURES_DEFAULTS) + + def test_exception_in_extraction_returns_partial(self): + """If extraction raises, returns whatever was collected (fail-silent).""" + mock = MagicMock() + mock.enable_lora = True + mock.lora_config = None + # Make speculative_config raise AttributeError (caught by narrowed handler) + type(mock).speculative_config = property( + lambda self: (_ for _ in ()).throw(AttributeError("boom")) + ) + result = json.loads(usage_lib._collect_features(mock)) + # lora should be collected before the exception + assert result["lora"] is True + # All keys must still be present + assert set(result.keys()) == set(usage_lib._FEATURES_DEFAULTS.keys()) + + # --- LoRA --- + + def test_lora_enabled_via_lora_config(self): + """LoRA detected when lora_config is not None.""" + mock = MagicMock() + mock.enable_lora = False + mock.lora_config = MagicMock() # non-None + result = json.loads(usage_lib._collect_features(mock)) + assert result["lora"] is True + + def test_lora_enabled_via_enable_lora_flag(self): + """LoRA detected when enable_lora is True.""" + mock = MagicMock() + mock.enable_lora = True + mock.lora_config = None + result = json.loads(usage_lib._collect_features(mock)) + assert result["lora"] is True + + def test_lora_both_signals(self): + """LoRA detected when both enable_lora and lora_config are set.""" + mock = MagicMock() + mock.enable_lora = True + mock.lora_config = MagicMock() + result = json.loads(usage_lib._collect_features(mock)) + assert result["lora"] is True + + def test_lora_disabled(self): + """LoRA is false when neither signal is set.""" + mock = MagicMock() + mock.enable_lora = False + mock.lora_config = None + result = json.loads(usage_lib._collect_features(mock)) + assert result["lora"] is False + + # --- Speculative decoding --- + + def test_speculative_decoding_enabled(self): + """Speculative decoding detected when speculative_config is not None.""" + mock = MagicMock() + mock.speculative_config = MagicMock() + result = json.loads(usage_lib._collect_features(mock)) + assert result["speculative_decoding"] is True + + def test_speculative_decoding_none(self): + """Speculative decoding is false when speculative_config is None.""" + mock = MagicMock() + mock.speculative_config = None + result = json.loads(usage_lib._collect_features(mock)) + assert result["speculative_decoding"] is False + + # --- Prefix caching --- + + def test_prefix_caching_enabled(self): + """Prefix caching detected when enable_block_reuse is True.""" + mock = MagicMock() + mock.kv_cache_config.enable_block_reuse = True + result = json.loads(usage_lib._collect_features(mock)) + assert result["prefix_caching"] is True + + def test_prefix_caching_disabled(self): + """Prefix caching is false when enable_block_reuse is False.""" + mock = MagicMock() + mock.kv_cache_config.enable_block_reuse = False + result = json.loads(usage_lib._collect_features(mock)) + assert result["prefix_caching"] is False + + def test_prefix_caching_no_kv_config(self): + """Prefix caching defaults to false when kv_cache_config is None.""" + mock = MagicMock() + mock.kv_cache_config = None + result = json.loads(usage_lib._collect_features(mock)) + assert result["prefix_caching"] is False + + # --- CUDA graphs --- + + def test_cuda_graphs_pytorch_backend(self): + """CUDA graphs detected via cuda_graph_config (PyTorch backend).""" + mock = MagicMock() + mock.cuda_graph_config = MagicMock() # non-None = enabled + mock.extended_runtime_perf_knob_config = None + result = json.loads(usage_lib._collect_features(mock)) + assert result["cuda_graphs"] is True + + def test_cuda_graphs_pytorch_disabled(self): + """CUDA graphs false when cuda_graph_config is None (PyTorch).""" + mock = MagicMock() + mock.cuda_graph_config = None + mock.extended_runtime_perf_knob_config = None + result = json.loads(usage_lib._collect_features(mock)) + assert result["cuda_graphs"] is False + + def test_cuda_graphs_trt_backend(self): + """CUDA graphs detected via extended_runtime_perf_knob_config (TRT).""" + mock = MagicMock() + mock.cuda_graph_config = None + mock.extended_runtime_perf_knob_config.cuda_graph_mode = True + result = json.loads(usage_lib._collect_features(mock)) + assert result["cuda_graphs"] is True + + def test_cuda_graphs_trt_disabled(self): + """CUDA graphs false when cuda_graph_mode is False (TRT).""" + mock = MagicMock() + mock.cuda_graph_config = None + mock.extended_runtime_perf_knob_config.cuda_graph_mode = False + result = json.loads(usage_lib._collect_features(mock)) + assert result["cuda_graphs"] is False + + def test_cuda_graphs_no_config_either_backend(self): + """CUDA graphs false when neither backend config is present.""" + mock = MagicMock() + mock.cuda_graph_config = None + mock.extended_runtime_perf_knob_config = None + result = json.loads(usage_lib._collect_features(mock)) + assert result["cuda_graphs"] is False + + # --- Chunked context --- + + def test_chunked_context_enabled(self): + """Chunked context detected when enable_chunked_prefill is True.""" + mock = MagicMock() + mock.enable_chunked_prefill = True + result = json.loads(usage_lib._collect_features(mock)) + assert result["chunked_context"] is True + + def test_chunked_context_disabled(self): + """Chunked context is false when enable_chunked_prefill is False.""" + mock = MagicMock() + mock.enable_chunked_prefill = False + result = json.loads(usage_lib._collect_features(mock)) + assert result["chunked_context"] is False + + # --- Data parallel size --- + + def test_data_parallel_size_with_attention_dp(self): + """dp_size = tp_size when enable_attention_dp is True.""" + mock = MagicMock() + mock.parallel_config.enable_attention_dp = True + mock.parallel_config.tp_size = 4 + result = json.loads(usage_lib._collect_features(mock)) + assert result["data_parallel_size"] == 4 + + def test_data_parallel_size_without_attention_dp(self): + """dp_size = 1 when enable_attention_dp is False.""" + mock = MagicMock() + mock.parallel_config.enable_attention_dp = False + mock.parallel_config.tp_size = 4 + result = json.loads(usage_lib._collect_features(mock)) + assert result["data_parallel_size"] == 1 + + def test_data_parallel_size_no_parallel_config(self): + """dp_size defaults to 1 when parallel_config is None.""" + mock = MagicMock() + mock.parallel_config = None + result = json.loads(usage_lib._collect_features(mock)) + assert result["data_parallel_size"] == 1 + + # --- All features enabled --- + + def test_all_features_enabled(self): + """All features active simultaneously.""" + mock = MagicMock() + mock.enable_lora = True + mock.lora_config = MagicMock() + mock.speculative_config = MagicMock() + mock.kv_cache_config.enable_block_reuse = True + mock.cuda_graph_config = MagicMock() + mock.extended_runtime_perf_knob_config = None + mock.enable_chunked_prefill = True + mock.parallel_config.enable_attention_dp = True + mock.parallel_config.tp_size = 8 + result = json.loads(usage_lib._collect_features(mock)) + assert result == { + "lora": True, + "speculative_decoding": True, + "prefix_caching": True, + "cuda_graphs": True, + "chunked_context": True, + "data_parallel_size": 8, + } + + # --- Output format --- + + def test_output_is_valid_json(self): + """Output is valid JSON.""" + mock = MagicMock() + result_str = usage_lib._collect_features(mock) + parsed = json.loads(result_str) + assert isinstance(parsed, dict) + + def test_output_uses_compact_separators(self): + """Output uses compact separators (no spaces after : or ,).""" + mock = MagicMock() + result_str = usage_lib._collect_features(mock) + assert ": " not in result_str + assert ", " not in result_str diff --git a/tests/unittest/usage/test_config.py b/tests/unittest/usage/test_config.py new file mode 100644 index 00000000000..03092e244ee --- /dev/null +++ b/tests/unittest/usage/test_config.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensorrt_llm.usage.config -- canonical location for telemetry types.""" + +import pytest + + +class TestTelemetryConfigLocation: + """Verify TelemetryConfig and UsageContext live in tensorrt_llm.usage.config.""" + + def test_import_telemetry_config_from_usage_config(self): + """TelemetryConfig must be importable from tensorrt_llm.usage.config.""" + from tensorrt_llm.usage import config + + assert hasattr(config, "TelemetryConfig") + + def test_import_usage_context_from_usage_config(self): + """UsageContext must be importable from tensorrt_llm.usage.config.""" + from tensorrt_llm.usage import config + + assert hasattr(config, "UsageContext") + + def test_telemetry_config_defaults(self): + """TelemetryConfig defaults: disabled=False, usage_context=UNKNOWN.""" + from tensorrt_llm.usage import config + + tc = config.TelemetryConfig() + assert tc.disabled is False + assert tc.usage_context == config.UsageContext.UNKNOWN + + def test_usage_context_values(self): + """UsageContext enum has all expected members.""" + from tensorrt_llm.usage import config + + expected = {"UNKNOWN", "LLM_CLASS", "CLI_SERVE", "CLI_BENCH", "CLI_EVAL"} + actual = {e.name for e in config.UsageContext} + assert expected == actual + + def test_usage_context_string_values(self): + """UsageContext members have correct string values.""" + from tensorrt_llm.usage import config + + assert config.UsageContext.UNKNOWN.value == "unknown" + assert config.UsageContext.LLM_CLASS.value == "llm_class" + assert config.UsageContext.CLI_SERVE.value == "cli_serve" + assert config.UsageContext.CLI_BENCH.value == "cli_bench" + assert config.UsageContext.CLI_EVAL.value == "cli_eval" + + def test_telemetry_config_disabled_flag(self): + """TelemetryConfig(disabled=True) sets the flag.""" + from tensorrt_llm.usage import config + + tc = config.TelemetryConfig(disabled=True) + assert tc.disabled is True + + def test_telemetry_config_with_context(self): + """TelemetryConfig accepts usage_context parameter.""" + from tensorrt_llm.usage import config + + tc = config.TelemetryConfig(usage_context=config.UsageContext.CLI_SERVE) + assert tc.usage_context == config.UsageContext.CLI_SERVE + + def test_telemetry_config_rejects_extra_fields(self): + """TelemetryConfig(extra='forbid') raises ValidationError on unknown fields.""" + from pydantic import ValidationError + + from tensorrt_llm.usage import config + + with pytest.raises(ValidationError): + config.TelemetryConfig(unknown_field=1) + + +class TestBackwardCompatibility: + """Verify types are still importable from llm_args for backward compat.""" + + def test_telemetry_config_importable_from_llm_args(self): + """TelemetryConfig must still be importable from llm_args.""" + from tensorrt_llm.llmapi import llm_args + + assert hasattr(llm_args, "TelemetryConfig") + + def test_usage_context_importable_from_llm_args(self): + """UsageContext must still be importable from llm_args.""" + from tensorrt_llm.llmapi import llm_args + + assert hasattr(llm_args, "UsageContext") + + def test_same_types_both_locations(self): + """Types from both locations must be the same class.""" + from tensorrt_llm.llmapi import llm_args + from tensorrt_llm.usage import config + + assert config.TelemetryConfig is llm_args.TelemetryConfig + assert config.UsageContext is llm_args.UsageContext diff --git a/tests/unittest/usage/test_e2e_capture.py b/tests/unittest/usage/test_e2e_capture.py new file mode 100644 index 00000000000..0f62e060f67 --- /dev/null +++ b/tests/unittest/usage/test_e2e_capture.py @@ -0,0 +1,261 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end telemetry capture test. + +Verifies the full data flow: LLM.__init__() → report_usage() → +_background_reporter() → _send_to_gxt() → HTTP POST with valid JSON. + +Uses a local HTTP capture server to intercept the telemetry payload without +hitting any external endpoint. + +Requirements: + - GPU (loads TinyLlama via PyTorch backend) + - LLM_MODELS_ROOT set (or /home/scratch.trt_llm_data accessible) + - Must be run with TRTLLM_USAGE_FORCE_ENABLED=1 to bypass pytest + auto-detection (conftest or env) + +Usage: + TRTLLM_USAGE_FORCE_ENABLED=1 LLM_MODELS_ROOT=/home/scratch.trt_llm_data/llm-models \ + python -m pytest tests/unittest/usage/test_e2e_capture.py -v -s +""" + +import json +import os +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path + +import pytest + +# --------------------------------------------------------------------------- +# Model path resolution (same pattern as test_llm_telemetry.py) +# --------------------------------------------------------------------------- + +MODEL_NAME = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +def _get_model_path(): + """Resolve TinyLlama model path from LLM_MODELS_ROOT.""" + root = os.environ.get("LLM_MODELS_ROOT") + if root is None: + # Fallback to standard scratch path + fallback = Path("/home/scratch.trt_llm_data/llm-models") + if fallback.exists(): + root = str(fallback) + if root is None: + pytest.skip("LLM_MODELS_ROOT not set and fallback path not available") + model_path = Path(root) / MODEL_NAME + if not model_path.exists(): + pytest.skip(f"Model not found at {model_path}") + return str(model_path) + + +# --------------------------------------------------------------------------- +# Local HTTP capture server +# --------------------------------------------------------------------------- + + +class CaptureHandler(BaseHTTPRequestHandler): + """HTTP handler that captures POST bodies.""" + + captured_payloads = [] + capture_event = threading.Event() + + def do_POST(self): + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) + try: + payload = json.loads(body) + except json.JSONDecodeError: + payload = {"_raw": body.decode("utf-8", errors="replace")} + + CaptureHandler.captured_payloads.append(payload) + CaptureHandler.capture_event.set() + + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"status": "ok"}') + + def log_message(self, format, *args): + """Suppress request logging to keep test output clean.""" + pass + + +@pytest.fixture +def capture_server(): + """Start a local HTTP server on a free port and yield its URL.""" + # Reset state from any previous test + CaptureHandler.captured_payloads = [] + CaptureHandler.capture_event = threading.Event() + + server = HTTPServer(("127.0.0.1", 0), CaptureHandler) + port = server.server_address[1] + url = f"http://127.0.0.1:{port}/events" + + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + yield url + + server.shutdown() + + +# --------------------------------------------------------------------------- +# E2E test +# --------------------------------------------------------------------------- + + +pytestmark = pytest.mark.threadleak(enabled=False) + + +@pytest.mark.skipif( + not os.environ.get("TRTLLM_USAGE_FORCE_ENABLED"), + reason="Set TRTLLM_USAGE_FORCE_ENABLED=1 to run e2e telemetry tests", +) +class TestE2ECapture: + """End-to-end telemetry capture: real model → real HTTP POST → validate JSON.""" + + def test_initial_report_captured(self, capture_server, monkeypatch): + """Load TinyLlama and verify the initial telemetry report arrives.""" + import tensorrt_llm.usage.usage_lib as usage_lib + + # Bypass endpoint validation for local capture server + monkeypatch.setattr(usage_lib, "_get_stats_server", lambda: capture_server) + monkeypatch.setenv("TRTLLM_USAGE_FORCE_ENABLED", "1") + # The parent conftest (tests/unittest/conftest.py) sets + # TRTLLM_NO_USAGE_STATS=1 to prevent telemetry during normal tests. + # We must clear it for e2e verification. + monkeypatch.delenv("TRTLLM_NO_USAGE_STATS", raising=False) + + # Reset the global reporter guard so we can trigger a fresh report + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + + model_path = _get_model_path() + + from tensorrt_llm import LLM as LLM_torch + from tensorrt_llm.llmapi import KvCacheConfig + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + + with LLM_torch(model=model_path, kv_cache_config=kv_cache_config) as _: + # Wait for the background thread to POST the initial report + received = CaptureHandler.capture_event.wait(timeout=30) + assert received, ( + "Timed out waiting for telemetry POST. The background reporter may not have fired." + ) + + # --- Validate the captured payload --- + assert len(CaptureHandler.captured_payloads) >= 1, "Expected at least 1 captured payload" + payload = CaptureHandler.captured_payloads[0] + + # GXT envelope fields + assert "clientId" in payload + assert "eventProtocol" in payload + assert payload["eventProtocol"] == "1.6" + assert "sessionId" in payload + assert "sentTs" in payload + assert "events" in payload + assert len(payload["events"]) == 1 + + event = payload["events"][0] + assert event["name"] == "trtllm_initial_report" + assert "ts" in event + assert "parameters" in event + + params = event["parameters"] + + # TRT-LLM version + assert "trtllmVersion" in params + assert isinstance(params["trtllmVersion"], str) + + # System info + assert "platform" in params + assert "pythonVersion" in params + assert "cpuArchitecture" in params + assert "cpuCount" in params + assert params["cpuCount"] > 0 + + # GPU info (we require a GPU for this test) + assert "gpuCount" in params + assert params["gpuCount"] > 0 + assert "gpuName" in params + assert len(params["gpuName"]) > 0 + assert "gpuMemoryMB" in params + assert params["gpuMemoryMB"] > 0 + assert "cudaVersion" in params + + # Model architecture + assert params["architectureClassName"] == "LlamaForCausalLM" + + # Backend + assert params["backend"] == "pytorch" + + # Parallelism defaults for single-GPU + assert params["tensorParallelSize"] == 1 + assert params["pipelineParallelSize"] == 1 + + # Ingress point (default Python API → promoted to llm_class) + assert params["ingressPoint"] == "llm_class" + + # Features JSON + assert "featuresJson" in params + features = json.loads(params["featuresJson"]) + expected_keys = { + "lora", + "speculative_decoding", + "prefix_caching", + "cuda_graphs", + "chunked_context", + "data_parallel_size", + } + assert set(features.keys()) == expected_keys + + # Schema version + assert payload["eventSchemaVer"] == "0.1" + + # Disagg fields present (may be empty strings) + assert "disaggRole" in params + assert "deploymentId" in params + + def test_cli_serve_context_e2e(self, capture_server, monkeypatch): + """Verify CLI_SERVE context flows through to the captured payload.""" + import tensorrt_llm.usage.usage_lib as usage_lib + + # Bypass endpoint validation for local capture server + monkeypatch.setattr(usage_lib, "_get_stats_server", lambda: capture_server) + monkeypatch.setenv("TRTLLM_USAGE_FORCE_ENABLED", "1") + monkeypatch.delenv("TRTLLM_NO_USAGE_STATS", raising=False) + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + + model_path = _get_model_path() + + from tensorrt_llm import LLM as LLM_torch + from tensorrt_llm.llmapi import KvCacheConfig + from tensorrt_llm.usage import TelemetryConfig, UsageContext + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + + with LLM_torch( + model=model_path, + kv_cache_config=kv_cache_config, + telemetry_config=TelemetryConfig(usage_context=UsageContext.CLI_SERVE), + ) as _: + received = CaptureHandler.capture_event.wait(timeout=30) + assert received, "Timed out waiting for telemetry POST" + + payload = CaptureHandler.captured_payloads[0] + params = payload["events"][0]["parameters"] + assert params["ingressPoint"] == "cli_serve" diff --git a/tests/unittest/usage/test_opt_out.py b/tests/unittest/usage/test_opt_out.py new file mode 100644 index 00000000000..f88fe8126c5 --- /dev/null +++ b/tests/unittest/usage/test_opt_out.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for telemetry opt-out mechanisms and CI/test auto-detection.""" + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from tensorrt_llm.usage import usage_lib + +# --------------------------------------------------------------------------- +# Opt-out tests +# --------------------------------------------------------------------------- + + +class TestOptOut: + def test_opt_out_env_var(self, monkeypatch): + """TRTLLM_NO_USAGE_STATS=1 disables telemetry.""" + monkeypatch.setenv("TRTLLM_NO_USAGE_STATS", "1") + monkeypatch.delenv("DO_NOT_TRACK", raising=False) + monkeypatch.setattr( + "tensorrt_llm.usage.usage_lib._OPT_OUT_FILE", + Path("/nonexistent"), + ) + assert not usage_lib.is_usage_stats_enabled() + + def test_opt_out_do_not_track(self, monkeypatch): + """DO_NOT_TRACK=1 (industry standard) disables telemetry.""" + monkeypatch.delenv("TRTLLM_NO_USAGE_STATS", raising=False) + monkeypatch.setenv("DO_NOT_TRACK", "1") + monkeypatch.setattr( + "tensorrt_llm.usage.usage_lib._OPT_OUT_FILE", + Path("/nonexistent"), + ) + assert not usage_lib.is_usage_stats_enabled() + + def test_opt_out_file(self, tmp_path, monkeypatch): + """File-based opt-out (~/.config/trtllm/do_not_track) disables telemetry.""" + opt_out = tmp_path / "do_not_track" + opt_out.touch() + monkeypatch.delenv("TRTLLM_NO_USAGE_STATS", raising=False) + monkeypatch.delenv("DO_NOT_TRACK", raising=False) + monkeypatch.setattr("tensorrt_llm.usage.usage_lib._OPT_OUT_FILE", opt_out) + assert not usage_lib.is_usage_stats_enabled() + + def test_enabled_by_default(self, enable_telemetry): + """Telemetry is enabled when no opt-out is configured.""" + assert usage_lib.is_usage_stats_enabled() + + def test_opt_out_telemetry_disabled_env_var_true(self, monkeypatch): + """TELEMETRY_DISABLED=true disables telemetry.""" + monkeypatch.setenv("TELEMETRY_DISABLED", "true") + assert not usage_lib.is_usage_stats_enabled() + + def test_opt_out_telemetry_disabled_env_var_one(self, monkeypatch): + """TELEMETRY_DISABLED=1 disables telemetry.""" + monkeypatch.setenv("TELEMETRY_DISABLED", "1") + assert not usage_lib.is_usage_stats_enabled() + + def test_opt_out_telemetry_disabled_env_var_case_insensitive(self, monkeypatch): + """TELEMETRY_DISABLED=True (mixed case) disables telemetry.""" + monkeypatch.setenv("TELEMETRY_DISABLED", "True") + assert not usage_lib.is_usage_stats_enabled() + + def test_opt_out_telemetry_disabled_env_var_false(self, monkeypatch, enable_telemetry): + """TELEMETRY_DISABLED=false does NOT disable telemetry.""" + monkeypatch.setenv("TELEMETRY_DISABLED", "false") + assert usage_lib.is_usage_stats_enabled() + + def test_opt_out_programmatic_flag(self): + """telemetry_disabled=True (programmatic) disables telemetry.""" + assert not usage_lib.is_usage_stats_enabled(telemetry_disabled=True) + + def test_programmatic_flag_default_false(self, enable_telemetry): + """Default telemetry_disabled=False does not disable.""" + assert usage_lib.is_usage_stats_enabled(telemetry_disabled=False) + + +# --------------------------------------------------------------------------- +# CI/Test auto-detection tests +# --------------------------------------------------------------------------- + + +class TestCIAutoDetection: + """Test automatic disabling of telemetry in CI and test environments.""" + + @pytest.fixture(autouse=True) + def _clear_all(self, monkeypatch): + """Clear all CI/test/opt-out env vars for a clean slate.""" + for var in usage_lib._CI_ENV_VARS + usage_lib._TEST_ENV_VARS: + monkeypatch.delenv(var, raising=False) + monkeypatch.delenv("TRTLLM_USAGE_FORCE_ENABLED", raising=False) + monkeypatch.delenv("TRTLLM_NO_USAGE_STATS", raising=False) + monkeypatch.delenv("DO_NOT_TRACK", raising=False) + monkeypatch.delenv("TELEMETRY_DISABLED", raising=False) + monkeypatch.setattr( + "tensorrt_llm.usage.usage_lib._OPT_OUT_FILE", + Path("/nonexistent/path/do_not_track"), + ) + + def test_auto_disable_ci_generic(self, monkeypatch): + """CI=true (generic CI env var) disables telemetry.""" + monkeypatch.setenv("CI", "true") + assert not usage_lib.is_usage_stats_enabled() + + def test_auto_disable_github_actions(self, monkeypatch): + """GITHUB_ACTIONS=true disables telemetry.""" + monkeypatch.setenv("GITHUB_ACTIONS", "true") + assert not usage_lib.is_usage_stats_enabled() + + def test_auto_disable_jenkins(self, monkeypatch): + """JENKINS_URL set disables telemetry.""" + monkeypatch.setenv("JENKINS_URL", "http://jenkins.example.com") + assert not usage_lib.is_usage_stats_enabled() + + def test_auto_disable_gitlab_ci(self, monkeypatch): + """GITLAB_CI=true disables telemetry.""" + monkeypatch.setenv("GITLAB_CI", "true") + assert not usage_lib.is_usage_stats_enabled() + + def test_auto_disable_buildkite(self, monkeypatch): + """BUILDKITE=true disables telemetry.""" + monkeypatch.setenv("BUILDKITE", "true") + assert not usage_lib.is_usage_stats_enabled() + + def test_auto_disable_pytest(self, monkeypatch): + """PYTEST_CURRENT_TEST set disables telemetry.""" + monkeypatch.setenv("PYTEST_CURRENT_TEST", "tests/test_example.py::test_foo") + assert not usage_lib.is_usage_stats_enabled() + + def test_ci_detection_returns_true(self, monkeypatch): + """_is_ci_or_test_environment() returns True when CI var is set.""" + monkeypatch.setenv("CI", "true") + assert usage_lib._is_ci_or_test_environment() + + def test_no_ci_detection_returns_false(self, monkeypatch): + """_is_ci_or_test_environment() returns False with no CI vars.""" + # PYTEST_CURRENT_TEST is set by pytest after fixtures run, re-clear it + monkeypatch.delenv("PYTEST_CURRENT_TEST", raising=False) + assert not usage_lib._is_ci_or_test_environment() + + def test_force_enable_overrides_ci_detection(self, monkeypatch): + """TRTLLM_USAGE_FORCE_ENABLED=1 re-enables telemetry in CI.""" + monkeypatch.setenv("CI", "true") + monkeypatch.setenv("TRTLLM_USAGE_FORCE_ENABLED", "1") + assert usage_lib.is_usage_stats_enabled() + + def test_force_enable_does_not_override_explicit_opt_out(self, monkeypatch): + """Explicit opt-out takes precedence over force-enable.""" + monkeypatch.setenv("CI", "true") + monkeypatch.setenv("TRTLLM_USAGE_FORCE_ENABLED", "1") + monkeypatch.setenv("TRTLLM_NO_USAGE_STATS", "1") + assert not usage_lib.is_usage_stats_enabled() + + def test_all_ci_vars_detected(self, monkeypatch): + """Every CI env var in _CI_ENV_VARS triggers detection.""" + for var in usage_lib._CI_ENV_VARS: + monkeypatch.setenv(var, "true") + assert usage_lib._is_ci_or_test_environment(), f"{var} was not detected" + + def test_empty_ci_var_not_detected(self, monkeypatch): + """Empty string CI var does NOT trigger detection.""" + # PYTEST_CURRENT_TEST is set by pytest after fixtures run, re-clear it + monkeypatch.delenv("PYTEST_CURRENT_TEST", raising=False) + monkeypatch.setenv("CI", "") + assert not usage_lib._is_ci_or_test_environment() + + def test_noop_in_ci_without_force(self, monkeypatch): + """report_usage() does not spawn thread in CI environment.""" + monkeypatch.setenv("CI", "true") + with patch("tensorrt_llm.usage.usage_lib.threading.Thread") as thread_cls: + usage_lib.report_usage() + thread_cls.assert_not_called() + + +# --------------------------------------------------------------------------- +# Path.home() failure resilience tests +# --------------------------------------------------------------------------- + + +class TestPathHomeFailure: + """Verify telemetry degrades gracefully when Path.home() fails. + + In minimal containers or non-standard service accounts, HOME may be + unset and passwd lookup may fail, causing Path.home() to raise + RuntimeError. The file-based opt-out becomes unavailable, but + everything else (env-var opt-out, telemetry reporting) must still work. + """ + + def test_opt_out_file_none_does_not_crash(self, monkeypatch, enable_telemetry): + """is_usage_stats_enabled() works when _OPT_OUT_FILE is None.""" + monkeypatch.setattr("tensorrt_llm.usage.usage_lib._OPT_OUT_FILE", None) + # Should not raise; telemetry enabled since no opt-out is active + assert usage_lib.is_usage_stats_enabled() + + def test_env_var_opt_out_still_works_when_file_unavailable(self, monkeypatch): + """Env-var opt-out works even when file-based opt-out is unavailable.""" + monkeypatch.setattr("tensorrt_llm.usage.usage_lib._OPT_OUT_FILE", None) + monkeypatch.setenv("TRTLLM_NO_USAGE_STATS", "1") + assert not usage_lib.is_usage_stats_enabled() + + def test_report_usage_does_not_crash_when_file_unavailable(self, monkeypatch): + """report_usage() degrades silently when _OPT_OUT_FILE is None.""" + monkeypatch.setattr("tensorrt_llm.usage.usage_lib._OPT_OUT_FILE", None) + monkeypatch.setenv("CI", "true") + # Should not raise + usage_lib.report_usage() + + def test_module_import_survives_path_home_failure(self): + """Simulates Path.home() raising RuntimeError at module scope. + + The module-level _OPT_OUT_FILE assignment is guarded by + try/except, so the module should already be imported. This test + verifies the guard produces None (not a crash) by replicating + the exact logic. + """ + from pathlib import Path + + def failing_home(): + raise RuntimeError("Could not determine home directory") + + with patch.object(Path, "home", side_effect=failing_home): + try: + result = Path.home() / ".config" / "trtllm" / "do_not_track" + except (RuntimeError, KeyError): + result = None + assert result is None diff --git a/tests/unittest/usage/test_reporter.py b/tests/unittest/usage/test_reporter.py new file mode 100644 index 00000000000..e498be3013e --- /dev/null +++ b/tests/unittest/usage/test_reporter.py @@ -0,0 +1,593 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for report_usage(), background reporter, thread lifecycle, and heartbeat.""" + +import logging +import threading +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from tensorrt_llm.usage import usage_lib + +# --------------------------------------------------------------------------- +# Console notification tests +# --------------------------------------------------------------------------- + + +class TestNotification: + def test_usage_notification_shown(self, monkeypatch, caplog, enable_telemetry): + """Notification is logged when telemetry is enabled.""" + usage_lib._NOTIFICATION_SHOWN.clear() + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + + mock_thread = MagicMock() + with patch("tensorrt_llm.usage.usage_lib.threading.Thread", return_value=mock_thread): + with caplog.at_level(logging.INFO, logger="tensorrt_llm"): + usage_lib.report_usage() + + assert "anonymous usage data" in caplog.text + + def test_usage_notification_not_shown_when_disabled(self, monkeypatch, caplog): + """Notification is NOT shown when telemetry is disabled.""" + usage_lib._NOTIFICATION_SHOWN.clear() + monkeypatch.setenv("TRTLLM_NO_USAGE_STATS", "1") + + with caplog.at_level(logging.INFO, logger="tensorrt_llm"): + usage_lib.report_usage() + + assert "anonymous usage data" not in caplog.text + + +# --------------------------------------------------------------------------- +# Thread lifecycle tests +# --------------------------------------------------------------------------- + + +class TestReportUsage: + def test_spawns_daemon_thread(self, monkeypatch, enable_telemetry): + """report_usage() spawns a daemon thread named 'trtllm-usage-stats'.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + mock_thread = MagicMock() + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", return_value=mock_thread + ) as thread_cls: + usage_lib.report_usage() + thread_cls.assert_called_once() + call_kwargs = thread_cls.call_args + assert call_kwargs.kwargs["daemon"] is True + assert call_kwargs.kwargs["name"] == "trtllm-usage-stats" + mock_thread.start.assert_called_once() + + def test_noop_when_disabled(self, monkeypatch): + """report_usage() does nothing when telemetry is disabled.""" + monkeypatch.setenv("TRTLLM_NO_USAGE_STATS", "1") + with patch("tensorrt_llm.usage.usage_lib.threading.Thread") as thread_cls: + usage_lib.report_usage() + thread_cls.assert_not_called() + + def test_fail_silent(self, monkeypatch, enable_telemetry): + """report_usage() never raises, even if thread creation fails.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", side_effect=RuntimeError("boom") + ): + usage_lib.report_usage() # Must not raise + + def test_report_usage_passes_args(self, monkeypatch, enable_telemetry): + """report_usage() passes llm_args and pretrained_config to thread.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + mock_args = MagicMock() + mock_config = MagicMock() + mock_thread = MagicMock() + + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", return_value=mock_thread + ) as thread_cls: + usage_lib.report_usage( + llm_args=mock_args, + pretrained_config=mock_config, + ) + call_args = thread_cls.call_args + assert call_args.kwargs["target"].__name__ == "_background_reporter" + assert call_args.kwargs["args"] == (mock_args, mock_config, "") + + def test_report_usage_telemetry_disabled_no_thread(self, monkeypatch): + """report_usage with TelemetryConfig(disabled=True) should not start a thread.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + + telemetry_config = SimpleNamespace(disabled=True) + + initial_count = threading.active_count() + usage_lib.report_usage(telemetry_config=telemetry_config) + assert threading.active_count() == initial_count + + def test_get_trtllm_version_returns_string(self): + """_get_trtllm_version returns a string.""" + result = usage_lib._get_trtllm_version() + assert isinstance(result, str) + + +# --------------------------------------------------------------------------- +# Duplicate reporter guard tests +# --------------------------------------------------------------------------- + + +class TestDuplicateReporterGuard: + def test_second_call_is_noop(self, monkeypatch, enable_telemetry): + """Calling report_usage() twice only spawns one thread.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + mock_thread = MagicMock() + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", return_value=mock_thread + ) as thread_cls: + usage_lib.report_usage() + usage_lib.report_usage() # second call should be a no-op + assert thread_cls.call_count == 1 + + def test_guard_resets_on_thread_failure(self, monkeypatch, enable_telemetry): + """_REPORTER_STARTED resets if thread creation fails, allowing retry.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + # First call: thread creation fails + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", + side_effect=RuntimeError("too many threads"), + ): + usage_lib.report_usage() # should not raise + + assert not usage_lib._REPORTER_STARTED + + # Second call: thread creation succeeds + mock_thread = MagicMock() + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", return_value=mock_thread + ) as thread_cls: + usage_lib.report_usage() + thread_cls.assert_called_once() + mock_thread.start.assert_called_once() + + +# --------------------------------------------------------------------------- +# Heartbeat interval tests +# --------------------------------------------------------------------------- + + +class TestHeartbeatInterval: + def test_default_value(self, monkeypatch): + """Default heartbeat interval is 600.""" + monkeypatch.delenv("TRTLLM_USAGE_HEARTBEAT_INTERVAL", raising=False) + assert usage_lib._get_heartbeat_interval() == 600 + + def test_custom_value(self, monkeypatch): + """Custom heartbeat interval is parsed correctly.""" + monkeypatch.setenv("TRTLLM_USAGE_HEARTBEAT_INTERVAL", "120") + assert usage_lib._get_heartbeat_interval() == 120 + + def test_invalid_value_falls_back(self, monkeypatch): + """Invalid env var falls back to 600 instead of crashing.""" + monkeypatch.setenv("TRTLLM_USAGE_HEARTBEAT_INTERVAL", "abc") + assert usage_lib._get_heartbeat_interval() == 600 + + def test_empty_value_falls_back(self, monkeypatch): + """Empty env var falls back to 600.""" + monkeypatch.setenv("TRTLLM_USAGE_HEARTBEAT_INTERVAL", "") + assert usage_lib._get_heartbeat_interval() == 600 + + +# --------------------------------------------------------------------------- +# Env vars read at call time tests +# --------------------------------------------------------------------------- + + +class TestEnvVarCallTime: + def test_stats_server_reads_at_call_time(self, monkeypatch): + """Stats server URL is read at call time, not import time.""" + monkeypatch.setenv( + "TRTLLM_USAGE_STATS_SERVER", "https://events.gfestage.nvidia.com/v1.1/events/json" + ) + assert ( + usage_lib._get_stats_server() == "https://events.gfestage.nvidia.com/v1.1/events/json" + ) + + def test_stats_server_default(self, monkeypatch): + """Default stats server is the GXT endpoint.""" + monkeypatch.delenv("TRTLLM_USAGE_STATS_SERVER", raising=False) + assert usage_lib._get_stats_server() == "https://events.gfe.nvidia.com/v1.1/events/json" + + def test_stats_server_rejects_non_nvidia_domain(self, monkeypatch): + """Non-nvidia.com domains fall back to the default endpoint.""" + monkeypatch.setenv("TRTLLM_USAGE_STATS_SERVER", "https://evil.example.com/capture") + assert usage_lib._get_stats_server() == usage_lib._DEFAULT_ENDPOINT + + def test_stats_server_rejects_http(self, monkeypatch): + """HTTP (non-TLS) endpoints fall back to the default.""" + monkeypatch.setenv( + "TRTLLM_USAGE_STATS_SERVER", "http://events.gfe.nvidia.com/v1.1/events/json" + ) + assert usage_lib._get_stats_server() == usage_lib._DEFAULT_ENDPOINT + + def test_stats_server_rejects_garbage(self, monkeypatch): + """Garbage URLs fall back to the default.""" + monkeypatch.setenv("TRTLLM_USAGE_STATS_SERVER", "not-a-url") + assert usage_lib._get_stats_server() == usage_lib._DEFAULT_ENDPOINT + + def test_stats_server_accepts_nvidia_subdomain(self, monkeypatch): + """Any *.nvidia.com HTTPS URL is accepted.""" + monkeypatch.setenv( + "TRTLLM_USAGE_STATS_SERVER", "https://telemetry.internal.nvidia.com/v2/events" + ) + assert usage_lib._get_stats_server() == "https://telemetry.internal.nvidia.com/v2/events" + + +# --------------------------------------------------------------------------- +# Notice text accuracy tests +# --------------------------------------------------------------------------- + + +class TestNoticeText: + def test_notice_does_not_claim_no_model_names(self): + """Notice no longer claims 'no model names' since arch class is collected.""" + assert "No model names" not in usage_lib._USAGE_NOTICE + assert "No user-identifying information" in usage_lib._USAGE_NOTICE + + +# --------------------------------------------------------------------------- +# Ingress point reporter tests +# --------------------------------------------------------------------------- + + +class TestIngressPointReporter: + """Tests for usage_context flowing through report_usage().""" + + def test_report_usage_passes_usage_context_to_thread(self, monkeypatch, enable_telemetry): + """report_usage() passes usage_context string to background thread.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + mock_thread = MagicMock() + mock_config = MagicMock() + mock_config.disabled = False + mock_config.usage_context = MagicMock() + mock_config.usage_context.value = "cli_serve" + + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", + return_value=mock_thread, + ) as thread_cls: + usage_lib.report_usage(telemetry_config=mock_config) + call_args = thread_cls.call_args + assert call_args.kwargs["args"][2] == "cli_serve" + + def test_report_usage_none_config_sends_empty_context(self, monkeypatch, enable_telemetry): + """report_usage(telemetry_config=None) sends empty usage_context.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + mock_thread = MagicMock() + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", + return_value=mock_thread, + ) as thread_cls: + usage_lib.report_usage(telemetry_config=None) + call_args = thread_cls.call_args + assert call_args.kwargs["args"][2] == "" + + def test_report_usage_context_without_value_falls_back_to_str( + self, monkeypatch, enable_telemetry + ): + """usage_context without .value attribute falls back to str().""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + mock_thread = MagicMock() + mock_config = SimpleNamespace(disabled=False, usage_context="plain_string") + + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", + return_value=mock_thread, + ) as thread_cls: + usage_lib.report_usage(telemetry_config=mock_config) + call_args = thread_cls.call_args + assert call_args.kwargs["args"][2] == "plain_string" + + def test_report_usage_disabled_via_telemetry_config(self, monkeypatch): + """report_usage with TelemetryConfig(disabled=True) is a no-op.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + + mock_config = MagicMock() + mock_config.disabled = True + + with patch("tensorrt_llm.usage.usage_lib.threading.Thread") as thread_cls: + usage_lib.report_usage(telemetry_config=mock_config) + thread_cls.assert_not_called() + + +# --------------------------------------------------------------------------- +# _clamp_str integration tests +# --------------------------------------------------------------------------- + + +class TestClampStrIntegration: + """Verify _background_reporter() clamps long strings to schema limits.""" + + def test_background_reporter_clamps_long_platform_string(self): + """Long platform string does not cause ValidationError; len <= 256.""" + long_platform = "x" * 300 + + captured = {} + + def fake_send(payload): + captured.update(payload) + + stop_event = threading.Event() + stop_event.set() + + with ( + patch.object( + usage_lib, + "_collect_system_info", + return_value={ + "platform": long_platform, + "python_version": "3.12.0", + "cpu_architecture": "x86_64", + "cpu_count": 8, + }, + ), + patch.object(usage_lib, "_send_to_gxt", side_effect=fake_send), + patch.object(usage_lib, "_REPORTER_STOP", stop_event), + ): + usage_lib._background_reporter(None, None, "") + + assert captured, "No payload was captured" + params = captured["events"][0]["parameters"] + assert len(params["platform"]) <= 256 + + +# --------------------------------------------------------------------------- +# Disaggregated serving metadata tests +# --------------------------------------------------------------------------- + + +class TestDisaggMetadata: + """Verify _background_reporter() reads disagg env vars into initial report.""" + + def test_disagg_env_vars_appear_in_payload(self, monkeypatch): + """Disagg env vars appear as disaggRole and deploymentId in payload.""" + monkeypatch.setenv("TRTLLM_DISAGG_ROLE", "context") + monkeypatch.setenv("TRTLLM_DISAGG_DEPLOYMENT_ID", "abc123") + + captured = {} + + def fake_send(payload): + captured.update(payload) + + stop_event = threading.Event() + stop_event.set() + + with ( + patch.object(usage_lib, "_send_to_gxt", side_effect=fake_send), + patch.object(usage_lib, "_REPORTER_STOP", stop_event), + ): + usage_lib._background_reporter(None, None, "") + + assert captured, "No payload was captured" + params = captured["events"][0]["parameters"] + assert params["disaggRole"] == "context" + assert params["deploymentId"] == "abc123" + + +class TestDisaggMetadataEmpty: + """Verify empty defaults when disagg env vars are unset (non-disagg mode).""" + + def test_disagg_fields_empty_when_unset(self, monkeypatch): + """Without disagg env vars, disaggRole and deploymentId are empty strings.""" + monkeypatch.delenv("TRTLLM_DISAGG_ROLE", raising=False) + monkeypatch.delenv("TRTLLM_DISAGG_DEPLOYMENT_ID", raising=False) + + captured = {} + + def fake_send(payload): + captured.update(payload) + + stop_event = threading.Event() + stop_event.set() + + with ( + patch.object(usage_lib, "_send_to_gxt", side_effect=fake_send), + patch.object(usage_lib, "_REPORTER_STOP", stop_event), + ): + usage_lib._background_reporter(None, None, "") + + assert captured, "No payload was captured" + params = captured["events"][0]["parameters"] + assert params["disaggRole"] == "" + assert params["deploymentId"] == "" + + +# --------------------------------------------------------------------------- +# Rank-0 guard tests +# --------------------------------------------------------------------------- + + +class TestRankGuard: + """Verify report_usage() skips reporting for non-zero MPI ranks.""" + + def _setup_reporter(self, monkeypatch): + """Reset reporter state so report_usage() can proceed.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + def test_rank_nonzero_no_thread(self, monkeypatch, enable_telemetry): + """report_usage() is a no-op when mpi_rank() != 0.""" + self._setup_reporter(monkeypatch) + + with patch("tensorrt_llm.usage.usage_lib.threading.Thread") as thread_cls: + with patch("tensorrt_llm._utils.mpi_rank", return_value=1): + usage_lib.report_usage() + thread_cls.assert_not_called() + + def test_rank_zero_proceeds(self, monkeypatch, enable_telemetry): + """report_usage() proceeds normally when mpi_rank() == 0.""" + self._setup_reporter(monkeypatch) + + mock_thread = MagicMock() + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", + return_value=mock_thread, + ) as thread_cls: + with patch("tensorrt_llm._utils.mpi_rank", return_value=0): + usage_lib.report_usage() + thread_cls.assert_called_once() + mock_thread.start.assert_called_once() + + def test_rank_import_fails_proceeds(self, monkeypatch, enable_telemetry): + """report_usage() proceeds (fail-open) when mpi_rank import fails.""" + self._setup_reporter(monkeypatch) + + mock_thread = MagicMock() + with patch( + "tensorrt_llm.usage.usage_lib.threading.Thread", + return_value=mock_thread, + ) as thread_cls: + with patch.dict( + "sys.modules", + {"tensorrt_llm._utils": None}, + ): + usage_lib.report_usage() + thread_cls.assert_called_once() + mock_thread.start.assert_called_once() + + +# --------------------------------------------------------------------------- +# Reporter shutdown tests +# --------------------------------------------------------------------------- + + +class TestReporterShutdown: + """Verify _REPORTER_STOP event exits the heartbeat loop.""" + + def test_reporter_stop_event_exits_heartbeat_loop(self): + """Setting _REPORTER_STOP causes the heartbeat loop to exit.""" + send_count = {"n": 0} + + def counting_send(payload): + send_count["n"] += 1 + + stop_event = threading.Event() + threading.Timer(0.1, stop_event.set).start() + + with ( + patch.object(usage_lib, "_send_to_gxt", side_effect=counting_send), + patch.object(usage_lib, "_REPORTER_STOP", stop_event), + patch.object(usage_lib, "_get_heartbeat_interval", return_value=3600), + ): + usage_lib._background_reporter(None, None, "") + + assert send_count["n"] == 1 + + +# --------------------------------------------------------------------------- +# Heartbeat fail-silent continuation test +# --------------------------------------------------------------------------- + + +class TestHeartbeatFailSilent: + """Verify transient heartbeat failure doesn't kill the loop.""" + + def test_heartbeat_continues_after_transient_failure(self): + """OSError on one heartbeat doesn't prevent subsequent heartbeats.""" + calls = [] + + def tracking_send(payload): + calls.append(payload) + if len(calls) == 2: # first heartbeat (seq=0) + raise OSError("transient network failure") + + stop = threading.Event() + timer = threading.Timer(0.05, stop.set) + timer.start() + + with ( + patch.object(usage_lib, "_send_to_gxt", side_effect=tracking_send), + patch.object(usage_lib, "_REPORTER_STOP", stop), + patch.object(usage_lib, "_get_heartbeat_interval", return_value=0), + ): + usage_lib._background_reporter(None, None, "") + + timer.join(timeout=1) + + # call 1 = initial report, call 2 = heartbeat (failed), call 3+ = more heartbeats + assert len(calls) >= 3, ( + f"Expected >=3 _send_to_gxt calls (loop should continue after failure), got {len(calls)}" + ) + + +# --------------------------------------------------------------------------- +# Concurrent reporter start test +# --------------------------------------------------------------------------- + + +class TestConcurrentReporterStart: + """Verify _REPORTER_LOCK works under real thread contention.""" + + def test_concurrent_calls_spawn_single_thread(self, monkeypatch, enable_telemetry): + """10 concurrent report_usage() calls produce exactly 1 reporter thread.""" + monkeypatch.setattr(usage_lib, "_REPORTER_STARTED", False) + usage_lib._NOTIFICATION_SHOWN.set() + + call_count_lock = threading.Lock() + threads_started = {"count": 0} + + mock_thread = MagicMock() + + def counting_thread(*args, **kwargs): + with call_count_lock: + threads_started["count"] += 1 + return mock_thread + + with ( + patch.object( + usage_lib, + "threading", + wraps=threading, + ) as mock_threading_mod, + patch("tensorrt_llm._utils.mpi_rank", return_value=0), + ): + mock_threading_mod.Thread = MagicMock(side_effect=counting_thread) + mock_threading_mod.Lock = threading.Lock + mock_threading_mod.Event = threading.Event + + barrier = threading.Barrier(10) + + def call_report(): + barrier.wait() + usage_lib.report_usage() + + pool = [threading.Thread(target=call_report) for _ in range(10)] + for t in pool: + t.start() + for t in pool: + t.join(timeout=5) + + assert threads_started["count"] == 1 diff --git a/tests/unittest/usage/test_schema.py b/tests/unittest/usage/test_schema.py new file mode 100644 index 00000000000..f65fed9e34a --- /dev/null +++ b/tests/unittest/usage/test_schema.py @@ -0,0 +1,714 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Pydantic schema models, GXT payload format, and SMS schema compliance.""" + +import json + +import pytest +from pydantic import ValidationError + +from tensorrt_llm.usage import schema, schemas + +# --------------------------------------------------------------------------- +# Features JSON payload structure tests +# --------------------------------------------------------------------------- + + +class TestFeaturesJsonPayload: + """Tests for featuresJson field in the GXT payload.""" + + def test_initial_report_contains_features_json(self): + """FeaturesJson appears in initial report event parameters.""" + report = schema.TrtllmInitialReport( + featuresJson='{"lora":true}', + ) + payload = schema.build_gxt_payload( + event=report, + session_id="test", + trtllm_version="0.18.0", + ) + params = payload["events"][0]["parameters"] + assert "featuresJson" in params + assert params["featuresJson"] == '{"lora":true}' + + def test_features_json_round_trips_through_gxt_payload(self): + """FeaturesJson survives full JSON serialization round-trip.""" + features = '{"lora":true,"speculative_decoding":false,"data_parallel_size":4}' + report = schema.TrtllmInitialReport(featuresJson=features) + payload = schema.build_gxt_payload( + event=report, + session_id="test", + trtllm_version="0.18.0", + ) + json_str = json.dumps(payload) + parsed = json.loads(json_str) + inner = json.loads(parsed["events"][0]["parameters"]["featuresJson"]) + assert inner["lora"] is True + assert inner["speculative_decoding"] is False + assert inner["data_parallel_size"] == 4 + + def test_features_json_default_is_empty_object(self): + """Default featuresJson value is '{}'.""" + report = schema.TrtllmInitialReport() + data = report.model_dump(by_alias=True) + assert data["featuresJson"] == "{}" + + +# --------------------------------------------------------------------------- +# GXT payload format tests +# --------------------------------------------------------------------------- + + +class TestGxtPayload: + def test_initial_report_format(self): + """Initial report payload matches GXT protocol v1.6 envelope.""" + event = schema.TrtllmInitialReport( + trtllmVersion="0.18.0", + platform="Linux-5.15.0", + pythonVersion="3.10.12", + cpuArchitecture="x86_64", + cpuCount=64, + gpuCount=8, + gpuName="NVIDIA H100 80GB HBM3", + gpuMemoryMB=81920, + cudaVersion="12.4", + architectureClassName="LlamaForCausalLM", + backend="pytorch", + tensorParallelSize=4, + pipelineParallelSize=2, + contextParallelSize=1, + moeExpertParallelSize=8, + moeTensorParallelSize=2, + dtype="float16", + quantizationAlgo="fp8", + kvCacheDtype="auto", + featuresJson='{"lora":true,"speculative_decoding":false}', + ) + payload = schema.build_gxt_payload( + event=event, session_id="abc123", trtllm_version="0.18.0" + ) + + # Top-level envelope + assert payload["clientId"] == schema.CLIENT_ID + assert payload["eventProtocol"] == schema.EVENT_PROTOCOL + assert payload["eventSchemaVer"] == schema.EVENT_SCHEMA_VER + assert payload["eventSysVer"] == schema.EVENT_SYS_VER + assert payload["clientType"] == "Native" + assert payload["clientVariant"] == "Release" + assert payload["clientVer"] == "0.18.0" + assert payload["sessionId"] == "abc123" + + # Privacy fields + assert payload["userId"] == "undefined" + assert payload["deviceId"] == "undefined" + assert payload["browserType"] == "undefined" + assert payload["externalUserId"] == "undefined" + + # GDPR fields + assert payload["gdprBehOptIn"] == "None" + assert payload["gdprFuncOptIn"] == "None" + assert payload["gdprTechOptIn"] == "None" + assert payload["deviceGdprBehOptIn"] == "None" + + # Events array + assert len(payload["events"]) == 1 + assert payload["events"][0]["name"] == "trtllm_initial_report" + assert payload["events"][0]["ts"] # non-empty timestamp + + # Event parameters + params = payload["events"][0]["parameters"] + assert params["trtllmVersion"] == "0.18.0" + assert params["gpuCount"] == 8 + assert params["gpuName"] == "NVIDIA H100 80GB HBM3" + assert params["moeExpertParallelSize"] == 8 + assert params["moeTensorParallelSize"] == 2 + + def test_heartbeat_minimal(self): + """Heartbeat payload contains only seq field.""" + event = schema.TrtllmHeartbeat(seq=0) + payload = schema.build_gxt_payload(event=event, session_id="xyz", trtllm_version="0.18.0") + + assert payload["events"][0]["name"] == "trtllm_heartbeat" + params = payload["events"][0]["parameters"] + assert params == {"seq": 0} + + def test_json_serializable(self): + """Entire payload is JSON-serializable (no enum, datetime issues).""" + event = schema.TrtllmInitialReport( + trtllmVersion="0.18.0", + platform="Linux-5.15.0", + pythonVersion="3.10.12", + cpuArchitecture="x86_64", + cpuCount=64, + gpuCount=4, + gpuName="NVIDIA A100", + gpuMemoryMB=40960, + cudaVersion="12.4", + architectureClassName="LlamaForCausalLM", + backend="pytorch", + tensorParallelSize=4, + pipelineParallelSize=1, + contextParallelSize=1, + moeExpertParallelSize=0, + moeTensorParallelSize=0, + dtype="float16", + quantizationAlgo="none", + kvCacheDtype="auto", + featuresJson='{"lora":false}', + ) + payload = schema.build_gxt_payload(event=event, session_id="test", trtllm_version="0.18.0") + json_str = json.dumps(payload) + assert isinstance(json_str, str) + parsed = json.loads(json_str) + assert parsed["clientId"] == schema.CLIENT_ID + + def test_timestamp_format(self): + """ISO timestamp has millisecond precision and ends with Z.""" + from datetime import datetime, timezone + + ts = schema.get_iso_timestamp(datetime(2026, 2, 17, 10, 30, 0, 500000, tzinfo=timezone.utc)) + assert ts == "2026-02-17T10:30:00.500Z" + + def test_timestamp_default_utc(self): + """Default timestamp is current UTC time.""" + ts = schema.get_iso_timestamp() + assert ts.endswith("Z") + assert "T" in ts + + +# --------------------------------------------------------------------------- +# Schema constants tests +# --------------------------------------------------------------------------- + + +class TestSchemaConstants: + def test_client_id(self): + """CLIENT_ID matches provisioned value.""" + assert schema.CLIENT_ID == "616561816355034" + + def test_event_protocol(self): + """EVENT_PROTOCOL is v1.6.""" + assert schema.EVENT_PROTOCOL == "1.6" + + def test_event_schema_ver(self): + """EVENT_SCHEMA_VER is 0.1 (matches SMS schema schemaVersion).""" + assert schema.EVENT_SCHEMA_VER == "0.1" + + def test_event_sys_ver(self): + """EVENT_SYS_VER identifies the telemetry subsystem.""" + assert schema.EVENT_SYS_VER == "trtllm-telemetry/1.0" + + +# --------------------------------------------------------------------------- +# Heartbeat event content tests +# --------------------------------------------------------------------------- + + +class TestHeartbeatContent: + def test_heartbeat_contains_only_seq(self): + """Heartbeat events contain only seq field.""" + event = schema.TrtllmHeartbeat(seq=0) + payload = schema.build_gxt_payload(event=event, session_id="xyz", trtllm_version="0.18.0") + params = payload["events"][0]["parameters"] + assert params == {"seq": 0} + assert payload["events"][0]["name"] == "trtllm_heartbeat" + + def test_heartbeat_seq_increments(self): + """Heartbeat seq field supports incrementing values.""" + for i in range(5): + event = schema.TrtllmHeartbeat(seq=i) + payload = schema.build_gxt_payload( + event=event, session_id="xyz", trtllm_version="0.18.0" + ) + params = payload["events"][0]["parameters"] + assert params["seq"] == i + + +# --------------------------------------------------------------------------- +# Ingress point field tests +# --------------------------------------------------------------------------- + + +class TestIngressPoint: + """Tests for the ingressPoint field in telemetry events.""" + + def test_initial_report_has_ingress_point_field(self): + """TrtllmInitialReport has an ingressPoint field.""" + report = schema.TrtllmInitialReport() + data = report.model_dump(by_alias=True) + assert "ingressPoint" in data + + def test_ingress_point_default_empty_string(self): + """Default ingressPoint is empty string.""" + report = schema.TrtllmInitialReport() + data = report.model_dump(by_alias=True) + assert data["ingressPoint"] == "" + + def test_ingress_point_set_to_cli_serve(self): + """IngressPoint can be set to 'cli_serve'.""" + report = schema.TrtllmInitialReport(ingressPoint="cli_serve") + data = report.model_dump(by_alias=True) + assert data["ingressPoint"] == "cli_serve" + + def test_ingress_point_set_to_llm_class(self): + """IngressPoint can be set to 'llm_class'.""" + report = schema.TrtllmInitialReport(ingressPoint="llm_class") + data = report.model_dump(by_alias=True) + assert data["ingressPoint"] == "llm_class" + + def test_ingress_point_set_to_cli_bench(self): + """IngressPoint can be set to 'cli_bench'.""" + report = schema.TrtllmInitialReport(ingressPoint="cli_bench") + data = report.model_dump(by_alias=True) + assert data["ingressPoint"] == "cli_bench" + + def test_ingress_point_set_to_cli_eval(self): + """IngressPoint can be set to 'cli_eval'.""" + report = schema.TrtllmInitialReport(ingressPoint="cli_eval") + data = report.model_dump(by_alias=True) + assert data["ingressPoint"] == "cli_eval" + + def test_ingress_point_in_gxt_payload(self): + """IngressPoint appears in the full GXT payload.""" + report = schema.TrtllmInitialReport(ingressPoint="cli_serve") + payload = schema.build_gxt_payload( + event=report, + session_id="test-session", + trtllm_version="1.0.0", + ) + params = payload["events"][0]["parameters"] + assert params["ingressPoint"] == "cli_serve" + + +# --------------------------------------------------------------------------- +# UsageContext enum tests +# --------------------------------------------------------------------------- + + +try: + from tensorrt_llm.llmapi import llm_args as _llm_args_mod + + _HAS_LLMAPI = True +except ImportError: + _HAS_LLMAPI = False + +_skip_no_llmapi = pytest.mark.skipif( + not _HAS_LLMAPI, + reason="Requires full tensorrt_llm.llmapi (C++ bindings not available)", +) + + +@_skip_no_llmapi +class TestUsageContextEnum: + """Tests for UsageContext enum values.""" + + def test_all_context_values_are_strings(self): + """All UsageContext values are valid strings.""" + for ctx in _llm_args_mod.UsageContext: + assert isinstance(ctx.value, str) + assert len(ctx.value) > 0 + + def test_expected_contexts_exist(self): + """All expected UsageContext values are defined.""" + expected = {"unknown", "llm_class", "cli_serve", "cli_bench", "cli_eval"} + actual = {ctx.value for ctx in _llm_args_mod.UsageContext} + assert actual == expected + + def test_unknown_is_default(self): + """UNKNOWN is the default UsageContext.""" + config = _llm_args_mod.TelemetryConfig() + assert config.usage_context == _llm_args_mod.UsageContext.UNKNOWN + + +# --------------------------------------------------------------------------- +# Schema compliance tests (SMS Event Definition validation) +# --------------------------------------------------------------------------- + + +class TestSchemaCompliance: + """Validate Pydantic models match the SMS Event Definition Schema. + + The ground truth is the SMS schema file which defines the event + structure registered with the NvTelemetry Data Platform. These tests + catch accidental schema drift. + """ + + def _load_sms_schema(self) -> dict: + """Load the checked-in SMS Event Definition Schema.""" + assert schemas.SMS_SCHEMA_PATH.exists(), ( + f"SMS schema file not found: {schemas.SMS_SCHEMA_PATH}\n" + "This file is the ground-truth event definition." + ) + return json.loads(schemas.SMS_SCHEMA_PATH.read_text()) + + def _get_pydantic_aliases(self, model_cls) -> set: + """Extract alias names from a Pydantic model.""" + aliases = set() + for name, field_info in model_cls.model_fields.items(): + alias = field_info.alias if field_info.alias else name + aliases.add(alias) + return aliases + + # --- Schema metadata validation --- + + def test_schema_version_matches_code_constant(self): + """SMS schemaVersion matches EVENT_SCHEMA_VER in code.""" + sms_schema = self._load_sms_schema() + assert sms_schema["schemaMeta"]["schemaVersion"] == schema.EVENT_SCHEMA_VER + + def test_client_id_matches_code_constant(self): + """SMS clientId matches CLIENT_ID in code.""" + sms_schema = self._load_sms_schema() + assert sms_schema["schemaMeta"]["clientId"] == schema.CLIENT_ID + + def test_schema_has_draft07(self): + """SMS schema uses JSON Schema draft-07.""" + sms_schema = self._load_sms_schema() + assert sms_schema["$schema"] == "http://json-schema.org/draft-07/schema#" + + def test_schema_has_two_events(self): + """SMS schema defines exactly two events.""" + sms_schema = self._load_sms_schema() + events = sms_schema["definitions"]["events"] + assert set(events.keys()) == {"trtllm_initial_report", "trtllm_heartbeat"} + + # --- TrtllmInitialReport <-> trtllm_initial_report field sync --- + + def test_initial_report_fields_match_sms(self): + """Every TrtllmInitialReport field (by alias) exists in SMS initial_report.""" + sms_schema = self._load_sms_schema() + sms_props = set( + sms_schema["definitions"]["events"]["trtllm_initial_report"]["properties"].keys() + ) + pydantic_aliases = self._get_pydantic_aliases(schema.TrtllmInitialReport) + + missing_in_sms = pydantic_aliases - sms_props + assert not missing_in_sms, ( + f"Fields in TrtllmInitialReport but missing from SMS schema: {sorted(missing_in_sms)}" + ) + + def test_sms_initial_report_fields_match_pydantic(self): + """Every SMS initial_report property exists in TrtllmInitialReport.""" + sms_schema = self._load_sms_schema() + sms_props = set( + sms_schema["definitions"]["events"]["trtllm_initial_report"]["properties"].keys() + ) + pydantic_aliases = self._get_pydantic_aliases(schema.TrtllmInitialReport) + + missing_in_pydantic = sms_props - pydantic_aliases + assert not missing_in_pydantic, ( + f"Fields in SMS schema but missing from TrtllmInitialReport: " + f"{sorted(missing_in_pydantic)}" + ) + + def test_initial_report_has_all_expected_fields(self): + """SMS trtllm_initial_report contains all expected telemetry fields.""" + sms_schema = self._load_sms_schema() + props = set( + sms_schema["definitions"]["events"]["trtllm_initial_report"]["properties"].keys() + ) + + expected_fields = { + "trtllmVersion", + "platform", + "pythonVersion", + "cpuArchitecture", + "cpuCount", + "gpuCount", + "gpuName", + "gpuMemoryMB", + "cudaVersion", + "architectureClassName", + "backend", + "tensorParallelSize", + "pipelineParallelSize", + "contextParallelSize", + "moeExpertParallelSize", + "moeTensorParallelSize", + "dtype", + "quantizationAlgo", + "kvCacheDtype", + "ingressPoint", + "featuresJson", + "disaggRole", + "deploymentId", + } + + missing = expected_fields - props + assert not missing, f"Missing expected initial_report fields: {missing}" + + extra = props - expected_fields + assert not extra, ( + f"Unexpected initial_report fields: {extra}. " + "If intentional, add them to the expected set in this test." + ) + + # --- TrtllmHeartbeat <-> trtllm_heartbeat field sync --- + + def test_heartbeat_fields_match_sms(self): + """Every TrtllmHeartbeat field (by alias) exists in SMS heartbeat.""" + sms_schema = self._load_sms_schema() + sms_props = set( + sms_schema["definitions"]["events"]["trtllm_heartbeat"]["properties"].keys() + ) + pydantic_aliases = self._get_pydantic_aliases(schema.TrtllmHeartbeat) + + missing_in_sms = pydantic_aliases - sms_props + assert not missing_in_sms, ( + f"Fields in TrtllmHeartbeat but missing from SMS schema: {sorted(missing_in_sms)}" + ) + + def test_sms_heartbeat_fields_match_pydantic(self): + """Every SMS heartbeat property exists in TrtllmHeartbeat.""" + sms_schema = self._load_sms_schema() + sms_props = set( + sms_schema["definitions"]["events"]["trtllm_heartbeat"]["properties"].keys() + ) + pydantic_aliases = self._get_pydantic_aliases(schema.TrtllmHeartbeat) + + missing_in_pydantic = sms_props - pydantic_aliases + assert not missing_in_pydantic, ( + f"Fields in SMS schema but missing from TrtllmHeartbeat: {sorted(missing_in_pydantic)}" + ) + + def test_heartbeat_required_fields(self): + """SMS trtllm_heartbeat requires exactly 'seq'.""" + sms_schema = self._load_sms_schema() + required = sms_schema["definitions"]["events"]["trtllm_heartbeat"]["required"] + assert required == ["seq"] + + def test_initial_report_required_fields(self): + """SMS trtllm_initial_report requires all declared fields.""" + sms_schema = self._load_sms_schema() + required = set(sms_schema["definitions"]["events"]["trtllm_initial_report"]["required"]) + all_props = set( + sms_schema["definitions"]["events"]["trtllm_initial_report"]["properties"].keys() + ) + assert required == all_props, ( + f"Expected all properties to be required.\n" + f" Missing from required: {all_props - required}\n" + f" Extra in required: {required - all_props}" + ) + + # --- Additional properties enforcement --- + + def test_initial_report_no_additional_properties(self): + """SMS trtllm_initial_report has additionalProperties: false.""" + sms_schema = self._load_sms_schema() + event = sms_schema["definitions"]["events"]["trtllm_initial_report"] + assert event["additionalProperties"] is False + + def test_heartbeat_no_additional_properties(self): + """SMS trtllm_heartbeat has additionalProperties: false.""" + sms_schema = self._load_sms_schema() + event = sms_schema["definitions"]["events"]["trtllm_heartbeat"] + assert event["additionalProperties"] is False + + # --- Privacy / PII guard --- + + def test_no_pii_fields_in_initial_report(self): + """Initial report must NOT contain PII or sensitive fields.""" + sms_schema = self._load_sms_schema() + props = set( + sms_schema["definitions"]["events"]["trtllm_initial_report"]["properties"].keys() + ) + + forbidden_fields = { + "num_layers", + "numLayers", + "hidden_size", + "hiddenSize", + "num_attention_heads", + "numAttentionHeads", + "model_type", + "modelType", + "userId", + "userName", + "hostName", + "hostname", + "macAddress", + "ipAddress", + "modelName", + "modelPath", + "ttft", + "tpot", + "latency", + "throughput", + "tokensPerSecond", + } + + found = forbidden_fields & props + assert not found, ( + f"Forbidden fields found in SMS schema: {found}. " + "These fields violate privacy/compliance constraints." + ) + + # --- GDPR metadata --- + + def test_events_have_gdpr_metadata(self): + """Both events have GDPR metadata in eventMeta.""" + sms_schema = self._load_sms_schema() + for event_name in ("trtllm_initial_report", "trtllm_heartbeat"): + event = sms_schema["definitions"]["events"][event_name] + assert "eventMeta" in event, f"{event_name} missing eventMeta" + assert "gdpr" in event["eventMeta"], f"{event_name} missing gdpr in eventMeta" + gdpr = event["eventMeta"]["gdpr"] + assert gdpr["category"] == "functional" + assert "description" in gdpr + + # --- GXT envelope --- + + def test_envelope_contains_all_gxt_v16_keys(self): + """GxtPayload model contains all GXT Event Protocol v1.6 envelope fields.""" + payload = schema.GxtPayload( + clientVer="0.18.0", + sentTs="2026-01-01T00:00:00.000Z", + sessionId="test", + events=[], + ) + serialized = payload.model_dump(by_alias=True) + props = set(serialized.keys()) + + gxt_v16_fields = { + "clientId", + "clientType", + "clientVariant", + "clientVer", + "cpuArchitecture", + "eventProtocol", + "eventSchemaVer", + "eventSysVer", + "sentTs", + "sessionId", + "browserType", + "deviceId", + "deviceMake", + "deviceModel", + "deviceOS", + "deviceOSVersion", + "deviceType", + "userId", + "externalUserId", + "idpId", + "integrationId", + "productName", + "productVersion", + "gdprBehOptIn", + "gdprFuncOptIn", + "gdprTechOptIn", + "deviceGdprBehOptIn", + "deviceGdprFuncOptIn", + "deviceGdprTechOptIn", + "events", + } + + missing = gxt_v16_fields - props + assert not missing, ( + f"GXT v1.6 fields missing from envelope: {missing}. " + "The GXT endpoint will reject payloads without these fields." + ) + + extra = props - gxt_v16_fields + assert not extra, ( + f"Unexpected envelope fields: {extra}. " + "If intentional, add them to gxt_v16_fields in this test." + ) + + # --- JSON schema validation (from collapsed TestSchemaDriftDetection) --- + + def test_initial_report_validates_against_json_schema(self): + """A fully-populated TrtllmInitialReport must validate against the JSON schema.""" + import jsonschema + + sms_schema = json.loads(schemas.SMS_SCHEMA_PATH.read_text()) + report = schema.TrtllmInitialReport( + trtllmVersion="1.0", + platform="Linux", + pythonVersion="3.10", + cpuArchitecture="x86_64", + cpuCount=8, + gpuCount=1, + gpuName="H100", + gpuMemoryMB=81920, + cudaVersion="12.0", + architectureClassName="LlamaForCausalLM", + backend="pytorch", + tensorParallelSize=1, + pipelineParallelSize=1, + contextParallelSize=1, + moeExpertParallelSize=0, + moeTensorParallelSize=0, + dtype="float16", + quantizationAlgo="", + kvCacheDtype="", + ingressPoint="llm_class", + featuresJson='{"lora":false}', + disaggRole="", + deploymentId="", + ) + payload = report.model_dump(by_alias=True) + initial_schema = sms_schema["definitions"]["events"]["trtllm_initial_report"].copy() + initial_schema["definitions"] = sms_schema["definitions"] + jsonschema.validate(instance=payload, schema=initial_schema) + + def test_heartbeat_validates_against_json_schema(self): + """A TrtllmHeartbeat must validate against the JSON schema.""" + import jsonschema + + sms_schema = json.loads(schemas.SMS_SCHEMA_PATH.read_text()) + heartbeat = schema.TrtllmHeartbeat(seq=0) + payload = heartbeat.model_dump(by_alias=True) + hb_schema = sms_schema["definitions"]["events"]["trtllm_heartbeat"].copy() + hb_schema["definitions"] = sms_schema["definitions"] + jsonschema.validate(instance=payload, schema=hb_schema) + + +# --------------------------------------------------------------------------- +# Pydantic field constraint validation tests +# --------------------------------------------------------------------------- + + +class TestPydanticValidation: + """Pydantic field constraint enforcement (max_length, ge/le).""" + + def test_initial_report_rejects_overlength_short_string(self): + """ShortString field (max_length=128) rejects 129-char value.""" + with pytest.raises(ValidationError): + schema.TrtllmInitialReport(trtllmVersion="a" * 129) + + def test_initial_report_rejects_overlength_long_string(self): + """LongString field (max_length=256) rejects 257-char value.""" + with pytest.raises(ValidationError): + schema.TrtllmInitialReport(platform="a" * 257) + + def test_initial_report_accepts_max_length_string(self): + """ShortString field accepts exactly 128 chars (boundary).""" + report = schema.TrtllmInitialReport(trtllmVersion="a" * 128) + assert len(report.trtllm_version) == 128 + + def test_heartbeat_rejects_negative_seq(self): + """TrtllmHeartbeat rejects seq < 0 (ge=0 constraint).""" + with pytest.raises(ValidationError): + schema.TrtllmHeartbeat(seq=-1) + + def test_heartbeat_rejects_overflow_seq(self): + """TrtllmHeartbeat rejects seq > uint32 max (le=4294967295).""" + with pytest.raises(ValidationError): + schema.TrtllmHeartbeat(seq=4294967296) + + def test_initial_report_rejects_negative_int(self): + """PositiveInt field (ge=0) rejects negative value.""" + with pytest.raises(ValidationError): + schema.TrtllmInitialReport(cpuCount=-1) diff --git a/tests/unittest/usage/test_transport.py b/tests/unittest/usage/test_transport.py new file mode 100644 index 00000000000..ad69d40bde8 --- /dev/null +++ b/tests/unittest/usage/test_transport.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for HTTP transport (_send_to_gxt) and live staging endpoint.""" + +import json +import urllib.request +from unittest.mock import MagicMock, patch + +import pytest + +from tensorrt_llm.usage import schema, usage_lib + +# --------------------------------------------------------------------------- +# HTTP transport tests +# --------------------------------------------------------------------------- + + +class TestSendToGxt: + def test_send_fail_silent(self): + """_send_to_gxt never raises on network error.""" + with patch( + "tensorrt_llm.usage.usage_lib._get_stats_server", + return_value="http://192.0.2.1/nonexistent", + ): + usage_lib._send_to_gxt({"test": "data"}) # Must not raise + + def test_send_uses_json_content_type(self): + captured_req = {} + + class MockOpener: + def open(self, req, timeout=None): + captured_req["headers"] = dict(req.headers) + captured_req["data"] = req.data + captured_req["method"] = req.method + return MagicMock() + + with patch("urllib.request.build_opener", return_value=MockOpener()): + usage_lib._send_to_gxt({"key": "value"}) + + assert captured_req["headers"]["Content-type"] == "application/json" + assert captured_req["headers"]["Accept"] == "application/json" + assert captured_req["method"] == "POST" + assert json.loads(captured_req["data"]) == {"key": "value"} + + def test_send_to_gxt_does_not_follow_redirects(self): + """Custom opener excludes HTTPRedirectHandler (SSRF protection).""" + captured_handlers = [] + + def mock_build_opener(*handlers): + captured_handlers.extend(handlers) + return MagicMock() + + with patch("urllib.request.build_opener", side_effect=mock_build_opener): + usage_lib._send_to_gxt({"test": True}) + + handler_types = set(captured_handlers) + assert urllib.request.HTTPRedirectHandler not in handler_types + assert usage_lib._NoRedirectHandler in handler_types + + def test_no_redirect_handler_blocks_redirects(self): + """_NoRedirectHandler rejects all redirect responses.""" + handler = usage_lib._NoRedirectHandler() + with pytest.raises(urllib.error.HTTPError) as exc_info: + handler.redirect_request( + MagicMock(full_url="http://example.com"), + None, + 302, + "Found", + {}, + "http://evil.com", + ) + assert exc_info.value.code == 302 + + def test_real_opener_lacks_default_redirect_handler(self): + """Verify the real opener built by build_opener has no HTTPRedirectHandler.""" + opener = urllib.request.build_opener( + urllib.request.HTTPHandler, + urllib.request.HTTPSHandler, + usage_lib._NoRedirectHandler, + ) + handler_names = [h.__class__.__name__ for h in opener.handlers] + assert "HTTPRedirectHandler" not in handler_names + assert "_NoRedirectHandler" in handler_names + + def test_send_to_gxt_catches_url_error(self, monkeypatch): + """_send_to_gxt silently handles URLError.""" + monkeypatch.setenv("TRTLLM_USAGE_STATS_SERVER", "http://localhost:1") + usage_lib._send_to_gxt({"test": True}) # should not raise + + +# --------------------------------------------------------------------------- +# HTTPS handler tests +# --------------------------------------------------------------------------- + + +class TestHttpsHandler: + def test_opener_has_https_handler(self): + """Opener includes HTTPSHandler for HTTPS endpoints.""" + captured_handlers = [] + + def mock_build_opener(*handlers): + captured_handlers.extend(handlers) + return MagicMock() + + with patch("urllib.request.build_opener", side_effect=mock_build_opener): + usage_lib._send_to_gxt({"test": True}) + + handler_set = set(captured_handlers) + has_http = urllib.request.HTTPHandler in handler_set or any( + isinstance(h, urllib.request.HTTPHandler) for h in captured_handlers + ) + has_https = urllib.request.HTTPSHandler in handler_set or any( + isinstance(h, urllib.request.HTTPSHandler) for h in captured_handlers + ) + assert has_http, f"HTTPHandler not found in {captured_handlers}" + assert has_https, f"HTTPSHandler not found in {captured_handlers}" + + +# --------------------------------------------------------------------------- +# Malformed server URL tests +# --------------------------------------------------------------------------- + + +class TestMalformedServerUrl: + """Verify _send_to_gxt handles malformed server URLs without crashing.""" + + def test_empty_server_url_fail_silent(self, monkeypatch): + """Empty TRTLLM_USAGE_STATS_SERVER doesn't crash.""" + monkeypatch.setenv("TRTLLM_USAGE_STATS_SERVER", "") + usage_lib._send_to_gxt({"test": True}) # should not raise + + def test_non_http_scheme_fail_silent(self, monkeypatch): + """Non-HTTP scheme doesn't crash.""" + monkeypatch.setenv("TRTLLM_USAGE_STATS_SERVER", "ftp://bad.example.com") + usage_lib._send_to_gxt({"test": True}) + + def test_garbage_url_fail_silent(self, monkeypatch): + """Completely invalid URL doesn't crash.""" + monkeypatch.setenv("TRTLLM_USAGE_STATS_SERVER", "not-a-url") + usage_lib._send_to_gxt({"test": True}) + + +# --------------------------------------------------------------------------- +# Live staging endpoint tests (opt-in via --run-staging or -m staging) +# --------------------------------------------------------------------------- + +_STAGING_ENDPOINT = "https://events.gfestage.nvidia.com/v1.1/events/json" + + +def _post_to_staging(payload: dict, timeout: float = 10.0) -> int: + """POST payload to GXT staging and return HTTP status code.""" + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + _STAGING_ENDPOINT, + data=data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + method="POST", + ) + resp = urllib.request.urlopen(req, timeout=timeout) + return resp.status + + +@pytest.mark.skipif( + "not config.getoption('--run-staging', default=False)", + reason="Live staging tests require --run-staging flag", +) +@pytest.mark.staging +class TestStagingEndpoint: + """Live tests that send payloads to the GXT staging endpoint. + + These are opt-in only -- they require network access to + events.gfestage.nvidia.com and are gated behind ``--run-staging``. + """ + + def test_initial_report_accepted(self): + """GXT staging accepts a well-formed trtllm_initial_report envelope (HTTP 200).""" + import os + import platform as plat + import uuid + + report = schema.TrtllmInitialReport( + trtllmVersion="0.0.0-test", + platform=plat.platform(), + pythonVersion=plat.python_version(), + cpuArchitecture=plat.machine(), + cpuCount=os.cpu_count() or 0, + gpuCount=0, + gpuName="", + gpuMemoryMB=0, + cudaVersion="", + architectureClassName="TestModel", + backend="pytorch", + tensorParallelSize=1, + pipelineParallelSize=1, + contextParallelSize=1, + moeExpertParallelSize=0, + moeTensorParallelSize=0, + dtype="float16", + quantizationAlgo="", + kvCacheDtype="", + ingressPoint="cli_serve", + featuresJson='{"lora":false,"speculative_decoding":false,"prefix_caching":false,"cuda_graphs":false,"chunked_context":false,"data_parallel_size":1}', + disaggRole="", + deploymentId="", + ) + payload = schema.build_gxt_payload( + event=report, + session_id=uuid.uuid4().hex, + trtllm_version="0.0.0-test", + ) + + status = _post_to_staging(payload) + assert status == 200, f"Expected HTTP 200, got {status}" + + def test_heartbeat_accepted(self): + """GXT staging accepts a well-formed trtllm_heartbeat envelope (HTTP 200).""" + import uuid + + heartbeat = schema.TrtllmHeartbeat(seq=0) + payload = schema.build_gxt_payload( + event=heartbeat, + session_id=uuid.uuid4().hex, + trtllm_version="0.0.0-test", + ) + + status = _post_to_staging(payload) + assert status == 200, f"Expected HTTP 200, got {status}" + + def test_ingress_point_in_accepted_payload(self): + """Staging accepts payloads containing ingressPoint without envelope rejection.""" + import os + import platform as plat + import uuid + + for context_value in ("cli_serve", "cli_bench", "cli_eval", "llm_class", "unknown", ""): + report = schema.TrtllmInitialReport( + trtllmVersion="0.0.0-test", + platform=plat.platform(), + pythonVersion=plat.python_version(), + cpuArchitecture=plat.machine(), + cpuCount=os.cpu_count() or 0, + gpuCount=0, + gpuName="", + gpuMemoryMB=0, + cudaVersion="", + architectureClassName="TestModel", + backend="pytorch", + tensorParallelSize=1, + pipelineParallelSize=1, + contextParallelSize=1, + moeExpertParallelSize=0, + moeTensorParallelSize=0, + dtype="float16", + quantizationAlgo="", + kvCacheDtype="", + ingressPoint=context_value, + featuresJson="{}", + disaggRole="", + deploymentId="", + ) + payload = schema.build_gxt_payload( + event=report, + session_id=uuid.uuid4().hex, + trtllm_version="0.0.0-test", + ) + + status = _post_to_staging(payload) + assert status == 200, f"Staging rejected ingressPoint={context_value!r}: HTTP {status}"