diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8dfea644a5..deab0c480c 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -4,6 +4,9 @@ import os import sys +import copy +import json +import traceback import logging from contextlib import nullcontext import torch @@ -21,6 +24,7 @@ Float8CurrentScalingQuantizer, MXFP8Quantizer, ) +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, @@ -209,10 +213,10 @@ def run_dpa_with_cp( os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] + config = copy.deepcopy(model_configs_flash_attn[model]) if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] + config = copy.deepcopy(model_configs_fused_attn[model]) assert config.attn_mask_type in [ "causal", "no_mask", @@ -223,23 +227,17 @@ def run_dpa_with_cp( else: config.attn_mask_type = "padding" - # set up distributed group - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - else: - device_count = torch.cuda.device_count() - device = rank % device_count - torch.cuda.set_device(device) + # Process group is managed by main(); one init/destroy per torchrun, not per config. + assert dist.is_initialized(), "dist.init_process_group must be called before run_dpa_with_cp" + world_size = dist.get_world_size() + rank = dist.get_rank() logging.info(f"[Rank {rank}] Setup: world_size {world_size}") - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) # set up communication group for CP cp_comm_ranks = range(world_size) assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + cp_comm_sub_groups = [] if cp_comm_type == "a2a+p2p": assert world_size % 2 == 0, ( "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" @@ -247,504 +245,524 @@ def run_dpa_with_cp( ) cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - cp_comm_sub_groups = [] for sub_ranks in cp_comm_sub_ranks: sub_group = dist.new_group(sub_ranks, backend="nccl") if rank in sub_ranks: cp_comm_sub_groups.append(sub_group) - if dtype == "fp8": + try: + if dtype == "fp8": + if scaling_mode == "delayed": + fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "current": + fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "mxfp8": + fp8_recipe = MXFP8BlockScaling( + fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha + ) + + # instantiate attention module + core_attn = DotProductAttention( + config.num_heads, + (config.head_dim_qk, config.head_dim_v), + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + window_size=config.window_size, + softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, + ).cuda() + if not is_training: + core_attn.eval() + if is_training and config.softmax_type != "vanilla": + core_attn.softmax_offset.requires_grad = True + + # generate attention inputs + ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + dout_orig = torch.clamp( + torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 + ).cuda() if scaling_mode == "delayed": - fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + qkv_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) if scaling_mode == "current": - fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + qkv_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + dout_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + device="cuda", + ) if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) - - # instantiate attention module - core_attn = DotProductAttention( - config.num_heads, - (config.head_dim_qk, config.head_dim_v), - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type, - window_size=config.window_size, - softmax_type=config.softmax_type, - return_max_logit=config.return_max_logit, - ).cuda() - if not is_training: - core_attn.eval() - if is_training and config.softmax_type != "vanilla": - core_attn.softmax_offset.requires_grad = True - - # generate attention inputs - ( - q_input_shape, - k_input_shape, - v_input_shape, - attn_output_shape, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) - q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - dout_orig = torch.clamp( - torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 - ).cuda() - if scaling_mode == "delayed": - qkv_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), - ) - dout_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), - ) - if scaling_mode == "current": - qkv_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - device="cuda", - ) - dout_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - device="cuda", - ) - if scaling_mode == "mxfp8": - qkv_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=True, - columnwise=True, - ) - qkv_quantizer.optimize_for_gemm = True - qkv_quantizer.internal = False - dout_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - rowwise=True, - columnwise=True, - ) - dout_quantizer.optimize_for_gemm = True - dout_quantizer.internal = False - qkv_layout = "_".join([qkv_format] * 3) - q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] - if fp8_mha: - q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) - for x in [q, k, v]: - x.requires_grad = True - - if config.attn_bias_type not in ["no_bias", "alibi"]: - bias_shape_map = { - "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), - "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), - "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), - "bhss": ( - config.batch_size, - config.num_heads, - config.max_seqlen_q, - config.max_seqlen_kv, - ), - "111s": (1, 1, 1, config.max_seqlen_kv), - } - attn_bias_shape = bias_shape_map.get(config.bias_shape) - if attn_bias_shape is None: - assert False, f"cuDNN does not support {config.bias_shape=}" - bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() - # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 - # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported - bias.requires_grad = True if config.bias_shape != "111s" else False - else: - bias = None - - ############ run without CP ############ - logging.info(f"[Rank {rank}] Run without context parallelism") - if dtype == "fp8": - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) - else: - fp8_context = nullcontext() - max_logit = None - with fp8_context: - # q, k, v, out in FP8; dout in F16 - out = core_attn( - q, - k, - v, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, - ) - if config.return_max_logit: - out, max_logit = out - if is_training: - if fp8_bwd and fp8_mha: - dout_fp8 = dout_quantizer(dout) - out.backward(dout_fp8) - else: - out.backward(dout) - if is_training: - dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None - d_softmax_offset = ( - core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None - ) - else: - dq, dk, dv, dbias = None, None, None, None - d_softmax_offset = None - - ############ run with CP ############ - logging.info(f"[Rank {rank}] Run with context parallelism") - - # set up inputs - q_, k_, v_, dout_, *rest = [ - x.clone().detach() - for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) - ] - bias_ = rest[0] if len(rest) else None - if qkv_format == "bshd" or qkv_format == "sbhd": - seq_dim = qkv_format.index("s") - q_, k_, v_, dout_ = [ - x.view( - *x.shape[:seq_dim], - 2 * world_size, - x.shape[seq_dim] // (2 * world_size), - *x.shape[(seq_dim + 1) :], + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, ) - for x in [q_, k_, v_, dout_] - ] - seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) - q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] - q_, k_, v_, dout_ = [ - x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] - ] - elif qkv_format == "thd": - seq_idx_q = tex.thd_get_partitioned_indices( - cu_seqlens_q_padded, q_.shape[0], world_size, rank - ) - seq_idx_kv = tex.thd_get_partitioned_indices( - cu_seqlens_kv_padded, k_.shape[0], world_size, rank - ) - q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] - k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] - if scaling_mode == "delayed": - qkv_quantizer.scale.fill_(1.0) - qkv_quantizer.amax.fill_(0.0) - dout_quantizer.scale.fill_(1.0) - dout_quantizer.amax.fill_(0.0) - if fp8_mha: - q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) - if is_training: - q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] - if bias_ is not None: - ndim = bias_.ndim - seq_q_dim = ndim - 2 - if qkv_format == "thd": - bias_seq_idx = seq_idx_q + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False + qkv_layout = "_".join([qkv_format] * 3) + q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] + if fp8_mha: + q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + for x in [q, k, v]: + x.requires_grad = True + + if config.attn_bias_type not in ["no_bias", "alibi"]: + bias_shape_map = { + "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), + "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), + "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), + "bhss": ( + config.batch_size, + config.num_heads, + config.max_seqlen_q, + config.max_seqlen_kv, + ), + "111s": (1, 1, 1, config.max_seqlen_kv), + } + attn_bias_shape = bias_shape_map.get(config.bias_shape) + if attn_bias_shape is None: + assert False, f"cuDNN does not support {config.bias_shape=}" + bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() + # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 + # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported + bias.requires_grad = True if config.bias_shape != "111s" else False else: - bias_seq_idx = seq_idx - shape_before_seq = bias_.shape[:seq_q_dim] - seq_q_size = bias_.shape[seq_q_dim] - seq_kv_size = bias_.shape[-1] - if seq_q_size == 1: - # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s - bias_.requires_grad = False - # Bias is broadcast, no need to partition along sequence dimension - pass + bias = None + + ############ run without CP ############ + logging.info(f"[Rank {rank}] Run without context parallelism") + if dtype == "fp8": + fp8_context = autocast( + enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group + ) else: - bias_ = bias_.view( - *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + fp8_context = nullcontext() + max_logit = None + with fp8_context: + # q, k, v, out in FP8; dout in F16 + out = core_attn( + q, + k, + v, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) - bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) - bias_.requires_grad = True - # set up environment - core_attn.set_context_parallel_group( - cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, - cp_comm_ranks, - torch.cuda.Stream(), - cp_comm_type, - ) - if config.softmax_type != "vanilla": - core_attn.softmax_offset.grad.zero_() - if dtype == "fp8": - core_attn.fp8_initialized = False - core_attn.fp8_meta_tensors_initialized = False - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) - else: - fp8_context = nullcontext() - - # run attention - max_logit_ = None - with fp8_context: - # q, k, v, out in FP8; dout in F16 - out_ = core_attn( - q_, - k_, - v_, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias_, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, - ) - if config.return_max_logit: - out_, max_logit_ = out_ - if is_training: - if fp8_bwd and fp8_mha: - dout_fp8_ = dout_quantizer(dout_) - out_.backward(dout_fp8_) - else: - out_.backward(dout_) - if is_training: - dq_, dk_, dv_, dbias_ = ( - q_.grad, - k_.grad, - v_.grad, - bias_.grad if bias_ is not None else None, - ) - d_softmax_offset_ = ( - core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None - ) - else: - dq_, dk_, dv_, dbias_ = None, None, None, None - d_softmax_offset_ = None - - # get outputs - tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] - names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] - if fp8_mha: - tensors_to_deq = [out, out_] if not fp8_bwd else tensors - for i, tensor in enumerate(tensors_to_deq): - # dbias/dbias_ could be None, so skip check for it - if tensor is not None: - tensors_to_deq[i] = tensor.dequantize() - if not fp8_bwd: - tensors[0], tensors[5] = tensors_to_deq - for i, tensor in enumerate(tensors): - # dbias/dbias_ could be None, so skip check for it - if tensor is not None: - assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" - assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" - out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors - - ############ compare results between CP and no-CP ############ - if qkv_format == "bshd" or qkv_format == "sbhd": + if config.return_max_logit: + out, max_logit = out + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8 = dout_quantizer(dout) + out.backward(dout_fp8) + else: + out.backward(dout) if is_training: - dq, dk, dv, out = [ + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None + d_softmax_offset = ( + core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None + ) + else: + dq, dk, dv, dbias = None, None, None, None + d_softmax_offset = None + + ############ run with CP ############ + logging.info(f"[Rank {rank}] Run with context parallelism") + + # set up inputs + q_, k_, v_, dout_, *rest = [ + x.clone().detach() + for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) + ] + bias_ = rest[0] if len(rest) else None + if qkv_format == "bshd" or qkv_format == "sbhd": + seq_dim = qkv_format.index("s") + q_, k_, v_, dout_ = [ x.view( *x.shape[:seq_dim], 2 * world_size, x.shape[seq_dim] // (2 * world_size), *x.shape[(seq_dim + 1) :], ) - for x in [dq, dk, dv, out] + for x in [q_, k_, v_, dout_] ] - dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] - dq_, dk_, dv_, out_ = [ - x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [dq_, dk_, dv_, out_] + seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) + q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] + q_, k_, v_, dout_ = [ + x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + for x in [q_, k_, v_, dout_] ] - if dbias is not None and dbias_ is not None: - ndim = dbias.ndim - # Query seq is at dim -2 - seq_q_dim = ndim - 2 - shape_before_seq = dbias.shape[:seq_q_dim] - seq_q_size = dbias.shape[seq_q_dim] - seq_kv_size = dbias.shape[-1] - # Reshape to split seq_q dimension - dbias = dbias.view( + elif qkv_format == "thd": + seq_idx_q = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, q_.shape[0], world_size, rank + ) + seq_idx_kv = tex.thd_get_partitioned_indices( + cu_seqlens_kv_padded, k_.shape[0], world_size, rank + ) + q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] + k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" + q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] + if scaling_mode == "delayed": + qkv_quantizer.scale.fill_(1.0) + qkv_quantizer.amax.fill_(0.0) + dout_quantizer.scale.fill_(1.0) + dout_quantizer.amax.fill_(0.0) + if fp8_mha: + q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + if is_training: + q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] + if bias_ is not None: + ndim = bias_.ndim + seq_q_dim = ndim - 2 + if qkv_format == "thd": + bias_seq_idx = seq_idx_q + else: + bias_seq_idx = seq_idx + shape_before_seq = bias_.shape[:seq_q_dim] + seq_q_size = bias_.shape[seq_q_dim] + seq_kv_size = bias_.shape[-1] + if seq_q_size == 1: + # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s + bias_.requires_grad = False + # Bias is broadcast, no need to partition along sequence dimension + pass + else: + bias_ = bias_.view( *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size ) - # Index select on the newly created dimension (now at position seq_q_dim) - dbias = dbias.index_select(seq_q_dim, seq_idx) - dbias_ = dbias_.view( - *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size - ) - else: - # Forward-only: reshape only out/out_ for comparison - out = out.view( - *out.shape[:seq_dim], - 2 * world_size, - out.shape[seq_dim] // (2 * world_size), - *out.shape[(seq_dim + 1) :], + bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) + bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) + bias_.requires_grad = True + # set up environment + core_attn.set_context_parallel_group( + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, + ) + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() + if dtype == "fp8": + core_attn.fp8_initialized = False + core_attn.fp8_meta_tensors_initialized = False + fp8_context = autocast( + enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group ) - out = out.index_select(seq_dim, seq_idx) - out_ = out_.view( - *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] + else: + fp8_context = nullcontext() + + # run attention + max_logit_ = None + with fp8_context: + # q, k, v, out in FP8; dout in F16 + out_ = core_attn( + q_, + k_, + v_, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - - elif qkv_format == "thd": + if config.return_max_logit: + out_, max_logit_ = out_ + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8_ = dout_quantizer(dout_) + out_.backward(dout_fp8_) + else: + out_.backward(dout_) if is_training: - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded // world_size - cu_seqlens_q = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + dq_, dk_, dv_, dbias_ = ( + q_.grad, + k_.grad, + v_.grad, + bias_.grad if bias_ is not None else None, ) - cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q - num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] - for x in [dq, out, dq_, out_]: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] - ] - ).item() - == 0 - ) - cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size - cu_seqlens_kv = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + d_softmax_offset_ = ( + core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None ) - cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv - num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] - for x in [dk, dv, dk_, dv_]: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] - ).item() - == 0 - ) else: - # Forward-only: reshape only out/out_ for comparison - out = out.index_select(0, seq_idx_q).contiguous() - out_ = out_ - - atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] - tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] - names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] - names_cp = [x + "_cp" for x in names] - names_no_cp = [x + "_no_cp" for x in names] - is_fp8 = dtype == "fp8" - for i, t in enumerate(tensors_no_cp): - if t is not None: - if "softmax_offset" not in names[i] and "max_logit" not in names[i]: - if qkv_format == "bshd": - # Compare the two sequence chunks separately - # Compare dbias - if names[i] == "dbias": - # Compare the two chunks along dimension 2 (the split sequence dimension) - seq_q_dim_bias = 2 - ndim_bias = t.ndim - slice_0 = [slice(None)] * ndim_bias - slice_0[seq_q_dim_bias] = 0 - slice_1 = [slice(None)] * ndim_bias - slice_1[seq_q_dim_bias] = 1 - compare_and_assert( - t[tuple(slice_0)], - tensors_cp[i][tuple(slice_0)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[tuple(slice_1)], - tensors_cp[i][tuple(slice_1)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - # Compare Q/K/V/out - else: - # Compare the two chunks along dimension 1 (the split sequence dimension) - compare_and_assert( - t[:, 0], - tensors_cp[i][:, 0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[:, 1], - tensors_cp[i][:, 1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - elif qkv_format == "sbhd": - # Compare the two sequence chunks separately - # Compare dbias (same as BSHD) - if names[i] == "dbias": - # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) - seq_q_dim_bias = 2 - ndim_bias = t.ndim - slice_0 = [slice(None)] * ndim_bias - slice_0[seq_q_dim_bias] = 0 - slice_1 = [slice(None)] * ndim_bias - slice_1[seq_q_dim_bias] = 1 - compare_and_assert( - t[tuple(slice_0)], - tensors_cp[i][tuple(slice_0)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[tuple(slice_1)], - tensors_cp[i][tuple(slice_1)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, + dq_, dk_, dv_, dbias_ = None, None, None, None + d_softmax_offset_ = None + + # get outputs + tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] + names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] + if fp8_mha: + tensors_to_deq = [out, out_] if not fp8_bwd else tensors + for i, tensor in enumerate(tensors_to_deq): + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + tensors_to_deq[i] = tensor.dequantize() + if not fp8_bwd: + tensors[0], tensors[5] = tensors_to_deq + for i, tensor in enumerate(tensors): + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" + assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" + out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors + + ############ compare results between CP and no-CP ############ + if qkv_format == "bshd" or qkv_format == "sbhd": + if is_training: + dq, dk, dv, out = [ + x.view( + *x.shape[:seq_dim], + 2 * world_size, + x.shape[seq_dim] // (2 * world_size), + *x.shape[(seq_dim + 1) :], + ) + for x in [dq, dk, dv, out] + ] + dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] + dq_, dk_, dv_, out_ = [ + x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) + for x in [dq_, dk_, dv_, out_] + ] + if dbias is not None and dbias_ is not None: + ndim = dbias.ndim + # Query seq is at dim -2 + seq_q_dim = ndim - 2 + shape_before_seq = dbias.shape[:seq_q_dim] + seq_q_size = dbias.shape[seq_q_dim] + seq_kv_size = dbias.shape[-1] + # Reshape to split seq_q dimension + dbias = dbias.view( + *shape_before_seq, + 2 * world_size, + seq_q_size // (2 * world_size), + seq_kv_size, + ) + # Index select on the newly created dimension (now at position seq_q_dim) + dbias = dbias.index_select(seq_q_dim, seq_idx) + dbias_ = dbias_.view( + *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size + ) + else: + # Forward-only: reshape only out/out_ for comparison + out = out.view( + *out.shape[:seq_dim], + 2 * world_size, + out.shape[seq_dim] // (2 * world_size), + *out.shape[(seq_dim + 1) :], + ) + out = out.index_select(seq_dim, seq_idx) + out_ = out_.view( + *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] + ) + + elif qkv_format == "thd": + if is_training: + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] + cu_seqlens_q_padded = cu_seqlens_q_padded // world_size + cu_seqlens_q = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + ) + cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q + num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + for x in [dq, out, dq_, out_]: + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( + x[ + ( + cu_seqlens_q_padded[b + 1] - num_pads_q[b] + ) : cu_seqlens_q_padded[b + 1] + ] + ).item() + == 0 ) - # Compare Q/K/V/out - else: - # Compare the two chunks along dimension 0 (the split sequence dimension) - compare_and_assert( - t[0], - tensors_cp[i][0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, + cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size + cu_seqlens_kv = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + ) + cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv + num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] + for x in [dk, dv, dk_, dv_]: + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 + or torch.count_nonzero( + x[ + ( + cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] + ) : cu_seqlens_kv_padded[b + 1] + ] + ).item() + == 0 ) + else: + out = out.index_select(0, seq_idx_q).contiguous() + out_ = out_ + + atol, rtol, rmse_tol = get_tols(config, dtype) + tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] + names_cp = [x + "_cp" for x in names] + names_no_cp = [x + "_no_cp" for x in names] + is_fp8 = dtype == "fp8" + for i, t in enumerate(tensors_no_cp): + if t is not None: + if "softmax_offset" not in names[i] and "max_logit" not in names[i]: + if qkv_format == "bshd": + # Compare the two sequence chunks separately + # Compare dbias + if names[i] == "dbias": + # Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 1 (the split sequence dimension) + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "sbhd": + # Compare the two sequence chunks separately + # Compare dbias (same as BSHD) + if names[i] == "dbias": + # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 0 (the split sequence dimension) + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "thd": compare_and_assert( - t[1], - tensors_cp[i][1], + t, + tensors_cp[i], names_no_cp[i], names_cp[i], atol, @@ -752,24 +770,124 @@ def run_dpa_with_cp( rmse_tol, is_fp8, ) - elif qkv_format == "thd": + else: compare_and_assert( t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 ) - else: - compare_and_assert( - t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 - ) - logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") + logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") + finally: + # Destroy per-config groups so they survive exceptions in batch mode. + # The global process group is torn down by main(). + dist.destroy_process_group(cp_comm_group) + for sg in cp_comm_sub_groups: + dist.destroy_process_group(sg) + + +# Env vars set by run_dpa_with_cp; cleared between batch configs to prevent leakage. +_TRANSIENT_ENV_KEYS = ( + "NVTE_FP8_DPA_BWD", + "NVTE_DPA_FP8CS_O_in_F16", + "NVTE_FLASH_ATTN", + "NVTE_FUSED_ATTN", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO", +) + - # destroy distribution group - dist.destroy_process_group() +def _init_distributed(): + """Init NCCL process group + CUDA device once per torchrun invocation.""" + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + device_count = torch.cuda.device_count() + # Prefer LOCAL_RANK when available (set by torchrun / torch.distributed.launch); + # fall back to RANK % device_count for single-node runs. + local_rank = int(os.getenv("LOCAL_RANK", str(rank % device_count))) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + return rank, world_size + + +def _run_single_config(kwargs): + """Run one config, return ``(ok, error_message)``. + + Re-seeds RNG before each config so results are deterministic and + order-independent within a batch. + """ + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + try: + run_dpa_with_cp(**kwargs) + return True, None + except BaseException: # noqa: BLE001 - capture any failure for per-config reporting + return False, traceback.format_exc() def main(**kwargs): - run_dpa_with_cp(**kwargs) + """Entry point. + + Two modes: + + * Single-config (legacy): ``run_attention_with_cp.py key=val ...`` runs + one config, propagates exceptions for normal exit-code signalling. + * Batch: ``run_attention_with_cp.py batch_config_json=`` reads a + JSON list of kwargs dicts, runs each via ``_run_single_config``, + aggregates ``ok`` across ranks (any rank failure → False), and flushes + ``[{ok,error}, ...]`` atomically to ``.results.json`` after each + config so a worker crash mid-batch leaves earlier results intact. + Transient env vars are reset between configs; per-config NCCL groups + are torn down inside ``run_dpa_with_cp``. + """ + batch_path = kwargs.pop("batch_config_json", None) + rank, _ = _init_distributed() + try: + if batch_path is None: + run_dpa_with_cp(**kwargs) + else: + with open(batch_path, "r") as f: + configs = json.load(f) + assert isinstance( + configs, list + ), f"batch_config_json must be a JSON list, got {type(configs)}" + results_path = batch_path + ".results.json" + results = [] + + def _flush_results(): + if rank != 0: + return + # Atomic write: tmp + rename so the reader never sees partial JSON. + tmp_path = results_path + ".tmp" + with open(tmp_path, "w") as f: + json.dump(results, f) + os.replace(tmp_path, results_path) + + for cfg in configs: + FP8GlobalStateManager.reset() + for env_key in _TRANSIENT_ENV_KEYS: + os.environ.pop(env_key, None) + ok, err = _run_single_config(cfg) + # Aggregate ok across ranks so a non-rank-0 failure (e.g. a + # per-partition compare assertion that fires only on rank > 0) + # is not silently swallowed when only rank 0 writes the result. + ok_tensor = torch.tensor(1 if ok else 0, dtype=torch.int32, device="cuda") + dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) + ok_aggregate = bool(ok_tensor.item()) + if not ok_aggregate and ok and err is None: + err = "Failed on a non-zero rank (see subprocess stderr for traceback)" + results.append({"ok": ok_aggregate, "error": err}) + _flush_results() + try: + dist.barrier() + except BaseException: # noqa: BLE001 + results[-1]["ok"] = False + if results[-1]["error"] is None: + results[-1]["error"] = traceback.format_exc() + _flush_results() + break + torch.cuda.empty_cache() + finally: + if dist.is_initialized(): + dist.destroy_process_group() if __name__ == "__main__": - kwargs = dict(arg.split("=") for arg in sys.argv[2:]) + kwargs = dict(arg.split("=", 1) for arg in sys.argv[2:]) main(**kwargs) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 23d1bfdd85..176b36b7d8 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -3,8 +3,9 @@ # See LICENSE for license information. import os -import subprocess +import json import sys +import tempfile import pathlib import logging import copy @@ -41,6 +42,8 @@ test_essential = True +_BATCH_SIZE = int(os.getenv("CP_TEST_BATCH_SIZE", "16")) + model_configs_flash_attn = { # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA @@ -67,6 +70,8 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): "torch.distributed.launch", "--nproc-per-node=" + str(num_gpus_per_node), ] + if "MASTER_PORT" in os.environ: + args.append("--master-port=" + os.environ["MASTER_PORT"]) te_path = os.getenv("TE_PATH", "/opt/transformerengine") script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") args.append(script_path) @@ -75,6 +80,276 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args +# --------------------------------------------------------------------------- +# Batched dispatch — keeps test bodies identical to the non-batched flow +# (parametrize stack + inline ``pytest.skip(...)``) and replaces only the +# final ``run_distributed(...)`` call with ``_run_or_fetch(...)``. +# +# Flow: +# 1. Collect (dry-run, in-process). Session fixture ``_cp_batch_results`` +# walks each parametrized item that requests this fixture, calls it +# with a stub ``request`` (only ``request.node.nodeid``). The body runs +# its ``pytest.skip(...)`` checks; if none fire, ``_run_or_fetch`` +# records the kwargs in ``_COLLECTED_KWARGS`` instead of launching +# torchrun. ``@pytest.mark.skip(if)`` markers are evaluated up front +# via ``_item_static_skip`` so marker-skipped items aren't queued. +# 2. Batch + execute. Recorded kwargs are grouped by num_gpus_per_node, +# chunked into batches of CP_TEST_BATCH_SIZE (default 16), and each +# batch runs in one torchrun (``_run_one_batch``). Worker +# (run_attention_with_cp.py) inits NCCL once, loops over configs, +# flushes per-config results to ``.results.json`` atomically. +# 3. Execute mode (normal pytest run). The test body re-evaluates its +# skip checks; if none fire, ``_run_or_fetch`` looks up the recorded +# result by nodeid and asserts pass/fail. +# +# Failure handling: +# - Inline ``pytest.skip``: same code path as non-batched. +# - Worker assertion: surfaced as ``AssertionError`` from the JSON entry. +# - Per-rank failure: cross-rank ``dist.all_reduce(ok, op=MIN)`` in the +# worker so a rank > 0 assertion isn't swallowed by the rank-0-only flush. +# - Worker subprocess crash mid-batch: configs without flushed results are +# marked unattributed and ``_run_one_batch`` retries each as a singleton +# to identify the actual culprit. Disable via ``CP_TEST_BATCH_RETRY=0``. +# - Dry-run exception: caught; the same error fires in execute mode and +# pytest reports it as a normal test ERROR (no fixture-level cascade). +# +# To add a new batched test: write it like a non-batched CP test +# (parametrize + inline ``pytest.skip(...)``), accept +# ``request, _cp_batch_results`` as fixtures, and replace +# ``run_distributed(get_bash_arguments(...))`` with +# ``_run_or_fetch(request, _cp_batch_results, num_gpus_per_node=N, ...)``. +# Nothing else needs wiring up. +# +# Knobs: +# CP_TEST_BATCH_SIZE=N configs per torchrun; default 16; set 1 to bisect. +# CP_TEST_BATCH_RETRY=0 skip the singleton retry on unattributed crashes. +# +# Caveats: +# - ``pytest -k`` reduces what's collected and therefore what's batched. +# - The body executes once per item in collect mode (cheap; only Python +# skip logic + ``get_available_attention_backends``). +# - Mutations to module-level state during the body persist between collect +# and execute. Worker uses ``copy.deepcopy(model_configs_*[model])`` so +# ``run_dpa_with_cp`` mutations don't leak across configs in a batch. +# --------------------------------------------------------------------------- + +# Module-level state used by the session fixture's collect phase. +_COLLECT_MODE = False +_COLLECTED_KWARGS = {} # nodeid -> kwargs dict (populated in collect mode) +_BACKEND_CACHE = ( + {} +) # nodeid -> (fused_attn_supported, fused_attn_backends) or (flash_attn_supported,) + + +def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs): + """Drop-in replacement for ``run_distributed(get_bash_arguments(...))``. + + In *collect mode* (during the session fixture's first pass), records this + test's kwargs so the fixture can batch them. In *execute mode* (the normal + test run), looks up the pre-computed result and either passes, fails, or + skips. + """ + if _COLLECT_MODE: + _COLLECTED_KWARGS[request.node.nodeid] = dict(num_gpus=num_gpus_per_node, **worker_kwargs) + return + entry = batch_results.get(request.node.nodeid) + if entry is None: + pytest.skip("No batched result recorded (collection mismatch).") + if not entry.get("ok", False): + raise AssertionError(entry.get("error") or "Batched config failed (no error captured)") + + +def _run_batch_once(num_gpus, configs): + """Launch one torchrun that runs *configs* sequentially inside one NCCL world. + + Returns a list of ``{"ok": bool, "error": str|None}`` dicts, one per config. + Missing entries (subprocess crashed mid-batch) are synthesized as failures. + """ + # Stringify values: run_dpa_with_cp uses ``== "True"`` string comparisons. + # Strip ``num_gpus`` (launcher-only, not a worker kwarg). + worker_kwargs = [{k: str(v) for k, v in cfg.items() if k != "num_gpus"} for cfg in configs] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".cp_batch.json", delete=False) as fh: + batch_path = fh.name + json.dump(worker_kwargs, fh) + results_path = batch_path + ".results.json" + + try: + argv = get_bash_arguments(num_gpus_per_node=num_gpus, batch_config_json=batch_path) + launch_err = None + try: + run_distributed(argv) + except Exception as exc: + # Catch broadly: subprocess.run can raise OSError/FileNotFoundError + # in addition to the AssertionError that run_distributed wraps a + # non-zero exit in. Letting any of these propagate would tear down + # the session fixture and ERROR every batched test instead of + # marking just this batch's configs as failed. + launch_err = str(exc) or repr(exc) + + try: + with open(results_path, "r") as f: + per_cfg = json.load(f) + except (OSError, json.JSONDecodeError): + per_cfg = [] + + results = [] + for i in range(len(configs)): + if i < len(per_cfg): + results.append(per_cfg[i]) + else: + results.append( + { + "ok": False, + "error": launch_err or "Subprocess exited before this config ran.", + "_unattributed": True, + } + ) + return results + finally: + for p in (batch_path, results_path): + try: + os.unlink(p) + except OSError: + pass + + +def _run_one_batch(num_gpus, configs): + """Run a batch, then retry any unattributed-crash entries as singletons. + + When the worker subprocess crashes (segfault / NCCL hang / OOM) before it + can flush a per-config result, every config past the crash gets the + generic "Subprocess exited before this config ran" marker and we don't + know which config is the actual culprit. Re-running each marked config in + its own torchrun pinpoints which one crashes on its own (and salvages a + real result for the ones that ran fine but were caught downstream of the + crash). + + Disable via ``CP_TEST_BATCH_RETRY=0`` — useful if the singleton retries + themselves are taking too long on a flaky cluster. + """ + results = _run_batch_once(num_gpus, configs) + if len(configs) <= 1 or not int(os.getenv("CP_TEST_BATCH_RETRY", "1")): + for r in results: + r.pop("_unattributed", None) + return results + for i, r in enumerate(results): + if r.pop("_unattributed", False): + results[i] = _run_batch_once(num_gpus, [configs[i]])[0] + results[i].pop("_unattributed", None) + return results + + +class _DummyRequest: + """Stand-in for the ``request`` fixture during the dry-run phase. + + The test body only touches ``request.node.nodeid``, so this is enough. + """ + + def __init__(self, nodeid): + self.node = type("_DummyNode", (), {"nodeid": nodeid})() + + +def _item_static_skip(item): + """Return True if pytest's static skip/skipif markers would skip *item*. + + These markers are evaluated by pytest at runtime, before the test body + runs. The dry-run calls ``item.function(...)`` directly and would bypass + them — we replicate the check here so a marker-skipped test isn't queued + for torchrun unnecessarily. + """ + for marker in item.iter_markers("skip"): + return True + for marker in item.iter_markers("skipif"): + cond = marker.args[0] if marker.args else marker.kwargs.get("condition") + if cond: + return True + return False + + +def _dry_run_item(item): + """Invoke a parametrized test body in collect mode. + + Raises ``pytest.skip.Exception`` if the body skips, otherwise returns + after ``_run_or_fetch`` has stashed the kwargs in ``_COLLECTED_KWARGS``. + """ + func = item.function + params = dict(item.callspec.params) + func(_DummyRequest(item.nodeid), {}, **params) + + +@pytest.fixture(scope="session") +def _cp_batch_results(request): + """Run all batched test bodies once in collect mode, then run torchrun batches. + + Skips are NOT tracked here — the test body raises ``pytest.skip(...)`` in + both collect and execute mode, so skipped tests never reach ``_run_or_fetch`` + and don't need an entry in the result map. + """ + global _COLLECT_MODE + + items = [ + it for it in request.session.items if "_cp_batch_results" in getattr(it, "fixturenames", ()) + ] + + import time as _time + + _COLLECTED_KWARGS.clear() + _COLLECT_MODE = True + _t0 = _time.monotonic() + try: + for item in items: + if _item_static_skip(item): + continue # pytest will skip this at runtime; don't queue for torchrun + try: + _dry_run_item(item) + except pytest.skip.Exception: + pass # the same pytest.skip will fire again in execute mode + except BaseException: # noqa: BLE001 + # Don't let a single bad item kill the whole session fixture — + # pytest will re-raise the same error in execute mode and the + # failure will surface there as a normal test ERROR. + pass + finally: + _COLLECT_MODE = False + print( + f"\n[CP-BATCH] Collect done: {len(_COLLECTED_KWARGS)} configs from" + f" {len(items)} items in {_time.monotonic() - _t0:.1f}s", + flush=True, + ) + + by_num_gpus = {} + for nodeid, kwargs in _COLLECTED_KWARGS.items(): + num_gpus = kwargs.pop("num_gpus") + by_num_gpus.setdefault(num_gpus, []).append((nodeid, kwargs)) + + results = {} + for num_gpus, entries in by_num_gpus.items(): + n_batches = (len(entries) + _BATCH_SIZE - 1) // _BATCH_SIZE + for batch_idx, start in enumerate(range(0, len(entries), _BATCH_SIZE)): + chunk = entries[start : start + _BATCH_SIZE] + print( + f"[CP-BATCH] Running batch {batch_idx + 1}/{n_batches}" + f" ({len(chunk)} cfgs, {num_gpus} GPUs)...", + flush=True, + ) + _bt = _time.monotonic() + chunk_results = _run_one_batch(num_gpus, [kw for _, kw in chunk]) + ok = sum(1 for r in chunk_results if r.get("ok")) + print( + f"[CP-BATCH] => {ok}/{len(chunk)} passed in {_time.monotonic() - _bt:.1f}s", + flush=True, + ) + for (nodeid, _), res in zip(chunk, chunk_results): + results[nodeid] = res + print( + f"[CP-BATCH] All batches done: {len(results)} results" + f" in {_time.monotonic() - _t0:.1f}s total", + flush=True, + ) + return results + + dtypes = ["bf16", "fp16"] qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] @@ -91,7 +366,9 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +def test_cp_with_flash_attention( + request, _cp_batch_results, dtype, model, qkv_format, cp_comm_type +): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -131,25 +408,30 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} - available_backends, *_ = get_available_attention_backends( - config, - qkv_dtype=dtypes[dtype], - qkv_layout="_".join([qkv_format] * 3), - ) - flash_attn_supported, *_ = available_backends + nodeid = request.node.nodeid + if nodeid in _BACKEND_CACHE: + flash_attn_supported = _BACKEND_CACHE[nodeid] + else: + available_backends, *_ = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + ) + flash_attn_supported, *_ = available_backends + _BACKEND_CACHE[nodeid] = flash_attn_supported if not flash_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FlashAttention", - cp_comm_type=cp_comm_type, - log_level=pytest_logging_level, - ), + _run_or_fetch( + request, + _cp_batch_results, + num_gpus_per_node=num_gpus, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ) @@ -274,7 +556,17 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( - dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O + request, + _cp_batch_results, + dtype, + model, + qkv_format, + cp_comm_type, + fp8_bwd, + fp8_mha, + fp8_dpa, + scaling_mode, + f16_O, ): config = model_configs_fused_attn[model] config.context_parallel = True @@ -377,22 +669,12 @@ def test_cp_with_fused_attention( # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, - qkv_layout="_".join([qkv_format] * 3), - fp8=fp8, - fp8_meta=fp8_meta, - is_training=is_training, - deterministic=_deterministic, - ) - _, fused_attn_supported, _ = available_backends - if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: - config_copy = copy.deepcopy(config) - config_copy.context_parallel = False - config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" + nodeid = request.node.nodeid + if nodeid in _BACKEND_CACHE: + fused_attn_supported = _BACKEND_CACHE[nodeid] + else: available_backends, _, fused_attn_backends = get_available_attention_backends( - config_copy, + config, qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), fp8=fp8, @@ -400,25 +682,46 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) + _, fused_attn_supported, _ = available_backends + if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: + config_copy = copy.deepcopy(config) + config_copy.context_parallel = False + config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" + available_backends, _, fused_attn_backends = get_available_attention_backends( + config_copy, + qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, + qkv_layout="_".join([qkv_format] * 3), + fp8=fp8, + fp8_meta=fp8_meta, + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported, _ = available_backends + _BACKEND_CACHE[nodeid] = fused_attn_supported if not fused_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FusedAttention", - cp_comm_type=cp_comm_type, - fp8_bwd=fp8_bwd, - fp8_dpa=fp8_dpa, - fp8_mha=fp8_mha, - scaling_mode=scaling_mode, - f16_O=f16_O, - is_training=is_training, - deterministic=_deterministic, - log_level=pytest_logging_level, - ), + if _deterministic and config.softmax_type != "vanilla": + pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") + if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: + pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + + _run_or_fetch( + request, + _cp_batch_results, + num_gpus_per_node=num_gpus, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, + fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, + is_training=is_training, + deterministic=_deterministic, + log_level=pytest_logging_level, )