-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathfiles.py
More file actions
659 lines (562 loc) · 24.5 KB
/
files.py
File metadata and controls
659 lines (562 loc) · 24.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
from __future__ import annotations
import json
import os
import csv
from pathlib import Path
from traceback import format_exc
from typing import Any, Dict, List
from tqdm import tqdm
from together.constants import (
MAX_FILE_SIZE_GB,
MIN_SAMPLES,
NUM_BYTES_IN_GB,
PARQUET_EXPECTED_COLUMNS,
JSONL_REQUIRED_COLUMNS_MAP,
REQUIRED_COLUMNS_MESSAGE,
POSSIBLE_ROLES_CONVERSATION,
DatasetFormat,
)
from together.types import FilePurpose
class InvalidFileFormatError(ValueError):
"""Exception raised for invalid file formats during file checks."""
def __init__(
self,
message: str = "",
line_number: int | None = None,
error_source: str | None = None,
) -> None:
super().__init__(message)
self.message = message
self.line_number = line_number
self.error_source = error_source
def check_file(
file: Path | str,
purpose: FilePurpose | str = FilePurpose.FineTune,
) -> Dict[str, Any]:
if not isinstance(file, Path):
file = Path(file)
report_dict = {
"is_check_passed": True,
"message": "Checks passed",
"found": None,
"file_size": None,
"utf8": None,
"line_type": None,
"text_field": None,
"key_value": None,
"has_min_samples": None,
"num_samples": None,
"load_json": None,
"load_csv": None,
}
if not file.is_file():
report_dict["found"] = False
report_dict["is_check_passed"] = False
return report_dict
else:
report_dict["found"] = True
file_size = os.stat(file).st_size
if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
report_dict["message"] = (
f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB."
)
report_dict["is_check_passed"] = False
elif file_size == 0:
report_dict["message"] = "File is empty"
report_dict["file_size"] = 0
report_dict["is_check_passed"] = False
return report_dict
else:
report_dict["file_size"] = file_size
data_report_dict = {}
if file.suffix == ".jsonl":
report_dict["filetype"] = "jsonl"
data_report_dict = _check_jsonl(file, purpose)
elif file.suffix == ".parquet":
report_dict["filetype"] = "parquet"
data_report_dict = _check_parquet(file, purpose)
elif file.suffix == ".csv":
report_dict["filetype"] = "csv"
data_report_dict = _check_csv(file, purpose)
else:
report_dict["filetype"] = (
f"Unknown extension of file {file}. "
"Only files with extensions .jsonl and .parquet are supported."
)
report_dict["is_check_passed"] = False
report_dict.update(data_report_dict)
return report_dict
def _check_conversation_type(messages: List[Dict[str, str | bool]], idx: int) -> None:
"""Check that the conversation has correct type.
Args:
messages: The messages in the conversation.
Can be any type, this function ensures that the messages are a list of dictionaries.
idx: Line number in the file.
Raises:
InvalidFileFormatError: If the conversation type is invalid.
"""
if not isinstance(messages, list):
raise InvalidFileFormatError(
message=f"Invalid format on line {idx + 1} of the input file. "
f"The `messages` column must be a list. Found {type(messages)}",
line_number=idx + 1,
error_source="key_value",
)
if len(messages) == 0:
raise InvalidFileFormatError(
message=f"Invalid format on line {idx + 1} of the input file. "
f"The `messages` column must not be empty.",
line_number=idx + 1,
error_source="key_value",
)
for message in messages:
if not isinstance(message, dict):
raise InvalidFileFormatError(
message=f"Invalid format on line {idx + 1} of the input file. "
f"The `messages` column must be a list of dicts. Found {type(message)}",
line_number=idx + 1,
error_source="key_value",
)
for column in REQUIRED_COLUMNS_MESSAGE:
if column not in message:
raise InvalidFileFormatError(
message=f"Missing required column `{column}` in message on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
if not isinstance(message[column], str):
raise InvalidFileFormatError(
message=f"Column `{column}` is not a string on line {idx + 1}. Found {type(message[column])}",
line_number=idx + 1,
error_source="text_field",
)
def _check_conversation_roles(
require_assistant_role: bool, assistant_role_exists: bool, idx: int
) -> None:
"""Check that the conversation has correct roles.
Args:
require_assistant_role: Whether to require at least one assistant role.
assistant_role_exists: Whether an assistant role exists in the conversation.
idx: Line number in the file.
Raises:
InvalidFileFormatError: If the conversation roles are invalid.
"""
if require_assistant_role and not assistant_role_exists:
raise InvalidFileFormatError(
message=f"Invalid format on line {idx + 1} of the input file. "
"At least one message with the assistant role must be present in the example.",
line_number=idx + 1,
error_source="key_value",
)
def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None:
"""Check that the message has a weight with the correct type and value.
Args:
message: The message to check.
idx: Line number in the file.
Raises:
InvalidFileFormatError: If the message weight is invalid.
"""
if "weight" in message:
weight = message["weight"]
if not isinstance(weight, int):
raise InvalidFileFormatError(
message=f"Weight must be an integer on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
if weight not in {0, 1}:
raise InvalidFileFormatError(
message=f"Weight must be either 0 or 1 on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
def _check_message_role(
message: Dict[str, str | bool], previous_role: str | None, idx: int
) -> str | bool:
"""Check that the message has correct roles.
Args:
message: The message to check.
previous_role: The role of the previous message.
idx: Line number in the file.
Returns:
str: The role of the current message.
Raises:
InvalidFileFormatError: If the message role is invalid.
"""
if message["role"] not in POSSIBLE_ROLES_CONVERSATION:
raise InvalidFileFormatError(
message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. "
f"Possible roles: {', '.join(POSSIBLE_ROLES_CONVERSATION)}",
line_number=idx + 1,
error_source="key_value",
)
if previous_role is not None and message["role"] == previous_role:
raise InvalidFileFormatError(
message=f"Invalid role turns on line {idx + 1} of the input file. "
"After the optional system message, conversation roles must alternate between user/assistant/user/assistant.",
line_number=idx + 1,
error_source="key_value",
)
return message["role"]
def validate_messages(
messages: List[Dict[str, str | bool]], idx: int, require_assistant_role: bool = True
) -> None:
"""Validate the messages column.
Args:
messages: List of message dictionaries to validate.
idx: Line number in the file.
require_assistant_role: Whether to require at least one assistant role.
Raises:
InvalidFileFormatError: If the messages are invalid.
"""
_check_conversation_type(messages, idx)
has_weights = any("weight" in message for message in messages)
previous_role = None
assistant_role_exists = False
for message in messages:
if has_weights:
_check_message_weight(message, idx)
previous_role = _check_message_role(message, previous_role, idx)
assistant_role_exists |= previous_role == "assistant"
_check_conversation_roles(require_assistant_role, assistant_role_exists, idx)
def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None:
"""Validate the OpenAI preference dataset format.
Args:
example (dict): Input entry to be checked.
idx (int): Line number in the file.
Raises:
InvalidFileFormatError: If the dataset format is invalid.
"""
if not isinstance(example["input"], dict):
raise InvalidFileFormatError(
message="The dataset is malformed, the `input` field must be a dictionary.",
line_number=idx + 1,
error_source="key_value",
)
if "messages" not in example["input"]:
raise InvalidFileFormatError(
message="The dataset is malformed, the `input` dictionary must contain a `messages` field.",
line_number=idx + 1,
error_source="key_value",
)
validate_messages(example["input"]["messages"], idx, require_assistant_role=False)
if example["input"]["messages"][-1]["role"] == "assistant":
raise InvalidFileFormatError(
message=f"The last message in the input conversation must not be from the assistant on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
keys = ["preferred_output", "non_preferred_output"]
for key in keys:
if key not in example:
raise InvalidFileFormatError(
message=f"The dataset is malformed, the `{key}` field must be present in the input dictionary on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
if not isinstance(example[key], list):
raise InvalidFileFormatError(
message=f"The dataset is malformed, the `{key}` field must be a list on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
if len(example[key]) != 1:
raise InvalidFileFormatError(
message=f"The dataset is malformed, the `{key}` list must contain exactly one message on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
if not isinstance(example[key][0], dict):
raise InvalidFileFormatError(
message=f"The dataset is malformed, the first element of `{key}` must be a dictionary on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
if "role" not in example[key][0]:
raise InvalidFileFormatError(
message=f"The dataset is malformed, the first element of `{key}` must have a 'role' field on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
if example[key][0]["role"] != "assistant":
raise InvalidFileFormatError(
message=f"The dataset is malformed, the first element of `{key}` must have the 'assistant' role on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
if "content" not in example[key][0]:
raise InvalidFileFormatError(
message=f"The dataset is malformed, the first element of `{key}` must have a 'content' field on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
if not isinstance(example[key][0]["content"], str):
raise InvalidFileFormatError(
message=f"The dataset is malformed, the 'content' field in `{key}` must be a string on line {idx + 1}.",
line_number=idx + 1,
error_source="key_value",
)
def _check_utf8(file: Path) -> Dict[str, Any]:
"""Check if the file is UTF-8 encoded.
Args:
file (Path): Path to the file to check.
Returns:
Dict[str, Any]: A dictionary with the results of the check.
"""
report_dict: Dict[str, Any] = {}
try:
# Dry-run UTF-8 decode: iterate through file to validate encoding
with file.open(encoding="utf-8") as f:
for _ in f:
pass
report_dict["utf8"] = True
except UnicodeDecodeError as e:
report_dict["utf8"] = False
report_dict["message"] = f"File is not UTF-8 encoded. Error raised: {e}."
report_dict["is_check_passed"] = False
return report_dict
def _check_samples_count(
file: Path, report_dict: Dict[str, Any], idx: int
) -> Dict[str, Any]:
if idx + 1 < MIN_SAMPLES:
report_dict["has_min_samples"] = False
report_dict["message"] = (
f"Processing {file} resulted in only {idx + 1} samples. "
f"Our minimum is {MIN_SAMPLES} samples. "
)
report_dict["is_check_passed"] = False
else:
report_dict["num_samples"] = idx + 1
report_dict["has_min_samples"] = True
return report_dict
def _check_csv(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
"""Check if the file is a valid CSV file.
Args:
file (Path): Path to the file to check.
purpose (FilePurpose | str): Purpose of the file, used to determine if the file should be checked for specific columns.
Returns:
Dict[str, Any]: A dictionary with the results of the check.
"""
report_dict: Dict[str, Any] = {}
if purpose != FilePurpose.Eval:
report_dict["is_check_passed"] = False
report_dict["message"] = (
f"CSV files are not supported for {purpose}. "
"Only JSONL and Parquet files are supported."
)
return report_dict
report_dict.update(_check_utf8(file))
if not report_dict["utf8"]:
return report_dict
with file.open() as f:
reader = csv.DictReader(f)
if not reader.fieldnames:
report_dict["message"] = "CSV file is empty or has no header."
report_dict["is_check_passed"] = False
return report_dict
idx = -1
try:
# for loop to iterate through the CSV rows
for idx, item in enumerate(reader):
if None in item.keys() or None in item.values():
raise InvalidFileFormatError(
message=f"CSV file is malformed or the number of columns found on line {idx + 1} is inconsistent with the header",
line_number=idx + 1,
error_source="format",
)
report_dict.update(_check_samples_count(file, report_dict, idx))
report_dict["load_csv"] = True
except InvalidFileFormatError as e:
report_dict["load_csv"] = False
report_dict["is_check_passed"] = False
report_dict["message"] = e.message
if e.line_number is not None:
report_dict["line_number"] = e.line_number
if e.error_source is not None:
report_dict[e.error_source] = False
except ValueError:
report_dict["load_csv"] = False
if idx < 0:
report_dict["message"] = (
"Unable to decode file. "
"File may be empty or in an unsupported format. "
)
else:
report_dict["message"] = (
f"Error parsing the CSV file. Unexpected format on line {idx + 1}."
)
report_dict["is_check_passed"] = False
return report_dict
def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
report_dict: Dict[str, Any] = {}
report_dict.update(_check_utf8(file))
if not report_dict["utf8"]:
return report_dict
dataset_format = None
with file.open() as f:
idx = -1
try:
for idx, line in tqdm(enumerate(f), desc="Validating file", unit=" lines"):
json_line = json.loads(line)
if not isinstance(json_line, dict):
raise InvalidFileFormatError(
message=(
f"Error parsing file. Invalid format on line {idx + 1} of the input file. "
"Datasets must follow text, conversational, or instruction format. For more"
"information, see https://docs.together.ai/docs/fine-tuning-data-preparation"
),
line_number=idx + 1,
error_source="line_type",
)
# In evals, we don't check the format of the dataset.
if purpose != FilePurpose.Eval:
current_format = None
for possible_format in JSONL_REQUIRED_COLUMNS_MAP:
if all(
column in json_line
for column in JSONL_REQUIRED_COLUMNS_MAP[possible_format]
):
if current_format is None:
current_format = possible_format
elif current_format != possible_format:
raise InvalidFileFormatError(
message="Found multiple dataset formats in the input file. "
f"Got {current_format} and {possible_format} on line {idx + 1}.",
line_number=idx + 1,
error_source="format",
)
# Check that there are no extra columns
for column in json_line:
if (
column
not in JSONL_REQUIRED_COLUMNS_MAP[possible_format]
):
raise InvalidFileFormatError(
message=f'Found extra column "{column}" in the line {idx + 1}.',
line_number=idx + 1,
error_source="format",
)
if current_format is None:
raise InvalidFileFormatError(
message=(
f"Error parsing file. Could not detect a format for the line {idx + 1} with the columns:\n"
f"{json_line.keys()}"
),
line_number=idx + 1,
error_source="format",
)
if current_format == DatasetFormat.PREFERENCE_OPENAI:
validate_preference_openai(json_line, idx)
elif current_format == DatasetFormat.CONVERSATION:
message_column = JSONL_REQUIRED_COLUMNS_MAP[
DatasetFormat.CONVERSATION
][0]
require_assistant = purpose != FilePurpose.Eval
validate_messages(
json_line[message_column],
idx,
require_assistant_role=require_assistant,
)
else:
for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
if not isinstance(json_line[column], str):
raise InvalidFileFormatError(
message=f'Invalid value type for "{column}" key on line {idx + 1}. '
f"Expected string. Found {type(json_line[column])}.",
line_number=idx + 1,
error_source="key_value",
)
if dataset_format is None:
dataset_format = current_format
elif current_format is not None:
if current_format != dataset_format:
raise InvalidFileFormatError(
message="All samples in the dataset must have the same dataset format. "
f"Got {dataset_format} for the first line and {current_format} "
f"for the line {idx + 1}.",
line_number=idx + 1,
error_source="format",
)
report_dict.update(_check_samples_count(file, report_dict, idx))
report_dict["load_json"] = True
except InvalidFileFormatError as e:
report_dict["load_json"] = False
report_dict["is_check_passed"] = False
report_dict["message"] = e.message
if e.line_number is not None:
report_dict["line_number"] = e.line_number
if e.error_source is not None:
report_dict[e.error_source] = False
except ValueError:
report_dict["load_json"] = False
if idx < 0:
report_dict["message"] = (
"Unable to decode file. "
"File may be empty or in an unsupported format. "
)
else:
report_dict["message"] = (
f"Error parsing json payload. Unexpected format on line {idx + 1}."
)
report_dict["is_check_passed"] = False
if "text_field" not in report_dict:
report_dict["text_field"] = True
if "line_type" not in report_dict:
report_dict["line_type"] = True
if "key_value" not in report_dict:
report_dict["key_value"] = True
return report_dict
def _check_parquet(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
try:
# Pyarrow is optional as it's large (~80MB) and isn't compatible with older systems.
from pyarrow import ArrowInvalid, parquet
except ImportError:
raise ImportError(
"pyarrow is not installed and is required to use parquet files. Please install it via `pip install together[pyarrow]`"
)
report_dict: Dict[str, Any] = {}
if purpose == FilePurpose.Eval:
report_dict["is_check_passed"] = False
report_dict["message"] = (
f"Parquet files are not supported for {purpose}. "
"Only JSONL and CSV files are supported."
)
return report_dict
try:
table = parquet.read_table(str(file), memory_map=True)
except ArrowInvalid:
report_dict["load_parquet"] = (
f"An exception has occurred when loading the Parquet file {file}. Please check the file for corruption. "
f"Exception trace:\n{format_exc()}"
)
report_dict["is_check_passed"] = False
return report_dict
column_names = table.schema.names
if "input_ids" not in column_names:
report_dict["load_parquet"] = (
f"Parquet file {file} does not contain the `input_ids` column."
)
report_dict["is_check_passed"] = False
return report_dict
# Don't check for eval
for column_name in column_names:
if column_name not in PARQUET_EXPECTED_COLUMNS:
report_dict["load_parquet"] = (
f"Parquet file {file} contains an unexpected column {column_name}. "
f"Only columns {PARQUET_EXPECTED_COLUMNS} are supported."
)
report_dict["is_check_passed"] = False
return report_dict
num_samples = len(table)
if num_samples < MIN_SAMPLES:
report_dict["has_min_samples"] = False
report_dict["message"] = (
f"Processing {file} resulted in only {num_samples} samples. "
f"Our minimum is {MIN_SAMPLES} samples. "
)
report_dict["is_check_passed"] = False
return report_dict
else:
report_dict["num_samples"] = num_samples
report_dict["is_check_passed"] = True
return report_dict