-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataloader.py
More file actions
23 lines (16 loc) · 856 Bytes
/
dataloader.py
File metadata and controls
23 lines (16 loc) · 856 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
class GPT2Dataset(Dataset):
def __init__(self, txt_list, tokenizer, gpt2_type="gpt2", max_length=768):
self.tokenizer = tokenizer
self.input_ids = []
self.attn_masks = []
for txt in txt_list:
encodings_dict = tokenizer('<|startoftext|>' + txt + '<|endoftext|>', truncation=True,
max_length=max_length, padding="max_length")
self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.attn_masks[idx]