Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions megatron/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from functools import lru_cache, reduce, wraps
from importlib.metadata import version
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy
import torch
Expand Down Expand Up @@ -83,6 +83,10 @@
_causal_conv1d_version = None


_Wrapped = TypeVar('_Wrapped', bound=Callable)
"""A function or class which has been wrapped by a decorator."""


@contextmanager
def null_decorator(*args, **kwargs):
"""
Expand Down Expand Up @@ -120,7 +124,7 @@ def experimental_fn(introduced_with_version: str):
"""
logged_functions = set()

def validator(func: Callable, max_lifetime: int = 3) -> Callable:
def validator(func: _Wrapped, max_lifetime: int = 3) -> _Wrapped:
"""Validates the request to the experimental function.

Args:
Expand Down Expand Up @@ -186,7 +190,7 @@ def experimental_cls(introduced_with_version: str):
"""
logged_classes = set()

def validator(cls: Callable, max_lifetime: int = 3) -> Callable:
def validator(cls: _Wrapped, max_lifetime: int = 3) -> _Wrapped:
"""Validates the request to the experimental function.

Args:
Expand Down Expand Up @@ -2209,7 +2213,9 @@ def _nvtx_decorator_get_func_path(func):
return f"{module.__name__}.{caller_func}"


def nvtx_decorator(message: Optional[str] = None, color: Optional[str] = None):
def nvtx_decorator(
message: Optional[str] = None, color: Optional[str] = None
) -> Callable[[_Wrapped], _Wrapped]:
"""Decorator to add NVTX range to a function.

Args:
Expand All @@ -2229,7 +2235,7 @@ def another_function():
pass
"""

def decorator(func: Callable) -> Callable:
def decorator(func: _Wrapped) -> _Wrapped:
if _nvtx_enabled:
return nvtx.annotate(
message=message or _nvtx_decorator_get_func_path(func), color=color
Expand Down Expand Up @@ -2372,7 +2378,7 @@ def deprecated(
removal_version: Optional[str] = None,
alternative: Optional[str] = None,
reason: Optional[str] = None,
) -> Callable:
) -> Callable[[_Wrapped], _Wrapped]:
"""
Mark a function as deprecated.

Expand Down Expand Up @@ -2401,7 +2407,7 @@ def old_train_model(config):
pass
"""

def decorator(func: Callable) -> Callable:
def decorator(func: _Wrapped) -> _Wrapped:
# Add metadata
func._deprecated = True
func._deprecated_version = version
Expand Down Expand Up @@ -2432,7 +2438,7 @@ def wrapper(*args, **kwargs):
return decorator


def internal_api(func: Callable) -> Callable:
def internal_api(func: _Wrapped) -> _Wrapped:
"""
Mark a function or class as internal API (not for external use).

Expand Down Expand Up @@ -2465,7 +2471,7 @@ class ExperimentalFeature:
return func


def experimental_api(func: Callable) -> Callable:
def experimental_api(func: _Wrapped) -> _Wrapped:
"""
Mark a function or class as experimental API.

Expand Down Expand Up @@ -2499,8 +2505,8 @@ class ExperimentalModel:


def deprecate_args(
*deprecated_keys, message="Argument '{name}' has been deprecated and should not be used."
):
*deprecated_keys: str, message="Argument '{name}' has been deprecated and should not be used."
) -> Callable[[_Wrapped], _Wrapped]:
"""
Intercepts specific keyword arguments to raise a custom TypeError.

Expand All @@ -2509,7 +2515,7 @@ def deprecate_args(
message: Custom error message string. Use {name} as a placeholder.
"""

def decorator(func):
def decorator(func: _Wrapped) -> _Wrapped:
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Check if any deprecated key is present in kwargs
Expand Down
Loading