Skip to content

Commit 197f2e6

Browse files
committed
Change varname
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent f009c89 commit 197f2e6

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

tests/pytorch/test_sanity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def test_sanity_grouped_linear(
608608
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
609609

610610
if single_param:
611-
os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"] = "1"
611+
os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1"
612612

613613
if fp8_recipe is not None:
614614
if not is_fp8_supported(config):
@@ -650,7 +650,7 @@ def test_sanity_grouped_linear(
650650
assert out.shape == (num_tokens, ffn_hidden_size)
651651

652652
if single_param:
653-
del os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"]
653+
del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"]
654654

655655

656656
@pytest.mark.parametrize("dtype", param_types)

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def make_grouped_weights(self, defer_init=False) -> None:
795795
def reset_parameters(self, defer_init=False):
796796
super().reset_parameters(defer_init=defer_init)
797797
# Grouped tensor weights is an opt-in feature.
798-
if bool(int(os.getenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0"))):
798+
if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))):
799799
self.make_grouped_weights(defer_init=defer_init)
800800

801801
def set_tensor_parallel_attributes(self, defer_init=False) -> None:

0 commit comments

Comments
 (0)