-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcommon.py
More file actions
30 lines (25 loc) · 1.08 KB
/
common.py
File metadata and controls
30 lines (25 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import tempfile
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict,
)
def _is_deepspeed_checkpoint(path: str):
if not os.path.exists(path):
raise FileExistsError(f"Checkpoint {path} does not exist.")
return os.path.isdir(path) and os.path.exists(os.path.join(path, "zero_to_fp32.py"))
def load_checkpoint(model_cls, ckpt_path: str, device, freeze: bool):
"""Handle DeepSpeed checkpoints in model loading."""
if not _is_deepspeed_checkpoint(ckpt_path):
model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device)
else:
with tempfile.TemporaryDirectory() as dirname:
path = os.path.join(dirname, "lightning.cpkt")
convert_zero_checkpoint_to_fp32_state_dict(ckpt_path, path)
model = model_cls.load_from_checkpoint(path, strict=False)
model = model.to(device)
if freeze:
model.freeze()
return model
def zip_strict(*args):
assert len(args) > 1 and all(len(args[0]) == len(a) for a in args[1:])
return zip(*args)