
I tested three methods in a very simple problem, and got the result as above.
Code are printed here:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import adabound
class Net(nn.Module):
def __init__(self, dim):
super(Net, self).__init__()
self.fc1 = nn.Linear(dim, 2*dim)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(2*dim, dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
DIM = 30
epochs = 1000
xini = (torch.ones(1, DIM) * 100)
opti = (torch.zeros(1, DIM) * 100)
lr = 0.01
net = Net(DIM)
objfun = nn.MSELoss()
loss_adab = []
loss_adam = []
loss_sgd = []
for epoch in range(epochs):
if epoch % 100 == 0:
lr /= 10
optimizer = adabound.AdaBound(net.parameters(), lr)
out = net(xini)
los = objfun(out, opti)
loss_adab.append(los.detach().numpy())
optimizer.zero_grad()
los.backward()
optimizer.step()
lr = 0.01
net = Net(DIM)
objfun = nn.MSELoss()
for epoch in range(epochs):
if epoch % 100 == 0:
lr /= 10
optimizer = torch.optim.Adam(net.parameters(), lr)
out = net(xini)
los = objfun(out, opti)
loss_adam.append(los.detach().numpy())
optimizer.zero_grad()
los.backward()
optimizer.step()
lr = 0.001
net = Net(DIM)
objfun = nn.MSELoss()
for epoch in range(epochs):
if epoch % 100 == 0:
lr /= 10
optimizer = torch.optim.SGD(net.parameters(), lr, momentum=0.9)
out = net(xini)
los = objfun(out, opti)
loss_sgd.append(los.detach().numpy())
optimizer.zero_grad()
los.backward()
optimizer.step()
plt.figure()
plt.plot(loss_adab, label='adabound')
plt.plot(loss_adam, label='adam')
plt.plot(loss_sgd, label='SGD')
plt.yscale('log')
plt.xlabel('epochs')
plt.ylabel('Log(loss)')
plt.legend()
plt.savefig('camp.png', dpi=600)
plt.show()
I tested three methods in a very simple problem, and got the result as above.
Code are printed here:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import adabound
class Net(nn.Module):
DIM = 30
epochs = 1000
xini = (torch.ones(1, DIM) * 100)
opti = (torch.zeros(1, DIM) * 100)
lr = 0.01
net = Net(DIM)
objfun = nn.MSELoss()
loss_adab = []
loss_adam = []
loss_sgd = []
for epoch in range(epochs):
lr = 0.01
net = Net(DIM)
objfun = nn.MSELoss()
for epoch in range(epochs):
lr = 0.001
net = Net(DIM)
objfun = nn.MSELoss()
for epoch in range(epochs):
plt.figure()
plt.plot(loss_adab, label='adabound')
plt.plot(loss_adam, label='adam')
plt.plot(loss_sgd, label='SGD')
plt.yscale('log')
plt.xlabel('epochs')
plt.ylabel('Log(loss)')
plt.legend()
plt.savefig('camp.png', dpi=600)
plt.show()