diff --git a/recipes_source/torch_logs.py b/recipes_source/torch_logs.py index b5c3f0bd8a..05100745c2 100644 --- a/recipes_source/torch_logs.py +++ b/recipes_source/torch_logs.py @@ -32,47 +32,48 @@ import torch -# exit cleanly if we are on a device that doesn't support torch.compile -if torch.cuda.get_device_capability() < (7, 0): +# torch.compile supports CPU and CUDA devices with compute capability >= 7.0. +# If a CUDA device is available but has insufficient capability, we skip. +if torch.cuda.is_available() and torch.cuda.get_device_capability() < (7, 0): print("Skipping because torch.compile is not supported on this device.") else: + device = "cuda" if torch.cuda.is_available() else "cpu" + @torch.compile() def fn(x, y): z = x + y return z + 2 - - inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda")) + inputs = (torch.ones(2, 2, device=device), torch.zeros(2, 2, device=device)) -# print separator and reset dynamo -# between each example + # print separator and reset dynamo + # between each example def separator(name): print(f"==================={name}=========================") torch._dynamo.reset() - separator("Dynamo Tracing") -# View dynamo tracing -# TORCH_LOGS="+dynamo" + # View dynamo tracing + # TORCH_LOGS="+dynamo" torch._logging.set_logs(dynamo=logging.DEBUG) fn(*inputs) separator("Traced Graph") -# View traced graph -# TORCH_LOGS="graph" + # View traced graph + # TORCH_LOGS="graph" torch._logging.set_logs(graph=True) fn(*inputs) separator("Fusion Decisions") -# View fusion decisions -# TORCH_LOGS="fusion" + # View fusion decisions + # TORCH_LOGS="fusion" torch._logging.set_logs(fusion=True) fn(*inputs) separator("Output Code") -# View output code generated by inductor -# TORCH_LOGS="output_code" + # View output code generated by inductor + # TORCH_LOGS="output_code" torch._logging.set_logs(output_code=True) fn(*inputs)