diff --git a/docs/source_en/Components/Model/TransformersModel.md b/docs/source_en/Components/Model/TransformersModel.md index 7e34149f..f10b0351 100644 --- a/docs/source_en/Components/Model/TransformersModel.md +++ b/docs/source_en/Components/Model/TransformersModel.md @@ -15,6 +15,7 @@ class TransformersModel: ddp_config: Dict[str, Any] = None, fsdp_config: Dict[str, Any] = None, grad_scaler_config: Dict[str, Any] = None, + memory_efficient_init: bool = False, **kwargs): ... @@ -30,6 +31,7 @@ class TransformersModel: - ddp_config: DDP configuration when strategy is `accelerate`, see: [DDPKwargs](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L155) - fsdp_config: FSDP configuration when strategy is `accelerate`, see: [FSDPConfig](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1566) - grad_scaler_config: PyTorch's grad_scaler initialization configuration, see: [PyTorch's GradScaler constructor](https://github.com/pytorch/pytorch/blob/main/torch/cuda/amp/grad_scaler.py#L25) +- memory_efficient_init: Whether to enable memory-efficient model initialization for FSDP. When enabled, only rank 0 loads full weights and broadcasts sharded parameters to other ranks, reducing peak memory usage during initialization. Default `False`. Note: The optimization currently only applies to transformers <= 4.57.6; for transformers >= 5.0.0, it may lead to negative performance impact. - kwargs: - If you don't want to pass the model config field, you can put scattered configurations here. These parameters will be passed to `from_pretrained` or `from_config` later. diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" index 468b5efb..b7b9cf0f 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" @@ -15,6 +15,7 @@ class TransformersModel: ddp_config: Dict[str, Any] = None, fsdp_config: Dict[str, Any] = None, grad_scaler_config: Dict[str, Any] = None, + memory_efficient_init: bool = False, **kwargs): ... @@ -30,6 +31,7 @@ class TransformersModel: - ddp_config: strategy为`accelerate`时的DDP配置,参见:[DDPKwargs](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L155) - fsdp_config: strategy为`accelerate`时的FSDP配置,参见:[FSDPConfig](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1566) - grad_scaler_config: PyTorch的grad_scaler初始化配置,参见:[PyTorch的GradScaler构造](https://github.com/pytorch/pytorch/blob/main/torch/cuda/amp/grad_scaler.py#L25) +- memory_efficient_init: 是否启用FSDP内存高效初始化。启用后仅rank 0加载完整权重,其余rank通过广播获取分片参数,降低初始化阶段的内存和显存峰值。默认`False`。注意:该优化目前仅适用于 transformers <= 4.57.6;对于 transformers >= 5.0.0,可能会导致负面性能影响。 - kwargs: - 如果你不希望传递模型config字段,可以把零星的配置从这里放置进去。后续这些参数会传递到`from_pretrained`或者`from_config`中。 diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 66d8e9cb..89b497e2 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import os from typing import Any, Dict, Literal, Optional from twinkle import DeviceMesh +from .load_context import fsdp_pretrained_load_context class AccelerateStrategy: @@ -21,13 +21,15 @@ def __init__( mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', ddp_config: Dict[str, Any] = None, fsdp_config: Dict[str, Any] = None, + memory_efficient_init: bool = False, ): from accelerate import Accelerator self.device_mesh = device_mesh self.mixed_precision = mixed_precision + self._memory_efficient_init = memory_efficient_init parallelism_config = self._parallelism_config_from_device_mesh(device_mesh) - fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config) + fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init) kwargs_handlers = [] if ddp_config is not None: @@ -42,6 +44,9 @@ def __init__( kwargs_handlers=kwargs_handlers, ) + def pretrained_load_context(self): + return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None) + @staticmethod def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh): # TODO should test with transformers v5.0 @@ -69,7 +74,8 @@ def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh): return parallelism_config - def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any]): + def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any], + memory_efficient: bool): from accelerate import FullyShardedDataParallelPlugin from torch.distributed.fsdp import BackwardPrefetch from torch.distributed.fsdp import ShardingStrategy as FSDPShardingStrategy @@ -107,11 +113,9 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di activation_checkpointing=fsdp_config.pop('activation_checkpointing', False), auto_wrap_policy=fsdp_config.pop('auto_wrap_policy', 'transformer_based_wrap'), # noqa reshard_after_forward=fsdp_config.pop('reshard_after_forward', True), + cpu_ram_efficient_loading=fsdp_config.pop('cpu_ram_efficient_loading', memory_efficient), **fsdp_config, ) - # Enable memory efficient model loading in transformers(see `is_fsdp_enabled` in transformers) - # os.environ['ACCELERATE_USE_FSDP'] = '1' - # os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING'] = '1' return fsdp_plugin def wrap_model(self, model, *args): diff --git a/src/twinkle/model/transformers/strategy/load_context.py b/src/twinkle/model/transformers/strategy/load_context.py new file mode 100644 index 00000000..e3c0e64c --- /dev/null +++ b/src/twinkle/model/transformers/strategy/load_context.py @@ -0,0 +1,27 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import contextlib +import os + +_FSDP_EFFICIENT_LOADING_ENV = { + 'ACCELERATE_USE_FSDP': 'true', + 'FSDP_CPU_RAM_EFFICIENT_LOADING': 'true', +} + + +@contextlib.contextmanager +def fsdp_pretrained_load_context(enabled: bool): + """Enable the env flags required for transformers FSDP-aware loading when needed.""" + if not enabled: + yield + return + + saved_env = {key: os.environ.get(key) for key in _FSDP_EFFICIENT_LOADING_ENV} + os.environ.update(_FSDP_EFFICIENT_LOADING_ENV) + try: + yield + finally: + for key, old_val in saved_env.items(): + if old_val is None: + os.environ.pop(key, None) + else: + os.environ[key] = old_val diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index ce938eef..48a1da85 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set from twinkle.utils import DeviceMesh, Platform, torch_util +from .load_context import fsdp_pretrained_load_context if TYPE_CHECKING: from torch.distributed.fsdp import MixedPrecisionPolicy @@ -18,14 +19,19 @@ def __init__(self, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', fsdp_config: Dict[str, Any] = None, + memory_efficient_init: bool = False, enable_ep: bool = True, ep_size: Optional[int] = None): self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.fsdp_config = fsdp_config or {} + self._memory_efficient_init = memory_efficient_init self.enable_ep = enable_ep self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None + def pretrained_load_context(self): + return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None) + def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]: if self.device_mesh is None: return None @@ -48,6 +54,23 @@ def wrap_model(self, model, optimizer=None): fsdp_mesh = _build_fsdp_mesh(self.device_mesh) if fsdp_mesh is not None: ep_enabled = (self.enable_ep and self.ep_fsdp_device_mesh is not None) + + # Drop optimizer references to pre-shard params before fully_shard to reduce peak memory. + if optimizer is not None: + _unbind_optimizer_params(optimizer) + + # EP path requires experts on a real device, incompatible with meta-device flow. + use_meta = self._memory_efficient_init and not ep_enabled + + original_sd = None + saved_buffers = None + if use_meta: + original_sd = model.state_dict() + saved_buffers = _get_non_persistent_buffers(model) + model = model.to(torch.device('meta')) + if hasattr(model, 'tie_weights'): + model.tie_weights() + if ep_enabled: _ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh) _place_ep_experts_on_local_device(model, self.ep_fsdp_device_mesh) @@ -57,11 +80,9 @@ def wrap_model(self, model, optimizer=None): if ep_enabled: _ensure_ep_fsdp_supported(model) - # Collect experts map and expert params experts_map = _collect_ep_experts_map(model) if ep_enabled else {} expert_params = _collect_expert_params(model) if self.enable_ep else None - # Build layer_pairs: [(layer_mod, experts_mod_or_None)] layers = _get_decoder_layers(model) layer_pairs = [] if layers is not None: @@ -69,7 +90,6 @@ def wrap_model(self, model, optimizer=None): experts_mod = _find_experts_in_layer(layer_mod, experts_map) layer_pairs.append((layer_mod, experts_mod)) - # FSDP2 wrapping per layer world_size = self.device_mesh.world_size ep_fsdp_mesh_1d = self.ep_fsdp_device_mesh['ep_fsdp'] if ep_enabled else None @@ -79,9 +99,6 @@ def wrap_model(self, model, optimizer=None): if experts_mod is not None and ep_fsdp_mesh_1d is not None: from torch.distributed.tensor import Shard - # PreMulSum (used by set_gradient_divide_factor) only supports - # float16/float32/float64; override reduce_dtype to float32 - # when the base policy uses bfloat16. ep_mp_policy = _build_ep_mp_policy(mp_policy) fully_shard( experts_mod, @@ -90,7 +107,6 @@ def wrap_model(self, model, optimizer=None): mp_policy=ep_mp_policy, shard_placement_fn=lambda param: Shard(1), ) - # gradient_divide_factor = world_size experts_mod.set_gradient_divide_factor(world_size) layer_mod._fsdp_modules.append(experts_mod) @@ -103,7 +119,6 @@ def wrap_model(self, model, optimizer=None): ) layer_mod._fsdp_modules.append(layer_mod) - # Root model fully_shard( model, mesh=fsdp_mesh, @@ -112,11 +127,22 @@ def wrap_model(self, model, optimizer=None): ignored_params=expert_params, ) - # Manual prefetch + if use_meta: + device_type = self.device_mesh.device_type or 'cuda' + is_rank0 = (dist.get_rank() == 0) + _broadcast_sharded_state_dict( + model, + original_sd if is_rank0 else {}, + device_type=device_type, + ) + target_device = torch.device(device_type) + _restore_non_persistent_buffers(model, saved_buffers, device=target_device) + if hasattr(model, 'tie_weights'): + model.tie_weights() + if ep_enabled and layer_pairs: _setup_manual_prefetch([lp[0] for lp in layer_pairs]) - # Rebuild groups after wrapping so grad clip sees the live Parameter objects. if ep_enabled: _rebuild_ep_param_groups(model) @@ -398,3 +424,76 @@ def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> tor return optimizer optimizer.param_groups[0]['params'] = list(model.parameters()) return optimizer + + +def _broadcast_sharded_state_dict( + model: nn.Module, + full_sd: dict, + device_type: str = 'cuda', +) -> None: + """Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor.""" + from torch.distributed.tensor import DTensor, distribute_tensor + + meta_sharded_sd = model.state_dict() + sharded_sd = {} + is_rank0 = (dist.get_rank() == 0) + + for param_name, sharded_param in meta_sharded_sd.items(): + shape = sharded_param.size() + dtype = sharded_param.dtype + + if is_rank0: + full_param = full_sd[param_name] + full_tensor = full_param.detach().to(device_type) + if isinstance(full_tensor, DTensor): + full_tensor = full_tensor.to_local() + else: + full_tensor = torch.empty(shape, device=device_type, dtype=dtype) + + dist.broadcast(full_tensor, src=0) + torch_util.synchronize() + + device_mesh = sharded_param.device_mesh + placements = sharded_param.placements + sharded_tensor = distribute_tensor(full_tensor, device_mesh, placements) + del full_tensor + + sharded_sd[param_name] = sharded_tensor + + model.load_state_dict(sharded_sd, assign=True) + + +def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]: + """Return {fqn: tensor} for non-persistent buffers (lost on to('meta')).""" + non_persistent_fqns: Set[str] = set() + for fqn, module in model.named_modules(): + for buf_name in getattr(module, '_non_persistent_buffers_set', set()): + full_fqn = f'{fqn}.{buf_name}' if fqn else buf_name + non_persistent_fqns.add(full_fqn) + + return {k: v.clone() for k, v in model.named_buffers() if k in non_persistent_fqns} + + +def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None: + """Drop optimizer refs to pre-shard params before fully_shard to lower peak memory.""" + for group in optimizer.param_groups: + for i in range(len(group['params'])): + param = group['params'][i] + group['params'][i] = torch.empty(1, dtype=param.dtype, device=param.device) + + +def _restore_non_persistent_buffers( + model: nn.Module, + saved_buffers: Dict[str, torch.Tensor], + device: torch.device, +) -> None: + """Re-register non-persistent buffers saved before to('meta').""" + for fqn, buf_tensor in saved_buffers.items(): + buf_tensor = buf_tensor.to(device) + if '.' in fqn: + parent_fqn, local_name = fqn.rsplit('.', 1) + parent = model.get_submodule(parent_fqn) + else: + local_name = fqn + parent = model + parent.register_buffer(local_name, buf_tensor, persistent=False) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 48f08039..520aaf9f 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -189,6 +189,7 @@ def __init__( ddp_config: Dict[str, Any] = None, fsdp_config: Dict[str, Any] = None, grad_scaler_config: Dict[str, Any] = None, + memory_efficient_init: bool = False, **kwargs): os.environ['TOKENIZERS_PARALLELISM'] = 'true' self._try_init_process_group() @@ -201,6 +202,7 @@ def __init__( self.mixed_precision = mixed_precision self._fsdp_config = dict(fsdp_config or {}) self._ddp_config = ddp_config or {} + self._memory_efficient_init = memory_efficient_init self._decide_strategy(strategy) self.grad_scaler_config = grad_scaler_config if isinstance(model_cls, str): @@ -209,8 +211,9 @@ def __init__( self.model = model_cls.from_config(config, **kwargs) else: model_id = HubOperation.download_model(model_id) - self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) - # Construct sequence-parallel strategy lazily during wrapping to reduce init-time side effects. + # Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load. + with self.strategy.pretrained_load_context(): + self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) self.model.gradient_checkpointing_enable() self.sp_strategy = None self._model_wrapped = False @@ -235,6 +238,7 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): mixed_precision=self.mixed_precision, fsdp_config=self._fsdp_config, device_mesh=self.device_mesh, + memory_efficient_init=self._memory_efficient_init, enable_ep=self._enable_expert_parallel, ep_size=ep_size, ) @@ -243,7 +247,8 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): mixed_precision=self.mixed_precision, ddp_config=self._ddp_config, fsdp_config=self._fsdp_config, - device_mesh=self.device_mesh) + device_mesh=self.device_mesh, + memory_efficient_init=self._memory_efficient_init) # Sequence parallel ("ulysses") is derived from dp/fsdp ranks; it does not change world size. # We construct `sp_strategy` after the underlying HF model is initialized (see __init__). @@ -290,6 +295,7 @@ def _lazy_wrap_model(self): self._ensure_sp_strategy() if self.sp_strategy is not None: self.sp_strategy.initialize() + if len(optimizer_groups) == 1: optimizer_group = optimizer_groups[0] optimizer = optimizer_group.optimizer @@ -299,7 +305,11 @@ def _lazy_wrap_model(self): self.register_mm_forward_hook(optimizer_group) else: # maybe forward_only, no optimizer_group available - self.model = self.strategy.wrap_model(self.model) + result = self.strategy.wrap_model(self.model) + if isinstance(result, tuple): + self.model = result[0] + else: + self.model = result self._model_wrapped = True def register_mm_forward_hook(self, optimizer_group: OptimizerGroup):