Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions crates/hiroz-codegen/src/python_msgspec_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ pub fn generate_python_bindings(
.push(msg);
}

// Group services by package so we can emit rclpy-style grouping classes (P4).
let mut service_groups: HashMap<String, Vec<&ResolvedService>> = HashMap::new();
for srv in services {
service_groups
.entry(srv.parsed.package.clone())
.or_default()
.push(srv);
}

// Group service Request/Response by package, and track service type hashes
let mut service_messages: HashMap<String, Vec<&ResolvedMessage>> = HashMap::new();
let mut service_hashes: HashMap<String, HashMap<String, String>> = HashMap::new();
Expand Down Expand Up @@ -65,11 +74,16 @@ pub fn generate_python_bindings(
.get(package_name)
.cloned()
.unwrap_or_default();
let srv_groups = service_groups
.get(package_name)
.map(|v| v.as_slice())
.unwrap_or(&[]);
let python_code = generate_python_package_with_services(
package_name,
package_msgs,
srv_msgs,
&svc_hashes,
srv_groups,
)?;
let output_path = python_output_dir.join(format!("{}.py", package_name));
fs::write(output_path, python_code)?;
Expand All @@ -82,8 +96,17 @@ pub fn generate_python_bindings(
.get(package_name)
.cloned()
.unwrap_or_default();
let python_code =
generate_python_package_with_services(package_name, &[], srv_msgs, &svc_hashes)?;
let srv_groups = service_groups
.get(package_name)
.map(|v| v.as_slice())
.unwrap_or(&[]);
let python_code = generate_python_package_with_services(
package_name,
&[],
srv_msgs,
&svc_hashes,
srv_groups,
)?;
let output_path = python_output_dir.join(format!("{}.py", package_name));
fs::write(output_path, python_code)?;
}
Expand Down Expand Up @@ -113,6 +136,7 @@ fn generate_python_package_with_services(
messages: &[&ResolvedMessage],
service_messages: &[&ResolvedMessage],
service_hashes: &HashMap<String, String>,
service_groups: &[&ResolvedService],
) -> Result<String> {
let mut code = format!(
"\"\"\"Auto-generated ROS 2 message types for {}.\"\"\"\n\
Expand All @@ -132,9 +156,33 @@ fn generate_python_package_with_services(
code.push_str(&generate_msgspec_struct(msg, svc_hash.map(|s| s.as_str()))?);
}

// Emit rclpy-style service grouping classes (P4). These reference the
// Request/Response structs above, so they must come after them.
for srv in service_groups {
code.push_str(&generate_service_grouping_class(srv));
}

Ok(code)
}

/// Generate a service grouping class: `AddTwoInts.Request` / `.Response` (P4).
///
/// Lets `create_client`/`create_server` accept a single rclpy-style type
/// (`example_interfaces.AddTwoInts`) instead of the bare Request class.
fn generate_service_grouping_class(srv: &ResolvedService) -> String {
let srv_name = &srv.parsed.name;
let package = &srv.parsed.package;
let request_struct = &srv.request.parsed.name;
let response_struct = &srv.response.parsed.name;
format!(
"class {srv_name}:\n \
\"\"\"Service grouping type. Use {srv_name}.Request and {srv_name}.Response.\"\"\"\n \
__srvtype__: ClassVar[str] = '{package}/srv/{srv_name}'\n \
Request: ClassVar[type] = {request_struct}\n \
Response: ClassVar[type] = {response_struct}\n\n"
)
}

fn rust_to_python_type(field_type: &FieldType, current_package: &str) -> Result<String> {
// Get the base field type (without array indicators)
let base_type = &field_type.base_type;
Expand Down
18 changes: 12 additions & 6 deletions crates/hiroz-msgs/python/hiroz_msgs_py/types/action_msgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import msgspec
from typing import ClassVar

class GoalStatusArray(msgspec.Struct, frozen=True, kw_only=True):
status_list: list["action_msgs.GoalStatus"] = msgspec.field(default_factory=list)

__msgtype__: ClassVar[str] = 'action_msgs/msg/GoalStatusArray'
__hash__: ClassVar[str] = 'RIHS01_6c1684b00f177d37438febe6e709fc4e2b0d4248dca4854946f9ed8b30cda83e'

class GoalInfo(msgspec.Struct, frozen=True, kw_only=True):
goal_id: "unique_identifier_msgs.UUID | None" = None
stamp: "builtin_interfaces.Time | None" = msgspec.field(default_factory=lambda: {'sec': 0, 'nanosec': 0})
Expand All @@ -16,12 +22,6 @@ class GoalStatus(msgspec.Struct, frozen=True, kw_only=True):
__msgtype__: ClassVar[str] = 'action_msgs/msg/GoalStatus'
__hash__: ClassVar[str] = 'RIHS01_32f4cfd717735d17657e1178f24431c1ce996c878c515230f6c5b3476819dbb9'

class GoalStatusArray(msgspec.Struct, frozen=True, kw_only=True):
status_list: list["action_msgs.GoalStatus"] = msgspec.field(default_factory=list)

__msgtype__: ClassVar[str] = 'action_msgs/msg/GoalStatusArray'
__hash__: ClassVar[str] = 'RIHS01_6c1684b00f177d37438febe6e709fc4e2b0d4248dca4854946f9ed8b30cda83e'

class CancelGoalRequest(msgspec.Struct, frozen=True, kw_only=True):
goal_info: "action_msgs.GoalInfo | None" = None

Expand All @@ -35,3 +35,9 @@ class CancelGoalResponse(msgspec.Struct, frozen=True, kw_only=True):
__msgtype__: ClassVar[str] = 'action_msgs/msg/CancelGoalResponse'
__hash__: ClassVar[str] = 'RIHS01_c66d49f351ea4375bf3eef8569e74b7afc19305d9fa94c71b412262e411f2a8f'

class CancelGoal:
"""Service grouping type. Use CancelGoal.Request and CancelGoal.Response."""
__srvtype__: ClassVar[str] = 'action_msgs/srv/CancelGoal'
Request: ClassVar[type] = CancelGoalRequest
Response: ClassVar[type] = CancelGoalResponse

Loading
Loading