-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathtorch_prof_optim.py
More file actions
136 lines (113 loc) · 5.13 KB
/
torch_prof_optim.py
File metadata and controls
136 lines (113 loc) · 5.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
import torch.profiler
from contextlib import nullcontext
# A simplified Transformer Encoder Layer to demonstrate profiling
class SimpleTransformerLayer(nn.Module):
"""
A basic implementation of a Transformer Encoder Layer.
It includes Multi-Head Self-Attention, Layer Normalization, and a Feed-Forward Network (MLP).
"""
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.ff_dim = ff_dim
# Self-Attention Layer
self.self_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
# Feed-Forward Network (MLP)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ff_dim, embed_dim),
)
# Layer Normalization
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
# Dropout for residual connections
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, src):
"""
Forward pass for the Transformer Layer.
We will wrap key sections with record_function to create distinct profiler labels.
"""
# 1. Self-Attention Block
with torch.profiler.record_function("Self-Attention"):
# The multi-head attention layer expects query, key, and value inputs.
# For self-attention, the source tensor is used for all three.
# OPTIM set need_weights=False to allow flash attn kernel use
attn_output, _ = self.self_attention(src, src, src, need_weights=False)
# Residual connection with dropout and LayerNorm
with torch.profiler.record_function("Add & Norm 1"):
src = src + self.dropout1(attn_output)
src = self.norm1(src)
# 2. MLP (Feed-Forward) Block
with torch.profiler.record_function("MLP"):
mlp_output = self.mlp(src)
# Residual connection with dropout and LayerNorm
with torch.profiler.record_function("Add & Norm 2"):
src = src + self.dropout2(mlp_output)
src = self.norm2(src)
return src
def main():
# --- Model and Input Configuration ---
batch_size = 32
# OPTIM increase sequence length and batch size
seq_length = 2048
embed_dim = 512 # d_model
num_heads = 8
ff_dim = 2048 # Hidden dimension in MLP
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- Instantiate Model and Create Dummy Input ---
model = SimpleTransformerLayer(embed_dim, num_heads, ff_dim).to(device)
model.eval() # Set to evaluation mode for profiling inference
# --- OPTIM flip on Flash-SDPA back-end when on CUDA
if device.type == "cuda":
torch.backends.cuda.enable_flash_sdp(True)
# Create a random input tensor
# OPTIM do everything but layernorms (see autocast below) in fp16
dummy_input = torch.randn(batch_size, seq_length, embed_dim, dtype=torch.float16).to(device)
model = model.half()
print("\nModel and input tensor created. Starting profiler...")
print("-" * 50)
# --- Run the Profiler ---
# The profiler context manager will trace the execution and performance.
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA, # Only include if CUDA is available
],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/transformer'),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
# The profiler will step through a schedule of wait, warmup, and active phases.
# We need to call `prof.step()` at the end of each iteration.
for i in range(10):
# We only care about the forward pass for this example
# OPTIM keep the norms in fp32. Let the model autocast.
with torch.no_grad(), (
torch.autocast(device_type="cuda", dtype=torch.float16) if device.type == "cuda"
else nullcontext()
):
model(dummy_input)
prof.step() # Notify the profiler that a step is complete
# --- Print Profiler Results ---
print("Profiler run complete. Printing summary...")
print("-" * 50)
# Print a summary of the results to the console, grouped by our custom labels.
# The `group_by_input_shape` is useful for seeing how different tensor sizes perform.
# The `group_by_stack_n` helps attribute time to specific lines of code.
print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=15))
print("\n" + "-" * 50)
print("To view the detailed trace, run the following command in your terminal:")
print("tensorboard --logdir=./log")
print("-" * 50)
if __name__ == "__main__":
main()