Skip to content

Commit ccdbdcd

Browse files
committed
feat:mult-ranks adaptation(dist)
1 parent 9fbb804 commit ccdbdcd

File tree

9 files changed

+76
-46
lines changed

9 files changed

+76
-46
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
# Common hooks
33
- repo: https://github.com/pre-commit/pre-commit-hooks
4-
rev: v4.1.0
4+
rev: v6.0.0
55
hooks:
66
- id: check-added-large-files
77
- id: check-merge-conflict

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ from padiff.utils import logger
5656
from padiff import compare_dumps
5757

5858
if __name__ == "__main__":
59-
logger.reset_dir( "./padiff_log")
59+
logger.setup( "./padiff_log")
6060

6161
cfg = {
6262
"atol": 1e-6,

padiff/abstracts/hooks/guard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def PaDiffGuard(
248248
# set max calls
249249
calls_context.set_limit(max_calls)
250250

251-
proxy_model = create_model(model, name=name, reset_dir=reset_flag)
251+
proxy_model = create_model(model, name=name, reset=reset_flag)
252252
model._padiff_proxy = proxy_model
253253
logger.debug(f"PaDiffGuard: creating proxy model.")
254254

padiff/abstracts/proxy/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
# this folder is just used to support assign_weight interface
1616

17+
import os
1718
import paddle
18-
from ...utils import logger
19+
from ...utils import logger, reset_dir
1920
from .model import ProxyModel
2021

2122

@@ -42,14 +43,15 @@ def remove_inplace(model):
4243
submodel.inplace = False
4344

4445

45-
def create_model(model, name=None, dump_freq=1, reset_dir=True):
46+
def create_model(model, name=None, dump_freq=1, reset=True):
4647
retval = ProxyModel.create_from(model, name, dump_freq)
4748
init_route(retval)
48-
if retval.framework == "paddle" and paddle.distributed.get_rank() % 8 == 0 and reset_dir:
49+
if retval.framework == "paddle" and retval.rank // 8 == 0 and reset:
4950
# Only reset the root path once for each machine, here we assume each machine has 8 GPUs
50-
logger.reset_dir(retval.dump_path)
51+
reset_dir(retval.dump_path)
5152
if retval.framework == "torch":
52-
if reset_dir:
53-
logger.reset_dir(retval.dump_path)
53+
if reset:
54+
reset_dir(retval.dump_path)
5455
remove_inplace(retval)
56+
logger.setup(retval.dump_path)
5557
return retval

padiff/abstracts/proxy/model.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..marker import Marker
2121
from ..report import Report
2222
from ...tools import dump_grads, dump_params, dump_report, dump_weights, get_dump_root_path
23-
from ...utils import deco_iter, logger
23+
from ...utils import deco_iter, logger, get_rank, reset_dir
2424
from .params import ProxyParam
2525

2626

@@ -34,7 +34,8 @@ def __init__(self, model, name, framework, dump_freq=1):
3434
self.report = Report(self.marker)
3535
self.step = 0
3636

37-
self.dump_path = get_dump_root_path() + "/" + self.name
37+
self.rank = get_rank(framework)
38+
self.dump_path = f"{get_dump_root_path()}/{self.name}/rank_{self.rank}"
3839

3940
self.dump_freq = dump_freq
4041
if self.dump_freq > 1:
@@ -190,31 +191,27 @@ def clear_report(self):
190191
def try_dump(self, dump_path=None):
191192
if self.step % self.dump_freq == 0:
192193
if dump_path is None:
193-
dump_path = f"{self.dump_path}/step_{self.step}/rank_{paddle.distributed.get_rank()}"
194-
logger.reset_dir(dump_path)
194+
dump_path = f"{self.dump_path}/step_{self.step}/rank_{self.rank}"
195+
reset_dir(dump_path)
195196
self.dump_params(dump_path)
196197
self.dump_report(dump_path)
197198
self.clear_report()
198199
self.step += 1
199200

200201
def dump_report(self, dump_path=None):
201-
if dump_path is None:
202-
dump_path = f"{self.dump_path}/rank_{paddle.distributed.get_rank()}"
202+
dump_path = self.dump_path if dump_path is None else dump_path
203203
dump_report(self, dump_path)
204204

205205
def dump_params(self, dump_path=None):
206-
if dump_path is None:
207-
dump_path = f"{self.dump_path}/rank_{paddle.distributed.get_rank()}"
206+
dump_path = self.dump_path if dump_path is None else dump_path
208207
dump_params(self, dump_path)
209208

210209
def dump_weights(self, dump_path=None):
211-
if dump_path is None:
212-
dump_path = f"{self.dump_path}/rank_{paddle.distributed.get_rank()}"
210+
dump_path = self.dump_path if dump_path is None else dump_path
213211
dump_weights(self, dump_path)
214212

215213
def dump_grads(self, dump_path=None):
216-
if dump_path is None:
217-
dump_path = f"{self.dump_path}/rank_{paddle.distributed.get_rank()}"
214+
dump_path = self.dump_path if dump_path is None else dump_path
218215
dump_grads(self, dump_path)
219216

220217
"""

padiff/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def main():
228228
cli_cfg["base_framework"] = args.base_framework
229229

230230
log_dir = cli_cfg.pop("log_dir", "./padiff_log")
231-
logger.reset_dir(log_dir)
231+
logger.setup(log_dir)
232232

233233
pt_cmd = cli_cfg.get("pt_cmd")
234234
pd_cmd = cli_cfg.get("pd_cmd")

padiff/tools/dump.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy
2020
import paddle
2121

22-
from ..utils import Counter, frames_to_string, logger, save_model_struct, get_numpy_from_tensor
22+
from ..utils import Counter, frames_to_string, logger, save_model_struct, get_numpy_from_tensor, reset_dir
2323

2424
dump_root_path = os.path.join(sys.path[0], "padiff_dump")
2525

@@ -34,7 +34,7 @@ def get_dump_root_path():
3434

3535

3636
def numpy_dumper(path, prefix):
37-
logger.reset_dir(path)
37+
reset_dir(path)
3838
counter = Counter()
3939

4040
def dumper(value):
@@ -51,8 +51,30 @@ def dumper(value):
5151
"""
5252

5353

54+
def report_deduplicate(report):
55+
"""
56+
In some cases, it is necessary to deduplicate report (e.g., when pipeline parallelism > 1
57+
and gradient accumulation steps > 1).
58+
"""
59+
if not report.stack.root:
60+
return report
61+
62+
base_root_str = report.stack.root[0].net_str
63+
idx_repeat = len(report.stack.root)
64+
for i in range(1, len(report.stack.root)):
65+
if report.stack.root[i].net_str == base_root_str:
66+
idx_repeat = i
67+
break
68+
report.stack.root = report.stack.root[:idx_repeat]
69+
logger.warning_once(
70+
"The report contains duplications, which might occur when pipeline parallelism > 1 and gradient accumulation > 1. "
71+
"The report is automatically deduplicated, please note if this deduplication was incorrect."
72+
)
73+
return report
74+
75+
5476
def dump_report(model, dump_path):
55-
report = model.report
77+
report = report_deduplicate(model.report)
5678
tensor_path = dump_path + "/tensors"
5779
tensor_dumper = numpy_dumper(tensor_path, "tensor")
5880

padiff/utils/log.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
class Logger:
3131
def __init__(self):
3232
self._logger = None
33-
self._is_initialized = False
3433
self.log_path = "padiff_log"
3534

3635
for key, conf in log_config.items():
@@ -41,26 +40,25 @@ def __init__(self):
4140
log_colors={key: conf["color"] for key, conf in log_config.items()},
4241
)
4342

44-
def setup(self, log_parent_dir):
45-
if self._is_initialized:
46-
return
47-
48-
self._logger = logging.getLogger("padiff")
43+
def setup(self, log_root_dir):
44+
os.makedirs(log_root_dir, exist_ok=True)
4945

5046
silent_flag = os.getenv("PADIFF_SILENT")
5147
log_level_flag = os.getenv("PADIFF_LOG_LEVEL")
52-
5348
if log_level_flag and log_level_flag.upper() in ("DEBUG", "INFO", "WARNING", "ERROR"):
5449
log_level = getattr(logging, log_level_flag.upper())
5550
else:
5651
log_level = logging.INFO
57-
self._logger.setLevel(log_level)
58-
self._logger.propagate = False
5952

60-
if self._logger.handlers:
61-
self._logger.handlers.clear()
53+
if self._logger is None:
54+
self._logger = logging.getLogger("padiff")
55+
self._logger.setLevel(log_level)
56+
self._logger.propagate = False
57+
58+
if self._logger.handlers:
59+
self._logger.handlers.clear()
6260

63-
log_file_path = os.path.join(log_parent_dir, "padiff.log")
61+
log_file_path = os.path.join(log_root_dir, "padiff.log")
6462
file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
6563
file_formatter = logging.Formatter("[AutoDiff] [%(levelname)s] %(message)s")
6664
file_handler.setFormatter(file_formatter)
@@ -74,7 +72,7 @@ def setup(self, log_parent_dir):
7472
self._logger.info(f"Logging initialized. Log file: {log_file_path}")
7573

7674
self._is_initialized = True
77-
self.log_path = log_parent_dir
75+
self.log_path = log_root_dir
7876

7977
def info(self, *args):
8078
if self._logger is not None:
@@ -108,14 +106,10 @@ def debug(self, *args):
108106
def debug_once(self, *args):
109107
self.debug(*args)
110108

111-
def reset_dir(self, path):
112-
if os.path.exists(path):
113-
shutil.rmtree(path)
114-
os.makedirs(path)
115-
self.setup(path)
116-
117-
def log_file(self, filename, mode, info):
118-
filepath = os.path.join(self.log_path, filename)
109+
def log_file(self, filename, mode, info, root_dir=None):
110+
if root_dir is None:
111+
root_dir = self.log_path
112+
filepath = os.path.join(root_dir, filename)
119113
with open(filepath, mode) as f:
120114
f.write(info)
121115
return filepath
@@ -129,6 +123,12 @@ def log_file(self, filename, mode, info):
129123
"""
130124

131125

126+
def reset_dir(path):
127+
if os.path.exists(path):
128+
shutil.rmtree(path)
129+
os.makedirs(path)
130+
131+
132132
def print_report_info(nodes, reports, exc, stage, msg=None):
133133

134134
logger.error(

padiff/utils/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ def get_numpy_from_tensor(tensor):
5757
return np_array
5858

5959

60+
def get_rank(framework: str):
61+
rank = 0
62+
if framework == "paddle" and paddle.distributed.is_initialized():
63+
rank = paddle.distributed.get_rank()
64+
elif framework == "torch" and torch.distributed.is_initialized():
65+
rank = torch.distributed.get_rank()
66+
return rank
67+
68+
6069
"""
6170
clone tensor
6271
"""

0 commit comments

Comments
 (0)