Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,14 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
row, skipped_count = proc_fn(item)
if row is None:
continue
# Data starts with an assistant message, skip the entire conversation
if row["conversations"][0]["role"] == "assistant":
total_skipped_count += len(row["conversations"])
continue
total_skipped_count += skipped_count
else:
row = item
row, skipped_count = proc_fn(item)
f.write(json.dumps(row, ensure_ascii=False) + "\n")

if test_ds is not None:
Expand All @@ -210,6 +215,10 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
row, skipped_count = proc_fn(item)
if row is None:
continue
# Data starts with an assistant message, skip the entire conversation
if row["conversations"][0]["role"] == "assistant":
total_skipped_count += len(row["conversations"])
continue
total_skipped_count += skipped_count
else:
row = item
Expand Down
124 changes: 88 additions & 36 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import os
import time
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Tuple, Union
from itertools import islice
from typing import Any, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -315,6 +316,39 @@ def build_target_model(
return target_head, None


def load_checkpoint(args: Namespace) -> Tuple[int, int, Optional[str], Any]:
"""
Load checkpoint and return starting epoch and global_step
returns:
start_epoch: The starting epoch.
global_step: The global step.
checkpoint_path: The path to the checkpoint.
state: The training state.
"""
if not args.resume or not os.path.isdir(args.output_dir):
print_on_rank0("Starting training from scratch")
return 0, 0, None, None

checkpoint_path = get_last_checkpoint(args.output_dir)
if not checkpoint_path:
print_on_rank0("No checkpoint found, starting from scratch")
return 0, 0, None, None

training_state_path = os.path.join(checkpoint_path, "training_state.pt")
if not os.path.exists(training_state_path):
# Could be fine-tuning from a pretrained model without training state
print_on_rank0(f"Training state not found at {training_state_path}")
return 0, 0, checkpoint_path, None

# Load training state
state = torch.load(training_state_path, weights_only=False, map_location="cpu")
start_epoch = state["epoch"]
global_step = state["global_step"]

print_on_rank0(f"Resumed from epoch {start_epoch}, step {global_step}")
return start_epoch, global_step, checkpoint_path, state


def sanity_check(args: Namespace) -> None:
"""
Perform sanity checks on the arguments.
Expand All @@ -336,7 +370,9 @@ def sanity_check(args: Namespace) -> None:
), "train_hidden_states_path should not be None for usp"


def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]:
def build_draft_model(
args: Namespace, draft_model_last_checkpoint: Optional[str]
) -> Tuple[AutoDraftModelConfig, nn.Module]:
# Handle draft model config
if args.draft_model_config is None:
# Auto-generate and save config file
Expand All @@ -348,27 +384,10 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]
# Use provided config file
draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config)

# Handle base ckpt, config file
draft_model_last_checkpoint = None
if args.ckpt_dir is not None:
if os.path.isdir(args.ckpt_dir):
draft_model_config = AutoDraftModelConfig.from_file(
os.path.join(args.ckpt_dir, "config.json")
)
draft_model_last_checkpoint = args.ckpt_dir
print_on_rank0(f"Finetuning from base model: {draft_model_last_checkpoint}")
else:
raise ValueError(
f"Provided base model dir {args.ckpt_dir} is not a valid directory."
)

# detecting last ckpt for draft model
if args.resume and os.path.isdir(args.output_dir):
print_on_rank0(args.output_dir)
draft_model_last_checkpoint = get_last_checkpoint(args.output_dir)
print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}")

if draft_model_last_checkpoint:
draft_model_config = AutoDraftModelConfig.from_file(
os.path.join(args.ckpt_dir, "config.json")
)
draft_model = AutoEagle3DraftModel.from_pretrained(
draft_model_last_checkpoint,
attention_backend=args.attention_backend,
Expand All @@ -395,22 +414,30 @@ def build_dataloaders(
tokenizer = AutoTokenizer.from_pretrained(args.target_model_path)

# convert to dataloader
cache_params_string = (
train_cache_params_string = (
f"{args.train_data_path}-"
f"{args.max_length}-"
f"{args.chat_template}-"
f"{args.target_model_path}" # Tokenizer may also different
)
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()
train_cache_key = hashlib.md5(train_cache_params_string.encode()).hexdigest()
eval_cache_params_string = (
f"{args.eval_data_path}-"
f"{args.max_length}-"
f"{args.chat_template}-"
f"{args.target_model_path}" # Tokenizer may also different
)
eval_cache_key = hashlib.md5(eval_cache_params_string.encode()).hexdigest()
cache_dir = os.path.join(args.cache_dir, "processed_dataset")
train_dataset = load_dataset("json", data_files=args.train_data_path)["train"]
with rank_0_priority():
train_eagle3_dataset = build_eagle3_dataset(
dataset=train_dataset,
tokenizer=tokenizer,
chat_template=args.chat_template,
max_length=args.max_length,
cache_dir=os.path.join(args.cache_dir, "processed_dataset"),
cache_key=cache_key,
cache_dir=cache_dir,
cache_key=train_cache_key,
is_vlm=args.is_vlm,
is_preformatted=args.is_preformatted,
processor=processor,
Expand All @@ -421,7 +448,7 @@ def build_dataloaders(
target_vocab_size=draft_model_config.vocab_size,
draft_vocab_size=draft_model_config.draft_vocab_size,
cache_dir=os.path.join(args.cache_dir, "vocab_mapping"),
cache_key=cache_key,
cache_key=train_cache_key,
)

if args.train_hidden_states_path is not None:
Expand Down Expand Up @@ -449,6 +476,8 @@ def build_dataloaders(
tokenizer,
args.chat_template,
args.max_length,
cache_dir=cache_dir,
cache_key=eval_cache_key,
is_vlm=args.is_vlm,
processor=processor,
num_proc=args.build_dataset_num_proc,
Expand Down Expand Up @@ -652,10 +681,16 @@ def main():
print_args_with_dots(args)
print_with_rank("Initialized distributed environment")

start_epoch, global_step, draft_model_last_checkpoint, optimizer_state = (
load_checkpoint(args)
)

# ================================================
# 2. Build models
# ================================================
draft_model_config, draft_model = build_draft_model(args)
draft_model_config, draft_model = build_draft_model(
args, draft_model_last_checkpoint
)
target_model, processor = build_target_model(args, draft_model_config, is_online)

# ================================================
Expand All @@ -664,6 +699,8 @@ def main():
train_dataloader, vocab_mapping_path, eval_dataloader = build_dataloaders(
args, draft_model_config, processor
)
# Set this attribute to show draft model config in the tracker
args.draft_model_config_dict = draft_model_config.__dict__

# we load the vocab mapping then
draft_model.load_vocab_mapping(vocab_mapping_path)
Expand Down Expand Up @@ -724,14 +761,15 @@ def main():
warmup_ratio=args.warmup_ratio,
total_steps=args.total_steps,
)
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state)
print_with_rank("Loaded optimizer state from checkpoint")
print_with_rank("Initialized optimizer and scheduler")

# ================================================
# 6. Build tracker
# ================================================
tracker = build_tracker(args, parser)
global_step = 0
start_epoch = 0
dist.barrier()

last_time = time.time()
Expand All @@ -741,19 +779,27 @@ def main():
# ================================================
print_on_rank0(f"Starting training from epoch {start_epoch}")

steps_to_skip = global_step - steps_per_epoch * start_epoch

for epoch in range(start_epoch, args.num_epochs):
# Run training
train_dataloader.sampler.set_epoch(epoch + 1)
draft_model.train()

if dist.get_rank() == 0:
progress_bar = tqdm(
train_dataloader, desc=f"Training Epoch {epoch}", leave=True
train_dataloader,
desc=f"Training Epoch {epoch}",
leave=True,
initial=steps_to_skip,
)
else:
progress_bar = train_dataloader

for data in progress_bar:
for batch_idx, data in enumerate(
islice(progress_bar, steps_to_skip, None), start=steps_to_skip
):
steps_to_skip = 0 # reset for next epoch
global_step += 1

# ================================================
Expand Down Expand Up @@ -789,12 +835,16 @@ def main():
)
run_backward_and_update(args, plosses, optimizer, global_step)

# detach losses and accuracies to avoid memory leak
plosses_for_metrics = [p.detach() for p in plosses]
acces_for_metrics = [a.detach() for a in acces]

# log training metrics
if global_step % (args.log_interval * args.draft_accumulation_steps) == 0:
record_metrcs(
args,
acces,
plosses,
acces_for_metrics,
plosses_for_metrics,
global_step // args.draft_accumulation_steps,
tracker,
optimizer,
Expand All @@ -804,8 +854,10 @@ def main():
if dist.get_rank() == 0:
time_per_step = time.time() - last_time
last_time = time.time()
avg_loss = sum(pl for pl in plosses) / len(plosses)
avg_acc = sum(acces) / len(acces)
avg_loss = sum(pl for pl in plosses_for_metrics) / len(
plosses_for_metrics
)
avg_acc = sum(acces_for_metrics) / len(acces_for_metrics)
progress_bar.set_postfix(
{
"loss": f"{avg_loss:.2f}",
Expand Down
8 changes: 7 additions & 1 deletion specforge/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,14 @@ def __init__(self, args, output_dir: str):
if self.rank == 0:
wandb.login(key=args.wandb_key)
wandb.init(
project=args.wandb_project, name=args.wandb_name, config=vars(args)
project=args.wandb_project,
name=args.wandb_name,
config={
**vars(args),
"draft_model_config_dict": args.draft_model_config_dict,
},
)
wandb.save(args.draft_model_config)
self.is_initialized = True

def log(self, log_dict: Dict[str, Any], step: Optional[int] = None):
Expand Down
4 changes: 2 additions & 2 deletions specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def print_on_rank0(message):
logger.info(message)


def get_last_checkpoint(folder, prefix="epoch"):
def get_last_checkpoint(folder):
content = os.listdir(folder)
_re_checkpoint = re.compile(r"^" + prefix + r"_(\d+)$")
_re_checkpoint = re.compile(r"^epoch_(\d+)_step_(\d+)$")
checkpoints = [
path
for path in content
Expand Down