-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathmicro_benchmarking_pytorch.py
More file actions
574 lines (518 loc) · 26 KB
/
micro_benchmarking_pytorch.py
File metadata and controls
574 lines (518 loc) · 26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
import torch
import torchvision
import random
import time
import argparse
import os
import sys
import ast
import copy
import math
import torch.nn as nn
import torch.multiprocessing as mp
from fp16util import network_to_half, get_param_copy
from shufflenet import shufflenet
from shufflenet_v2 import shufflenet as shufflenet_v2
from xception import xception
import csv
import json
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import LambdaLR
try:
import torch._dynamo
torch._dynamo.config.verbose=True
HAVE_DYNAMO = True
except:
HAVE_DYNAMO = False
IS_PT2 = hasattr(torch, "compile")
is_torchrun = False
if "LOCAL_RANK" in os.environ:
# this indicates we're using torchrun
is_torchrun = True
try:
import apex
HAVE_APEX = True
except:
HAVE_APEX = False
def xform(m: nn.Module) -> nn.Module:
m = m.cuda()
m.to(memory_format=torch.channels_last)
return m
def weight_init(m):
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# num_classes=1000
models = {
"alexnet" : torchvision.models.alexnet,
"densenet121" : torchvision.models.densenet121,
"densenet161" : torchvision.models.densenet161,
"densenet169" : torchvision.models.densenet169,
"densenet201" : torchvision.models.densenet201,
"googlenet" : torchvision.models.googlenet,
"inception_v3" : torchvision.models.inception_v3,
"mnasnet0_5" : torchvision.models.mnasnet0_5,
"mnasnet0_75" : torchvision.models.mnasnet0_75,
"mnasnet1_0" : torchvision.models.mnasnet1_0,
"mnasnet1_3" : torchvision.models.mnasnet1_3,
"mobilenet_v2" : torchvision.models.mobilenet_v2,
"resnet18" : torchvision.models.resnet18,
"resnet34" : torchvision.models.resnet34,
"resnet50" : torchvision.models.resnet50,
"resnet101" : torchvision.models.resnet101,
"resnet152" : torchvision.models.resnet152,
"resnext50" : torchvision.models.resnext50_32x4d,
"resnext50_32x4d" : torchvision.models.resnext50_32x4d,
"resnext101" : torchvision.models.resnext101_32x8d,
"resnext101_32x8d" : torchvision.models.resnext101_32x8d,
"shufflenet" : shufflenet,
"shufflenet_v2" : shufflenet_v2,
"shufflenet_v2_x05" : torchvision.models.shufflenet_v2_x0_5,
"shufflenet_v2_x10" : torchvision.models.shufflenet_v2_x1_0,
"shufflenet_v2_x15" : torchvision.models.shufflenet_v2_x1_5,
"shufflenet_v2_x20" : torchvision.models.shufflenet_v2_x2_0,
"shufflenet_v2_x0_5" : torchvision.models.shufflenet_v2_x0_5,
"shufflenet_v2_x1_0" : torchvision.models.shufflenet_v2_x1_0,
"shufflenet_v2_x1_5" : torchvision.models.shufflenet_v2_x1_5,
"shufflenet_v2_x2_0" : torchvision.models.shufflenet_v2_x2_0,
"SqueezeNet" : torchvision.models.squeezenet1_0,
"squeezenet1_0" : torchvision.models.squeezenet1_0,
"SqueezeNet1.1" : torchvision.models.squeezenet1_1,
"squeezenet1_1" : torchvision.models.squeezenet1_1,
"vgg11" : torchvision.models.vgg11,
"vgg13" : torchvision.models.vgg13,
"vgg16" : torchvision.models.vgg16,
"vgg19" : torchvision.models.vgg19,
"vgg11_bn" : torchvision.models.vgg11_bn,
"vgg13_bn" : torchvision.models.vgg13_bn,
"vgg16_bn" : torchvision.models.vgg16_bn,
"vgg19_bn" : torchvision.models.vgg19_bn,
"wide_resnet50_2" : torchvision.models.wide_resnet50_2,
"wide_resnet101_2" : torchvision.models.wide_resnet101_2,
"xception" : xception,
}
# newer torchvision models, for backwards compat
try:
models["swin_t"] = torchvision.models.swin_t
models["swin_s"] = torchvision.models.swin_s
models["swin_b"] = torchvision.models.swin_b
models["swin_v2_t"] = torchvision.models.swin_v2_t
models["swin_v2_s"] = torchvision.models.swin_v2_s
models["swin_v2_b"] = torchvision.models.swin_v2_b
models["vit_b_16"] = torchvision.models.vit_b_16
models["vit_b_32"] = torchvision.models.vit_b_32
models["vit_l_16"] = torchvision.models.vit_l_16
models["vit_l_32"] = torchvision.models.vit_l_32
models["vit_h_14"] = torchvision.models.vit_h_14
models["efficientnet_b0"] = torchvision.models.efficientnet_b0
models["efficientnet_b1"] = torchvision.models.efficientnet_b1
models["efficientnet_b2"] = torchvision.models.efficientnet_b2
models["efficientnet_b3"] = torchvision.models.efficientnet_b3
models["efficientnet_b4"] = torchvision.models.efficientnet_b4
models["efficientnet_b5"] = torchvision.models.efficientnet_b5
models["efficientnet_b6"] = torchvision.models.efficientnet_b6
models["efficientnet_b7"] = torchvision.models.efficientnet_b7
models["maxvit_t"] = torchvision.models.maxvit_t
except AttributeError:
pass
try:
models["mobilenet_v3_large"] = torchvision.models.mobilenet_v3_large
models["mobilenet_v3_small"] = torchvision.models.mobilenet_v3_small
except AttributeError:
pass
# segmentation models, num_classes=21
segmentation_models = {
"fcn_resnet50" : torchvision.models.segmentation.fcn_resnet50,
"fcn_resnet101" : torchvision.models.segmentation.fcn_resnet101,
"deeplabv3_resnet50" : torchvision.models.segmentation.deeplabv3_resnet50,
"deeplabv3_resnet101" : torchvision.models.segmentation.deeplabv3_resnet101,
}
# newer torchvision segmentation models, for backwards compat
try:
segmentation_models["deeplabv3_mobilenet_v3_large"] = torchvision.models.segmentation.deeplabv3_mobilenet_v3_large
segmentation_models["lraspp_mobilenet_v3_large"] = torchvision.models.segmentation.lraspp_mobilenet_v3_large,
except AttributeError:
pass
def get_network_names():
return sorted(list(models.keys()) + list(segmentation_models.keys()))
def get_network(net, params):
# aux_logits=False only used by inception_v3
if "inception_v3" == net:
if params.nhwc:
return xform(models[net](aux_logits=False))
return models[net](aux_logits=False).to(device="cuda")
elif net in models:
if params.nhwc:
return xform(models[net]())
return models[net]().to(device="cuda")
elif net in segmentation_models:
if params.nhwc:
return xform(segmentation_models[net]())
return segmentation_models[net]().to(device="cuda")
else:
print ("ERROR: not a supported model '%s'" % net)
sys.exit(1)
def forwardbackward(inp, optimizer, network, params, target, scaler, step=0, opt_step=1, flops_prof_step=0):
if step % opt_step == 0:
optimizer.zero_grad()
if flops_prof_step:
prof = FlopsProfiler(network)
prof.start_profile()
# AMP
if params.amp:
with autocast('cuda'):
out = network(inp)
# If using HuggingFace model outputs logits, we need to extract them
if hasattr(out, 'logits'):
logits = out.logits
else:
logits = out
loss_fn = torch.nn.CrossEntropyLoss().to(device="cuda")
if params.nhwc:
loss_fn = loss_fn.to(memory_format=torch.channels_last)
loss = loss_fn(logits, target)
scaler.scale(loss).backward()
if (step + 1) % opt_step == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# Not use amp (autocast and scaler)
else:
out = network(inp)
# If using HuggingFace model outputs logits, we need to extract them
if hasattr(out, 'logits'):
logits = out.logits
else:
logits = out
loss_fn = torch.nn.CrossEntropyLoss().to(device="cuda")
if params.nhwc:
loss_fn = loss_fn.to(memory_format=torch.channels_last)
loss = loss_fn(logits, target)
loss.backward()
if (step + 1) % opt_step == 0:
optimizer.step()
optimizer.zero_grad()
if flops_prof_step:
# End profiler here to profile both fwd and bwd passes
# flops = prof.get_total_flops(as_string=True)
# params = prof.get_total_params(as_string=True)
prof.print_model_profile(profile_step=flops_prof_step)
prof.end_profile()
def forward(inp, optimizer, network, params, target=None, scaler=None, step=0, opt_step=1, flops_prof_step=0):
if flops_prof_step:
prof = FlopsProfiler(network)
prof.start_profile()
# Run the forward pass
with torch.no_grad(): # Disable gradient calculation
if params.amp:
with autocast('cuda'):
out = network(inp)
else:
out = network(inp)
if hasattr(out, 'logits'):
return out.logits
return out
if flops_prof_step:
# End profiler here to profile the forward pass
prof.print_model_profile(profile_step=flops_prof_step)
prof.end_profile()
return out
def rendezvous(distributed_parameters):
print("Initializing process group...")
torch.distributed.init_process_group(backend=distributed_parameters['dist_backend'], init_method=distributed_parameters['dist_url'], rank=distributed_parameters['rank'], world_size=distributed_parameters['world_size'])
print("Rendezvous complete. Created process group...")
def run_benchmarking_wrapper(params):
params.flops_prof_step = max(0, min(params.flops_prof_step, params.iterations - 1))
if (params.device_ids):
params.device_ids = [int(x) for x in params.device_ids.split(",")]
else:
params.device_ids = None
params.distributed_parameters = {}
if is_torchrun:
params.distributed_parameters['rank'] = int(os.environ["LOCAL_RANK"])
params.distributed_parameters['world_size'] = int(os.environ["WORLD_SIZE"])
params.distributed_parameters['dist_backend'] = "nccl"
params.distributed_parameters['dist_url'] = 'tcp://' + os.environ["MASTER_ADDR"] + ":" + os.environ["MASTER_PORT"]
else:
params.distributed_parameters['rank'] = params.rank
params.distributed_parameters['world_size'] = params.world_size
params.distributed_parameters['dist_backend'] = params.dist_backend
params.distributed_parameters['dist_url'] = params.dist_url
# Some arguments are required for distributed_dataparallel
if params.distributed_dataparallel:
assert params.distributed_parameters['rank'] is not None and \
params.distributed_parameters['world_size'] is not None and \
params.distributed_parameters['dist_backend'] is not None and \
params.distributed_parameters['dist_url'] is not None, "rank, world-size, dist-backend and dist-url are required arguments for distributed_dataparallel"
if is_torchrun:
params.ngpus = params.distributed_parameters['world_size']
elif params.distributed_dataparallel:
params.ngpus = len(params.device_ids) if params.device_ids else torch.cuda.device_count()
else:
params.ngpus = 1
if is_torchrun:
run_benchmarking(params.distributed_parameters['rank'], params)
elif params.distributed_dataparallel:
# Assumption below that each process launched with --distributed_dataparallel has the same number of devices visible/specified
params.distributed_parameters['world_size'] = params.ngpus * params.distributed_parameters['world_size']
params.distributed_parameters['rank'] = params.ngpus * params.distributed_parameters['rank']
mp.spawn(run_benchmarking, nprocs=params.ngpus, args=(params,))
else:
run_benchmarking(0, params)
def run_benchmarking(local_rank, params):
device_ids = params.device_ids
ngpus = params.ngpus
net = params.network
run_fp16 = params.fp16
run_amp = params.amp
distributed_dataparallel = params.distributed_dataparallel
distributed_parameters = params.distributed_parameters
batch_size = params.batch_size
kineto = params.kineto
iterations = params.iterations
autograd_profiler = params.autograd_profiler
flops_prof_step = params.flops_prof_step
if is_torchrun:
torch.cuda.set_device("cuda:%d" % local_rank)
elif device_ids:
assert ngpus == len(device_ids)
torch.cuda.set_device("cuda:%d" % device_ids[local_rank])
else:
torch.cuda.set_device("cuda:0")
network = get_network(net, params)
if "shufflenet" == net:
network.apply(weight_init)
if (run_fp16):
network = network_to_half(network)
if params.compile:
compile_ctx = {"mode": None,
"dynamic": False,
"fullgraph": False,
"backend": "inductor",
"options": None,
"disable": False}
options = None # needed for internal pytorch checks
if params.compileContext:
compile_ctx.update(ast.literal_eval(params.compileContext))
if compile_ctx["mode"] is not None and compile_ctx["options"] is not None:
raise RuntimeError("Cannot specify mode and options simultaneously")
if compile_ctx["options"] is not None:
options = {} # needed to save multiple options
for compiler_pass in compile_ctx["options"].keys():
options.update({compiler_pass: bool(compile_ctx["options"][compiler_pass])})
if IS_PT2:
network = torch.compile(network,
mode=compile_ctx["mode"],
dynamic=bool(compile_ctx["dynamic"]),
fullgraph=bool(compile_ctx["fullgraph"]),
backend=compile_ctx["backend"],
options=options,
disable=compile_ctx["disable"])
else:
print ("ERROR: requested torch.compile but this isn't pytorch 2.x")
sys.exit(1)
param_copy = network.parameters()
if (run_fp16):
param_copy = get_param_copy(network)
## MLPerf Setting
sgd_opt_base_learning_rate = 0.01
sgd_opt_end_learning_rate = 1e-4
sgd_opt_learning_rate_decay_poly_power = 2
sgd_opt_weight_decay = 0.0001
sgd_opt_momentum = 0.9
opt_learning_rate_warmup_epochs = 5
total_epochs = params.iterations
optimizer = torch.optim.SGD(param_copy, lr = sgd_opt_base_learning_rate, momentum = sgd_opt_momentum, weight_decay=sgd_opt_weight_decay)
def poly_decay(epoch):
if epoch < opt_learning_rate_warmup_epochs:
return float(epoch + 1) / opt_learning_rate_warmup_epochs
else:
poly = ((1 - (epoch - opt_learning_rate_warmup_epochs) / (total_epochs - opt_learning_rate_warmup_epochs)) ** sgd_opt_learning_rate_decay_poly_power)
return (sgd_opt_end_learning_rate + (sgd_opt_base_learning_rate - sgd_opt_end_learning_rate) * poly) / sgd_opt_base_learning_rate
scheduler = LambdaLR(optimizer, lr_lambda=poly_decay)
if is_torchrun:
rendezvous(distributed_parameters)
devices_to_run_on = [local_rank]
print ("INFO: Rank {} running distributed_dataparallel on devices: {}".format(distributed_parameters['rank'], str(devices_to_run_on)))
network = torch.nn.parallel.DistributedDataParallel(network, device_ids=devices_to_run_on)
batch_size = int(batch_size / ngpus)
elif (distributed_dataparallel):
distributed_parameters['rank'] += local_rank
rendezvous(distributed_parameters)
devices_to_run_on = [(device_ids[local_rank] if device_ids else local_rank)]
print ("INFO: Rank {} running distributed_dataparallel on devices: {}".format(distributed_parameters['rank'], str(devices_to_run_on)))
network = torch.nn.parallel.DistributedDataParallel(network, device_ids=devices_to_run_on)
batch_size = int(batch_size / ngpus)
if (net == "inception_v3"):
inp = torch.randn(batch_size, 3, 299, 299, device="cuda")
if params.nhwc:
inp = inp.to(memory_format=torch.channels_last)
else:
inp = torch.randn(batch_size, 3, 224, 224, device="cuda")
if params.nhwc:
inp = inp.to(memory_format=torch.channels_last)
if (run_fp16):
inp = inp.half()
if params.nhwc:
inp = inp.to(memory_format=torch.channels_last)
if net in models:
# number of classes is 1000 for imagenet
target = torch.randint(0, 1000, (batch_size,), device="cuda")
elif net in segmentation_models:
# number of classes is 21 for segmentation
target = torch.randint(0, 21, (batch_size,), device="cuda")
if params.mode == "training":
forward_fn = forwardbackward
network.train()
else:
forward_fn = forward
network.eval()
scaler = GradScaler('cuda')
## warmup.
print ("INFO: running forward and backward for warmup.")
for i in range(20):
forward_fn(inp, optimizer, network, params, target, scaler=scaler, step=0, opt_step=params.opt_step)
torch.cuda.synchronize()
## benchmark.
print ("INFO: running the benchmark..")
if kineto:
from torch.profiler import schedule, profile, ProfilerActivity, record_function
profiler_schedule = schedule(
skip_first = 0,
wait = 1,
warmup = 2,
active = 5,
repeat = 1,
)
def trace_ready_callback(prof):
rank = 0
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
if rank == 0:
print("----------- Trace Ready -----------")
prof.export_chrome_trace(f"{params.profiler_output}.json")
# print(f"----------- Rank {rank} Trace Ready -----------")
# prof.export_chrome_trace(f"{params.profiler_output}_rank{rank}.json")
tm = time.time()
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=profiler_schedule,
on_trace_ready=trace_ready_callback) as prof:
for i in range(iterations):
with record_function(f"iteration {i}"):
forward_fn(inp, optimizer, network, params, target, scaler=scaler, step=i, opt_step=params.opt_step)
prof.step()
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by="cuda_time_total"))
else:
tm = time.time()
with torch.autograd.profiler.emit_nvtx(enabled=autograd_profiler):
for i in range(iterations):
if i == flops_prof_step:
forward_fn(inp, optimizer, network, params, target, scaler=scaler, step=i, opt_step=params.opt_step, flops_prof_step=i)
else:
forward_fn(inp, optimizer, network, params, target, scaler=scaler, step=i, opt_step=params.opt_step)
torch.cuda.synchronize()
tm2 = time.time()
time_per_batch = (tm2 - tm) / iterations
if run_fp16:
dtype = 'FP16'
elif run_amp:
dtype = 'AMP: PyTorch Native Automatic Mixed Precision'
else:
dtype = 'FP32'
result = None
if not params.output_dir:
params.output_dir = "."
print ("OK: finished running benchmark..")
print ("--------------------SUMMARY--------------------------")
print ("Microbenchmark for network : {}".format(net))
if distributed_dataparallel or is_torchrun:
print ("--------This process: rank " + str(distributed_parameters['rank']) + "--------");
print ("Num devices: 1")
else:
print ("Num devices: {}".format(ngpus))
result = {
"Name": params.output_file,
"GPUs": 1,
"Mini batch size [img]": batch_size,
"Mini batch size [img/gpu]": batch_size,
"Throughput [img/sec]": batch_size / time_per_batch,
"Time per mini-batch": time_per_batch
}
with open(f"{params.output_dir}/{params.output_file}.json", "w") as f:
json.dump(result, f, indent=2)
print ("Dtype: {}".format(dtype))
print ("Mini batch size [img] : {}".format(batch_size))
print ("Throughput [img/sec] : {}".format(batch_size/time_per_batch))
print ("Time per mini-batch : {}".format(time_per_batch))
if (distributed_dataparallel or is_torchrun) and distributed_parameters['rank'] == 0:
print ("")
print ("--------Overall (all ranks) (assuming same num/type devices for each rank)--------")
world_size = distributed_parameters['world_size']
print ("Num devices: {}".format(world_size))
print ("Dtype: {}".format(dtype))
print ("Mini batch size [img] : {}".format(batch_size*world_size))
print ("Throughput [img/sec] : {}".format(batch_size*world_size/time_per_batch))
print ("Time per mini-batch : {}".format(time_per_batch))
result = {
"Name": params.output_file,
"GPUs": distributed_parameters['world_size'],
"Mini batch size [img]": batch_size * distributed_parameters['world_size'],
"Mini batch size [img/gpu]": batch_size,
"Throughput [img/sec]": batch_size * distributed_parameters['world_size'] / time_per_batch,
"Time per mini-batch": time_per_batch
}
with open(f"{params.output_dir}/{params.output_file}.json", "w") as f:
json.dump(result, f, indent=2)
csv_filename = f"{params.output_dir}/benchmark_summary.csv"
if params.csv_file:
csv_filename = params.csv_file
file_exists = os.path.isfile(csv_filename)
if result:
with open(csv_filename, "a", newline='') as csvfile:
writer = csv.writer(csvfile)
if not file_exists:
writer.writerow(result.keys())
writer.writerow(result.values())
print(f"Benchmark result saved to {csv_filename}")
def main():
run_benchmarking_wrapper(copy.deepcopy(args))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices=get_network_names(), required=True, help="Network to run.")
parser.add_argument("--batch-size" , type=int, required=False, default=64, help="Batch size (will be split among devices used by this invocation)")
parser.add_argument("--iterations", type=int, required=False, default=20, help="Iterations")
parser.add_argument("--flops-prof-step", type=int, required=False, default=0, help="The flops profiling step")
parser.add_argument("--kineto", action='store_true', required=False, help="Turn kineto profiling on")
parser.add_argument("--autograd_profiler", action='store_true', required=False, help="Use PyTorch autograd (old) profiler")
parser.add_argument("--fp16", type=int, required=False, default=0,help="FP16 mixed precision benchmarking")
parser.add_argument("--distributed_dataparallel", action='store_true', required=False, help="Use torch.nn.parallel.DistributedDataParallel api to run on multiple processes/nodes. The multiple processes need to be launched manually, this script will only launch ONE process per invocation. Either use --distributed_dataparallel and manually launch multiple processes or launch this script with `torchrun`")
parser.add_argument("--device_ids", type=str, required=False, default=None, help="Comma-separated list (no spaces) to specify which HIP devices (0-indexed) to run distributedDataParallel api on. Might need to use HIP_VISIBLE_DEVICES to limit visiblity of devices to different processes.")
parser.add_argument("--rank", type=int, required=False, default=None, help="Rank of this process. Required for --distributed_dataparallel")
parser.add_argument("--world-size", type=int, required=False, default=None, help="Total number of ranks/processes. Required for --distributed_dataparallel")
parser.add_argument("--dist-backend", type=str, required=False, default=None, help="Backend used for distributed training. Can be one of 'nccl' or 'gloo'. Required for --distributed_dataparallel")
parser.add_argument("--dist-url", type=str, required=False, default=None, help="url used for rendezvous of processes in distributed training. Needs to contain IP and open port of master rank0 eg. 'tcp://172.23.2.1:54321'. Required for --distributed_dataparallel")
parser.add_argument("--compile", action='store_true', required=False, help="use pytorch 2.0")
parser.add_argument("--compileContext", default={}, required=False, help="additional compile options")
parser.add_argument("--amp", action='store_true', default=False, required=False, help="Automatic mixed precision benchmarking")
parser.add_argument("--csv-file", type=str, default=None, required=False, help="assign output csv file name.")
parser.add_argument("--mode", type=str, choices=['training', 'inference'], default="training", help="Select mode: training or inference")
parser.add_argument("--nhwc", action='store_true', default=False, help="Use nhwc format")
parser.add_argument("--opt-step", type=int, required=False, default=1, help="Optimizer update step")
parser.add_argument("--output-dir", type=str, default="", help="assign output directory name.")
parser.add_argument("--output-file", type=str, default="", help="assign output file name.")
parser.add_argument("--profiler-output", type=str, default="", help="assign profiler output name.")
args = parser.parse_args()
if args.flops_prof_step:
try:
from deepspeed.profiling.flops_profiler import FlopsProfiler
except:
print("ERROR: You must install (or copy) deepspeed.profiling to use --flops-prof-step")
sys.exit(1)
main()