Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 118 additions & 17 deletions agents/openhands/std_to_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from schema.action.api import ApiAction
from schema.action.code import CodeAction
from schema.action.message import MessageAction
from schema.observation.image import ImageObservation
from schema.observation.text import TextObservation
from schema.observation.web import WebObservation
from schema.trajectory import Trajectory
Expand Down Expand Up @@ -91,7 +92,52 @@ def standardized_event_to_openhands_message(
else:
axtree = generate_axtree.last_xtree
prompt = get_web_user_message("", event.url, axtree, PREV_BID)
return {"from": "human", "value": prompt}

# Handle nested image observation
image_path = None
if hasattr(event, "image_observation") and event.image_observation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused, where would this nested image observation come from? I didn't see it in other parts of the code.

In general, "getattr" and "hasattr" are kinda anti-patterns in Python programming. They are indicative of not strictly adhering to type definitions, and can cause all kinds of tricky runtime errors. Let's try to write this without using these.

image_path = event.image_observation.content

# Add visual observation section
prompt += "\n\n---\nVISUAL OBSERVATION:\n<image>"

# Add image annotations if present (using enhanced parsing)
if (
hasattr(event.image_observation, "annotations")
and event.image_observation.annotations
):
annotations = []
for annotation in event.image_observation.annotations:
# Build annotation description from available fields
parts = []
if hasattr(annotation, "text") and annotation.text:
parts.append(annotation.text)
elif (
hasattr(annotation, "content_description")
and annotation.content_description
):
parts.append(annotation.content_description)

# Add element type
if hasattr(annotation, "element_type"):
parts.append(f"({annotation.element_type})")

# Add interactivity info
attrs = []
if hasattr(annotation, "clickable") and annotation.clickable:
attrs.append("clickable")
if hasattr(annotation, "editable") and annotation.editable:
attrs.append("editable")
if attrs:
parts.append(f"[{', '.join(attrs)}]")

if parts:
annotations.append(" ".join(parts))

if annotations:
prompt += "\nElements detected: " + ", ".join(annotations)

return {"from": "human", "value": prompt, "_image_path": image_path}

if isinstance(event, ApiAction):
PREV_BID = None
Expand Down Expand Up @@ -133,10 +179,23 @@ def standardized_event_to_openhands_message(
event_xpath = event.kwargs.get("xpath", None)
if event_xpath:
browsergym_id = generate_axtree.get_bid(id, event_xpath, "all")

# Generate placeholder bid for web datasets when get_bid fails
if not browsergym_id and is_web:
event_xpath = event.kwargs.get("xpath", None)
if event_xpath:
# Use xpath hash as placeholder to maintain some consistency
placeholder_id = f"placeholder_bid_{abs(hash(event_xpath)) % 10000}"
browsergym_id = f'"{placeholder_id}"'
print(
f"Warning: Generated placeholder bid {browsergym_id} for xpath: {event_xpath}",
file=sys.stderr,
)

# for tool calls that are not browser based since there is no browsergym_id
# and tool calls that are specified as non-web
# these should all be dataset specific apis
if (not browsergym_id or not is_web) and function_name in api_sigs:
if not is_web and function_name in api_sigs:
if not api_env:
# Default to 'execute_ipython_cell' if api_env is not specified
api_env = "execute_ipython_cell"
Expand All @@ -151,7 +210,8 @@ def standardized_event_to_openhands_message(
return {"from": "function_call", "value": f"{thought}{function_call}"}

api_env = "browser"
if not browsergym_id[0] == browsergym_id[-1] == '"':

if browsergym_id and not browsergym_id[0] == browsergym_id[-1] == '"':
browsergym_id = f'"{browsergym_id[0]}"'
PREV_BID = browsergym_id
# for apis that are browser based but are not OH default browser apis
Expand Down Expand Up @@ -223,25 +283,49 @@ def standardized_event_to_openhands_message(
raise ValueError(f"Wrong event source: {event.source}")
return {"from": event.source, "value": event.content}

elif hasattr(event, "__class__") and event.__class__.__name__ == "ImageObservation":
elif isinstance(event, ImageObservation):
# Handle ImageObservation
annotations_text = ""
if hasattr(event, "annotations") and event.annotations:
annotations = []
for annotation in event.annotations:
# Build annotation description from available fields
parts = []
if hasattr(annotation, "text") and annotation.text:
annotations.append(f"{annotation.text} ({annotation.element_type})")
parts.append(annotation.text)
elif hasattr(annotation, "content_description") and annotation.content_description:
parts.append(annotation.content_description)

# Add element type
if hasattr(annotation, "element_type"):
parts.append(f"({annotation.element_type})")

# Add interactivity info
attrs = []
if hasattr(annotation, "clickable") and annotation.clickable:
attrs.append("clickable")
if hasattr(annotation, "editable") and annotation.editable:
attrs.append("editable")
if attrs:
parts.append(f"[{', '.join(attrs)}]")

if parts:
annotations.append(" ".join(parts))

if annotations:
annotations_text = "Elements detected: " + ", ".join(annotations)

image_path = getattr(event, "content", "unknown_image_path")
return {"from": "observation", "value": f"[Image: {image_path}]\n{annotations_text}"}
return {
"from": "observation",
"value": f"<image>{annotations_text}",
"_image_path": event.content,
}

else:
raise ValueError(f"Unknown event type: {type(event)}\n{event}")


def process_row(line, is_web, api_env, api_tool_description, api_sigs):
def process_row(line, is_web, api_env, api_tool_description, api_sigs, export_for="explicit"):
std_dataset = [json.loads(line)]
std_data = std_dataset[0]
trajectory = Trajectory(**std_data)
Expand All @@ -251,6 +335,7 @@ def process_row(line, is_web, api_env, api_tool_description, api_sigs):
conversations = []
previous_web_actions = []
languages = []
image_paths = []
for i in range(len(events)):
event = events[i]
try:
Expand All @@ -259,6 +344,13 @@ def process_row(line, is_web, api_env, api_tool_description, api_sigs):
)
if not message:
return None

# Extract image path if present
if "_image_path" in message:
path = message.pop("_image_path")
if path:
image_paths.append(path)

if len(conversations) == 0:
# append api function docs to first user message when available
if api_env:
Expand Down Expand Up @@ -290,18 +382,24 @@ def process_row(line, is_web, api_env, api_tool_description, api_sigs):
language_descriptions = get_language_descriptions(languages)
conversations[0]["value"] = language_descriptions + "\n\n" + conversations[0]["value"]
for m in conversations:
if m["from"] == "function_call":
if export_for == "training" and m["from"] == "function_call":
m["from"] = "gpt"
if m["from"] == "observation":
m["from"] = "human"
return {

output = {
"id": trajectory.id,
"conversations": conversations,
"system": get_system_message(),
}

if image_paths:
output["images"] = image_paths

return output


def process_line(line, is_web, api_env):
def process_line(line, is_web, api_env, export_for="explicit"):
exclude_apis = browser_default_apis if is_web else {}
api_tool_description, api_sigs = get_api_tool_description(dataset, exclude_apis, api_env)
output_line = process_row(
Expand All @@ -310,6 +408,7 @@ def process_line(line, is_web, api_env):
api_env=api_env,
api_tool_description=api_tool_description,
api_sigs=api_sigs,
export_for=export_for,
)
output_line = json.dumps(output_line)
# if output_line:
Expand All @@ -324,11 +423,6 @@ def process_line(line, is_web, api_env):
return output_line


# Keep the old main function for backward compatibility
def main_with_args(line, is_web, api_env):
return process_line(line, is_web, api_env)


def main():
parser = argparse.ArgumentParser(description="Convert standardized data to SFT format")
parser.add_argument(
Expand All @@ -346,10 +440,17 @@ def main():
help="The environment in which the APIs are pre-defined",
default=None,
)
parser.add_argument(
"--export_for",
type=str,
choices=["explicit", "training"],
default="explicit",
help="'explicit' preserves function_call message role, 'training' replaces it with gpt role for LLaMA Factory",
)
args = parser.parse_args()
args.is_web = args.is_web == "yes"
for line in sys.stdin:
print(main_with_args(line, args.is_web, args.api_env))
print(process_line(line, args.is_web, args.api_env, args.export_for))


if __name__ == "__main__":
Expand Down
99 changes: 93 additions & 6 deletions datasets/android_in_the_wild/raw_to_standardized.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,40 @@
from schema.observation.observation import Observation
from schema.trajectory import Trajectory

# Constants from Android in the Wild action matching code
_SWIPE_DISTANCE_THRESHOLD = 0.04


def _is_tap(touch_yx: List[float], lift_yx: List[float]) -> bool:
"""Check if a dual-point gesture is a tap (touch and lift are close together).

Args:
touch_yx: The (y, x) coordinates where the touch started
lift_yx: The (y, x) coordinates where the touch lifted

Returns:
True if the action is a tap, False if it's a swipe
"""
distance = ((touch_yx[0] - lift_yx[0]) ** 2 + (touch_yx[1] - lift_yx[1]) ** 2) ** 0.5
return distance <= _SWIPE_DISTANCE_THRESHOLD


def _point_in_bbox(point_yx: List[float], bbox: List[float]) -> bool:
"""Check if a point falls within a bounding box.

Args:
point_yx: The (y, x) coordinates of the point (normalized 0-1)
bbox: The bounding box as [y, x, height, width] (normalized 0-1)

Returns:
True if the point is within the bounding box
"""
y, x = point_yx
bbox_y, bbox_x, bbox_height, bbox_width = bbox

# Check if point is within the box boundaries
return bbox_y <= y <= bbox_y + bbox_height and bbox_x <= x <= bbox_x + bbox_width


def process_episode(episode_data: List[Dict]) -> Dict:
"""Process a list of data for a single episode into a standardized trajectory.
Expand All @@ -28,6 +62,43 @@ def process_episode(episode_data: List[Dict]) -> Dict:
# Add the goal info as the first message
content.append(MessageAction(content=episode_data[0]["goal_info"]))

# Pass 1: Analyze actions to determine clickability and editability
# Structure: {step_id: {annotation_idx: {"clickable": bool, "editable": bool}}}
annotation_properties = {}

for idx, data in enumerate(episode_data):
step_id = data["step_id"]
annotation_properties[step_id] = {}

# Initialize properties for each annotation
num_annotations = len(data["image/ui_annotations_positions"])
for ann_idx in range(num_annotations):
annotation_properties[step_id][ann_idx] = {"clickable": False, "editable": False}

# Mark ICON_* elements as clickable
for ann_idx, ui_type in enumerate(data["image/ui_annotations_ui_types"]):
if ui_type.startswith("ICON_"):
annotation_properties[step_id][ann_idx]["clickable"] = True

# Check if current action is a tap and mark the tapped element as clickable
if data["results/action_type"] == "dual-point gesture":
touch_yx = data["results/yx_touch"]
lift_yx = data["results/yx_lift"]

if _is_tap(touch_yx, lift_yx):
# Find which annotation contains the tap point
for ann_idx, bbox in enumerate(data["image/ui_annotations_positions"]):
if _point_in_bbox(touch_yx, bbox):
annotation_properties[step_id][ann_idx]["clickable"] = True

# Check if next action is type, then mark as editable
if idx + 1 < len(episode_data):
next_data = episode_data[idx + 1]
if next_data["results/action_type"] == "type":
annotation_properties[step_id][ann_idx]["editable"] = True
break # Only mark the first matching bounding box

# Pass 2: Create trajectory content with enhanced annotations
for data in episode_data:
# Validating assumptions
if data["goal_info"] != content[0].content:
Expand All @@ -36,6 +107,7 @@ def process_episode(episode_data: List[Dict]) -> Dict:
f" but got: {data['goal_info']} != {content[0].content}"
)
# Create the image observation
step_id = data["step_id"]
annotations = [
ImageAnnotation(
text=text,
Expand All @@ -46,11 +118,15 @@ def process_episode(episode_data: List[Dict]) -> Dict:
width=pos[3],
height=pos[2],
),
clickable=annotation_properties[step_id][ann_idx]["clickable"],
editable=annotation_properties[step_id][ann_idx]["editable"],
)
for text, ui_type, pos in zip(
data["image/ui_annotations_text"],
data["image/ui_annotations_ui_types"],
data["image/ui_annotations_positions"],
for ann_idx, (text, ui_type, pos) in enumerate(
zip(
data["image/ui_annotations_text"],
data["image/ui_annotations_ui_types"],
data["image/ui_annotations_positions"],
)
)
]
content.append(
Expand Down Expand Up @@ -98,9 +174,20 @@ def process_episode(episode_data: List[Dict]) -> Dict:
current_episode_id = None
current_episode_data = []

for line in sys.stdin:
data = json.loads(line)
# Read all input first to detect format
input_text = sys.stdin.read()

# Try to parse as a JSON array first (for sample files)
try:
all_data = json.loads(input_text)
# If it's a single object, wrap it in a list
if isinstance(all_data, dict):
all_data = [all_data]
except json.JSONDecodeError:
# Fall back to newline-delimited JSON
all_data = [json.loads(line) for line in input_text.strip().split("\n") if line.strip()]

for data in all_data:
# If we encounter a new episode, process the previous one
if current_episode_id is not None and current_episode_id != data["episode_id"]:
# Process and output the current episode
Expand Down
Loading