diff --git a/training/verl/workers/fsdp_workers.py b/training/verl/workers/fsdp_workers.py index 5aa1328..adf2653 100644 --- a/training/verl/workers/fsdp_workers.py +++ b/training/verl/workers/fsdp_workers.py @@ -134,6 +134,16 @@ def _build_model_optimizer(self, # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings) + # Apply Liger kernel optimizations to Qwen2 model + from liger_kernel.transformers import apply_liger_kernel_to_qwen2 + apply_liger_kernel_to_qwen2( + rope=False, + cross_entropy=False, + fused_linear_cross_entropy=True, + rms_norm=True, + swiglu=True + ) + with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") actor_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, @@ -869,13 +879,23 @@ def _build_model_optimizer(self, config, enable_gradient_checkpointing=False): check_model_support_rmpad(model_config.model_type) init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings) + # Apply Liger kernel optimizations to Qwen2 model + from liger_kernel.transformers import apply_liger_kernel_to_qwen2 + apply_liger_kernel_to_qwen2( + rope=False, + cross_entropy=False, + fused_linear_cross_entropy=True, + rms_norm=True, + swiglu=True + ) + with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") reward_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, - torch_dtype=torch.bfloat16, + torch_dtype=torch.float32, attn_implementation='flash_attention_2', trust_remote_code=trust_remote_code) - reward_module.to(torch.bfloat16) + reward_module.to(torch.float32) if enable_gradient_checkpointing: reward_module.gradient_checkpointing_enable() from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision