-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmulti_classify_model.py
More file actions
24 lines (21 loc) · 1.03 KB
/
multi_classify_model.py
File metadata and controls
24 lines (21 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from pytorch_pretrained_bert import BertModel
from pytorch_pretrained_bert.modeling import BertPreTrainedModel
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
def __init__(self, config, num_labels):
super(BertForMultiLabelSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size, num_labels),
nn.Sigmoid()
)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
input_ids=pad_sequence(input_ids,batch_first=True)
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits