-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsplit.py
More file actions
65 lines (48 loc) · 1.87 KB
/
split.py
File metadata and controls
65 lines (48 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# SPDX-FileCopyrightText: (c) UIUC PurpCode Team
#
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
from pathlib import Path
from datasets import load_dataset
def split_dataset(dataset, num_splits, output_dir):
# Create output directory
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
# Load data
if dataset.endswith(".jsonl"):
with open(dataset, "r") as f:
data = [json.loads(line) for line in f]
else:
data = load_dataset(dataset, split="test")
data = list(data)
# Calculate items per split
total_items = len(data)
items_per_split = total_items // num_splits
remainder = total_items % num_splits
# Create splits
start_idx = 0
for i in range(num_splits):
# Add one extra item to first 'remainder' splits
current_split_size = items_per_split + (1 if i < remainder else 0)
end_idx = start_idx + current_split_size
# Save split
split_data = data[start_idx:end_idx]
output_file = output_dir / f"split_{i+1:03d}.jsonl"
with open(output_file, "w") as f:
for item in split_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"Split {i+1}: {len(split_data)} items -> {output_file}")
start_idx = end_idx
print(f"\nTotal: {total_items} items split into {num_splits} files")
def main():
parser = argparse.ArgumentParser(description="Split dataset into multiple files")
parser.add_argument("--dataset", help="Input JSONL file")
parser.add_argument(
"-n", "--num-splits", type=int, required=True, help="Number of splits"
)
parser.add_argument("-o", "--output-dir", default="splits", help="Output directory")
args = parser.parse_args()
split_dataset(args.dataset, args.num_splits, args.output_dir)
if __name__ == "__main__":
main()