Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
.vscode/launch.json
**pycache**
gumble_sampling/data
17 changes: 17 additions & 0 deletions gumble_sampling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## Normal Mnist

```
python mnist_normal.py
```

running the above will generate plot like

![baseline](mnist_loss_plot.png)

## Gumble MNIST

```
python mnist_gumble_base.py
```

![baseline](mnist_loss_plot_gumble.png)
15 changes: 15 additions & 0 deletions gumble_sampling/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
import torch.nn.functional as F


def gumbel_softmax(logits, temperature=1.0, hard=False):
# Sample Gumbel noise
noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
y = F.softmax((logits + noise) / temperature, dim=-1)

if hard:
# Convert to one-hot
y_hard = torch.zeros_like(y)
y_hard.scatter_(1, y.argmax(dim=1, keepdim=True), 1.0)
y = (y_hard - y).detach() + y # Straight-through estimator
return y
Binary file added gumble_sampling/gumble.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 29 additions & 0 deletions gumble_sampling/gumble_sampling_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from common import gumbel_softmax


class GumbelSamplingLayer(nn.Module):
def __init__(self, in_dim, full_out_dim, sample_dim, temperature=1.0):
super().__init__()
self.linear = nn.Linear(in_dim, full_out_dim)
self.logits = nn.Parameter(
torch.randn(sample_dim, full_out_dim)
) # learnable selector
self.sample_dim = sample_dim
self.temperature = temperature

def forward(self, x):
# x: (batch_size, in_dim)
full_out = self.linear(x) # (batch_size, full_out_dim)

# Sample selection mask: (sample_dim, full_out_dim)
mask_weights = gumbel_softmax(
self.logits, temperature=self.temperature, hard=True
)

# Weighted sum: project from full_out_dim -> sample_dim
# Output shape: (batch_size, sample_dim)
sampled_out = torch.matmul(mask_weights, full_out.T).T
return sampled_out
Empty file.
Empty file.
108 changes: 108 additions & 0 deletions gumble_sampling/mnist_gumble_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from gumble_sampling_layer import GumbelSamplingLayer
import matplotlib.pyplot as plt

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

train_data = datasets.MNIST(
root="./data", train=True, transform=transform, download=True
)
test_data = datasets.MNIST(root="./data", train=False, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)


# Model using GumbelSamplingLayer
class GumbelNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = GumbelSamplingLayer(784, 512, 256)
self.relu = nn.ReLU()
self.fc2 = GumbelSamplingLayer(256, 256, 128)
self.output = nn.Linear(128, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return self.output(x)


model = GumbelNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

train_losses = []
test_accuracies = []
test_losses = []
# Training loop
epochs = 50
for epoch in range(epochs):
model.train()
total_train_loss = 0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out = model(data)
loss = criterion(out, target)
loss.backward()
optimizer.step()
total_train_loss += loss.item()

avg_train_loss = total_train_loss / len(train_loader)
train_losses.append(avg_train_loss)
# Evaluation
model.eval()
correct = 0
total = 0
total_test_loss = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
total_test_loss += loss.item()
# Calculate accuracy
preds = output.argmax(dim=1)
correct += (preds == target).sum().item()
total += target.size(0)
avg_test_loss = total_test_loss / len(test_loader)
test_losses.append(avg_test_loss)
acc = 100.0 * correct / total
test_accuracies.append(acc)
print(
f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f} | Accuracy: {acc}%"
)

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8))

# First subplot: Losses
ax1.plot(train_losses, label="Train Loss")
ax1.plot(test_losses, label="Test Loss")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Train vs Test Loss on MNIST")
ax1.legend()
ax1.grid(True)

# Second subplot: Accuracy
ax2.plot(test_accuracies, label="Test Accuracy", linestyle="--", color="orange")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy (%)")
ax2.set_title("Test Accuracy on MNIST")
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig("mnist_loss_plot_gumble.png")
Binary file added gumble_sampling/mnist_loss_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added gumble_sampling/mnist_loss_plot_gumble.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
117 changes: 117 additions & 0 deletions gumble_sampling/mnist_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data transforms
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# Load MNIST dataset
train_data = datasets.MNIST(
root="./data", train=True, transform=transform, download=True
)
test_data = datasets.MNIST(root="./data", train=False, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)


# Simple Linear MLP model
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
)

def forward(self, x):
x = x.view(-1, 28 * 28)
return self.model(x)


model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Track losses
train_losses = []
test_losses = []
test_accuracies = []

# Training loop
epochs = 25
for epoch in range(epochs):
model.train()
total_train_loss = 0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_train_loss += loss.item()

avg_train_loss = total_train_loss / len(train_loader)
train_losses.append(avg_train_loss)

# Evaluation
model.eval()
total_test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
total_test_loss += loss.item()
preds = output.argmax(dim=1)
correct += (preds == target).sum().item()
total += target.size(0)

avg_test_loss = total_test_loss / len(test_loader)
test_losses.append(avg_test_loss)

acc = 100.0 * correct / total
test_accuracies.append(acc)
print(f"Test Accuracy: {acc:.2f}%")

print(
f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f} | Accuracy: {acc}%"
)


fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8))

# Loss subplot
ax1.plot(train_losses, label="Train Loss")
ax1.plot(test_losses, label="Test Loss")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Train vs Test Loss")
ax1.legend()
ax1.grid(True)

# Accuracy subplot
ax2.plot(test_accuracies, label="Test Accuracy", color="orange", linestyle="--")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy (%)")
ax2.set_title("Test Accuracy")
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig("mnist_loss_plot.png")
plt.show()
Loading