-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_prep.py
More file actions
69 lines (57 loc) · 2.15 KB
/
data_prep.py
File metadata and controls
69 lines (57 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from datasets import load_dataset, Dataset
from transformers import LEDTokenizer
import torch
def prepare_data():
dataset = load_dataset("ninadn/indian-legal")
print(f"Initial dataset size: {len(dataset['train'])} documents")
# Take only half of the documents
documents = [
{
'Text': str(doc['Text']).strip(),
'excerpt': str(doc['Summary']).strip()
}
for doc in dataset['train']
if len(str(doc['Text']).strip()) > 100 and len(str(doc['Summary']).strip()) > 0
][:2500] # Using only 2500 instead of 5000
dataset = Dataset.from_list(documents)
tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
def preprocess_function(examples):
inputs = tokenizer(
examples["Text"],
padding="max_length",
truncation=True,
max_length=16384
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
examples["excerpt"],
padding="max_length",
truncation=True,
max_length=1024
)
global_attention_mask = torch.zeros_like(torch.tensor(inputs.input_ids))
global_attention_mask[:, 0] = 1
return {
"input_ids": inputs.input_ids,
"attention_mask": inputs.attention_mask,
"global_attention_mask": global_attention_mask.tolist(),
"labels": labels.input_ids
}
processed_dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=dataset.column_names,
num_proc=4
)
# Split proportions remain the same
splits = processed_dataset.train_test_split(train_size=0.8, test_size=0.2, shuffle=True, seed=42)
val_test_splits = splits["test"].train_test_split(train_size=0.5, shuffle=True, seed=42)
final_dataset = {
"train": splits["train"],
"validation": val_test_splits["train"],
"test": val_test_splits["test"]
}
print(f"\nFinal dataset splits:")
for split, data in final_dataset.items():
print(f"{split}: {len(data)} examples")
return final_dataset, tokenizer