-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_deepSC_with_MI.py
More file actions
executable file
·68 lines (55 loc) · 2.77 KB
/
train_deepSC_with_MI.py
File metadata and controls
executable file
·68 lines (55 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
it's used to train a model guided by deepSC and mutual information system
attention that it won't modify mutual info model, only deepSC's improvement will be stored
"""
import torch
from torch.utils.data import DataLoader
import modelModifiedForMI
from tqdm import tqdm
from data_process import CorpusData
import torch.nn.functional as F
batch_size = 256
num_epoch = 2
lamda = 0.05 # it's used to control how much the muInfo will affect deepSC model
save_path = './trainedModel/deepSC_with_MI.pth'
deepSC_path = 'trainedModel/deepSC_without_MI.pth'
muInfo_path = 'trainedModel/MutualInfoSystem.pth'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print('Using ' + str(device).upper())
dataloader = DataLoader(CorpusData(), batch_size= batch_size, shuffle=True)
scNet = modelModifiedForMI.SemanticCommunicationSystem()
scNet.load_state_dict(torch.load(deepSC_path, map_location=device))
scNet.to(device)
muInfoNet = modelModifiedForMI.MutualInfoSystem()
muInfoNet.load_state_dict(torch.load(muInfo_path, map_location=device))
muInfoNet.to(device)
optim = torch.optim.Adam(scNet.parameters(), lr=0.0005)
lossFn = modelModifiedForMI.LossFn()
for epoch in range(num_epoch):
train_bar = tqdm(dataloader)
for i, data in enumerate(train_bar):
[inputs, length_sen] = data # get length of sentence without padding
num_sample = inputs.size()[0] # get how much sentence the system get
inputs = inputs[:, 0, :].clone().detach().requires_grad_(True).long() # .long used to convert the tensor to long format
# in order to fit one_hot function
inputs = inputs.to(device)
label = F.one_hot(inputs, num_classes=35632).float()
label = label.to(device)
[s_predicted, codeSent, codeWithNoise] = scNet(inputs)
x = torch.reshape(codeSent, (-1, 16)) # get intermediate variables to train mutual info sys
y = torch.reshape(codeWithNoise, (-1, 16))
batch_joint = modelModifiedForMI.sample_batch(5, 'joint', x, y).to(device)
batch_marginal = modelModifiedForMI.sample_batch(5, 'marginal', x, y).to(device)
t = muInfoNet(batch_joint)
et = torch.exp(muInfoNet(batch_marginal))
MI_loss = torch.mean(t) - torch.log(torch.mean(et))
SC_loss = lossFn(s_predicted, label, length_sen, num_sample, batch_size)
loss = SC_loss + torch.exp(-MI_loss) * lamda
loss.backward()
optim.step()
optim.zero_grad()
print("Total Loss: {}, Mutual Loss: {}, SC Loss: {}".format(loss.cpu().detach().numpy(),
-MI_loss.cpu().detach().numpy(),
SC_loss.cpu().detach().numpy()))
torch.save(scNet.state_dict(), save_path)
print("All done!")