-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
55 lines (40 loc) · 1.62 KB
/
dataset.py
File metadata and controls
55 lines (40 loc) · 1.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
@dataclass
class DataArgs:
batch_size: int
block_size: int
pad_token_id: int
class TokenDataset(Dataset):
def __init__(self, input_ids, args: DataArgs):
self.input_ids = input_ids
self.block_size = args.block_size
self.pad_token_id = args.pad_token_id
def __len__(self):
# Number of full blocks
return (len(self.input_ids) + self.block_size - 1) // self.block_size
def __getitem__(self, idx):
start_idx = idx * self.block_size
end_idx = start_idx + self.block_size
input_ids_block = self.input_ids[start_idx:end_idx]
# If the block is shorter than block_size, pad it
if len(input_ids_block) < self.block_size:
padding_length = self.block_size - len(input_ids_block)
input_ids_block += [self.pad_token_id] * padding_length
return torch.tensor(input_ids_block)
def collate_fn(batch):
"""
A function that dynamically pads the sequences in the batch.
Args:
batch: The input batch of sequences to pad.
Returns:
padded_batch: The padded batch of sequences.
Note:
The padding value used is 0, which ensures that the padded tokens do not have interaction with other useful tokens.
"""
# Dynamically pad the sequences in the batch
padded_batch = pad_sequence(batch, batch_first=True, padding_value=0)
return padded_batch