-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
45 lines (36 loc) · 1.46 KB
/
models.py
File metadata and controls
45 lines (36 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from transformers import T5Tokenizer, T5ForConditionalGeneration, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM
from torch.optim import AdamW
import evaluate
from datamodule import SarcasmDataModule
class Model(pl.LightningModule):
def __init__(self, model_name, lr):
super(Model, self).__init__()
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.lr = lr
def forward(self, batch):
output = self.model(**batch)
return output
def generate(self, input_ids, attention_mask, max_length=128):
return self.model.generate(input_ids, attention_mask=attention_mask, max_length=max_length)
def _step(self, batch, idx):
output = self(batch)
loss = output.loss
return loss
def training_step(self, batch, idx):
loss = self._step(batch, idx)
self.log('train_loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, idx):
loss = self._step(batch, idx)
self.log('val_loss', loss, prog_bar=True)
return loss
def test_step(self, batch, idx):
loss = self._step(batch, idx)
self.log('test_loss', loss, prog_bar=True)
return loss
def configure_optimizers(self):
return AdamW(self.parameters(), lr=self.lr)