-
Notifications
You must be signed in to change notification settings - Fork 0
Remove restriction #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: self-supervised-nas
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,8 @@ | |
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| import utils | ||
| from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | ||
|
|
||
| SOS_ID = 0 | ||
| EOS_ID = 0 | ||
|
|
@@ -73,25 +75,40 @@ def __init__(self, | |
| for i in range(self.n): | ||
| self.offsets.append( (i + 3) * i // 2 - 1) | ||
|
|
||
| def forward(self, x, encoder_hidden=None, encoder_outputs=None): | ||
| def forward(self, x, x_len, encoder_hidden=None, encoder_outputs=None): | ||
| # x is decoder_inputs = [0] + encoder_inputs[:-1] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this |
||
|
|
||
| decoder_hidden = self._init_state(encoder_hidden) | ||
| if x is not None: | ||
| bsz = x.size(0) | ||
| tgt_len = x.size(1) | ||
| x = self.embedding(x) | ||
| x = F.dropout(x, self.dropout, training=self.training) | ||
| residual = x | ||
|
|
||
| x = pack_padded_sequence(x, x_len, batch_first=True) | ||
| x, hidden = self.rnn(x, decoder_hidden) | ||
| x = pad_packed_sequence(x, batch_first=True)[0] | ||
|
|
||
| x = (residual + x) * math.sqrt(0.5) | ||
| residual = x | ||
| x, _ = self.attention(x, encoder_outputs) | ||
|
|
||
| # create mask | ||
| mask = torch.zeros(bsz, x.size(1)) | ||
| for i,l in enumerate(x_len): | ||
| for j in range(l): | ||
| mask[i][j] = 1 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about ? |
||
| mask = (mask == 0).unsqueeze(1) | ||
| mask = utils.move_to_cuda(mask) | ||
|
|
||
| x, _ = self.attention(x, encoder_outputs, mask=mask) | ||
| x = (residual + x) * math.sqrt(0.5) | ||
| predicted_softmax = F.log_softmax(self.out(x.view(-1, self.hidden_size)), dim=-1) | ||
| predicted_softmax = predicted_softmax.view(bsz, tgt_len, -1) | ||
| return predicted_softmax, None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is predicted softmax returns sane values? If padded elements are zero-initialized, then probabilty values will be broken. |
||
|
|
||
|
|
||
| # inference | ||
| # inference : not using xlen. pad packed. | ||
| assert x is None | ||
| bsz = encoder_hidden[0].size(1) | ||
| length = self.length | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| import torch | ||
| import torch.utils.data | ||
| import torch.nn.functional as F | ||
| from torch.autograd import Variable | ||
| from nasbench import api | ||
|
|
||
| INPUT = 'input' | ||
|
|
@@ -48,19 +49,19 @@ def generate_arch(n, nasbench, need_perf=False): | |
| np.random.shuffle(all_keys) | ||
| for key in all_keys: | ||
| fixed_stat, computed_stat = nasbench.get_metrics_from_hash(key) | ||
| if len(fixed_stat['module_operations']) < 7: | ||
| continue | ||
| #if len(fixed_stat['module_operations']) < 7: | ||
| # continue | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove plz |
||
| arch = api.ModelSpec( | ||
| matrix=fixed_stat['module_adjacency'], | ||
| ops=fixed_stat['module_operations'], | ||
| ) | ||
| if need_perf: | ||
| data = nasbench.query(arch) | ||
| if data['validation_accuracy'] < 0.9: | ||
| val_acc = nasbench.query(arch, option='valid') | ||
| if val_acc < 0.9: | ||
| continue | ||
| valid_accs.append(data['validation_accuracy']) | ||
| valid_accs.append(val_acc) | ||
| archs.append(arch) | ||
| seqs.append(convert_arch_to_seq(arch.matrix, arch.ops)) | ||
| seqs.append(convert_arch_to_seq(arch.matrix, arch.ops, nasbench.search_space)) | ||
| count += 1 | ||
| if count >= n: | ||
| return archs, seqs, valid_accs | ||
|
|
@@ -75,74 +76,76 @@ def __init__(self, inputs, targets=None, train=True, sos_id=0, eos_id=0): | |
| super(ControllerDataset, self).__init__() | ||
| if targets is not None: | ||
| assert len(inputs) == len(targets) | ||
| self.inputs = inputs | ||
| self.inputs = inputs # list of seqs | ||
| self.len_inputs = [len(i) for i in inputs] | ||
| self.max_len = max(self.len_inputs) | ||
| self.targets = targets | ||
| self.train = train | ||
| self.sos_id = sos_id | ||
| self.eos_id = eos_id | ||
|
|
||
| def __getitem__(self, index): | ||
| encoder_input = self.inputs[index] | ||
| encoder_input = self.inputs[index] + [0 for _ in range(self.max_len - len(self.inputs[index]))] # fix length as max_len | ||
| len_input = self.len_inputs[index] | ||
| encoder_target = None | ||
| if self.targets is not None: | ||
| encoder_target = [self.targets[index]] | ||
| if self.train: | ||
| decoder_input = [self.sos_id] + encoder_input[:-1] | ||
| sample = { | ||
| 'encoder_input': torch.LongTensor(encoder_input), | ||
| 'encoder_target': torch.FloatTensor(encoder_target), | ||
| 'decoder_input': torch.LongTensor(decoder_input), | ||
| 'decoder_target': torch.LongTensor(encoder_input), | ||
| 'encoder_input': np.array(encoder_input, dtype=np.int64), | ||
| 'encoder_target': np.array(encoder_target, dtype=np.float64), | ||
| 'decoder_input': np.array(decoder_input, dtype=np.int64), | ||
| 'decoder_target': np.array(encoder_input, dtype=np.int64), | ||
| 'input_len': len_input, | ||
| } | ||
| else: | ||
| sample = { | ||
| 'encoder_input': torch.LongTensor(encoder_input), | ||
| 'decoder_target': torch.LongTensor(encoder_input), | ||
| 'encoder_input': np.array(encoder_input, dtype=np.int64), | ||
| 'decoder_target': np.array(encoder_input, dtype=np.int64), | ||
| 'input_len': len_input, | ||
| } | ||
| if encoder_target is not None: | ||
| sample['encoder_target'] = torch.FloatTensor(encoder_target) | ||
| sample['encoder_target'] = np.array(encoder_target, dtype=np.float64) | ||
| return sample | ||
|
|
||
| def __len__(self): | ||
| return len(self.inputs) | ||
|
|
||
|
|
||
| def convert_arch_to_seq(matrix, ops): | ||
| def convert_arch_to_seq(matrix, ops, search_space): | ||
| seq = [] | ||
| n = len(matrix) | ||
| assert n == len(ops) | ||
|
|
||
| for col in range(1, n): | ||
| for row in range(col): | ||
| seq.append(matrix[row][col]+1) | ||
| if ops[col] == CONV1X1: | ||
| seq.append(3) | ||
| elif ops[col] == CONV3X3: | ||
| seq.append(4) | ||
| elif ops[col] == MAXPOOL3X3: | ||
| seq.append(5) | ||
| if ops[col] == OUTPUT: | ||
| seq.append(6) | ||
| if ops[col] == 'output': | ||
| seq.append(len(search_space) + 3) | ||
| elif ops[col] != 'input': | ||
| seq.append(search_space.index(ops[col]) + 3) | ||
|
|
||
| assert len(seq) == (n+2)*(n-1)/2 | ||
| return seq | ||
|
|
||
|
|
||
| def convert_seq_to_arch(seq): | ||
| def convert_seq_to_arch(seq, search_space): | ||
| n = int(math.floor(math.sqrt((len(seq) + 1) * 2))) | ||
| matrix = [[0 for _ in range(n)] for _ in range(n)] | ||
| ops = [INPUT] | ||
| ops = ['input'] | ||
|
|
||
| for i in range(n-1): | ||
| offset=(i+3)*i//2 | ||
| for j in range(i+1): | ||
| matrix[j][i+1] = seq[offset+j] - 1 | ||
| if seq[offset+i+1] == 3: | ||
| op = CONV1X1 | ||
| elif seq[offset+i+1] == 4: | ||
| op = CONV3X3 | ||
| elif seq[offset+i+1] == 5: | ||
| op = MAXPOOL3X3 | ||
| elif seq[offset+i+1] == 6: | ||
| op = OUTPUT | ||
| idx = seq[offset+i+1] - 3 | ||
| if idx == len(search_space): | ||
| op = 'output' | ||
| else: | ||
| op = search_space[idx] | ||
| ops.append(op) | ||
|
|
||
| return matrix, ops | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x_len feels like a somewhat misnomer. If i understood it correctly, it is somewhat like this:
x_len = [len(x) for x in xs]
So, maybe x_len_per_elem? x_len_list? Or at least some comments would be helpful!