Skip to content

Commit c68ec31

Browse files
authored
Add fast_set_attr to modules not inheriting from base.py (NVIDIA#2724)
fix fast_set_attr in other nn modules for fsdp Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
1 parent 9dac78e commit c68ec31

5 files changed

Lines changed: 24 additions & 4 deletions

File tree

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ def mask_func(x, y):
293293
bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
294294
)
295295

296+
def fast_setattr(self, name: str, value: Any) -> None:
297+
"""Fast attribute set for non-parameter fields."""
298+
self.__dict__[name] = value
299+
296300
def forward(
297301
self,
298302
_alibi_cache: Dict[str, Any],

transformer_engine/pytorch/attention/multi_head_attention.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""Multi-head Attention."""
66
import os
77
import collections
8-
from typing import Callable, List, Optional, Tuple, Union
8+
from typing import Any, Callable, List, Optional, Tuple, Union
99
import torch
1010

1111
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
@@ -478,6 +478,10 @@ def __init__(
478478
**common_gemm_kwargs,
479479
)
480480

481+
def fast_setattr(self, name: str, value: Any) -> None:
482+
"""Fast attribute set for non-parameter fields."""
483+
self.__dict__[name] = value
484+
481485
def _create_qk_norm_modules(
482486
self,
483487
qk_norm_type: Optional[str],

transformer_engine/pytorch/module/layernorm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
"""LayerNorm API"""
66
import warnings
7-
from typing import Iterable, Optional, Union
7+
from typing import Any, Iterable, Optional, Union
88

99
import torch
1010

@@ -102,6 +102,10 @@ def __init__(
102102
**kwargs,
103103
)
104104

105+
def fast_setattr(self, name: str, value: Any) -> None:
106+
"""Fast attribute set for non-parameter fields."""
107+
self.__dict__[name] = value
108+
105109
def reset_layer_norm_parameters(self) -> None:
106110
"""Init LN params"""
107111
warnings.warn(

transformer_engine/pytorch/module/rmsnorm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
"""RMSNorm API"""
66
import warnings
7-
from typing import Iterable, Optional, Union
7+
from typing import Any, Iterable, Optional, Union
88

99
import torch
1010

@@ -106,6 +106,10 @@ def __init__(
106106
**kwargs,
107107
)
108108

109+
def fast_setattr(self, name: str, value: Any) -> None:
110+
"""Fast attribute set for non-parameter fields."""
111+
self.__dict__[name] = value
112+
109113
def reset_rms_norm_parameters(self) -> None:
110114
"""Deprecated"""
111115
warnings.warn(

transformer_engine/pytorch/transformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import warnings
88
from contextlib import nullcontext
9-
from typing import Callable, List, Optional, Tuple, Union
9+
from typing import Any, Callable, List, Optional, Tuple, Union
1010

1111
import torch
1212

@@ -545,6 +545,10 @@ def __init__(
545545
device=device,
546546
)
547547

548+
def fast_setattr(self, name: str, value: Any) -> None:
549+
"""Fast attribute set for non-parameter fields."""
550+
self.__dict__[name] = value
551+
548552
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
549553
"""
550554
Set the tensor parallel group for the given

0 commit comments

Comments
 (0)