forked from kmccleary3301/nested_learning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_deepspeed.py
More file actions
121 lines (105 loc) · 4.04 KB
/
train_deepspeed.py
File metadata and controls
121 lines (105 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import annotations
import json
import os
from pathlib import Path
import hydra
import torch
import torch.distributed as dist
from omegaconf import DictConfig
from nested_learning.logging_utils import NullLogger, init_logger
from nested_learning.training import (
DistributedContext,
_seed_everything,
build_dataloader,
build_model_from_cfg,
compute_teach_signal,
unwrap_config,
)
try:
import deepspeed
except ImportError as exc: # pragma: no cover - optional dependency
raise RuntimeError(
"DeepSpeed is not installed. Install it in this environment to use train_deepspeed.py."
) from exc
def setup_distributed() -> DistributedContext:
deepspeed.init_distributed()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{local_rank}")
return DistributedContext(rank=rank, world_size=world_size, device=device)
def load_ds_config(path: str | Path) -> dict:
with open(path, "r", encoding="utf-8") as handle:
return json.load(handle)
@hydra.main(config_path="configs", config_name="hope/target", version_base=None)
def main(cfg: DictConfig) -> None:
cfg = unwrap_config(cfg)
dist_ctx = setup_distributed()
train_seed = cfg.train.get("seed")
deterministic = cfg.train.get("deterministic", False)
if train_seed is not None:
_seed_everything(int(train_seed), deterministic=bool(deterministic))
model = build_model_from_cfg(cfg.model)
ds_config = load_ds_config(cfg.deepspeed.config)
engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=[p for p in model.parameters() if p.requires_grad],
config=ds_config,
)
train_seed = cfg.train.get("seed")
loader_seed = None if train_seed is None else int(train_seed) + dist_ctx.rank
dataloader, sampler = build_dataloader(
cfg.data,
distributed=True,
dist_ctx=dist_ctx,
seed=loader_seed,
)
logger = (
init_logger(getattr(cfg, "logging", None), cfg) if engine.global_rank == 0 else NullLogger()
)
steps = cfg.train.steps
log_interval = cfg.train.get("log_interval", 10)
checkpoint_cfg = cfg.train.get("checkpoint", {})
ckpt_dir = Path(checkpoint_cfg.get("dir", "checkpoints/deepspeed"))
if checkpoint_cfg.get("resume_tag"):
tag = checkpoint_cfg["resume_tag"]
engine.load_checkpoint(str(ckpt_dir), tag=tag)
if engine.global_rank == 0:
print(f"[DeepSpeed] Resumed from {ckpt_dir} tag={tag}")
step_iter = iter(dataloader)
epoch = 0
for step in range(steps):
if sampler is not None and step % len(dataloader) == 0:
sampler.set_epoch(epoch)
epoch += 1
try:
batch = next(step_iter)
except StopIteration:
step_iter = iter(dataloader)
batch = next(step_iter)
tokens = batch.to(dist_ctx.device)
logits = engine(tokens)
loss = torch.nn.functional.cross_entropy(
logits[:, :-1].reshape(-1, logits.size(-1)), tokens[:, 1:].reshape(-1)
)
engine.backward(loss)
engine.step()
with torch.no_grad():
teach_signal = compute_teach_signal(engine.module, logits, tokens)
engine.module(tokens, teach_signal=teach_signal)
if step % log_interval == 0 and engine.global_rank == 0:
ppl = torch.exp(loss.detach()).item()
logger.log({"loss": loss.item(), "ppl": ppl}, step=step)
print(f"[DeepSpeed] step={step} loss={loss.item():.4f} ppl={ppl:.2f}")
if (
checkpoint_cfg.get("enable", False)
and step % checkpoint_cfg.get("save_interval", 100) == 0
and engine.global_rank == 0
):
ckpt_dir.mkdir(parents=True, exist_ok=True)
engine.save_checkpoint(str(ckpt_dir), tag=f"step_{step:06d}")
logger.finish()
dist.destroy_process_group()
if __name__ == "__main__":
main()