diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9cb8099080..ea8769d946 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -108,3 +108,33 @@ repos: entry: python bin/hooks/filter_commit_message.py --check language: python stages: [pre-commit] + + - id: cargo-fmt + name: cargo fmt + entry: | + bash -c ' + set -euo pipefail + git ls-files "*Cargo.toml" | while read -r m; do + cargo locate-project --manifest-path "$m" --workspace --message-format plain + done | sort -u | while read -r root; do + cargo fmt --manifest-path "$root" --all --check + done + ' + language: system + types: [rust] + pass_filenames: false + + - id: cargo-clippy + name: cargo clippy + entry: | + bash -c ' + set -euo pipefail + git ls-files "*Cargo.toml" | while read -r m; do + cargo locate-project --manifest-path "$m" --workspace --message-format plain + done | sort -u | while read -r root; do + cargo clippy --manifest-path "$root" --workspace --all-targets -- -D warnings + done + ' + language: system + types: [rust] + pass_filenames: false diff --git a/examples/native-modules/rust/Cargo.lock b/examples/native-modules/rust/Cargo.lock index 420f9b0ef4..c50ab9564c 100644 --- a/examples/native-modules/rust/Cargo.lock +++ b/examples/native-modules/rust/Cargo.lock @@ -8,10 +8,16 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "dimos-lcm" version = "0.1.0" -source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#fd2e7e2d28597b34dce1d92d3065796a1b722590" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#e7c9428b7201cdfeadecd181c77c9e2d60a14503" dependencies = [ "byteorder", "socket2 0.5.10", @@ -19,25 +25,45 @@ dependencies = [ ] [[package]] -name = "dimos-native-module" +name = "dimos-module" version = "0.1.0" dependencies = [ "dimos-lcm", + "dimos-module-macros", "serde", "serde_json", "tokio", ] +[[package]] +name = "dimos-module-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "dimos-native-module-examples" version = "0.1.0" dependencies = [ - "dimos-native-module", + "dimos-module", "lcm-msgs", "serde", "tokio", ] +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "itoa" version = "1.0.18" @@ -47,7 +73,7 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "lcm-msgs" version = "0.1.0" -source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#fd2e7e2d28597b34dce1d92d3065796a1b722590" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#e7c9428b7201cdfeadecd181c77c9e2d60a14503" dependencies = [ "byteorder", ] @@ -142,6 +168,16 @@ dependencies = [ "zmij", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + [[package]] name = "socket2" version = "0.5.10" @@ -179,9 +215,11 @@ version = "1.52.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" dependencies = [ + "bytes", "libc", "mio", "pin-project-lite", + "signal-hook-registry", "socket2 0.6.3", "tokio-macros", "windows-sys 0.61.2", diff --git a/examples/native-modules/rust/Cargo.toml b/examples/native-modules/rust/Cargo.toml index 65b9a8244f..6add9dd8c4 100644 --- a/examples/native-modules/rust/Cargo.toml +++ b/examples/native-modules/rust/Cargo.toml @@ -12,7 +12,7 @@ name = "native_pong" path = "src/native_pong.rs" [dependencies] -dimos-native-module = { path = "../../../native/rust" } +dimos-module = { path = "../../../native/rust/dimos-module" } lcm-msgs = { git = "https://github.com/dimensionalOS/dimos-lcm.git", branch = "rust-codegen" } tokio = { version = "1", features = ["rt-multi-thread", "macros", "time"] } serde = { version = "1", features = ["derive"] } diff --git a/examples/native-modules/rust/src/native_ping.rs b/examples/native-modules/rust/src/native_ping.rs index ab19c78b28..d1551b050c 100644 --- a/examples/native-modules/rust/src/native_ping.rs +++ b/examples/native-modules/rust/src/native_ping.rs @@ -1,40 +1,55 @@ -// NativeModule ping example. -// -// Sends a Twist message at 5 Hz and logs each echo received on `confirm`. - -use dimos_native_module::{LcmTransport, NativeModule}; +use dimos_module::{run, Input, LcmTransport, Module, Output}; use lcm_msgs::geometry_msgs::{Twist, Vector3}; use tokio::time::{interval, Duration}; -#[tokio::main] -async fn main() { - let transport = LcmTransport::new() - .await - .expect("Failed to create transport"); - let (mut module, _config) = NativeModule::from_stdin::<()>(transport) - .await - .expect("Failed to read config from stdin"); - - let mut confirm = module.input("confirm", Twist::decode); - let data = module.output("data", Twist::encode); - let _handle = module.spawn(); +#[derive(Module)] +#[module(setup = start_publisher)] +struct Ping { + #[input(decode = Twist::decode)] + confirm: Input, - let mut ticker = interval(Duration::from_millis(200)); - let mut seq = 0u64; + #[output(encode = Twist::encode)] + data: Output, +} - loop { - tokio::select! { - _ = ticker.tick() => { +impl Ping { + async fn start_publisher(&mut self) { + let data = self.data.clone(); + tokio::spawn(async move { + let mut ticker = interval(Duration::from_millis(200)); + let mut seq = 0u64; + loop { + ticker.tick().await; let msg = Twist { - linear: Vector3 { x: seq as f64, y: 0.0, z: 0.0 }, - angular: Vector3 { x: 0.0, y: 0.0, z: 0.0 }, + linear: Vector3 { + x: seq as f64, + y: 0.0, + z: 0.0, + }, + angular: Vector3 { + x: 0.0, + y: 0.0, + z: 0.0, + }, }; data.publish(&msg).await.ok(); seq += 1; } - Some(echo) = confirm.recv() => { - eprintln!("ping: echo received (seq={}, sample_config={})", echo.linear.x as u64, echo.angular.z as i64); - } - } + }); } + + async fn handle_confirm(&mut self, echo: Twist) { + println!( + "ping: echo received (seq={}, sample_config={})", + echo.linear.x as u64, echo.angular.z as i64, + ); + } +} + +#[tokio::main] +async fn main() { + let transport = LcmTransport::new() + .await + .expect("Failed to create transport"); + run::(transport).await.expect("ping run failed"); } diff --git a/examples/native-modules/rust/src/native_pong.rs b/examples/native-modules/rust/src/native_pong.rs index 74109eb4ba..36f33dc44f 100644 --- a/examples/native-modules/rust/src/native_pong.rs +++ b/examples/native-modules/rust/src/native_pong.rs @@ -1,9 +1,4 @@ -// NativeModule pong example. -// -// Receives Twist messages on `data` and echoes each one back on `confirm`, -// embedding the sample_config value in the reply's angular.z field. - -use dimos_native_module::{LcmTransport, NativeModule}; +use dimos_module::{run, Input, LcmTransport, Module, Output}; use lcm_msgs::geometry_msgs::{Twist, Vector3}; use serde::Deserialize; @@ -13,37 +8,36 @@ struct PongConfig { sample_config: i64, } +#[derive(Module)] +struct Pong { + #[input(decode = Twist::decode)] + data: Input, + + #[output(encode = Twist::encode)] + confirm: Output, + + #[config] + config: PongConfig, +} + +impl Pong { + async fn handle_data(&mut self, msg: Twist) { + let reply = Twist { + linear: msg.linear, + angular: Vector3 { + x: 0.0, + y: 0.0, + z: self.config.sample_config as f64, + }, + }; + self.confirm.publish(&reply).await.ok(); + } +} + #[tokio::main] async fn main() { let transport = LcmTransport::new() .await .expect("Failed to create transport"); - let (mut module, config) = NativeModule::from_stdin::(transport) - .await - .expect("Failed to read config from stdin"); - - eprintln!("pong: sample_config={}", config.sample_config); - - let mut data = module.input("data", Twist::decode); - let confirm = module.output("confirm", Twist::encode); - let _handle = module.spawn(); - - eprintln!("pong ready"); - - loop { - match data.recv().await { - Some(msg) => { - let reply = Twist { - linear: msg.linear, - angular: Vector3 { - x: 0.0, - y: 0.0, - z: config.sample_config as f64, - }, - }; - confirm.publish(&reply).await.ok(); - } - None => break, - } - } + run::(transport).await.expect("pong run failed"); } diff --git a/native/rust/.gitignore b/native/rust/.gitignore index 2f7896d1d1..eccd7b4ab8 100644 --- a/native/rust/.gitignore +++ b/native/rust/.gitignore @@ -1 +1,2 @@ -target/ +/target/ +**/*.rs.bk diff --git a/native/rust/Cargo.lock b/native/rust/Cargo.lock index 45982487ec..63f7b42fd9 100644 --- a/native/rust/Cargo.lock +++ b/native/rust/Cargo.lock @@ -8,10 +8,16 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "dimos-lcm" version = "0.1.0" -source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#50538b1372d6e06fdb0399abc6a35c2aa650a72f" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#e7c9428b7201cdfeadecd181c77c9e2d60a14503" dependencies = [ "byteorder", "socket2 0.5.10", @@ -19,16 +25,36 @@ dependencies = [ ] [[package]] -name = "dimos-native-module" +name = "dimos-module" version = "0.1.0" dependencies = [ "dimos-lcm", + "dimos-module-macros", "lcm-msgs", "serde", "serde_json", "tokio", ] +[[package]] +name = "dimos-module-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "itoa" version = "1.0.18" @@ -38,16 +64,16 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "lcm-msgs" version = "0.1.0" -source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#50538b1372d6e06fdb0399abc6a35c2aa650a72f" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#e7c9428b7201cdfeadecd181c77c9e2d60a14503" dependencies = [ "byteorder", ] [[package]] name = "libc" -version = "0.2.185" +version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" [[package]] name = "memchr" @@ -133,6 +159,16 @@ dependencies = [ "zmij", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + [[package]] name = "socket2" version = "0.5.10" @@ -166,13 +202,15 @@ dependencies = [ [[package]] name = "tokio" -version = "1.52.0" +version = "1.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a91135f59b1cbf38c91e73cf3386fca9bb77915c45ce2771460c9d92f0f3d776" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" dependencies = [ + "bytes", "libc", "mio", "pin-project-lite", + "signal-hook-registry", "socket2 0.6.3", "tokio-macros", "windows-sys 0.61.2", diff --git a/native/rust/Cargo.toml b/native/rust/Cargo.toml index e3e24a6ad3..13f4fc83db 100644 --- a/native/rust/Cargo.toml +++ b/native/rust/Cargo.toml @@ -1,15 +1,3 @@ -[package] -name = "dimos-native-module" -version = "0.1.0" -edition = "2021" -description = "Rust native module SDK for dimos NativeModule framework" -license = "Apache-2.0" - -[dependencies] -dimos-lcm = { git = "https://github.com/dimensionalOS/dimos-lcm.git", branch = "rust-codegen" } -tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "time"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1" - -[dev-dependencies] -lcm-msgs = { git = "https://github.com/dimensionalOS/dimos-lcm.git", branch = "rust-codegen" } +[workspace] +members = ["dimos-module", "dimos-module-macros"] +resolver = "2" diff --git a/native/rust/README.md b/native/rust/README.md new file mode 100644 index 0000000000..cd85b0d1d4 --- /dev/null +++ b/native/rust/README.md @@ -0,0 +1,100 @@ +# DimOS Rust module SDK + +Two crates: + +- **`dimos-module`**: runtime. `Module` trait, `Builder`, `Input`/`Output`, `Transport`/`LcmTransport`, `run()`. +- **`dimos-module-macros`**: `#[derive(Module)]` proc-macro. + +## Writing a module + +```rust +use dimos_module::{run, Input, LcmTransport, Module, Output}; +use lcm_msgs::geometry_msgs::Twist; +use serde::Deserialize; + +#[derive(Debug, Deserialize, Default)] +struct MyConfig { threshold: f64 } + +#[derive(Module)] +#[module(setup = on_start, teardown = on_stop)] +struct MyModule { + #[input(decode = Twist::decode)] + cmd: Input, + + #[output(encode = Twist::encode)] + out: Output, + + #[config] + config: MyConfig, +} + +impl MyModule { + // initialization or publisher setup + async fn on_start(&mut self) { /* ... */ } + + // processing function expected by cmd: Input + async fn handle_cmd(&mut self, msg: Twist) { /* ... */ } + + // teardown / clean up logic + async fn on_stop(&mut self) { /* ... */ } +} + +#[tokio::main] +async fn main() { + let transport = LcmTransport::new().await.unwrap(); + run::(transport).await.unwrap(); +} +``` + +## Attributes + +- `#[derive(Module)]`: on the struct. Required. +- `#[module(setup = fn, teardown = fn)]`: on the struct. Both optional. Names methods on `Self`. `setup` runs once before the input dispatch loop starts (use it to spawn background tasks or initialize resources); `teardown` runs once after the loop exits (use it for cleanup). +- `#[input(decode = fn, handler = fn)]`: on a field of type `Input`. `decode` is required; `handler` defaults to `handle_`. +- `#[output(encode = fn)]`: on a field of type `Output`. `encode` is required. +- `#[config]`: on one field of any `Deserialize` type. At most one per struct. If absent, `Config = ()`. +- Unattributed fields are initialized via `Default::default()` and treated as module state. + +Field name = port name. Ports map to topics via the stdin JSON; unmapped ports fall back to `/{port}`. + +## What `#[derive(Module)]` generates + +Just for reference, in the example above the macro expands to: + +```rust ignore +impl ::dimos_module::Module for MyModule { + type Config = MyConfig; + + fn build(builder: &mut ::dimos_module::Builder, config: Self::Config) -> Self { + Self { + cmd: builder.input("cmd", Twist::decode), + out: builder.output("out", Twist::encode), + config, + } + } + + async fn setup(&mut self) { self.on_start().await } + async fn teardown(&mut self) { self.on_stop().await } + + async fn handle(&mut self) { + loop { + // run whichever input channel has available messages and run the handler function + tokio::select! { + Some(msg) = self.cmd.recv() => self.handle_cmd(msg).await, + else => break, + } + } + } +} +``` + +`builder.input` registers a route from the resolved topic into an mpsc channel that backs `Input`. `builder.output` hands back an `Output` carrying a sender into the shared publish channel. + +## Lifecycle inside `run()` + +1. Read one JSON line from stdin, parse into `(topics, config)`. +2. `M::build(&mut builder, config)`: macro-generated, populates each field. +3. Spawn two tokio tasks: one drives `transport.recv()` and dispatches to input channels; one drains the publish channel into `transport.publish()`. The two run independently so a slow publish can't block recv. +4. `module.setup().await`. +5. `module.handle().await`, racing ctrl-c. +6. `module.teardown().await`. diff --git a/native/rust/dimos-module-macros/Cargo.toml b/native/rust/dimos-module-macros/Cargo.toml new file mode 100644 index 0000000000..c9644dd5ff --- /dev/null +++ b/native/rust/dimos-module-macros/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "dimos-module-macros" +version = "0.1.0" +edition = "2021" +description = "Proc-macros for dimos-module" +license = "Apache-2.0" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2", features = ["full"] } +quote = "1" +proc-macro2 = "1" diff --git a/native/rust/dimos-module-macros/src/lib.rs b/native/rust/dimos-module-macros/src/lib.rs new file mode 100644 index 0000000000..184e39c75e --- /dev/null +++ b/native/rust/dimos-module-macros/src/lib.rs @@ -0,0 +1,251 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, Ident, Path, Type}; + +#[proc_macro_derive(Module, attributes(input, output, config, module))] +pub fn derive_module(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + match expand(input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +enum FieldKind { + Input { decode: Path, handler: Ident }, + Output { encode: Path }, + Config, + State, +} + +struct ClassifiedField<'a> { + name: &'a Ident, + ty: &'a Type, + kind: FieldKind, +} + +fn expand(input: DeriveInput) -> syn::Result { + let struct_name = &input.ident; + + let fields = match &input.data { + Data::Struct(s) => match &s.fields { + Fields::Named(named) => &named.named, + _ => { + return Err(syn::Error::new_spanned( + &input, + "Module requires a struct with named fields", + )) + } + }, + _ => { + return Err(syn::Error::new_spanned( + &input, + "Module can only be derived for structs", + )) + } + }; + + let mut setup_method: Option = None; + let mut teardown_method: Option = None; + for attr in &input.attrs { + if attr.path().is_ident("module") { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("setup") { + setup_method = Some(meta.value()?.parse()?); + } else if meta.path.is_ident("teardown") { + teardown_method = Some(meta.value()?.parse()?); + } else { + return Err(meta.error( + "unrecognized #[module] argument; expected `setup = ...` or `teardown = ...`", + )); + } + Ok(()) + })?; + } + } + + let mut classified: Vec = Vec::new(); + let mut config_seen: Option<&Ident> = None; + + for field in fields { + let name = field.ident.as_ref().expect("named field has an identifier"); + let kind = classify_field(field, name)?; + if matches!(kind, FieldKind::Config) { + if let Some(prev) = config_seen { + return Err(syn::Error::new_spanned( + field, + format!( + "multiple #[config] fields (previous: `{prev}`); at most one is allowed" + ), + )); + } + config_seen = Some(name); + } + classified.push(ClassifiedField { + name, + ty: &field.ty, + kind, + }); + } + + let config_type: Type = classified + .iter() + .find_map(|f| matches!(f.kind, FieldKind::Config).then(|| f.ty.clone())) + .unwrap_or_else(|| syn::parse_quote!(())); + + let config_param: TokenStream2 = if config_seen.is_some() { + quote!(config) + } else { + quote!(_config) + }; + + let build_field_inits = classified.iter().map(|f| { + let name = f.name; + let name_str = name.to_string(); + match &f.kind { + FieldKind::Input { decode, .. } => { + quote!(#name: builder.input(#name_str, #decode)) + } + FieldKind::Output { encode } => { + quote!(#name: builder.output(#name_str, #encode)) + } + FieldKind::Config => quote!(#name: config), + FieldKind::State => quote!(#name: ::core::default::Default::default()), + } + }); + + let input_fields: Vec<&ClassifiedField> = classified + .iter() + .filter(|f| matches!(f.kind, FieldKind::Input { .. })) + .collect(); + + let handle_body = if input_fields.is_empty() { + quote!(::std::future::pending::<()>().await) + } else { + let handle_arms = input_fields.iter().map(|f| { + let FieldKind::Input { handler, .. } = &f.kind else { + unreachable!() + }; + let name = f.name; + quote!( + ::core::option::Option::Some(msg) = self.#name.recv() => { + self.#handler(msg).await + } + ) + }); + quote! { + loop { + ::tokio::select! { + #(#handle_arms,)* + else => break, + } + } + } + }; + + let setup_impl = setup_method.map(|m| { + quote! { + async fn setup(&mut self) { + self.#m().await + } + } + }); + + let teardown_impl = teardown_method.map(|m| { + quote! { + async fn teardown(&mut self) { + self.#m().await + } + } + }); + + Ok(quote! { + impl ::dimos_module::Module for #struct_name { + type Config = #config_type; + + fn build( + builder: &mut ::dimos_module::Builder, + #config_param: ::Config, + ) -> Self { + Self { + #(#build_field_inits,)* + } + } + + #setup_impl + + async fn handle(&mut self) { + #handle_body + } + + #teardown_impl + } + }) +} + +fn classify_field(field: &Field, name: &Ident) -> syn::Result { + let mut found: Option = None; + + for attr in &field.attrs { + let path = attr.path(); + if path.is_ident("input") { + if found.is_some() { + return Err(syn::Error::new_spanned( + attr, + "field has multiple module attributes; only one of #[input], #[output], #[config] is allowed", + )); + } + let mut decode: Option = None; + let mut handler: Option = None; + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("decode") { + decode = Some(meta.value()?.parse()?); + } else if meta.path.is_ident("handler") { + handler = Some(meta.value()?.parse()?); + } else { + return Err(meta.error( + "unrecognized #[input] argument; expected `decode = ...` or `handler = ...`", + )); + } + Ok(()) + })?; + let decode = decode + .ok_or_else(|| syn::Error::new_spanned(attr, "#[input] requires `decode = ...`"))?; + let handler = handler.unwrap_or_else(|| format_ident!("handle_{}", name)); + found = Some(FieldKind::Input { decode, handler }); + } else if path.is_ident("output") { + if found.is_some() { + return Err(syn::Error::new_spanned( + attr, + "field has multiple module attributes; only one of #[input], #[output], #[config] is allowed", + )); + } + let mut encode: Option = None; + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("encode") { + encode = Some(meta.value()?.parse()?); + } else { + return Err( + meta.error("unrecognized #[output] argument; expected `encode = ...`") + ); + } + Ok(()) + })?; + let encode = encode.ok_or_else(|| { + syn::Error::new_spanned(attr, "#[output] requires `encode = ...`") + })?; + found = Some(FieldKind::Output { encode }); + } else if path.is_ident("config") { + if found.is_some() { + return Err(syn::Error::new_spanned( + attr, + "field has multiple module attributes; only one of #[input], #[output], #[config] is allowed", + )); + } + found = Some(FieldKind::Config); + } + } + + Ok(found.unwrap_or(FieldKind::State)) +} diff --git a/native/rust/dimos-module/Cargo.toml b/native/rust/dimos-module/Cargo.toml new file mode 100644 index 0000000000..f08e76b1a7 --- /dev/null +++ b/native/rust/dimos-module/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "dimos-module" +version = "0.1.0" +edition = "2021" +description = "Rust native module SDK for dimos NativeModule framework" +license = "Apache-2.0" + +[dependencies] +dimos-lcm = { git = "https://github.com/dimensionalOS/dimos-lcm.git", branch = "rust-codegen" } +dimos-module-macros = { version = "=0.1.0", path = "../dimos-module-macros" } +tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "time", "signal", "io-std", "io-util"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +[dev-dependencies] +lcm-msgs = { git = "https://github.com/dimensionalOS/dimos-lcm.git", branch = "rust-codegen" } diff --git a/native/rust/src/lcm.rs b/native/rust/dimos-module/src/lcm.rs similarity index 91% rename from native/rust/src/lcm.rs rename to native/rust/dimos-module/src/lcm.rs index a4fbd027f4..b903b6f5a9 100644 --- a/native/rust/src/lcm.rs +++ b/native/rust/dimos-module/src/lcm.rs @@ -22,7 +22,7 @@ impl Transport for LcmTransport { self.0.publish(channel, data).await } - async fn recv(&mut self) -> io::Result<(String, Vec)> { + async fn recv(&self) -> io::Result<(String, Vec)> { let msg = self.0.recv().await?; Ok((msg.channel, msg.data)) } diff --git a/native/rust/src/lib.rs b/native/rust/dimos-module/src/lib.rs similarity index 70% rename from native/rust/src/lib.rs rename to native/rust/dimos-module/src/lib.rs index d98866417f..540dd72b42 100644 --- a/native/rust/src/lib.rs +++ b/native/rust/dimos-module/src/lib.rs @@ -2,8 +2,9 @@ pub mod lcm; pub mod module; pub mod transport; +pub use dimos_module_macros::Module; pub use lcm::LcmTransport; -pub use module::{Input, NativeModule, NativeModuleHandle, Output}; +pub use module::{run, Builder, Input, Module, Output}; pub use transport::Transport; // Re-export LcmOptions so callers don't need to depend on dimos-lcm directly. diff --git a/native/rust/dimos-module/src/module.rs b/native/rust/dimos-module/src/module.rs new file mode 100644 index 0000000000..9746aec5a1 --- /dev/null +++ b/native/rust/dimos-module/src/module.rs @@ -0,0 +1,567 @@ +use std::collections::HashMap; +use std::fmt::Debug; +use std::io; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::sync::mpsc; + +use serde::de::DeserializeOwned; + +use crate::transport::Transport; + +const INPUT_CHANNEL_CAPACITY: usize = 16; +const PUBLISH_CHANNEL_CAPACITY: usize = 64; + +// Each input() call produces a TypedRoute that decodes its message type +// and forwards it to the right Input's mpsc channel. +pub(crate) trait Route: Send { + fn try_dispatch(&self, data: &[u8]); +} + +struct TypedRoute { + topic: String, + decode: fn(&[u8]) -> io::Result, + sender: mpsc::Sender, +} + +impl Route for TypedRoute { + fn try_dispatch(&self, data: &[u8]) { + match (self.decode)(data) { + // If the input channel is full, the newest message is dropped. + Ok(msg) => { + let _ = self.sender.try_send(msg); + } + Err(e) => eprintln!("dimos_module: decode error on {}: {e}", self.topic), + } + } +} +pub struct Input { + pub topic: String, + receiver: mpsc::Receiver, +} + +impl Input { + pub async fn recv(&mut self) -> Option { + self.receiver.recv().await + } +} + +#[derive(Clone)] +pub struct Output { + pub topic: String, + encode: fn(&T) -> Vec, + sender: mpsc::Sender<(String, Vec)>, +} + +impl Output { + pub async fn publish(&self, msg: &T) -> io::Result<()> { + let data = (self.encode)(msg); + self.sender + .send((self.topic.clone(), data)) + .await + .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "background task gone")) + } +} + +/// Parse a JSON config line as written by the Python NativeModule coordinator. +/// Returns `(topics, config)`. Extracted so it can be unit-tested without stdin. +fn parse_config_json(line: &str) -> io::Result<(HashMap, C)> { + let json: serde_json::Value = serde_json::from_str(line.trim()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let mut topics = HashMap::new(); + if let Some(t) = json.get("topics").and_then(|v| v.as_object()) { + for (port, topic) in t { + if let Some(s) = topic.as_str() { + topics.insert(port.clone(), s.to_string()); + } + } + } + + let config: C = match json.get("config") { + None => return Err(io::Error::new( + io::ErrorKind::InvalidData, + "missing 'config' field in stdin JSON — coordinator must always send a config object", + )), + Some(v) => serde_json::from_value(v.clone()).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to deserialize config: {e}"), + ) + })?, + }; + + Ok((topics, config)) +} + +pub trait Module: Sized + Send + 'static { + type Config: DeserializeOwned + Debug; + + fn build(builder: &mut Builder, config: Self::Config) -> Self; + + fn setup(&mut self) -> impl std::future::Future + Send { + async {} + } + + fn handle(&mut self) -> impl std::future::Future + Send; + + fn teardown(&mut self) -> impl std::future::Future + Send { + async {} + } +} + +pub struct Builder { + topics: HashMap, + routes: HashMap>>, + publish_tx: mpsc::Sender<(String, Vec)>, +} + +impl Builder { + pub(crate) fn new( + topics: HashMap, + publish_tx: mpsc::Sender<(String, Vec)>, + ) -> Self { + Self { + topics, + routes: HashMap::new(), + publish_tx, + } + } + + fn topic_for(&self, port: &str) -> String { + self.topics + .get(port) + .cloned() + .unwrap_or_else(|| format!("/{port}")) + } + + pub fn input( + &mut self, + port: &str, + decode: fn(&[u8]) -> io::Result, + ) -> Input { + let topic = self.topic_for(port); + let (tx, rx) = mpsc::channel(INPUT_CHANNEL_CAPACITY); + self.routes + .entry(topic.clone()) + .or_default() + .push(Box::new(TypedRoute { + topic: topic.clone(), + decode, + sender: tx, + })); + Input { + topic, + receiver: rx, + } + } + + pub fn output(&self, port: &str, encode: fn(&T) -> Vec) -> Output { + Output { + topic: self.topic_for(port), + encode, + sender: self.publish_tx.clone(), + } + } +} + +pub(crate) fn spawn_pubsub_tasks( + transport: T, + routes: HashMap>>, + mut publish_rx: mpsc::Receiver<(String, Vec)>, +) -> (tokio::task::JoinHandle<()>, tokio::task::JoinHandle<()>) { + let transport = Arc::new(transport); + + let recv_transport = Arc::clone(&transport); + let recv_handle = tokio::spawn(async move { + loop { + match recv_transport.recv().await { + Ok((channel, data)) => { + if let Some(rs) = routes.get(&channel) { + for route in rs { + route.try_dispatch(&data); + } + } + } + Err(e) => eprintln!("dimos_module: recv error: {e}"), + } + } + }); + + let pub_transport = Arc::clone(&transport); + let pub_handle = tokio::spawn(async move { + while let Some((topic, data)) = publish_rx.recv().await { + if let Err(e) = pub_transport.publish(&topic, &data).await { + eprintln!("dimos_module: publish error on {topic}: {e}"); + } + } + }); + + (recv_handle, pub_handle) +} + +fn propagate_task_failure(name: &str, res: Result<(), tokio::task::JoinError>) { + match res { + Ok(()) => eprintln!("dimos_module: {name} task exited unexpectedly"), + Err(e) => { + eprintln!("dimos_module: {name} task panicked, propagating"); + std::panic::resume_unwind(e.into_panic()); + } + } +} + +pub async fn run(transport: T) -> io::Result<()> +where + M: Module, + T: Transport, +{ + let mut line = String::new(); + BufReader::new(tokio::io::stdin()) + .read_line(&mut line) + .await?; + let (topics, config) = parse_config_json::(&line)?; + + let exe = std::env::current_exe() + .ok() + .and_then(|p| p.file_name().map(|n| n.to_string_lossy().into_owned())) + .unwrap_or_else(|| "unknown".to_string()); + eprintln!("[{exe}] topics received:"); + for (port, topic) in &topics { + eprintln!(" {port} -> {topic}"); + } + eprintln!("[{exe}] config: {config:?}"); + + let (publish_tx, publish_rx) = mpsc::channel::<(String, Vec)>(PUBLISH_CHANNEL_CAPACITY); + let mut builder = Builder::new(topics, publish_tx); + let mut module = M::build(&mut builder, config); + let (mut recv_handle, mut pub_handle) = + spawn_pubsub_tasks(transport, builder.routes, publish_rx); + + module.setup().await; + + // record whatever resolves first, then teardown unconditionally + let failure = tokio::select! { + _ = module.handle() => None, + _ = tokio::signal::ctrl_c() => None, + res = &mut recv_handle => Some(("recv", res)), + res = &mut pub_handle => Some(("publish", res)), + }; + + module.teardown().await; + + // if the result was an error, handle it here + if let Some((name, res)) = failure { + propagate_task_failure(name, res); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::Deserialize; + use std::collections::VecDeque; + use std::sync::atomic::{AtomicU64, Ordering}; + use std::sync::{Arc, Mutex}; + use std::time::{Duration, Instant}; + use tokio::sync::Notify; + + type InboundQueue = Mutex)>>; + + /// Mock transport for testing message timing. + /// + /// Lets us test for concurrency and blocking when handling different messages. + struct ControllableMockTransport { + inbound: Arc, + inbound_notify: Arc, + publish_delay_ms: Arc, + publish_entered: Arc, + recv_returned: Arc, + recv_log: Arc>>, + publish_log: Arc>>, + } + + impl ControllableMockTransport { + fn new() -> Self { + Self { + inbound: Arc::new(InboundQueue::new(VecDeque::new())), + inbound_notify: Arc::new(Notify::new()), + publish_delay_ms: Arc::new(AtomicU64::new(0)), + publish_entered: Arc::new(Notify::new()), + recv_returned: Arc::new(Notify::new()), + recv_log: Arc::new(Mutex::new(Vec::new())), + publish_log: Arc::new(Mutex::new(Vec::new())), + } + } + } + + impl crate::transport::Transport for ControllableMockTransport { + async fn publish(&self, _channel: &str, _data: &[u8]) -> io::Result<()> { + self.publish_entered.notify_one(); + let delay = self.publish_delay_ms.load(Ordering::Relaxed); + if delay > 0 { + tokio::time::sleep(Duration::from_millis(delay)).await; + } + self.publish_log.lock().unwrap().push(Instant::now()); + Ok(()) + } + + async fn recv(&self) -> io::Result<(String, Vec)> { + loop { + let popped = self.inbound.lock().unwrap().pop_front(); + if let Some(msg) = popped { + self.recv_log.lock().unwrap().push(Instant::now()); + self.recv_returned.notify_one(); + return Ok(msg); + } + self.inbound_notify.notified().await; + } + } + } + + fn inject_inbound(inbound: &InboundQueue, notify: &Notify, channel: &str, data: Vec) { + inbound + .lock() + .unwrap() + .push_back((channel.to_string(), data)); + notify.notify_one(); + } + + #[derive(Debug, Deserialize, Default, PartialEq)] + #[serde(deny_unknown_fields)] + struct TestConfig { + value: i64, + name: String, + } + + // parse_config_json + #[test] + fn parses_topics_and_config() { + let json = r#"{"topics": {"data": "/foo/data", "confirm": "/foo/confirm"}, "config": {"value": 42, "name": "hello"}}"#; + let (topics, config) = parse_config_json::(json).unwrap(); + assert_eq!(topics["data"], "/foo/data"); + assert_eq!(topics["confirm"], "/foo/confirm"); + assert_eq!( + config, + TestConfig { + value: 42, + name: "hello".into() + } + ); + } + + #[test] + fn missing_config_field_returns_error() { + let json = r#"{"topics": {"data": "/foo/data"}}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("missing 'config' field")); + } + + #[test] + fn null_config_succeeds_for_unit_type() { + let json = r#"{"topics": {}, "config": null}"#; + let (_topics, _config) = parse_config_json::<()>(json).unwrap(); + } + + #[test] + fn null_config_errors_when_struct_expects_fields() { + let json = r#"{"topics": {}, "config": null}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + } + + #[test] + fn empty_config_object_errors_when_struct_expects_fields() { + let json = r#"{"topics": {}, "config": {}}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + } + + #[test] + fn config_with_wrong_type_returns_error() { + let json = r#"{"topics": {}, "config": {"value": "not_a_number", "name": "x"}}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("failed to deserialize config")); + } + + #[test] + fn missing_topics_field_gives_empty_map() { + let json = r#"{"config": {"value": 1, "name": "x"}}"#; + let (topics, _config) = parse_config_json::(json).unwrap(); + assert!(topics.is_empty()); + } + + #[test] + fn malformed_json_returns_error() { + let result = parse_config_json::<()>("not json at all"); + assert!(result.is_err()); + } + + #[test] + fn unknown_config_field_returns_error() { + let json = r#"{"topics": {}, "config": {"value": 1, "name": "x", "unexpected": true}}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + } + + // topic_for fallback + + fn topics(pairs: &[(&str, &str)]) -> HashMap { + pairs + .iter() + .map(|(p, t)| (p.to_string(), t.to_string())) + .collect() + } + + fn builder_with_topics(pairs: &[(&str, &str)]) -> Builder { + let (publish_tx, _) = mpsc::channel(PUBLISH_CHANNEL_CAPACITY); + Builder::new(topics(pairs), publish_tx) + } + + #[test] + fn unmapped_port_falls_back_to_slash_port() { + let builder = builder_with_topics(&[]); + assert_eq!(builder.topic_for("cmd_vel"), "/cmd_vel"); + } + + #[test] + fn mapped_port_uses_given_topic() { + let builder = builder_with_topics(&[("cmd_vel", "/robot/cmd_vel")]); + assert_eq!(builder.topic_for("cmd_vel"), "/robot/cmd_vel"); + } + + #[test] + fn input_uses_mapped_topic() { + let mut builder = builder_with_topics(&[("data", "/test/data")]); + let input = builder.input("data", |b| Ok(b.to_vec())); + assert_eq!(input.topic, "/test/data"); + } + + #[test] + fn input_falls_back_to_slash_port_when_unmapped() { + let mut builder = builder_with_topics(&[]); + let input = builder.input("data", |b| Ok(b.to_vec())); + assert_eq!(input.topic, "/data"); + } + + #[test] + fn output_uses_mapped_topic() { + let builder = builder_with_topics(&[("cmd_vel", "/robot/cmd_vel")]); + let output = builder.output("cmd_vel", |b: &Vec| b.clone()); + assert_eq!(output.topic, "/robot/cmd_vel"); + } + + // recv/publish concurrency + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn slow_publish_does_not_block_recv() { + let transport = ControllableMockTransport::new(); + let recv_log = transport.recv_log.clone(); + let inbound = transport.inbound.clone(); + let inbound_notify = transport.inbound_notify.clone(); + let publish_delay_ms = transport.publish_delay_ms.clone(); + let publish_entered = transport.publish_entered.clone(); + + // set publishing to take 200ms + publish_delay_ms.store(200, Ordering::Relaxed); + + let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_CAPACITY); + let mut builder = Builder::new(topics(&[("data", "/data"), ("out", "/out")]), publish_tx); + let _input = builder.input("data", |b| Ok(b.to_vec())); + let output = builder.output("out", |b: &Vec| b.clone()); + spawn_pubsub_tasks(transport, builder.routes, publish_rx); + + // start the 200ms publish + output.publish(&vec![0u8]).await.ok(); + + // ensure the publish starts getting handled before the receive + tokio::time::timeout(Duration::from_secs(1), publish_entered.notified()) + .await + .expect("dispatch task should pick up publish_rx within 1s"); + + inject_inbound(&inbound, &inbound_notify, "/data", vec![42u8]); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let recv_count = recv_log.lock().unwrap().len(); + assert!( + recv_count >= 1, + "expected recv to fire during slow publish; got {recv_count} events. \ + The recv path should be independent of publish latency." + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn slow_recv_dispatch_does_not_block_publish() { + let transport = ControllableMockTransport::new(); + let publish_log = transport.publish_log.clone(); + let inbound = transport.inbound.clone(); + let inbound_notify = transport.inbound_notify.clone(); + let recv_returned = transport.recv_returned.clone(); + + let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_CAPACITY); + let mut builder = Builder::new(topics(&[("slow", "/slow"), ("out", "/out")]), publish_tx); + + // simulate slow processing function in a receive + let _input = builder.input("slow", |b| { + std::thread::sleep(Duration::from_millis(200)); + Ok(b.to_vec()) + }); + let output = builder.output("out", |b: &Vec| b.clone()); + spawn_pubsub_tasks(transport, builder.routes, publish_rx); + + // send a message to the receiving + inject_inbound(&inbound, &inbound_notify, "/slow", vec![1u8]); + + // make sure the receive gets picked up before we publish + tokio::time::timeout(Duration::from_secs(1), recv_returned.notified()) + .await + .expect("dispatch task should pick up inbound within 1s"); + + output.publish(&vec![42u8]).await.ok(); + + // receive should still be processing, but publish should go through by now + tokio::time::sleep(Duration::from_millis(50)).await; + + let publish_count = publish_log.lock().unwrap().len(); + assert!( + publish_count >= 1, + "expected publish to fire during slow recv dispatch; got \ + {publish_count} events. The publish path should be independent \ + of recv-side CPU work." + ); + } + + // propagate_task_failure + + #[tokio::test] + async fn propagates_task_panic_payload() { + let handle = tokio::spawn(async { panic!("kaboom") }); + let res = handle.await; + + let caught = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + propagate_task_failure("recv", res); + })); + + let payload = caught.expect_err("expected helper to re-panic"); + let msg = payload + .downcast_ref::<&'static str>() + .copied() + .expect("panic payload should be a string literal"); + assert_eq!(msg, "kaboom"); + } + + #[test] + fn ok_does_not_panic() { + propagate_task_failure("recv", Ok(())); + } +} diff --git a/native/rust/src/transport.rs b/native/rust/dimos-module/src/transport.rs similarity index 77% rename from native/rust/src/transport.rs rename to native/rust/dimos-module/src/transport.rs index 0322f52681..49d0f9ceea 100644 --- a/native/rust/src/transport.rs +++ b/native/rust/dimos-module/src/transport.rs @@ -5,9 +5,9 @@ use std::io; /// /// New transport protocols should implement this trait. /// `NativeModule` is generic over any transport -pub trait Transport: Send + 'static { +pub trait Transport: Send + Sync + 'static { /// Send `data` on `channel`. fn publish(&self, channel: &str, data: &[u8]) -> impl Future> + Send; /// Block until the next inbound message, returning `(channel, data)`. - fn recv(&mut self) -> impl Future)>> + Send; + fn recv(&self) -> impl Future)>> + Send; } diff --git a/native/rust/src/module.rs b/native/rust/src/module.rs deleted file mode 100644 index 37ac2bd5e7..0000000000 --- a/native/rust/src/module.rs +++ /dev/null @@ -1,416 +0,0 @@ -use std::collections::HashMap; -use std::io::{self, BufRead}; -use tokio::sync::mpsc; - -use serde::de::DeserializeOwned; - -use crate::transport::Transport; - -const INPUT_CHANNEL_CAPACITY: usize = 16; -const PUBLISH_CHANNEL_CAPACITY: usize = 64; - -// Each input() call produces a TypedRoute that decodes its message type -// and forwards it to the right Input's mpsc channel. -trait Route: Send { - fn topic(&self) -> &str; - fn try_dispatch(&self, data: &[u8]); -} - -struct TypedRoute { - topic: String, - decode: fn(&[u8]) -> io::Result, - sender: mpsc::Sender, -} - -impl Route for TypedRoute { - fn topic(&self) -> &str { - &self.topic - } - - fn try_dispatch(&self, data: &[u8]) { - match (self.decode)(data) { - // If the input channel is full, the newest message is dropped. - Ok(msg) => { - let _ = self.sender.try_send(msg); - } - Err(e) => eprintln!("dimos_module: decode error on {}: {e}", self.topic), - } - } -} -pub struct Input { - pub topic: String, - receiver: mpsc::Receiver, -} - -impl Input { - pub async fn recv(&mut self) -> Option { - self.receiver.recv().await - } -} - -pub struct Output { - pub topic: String, - encode: fn(&T) -> Vec, - sender: mpsc::Sender<(String, Vec)>, -} - -impl Output { - pub async fn publish(&self, msg: &T) -> io::Result<()> { - let data = (self.encode)(msg); - self.sender - .send((self.topic.clone(), data)) - .await - .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "background task gone")) - } -} - -/// Parse a JSON config line as written by the Python NativeModule coordinator. -/// Returns `(topics, config)`. Extracted so it can be unit-tested without stdin. -fn parse_config_json(line: &str) -> io::Result<(HashMap, C)> { - let json: serde_json::Value = serde_json::from_str(line.trim()) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - - let mut topics = HashMap::new(); - if let Some(t) = json.get("topics").and_then(|v| v.as_object()) { - for (port, topic) in t { - if let Some(s) = topic.as_str() { - topics.insert(port.clone(), s.to_string()); - } - } - } - - let config: C = match json.get("config") { - None => return Err(io::Error::new( - io::ErrorKind::InvalidData, - "missing 'config' field in stdin JSON — coordinator must always send a config object", - )), - Some(v) => serde_json::from_value(v.clone()).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidData, - format!("failed to deserialize config: {e}"), - ) - })?, - }; - - Ok((topics, config)) -} - -/// High-level wrapper around a transport for use in dimos native modules. -/// -/// Generic over any `T: Transport`. Use `LcmTransport` for the standard LCM -/// UDP multicast transport. -/// -/// # Usage -/// -/// ```ignore -/// let transport = LcmTransport::new().await?; -/// let (mut module, config) = NativeModule::from_stdin::(transport).await?; -/// -/// let mut image_in = module.input("color_image", Image::decode); -/// let cmd_out = module.output("cmd_vel", Twist::encode); -/// let _handle = module.spawn(); -/// -/// loop { -/// tokio::select! { -/// Some(frame) = image_in.recv() => { cmd_out.publish(&twist).await.ok(); } -/// } -/// } -/// ``` -pub struct NativeModule { - transport: T, - routes: Vec>, - topics: HashMap, - publish_tx: mpsc::Sender<(String, Vec)>, - publish_rx: mpsc::Receiver<(String, Vec)>, -} - -impl NativeModule { - pub(crate) fn new(transport: T) -> Self { - let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_CAPACITY); - Self { - transport, - routes: Vec::new(), - topics: HashMap::new(), - publish_tx, - publish_rx, - } - } - - /// Parse `--port_name topic_string` pairs from argv, as injected by NativeModule. - pub async fn from_args(transport: T) -> io::Result { - let mut module = Self::new(transport); - let args: Vec = std::env::args().collect(); - let mut i = 1; - while i < args.len() { - if let Some(port) = args[i].strip_prefix("--") { - if i + 1 < args.len() && !args[i + 1].starts_with("--") { - module.topics.insert(port.to_string(), args[i + 1].clone()); - i += 2; - continue; - } - } - i += 1; - } - Ok(module) - } - - /// Read config from a single JSON line on stdin, as written by the Python NativeModule declaration. - /// - /// The JSON format is: - /// ```json - /// {"topics": {"port_name": "lcm/topic", ...}, "config": { ... }} - /// ``` - /// - /// `C` is the module-specific config type. Use `()` for modules with no configuration. - pub async fn from_stdin( - transport: T, - ) -> io::Result<(Self, C)> { - let mut line = String::new(); - io::stdin().lock().read_line(&mut line)?; - - let (topics, config) = parse_config_json::(&line)?; - - let mut module = Self::new(transport); - module.topics = topics; - - let exe = std::env::current_exe() - .ok() - .and_then(|p| p.file_name().map(|n| n.to_string_lossy().into_owned())) - .unwrap_or_else(|| "unknown".to_string()); - eprintln!("[{exe}] topics received:"); - for (port, topic) in &module.topics { - eprintln!(" {port} -> {topic}"); - } - eprintln!("[{exe}] config: {config:?}"); - - Ok((module, config)) - } - - /// Manually set a topic for a port — useful for testing without a parent process. - pub fn map_topic(&mut self, port: &str, topic: &str) { - self.topics.insert(port.to_string(), topic.to_string()); - } - - fn topic_for(&self, port: &str) -> String { - self.topics - .get(port) - .cloned() - .unwrap_or_else(|| format!("/{port}")) - } - - /// Register an input port. Must be called before `spawn()`. - pub fn input( - &mut self, - port: &str, - decode: fn(&[u8]) -> io::Result, - ) -> Input { - let topic = self.topic_for(port); - let (tx, rx) = mpsc::channel(INPUT_CHANNEL_CAPACITY); - self.routes.push(Box::new(TypedRoute { - topic: topic.clone(), - decode, - sender: tx, - })); - Input { - topic, - receiver: rx, - } - } - - /// Register an output port. Must be called before `spawn()`. - pub fn output(&self, port: &str, encode: fn(&M) -> Vec) -> Output { - Output { - topic: self.topic_for(port), - encode, - sender: self.publish_tx.clone(), - } - } - - /// Start the background recv/dispatch/publish loop. - /// - /// Consumes the module — no new ports can be registered after this point. - pub fn spawn(self) -> NativeModuleHandle { - let NativeModule { - mut transport, - routes, - mut publish_rx, - .. - } = self; - - let handle = tokio::spawn(async move { - loop { - tokio::select! { - result = transport.recv() => match result { - Ok((channel, data)) => { - for route in &routes { - if route.topic() == channel { - route.try_dispatch(&data); - } - } - } - Err(e) => { - eprintln!("dimos_module: recv error: {e}"); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - }, - Some((topic, data)) = publish_rx.recv() => { - if let Err(e) = transport.publish(&topic, &data).await { - eprintln!("dimos_module: publish error on {topic}: {e}"); - } - } - } - } - }); - - NativeModuleHandle(handle) - } -} - -pub struct NativeModuleHandle(tokio::task::JoinHandle<()>); - -impl NativeModuleHandle { - pub async fn join(self) -> Result<(), tokio::task::JoinError> { - self.0.await - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde::Deserialize; - - struct MockTransport; - - impl crate::transport::Transport for MockTransport { - async fn publish(&self, _channel: &str, _data: &[u8]) -> io::Result<()> { - Ok(()) - } - async fn recv(&mut self) -> io::Result<(String, Vec)> { - std::future::pending().await - } - } - - #[derive(Debug, Deserialize, Default, PartialEq)] - #[serde(deny_unknown_fields)] - struct TestConfig { - value: i64, - name: String, - } - - // --- parse_config_json --- - - #[test] - fn parses_topics_and_config() { - let json = r#"{"topics": {"data": "/foo/data", "confirm": "/foo/confirm"}, "config": {"value": 42, "name": "hello"}}"#; - let (topics, config) = parse_config_json::(json).unwrap(); - assert_eq!(topics["data"], "/foo/data"); - assert_eq!(topics["confirm"], "/foo/confirm"); - assert_eq!( - config, - TestConfig { - value: 42, - name: "hello".into() - } - ); - } - - #[test] - fn missing_config_field_returns_error() { - let json = r#"{"topics": {"data": "/foo/data"}}"#; - let result = parse_config_json::(json); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("missing 'config' field")); - } - - #[test] - fn null_config_succeeds_for_unit_type() { - let json = r#"{"topics": {}, "config": null}"#; - let (_topics, _config) = parse_config_json::<()>(json).unwrap(); - } - - #[test] - fn null_config_errors_when_struct_expects_fields() { - let json = r#"{"topics": {}, "config": null}"#; - let result = parse_config_json::(json); - assert!(result.is_err()); - } - - #[test] - fn empty_config_object_errors_when_struct_expects_fields() { - let json = r#"{"topics": {}, "config": {}}"#; - let result = parse_config_json::(json); - assert!(result.is_err()); - } - - #[test] - fn config_with_wrong_type_returns_error() { - let json = r#"{"topics": {}, "config": {"value": "not_a_number", "name": "x"}}"#; - let result = parse_config_json::(json); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("failed to deserialize config")); - } - - #[test] - fn missing_topics_field_gives_empty_map() { - let json = r#"{"config": {"value": 1, "name": "x"}}"#; - let (topics, _config) = parse_config_json::(json).unwrap(); - assert!(topics.is_empty()); - } - - #[test] - fn malformed_json_returns_error() { - let result = parse_config_json::<()>("not json at all"); - assert!(result.is_err()); - } - - #[test] - fn unknown_config_field_returns_error() { - let json = r#"{"topics": {}, "config": {"value": 1, "name": "x", "unexpected": true}}"#; - let result = parse_config_json::(json); - assert!(result.is_err()); - } - - // --- topic_for / map_topic --- - - #[test] - fn unmapped_port_falls_back_to_slash_port() { - let module = NativeModule::new(MockTransport); - assert_eq!(module.topic_for("cmd_vel"), "/cmd_vel"); - } - - #[test] - fn map_topic_overrides_fallback() { - let mut module = NativeModule::new(MockTransport); - module.map_topic("cmd_vel", "/robot/cmd_vel"); - assert_eq!(module.topic_for("cmd_vel"), "/robot/cmd_vel"); - } - - #[test] - fn input_uses_mapped_topic() { - let mut module = NativeModule::new(MockTransport); - module.map_topic("data", "/test/data"); - let input = module.input("data", |b| Ok(b.to_vec())); - assert_eq!(input.topic, "/test/data"); - } - - #[test] - fn input_falls_back_to_slash_port_when_unmapped() { - let mut module = NativeModule::new(MockTransport); - let input = module.input("data", |b| Ok(b.to_vec())); - assert_eq!(input.topic, "/data"); - } - - #[test] - fn output_uses_mapped_topic() { - let mut module = NativeModule::new(MockTransport); - module.map_topic("cmd_vel", "/robot/cmd_vel"); - let output = module.output("cmd_vel", |b: &Vec| b.clone()); - assert_eq!(output.topic, "/robot/cmd_vel"); - } -}