Skip to content

Be careful when using adaptive gradient methods #17

@stevenyangyj

Description

@stevenyangyj

camp

I tested three methods in a very simple problem, and got the result as above.

Code are printed here:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import adabound

class Net(nn.Module):

def __init__(self, dim):
    
    super(Net, self).__init__()
    self.fc1 = nn.Linear(dim, 2*dim)
    self.relu = nn.ReLU(inplace=True)
    self.fc2 = nn.Linear(2*dim, dim)

def forward(self, x):
    
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    
    return x

DIM = 30
epochs = 1000
xini = (torch.ones(1, DIM) * 100)
opti = (torch.zeros(1, DIM) * 100)

lr = 0.01
net = Net(DIM)
objfun = nn.MSELoss()

loss_adab = []
loss_adam = []
loss_sgd = []
for epoch in range(epochs):

if epoch % 100 == 0:
    lr /= 10

optimizer = adabound.AdaBound(net.parameters(), lr) 
out = net(xini)
los = objfun(out, opti)
loss_adab.append(los.detach().numpy())

optimizer.zero_grad()
los.backward()
optimizer.step()

lr = 0.01
net = Net(DIM)
objfun = nn.MSELoss()

for epoch in range(epochs):

if epoch % 100 == 0:
    lr /= 10

optimizer = torch.optim.Adam(net.parameters(), lr) 
out = net(xini)
los = objfun(out, opti)
loss_adam.append(los.detach().numpy())

optimizer.zero_grad()
los.backward()
optimizer.step()   

lr = 0.001
net = Net(DIM)
objfun = nn.MSELoss()

for epoch in range(epochs):

if epoch % 100 == 0:
    lr /= 10

optimizer = torch.optim.SGD(net.parameters(), lr, momentum=0.9) 
out = net(xini)
los = objfun(out, opti)
loss_sgd.append(los.detach().numpy())

optimizer.zero_grad()
los.backward()
optimizer.step()

plt.figure()
plt.plot(loss_adab, label='adabound')
plt.plot(loss_adam, label='adam')
plt.plot(loss_sgd, label='SGD')
plt.yscale('log')
plt.xlabel('epochs')
plt.ylabel('Log(loss)')
plt.legend()
plt.savefig('camp.png', dpi=600)
plt.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions