diff --git a/text_to_video/wan-2.2-t2v-a14b/run_mlperf.py b/text_to_video/wan-2.2-t2v-a14b/run_mlperf.py index 147624b340..ab73c25966 100644 --- a/text_to_video/wan-2.2-t2v-a14b/run_mlperf.py +++ b/text_to_video/wan-2.2-t2v-a14b/run_mlperf.py @@ -46,7 +46,8 @@ def load_prompts(dataset_path): class Model: - def __init__(self, model_path, device, config, prompts, fixed_latent=None, rank=0): + def __init__(self, model_path, device, config, + prompts, fixed_latent=None, rank=0): self.device = device self.rank = rank self.height = config["height"] @@ -106,7 +107,8 @@ def flush_queries(self): class DebugModel: - def __init__(self, model_path, device, config, prompts, fixed_latent=None, rank=0): + def __init__(self, model_path, device, config, + prompts, fixed_latent=None, rank=0): self.prompts = prompts def issue_queries(self, query_samples): @@ -186,7 +188,8 @@ def get_args(): parser.add_argument( "--scenario", default="SingleStream", - help="mlperf benchmark scenario, one of " + str(list(SCENARIO_MAP.keys())), + help="mlperf benchmark scenario, one of " + + str(list(SCENARIO_MAP.keys())), ) parser.add_argument( "--user_conf", @@ -202,7 +205,10 @@ def get_args(): help="performance sample count", default=5000, ) - parser.add_argument("--accuracy", action="store_true", help="enable accuracy pass") + parser.add_argument( + "--accuracy", + action="store_true", + help="enable accuracy pass") # Dont overwrite these for official submission parser.add_argument("--count", type=int, help="dataset items to use") parser.add_argument("--time", type=int, help="time to scan in seconds") @@ -271,7 +277,10 @@ def run_mlperf(args, config): audit_config = os.path.abspath(args.audit_conf) if os.path.exists(audit_config): - settings.FromConfig(audit_config, "wan-2.2-t2v-a14b", args.scenario) + settings.FromConfig( + audit_config, + "wan-2.2-t2v-a14b", + args.scenario) settings.scenario = SCENARIO_MAP[args.scenario] settings.mode = lg.TestMode.PerformanceOnly @@ -297,8 +306,10 @@ def run_mlperf(args, config): if args.samples_per_query: settings.multi_stream_samples_per_query = args.samples_per_query if args.max_latency: - settings.server_target_latency_ns = int(args.max_latency * NANO_SEC) - settings.multi_stream_expected_latency_ns = int(args.max_latency * NANO_SEC) + settings.server_target_latency_ns = int( + args.max_latency * NANO_SEC) + settings.multi_stream_expected_latency_ns = int( + args.max_latency * NANO_SEC) performance_sample_count = ( args.performance_sample_count @@ -311,7 +322,8 @@ def run_mlperf(args, config): count, performance_sample_count, load_query_samples, unload_query_samples ) - lg.StartTestWithLogSettings(sut, qsl, settings, log_settings, audit_config) + lg.StartTestWithLogSettings( + sut, qsl, settings, log_settings, audit_config) lg.DestroyQSL(qsl) lg.DestroySUT(sut) diff --git a/vision/classification_and_detection/tools/calibrate_torchvision_model.py b/vision/classification_and_detection/tools/calibrate_torchvision_model.py index 3b002003ab..875d26c388 100644 --- a/vision/classification_and_detection/tools/calibrate_torchvision_model.py +++ b/vision/classification_and_detection/tools/calibrate_torchvision_model.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader, Dataset import torchvision.transforms as transforms -from torchvision.models.quantization import * +import torchvision.models.quantization as torchvision_quantization_models class CalibrationDataset(Dataset): @@ -73,7 +73,11 @@ def main(): ) dataloader = DataLoader(dataset, batch_size=1) - model = eval(args.model)(pretrained=True, progress=True, quantize=False) + if not hasattr(torchvision_quantization_models, args.model): + raise ValueError(f"Model {args.model} not found in torchvision quantization models") + + + model = getattr(torchvision_quantization_models, args.model)(pretrained=True, progress=True, quantize=False) quantize_model(model, dataloader) print(model)