-
Notifications
You must be signed in to change notification settings - Fork 12
Implement pipeline and new schema for image datasets #158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6f39257
46b3ae3
3a9a049
4508e78
16f29a0
a05ecd4
7ced439
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| from schema.action.api import ApiAction | ||
| from schema.action.code import CodeAction | ||
| from schema.action.message import MessageAction | ||
| from schema.observation.image import ImageObservation | ||
| from schema.observation.text import TextObservation | ||
| from schema.observation.web import WebObservation | ||
| from schema.trajectory import Trajectory | ||
|
|
@@ -91,7 +92,52 @@ def standardized_event_to_openhands_message( | |
| else: | ||
| axtree = generate_axtree.last_xtree | ||
| prompt = get_web_user_message("", event.url, axtree, PREV_BID) | ||
| return {"from": "human", "value": prompt} | ||
|
|
||
| # Handle nested image observation | ||
| image_path = None | ||
| if hasattr(event, "image_observation") and event.image_observation: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit confused, where would this nested image observation come from? I didn't see it in other parts of the code. In general, "getattr" and "hasattr" are kinda anti-patterns in Python programming. They are indicative of not strictly adhering to type definitions, and can cause all kinds of tricky runtime errors. Let's try to write this without using these. |
||
| image_path = event.image_observation.content | ||
|
|
||
| # Add visual observation section | ||
| prompt += "\n\n---\nVISUAL OBSERVATION:\n<image>" | ||
|
|
||
| # Add image annotations if present (using enhanced parsing) | ||
| if ( | ||
| hasattr(event.image_observation, "annotations") | ||
| and event.image_observation.annotations | ||
| ): | ||
| annotations = [] | ||
| for annotation in event.image_observation.annotations: | ||
| # Build annotation description from available fields | ||
| parts = [] | ||
| if hasattr(annotation, "text") and annotation.text: | ||
| parts.append(annotation.text) | ||
| elif ( | ||
| hasattr(annotation, "content_description") | ||
| and annotation.content_description | ||
| ): | ||
| parts.append(annotation.content_description) | ||
|
|
||
| # Add element type | ||
| if hasattr(annotation, "element_type"): | ||
| parts.append(f"({annotation.element_type})") | ||
|
|
||
| # Add interactivity info | ||
| attrs = [] | ||
| if hasattr(annotation, "clickable") and annotation.clickable: | ||
| attrs.append("clickable") | ||
| if hasattr(annotation, "editable") and annotation.editable: | ||
| attrs.append("editable") | ||
| if attrs: | ||
| parts.append(f"[{', '.join(attrs)}]") | ||
|
|
||
| if parts: | ||
| annotations.append(" ".join(parts)) | ||
|
|
||
| if annotations: | ||
| prompt += "\nElements detected: " + ", ".join(annotations) | ||
|
|
||
| return {"from": "human", "value": prompt, "_image_path": image_path} | ||
|
|
||
| if isinstance(event, ApiAction): | ||
| PREV_BID = None | ||
|
|
@@ -133,10 +179,23 @@ def standardized_event_to_openhands_message( | |
| event_xpath = event.kwargs.get("xpath", None) | ||
| if event_xpath: | ||
| browsergym_id = generate_axtree.get_bid(id, event_xpath, "all") | ||
|
|
||
| # Generate placeholder bid for web datasets when get_bid fails | ||
| if not browsergym_id and is_web: | ||
| event_xpath = event.kwargs.get("xpath", None) | ||
| if event_xpath: | ||
| # Use xpath hash as placeholder to maintain some consistency | ||
| placeholder_id = f"placeholder_bid_{abs(hash(event_xpath)) % 10000}" | ||
| browsergym_id = f'"{placeholder_id}"' | ||
| print( | ||
| f"Warning: Generated placeholder bid {browsergym_id} for xpath: {event_xpath}", | ||
| file=sys.stderr, | ||
| ) | ||
|
|
||
| # for tool calls that are not browser based since there is no browsergym_id | ||
| # and tool calls that are specified as non-web | ||
| # these should all be dataset specific apis | ||
| if (not browsergym_id or not is_web) and function_name in api_sigs: | ||
| if not is_web and function_name in api_sigs: | ||
| if not api_env: | ||
| # Default to 'execute_ipython_cell' if api_env is not specified | ||
| api_env = "execute_ipython_cell" | ||
|
|
@@ -151,7 +210,8 @@ def standardized_event_to_openhands_message( | |
| return {"from": "function_call", "value": f"{thought}{function_call}"} | ||
|
|
||
| api_env = "browser" | ||
| if not browsergym_id[0] == browsergym_id[-1] == '"': | ||
|
|
||
| if browsergym_id and not browsergym_id[0] == browsergym_id[-1] == '"': | ||
| browsergym_id = f'"{browsergym_id[0]}"' | ||
| PREV_BID = browsergym_id | ||
| # for apis that are browser based but are not OH default browser apis | ||
|
|
@@ -223,25 +283,49 @@ def standardized_event_to_openhands_message( | |
| raise ValueError(f"Wrong event source: {event.source}") | ||
| return {"from": event.source, "value": event.content} | ||
|
|
||
| elif hasattr(event, "__class__") and event.__class__.__name__ == "ImageObservation": | ||
| elif isinstance(event, ImageObservation): | ||
MajikalExplosions marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Handle ImageObservation | ||
| annotations_text = "" | ||
| if hasattr(event, "annotations") and event.annotations: | ||
| annotations = [] | ||
| for annotation in event.annotations: | ||
| # Build annotation description from available fields | ||
| parts = [] | ||
| if hasattr(annotation, "text") and annotation.text: | ||
| annotations.append(f"{annotation.text} ({annotation.element_type})") | ||
| parts.append(annotation.text) | ||
| elif hasattr(annotation, "content_description") and annotation.content_description: | ||
| parts.append(annotation.content_description) | ||
|
|
||
| # Add element type | ||
| if hasattr(annotation, "element_type"): | ||
| parts.append(f"({annotation.element_type})") | ||
|
|
||
| # Add interactivity info | ||
| attrs = [] | ||
| if hasattr(annotation, "clickable") and annotation.clickable: | ||
| attrs.append("clickable") | ||
| if hasattr(annotation, "editable") and annotation.editable: | ||
| attrs.append("editable") | ||
| if attrs: | ||
| parts.append(f"[{', '.join(attrs)}]") | ||
|
|
||
| if parts: | ||
| annotations.append(" ".join(parts)) | ||
|
|
||
| if annotations: | ||
| annotations_text = "Elements detected: " + ", ".join(annotations) | ||
|
|
||
| image_path = getattr(event, "content", "unknown_image_path") | ||
| return {"from": "observation", "value": f"[Image: {image_path}]\n{annotations_text}"} | ||
| return { | ||
| "from": "observation", | ||
| "value": f"<image>{annotations_text}", | ||
| "_image_path": event.content, | ||
| } | ||
|
|
||
| else: | ||
| raise ValueError(f"Unknown event type: {type(event)}\n{event}") | ||
|
|
||
|
|
||
| def process_row(line, is_web, api_env, api_tool_description, api_sigs): | ||
| def process_row(line, is_web, api_env, api_tool_description, api_sigs, export_for="explicit"): | ||
| std_dataset = [json.loads(line)] | ||
| std_data = std_dataset[0] | ||
| trajectory = Trajectory(**std_data) | ||
|
|
@@ -251,6 +335,7 @@ def process_row(line, is_web, api_env, api_tool_description, api_sigs): | |
| conversations = [] | ||
| previous_web_actions = [] | ||
| languages = [] | ||
| image_paths = [] | ||
| for i in range(len(events)): | ||
| event = events[i] | ||
| try: | ||
|
|
@@ -259,6 +344,13 @@ def process_row(line, is_web, api_env, api_tool_description, api_sigs): | |
| ) | ||
| if not message: | ||
| return None | ||
|
|
||
| # Extract image path if present | ||
| if "_image_path" in message: | ||
| path = message.pop("_image_path") | ||
| if path: | ||
| image_paths.append(path) | ||
|
|
||
| if len(conversations) == 0: | ||
| # append api function docs to first user message when available | ||
| if api_env: | ||
|
|
@@ -290,18 +382,24 @@ def process_row(line, is_web, api_env, api_tool_description, api_sigs): | |
| language_descriptions = get_language_descriptions(languages) | ||
| conversations[0]["value"] = language_descriptions + "\n\n" + conversations[0]["value"] | ||
| for m in conversations: | ||
| if m["from"] == "function_call": | ||
| if export_for == "training" and m["from"] == "function_call": | ||
| m["from"] = "gpt" | ||
| if m["from"] == "observation": | ||
| m["from"] = "human" | ||
| return { | ||
|
|
||
| output = { | ||
| "id": trajectory.id, | ||
| "conversations": conversations, | ||
| "system": get_system_message(), | ||
| } | ||
|
|
||
| if image_paths: | ||
| output["images"] = image_paths | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| def process_line(line, is_web, api_env): | ||
| def process_line(line, is_web, api_env, export_for="explicit"): | ||
| exclude_apis = browser_default_apis if is_web else {} | ||
| api_tool_description, api_sigs = get_api_tool_description(dataset, exclude_apis, api_env) | ||
| output_line = process_row( | ||
|
|
@@ -310,6 +408,7 @@ def process_line(line, is_web, api_env): | |
| api_env=api_env, | ||
| api_tool_description=api_tool_description, | ||
| api_sigs=api_sigs, | ||
| export_for=export_for, | ||
| ) | ||
| output_line = json.dumps(output_line) | ||
| # if output_line: | ||
|
|
@@ -324,11 +423,6 @@ def process_line(line, is_web, api_env): | |
| return output_line | ||
|
|
||
|
|
||
| # Keep the old main function for backward compatibility | ||
| def main_with_args(line, is_web, api_env): | ||
| return process_line(line, is_web, api_env) | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Convert standardized data to SFT format") | ||
| parser.add_argument( | ||
|
|
@@ -346,10 +440,17 @@ def main(): | |
| help="The environment in which the APIs are pre-defined", | ||
| default=None, | ||
| ) | ||
| parser.add_argument( | ||
| "--export_for", | ||
| type=str, | ||
| choices=["explicit", "training"], | ||
| default="explicit", | ||
| help="'explicit' preserves function_call message role, 'training' replaces it with gpt role for LLaMA Factory", | ||
| ) | ||
| args = parser.parse_args() | ||
| args.is_web = args.is_web == "yes" | ||
| for line in sys.stdin: | ||
| print(main_with_args(line, args.is_web, args.api_env)) | ||
| print(process_line(line, args.is_web, args.api_env, args.export_for)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.