Skip to content

Commit 24d209e

Browse files
committed
Update validation function; Use correct errors
1 parent b322858 commit 24d209e

1 file changed

Lines changed: 84 additions & 109 deletions

File tree

src/together/utils/files.py

Lines changed: 84 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -108,107 +108,142 @@ def _has_weights(messages: List[Dict[str, str | bool]]) -> bool:
108108
return any("weight" in message for message in messages)
109109

110110

111-
def validate_and_filter_messages(
112-
messages: List[Dict[str, str | bool]]
111+
def validate_messages(
112+
messages: List[Dict[str, str | bool]], idx: int = 0
113113
) -> tuple[List[Dict[str, str | bool]], bool]:
114-
"""Validate and filter the messages column."""
114+
"""Validate the messages column."""
115115
if not isinstance(messages, list):
116-
raise ValueError(
117-
"The dataset is malformed, the `messages` column must be a list."
116+
raise InvalidFileFormatError(
117+
message="The dataset is malformed, the `messages` column must be a list.",
118+
line_number=idx + 1,
119+
error_source="key_value",
118120
)
119-
if len(messages) == 0:
120-
raise ValueError(
121-
"The dataset is malformed, the `messages` column must not be empty."
121+
if not messages:
122+
raise InvalidFileFormatError(
123+
message="The dataset is malformed, the `messages` column must not be empty.",
124+
line_number=idx + 1,
125+
error_source="key_value",
122126
)
123127

124128
has_weights = False
125129
# Check for weights in messages
126130
if _has_weights(messages):
127131
has_weights = True
128132

129-
filtered_messages = []
133+
previous_role = None
130134
for message in messages:
131135
if any(column not in message for column in REQUIRED_COLUMNS_MESSAGE):
132-
raise ValueError(
133-
"The dataset is malformed. "
136+
raise InvalidFileFormatError(
137+
message="The dataset is malformed. "
134138
"Each message in the messages column must have "
135-
f"{REQUIRED_COLUMNS_MESSAGE} columns."
139+
f"{REQUIRED_COLUMNS_MESSAGE} columns.",
140+
line_number=idx + 1,
141+
error_source="key_value",
136142
)
137143
for column in REQUIRED_COLUMNS_MESSAGE:
138144
if not isinstance(message[column], str):
139-
raise ValueError(
140-
f"The dataset is malformed, the column `{column}` must be of the string type."
145+
raise InvalidFileFormatError(
146+
message=f"The dataset is malformed, the column `{column}` must be of the string type.",
147+
line_number=idx + 1,
148+
error_source="key_value",
141149
)
142150

143151
if has_weights and "weight" in message:
144152
weight = message["weight"]
145153
if not isinstance(weight, int):
146-
raise ValueError("Weight must be an integer")
154+
raise InvalidFileFormatError(
155+
message="Weight must be an integer",
156+
line_number=idx + 1,
157+
error_source="key_value",
158+
)
147159
if weight not in {0, 1}:
148-
raise ValueError("Weight must be either 0 or 1")
160+
raise InvalidFileFormatError(
161+
message="Weight must be either 0 or 1",
162+
line_number=idx + 1,
163+
error_source="key_value",
164+
)
149165
if message["role"] not in POSSIBLE_ROLES_CONVERSATION:
150-
raise ValueError(
151-
f"Invalid role {message['role']} in conversation, possible roles: "
152-
f"{', '.join(POSSIBLE_ROLES_CONVERSATION)}"
166+
raise InvalidFileFormatError(
167+
message=f"Invalid role {message['role']} in conversation, possible roles: "
168+
f"{', '.join(POSSIBLE_ROLES_CONVERSATION)}",
169+
line_number=idx + 1,
170+
error_source="key_value",
153171
)
154-
filtered_messages.append(
155-
{column: message[column] for column in REQUIRED_COLUMNS_MESSAGE}
156-
)
157172

158-
return filtered_messages, has_weights
173+
if previous_role == message["role"]:
174+
raise InvalidFileFormatError(
175+
message=f"Invalid role turns on line {idx + 1} of the input file. "
176+
"`user` and `assistant` roles must alternate user/assistant/user/assistant/...",
177+
line_number=idx + 1,
178+
error_source="key_value",
179+
)
180+
previous_role = message["role"]
181+
182+
return messages, has_weights
159183

160184

161-
def validate_preference_openai(example: Dict[str, Any]) -> Dict[str, Any]:
185+
def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[str, Any]:
162186
"""Validate the OpenAI preference dataset format.
163187
164188
Args:
165189
example (dict): Input entry to be checked.
190+
idx (int): Line number in the file.
166191
167192
Raises:
168-
ValueError: If the dataset format is invalid.
193+
InvalidFileFormatError: If the dataset format is invalid.
169194
170195
Returns:
171196
Dict[str, Any]: The validated example.
172197
"""
173198
if not isinstance(example["input"], dict):
174-
raise ValueError(
175-
"The dataset is malformed, the `input` field must be a dictionary."
199+
raise InvalidFileFormatError(
200+
message="The dataset is malformed, the `input` field must be a dictionary.",
201+
line_number=idx + 1,
202+
error_source="key_value",
176203
)
177204

178205
if "messages" not in example["input"]:
179-
raise ValueError(
180-
"The dataset is malformed, the `input` dictionary must contain a `messages` field."
206+
raise InvalidFileFormatError(
207+
message="The dataset is malformed, the `input` dictionary must contain a `messages` field.",
208+
line_number=idx + 1,
209+
error_source="key_value",
181210
)
182211

183-
example["input"]["messages"], _ = validate_and_filter_messages(
184-
example["input"]["messages"]
212+
example["input"]["messages"], _ = validate_messages(
213+
example["input"]["messages"], idx
185214
)
186215

187216
if not isinstance(example["preferred_output"], list):
188-
raise ValueError(
189-
"The dataset is malformed, the `preferred_output` field must be a list."
217+
raise InvalidFileFormatError(
218+
message="The dataset is malformed, the `preferred_output` field must be a list.",
219+
line_number=idx + 1,
220+
error_source="key_value",
190221
)
191222

192223
if not isinstance(example["non_preferred_output"], list):
193-
raise ValueError(
194-
"The dataset is malformed, the `non_preferred_output` field must be a list."
224+
raise InvalidFileFormatError(
225+
message="The dataset is malformed, the `non_preferred_output` field must be a list.",
226+
line_number=idx + 1,
227+
error_source="key_value",
195228
)
196229

197230
if len(example["preferred_output"]) != 1:
198-
raise ValueError(
199-
"The dataset is malformed, the `preferred_output` list must contain exactly one message."
231+
raise InvalidFileFormatError(
232+
message="The dataset is malformed, the `preferred_output` list must contain exactly one message.",
233+
line_number=idx + 1,
234+
error_source="key_value",
200235
)
201236

202237
if len(example["non_preferred_output"]) != 1:
203-
raise ValueError(
204-
"The dataset is malformed, the `non_preferred_output` list must contain exactly one message."
238+
raise InvalidFileFormatError(
239+
message="The dataset is malformed, the `non_preferred_output` list must contain exactly one message.",
240+
line_number=idx + 1,
241+
error_source="key_value",
205242
)
206243

207-
example["preferred_output"], _ = validate_and_filter_messages(
208-
example["preferred_output"]
209-
)
210-
example["non_preferred_output"], _ = validate_and_filter_messages(
211-
example["non_preferred_output"]
244+
example["preferred_output"], _ = validate_messages(example["preferred_output"], idx)
245+
example["non_preferred_output"], _ = validate_messages(
246+
example["non_preferred_output"], idx
212247
)
213248
return example
214249

@@ -282,7 +317,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
282317
error_source="format",
283318
)
284319
if current_format == DatasetFormat.PREFERENCE_OPENAI:
285-
validate_preference_openai(json_line)
320+
validate_preference_openai(json_line, idx)
286321
elif current_format == DatasetFormat.PREFERENCE:
287322
for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
288323
if not isinstance(json_line[column], list):
@@ -297,7 +332,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
297332
line_number=idx + 1,
298333
error_source="key_value",
299334
)
300-
validate_and_filter_messages(json_line[column])
335+
validate_messages(json_line[column], idx)
301336
if not json_line[column][-1].get("role") == "assistant":
302337
raise InvalidFileFormatError(
303338
message=f"The last message in {column} must be from an assistant",
@@ -333,69 +368,9 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
333368
message_column = JSONL_REQUIRED_COLUMNS_MAP[
334369
DatasetFormat.CONVERSATION
335370
][0]
336-
if not isinstance(json_line[message_column], list):
337-
raise InvalidFileFormatError(
338-
message=f"Invalid format on line {idx + 1} of the input file. "
339-
f"Expected a list of messages. Found {type(json_line[message_column])}",
340-
line_number=idx + 1,
341-
error_source="key_value",
342-
)
343-
344-
if len(json_line[message_column]) == 0:
345-
raise InvalidFileFormatError(
346-
message=f"Invalid format on line {idx + 1} of the input file. "
347-
f"Expected a non-empty list of messages. Found empty list",
348-
line_number=idx + 1,
349-
error_source="key_value",
350-
)
351-
352-
for turn_id, turn in enumerate(json_line[message_column]):
353-
if not isinstance(turn, dict):
354-
raise InvalidFileFormatError(
355-
message=f"Invalid format on line {idx + 1} of the input file. "
356-
f"Expected a dictionary in the {turn_id + 1} turn. Found {type(turn)}",
357-
line_number=idx + 1,
358-
error_source="key_value",
359-
)
360-
361-
previous_role = None
362-
for turn in json_line[message_column]:
363-
for column in REQUIRED_COLUMNS_MESSAGE:
364-
if column not in turn:
365-
raise InvalidFileFormatError(
366-
message=f"Field `{column}` is missing for a turn `{turn}` on line {idx + 1} "
367-
"of the the input file.",
368-
line_number=idx + 1,
369-
error_source="key_value",
370-
)
371-
else:
372-
if not isinstance(turn[column], str):
373-
raise InvalidFileFormatError(
374-
message=f"Invalid format on line {idx + 1} in the column {column} for turn `{turn}` "
375-
f"of the input file. Expected string. Found {type(turn[column])}",
376-
line_number=idx + 1,
377-
error_source="text_field",
378-
)
379-
role = turn["role"]
380-
381-
if role not in POSSIBLE_ROLES_CONVERSATION:
382-
raise InvalidFileFormatError(
383-
message=f"Found invalid role `{role}` in the messages on the line {idx + 1}. "
384-
f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}",
385-
line_number=idx + 1,
386-
error_source="key_value",
387-
)
388-
389-
if previous_role == role:
390-
raise InvalidFileFormatError(
391-
message=f"Invalid role turns on line {idx + 1} of the input file. "
392-
"`user` and `assistant` roles must alternate user/assistant/user/assistant/...",
393-
line_number=idx + 1,
394-
error_source="key_value",
395-
)
396-
397-
previous_role = role
398-
371+
messages, has_weights = validate_messages(
372+
json_line[message_column], idx
373+
)
399374
else:
400375
for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
401376
if not isinstance(json_line[column], str):

0 commit comments

Comments
 (0)