Skip to content

Potential bug in Trainer+Accelerate integration #42658

@quic-meetkuma

Description

@quic-meetkuma

System Info

  • transformers version: 5.0.0.dev0 (Added backend specific code only)
  • Platform: Linux-6.8.0-41-generic-x86_64-with-glibc2.39
  • Python version: 3.10.19
  • Huggingface_hub version: 1.0.0.rc6
  • Safetensors version: 0.6.2
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.7.0+cpu (NA)
  • Using distributed or parallel set-up in script?: Mentioned below

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Code mentioned in expected behavior can be executed directly.

  1. Install torch, accelerate, datasets and transformer dependencies
  2. Run mentioned code

Expected behavior

I am trying to run DDP+TP with DP sharding using below code. I am also using FSDP plugin for DP sharding, because accelerate does not let me us only DDP with TP without DP sharding.

As per comment, the accelerator.prepare should only be called on the optimizer not on the optimizer and the model together. Because the model was already prepared before hand. As per the accelerate code, the FSDP2 needs the model and the optimizer to be passed together. It looks like a potential bug to me.

Please note that I have added a backend like cuda in accelerate and transformers. Using that backend to run the code.

Code reference: https://github.com/huggingface/accelerate/blob/b9ca0de682f25f15357a3f9f1a4d94374a1d451d/tests/tp/fsdp2_tp_preparation.py#L60

Command: torchrun --master-port=1234 --nproc-per-node 8 ./run_tp_generic.py --dp_size 2 --tp_size 2

Code

import os
import torch
import argparse
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    AutoConfig
)
from transformers import DataCollatorWithPadding
from accelerate.utils import ParallelismConfig

def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Run tensor parallel training on GPU")
    parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.2-1b", help="Model name or path")
    parser.add_argument("--output_dir", type=str, default="./tp_output", help="Output directory")
    parser.add_argument("--batch_size", type=int, default=1, help="Per-device batch size")
    parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
    parser.add_argument("--force_device", type=str, choices=["cuda", "cpu"], default="auto",
                        help="Force specific device type (auto will use best available)")
    parser.add_argument("--bf16", action="store_true", help="Use bfloat16 precision if available")
    parser.add_argument("--tp_size", type=int, help="TP degree for tensor parallelism")
    parser.add_argument("--dp_size", type=int, help="DP degree for Distributed data parallelism")
    return parser.parse_args()

def setup_parallelism(tp_size, dp_size):
    """Set up tensor parallelism configuration."""
    parallelism_config = {"tp_size": tp_size, "dp_replicate_size": dp_size, "dp_shard_size": dp_size}
    pc = ParallelismConfig(**parallelism_config)
    return pc

def create_training_arguments(args, pc):
    """Create training arguments with appropriate settings."""
    return TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=1,
        num_train_epochs=args.epochs,
        logging_steps=1,
        fp16=True,
        parallelism_config=pc,
        ddp_find_unused_parameters=False,
        fsdp='no_shard',
        fsdp_config={
            "fsdp_version" : 2,
            "reshard_after_forward" : True,
            "auto_wrap_policy" : "transformer_based_wrap",
            "state_dict_type" : "SHARDED_STATE_DICT",
            "activation_checkpointing" : False,
            "cpu_ram_efficient_loading" : True,
        }
    )

def load_tokenizer(model_name):
    """Load and configure tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def load_model(model_name, tp_size, device_mesh):
    """Load model with tensor parallelism."""
    print(f"Loading model {model_name} with tensor parallelism (tp_size={tp_size})")
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        tp_plan="auto",
        tp_size=tp_size,
        device_mesh=device_mesh
    )
    return model

def create_dummy_dataset():
    """Create a dummy dataset for testing."""
    dummy_data = {
        "text": [
            "This is a sample sentence for training.",
            "Tensor parallelism is a cool technique.",
            "We need to test the training loop.",
            "This example should run without errors."
        ]
    }
    return Dataset.from_dict(dummy_data)

def tokenize_function(examples, tokenizer):
    """Tokenize dataset examples."""
    tokenized = tokenizer(
        examples["text"], 
        padding="max_length", 
        truncation=True, 
        max_length=128, 
        return_tensors="pt"
    )
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized


def main():
    """Main training function."""
    # Parse arguments
    args = parse_arguments()
    
    # Setup tensor parallelism
    pc = setup_parallelism(args.tp_size, args.dp_size)
    
    # Create training arguments
    training_args = create_training_arguments(args, pc)
    
    
    if args.dp_size > 0:
        training_args.fsdp_plugin_args["activation_checkpointing"] = False
        training_args.fsdp_plugin_args["state_dict_type"] = "SHARDED_STATE_DICT"
        training_args.fsdp_plugin_args["fsdp_version"] = 2
        training_args.fsdp_plugin_args["reshard_after_forward"] = True
        training_args.fsdp_plugin_args["auto_wrap_policy"] = "transformer_based_wrap"
        training_args.fsdp_plugin_args["cpu_ram_efficient_loading"] = True
        training_args.fsdp_plugin_args["forward_prefetch"] = None
    
    # Load tokenizer
    tokenizer = load_tokenizer(args.model_name)
    
    # Load model
    device_mesh = pc.get_device_mesh(args.force_device)
    model = load_model(args.model_name, args.tp_size, device_mesh)
    print(f"Model loaded on device: {model.device}")
    
    # Create dataset
    dataset = create_dummy_dataset()
    
    # Tokenize dataset
    tokenized_dataset = dataset.map(
        lambda x: tokenize_function(x, tokenizer), 
        batched=True
    )
    
    # Create data collator
    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    
    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
    )
    
    # Train model
    trainer.train()
    print("Training complete!")

if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions