Skip to content

Commit 7ecb48c

Browse files
committed
conv -> patch
1 parent 3b50598 commit 7ecb48c

8 files changed

Lines changed: 62 additions & 82 deletions

File tree

README.md

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -91,56 +91,36 @@ flowchart TB
9191
emb --> L1 --> L2 --> LN --> head
9292
```
9393

94-
### 2. Conv1D UNet Transformer
95-
A U-Net style architecture that uses Conv1D for downsampling and ConvTranspose1D for upsampling. It progressively reduces sequence length while increasing hidden dimension, allowing the model to process information at hierarchically different resolutions.
94+
### 2. Transformer UNet
95+
A U-Net architecture that uses skip connections between encoder and decoder layers, but maintains the same sequence length and hidden dimension throughout (no downsampling). This allows the model to mix features from early and late layers.
9696

9797
```mermaid
9898
flowchart TB
9999
subgraph Input
100-
emb[Embedding Layer<br/>seq_len x hidden_size]
100+
emb[Embedding Layer]
101101
end
102-
102+
103103
subgraph Encoder[Encoder Path]
104-
e0[TransformerBlock 0<br/>FULL RESOLUTION<br/>1024 x 768]
105-
down1[Conv1D Downsample]
106-
e1[TransformerBlock 1<br/>512 x 832]
107-
down2[Conv1D Downsample]
108-
e2[TransformerBlock 2<br/>256 x 896]
109-
down3[...]
110-
eN[MLP Block if seq=1]
104+
e1[TransformerBlock 1]
105+
e2[TransformerBlock 2]
111106
end
112-
107+
113108
subgraph Decoder[Decoder Path]
114-
dN[MLP Block if seq=1]
115-
up1[ConvTranspose1D Upsample]
116-
d2[TransformerBlock + Skip<br/>256 x 896]
117-
up2[ConvTranspose1D Upsample]
118-
d1[TransformerBlock + Skip<br/>512 x 832]
119-
up3[ConvTranspose1D Upsample]
120-
d0[TransformerBlock N<br/>FULL RESOLUTION<br/>1024 x 768]
109+
d2[TransformerBlock 3 + Skip]
110+
d1[TransformerBlock 4 + Skip]
121111
end
122-
123-
subgraph ExtraLayers[Extra Sequential Layers]
124-
ex1[TransformerBlock<br/>Full Resolution]
125-
ex2[TransformerBlock<br/>Full Resolution]
126-
end
127-
112+
128113
subgraph Output
129114
head[LM Head]
130115
end
131-
132-
emb --> e0
133-
e0 --> down1 --> e1 --> down2 --> e2 --> down3 --> eN
134-
eN --> dN --> up1 --> d2 --> up2 --> d1 --> up3 --> d0
135-
d0 --> ex1 --> ex2 --> head
136-
137-
e0 -.->|skip| d0
116+
117+
emb --> e1 --> e2 --> d2 --> d1 --> head
138118
e1 -.->|skip| d1
139119
e2 -.->|skip| d2
140120
```
141121

142122
### 3. Patch UNet Transformer
143-
An optimized U-Net architecture designed for speed. It uses "Patch Merging" (concatenating adjacent tokens) for downsampling instead of convolutions, which is faster and cleaner. It operates on batched inputs `(B, L)` and efficiently handles document boundaries and padding without complex dynamic shape logic.
123+
An optimized U-Net architecture designed for speed. It uses "Patch Merging" (concatenating adjacent tokens) for downsampling, which is faster and cleaner than convolutions. It operates on batched inputs `(B, L)` and efficiently handles document boundaries and padding without complex dynamic shape logic.
144124

145125
```mermaid
146126
flowchart TB

example_yamls/debug.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ expansion_ratio: 2.0
2828
soft_logit_cap: 16.0
2929
tie_embeddings: false
3030
unet: true
31-
conv_unet: false
31+
patch_unet: false
3232
token_dropout: true
3333
bfloat16: true
3434
compile_model: false

example_yamls/default.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ auto_grad_clip_p: 10.0
2121
hidden_size: 768
2222
num_attention_heads: 6
2323
num_hidden_layers: 24
24-
num_unet_layers: 0 # Number of layers for Conv1D UNet (set to > 0 to use)
24+
num_unet_layers: 0 # Number of layers for Patch UNet (set to > 0 to use)
2525
num_extra_layers: 0 # Number of extra transformer layers after UNet
2626
vocab_size: 33
2727
expansion_ratio: 2.0
2828
soft_logit_cap: 32.0
2929
tie_embeddings: false
3030
unet: true
31-
conv_unet: false # Use Conv1D UNet with downsampling
31+
patch_unet: false # Use Patch UNet with downsampling
3232
token_dropout: true
3333
bfloat16: false
3434
compile_model: true
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# General Configuration
88
bugfix: false
9-
save_path: "Synthyra/speedrun_conv_unet"
9+
save_path: "Synthyra/speedrun_patch_unet"
1010
data_name: "uniref50"
1111
num_chunks: 197
1212
log_name: null
@@ -24,15 +24,15 @@ auto_grad_clip_p: 10.0
2424
# With only 6 heads: hidden dimension is capped to keep head_dim <= 128.
2525
hidden_size: 768
2626
num_attention_heads: 12
27-
num_hidden_layers: 0 # Not used for conv_unet
27+
num_hidden_layers: 0 # Not used for patch_unet
2828
num_unet_layers: 12 # 6 encoder + 6 decoder
2929
num_extra_layers: 4 # Extra full-resolution transformer layers after UNet
3030
vocab_size: 33
3131
expansion_ratio: 2.0
3232
soft_logit_cap: 32.0
3333
tie_embeddings: false
3434
unet: false # Standard UNet off
35-
conv_unet: true # Batched Conv UNet on
35+
patch_unet: true # Batched Patch UNet on
3636
token_dropout: false
3737
bfloat16: true
3838
compile_model: true
@@ -69,7 +69,7 @@ muon_momentum_warmup_steps: 300
6969

7070
# Evaluation & Logging
7171
eval_every: 1000
72-
hf_model_name: "lhallee/speedrun_conv_unet"
72+
hf_model_name: "lhallee/speedrun_patch_unet"
7373
save_every: null
7474

7575
# Dataloader Parameters
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
# General Configuration
66
bugfix: true
7-
save_path: "Synthyra/debug_conv_unet"
7+
save_path: "Synthyra/debug_patch_unet"
88
data_name: "uniref50"
99
num_chunks: 10
10-
log_name: "debug_conv_unet"
10+
log_name: "debug_patch_unet"
1111

1212
# Distributed Training & Reproducibility
1313
seed: 42
@@ -27,7 +27,7 @@ expansion_ratio: 2.0
2727
soft_logit_cap: 16.0
2828
tie_embeddings: false
2929
unet: false
30-
conv_unet: true
30+
patch_unet: true
3131
token_dropout: false
3232
bfloat16: true
3333
compile_model: false # Faster startup for debugging
@@ -64,7 +64,7 @@ muon_momentum_warmup_steps: 10
6464

6565
# Evaluation & Logging
6666
eval_every: 50
67-
hf_model_name: "Synthyra/debug_conv_unet"
67+
hf_model_name: "Synthyra/debug_patch_unet"
6868
save_every: null
6969

7070
# Dataloader Parameters

example_yamls/test.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# General Configuration
77
bugfix: false
8-
save_path: "Synthyra/conv_unet_test"
8+
save_path: "Synthyra/patch_unet_test"
99
data_name: "uniref50"
1010
num_chunks: 197
1111
log_name: null # If null, a random UUID will be generated
@@ -21,14 +21,14 @@ auto_grad_clip_p: 10.0
2121
hidden_size: 768
2222
num_attention_heads: 6
2323
num_hidden_layers: 24
24-
num_unet_layers: 12 # Number of layers for Conv1D UNet (set to > 0 to use)
24+
num_unet_layers: 12 # Number of layers for Patch UNet (set to > 0 to use)
2525
num_extra_layers: 4 # Number of extra transformer layers after UNet
2626
vocab_size: 33
2727
expansion_ratio: 2.0
2828
soft_logit_cap: 32.0
2929
tie_embeddings: false
3030
unet: true
31-
conv_unet: true # Use Conv1D UNet with downsampling
31+
patch_unet: true # Use Patch UNet with downsampling
3232
token_dropout: false
3333
bfloat16: true
3434
compile_model: true

model/model.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
sliding_window_size: int = 2048,
2929
tie_embeddings: bool = False,
3030
unet: bool = False,
31-
conv_unet: bool = False,
31+
patch_unet: bool = False,
3232
mlm: bool = False,
3333
masked_diffusion: bool = False,
3434
token_dropout: bool = True,
@@ -48,7 +48,7 @@ def __init__(
4848
self.sliding_window_size = sliding_window_size
4949
self.tie_embeddings = tie_embeddings
5050
self.unet = unet
51-
self.conv_unet = conv_unet
51+
self.patch_unet = patch_unet
5252
self.mlm = mlm
5353
self.masked_diffusion = masked_diffusion
5454
self.token_dropout = token_dropout
@@ -661,11 +661,11 @@ def __init__(self, config: PLMConfig):
661661
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
662662

663663
self.unet = config.unet
664-
self.conv_unet = config.conv_unet
664+
self.patch_unet = config.patch_unet
665665

666-
if config.conv_unet:
666+
if config.patch_unet:
667667
# Batched UNet with Swin-style patch merge/expand
668-
assert config.num_unet_layers > 0, "num_unet_layers must be > 0 for conv_unet"
668+
assert config.num_unet_layers > 0, "num_unet_layers must be > 0 for patch_unet"
669669
self.transformer = BatchedUnetTransformer(config)
670670
hidden_sizes = self.transformer.hidden_sizes
671671
self.value_embeds = BatchedValueEmbedding(config.vocab_size, hidden_sizes)
@@ -698,9 +698,9 @@ def __init__(self, config: PLMConfig):
698698
self.ce = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
699699

700700
def get_last_hidden_state(self, input_ids: torch.Tensor, sliding_window_size: int) -> torch.Tensor:
701-
if self.conv_unet:
701+
if self.patch_unet:
702702
# Batched UNet path: input_ids is (B, L)
703-
assert input_ids.dim() == 2, f"conv_unet expects (B, L) input, got shape {input_ids.shape}"
703+
assert input_ids.dim() == 2, f"patch_unet expects (B, L) input, got shape {input_ids.shape}"
704704
B, L = input_ids.shape
705705

706706
# Pre-compute multi-resolution block masks
@@ -795,18 +795,18 @@ def get_vector_embeddings(self, input_ids: torch.Tensor, sliding_window_size: Op
795795
"""Mean-pool hidden states per document to get per-document embeddings.
796796
797797
Args:
798-
input_ids: (B, L) for conv_unet or (total_len,) for standard/unet
798+
input_ids: (B, L) for patch_unet or (total_len,) for standard/unet
799799
sliding_window_size: Override sliding window size
800800
801801
Returns:
802-
For conv_unet (B, L): flattened (total_docs, hidden_size) across all batch elements
802+
For patch_unet (B, L): flattened (total_docs, hidden_size) across all batch elements
803803
For standard (total_len,): (num_docs, hidden_size)
804804
"""
805805
if sliding_window_size is None:
806806
sliding_window_size = self.sliding_window_size
807807
x = self.get_last_hidden_state(input_ids, sliding_window_size)
808808

809-
if self.conv_unet:
809+
if self.patch_unet:
810810
# Batched: x is (B, L, D), input_ids is (B, L)
811811
B, L, D = x.shape
812812
doc_ids = (input_ids == self.cls_token_id).cumsum(dim=1) # (B, L)
@@ -871,7 +871,7 @@ def get_logits(self, input_ids: torch.Tensor, sliding_window_size: Optional[int]
871871
"""Get LM logits without computing loss.
872872
873873
Args:
874-
input_ids: (B, L) for conv_unet or (total_len,) for standard/unet
874+
input_ids: (B, L) for patch_unet or (total_len,) for standard/unet
875875
sliding_window_size: Override sliding window size
876876
877877
Returns:
@@ -892,7 +892,7 @@ def get_embeddings(
892892
"""Get per-sequence pooled embeddings.
893893
894894
Args:
895-
input_ids: (B, L) for conv_unet or (total_len,) for standard/unet
895+
input_ids: (B, L) for patch_unet or (total_len,) for standard/unet
896896
sliding_window_size: Override sliding window size
897897
pooling: 'mean' for mean pooling over non-pad tokens, 'cls' for CLS token embedding
898898
@@ -903,7 +903,7 @@ def get_embeddings(
903903
sliding_window_size = self.sliding_window_size
904904
hidden = self.get_last_hidden_state(input_ids, sliding_window_size)
905905

906-
if self.conv_unet:
906+
if self.patch_unet:
907907
# Batched: hidden is (B, L, D), input_ids is (B, L)
908908
assert input_ids.dim() == 2
909909
B, L, D = hidden.shape
@@ -1013,20 +1013,20 @@ def push_weights_to_hub(self, repo_id: str):
10131013
print(f"Original UNet loss: {loss.item():.4f}")
10141014

10151015
print("\n" + "=" * 80)
1016-
print("Testing Batched UNet Transformer (conv_unet)")
1016+
print("Testing Batched UNet Transformer (patch_unet)")
10171017
print("=" * 80)
10181018
max_length = 128 # Power of 2 for patch merging
1019-
conv_config = PLMConfig(
1019+
patch_config = PLMConfig(
10201020
hidden_size=384,
10211021
num_attention_heads=6,
10221022
num_unet_layers=8, # 4 encoder + 4 decoder
10231023
num_extra_layers=2,
10241024
max_sequence_length=max_length,
10251025
expansion_ratio=8/3,
1026-
conv_unet=True,
1026+
patch_unet=True,
10271027
)
1028-
conv_model = PLM(conv_config).cuda()
1029-
print(f"Model parameters: {sum(p.numel() for p in conv_model.parameters()):,}")
1028+
patch_model = PLM(patch_config).cuda()
1029+
print(f"Model parameters: {sum(p.numel() for p in patch_model.parameters()):,}")
10301030

10311031
# Create batched test input (B, max_length) with packed documents per element
10321032
B = 4
@@ -1042,13 +1042,13 @@ def push_weights_to_hub(self, repo_id: str):
10421042
batched_labels = batched_ids.clone()
10431043
batched_labels[batched_labels != 32] = -100
10441044

1045-
loss = conv_model(batched_ids, batched_labels, mask_rate)
1045+
loss = patch_model(batched_ids, batched_labels, mask_rate)
10461046
print(f"Batched UNet loss: {loss.item():.4f}")
10471047

1048-
print(f"\nHidden sizes: {conv_model.transformer.hidden_sizes}")
1049-
print(f"Vector depth (log2(max_length)): {conv_model.transformer.vector_depth}")
1050-
print(f"Num encoder layers: {conv_model.transformer.num_encoder_layers}")
1051-
print(f"Num decoder layers: {conv_model.transformer.num_decoder_layers}")
1048+
print(f"\nHidden sizes: {patch_model.transformer.hidden_sizes}")
1049+
print(f"Vector depth (log2(max_length)): {patch_model.transformer.vector_depth}")
1050+
print(f"Num encoder layers: {patch_model.transformer.num_encoder_layers}")
1051+
print(f"Num decoder layers: {patch_model.transformer.num_decoder_layers}")
10521052

10531053
print("\n" + "=" * 80)
10541054
print("Testing Batched UNet with deep layers (MLP at vector depth)")
@@ -1060,7 +1060,7 @@ def push_weights_to_hub(self, repo_id: str):
10601060
num_extra_layers=1,
10611061
max_sequence_length=128, # log2(128)=7, so layers 7+ become MLPs
10621062
expansion_ratio=8/3,
1063-
conv_unet=True,
1063+
patch_unet=True,
10641064
)
10651065
deep_model = PLM(deep_config).cuda()
10661066

@@ -1086,7 +1086,7 @@ def push_weights_to_hub(self, repo_id: str):
10861086
input_ids=batched_ids,
10871087
cls_token_id=0,
10881088
pad_token_id=1,
1089-
num_levels=conv_model.transformer.num_resolution_levels,
1089+
num_levels=patch_model.transformer.num_resolution_levels,
10901090
sliding_window_size=128,
10911091
n_heads=6,
10921092
device=batched_ids.device,

0 commit comments

Comments
 (0)