diff --git a/.gitignore b/.gitignore index 8516179..6f49254 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/Cargo.lock b/Cargo.lock index ed335ca..3c4aa1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -202,7 +202,7 @@ dependencies = [ "num-traits", "pastey", "rayon", - "thiserror", + "thiserror 2.0.18", "v_frame", "y4m", ] @@ -230,6 +230,61 @@ dependencies = [ "arrayvec", ] +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "base64 0.22.1", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sha1", + "sync_wrapper", + "tokio", + "tokio-tungstenite 0.28.0", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base64" version = "0.13.1" @@ -242,6 +297,21 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bit_field" version = "0.10.3" @@ -272,6 +342,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "borrow-or-share" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0b364ead1874514c8c2855ab558056ebfeb775653e7ae45ff72f28f8f3166c" + [[package]] name = "built" version = "0.8.0" @@ -284,6 +360,12 @@ version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +[[package]] +name = "bytecount" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e" + [[package]] name = "bytemuck" version = "1.25.0" @@ -349,30 +431,6 @@ dependencies = [ "shlex", ] -[[package]] -name = "cede" -version = "0.1.0" -dependencies = [ - "async-channel", - "async-trait", - "bytemuck", - "chrono", - "clap", - "crossterm", - "fastembed", - "futures", - "instant-distance", - "lru", - "ratatui", - "reqwest", - "rusqlite", - "serde", - "serde_json", - "thiserror", - "tokio", - "uuid", -] - [[package]] name = "cfg-if" version = "1.0.4" @@ -548,6 +606,17 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "cron" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eee8b2b4516038bc0f1d3c9934bcb4a13dd316e04abbc63c96757a6d75978532" +dependencies = [ + "chrono", + "nom 7.1.3", + "once_cell", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -693,6 +762,12 @@ dependencies = [ "serde", ] +[[package]] +name = "data-encoding" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" + [[package]] name = "derive_builder" version = "0.20.2" @@ -772,6 +847,15 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "email_address" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449" +dependencies = [ + "serde", +] + [[package]] name = "encode_unicode" version = "1.0.0" @@ -877,6 +961,17 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" +[[package]] +name = "fancy-regex" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298" +dependencies = [ + "bit-set", + "regex-automata", + "regex-syntax", +] + [[package]] name = "fastembed" version = "4.9.1" @@ -956,6 +1051,17 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "fluent-uri" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5" +dependencies = [ + "borrow-or-share", + "ref-cast", + "serde", +] + [[package]] name = "fnv" version = "1.0.7" @@ -992,6 +1098,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fraction" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7" +dependencies = [ + "lazy_static", + "num", +] + [[package]] name = "futures" version = "0.3.32" @@ -1229,7 +1345,7 @@ dependencies = [ "reqwest", "serde", "serde_json", - "thiserror", + "thiserror 2.0.18", "ureq", "windows-sys 0.60.2", ] @@ -1273,6 +1389,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.8.1" @@ -1287,6 +1409,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "pin-utils", @@ -1667,6 +1790,37 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonschema" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f66fe41fa46a5c83ed1c717b7e0b4635988f427083108c8cf0a882cc13441" +dependencies = [ + "ahash", + "base64 0.22.1", + "bytecount", + "email_address", + "fancy-regex", + "fraction", + "idna", + "itoa", + "num-cmp", + "once_cell", + "percent-encoding", + "referencing", + "regex-syntax", + "reqwest", + "serde", + "serde_json", + "uuid-simd", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "leb128fmt" version = "0.1.0" @@ -1785,6 +1939,21 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -1940,6 +2109,29 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -1950,6 +2142,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-cmp" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa" + [[package]] name = "num-complex" version = "0.4.6" @@ -1979,6 +2177,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-rational" version = "0.4.2" @@ -2015,6 +2224,41 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "omni-cede" +version = "0.1.0" +dependencies = [ + "async-channel", + "async-trait", + "axum", + "base64 0.22.1", + "bytemuck", + "chrono", + "clap", + "cron", + "crossterm", + "fastembed", + "futures", + "instant-distance", + "jsonschema", + "lru", + "rand 0.8.5", + "ratatui", + "reqwest", + "rusqlite", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-tungstenite 0.24.0", + "toml", + "tower-http", + "tracing", + "tracing-subscriber", + "url", + "uuid", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -2132,6 +2376,12 @@ dependencies = [ "ureq", ] +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + [[package]] name = "parking" version = "2.2.1" @@ -2439,7 +2689,7 @@ dependencies = [ "rand 0.9.2", "rand_chacha 0.9.0", "simd_helpers", - "thiserror", + "thiserror 2.0.18", "v_frame", "wasm-bindgen", ] @@ -2522,7 +2772,40 @@ checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" dependencies = [ "getrandom 0.2.17", "libredox", - "thiserror", + "thiserror 2.0.18", +] + +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "referencing" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0dcb5ab28989ad7c91eb1b9531a37a1a137cc69a0499aee4117cae4a107c464" +dependencies = [ + "ahash", + "fluent-uri", + "once_cell", + "percent-encoding", + "serde_json", ] [[package]] @@ -2563,6 +2846,7 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2", @@ -2791,6 +3075,26 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2803,6 +3107,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" @@ -2814,6 +3129,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -3033,13 +3357,33 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + [[package]] name = "thiserror" version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -3053,6 +3397,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "tiff" version = "0.11.3" @@ -3104,7 +3457,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror", + "thiserror 2.0.18", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", @@ -3158,6 +3511,32 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "native-tls", + "tokio", + "tokio-native-tls", + "tungstenite 0.24.0", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.28.0", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -3171,6 +3550,47 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "tower" version = "0.5.3" @@ -3184,6 +3604,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -3202,6 +3623,7 @@ dependencies = [ "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -3222,10 +3644,23 @@ version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ + "log", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.36" @@ -3233,6 +3668,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -3241,6 +3706,42 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "native-tls", + "rand 0.8.5", + "sha1", + "thiserror 1.0.69", + "utf-8", +] + +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.2", + "sha1", + "thiserror 2.0.18", + "utf-8", +] + [[package]] name = "typenum" version = "1.19.0" @@ -3341,6 +3842,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -3364,6 +3871,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "uuid-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8" +dependencies = [ + "outref", + "uuid", + "vsimd", +] + [[package]] name = "v_frame" version = "0.3.9" @@ -3375,6 +3893,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "vcpkg" version = "0.2.15" @@ -3387,6 +3911,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "want" version = "0.3.1" @@ -3827,6 +4357,15 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winnow" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index 245fc47..48a6972 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,16 +1,16 @@ [package] -name = "cede" +name = "omni-cede" version = "0.1.0" edition = "2021" -description = "A forkable self-aware agent with graph memory. Built on cortex-embedded." +description = "Omnichannel self-aware agent. Fork of cede with HTTP API, identity, and session management." [lib] -name = "cede" +name = "omni_cede" path = "src/lib.rs" [[bin]] -name = "cede" -path = "src/bin/cede.rs" +name = "omni-cede" +path = "src/bin/omni_cede.rs" [dependencies] rusqlite = { version = "0.31", features = ["bundled"] } @@ -31,3 +31,21 @@ async-trait = "0.1" chrono = "0.4" ratatui = "0.29" crossterm = { version = "0.28", features = ["event-stream"] } +axum = { version = "0.8", features = ["ws"] } +jsonschema = "0.28" +tower-http = { version = "0.6", features = ["cors", "trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +toml = "0.8" +cron = "0.13" +url = "2" + +base64 = "0.22" + +# Browser module (optional) +tokio-tungstenite = { version = "0.24", features = ["native-tls"], optional = true } +rand = { version = "0.8", optional = true } + +[features] +default = [] +browser = ["tokio-tungstenite", "rand"] diff --git a/OMNICHANNEL_PLAN.md b/OMNICHANNEL_PLAN.md new file mode 100644 index 0000000..e41e777 --- /dev/null +++ b/OMNICHANNEL_PLAN.md @@ -0,0 +1,607 @@ +# Omnichannel Integration Plan — omni-cede + +## Inspiration + +This plan takes direct inspiration from [OpenClaw](https://github.com/openclaw/openclaw), a TypeScript personal AI assistant that supports 20+ messaging channels through a **Gateway + Plugin** architecture. We adapt their core patterns to Rust and our graph-memory engine. + +### What OpenClaw does right (and what we're stealing) + +1. **Gateway as single control plane** — one process handles all channels, sessions, tools, and events. Channels connect TO the gateway, not the other way around. +2. **Channel = Plugin** — each channel is a self-contained extension implementing a standard contract (`channel-contract.ts`). Adding WhatsApp doesn't touch Telegram code. +3. **Plugin SDK** — shared helpers for the hard parts: pairing, allowlists, reply pipelines, typing indicators, media handling. +4. **Hooks pipeline** — lifecycle hooks (`before_dispatch`, `after_tool_call`, `session:patch`) let plugins intercept and transform messages at well-defined points. +5. **Session isolation with cross-channel knowledge** — each channel gets its own session, but the semantic search layer (their "context engine") spans everything. + +### What we already have (our advantage) + +- **Graph-native sessions** — our `run_turn()` already builds a fresh briefing per turn using HNSW semantic search + recency window. OpenClaw does growing message arrays then "compacts" them. We don't need that. +- **Identity resolution** — our `identity::resolve_user()` already maps (channel, external_id) → internal user_id. +- **Session manager** — our `session::get_or_create()` already scopes sessions to (user_id, channel). +- **HTTP API** — our axum server already handles `POST /v1/message` with the full identity→session→agent pipeline. + +We just need to add the **channel adapter layer** — the part that connects real messaging platforms to our existing `/v1/message` pipeline. + +--- + +## Architecture + +``` + ┌─────────────────────────────────────────────────┐ + │ omni-cede │ + │ │ + WhatsApp ──┐ │ ┌──────────────────────────────────────────┐ │ + Telegram ──┤ │ │ Channel Registry │ │ + Discord ──┤────▶│ │ ┌─────────┐ ┌─────────┐ ┌──────────┐ │ │ + Slack ──┤ │ │ │WhatsApp │ │Telegram │ │ Discord │ │ │ + WebChat ──┤ │ │ │ Adapter │ │ Adapter │ │ Adapter │ │ │ + Webhook ──┘ │ │ └────┬────┘ └────┬────┘ └────┬─────┘ │ │ + │ │ │ │ │ │ │ + │ └───────┼───────────┼───────────┼─────────┘ │ + │ ▼ ▼ ▼ │ + │ ┌──────────────────────────────────────────┐ │ + │ │ Inbound Pipeline │ │ + │ │ normalize → identity → session → hooks │ │ + │ └──────────────────┬───────────────────────┘ │ + │ ▼ │ + │ ┌──────────────────────────────────────────┐ │ + │ │ Agent (run_turn) │ │ + │ │ briefing → HNSW recall → LLM → tools │ │ + │ └──────────────────┬───────────────────────┘ │ + │ ▼ │ + │ ┌──────────────────────────────────────────┐ │ + │ │ Outbound Pipeline │ │ + │ │ hooks → chunking → rate-limit → send │ │ + │ └──────────────────────────────────────────┘ │ + └─────────────────────────────────────────────────┘ +``` + +--- + +## Phase 1: Channel Trait & Registry + +### The `Channel` trait + +Every messaging platform adapter implements one trait: + +```rust +// src/channels/trait.rs + +#[async_trait] +pub trait Channel: Send + Sync + 'static { + /// Unique channel identifier, e.g. "whatsapp", "telegram", "discord" + fn id(&self) -> &str; + + /// Human-readable name + fn display_name(&self) -> &str; + + /// Start the channel adapter (connect to APIs, start polling/webhooks) + async fn start(&self, ctx: ChannelContext) -> Result<()>; + + /// Stop gracefully + async fn stop(&self) -> Result<()>; + + /// Send a message back to a user on this channel + async fn send(&self, target: &OutboundTarget, message: OutboundMessage) -> Result<()>; + + /// Health check — is the channel connection alive? + async fn health(&self) -> ChannelHealth; + + /// Channel-specific configuration schema (for validation) + fn config_schema(&self) -> serde_json::Value { serde_json::json!({}) } + + /// Optional: typing indicator support + async fn send_typing(&self, _target: &OutboundTarget) -> Result<()> { Ok(()) } + + /// Optional: message editing (for streaming responses) + async fn edit_message(&self, _msg_id: &str, _new_text: &str) -> Result<()> { + Err(CortexError::Unsupported("edit not supported on this channel".into())) + } + + /// Optional: media support + fn supports_media(&self) -> bool { false } + async fn send_media(&self, _target: &OutboundTarget, _media: MediaPayload) -> Result<()> { + Err(CortexError::Unsupported("media not supported".into())) + } +} +``` + +### Supporting types + +```rust +// src/channels/types.rs + +/// Context passed to channels on startup — gives them access to the inbound pipeline +pub struct ChannelContext { + /// Call this when a message arrives from the channel + pub inbound_tx: tokio::sync::mpsc::Sender, + /// Shared app state for identity/session resolution + pub db: Db, + /// Channel-specific config section + pub config: serde_json::Value, +} + +/// A normalized inbound message from any channel +pub struct InboundEnvelope { + pub channel: String, + pub external_id: String, + pub sender_name: Option, + pub text: String, + pub media: Option, + pub reply_to: Option, // message ID being replied to + pub group_id: Option, // if this is a group message + pub raw: serde_json::Value, // channel-specific raw payload + pub timestamp: i64, +} + +/// Where to send a reply +pub struct OutboundTarget { + pub channel: String, + pub external_id: String, + pub group_id: Option, + pub reply_to_message_id: Option, +} + +/// An outbound message — text, media, or both +pub struct OutboundMessage { + pub text: String, + pub media: Option, + pub metadata: serde_json::Value, +} + +pub struct MediaPayload { + pub kind: MediaKind, + pub data: Vec, + pub mime_type: String, + pub filename: Option, +} + +pub enum MediaKind { + Image, + Audio, + Video, + Document, +} + +pub enum ChannelHealth { + Connected, + Degraded(String), + Disconnected(String), +} +``` + +### Channel Registry + +```rust +// src/channels/registry.rs + +pub struct ChannelRegistry { + channels: HashMap>, + inbound_tx: mpsc::Sender, + inbound_rx: mpsc::Receiver, +} + +impl ChannelRegistry { + pub fn new() -> Self { ... } + + /// Register a channel adapter + pub fn register(&mut self, channel: Arc) { ... } + + /// Start all registered channels + pub async fn start_all(&self, db: &Db, config: &Config) -> Result<()> { ... } + + /// Stop all channels + pub async fn stop_all(&self) -> Result<()> { ... } + + /// Get a channel by ID (for outbound routing) + pub fn get(&self, id: &str) -> Option> { ... } + + /// List all registered channels with health status + pub async fn health_all(&self) -> Vec<(String, ChannelHealth)> { ... } +} +``` + +### New file structure + +``` +src/channels/ + mod.rs # re-exports, Channel trait + types.rs # InboundEnvelope, OutboundTarget, OutboundMessage, etc. + registry.rs # ChannelRegistry + pipeline.rs # inbound/outbound message processing pipeline + webhook.rs # Generic webhook channel (for platforms that POST to us) + whatsapp.rs # WhatsApp adapter (via Baileys/whatsapp-web.js sidecar or webhook) + telegram.rs # Telegram adapter (Bot API, long polling or webhook) + discord.rs # Discord adapter (serenity or webhook) + slack.rs # Slack adapter (Bolt-style webhook) + webchat.rs # Built-in WebSocket webchat (served from the gateway) +``` + +--- + +## Phase 2: Inbound / Outbound Pipeline + +Inspired by OpenClaw's `before_dispatch` and reply pipeline hooks. + +### Inbound Pipeline + +When a message arrives from any channel: + +``` +InboundEnvelope + │ + ├─ 1. Normalize: trim whitespace, detect /commands + ├─ 2. Security: check allowlist for this (channel, sender) + ├─ 3. Identity: resolve_user(channel, external_id) → user_id + ├─ 4. Session: get_or_create(user_id, channel) → session_id + ├─ 5. Hook: before_agent (plugins can modify or reject) + ├─ 6. Agent: run_turn(session_id, text) → reply + ├─ 7. Hook: after_agent (plugins can modify reply) + ├─ 8. Record: session::record_turn() + └─ 9. Outbound: route reply back to the originating channel +``` + +### Outbound Pipeline + +``` +OutboundMessage + │ + ├─ 1. Hook: before_send (rate-limiting, logging) + ├─ 2. Chunk: split long messages per channel limits + │ (WhatsApp: 65536, Telegram: 4096, Discord: 2000, Slack: 40000) + ├─ 3. Send: channel.send(target, chunk) + ├─ 4. Typing: send typing indicator between chunks + └─ 5. Hook: after_send (delivery tracking) +``` + +### Hooks System + +```rust +// src/channels/hooks.rs + +#[async_trait] +pub trait ChannelHook: Send + Sync { + /// Called before the message is sent to the agent. Return Err to reject. + async fn before_agent(&self, _env: &mut InboundEnvelope) -> Result<()> { Ok(()) } + + /// Called after the agent produces a reply. Can modify the reply text. + async fn after_agent(&self, _env: &InboundEnvelope, _reply: &mut String) -> Result<()> { Ok(()) } + + /// Called before sending a message on a channel. + async fn before_send(&self, _target: &OutboundTarget, _msg: &mut OutboundMessage) -> Result<()> { Ok(()) } + + /// Called after successful send. + async fn after_send(&self, _target: &OutboundTarget, _msg: &OutboundMessage) -> Result<()> { Ok(()) } +} +``` + +--- + +## Phase 3: Channel Adapters (Priority Order) + +### 3a. Webhook Channel (generic) + +The simplest adapter — any platform that can POST JSON to us. Our existing `POST /v1/message` is basically this already. We generalize it: + +``` +POST /v1/channels/webhook/inbound +{ + "channel": "custom", + "external_id": "user123", + "text": "Hello", + "callback_url": "https://my-app.com/reply" // optional: where to POST the reply +} +``` + +This lets any system integrate without a dedicated adapter. + +**Effort:** Small — refactor existing `/v1/message` into the pipeline pattern. + +### 3b. Telegram + +Telegram is the easiest real channel — clean Bot API, no unofficial hacks. + +- **Inbound:** Long polling via `getUpdates` or webhook mode (Telegram POSTs to our `/v1/channels/telegram/webhook`) +- **Outbound:** `sendMessage`, `sendPhoto`, `editMessageText` (for streaming) +- **Features:** Typing indicators, inline keyboards, message editing, groups (mention gating), media +- **Auth:** `TELEGRAM_BOT_TOKEN` env var +- **Crate:** `reqwest` (just HTTP calls to `api.telegram.org`) +- **Config:** + ```json + { + "channels": { + "telegram": { + "bot_token": "...", + "mode": "polling", // or "webhook" + "webhook_url": "https://...", + "allow_from": ["123456789"], // telegram user IDs, "*" for all + "groups": { "*": { "require_mention": true } } + } + } + } + ``` + +**Effort:** Medium — straightforward HTTP API, ~400 lines. + +### 3c. Discord + +Discord needs a persistent WebSocket (gateway) for real-time events. + +- **Inbound:** WS gateway for `MESSAGE_CREATE` events, or slash commands via webhook +- **Outbound:** REST API `POST /channels/{id}/messages` +- **Features:** Threads, embeds, reactions, slash commands, voice channels (future) +- **Auth:** `DISCORD_BOT_TOKEN` env var +- **Crate:** Either `serenity` (full-featured) or raw WS + REST via `tokio-tungstenite` + `reqwest` +- **Config:** + ```json + { + "channels": { + "discord": { + "token": "...", + "allow_from": ["guild_id:channel_id"], + "dm_policy": "pairing" + } + } + } + ``` + +**Effort:** Medium-High — WS gateway is more complex. Recommend `serenity` crate to handle the protocol. + +### 3d. Slack + +Slack uses Socket Mode (WebSocket) or Events API (webhook). + +- **Inbound:** Socket Mode WS for `message` events, or HTTP webhook for Events API +- **Outbound:** `chat.postMessage`, `chat.update` (for streaming edits) +- **Features:** Threads, blocks (rich formatting), reactions, slash commands +- **Auth:** `SLACK_BOT_TOKEN` + `SLACK_APP_TOKEN` env vars +- **Crate:** `reqwest` for API calls, `tokio-tungstenite` for Socket Mode +- **Config:** + ```json + { + "channels": { + "slack": { + "bot_token": "xoxb-...", + "app_token": "xapp-...", + "mode": "socket", + "allow_from": ["U12345678"], + "dm_policy": "open" + } + } + } + ``` + +**Effort:** Medium — Socket Mode is simpler than Discord's gateway. + +### 3e. WhatsApp + +WhatsApp is the hardest — no official free API for personal accounts. + +**Option A: WhatsApp Cloud API (Business)** — official, requires Meta Business account. +- Inbound: Webhook (Meta POSTs to us) +- Outbound: REST API +- Config: `WHATSAPP_PHONE_NUMBER_ID`, `WHATSAPP_ACCESS_TOKEN`, `WHATSAPP_VERIFY_TOKEN` + +**Option B: Baileys sidecar** — unofficial, like OpenClaw does. +- Run a Node.js sidecar process that handles the WhatsApp Web protocol +- Communicate via local HTTP/WS between Rust and the sidecar +- More fragile, but works with personal accounts + +**Recommendation:** Start with Option A (Cloud API). Add Option B later as an optional sidecar. + +**Effort:** Medium (Cloud API) or High (Baileys sidecar). + +### 3f. WebSocket WebChat + +Built-in web interface served from the gateway itself. + +- **Inbound:** WebSocket `ws://host:port/v1/ws/chat` +- **Outbound:** same WebSocket, streaming tokens +- **Features:** Real-time streaming, typing indicators, session management in the browser +- **Auth:** Session token or API key +- **Crate:** `axum` already supports WebSocket upgrades + +**Effort:** Medium — WebSocket upgrade + simple web UI. + +--- + +## Phase 4: Configuration System + +Unified TOML/JSON config file at `~/.omni-cede/config.toml`: + +```toml +[agent] +model = "anthropic/claude-sonnet-4-20250514" + +[gateway] +host = "0.0.0.0" +port = 3000 +api_key = "sk-..." # or use OMNI_CEDE_API_KEY env var + +[channels.telegram] +enabled = true +bot_token = "123456:ABCDEF" # or TELEGRAM_BOT_TOKEN env +mode = "polling" # "polling" or "webhook" +allow_from = ["*"] + +[channels.discord] +enabled = true +token = "MTIz..." # or DISCORD_BOT_TOKEN env +dm_policy = "pairing" + +[channels.slack] +enabled = false + +[channels.whatsapp] +enabled = false + +[channels.webchat] +enabled = true # always-on by default + +[security] +dm_policy = "pairing" # global default: "open", "pairing", "closed" +``` + +**Pattern from OpenClaw:** Env vars always override config file values. Channel-specific settings override global defaults. + +--- + +## Phase 5: Security & Access Control + +Directly inspired by OpenClaw's DM pairing model: + +### Pairing Flow +1. Unknown sender messages the bot on any channel +2. Bot replies with a 6-digit pairing code (stored in DB with expiry) +3. Owner approves: `omni-cede pairing approve ` +4. Sender is added to the persistent allowlist for that channel +5. Future messages are processed normally + +### Allowlist Storage +```sql +CREATE TABLE channel_allowlist ( + channel TEXT NOT NULL, + external_id TEXT NOT NULL, + approved_at INTEGER NOT NULL, + approved_by TEXT, -- admin user_id who approved + PRIMARY KEY (channel, external_id) +); +``` + +### Policies (per-channel, cascading from global) +- `"open"` — process all inbound messages (dev/personal use) +- `"pairing"` — unknown senders get pairing code (default, safe) +- `"closed"` — only pre-approved allowlist members (production) + +--- + +## Phase 6: Observability & Management + +### CLI Commands +``` +omni-cede serve # Start gateway + all enabled channels +omni-cede channels list # Show all channels and their health +omni-cede channels status telegram # Detailed status for one channel +omni-cede pairing list # Pending pairing requests +omni-cede pairing approve # Approve a pairing request +omni-cede sessions list # All active sessions across channels +omni-cede doctor # Check config, credentials, connectivity +``` + +### API Endpoints (additions) +``` +GET /v1/channels # List channels + health +GET /v1/channels/:id/status # Detailed channel status +POST /v1/channels/:id/send # Send a message TO a channel (admin) +GET /v1/pairing # Pending pairing requests +POST /v1/pairing/:code/approve # Approve pairing +``` + +### Metrics (via stats endpoint) +- Messages processed per channel per hour +- Average response latency per channel +- Channel uptime/reconnection count +- Session count per channel + +--- + +## Implementation Order + +| Phase | What | New Files | Est. Lines | Priority | +|-------|------|-----------|------------|----------| +| 1a | Channel trait + types | `channels/{mod,types}.rs` | ~200 | **NOW** | +| 1b | Channel registry | `channels/registry.rs` | ~150 | **NOW** | +| 1c | Inbound/outbound pipeline | `channels/pipeline.rs` | ~300 | **NOW** | +| 1d | Hooks system | `channels/hooks.rs` | ~100 | **NOW** | +| 2a | Config system (TOML) | `config_file.rs` | ~200 | **NOW** | +| 2b | Webhook channel (generic) | `channels/webhook.rs` | ~100 | **NOW** | +| 3a | Telegram adapter | `channels/telegram.rs` | ~400 | **NEXT** | +| 3b | Discord adapter | `channels/discord.rs` | ~500 | **NEXT** | +| 3c | WebSocket WebChat | `channels/webchat.rs` | ~350 | **NEXT** | +| 3d | Slack adapter | `channels/slack.rs` | ~400 | **LATER** | +| 3e | WhatsApp Cloud API | `channels/whatsapp.rs` | ~450 | **LATER** | +| 4a | Pairing/allowlist | `channels/security.rs` | ~250 | **NEXT** | +| 4b | CLI commands | `cli/mod.rs` additions | ~200 | **NEXT** | +| 5a | observability endpoints | `api/mod.rs` additions | ~150 | **LATER** | +| 5b | Doctor command | `cli/doctor.rs` | ~200 | **LATER** | + +**Total new code:** ~4,000 lines across ~15 files + +--- + +## Dependency Additions + +```toml +# Phase 1 (trait + pipeline) +# No new deps — uses existing tokio, serde, axum + +# Phase 2 (config) +toml = "0.8" # Config file parsing + +# Phase 3a (Telegram) +# No new deps — uses reqwest (already have it) + +# Phase 3b (Discord) +serenity = { version = "0.12", default-features = false, features = ["client", "gateway", "model"] } + +# Phase 3c (WebChat) +# No new deps — axum WebSocket support is built-in + +# Phase 3d (Slack) +# No new deps — uses reqwest + tokio-tungstenite +tokio-tungstenite = "0.24" # WebSocket client for Slack Socket Mode + +# Phase 3e (WhatsApp) +# No new deps for Cloud API (uses reqwest) +``` + +--- + +## Key Design Decisions + +### 1. Channels run inside the gateway process (like OpenClaw) +No separate sidecar processes (except WhatsApp Baileys if needed). Each channel adapter is a Rust async task managed by the `ChannelRegistry`. This keeps deployment simple — one binary, one config file. + +### 2. All channels share the same inbound pipeline +Every message, regardless of source, flows through the same normalize → identity → session → agent → outbound path. The `Channel` trait only handles platform-specific wire protocol. Business logic stays in the pipeline. + +### 3. One session per (user, channel) — cross-channel knowledge via HNSW +Same as now. A WhatsApp session and a Telegram session for the same user are separate (separate recency windows). But HNSW semantic search spans the entire graph — the agent remembers what you said on WhatsApp when you talk on Telegram. + +### 4. Feature flags via Cargo features (later) +Eventually, each channel can be a Cargo feature so you only compile what you need: +```toml +[features] +default = ["telegram", "webchat"] +telegram = [] +discord = ["dep:serenity"] +slack = ["dep:tokio-tungstenite"] +whatsapp = [] +``` + +### 5. Outbound chunking is channel-aware +Each channel has different message length limits. The outbound pipeline asks the channel for its limit and splits accordingly. OpenClaw does this per-channel — we should too. + +--- + +## What We're NOT Doing (vs OpenClaw) + +| OpenClaw feature | Our take | +|-----------------|----------| +| Voice Wake / Talk Mode | Out of scope — we're text-first | +| Canvas / A2UI | Out of scope — no visual workspace | +| Companion apps (macOS/iOS/Android) | Out of scope — server-only | +| Skills registry (ClawHub) | We have tools, not skills | +| Browser control | Out of scope | +| Cron / scheduled messages | Phase 6+ (future) | +| Sandboxed execution | Not needed — we don't run arbitrary code | +| Multi-agent routing | Future — could route channels to different agents | + +--- + +## Next Steps + +1. **Build Phase 1** — Channel trait, registry, pipeline, hooks (~750 lines) +2. **Build Phase 2a** — Config file system (~200 lines) +3. **Build Phase 3a** — Telegram adapter as the first real channel +4. **Test end-to-end** — Message on Telegram → agent processes → reply on Telegram +5. **Iterate** — Add Discord, WebChat, then Slack/WhatsApp diff --git a/README.md b/README.md index 064043b..8f94de6 100644 --- a/README.md +++ b/README.md @@ -1,164 +1,236 @@ -# cortex-embedded +# omni-cede -**One crate. One SQLite file. A complete AI agent with graph memory, sub-agents, and a CLI.** +**Omnichannel AI agent powered by embedded memory graphs. One API, every channel, one graph.** -Everything — identity, knowledge, tool calls, LLM calls, sub-agent work, loop iterations, self-model — is a node in the graph. The agent queries its own history the same way it queries any other knowledge. +omni-cede extends [cede](https://github.com/MikeSquared-Agency/cede) with an HTTP API, identity resolution, and per-channel session management — all backed by an embedded memory graph (single SQLite file, no external DB). Connect WhatsApp, Telegram, Slack, Discord, or any custom integration — the agent remembers across all of them because every interaction is a node in the same graph. -## Features +## Ecosystem -- **Graph memory** — 18 node kinds, 6 edge kinds, full provenance tracking -- **Hybrid recall** — HNSW ANN search + BFS graph traversal + trust scoring + recency decay -- **Embeddings** — BAAI/bge-small-en-v1.5 via fastembed (384-dim, runs locally) -- **Auto-link** — background task creates `RelatesTo` and `Contradicts` edges automatically -- **Decay** — importance fades over time; Soul/Belief/Goal nodes are immune -- **Trust propagation** — `Supports` edges boost trust, `Contradicts` edges reduce it -- **Context compaction** — LLM extracts key facts from long conversations into the graph -- **LLM backends** — Anthropic Claude, Ollama (local), Mock (testing) -- **Tool registry** — tools write provenance-tracked results into the graph -- **Sub-agents** — spawn into the shared graph with scoped identity -- **CLI** — chat, ask, memory search, identity management, consolidation, diagnostics +``` +cortex-embedded <-- embedded memory graph engine (upstream) + |-- cede <-- forkable starter kit + |-- omni-cede <-- you are here (omnichannel deployment) +``` + +## What omni-cede Adds + +On top of everything in cede (embedded memory graph, hybrid recall, auto-linking, decay, tools, sub-agents, TUI), omni-cede adds: + +| Layer | What it does | +|-------|-------------| +| **HTTP API** | `POST /v1/message` — send a message from any channel and get a reply | +| **Identity** | Maps `(channel, external_id)` pairs to internal user IDs. Same person on WhatsApp and Telegram = same user | +| **Sessions** | One active session per (user, channel). WhatsApp gets its own conversational flow; Telegram gets another. Semantic recall searches the global graph — cross-channel knowledge | +| **Auth** | `x-api-key` header middleware. Set `API_KEY` env var to enable; omit for dev mode | ## Quick Start ```bash +# Clone +git clone https://github.com/MikeSquared-Agency/omni-cede.git +cd omni-cede + # Build cargo build --release -# Initialize database and download embedding model -cortex init - -# Check graph health -cortex doctor +# Start the API server +ANTHROPIC_API_KEY=sk-ant-... omni-cede serve +# Custom host/port +omni-cede serve --host 127.0.0.1 --port 8080 +# With Ollama +omni-cede --ollama llama3 serve -# View identity -cortex soul show +# Send a message +curl -X POST http://localhost:3000/v1/message \ + -H "Content-Type: application/json" \ + -d '{"channel": "whatsapp", "external_id": "+447123456789", "text": "Hello!"}' -# Memory stats -cortex memory stats +# Health check +curl http://localhost:3000/v1/health -# Interactive chat (requires LLM) -ANTHROPIC_API_KEY=sk-ant-... cortex chat -# or with Ollama -cortex --ollama llama3 chat +# List sessions for a user +curl http://localhost:3000/v1/sessions/ -# Single query -cortex ask "What do you know about JWT tokens?" +# Stats +curl http://localhost:3000/v1/stats +``` -# Semantic search -cortex memory search "authentication" +### With Auth -# Run trust consolidation -cortex consolidate +```bash +# Start with auth enabled +API_KEY=my-secret-key ANTHROPIC_API_KEY=sk-ant-... omni-cede serve + +# Requests require the header +curl -X POST http://localhost:3000/v1/message \ + -H "Content-Type: application/json" \ + -H "x-api-key: my-secret-key" \ + -d '{"channel": "telegram", "external_id": "12345678", "text": "Hello!"}' ``` -## Architecture +## API Reference + +### `POST /v1/message` + +Send a message from any channel. The server resolves the user's identity, gets or creates a session, runs the agent, and returns the reply. +**Request:** +```json +{ + "channel": "whatsapp", + "external_id": "+447123456789", + "text": "What did we discuss yesterday?" +} ``` -┌─────────────────────────────────────────────┐ -│ cortex-embedded │ -├──────────┬──────────┬──────────┬────────────┤ -│ recall │ briefing │ tools │ agent │ -│ (hybrid │ (context │ (registry│ (loop + │ -│ search) │ doc) │ + trust)│ sub-agents)│ -├──────────┴──────────┴──────────┴────────────┤ -│ graph + memory │ -│ (BFS walk, scoring, decay) │ -├──────────┬──────────────────────────────────┤ -│ HNSW │ SQLite │ -│ (2-tier) │ (WAL mode, bundled rusqlite) │ -├──────────┴──────────────────────────────────┤ -│ fastembed │ -│ (BAAI/bge-small-en-v1.5) │ -└─────────────────────────────────────────────┘ + +**Response:** +```json +{ + "reply": "Yesterday we discussed the new API design...", + "user_id": "a1b2c3d4-...", + "session_id": "e5f6g7h8-..." +} ``` -### Node Kinds +### `GET /v1/health` -| Category | Kinds | -|----------|-------| -| Knowledge | `Fact`, `Entity`, `Concept`, `Decision` | -| Identity | `Soul`, `Belief`, `Goal` | -| Operational | `Session`, `Turn`, `LlmCall`, `ToolCall`, `LoopIteration` | -| Sub-agents | `SubAgent`, `Delegation`, `Synthesis` | -| Meta | `Pattern`, `Capability`, `Limitation`, `Contradiction` | +```json +{ + "status": "ok", + "version": "0.1.0" +} +``` -### Edge Kinds +### `GET /v1/sessions/:user_id` + +```json +[ + { + "session_id": "e5f6g7h8-...", + "channel": "whatsapp", + "created_at": 1711324800, + "turn_count": 42, + "last_active": 1711411200 + } +] +``` -`RelatesTo` · `Contradicts` · `Supports` · `DerivesFrom` · `PartOf` · `Supersedes` +### `GET /v1/stats` -## How It Works +```json +{ + "nodes": 1234, + "edges": 5678, + "by_kind": {"fact": 200, "soul": 1, "session": 15, "...": "..."}, + "managed_sessions": 15, + "total_turns": 342 +} +``` -Every interaction creates a provenance chain: +## How Identity Works ``` -Fact → ToolCall → LoopIteration → Session +WhatsApp +447123456789 -+ + |-> user_id: a1b2c3d4 +Telegram @johndoe -+ (linked via identity layer) ``` -The agent knows not just *what* it knows, but *how it came to know it*, *when*, *via which tool*, and *how much to trust it*. +When a message arrives, the identity layer: +1. Looks up `(channel, external_id)` in the `channel_mappings` table +2. If found, returns the existing internal user +3. If not, creates a new user and mapping -**Recall pipeline:** -1. Embed query → HNSW k-NN search -2. BFS graph walk from candidates -3. Score: `importance × trust × recency × proximity_bonus` -4. Return ranked nodes with contradiction warnings +You can link multiple channels to one user via the identity API. -**Background tasks:** -- **Auto-link** — new nodes are compared against the graph; similar nodes get `RelatesTo` edges, contradicting nodes get `Contradicts` edges -- **Decay** — every 60s, nodes not accessed in 24h lose importance (floor: 0.01) +## How Sessions Work -## Using as a Library +Each (user, channel) pair gets its own session. This means: -```rust -use cortex_embedded::{CortexEmbedded, types::*}; +- **Recency window is channel-scoped** — "stop using big words" on WhatsApp only affects WhatsApp's briefing +- **Semantic recall is global** — facts learned on Telegram are available when the user asks on WhatsApp +- **Sessions persist** — reconnecting to the same channel resumes the same session -#[tokio::main] -async fn main() -> Result<(), Box> { - let cx = CortexEmbedded::open("my_agent.db").await?; +## Architecture - // Store knowledge - let node = Node::new(NodeKind::Fact, "Rust is fast") - .with_body("Rust provides zero-cost abstractions and memory safety."); - cx.remember(node).await?; +``` ++---------------------------------------------+ +| omni-cede | ++-----------+-----------+---------------------+ +| HTTP API | Identity | Session Manager | +| (axum) | (channel | (one per user + | +| | mapping) | channel pair) | ++-----------+-----------+---------------------+ +| cede core | ++---------+----------+---------+--------------+ +| recall | briefing | tools | agent | +| (HNSW + | (scored | (custom | (loop + | +| graph) | context)| + std) | subagent) | ++---------+----------+---------+--------------+ +| graph + memory | +| (BFS, scoring, decay) | ++---------+------------------------------------+ +| HNSW | SQLite | +| (2-tier)| (WAL, bundled rusqlite) | ++---------+------------------------------------+ +| fastembed | +| (BAAI/bge-small-en-v1.5) | ++----------------------------------------------+ +``` - // Recall - let results = cx.recall("performance", RecallOptions::default()).await?; - for r in &results { - println!("[{}] {} — score: {:.3}", r.node.kind, r.node.title, r.score); - } +## CLI Commands - // Build briefing for LLM - let briefing = cx.briefing("system design", 12).await?; - println!("{}", briefing.context_doc); +omni-cede retains all of cede's CLI commands and adds `serve`: - Ok(()) -} +```bash +omni-cede serve # Start HTTP API server (0.0.0.0:3000) +omni-cede serve --port 8080 # Custom port +omni-cede chat # Interactive CLI chat +omni-cede ask "question" # Single query +omni-cede graph explore # TUI graph explorer +omni-cede graph overview # Graph visualization +omni-cede memory stats # Memory statistics +omni-cede memory search "query" # Semantic search +omni-cede soul show # View identity +omni-cede doctor # Health check +omni-cede consolidate # Trust propagation +omni-cede init # Initialize DB + download model +``` + +## Environment Variables + +| Variable | Required | Description | +|----------|----------|-------------| +| `ANTHROPIC_API_KEY` | Yes* | Anthropic API key (*or use `--ollama`) | +| `ANTHROPIC_MODEL` | No | Model override (default: `claude-sonnet-4-20250514`) | +| `API_KEY` | No | If set, requires `x-api-key` header on all requests | +| `RUST_LOG` | No | Tracing filter (default: `omni_cede=info,tower_http=info`) | + +## Staying Updated + +omni-cede tracks cede as `upstream`. To pull improvements: + +```bash +git fetch upstream +git merge upstream/master ``` ## Dependencies +Everything from cede, plus: + | Crate | Purpose | |-------|---------| -| `rusqlite` (bundled) | SQLite with WAL mode | -| `instant-distance` | HNSW approximate nearest neighbor search | -| `fastembed` | Local text embeddings (ONNX runtime) | -| `tokio` | Async runtime | -| `reqwest` | HTTP client for Anthropic API | -| `clap` | CLI argument parsing | -| `async-channel` | Background task communication | +| `axum` 0.8 | HTTP framework | +| `tower-http` 0.6 | CORS + request tracing middleware | +| `tracing` + `tracing-subscriber` | Structured logging | ## Tests ```bash -# Run all tests (22 total) +# Run all 22 tests cargo test -- --test-threads=1 - -# Just HNSW unit tests -cargo test --lib hnsw - -# Just integration tests -cargo test --test integration -- --test-threads=1 ``` ## License -MIT +MIT \ No newline at end of file diff --git a/agents.md b/agents.md new file mode 100644 index 0000000..ae92356 --- /dev/null +++ b/agents.md @@ -0,0 +1,140 @@ +# agents.md — Guide for AI Agents Working on omni-cede + +You are working on **omni-cede**, the omnichannel deployment variant of the cortex-embedded cognitive engine. This file tells you how to navigate the codebase and contribute effectively. + +## What This Repo Is + +omni-cede extends the cortex-embedded engine with: +- **HTTP API** (axum) — stateless REST endpoints for multi-client messaging +- **Identity resolution** — maps (channel, external_id) pairs to internal user IDs +- **Session management** — one active session per (user_id, channel), automatic turn tracking + +### Ecosystem Position +- **cortex-embedded** (upstream) — the frozen engine +- **cede** — forkable starter kit (no API layer) +- **omni-cede** (this repo) — production omnichannel variant + +## Repository Layout + +``` +src/ + lib.rs # CortexEmbedded struct, background tasks, decay, consolidation + types.rs # All types: Node, Edge, NodeKind, EdgeKind, Message, LlmResponse + error.rs # CortexError enum, Result type alias + config.rs # Config struct with all tunable parameters + agent/ + mod.rs # Re-exports Agent + orchestrator.rs # Agent struct, run() and run_turn() methods, tool-call loop + subagent.rs # Sub-agent spawning and delegation + api/ + mod.rs # axum Router, POST /v1/message, GET /v1/health, sessions, stats + identity/ + mod.rs # IdentityResolver: (channel, external_id) → internal user_id + session/ + mod.rs # SessionManager: one active session per (user_id, channel) + db/ + mod.rs # Db struct (Arc>), async call() wrapper + schema.rs # CREATE TABLE statements, migrations + queries.rs # All SQL queries as functions + embed/ + mod.rs # EmbedHandle — fastembed wrapper with LRU cache + hnsw/ + mod.rs # VectorIndex — 2-tier HNSW (built index + linear buffer) + graph/ + mod.rs # BFS traversal, graph walk scoring + memory/ + mod.rs # recall(), briefing(), briefing_with_kinds(), recency window + tools/ + mod.rs # ToolRegistry, builtin tools + llm/ + mod.rs # LlmClient trait, AnthropicClient, OllamaClient, MockLlm + cli/ + mod.rs # CLI commands including Serve { host, port } + graph_tui.rs # Interactive TUI graph explorer + graph_viz.rs # ASCII graph visualization + bin/ + omni_cede.rs # Binary entry point, tracing-subscriber init +tests/ + integration.rs # 22 integration tests +``` + +## API Endpoints + +| Method | Path | Auth | Description | +|--------|------|------|-------------| +| POST | /v1/message | x-api-key | Send a message, get a response | +| GET | /v1/health | none | Health + node/edge counts | +| GET | /v1/sessions/:user_id | x-api-key | List sessions for a user | +| GET | /v1/stats | x-api-key | Global graph statistics | + +### POST /v1/message +```json +{ + "channel": "web", + "external_id": "user_abc", + "message": "Hello" +} +``` +Returns: +```json +{ + "response": "...", + "user_id": "internal-uuid", + "session_id": "session-uuid" +} +``` + +## Key Architecture (omni-cede-specific) + +### Identity Resolution (`src/identity/mod.rs`) +- SQLite tables: `users` (id, created_at), `channel_mappings` (channel, external_id, user_id) +- `resolve(channel, external_id)` → returns existing or creates new internal user_id +- Same external_id on different channels = different internal users + +### Session Management (`src/session/mod.rs`) +- SQLite table: `managed_sessions` +- One active session per (user_id, channel) +- `get_or_create(user_id, channel)` → session_id +- `record_turn(session_id)` → updates last_active_at, increments turn_count +- `list_user_sessions(user_id)` → all sessions across channels + +### API Layer (`src/api/mod.rs`) +- axum 0.8 Router with tower-http CORS and tracing +- Auth middleware: checks `x-api-key` header against `OMNI_CEDE_API_KEY` env var +- State: `Arc` containing CortexEmbedded, IdentityResolver, SessionManager, Agent + +### Additional Dependencies vs cede +- `axum = "0.8"`, `tower-http = "0.6"` (cors, trace features) +- `tracing = "0.1"`, `tracing-subscriber = "0.3"` (env-filter feature) + +## Environment Variables + +| Variable | Required | Description | +|----------|----------|-------------| +| ANTHROPIC_API_KEY | Yes (unless --ollama) | Claude API key | +| OMNI_CEDE_API_KEY | Yes (for API mode) | API authentication key | +| RUST_LOG | No | Tracing filter (default: info) | + +## Build and Test + +```bash +cargo build +cargo test -- --test-threads=1 # 28 tests + +# Run the HTTP server +OMNI_CEDE_API_KEY=secret cargo run -- serve --host 0.0.0.0 --port 3000 +``` + +## Conventions + +- Async DB: `db.call(move |conn| { ... }).await` +- Embeddings: 384-dim f32 (BAAI/bge-small-en-v1.5) +- Node IDs: UUID v4 strings +- Timestamps: Unix seconds (i64) +- Error handling: `CortexError` enum, `Result` alias +- API errors: JSON `{"error": "message"}` with appropriate HTTP status + +## Branch Policy + +- `master` is protected: no direct push, PRs required +- Work on `dev` branch, merge via PR \ No newline at end of file diff --git a/claude.md b/claude.md new file mode 100644 index 0000000..2d6a14f --- /dev/null +++ b/claude.md @@ -0,0 +1,100 @@ +# claude.md — Instructions for Claude Working on omni-cede + +## Identity + +You are working on **omni-cede** — the omnichannel deployment variant of cortex-embedded, built by MikeSquared Agency. This repo adds HTTP API, identity resolution, and session management on top of the core graph-memory engine. + +## Your Role + +You are an expert Rust systems programmer with deep knowledge of async web services (axum/tower), SQLite, embedding models, and graph data structures. You build production-grade API layers. + +## Critical Rules + +1. **All DB access through `db.call()`** — the established async pattern: + ```rust + db.call(move |conn| { + // synchronous rusqlite code here + Ok(result) + }).await? + ``` +2. **Tests must pass.** `cargo test -- --test-threads=1` — 28 tests. MockLlm + in-memory SQLite. +3. **UTF-8 only.** Em dashes are `—` (U+2014), never byte 0x97 (Windows-1252). +4. **No growing message arrays.** `run_turn()` builds a fresh briefing each turn. +5. **API responses are JSON.** Errors return `{"error": "message"}` with proper HTTP status codes. +6. **Auth is required.** All mutating/data endpoints require `x-api-key` header matching `OMNI_CEDE_API_KEY` env var. Only `/v1/health` is public. + +## Architecture Quick Reference + +| Struct | Location | Purpose | +|--------|----------|---------| +| CortexEmbedded | lib.rs | Top-level runtime, owns all resources | +| Agent | agent/orchestrator.rs | Runs queries and chat turns | +| Db | db/mod.rs | Arc> with async wrapper | +| AppState | api/mod.rs | Shared API state (cortex, identity, session, agent) | +| IdentityResolver | identity/mod.rs | (channel, external_id) → internal user_id | +| SessionManager | session/mod.rs | One active session per (user_id, channel) | +| VectorIndex | hnsw/mod.rs | 2-tier HNSW for semantic search | +| EmbedHandle | embed/mod.rs | fastembed with LRU cache | +| Config | config.rs | All tunable parameters | + +## API Layer Details + +### Request Flow (POST /v1/message) +1. Auth middleware validates `x-api-key` +2. Parse JSON body: `{ channel, external_id, message }` +3. `IdentityResolver::resolve(channel, external_id)` → `user_id` +4. `SessionManager::get_or_create(user_id, channel)` → `session_id` +5. `Agent::run_turn(session_id, message)` → `response` +6. `SessionManager::record_turn(session_id)` → updates stats +7. Return `{ response, user_id, session_id }` + +### Adding a New Endpoint +1. Add handler function in `src/api/mod.rs` +2. Add route in the `router()` function +3. If it needs auth, nest it under the auth middleware layer +4. Return `Json` or use a typed response struct + +### Identity Resolution Design +- `users` table: `(id TEXT PK, created_at INTEGER)` +- `channel_mappings` table: `(channel TEXT, external_id TEXT, user_id TEXT, UNIQUE(channel, external_id))` +- Same person on Slack vs Discord = different internal user_ids (by design) +- To merge identities in the future, update channel_mappings to point to same user_id + +### Session Management Design +- `managed_sessions` table: `(id TEXT PK, user_id TEXT, channel TEXT, created_at INTEGER, last_active_at INTEGER, turn_count INTEGER)` +- One active session per (user_id, channel) — no explicit session close +- Sessions are reused until a new one is explicitly created + +## Environment Variables + +| Variable | Required | Default | Notes | +|----------|----------|---------|-------| +| ANTHROPIC_API_KEY | Yes* | — | *Unless using --ollama | +| OMNI_CEDE_API_KEY | Yes | — | API auth key | +| RUST_LOG | No | info | Tracing filter level | + +## Dependencies (omni-cede-specific) + +- `axum = "0.8"` — HTTP framework +- `tower-http = "0.6"` (cors, trace) — middleware +- `tracing = "0.1"` — structured logging +- `tracing-subscriber = "0.3"` (env-filter) — log output + +## Style Guide + +- `thiserror` for error types +- `impl Into` in public APIs +- `tracing` macros (`info!`, `warn!`, `error!`) for logging +- Functions under 50 lines +- Typed extractors in axum handlers +- `Arc` as shared state — never clone the inner structs + +## Common Pitfalls + +- **CortexError::DbTask** — NOT `CortexError::Database` +- HNSW buffer must be flushed (`build()`) before queries see new vectors +- fastembed downloads model on first call — tests use mock embeddings +- SQLite WAL mode — one writer at a time +- `OMNI_CEDE_API_KEY` must be set or ALL authenticated endpoints return 401 +- axum 0.8 uses `axum::extract::State` — not the old Extension pattern +- CORS is permissive by default (tower_http::cors::CorsLayer::permissive()) — tighten for production \ No newline at end of file diff --git a/server_debug.log b/server_debug.log new file mode 100644 index 0000000..0a11161 Binary files /dev/null and b/server_debug.log differ diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 5a5cf3b..4b7b527 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,4 +1,3 @@ pub mod orchestrator; -pub mod subagent; pub use orchestrator::Agent; diff --git a/src/agent/orchestrator.rs b/src/agent/orchestrator.rs index 04a9f63..3aa2107 100644 --- a/src/agent/orchestrator.rs +++ b/src/agent/orchestrator.rs @@ -1,6 +1,9 @@ use std::sync::Arc; use std::time::Instant; use tokio::sync::RwLock; +use tokio::task::JoinSet; + +use base64::Engine as _; use crate::config::Config; use crate::db::Db; @@ -10,9 +13,15 @@ use crate::error::Result; use crate::hnsw::VectorIndex; use crate::llm::LlmClient; use crate::memory; +use crate::memory::format_timestamp; use crate::tools::ToolRegistry; use crate::types::*; +/// Base64-encode raw bytes for Anthropic's image block format. +fn base64_encode(data: &[u8]) -> String { + base64::engine::general_purpose::STANDARD.encode(data) +} + /// The agent. Owns the LLM client, tool registry, and a handle to the shared /// `CortexEmbedded` infrastructure (db, embed, hnsw). pub struct Agent { @@ -39,6 +48,26 @@ impl Agent { .await?; // Build briefing for system prompt + let now_ts = format_timestamp(crate::types::now_unix()); + + // Store user input with timestamp + let user_node = Node::new(NodeKind::UserInput, format!("[{now_ts}] {input}")) + .with_body(input) + .with_importance(0.4) + .with_decay_rate(0.02); + let user_node_id = user_node.id.clone(); + self.db + .call({ + let n = user_node; + move |conn| queries::insert_node(conn, &n) + }) + .await?; + let edge = Edge::new(user_node_id.clone(), session_id.clone(), EdgeKind::PartOf); + self.db + .call(move |conn| queries::insert_edge(conn, &edge)) + .await?; + let _ = self.auto_link_tx.try_send(user_node_id); + let brief = memory::briefing_with_kinds( &self.db, &self.embed, @@ -55,7 +84,7 @@ impl Agent { NodeKind::Capability, NodeKind::Limitation, ], - 12, + self.config.briefing_max_nodes, ) .await?; @@ -131,9 +160,11 @@ impl Agent { messages.push(Message::assistant(&response.text)); } - // Execute ALL tool calls and collect results + // Execute ALL tool calls in parallel and collect results let mut tool_results: Vec<(String, String)> = Vec::new(); - for tc in &response.tool_calls { + + if response.tool_calls.len() == 1 { + let tc = &response.tool_calls[0]; let result = self .tools .execute( @@ -145,6 +176,55 @@ impl Agent { ) .await?; tool_results.push((tc.id.clone(), result.output)); + } else { + let mut set = JoinSet::new(); + for tc in &response.tool_calls { + // Validate input before spawning parallel handler + if let Err(e) = self.tools.validate_input(&tc.name, &tc.input) { + tool_results.push(( + tc.id.clone(), + format!("Validation error: {e}"), + )); + continue; + } + let handler = self.tools.get_handler(&tc.name); + let input = tc.input.clone(); + let id = tc.id.clone(); + let name = tc.name.clone(); + if let Some(handler) = handler { + set.spawn(async move { + let result = handler(input).await; + (id, name, result) + }); + } else { + tool_results.push(( + tc.id.clone(), + format!("Error: unknown tool '{}'", tc.name), + )); + } + } + while let Some(res) = set.join_next().await { + match res { + Ok((id, name, Ok(result))) => { + self.tools + .record_tool_call( + &name, + &result, + iter_id.clone(), + &self.db, + &self.auto_link_tx, + ) + .await?; + tool_results.push((id, result.output)); + } + Ok((id, _name, Err(e))) => { + tool_results.push((id, format!("Tool error: {e}"))); + } + Err(e) => { + eprintln!("Tool task panicked: {e}"); + } + } + } } // Push all tool results in a single user message @@ -156,8 +236,10 @@ impl Agent { } } StopReason::EndTurn | StopReason::MaxTokens => { - // Store fact from response - let fact = Node::fact_from_response(&response.text, &session_id); + // Store fact from response with timestamp + let resp_ts = format_timestamp(crate::types::now_unix()); + let fact = Node::fact_from_response(&response.text, &session_id) + .with_body(format!("[{resp_ts}] {}", response.text)); let fact_id = fact.id.clone(); self.db .call({ @@ -210,7 +292,14 @@ impl Agent { } } - Ok("Reached iteration limit without final answer.".into()) + // Max iterations reached — ask the LLM to summarise with full context + messages.push(Message::user( + "You've reached the maximum number of iterations for this task. \ + Summarise what you accomplished so far and let the user know \ + they can ask you to continue if needed. Be concise and natural." + )); + let wrap_up = self.llm.complete(&messages).await?; + Ok(wrap_up.text) } /// Run a single turn within an ongoing chat session. @@ -220,14 +309,21 @@ impl Agent { /// turns that are relevant surface naturally), and the LLM receives only /// `[system(briefing), user(input)]` — no growing message history. /// - /// Tool-call loops use a temporary message vec within the turn. + /// **Non-blocking design**: The first LLM call runs synchronously so the + /// user always gets a fast response. If the LLM requests tool calls, they + /// are executed in a background `tokio::spawn` task which then continues + /// the LLM loop and stores its final answer as a `BackgroundTask` node. + /// This means the user can keep chatting while tools run. pub async fn run_turn( &self, session_id: &NodeId, input: &str, + ctx: &TurnContext, + media: Option<&crate::channels::types::MediaPayload>, ) -> Result { // 1. Store the user's input as a UserInput node in the graph - let user_node = Node::new(NodeKind::UserInput, input) + let now_ts = format_timestamp(crate::types::now_unix()); + let user_node = Node::new(NodeKind::UserInput, format!("[{now_ts}] {input}")) .with_body(input) .with_importance(0.4) .with_decay_rate(0.02); @@ -264,8 +360,6 @@ impl Agent { let _ = self.auto_link_tx.try_send(user_node_id); // 2. Build a FRESH briefing using the input as semantic query - // Prior UserInput nodes and Fact responses that are relevant will - // surface naturally through HNSW recall. let brief = memory::briefing_with_kinds( &self.db, &self.embed, @@ -283,12 +377,11 @@ impl Agent { NodeKind::Capability, NodeKind::Limitation, ], - 16, // slightly more nodes to capture conversation context + self.config.briefing_max_nodes, ) .await?; - // 3. Fetch recent session nodes (recency window) and merge any that - // the semantic search didn't already return. + // 3. Fetch recent session nodes (recency window) let recency_window = self.config.session_recency_window; let briefed_ids: std::collections::HashSet = brief.nodes.iter().map(|sn| sn.node.id.clone()).collect(); @@ -300,17 +393,18 @@ impl Agent { }) .await?; let mut recency_section = String::new(); - // Reverse so we go chronological (oldest first) within the section for node in recent_nodes.iter().rev() { if briefed_ids.contains(&node.id) { - continue; // already in semantic briefing + continue; } let body = node.body.as_deref().unwrap_or(&node.title); let label = match node.kind { NodeKind::UserInput => "User", _ => "Assistant", }; - recency_section.push_str(&format!("- {label}: {body}\n")); + let ts = format_timestamp(node.created_at); + let rel = memory::relative_time(node.created_at); + recency_section.push_str(&format!("- [{ts}] ({rel}) {label}: {body}\n")); } let mut context_doc = brief.context_doc; @@ -320,129 +414,444 @@ impl Agent { context_doc.push('\n'); } - // 4. Build messages — just system + user, no history - let mut messages = vec![ + // ── Channel awareness ─────────────────────────── + { + let sender = ctx.sender_name.as_deref().unwrap_or("someone"); + let where_str = if ctx.is_group { "a group chat" } else { "a direct message" }; + context_doc.push_str(&format!( + "## Current conversation\nYou are talking to **{}** via **{}** ({}).\n\n", + sender, ctx.channel, where_str, + )); + context_doc.push_str( + "When you need to use tools to fulfil a request, always include a brief, \ + natural acknowledgment in your response text so the user knows you're on it. \ + Keep it short and human — e.g. \"Let me look into that\" or \"Sure, one sec.\" \ + Your background workers will handle the tools and you'll be briefed on the \ + results, which will then be proactively sent to the user.\n\n", + ); + } + + // ── Pending notifications (background task results) ─ + let session_for_notif = session_id.to_string(); + let pending = self.db.call(move |conn| { + queries::get_pending_notification_nodes(conn, &session_for_notif) + }).await?; + + if !pending.is_empty() { + context_doc.push_str("## Updates while you were away\n"); + context_doc.push_str("The following background tasks finished since your last message. Mention these to the user naturally:\n"); + let mut delivered_ids: Vec = Vec::new(); + for node in &pending { + let rel = memory::relative_time(node.created_at); + context_doc.push_str(&format!("- ({}) {}\n", rel, node.title)); + delivered_ids.push(node.id.clone()); + } + context_doc.push('\n'); + + // Mark as delivered (touch increments access_count from 0 to 1+) + if !delivered_ids.is_empty() { + self.db.call(move |conn| { + queries::touch_nodes(conn, &delivered_ids) + }).await?; + } + } + + // 4. Build messages — just system + user (+ optional image), no history + let user_msg = if let Some(media) = media { + if media.kind == crate::channels::types::MediaKind::Image { + let b64 = base64_encode(&media.data); + Message::user_with_image(input, &b64, &media.mime_type) + } else { + // Non-image media: mention it in text + let label = format!("{} [attached {} file: {}]", + input, + format!("{:?}", media.kind).to_lowercase(), + media.filename.as_deref().unwrap_or("file"), + ); + Message::user(&label) + } + } else { + Message::user(input) + }; + // Clone context_doc before it's moved into messages — needed for + // the acknowledgment LLM call if the model returns tool calls with + // no accompanying text. + let context_doc_for_ack = context_doc.clone(); + let messages = vec![ Message::system(context_doc), - Message::user(input), + user_msg, ]; - let mut iter: usize = 0; - - loop { - iter += 1; + // 5. First LLM call (synchronous — the user waits for this one) + let iter: usize = 1; + let iter_node = Node::loop_iteration(iter, session_id); + let iter_id = iter_node.id.clone(); + self.db + .call({ + let n = iter_node.clone(); + move |conn| queries::insert_node(conn, &n) + }) + .await?; + let edge = Edge::new(iter_id.clone(), session_id.to_string(), EdgeKind::PartOf); + self.db + .call(move |conn| queries::insert_edge(conn, &edge)) + .await?; - // Write LoopIteration node - let iter_node = Node::loop_iteration(iter, session_id); - let iter_id = iter_node.id.clone(); - self.db - .call({ - let n = iter_node.clone(); - move |conn| queries::insert_node(conn, &n) + let start = Instant::now(); + let tool_defs = self.tools.anthropic_tool_defs(); + let response = if tool_defs.is_empty() { + self.llm.complete(&messages).await? + } else { + self.llm.complete_with_tools(&messages, &tool_defs).await? + }; + let latency_ms = start.elapsed().as_millis() as u64; + + // Record LlmCall node + let llm_node = Node { + kind: NodeKind::LlmCall, + title: format!("LLM call turn iter {iter}"), + body: Some( + serde_json::json!({ + "model": self.llm.model_name(), + "input_tokens": response.input_tokens, + "output_tokens": response.output_tokens, + "latency_ms": latency_ms, }) - .await?; + .to_string(), + ), + ..Node::new(NodeKind::LlmCall, format!("LLM call turn iter {iter}")) + }; + let llm_id = llm_node.id.clone(); + self.db + .call({ + let n = llm_node; + move |conn| queries::insert_node(conn, &n) + }) + .await?; + let llm_edge = Edge::new(llm_id, iter_id.clone(), EdgeKind::PartOf); + self.db + .call(move |conn| queries::insert_edge(conn, &llm_edge)) + .await?; - let edge = Edge::new(iter_id.clone(), session_id.to_string(), EdgeKind::PartOf); - self.db - .call(move |conn| queries::insert_edge(conn, &edge)) - .await?; + match response.stop_reason { + StopReason::EndTurn | StopReason::MaxTokens => { + // No tools needed — store and return immediately + let resp_ts = format_timestamp(crate::types::now_unix()); + let fact = Node::fact_from_response(&response.text, session_id) + .with_body(format!("[{resp_ts}] {}", response.text)); + let fact_id = fact.id.clone(); + self.db + .call({ + let f = fact; + move |conn| queries::insert_node(conn, &f) + }) + .await?; + let derives = Edge::new( + fact_id.clone(), + session_id.to_string(), + EdgeKind::DerivesFrom, + ); + self.db + .call(move |conn| queries::insert_edge(conn, &derives)) + .await?; + let _ = self.auto_link_tx.try_send(fact_id); + return Ok(response.text); + } + StopReason::ToolUse => { + // ── Return immediately, spawn tool execution in background ── + // Use the LLM's own natural acknowledgment text. If it sent + // tool calls with no accompanying text, make a quick LLM call + // with the full briefing to generate a natural acknowledgment. + let immediate_reply = if response.text.is_empty() { + let ack_messages = vec![ + Message::system(context_doc_for_ack.clone()), + Message::user(format!( + "The user said: \"{}\"\n\n\ + You are about to use tools to handle this. \ + Write a brief, natural acknowledgment (one short sentence) \ + so the user knows you're working on it. Do NOT describe \ + what tools you'll use or what you're doing. Just a quick, \ + human acknowledgment. Stay in character.", + input + )), + ]; + match self.llm.complete(&ack_messages).await { + Ok(ack) if !ack.text.is_empty() => ack.text, + _ => response.text.clone(), + } + } else { + response.text.clone() + }; + + // Clone everything needed for the background task + let db = self.db.clone(); + let llm = self.llm.clone(); + let tool_defs = self.tools.anthropic_tool_defs(); + let tools = self.tools.clone(); + let auto_link_tx = self.auto_link_tx.clone(); + let session_id = session_id.to_string(); + let panic_session = session_id.clone(); + let pending_calls: Vec = response.tool_calls.clone(); + let raw_content = response.raw_content.clone(); + let response_text = response.text.clone(); + let max_iterations = self.config.max_iterations; + + // Spawn background task for tool execution + continuation + let handle = tokio::spawn(async move { + let result = Self::background_tool_loop( + db.clone(), + llm, + tool_defs, + tools, + pending_calls, + raw_content, + response_text, + messages, + session_id.clone(), + max_iterations, + auto_link_tx.clone(), + ).await; + + // Store the final result as a BackgroundTask node + let bg_ts = format_timestamp(crate::types::now_unix()); + let (bg_title, bg_body, notif_summary) = match &result { + Ok(text) => ( + format!("[{bg_ts}] Background task completed"), + format!("[{bg_ts}] {text}"), + format!("Background task finished: {}", Self::truncate_summary(text, 120)), + ), + Err(e) => ( + format!("[{bg_ts}] Background task failed"), + format!("[{bg_ts}] Error: {e}"), + format!("A background task ran into a problem: {e}"), + ), + }; + let bg_node = Node::new(NodeKind::BackgroundTask, bg_title) + .with_body(&bg_body) + .with_importance(0.6) + .with_decay_rate(0.01); + let bg_id = bg_node.id.clone(); + if let Err(e) = db.call({ + let n = bg_node; + move |conn| queries::insert_node(conn, &n) + }).await { + tracing::error!("Failed to store background task node: {e}"); + } + let edge = Edge::new(bg_id.clone(), session_id.clone(), EdgeKind::PartOf); + if let Err(e) = db.call(move |conn| queries::insert_edge(conn, &edge)).await { + tracing::error!("Failed to store background task edge: {e}"); + } + let _ = auto_link_tx.try_send(bg_id.clone()); + + // Write notification node so the user gets informed on next message + let notif_node = Node::notification(¬if_summary); + let notif_id = notif_node.id.clone(); + if let Err(e) = db.call({ + let n = notif_node; + move |conn| queries::insert_node(conn, &n) + }).await { + tracing::error!("Failed to write notification node: {e}"); + } + // Link notification → session via PartOf + let notif_edge = Edge::new(notif_id.clone(), session_id.clone(), EdgeKind::PartOf); + if let Err(e) = db.call(move |conn| queries::insert_edge(conn, ¬if_edge)).await { + tracing::error!("Failed to link notification to session: {e}"); + } + // Also link notification → background task node via DerivesFrom + let derives = Edge::new(notif_id, bg_id, EdgeKind::DerivesFrom); + if let Err(e) = db.call(move |conn| queries::insert_edge(conn, &derives)).await { + tracing::error!("Failed to link notification to bg task: {e}"); + } + + if let Err(e) = &result { + tracing::error!("Background tool loop failed: {e}"); + } + }); + + // Monitor for panics in a secondary task + let panic_db = self.db.clone(); + let panic_sid = panic_session.clone(); + tokio::spawn(async move { + if let Err(e) = handle.await { + tracing::error!("Background task panicked: {e}"); + let notif_node = Node::notification( + &format!("A background task crashed with error: {e}"), + ); + let notif_id = notif_node.id.clone(); + let _ = panic_db.call({ + let n = notif_node; + move |conn| queries::insert_node(conn, &n) + }).await; + let edge = Edge::new(notif_id, panic_sid, EdgeKind::PartOf); + let _ = panic_db.call(move |conn| queries::insert_edge(conn, &edge)).await; + } + }); + + return Ok(immediate_reply); + } + } + } + + /// Truncate text to `max_len` chars, adding "..." if truncated. + fn truncate_summary(text: &str, max_len: usize) -> String { + if text.len() <= max_len { + text.to_string() + } else { + format!("{}...", &text[..max_len]) + } + } + + /// Execute tool calls and continue the LLM loop in the background. + /// + /// This runs after `run_turn` has returned the first response to the user. + /// It executes all pending tool calls, feeds results back to the LLM, and + /// continues until the LLM produces a final answer (EndTurn) or hits + /// max_iterations. + async fn background_tool_loop( + db: Db, + llm: Arc, + tool_defs: Vec, + tools: ToolRegistry, + pending_calls: Vec, + raw_content: Option, + response_text: String, + mut messages: Vec, + session_id: String, + max_iterations: usize, + auto_link_tx: async_channel::Sender, + ) -> crate::error::Result { + // Push the assistant's response (with tool_use blocks) + if let Some(raw) = raw_content { + messages.push(Message::assistant_raw(raw)); + } else { + messages.push(Message::assistant(&response_text)); + } + + // Execute pending tool calls using the full registry + let tool_results = Self::execute_tool_calls(&tools, &pending_calls, &db, &auto_link_tx, &session_id).await; + + // Push tool results + Self::push_tool_results(&mut messages, tool_results); + + // Continue LLM loop + let mut iter: usize = 1; // already did iter 1 in run_turn + loop { + iter += 1; + if iter > max_iterations { + // Max iterations in background — ask the LLM to wrap up + messages.push(Message::user( + "You've reached the maximum number of iterations for this \ + background task. Summarise what you accomplished and what \ + remains. Be concise and natural." + )); + let wrap_up = llm.complete(&messages).await?; + return Ok(wrap_up.text); + } - // LLM call - let start = Instant::now(); - let tool_defs = self.tools.anthropic_tool_defs(); let response = if tool_defs.is_empty() { - self.llm.complete(&messages).await? + llm.complete(&messages).await? } else { - self.llm.complete_with_tools(&messages, &tool_defs).await? + llm.complete_with_tools(&messages, &tool_defs).await? }; - let latency_ms = start.elapsed().as_millis() as u64; - - // Record LlmCall node - let llm_node = Node { - kind: NodeKind::LlmCall, - title: format!("LLM call turn iter {iter}"), - body: Some( - serde_json::json!({ - "model": self.llm.model_name(), - "input_tokens": response.input_tokens, - "output_tokens": response.output_tokens, - "latency_ms": latency_ms, - }) - .to_string(), - ), - ..Node::new(NodeKind::LlmCall, format!("LLM call turn iter {iter}")) - }; - let llm_id = llm_node.id.clone(); - self.db - .call({ - let n = llm_node; - move |conn| queries::insert_node(conn, &n) - }) - .await?; - let llm_edge = Edge::new(llm_id, iter_id.clone(), EdgeKind::PartOf); - self.db - .call(move |conn| queries::insert_edge(conn, &llm_edge)) - .await?; match response.stop_reason { + StopReason::EndTurn | StopReason::MaxTokens => { + // Store result in graph + let resp_ts = format_timestamp(crate::types::now_unix()); + let fact = Node::fact_from_response(&response.text, &session_id) + .with_body(format!("[{resp_ts}] {}", response.text)); + let fact_id = fact.id.clone(); + db.call({ + let f = fact; + move |conn| queries::insert_node(conn, &f) + }).await?; + let derives = Edge::new(fact_id, session_id, EdgeKind::DerivesFrom); + db.call(move |conn| queries::insert_edge(conn, &derives)).await?; + return Ok(response.text); + } StopReason::ToolUse => { - // Tool calls stay in the temporary messages vec for this turn + // More tool calls — execute them and keep going if let Some(raw) = response.raw_content.clone() { messages.push(Message::assistant_raw(raw)); } else { messages.push(Message::assistant(&response.text)); } - let mut tool_results: Vec<(String, String)> = Vec::new(); - for tc in &response.tool_calls { - let result = self - .tools - .execute( - &tc.name, - tc.input.clone(), - iter_id.clone(), - &self.db, - &self.auto_link_tx, - ) - .await?; - tool_results.push((tc.id.clone(), result.output)); - } - - if tool_results.len() == 1 { - let (id, output) = tool_results.into_iter().next().unwrap(); - messages.push(Message::tool_result_block(&id, &output)); - } else { - messages.push(Message::multi_tool_result_block(tool_results)); - } + let tool_results = Self::execute_tool_calls(&tools, &response.tool_calls, &db, &auto_link_tx, &session_id).await; + Self::push_tool_results(&mut messages, tool_results); } - StopReason::EndTurn | StopReason::MaxTokens => { - // Store the response as a Fact node in the graph - let fact = Node::fact_from_response(&response.text, session_id); - let fact_id = fact.id.clone(); - self.db - .call({ - let f = fact; - move |conn| queries::insert_node(conn, &f) - }) - .await?; - let derives = Edge::new( - fact_id.clone(), - session_id.to_string(), - EdgeKind::DerivesFrom, - ); - self.db - .call(move |conn| queries::insert_edge(conn, &derives)) - .await?; - let _ = self.auto_link_tx.try_send(fact_id); + } + } + } - return Ok(response.text); + /// Execute a set of tool calls (parallel when >1) and return (id, output) pairs. + async fn execute_tool_calls( + tools: &ToolRegistry, + calls: &[ToolCall], + db: &Db, + auto_link_tx: &async_channel::Sender, + session_id: &str, + ) -> Vec<(String, String)> { + let mut results: Vec<(String, String)> = Vec::new(); + + if calls.len() == 1 { + let tc = &calls[0]; + match tools.execute(&tc.name, tc.input.clone(), session_id.to_string(), db, auto_link_tx).await { + Ok(result) => { + tracing::debug!(tool=%tc.name, "background tool completed"); + results.push((tc.id.clone(), result.output)); + } + Err(e) => { + tracing::warn!(tool=%tc.name, error=%e, "background tool failed"); + results.push((tc.id.clone(), format!("Tool error: {e}"))); } } - - if iter >= self.config.max_iterations { - break; + } else { + let mut set = JoinSet::new(); + for tc in calls { + if let Err(e) = tools.validate_input(&tc.name, &tc.input) { + results.push((tc.id.clone(), format!("Validation error: {e}"))); + continue; + } + let handler = tools.get_handler(&tc.name); + let input = tc.input.clone(); + let id = tc.id.clone(); + let name = tc.name.clone(); + if let Some(handler) = handler { + set.spawn(async move { + let result = handler(input).await; + (id, name, result) + }); + } else { + results.push((tc.id.clone(), format!("Error: unknown tool '{}'", tc.name))); + } + } + while let Some(res) = set.join_next().await { + match res { + Ok((id, name, Ok(result))) => { + tracing::debug!(tool=%name, "background tool completed"); + results.push((id, result.output)); + } + Ok((id, _name, Err(e))) => { + results.push((id, format!("Tool error: {e}"))); + } + Err(e) => { + tracing::error!("Background tool task panicked: {e}"); + } + } } } - Ok("Reached iteration limit without final answer.".into()) + results + } + + /// Push tool results into the messages vec. + fn push_tool_results(messages: &mut Vec, results: Vec<(String, String)>) { + if results.len() == 1 { + let (id, output) = results.into_iter().next().unwrap(); + messages.push(Message::tool_result_block(&id, &output)); + } else { + messages.push(Message::multi_tool_result_block(results)); + } } } diff --git a/src/agent/subagent.rs b/src/agent/subagent.rs deleted file mode 100644 index 907593e..0000000 --- a/src/agent/subagent.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::sync::Arc; -use tokio::sync::RwLock; - -use crate::config::Config; -use crate::db::Db; -use crate::db::queries; -use crate::embed::EmbedHandle; -use crate::error::Result; -use crate::hnsw::VectorIndex; -use crate::llm::LlmClient; -use crate::tools::ToolRegistry; -use crate::types::*; - -use super::orchestrator::Agent; - -/// Spawn a sub-agent that shares the same graph. Its work is fully -/// visible, trusted, and linked to the parent session. -pub async fn spawn_subagent( - spec: SubAgentSpec, - task: &str, - parent_session: NodeId, - db: &Db, - embed: &EmbedHandle, - hnsw: &Arc>, - config: &Config, - llm: Arc, - tools: ToolRegistry, - auto_link_tx: async_channel::Sender, -) -> Result { - // 1. Write SubAgent node - let sub_node = Node::new(NodeKind::SubAgent, &spec.name) - .with_body(format!( - "Soul: {}\nCapabilities: {}", - spec.soul, - spec.capabilities.join(", ") - )); - let sub_id = sub_node.id.clone(); - db.call({ - let n = sub_node; - move |conn| queries::insert_node(conn, &n) - }) - .await?; - - // 2. Write Delegation node - let deleg = Node::new(NodeKind::Delegation, format!("Delegate: {}", task)) - .with_body(task); - let deleg_id = deleg.id.clone(); - db.call({ - let n = deleg; - move |conn| queries::insert_node(conn, &n) - }) - .await?; - - // Link: Delegation → SubAgent, Delegation → parent session - let e1 = Edge::new(deleg_id.clone(), sub_id.clone(), EdgeKind::PartOf); - let e2 = Edge::new(deleg_id.clone(), parent_session.clone(), EdgeKind::PartOf); - db.call(move |conn| { - queries::insert_edge(conn, &e1)?; - queries::insert_edge(conn, &e2) - }) - .await?; - - // 3. Run sub-agent with scoped config - let sub_config = Config { - max_iterations: spec.max_iterations, - ..config.clone() - }; - - let agent = Agent { - db: db.clone(), - embed: embed.clone(), - hnsw: hnsw.clone(), - config: sub_config, - llm, - tools, - auto_link_tx: auto_link_tx.clone(), - }; - - let answer = agent.run(task).await?; - - // 4. Write Synthesis node - let synth = Node::new(NodeKind::Synthesis, format!("Synthesis: {}", spec.name)) - .with_body(&answer); - let synth_id = synth.id.clone(); - db.call({ - let n = synth; - move |conn| queries::insert_node(conn, &n) - }) - .await?; - - // Link: Synthesis → Delegation, Synthesis → parent session - let e3 = Edge::new(synth_id.clone(), deleg_id, EdgeKind::DerivesFrom); - let e4 = Edge::new(synth_id, parent_session, EdgeKind::PartOf); - db.call(move |conn| { - queries::insert_edge(conn, &e3)?; - queries::insert_edge(conn, &e4) - }) - .await?; - - Ok(SubAgentResult { - answer, - facts_created: vec![], // TODO: collect fact IDs during sub-agent run - tokens_used: 0, - }) -} diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..1c1e0ff --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,442 @@ +//! HTTP API — the omnichannel gateway. +//! +//! Provides a REST API that any messaging platform adapter can call: +//! +//! - `POST /v1/message` — send a message (resolves identity, gets/creates session, runs turn) +//! - `POST /v1/channels/webhook/inbound` — generic webhook inbound (pipeline-routed) +//! - `GET /v1/channels` — list channels and their health +//! - `GET /v1/health` — liveness check +//! - `GET /v1/sessions/:user_id` — list sessions for a user +//! - `GET /v1/stats` — graph + session statistics +//! - `GET /v1/ws/chat` — WebSocket chat endpoint +//! +//! Authentication is via an `x-api-key` header checked against the `API_KEY` env var. +//! If `API_KEY` is not set, authentication is disabled (development mode). + +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use axum::{ + Json, Router, + extract::{Path, Query, State, ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}}, + http::{HeaderMap, StatusCode}, + middleware::{self, Next}, + response::IntoResponse, + routing::{get, post}, +}; +use serde::{Deserialize, Serialize}; +use tower_http::cors::CorsLayer; +use tower_http::trace::TraceLayer; + +use crate::agent::orchestrator::Agent; +use crate::channels::pipeline::Pipeline; +use crate::channels::types::InboundEnvelope; +use crate::channels::webchat::{WsChatMessage, WsChatReply}; +use crate::channels::registry::ChannelRegistry; +use crate::session; +use crate::CortexEmbedded; + +// ─── Shared state ─────────────────────────────────────── + +/// Application state shared across all request handlers. +pub struct AppState { + pub cx: CortexEmbedded, + pub agent: Agent, + pub api_key: Option, + /// The omnichannel pipeline (identity → session → hooks → agent → outbound). + pub pipeline: Arc, + /// Channel registry for health/status queries. + pub registry: Arc, + /// WebChat active connection counter (shared with WebChatChannel). + pub webchat_counter: Arc, + /// WebChat maximum concurrent connections. + pub webchat_max: usize, +} + +// ─── Request / Response types ─────────────────────────── + +#[derive(Debug, Deserialize)] +pub struct MessageRequest { + /// Channel identifier, e.g. "whatsapp", "telegram", "api", "cli". + pub channel: String, + /// The external user ID on that channel (phone number, chat id, etc.). + pub external_id: String, + /// The user's message text. + pub text: String, +} + +#[derive(Debug, Serialize)] +pub struct MessageResponse { + /// The agent's reply. + pub reply: String, + /// Internal user ID (for follow-up requests). + pub user_id: String, + /// Session ID (graph node id used for this conversation). + pub session_id: String, +} + +#[derive(Debug, Serialize)] +pub struct HealthResponse { + pub status: &'static str, + pub version: &'static str, +} + +#[derive(Debug, Serialize)] +pub struct StatsResponse { + pub nodes: i64, + pub edges: i64, + pub by_kind: std::collections::HashMap, + pub managed_sessions: i64, + pub total_turns: i64, +} + +#[derive(Debug, Serialize)] +pub struct SessionInfo { + pub session_id: String, + pub channel: String, + pub created_at: i64, + pub turn_count: i64, + pub last_active: i64, +} + +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub error: String, +} + +/// Webhook inbound request — superset of MessageRequest with optional fields. +#[derive(Debug, Deserialize)] +pub struct WebhookInboundRequest { + pub channel: Option, + pub external_id: String, + pub text: String, + #[serde(default)] + pub sender_name: Option, + #[serde(default)] + pub callback_url: Option, + #[serde(default)] + pub group_id: Option, +} + +#[derive(Debug, Serialize)] +pub struct ChannelStatusResponse { + pub id: String, + pub health: crate::channels::types::ChannelHealth, +} + +// ─── Router ───────────────────────────────────────────── + +/// Build the axum `Router` with all routes and middleware. +pub fn router(state: Arc) -> Router { + Router::new() + // Core messaging endpoints + .route("/v1/message", post(handle_message)) + .route("/v1/channels/webhook/inbound", post(handle_webhook_inbound)) + // Session / stats endpoints + .route("/v1/sessions/{user_id}", get(handle_sessions)) + .route("/v1/stats", get(handle_stats)) + // Channel management + .route("/v1/channels", get(handle_channels)) + // Auth middleware on all of the above + .layer(middleware::from_fn_with_state(state.clone(), auth_middleware)) + // Public endpoints (auth via query param for WS) + .route("/v1/health", get(handle_health)) + .route("/v1/ws/chat", get(handle_ws_upgrade)) + // Cross-cutting middleware + .layer(CorsLayer::permissive()) + .layer(TraceLayer::new_for_http()) + .with_state(state) +} + +// ─── Auth middleware ──────────────────────────────────── + +async fn auth_middleware( + State(state): State>, + headers: HeaderMap, + request: axum::extract::Request, + next: Next, +) -> impl IntoResponse { + // If no API_KEY is set, skip auth (dev mode) + let Some(ref expected) = state.api_key else { + return next.run(request).await.into_response(); + }; + + match headers.get("x-api-key").and_then(|v| v.to_str().ok()) { + Some(key) if key == expected => next.run(request).await.into_response(), + _ => ( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Invalid or missing x-api-key header".into(), + }), + ) + .into_response(), + } +} + +// ─── Handlers ─────────────────────────────────────────── + +async fn handle_health() -> Json { + Json(HealthResponse { + status: "ok", + version: env!("CARGO_PKG_VERSION"), + }) +} + +/// Original message handler — uses the pipeline for processing. +async fn handle_message( + State(state): State>, + Json(req): Json, +) -> impl IntoResponse { + let envelope = InboundEnvelope::new(&req.channel, &req.external_id, &req.text); + + match state.pipeline.process_sync(envelope, &state.cx.db, &state.agent).await { + Ok(result) => ( + StatusCode::OK, + Json(MessageResponse { + reply: result.reply, + user_id: result.user_id, + session_id: result.session_id, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("{e}"), + }), + ) + .into_response(), + } +} + +/// Webhook inbound — generic webhook channel messages. +async fn handle_webhook_inbound( + State(state): State>, + Json(req): Json, +) -> impl IntoResponse { + let mut envelope = InboundEnvelope::new( + req.channel.as_deref().unwrap_or("webhook"), + &req.external_id, + &req.text, + ); + envelope.sender_name = req.sender_name; + envelope.callback_url = req.callback_url; + envelope.group_id = req.group_id; + + match state.pipeline.process(envelope, &state.cx.db, &state.agent).await { + Ok(result) => ( + StatusCode::OK, + Json(MessageResponse { + reply: result.reply, + user_id: result.user_id, + session_id: result.session_id, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("{e}"), + }), + ) + .into_response(), + } +} + +/// List all channels and their health status. +async fn handle_channels(State(state): State>) -> impl IntoResponse { + let health_list = state.registry.health_all().await; + let statuses: Vec = health_list + .into_iter() + .map(|(id, health)| ChannelStatusResponse { id, health }) + .collect(); + (StatusCode::OK, Json(statuses)).into_response() +} + +async fn handle_sessions( + State(state): State>, + Path(user_id): Path, +) -> impl IntoResponse { + match session::list_user_sessions(&state.cx.db, &user_id).await { + Ok(sessions) => { + let infos: Vec = sessions + .into_iter() + .map(|s| SessionInfo { + session_id: s.node_id, + channel: s.channel, + created_at: s.created_at, + turn_count: s.turn_count, + last_active: s.last_active, + }) + .collect(); + (StatusCode::OK, Json(infos)).into_response() + } + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("Failed to list sessions: {e}"), + }), + ) + .into_response(), + } +} + +async fn handle_stats(State(state): State>) -> impl IntoResponse { + let graph_stats = match state.cx.stats().await { + Ok(s) => s, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("Failed to get graph stats: {e}"), + }), + ) + .into_response(); + } + }; + + let (managed_sessions, total_turns) = match session::stats(&state.cx.db).await { + Ok(s) => s, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("Failed to get session stats: {e}"), + }), + ) + .into_response(); + } + }; + + ( + StatusCode::OK, + Json(StatsResponse { + nodes: graph_stats.0, + edges: graph_stats.1, + by_kind: graph_stats.2, + managed_sessions, + total_turns, + }), + ) + .into_response() +} + +// ─── WebSocket chat ──────────────────────────────────── + +async fn handle_ws_upgrade( + State(state): State>, + Query(params): Query>, + ws: WebSocketUpgrade, +) -> impl IntoResponse { + // Auth check: if API_KEY is set, verify ?api_key= query param + if let Some(ref expected) = state.api_key { + match params.get("api_key") { + Some(k) if k == expected => {} + _ => { + return ( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Invalid or missing api_key query parameter".into(), + }), + ) + .into_response(); + } + } + } + + // Check connection limit + let current = state.webchat_counter.load(Ordering::Relaxed); + if current >= state.webchat_max { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ErrorResponse { + error: format!("WebChat connection limit reached ({}/{})", current, state.webchat_max), + }), + ) + .into_response(); + } + + ws.on_upgrade(move |socket| handle_ws_connection(socket, state)) + .into_response() +} + +async fn handle_ws_connection(socket: WebSocket, state: Arc) { + use futures::stream::StreamExt; + use futures::sink::SinkExt; + + let (mut sender, mut receiver) = socket.split(); + let session_token = uuid::Uuid::new_v4().to_string(); + + // Increment connection counter + state.webchat_counter.fetch_add(1, Ordering::Relaxed); + + // Send connected message + let connected = WsChatReply::connected(&session_token); + if let Ok(json) = serde_json::to_string(&connected) { + let _ = sender.send(WsMessage::Text(json.into())).await; + } + + tracing::info!(session_token = %session_token, "webchat client connected"); + + // Process messages + while let Some(Ok(msg)) = receiver.next().await { + let text = match msg { + WsMessage::Text(t) => t.to_string(), + WsMessage::Close(_) => break, + _ => continue, + }; + + // Parse the client message + let chat_msg: WsChatMessage = match serde_json::from_str(&text) { + Ok(m) => m, + Err(e) => { + let err = WsChatReply::error(&format!("Invalid message format: {e}")); + if let Ok(json) = serde_json::to_string(&err) { + let _ = sender.send(WsMessage::Text(json.into())).await; + } + continue; + } + }; + + // Send typing indicator + let typing = WsChatReply::typing(); + if let Ok(json) = serde_json::to_string(&typing) { + let _ = sender.send(WsMessage::Text(json.into())).await; + } + + // Use session_token from message or from connection + let token = chat_msg + .session_token + .as_deref() + .unwrap_or(&session_token); + + // Create inbound envelope and process through pipeline + let envelope = InboundEnvelope::new("webchat", token, &chat_msg.text); + + match state + .pipeline + .process_sync(envelope, &state.cx.db, &state.agent) + .await + { + Ok(result) => { + let reply = WsChatReply::reply(&result.reply); + if let Ok(json) = serde_json::to_string(&reply) { + if sender.send(WsMessage::Text(json.into())).await.is_err() { + break; + } + } + } + Err(e) => { + let err = WsChatReply::error(&format!("Agent error: {e}")); + if let Ok(json) = serde_json::to_string(&err) { + let _ = sender.send(WsMessage::Text(json.into())).await; + } + } + } + } + + // Decrement connection counter + state.webchat_counter.fetch_sub(1, Ordering::Relaxed); + tracing::info!(session_token = %session_token, "webchat client disconnected"); +} diff --git a/src/bin/cede.rs b/src/bin/cede.rs deleted file mode 100644 index eae5c6f..0000000 --- a/src/bin/cede.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[tokio::main] -async fn main() { - if let Err(e) = cede::cli::run().await { - eprintln!("Error: {e}"); - std::process::exit(1); - } -} diff --git a/src/bin/omni_cede.rs b/src/bin/omni_cede.rs new file mode 100644 index 0000000..bcfa73e --- /dev/null +++ b/src/bin/omni_cede.rs @@ -0,0 +1,15 @@ +#[tokio::main] +async fn main() { + // Initialize tracing (for tower-http TraceLayer) + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "omni_cede=info,tower_http=info".parse().unwrap()), + ) + .init(); + + if let Err(e) = omni_cede::cli::run().await { + eprintln!("Error: {e}"); + std::process::exit(1); + } +} \ No newline at end of file diff --git a/src/browser/cdp.rs b/src/browser/cdp.rs new file mode 100644 index 0000000..6716364 --- /dev/null +++ b/src/browser/cdp.rs @@ -0,0 +1,488 @@ +//! Chrome DevTools Protocol (CDP) client over WebSocket. +//! +//! Communicates with a running Chrome instance via its debugging WebSocket. +//! Supports navigation, DOM queries, JavaScript evaluation, screenshots, +//! input events, and network interception. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +use futures::stream::{SplitSink, SplitStream}; +use futures::{SinkExt, StreamExt}; +use serde_json::Value; +use tokio::net::TcpStream; +use tokio::sync::{Mutex, RwLock, oneshot}; +use tokio_tungstenite::tungstenite::Message as WsMessage; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; + +type WsWriter = SplitSink>, WsMessage>; +type WsReader = SplitStream>>; + +/// A CDP session wrapping a WebSocket connection to Chrome. +pub struct BrowserSession { + writer: Arc>, + /// Pending request callbacks keyed by message ID. + pending: Arc>>>, + /// Monotonic message counter. + next_id: AtomicU64, + /// Chrome process handle (if we spawned it). + _chrome: Option, + /// Event listeners keyed by method name. + event_listeners: Arc>>>>, +} + +/// Response from calling navigate. +#[derive(Debug, Clone)] +pub struct NavigateResult { + pub frame_id: String, + pub loader_id: Option, +} + +/// A compact representation of a page element. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct PageElement { + pub tag: String, + pub text: String, + pub attributes: HashMap, + pub selector: String, +} + +impl BrowserSession { + /// Launch Chrome with remote debugging and connect via CDP. + /// + /// `chrome_path` — path to Chrome executable (None = auto-detect). + /// `port` — debugging port (default 9222). + /// `headless` — whether to run headless. + pub async fn launch( + chrome_path: Option<&str>, + port: u16, + headless: bool, + ) -> Result { + let chrome = find_chrome(chrome_path)?; + + let mut args = vec![ + format!("--remote-debugging-port={port}"), + "--no-first-run".to_string(), + "--no-default-browser-check".to_string(), + "--disable-background-networking".to_string(), + "--disable-component-update".to_string(), + "--disable-features=TranslateUI".to_string(), + ]; + + if headless { + args.push("--headless=new".to_string()); + } + + // Apply stealth flags + args.extend(super::stealth::chrome_flags()); + + let child = tokio::process::Command::new(&chrome) + .args(&args) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::piped()) + .spawn() + .map_err(|e| format!("failed to launch Chrome: {e}"))?; + + // Wait for the debugger to come up + let ws_url = wait_for_debugger(port, 15).await?; + + let mut session = Self::connect(&ws_url).await?; + session._chrome = Some(child); + + Ok(session) + } + + /// Connect to an already-running Chrome debugger at the given WebSocket URL. + pub async fn connect(ws_url: &str) -> Result { + let (ws, _) = tokio_tungstenite::connect_async(ws_url) + .await + .map_err(|e| format!("CDP connect failed: {e}"))?; + + let (writer, reader) = ws.split(); + let pending: Arc>>> = + Arc::new(RwLock::new(HashMap::new())); + let event_listeners: Arc>>>> = + Arc::new(RwLock::new(HashMap::new())); + + let session = Self { + writer: Arc::new(Mutex::new(writer)), + pending: pending.clone(), + next_id: AtomicU64::new(1), + _chrome: None, + event_listeners: event_listeners.clone(), + }; + + // Spawn reader task + tokio::spawn(Self::reader_loop(reader, pending, event_listeners)); + + Ok(session) + } + + /// Background loop that reads CDP responses and events. + async fn reader_loop( + mut reader: WsReader, + pending: Arc>>>, + event_listeners: Arc>>>>, + ) { + while let Some(Ok(msg)) = reader.next().await { + if let WsMessage::Text(text) = msg { + if let Ok(json) = serde_json::from_str::(&text) { + // CDP response (has "id" field) + if let Some(id) = json["id"].as_u64() { + let mut map = pending.write().await; + if let Some(tx) = map.remove(&id) { + let _ = tx.send(json); + } + } + // CDP event (has "method" field, no "id") + else if let Some(method) = json["method"].as_str() { + let listeners = event_listeners.read().await; + if let Some(senders) = listeners.get(method) { + let params = json["params"].clone(); + for tx in senders { + let _ = tx.try_send(params.clone()); + } + } + } + } + } + } + } + + /// Send a CDP command and wait for the response. + pub async fn send(&self, method: &str, params: Value) -> Result { + let id = self.next_id.fetch_add(1, Ordering::SeqCst); + + let msg = serde_json::json!({ + "id": id, + "method": method, + "params": params, + }); + + let (tx, rx) = oneshot::channel(); + { + let mut map = self.pending.write().await; + map.insert(id, tx); + } + + { + let mut writer = self.writer.lock().await; + writer + .send(WsMessage::Text(msg.to_string())) + .await + .map_err(|e| format!("CDP send error: {e}"))?; + } + + let response = tokio::time::timeout(std::time::Duration::from_secs(30), rx) + .await + .map_err(|_| "CDP response timeout (30s)".to_string())? + .map_err(|_| "CDP response channel closed".to_string())?; + + if let Some(err) = response.get("error") { + return Err(format!("CDP error: {}", err)); + } + + Ok(response.get("result").cloned().unwrap_or(Value::Null)) + } + + /// Subscribe to a CDP event. Returns a receiver that yields event params. + pub async fn on_event(&self, method: &str) -> tokio::sync::mpsc::Receiver { + let (tx, rx) = tokio::sync::mpsc::channel(64); + let mut listeners = self.event_listeners.write().await; + listeners + .entry(method.to_string()) + .or_default() + .push(tx); + rx + } + + // ─── High-level helpers ─────────────────────────────── + + /// Navigate to a URL and wait for load. + pub async fn navigate(&self, url: &str) -> Result { + // Enable Page domain + self.send("Page.enable", serde_json::json!({})).await?; + + let result = self + .send("Page.navigate", serde_json::json!({ "url": url })) + .await?; + + let frame_id = result["frameId"] + .as_str() + .unwrap_or("") + .to_string(); + let loader_id = result["loaderId"].as_str().map(String::from); + + // Wait for loadEventFired + let mut rx = self.on_event("Page.loadEventFired").await; + let _ = tokio::time::timeout( + std::time::Duration::from_secs(30), + rx.recv(), + ) + .await; + + Ok(NavigateResult { + frame_id, + loader_id, + }) + } + + /// Get the current page URL. + pub async fn current_url(&self) -> Result { + let result = self + .send( + "Runtime.evaluate", + serde_json::json!({ + "expression": "window.location.href", + "returnByValue": true, + }), + ) + .await?; + Ok(result["result"]["value"] + .as_str() + .unwrap_or("") + .to_string()) + } + + /// Evaluate JavaScript and return the result as a string. + pub async fn evaluate(&self, expression: &str) -> Result { + let result = self + .send( + "Runtime.evaluate", + serde_json::json!({ + "expression": expression, + "returnByValue": true, + "awaitPromise": true, + }), + ) + .await?; + + if let Some(exc) = result.get("exceptionDetails") { + return Err(format!("JS error: {}", exc)); + } + + Ok(result["result"]["value"].clone()) + } + + /// Click on an element matching a CSS selector. + pub async fn click(&self, selector: &str) -> Result<(), String> { + let js = format!( + r#"(() => {{ + const el = document.querySelector({sel}); + if (!el) return 'NOT_FOUND'; + el.click(); + return 'OK'; + }})()"#, + sel = serde_json::to_string(selector).unwrap(), + ); + let result = self.evaluate(&js).await?; + if result.as_str() == Some("NOT_FOUND") { + return Err(format!("element not found: {selector}")); + } + Ok(()) + } + + /// Type text into the focused element, character by character. + pub async fn type_text(&self, text: &str) -> Result<(), String> { + for ch in text.chars() { + self.send( + "Input.dispatchKeyEvent", + serde_json::json!({ + "type": "keyDown", + "text": ch.to_string(), + }), + ) + .await?; + self.send( + "Input.dispatchKeyEvent", + serde_json::json!({ + "type": "keyUp", + "text": ch.to_string(), + }), + ) + .await?; + } + Ok(()) + } + + /// Fill an input element matching a selector with text. + pub async fn fill(&self, selector: &str, text: &str) -> Result<(), String> { + // Focus the element + let focus_js = format!( + r#"(() => {{ + const el = document.querySelector({sel}); + if (!el) return 'NOT_FOUND'; + el.focus(); + el.value = ''; + return 'OK'; + }})()"#, + sel = serde_json::to_string(selector).unwrap(), + ); + let result = self.evaluate(&focus_js).await?; + if result.as_str() == Some("NOT_FOUND") { + return Err(format!("element not found: {selector}")); + } + self.type_text(text).await?; + + // Trigger input event + let trigger_js = format!( + r#"(() => {{ + const el = document.querySelector({sel}); + if (el) {{ + el.dispatchEvent(new Event('input', {{ bubbles: true }})); + el.dispatchEvent(new Event('change', {{ bubbles: true }})); + }} + }})()"#, + sel = serde_json::to_string(selector).unwrap(), + ); + self.evaluate(&trigger_js).await?; + Ok(()) + } + + /// Take a screenshot (PNG), return as base64. + pub async fn screenshot(&self) -> Result { + let result = self + .send( + "Page.captureScreenshot", + serde_json::json!({ "format": "png" }), + ) + .await?; + result["data"] + .as_str() + .map(String::from) + .ok_or_else(|| "no screenshot data".to_string()) + } + + /// Get a compact text snapshot of the page DOM. + pub async fn snapshot(&self) -> Result, String> { + super::snapshot::take_snapshot(self).await + } + + /// Get page HTML content. + pub async fn get_html(&self) -> Result { + let result = self + .send( + "Runtime.evaluate", + serde_json::json!({ + "expression": "document.documentElement.outerHTML", + "returnByValue": true, + }), + ) + .await?; + Ok(result["result"]["value"] + .as_str() + .unwrap_or("") + .to_string()) + } + + /// Wait for a selector to appear in the DOM, with timeout. + pub async fn wait_for_selector( + &self, + selector: &str, + timeout_ms: u64, + ) -> Result<(), String> { + let start = std::time::Instant::now(); + let timeout = std::time::Duration::from_millis(timeout_ms); + + loop { + let js = format!( + "document.querySelector({}) !== null", + serde_json::to_string(selector).unwrap(), + ); + let result = self.evaluate(&js).await?; + if result.as_bool() == Some(true) { + return Ok(()); + } + if start.elapsed() > timeout { + return Err(format!( + "timeout waiting for selector '{selector}' ({timeout_ms}ms)" + )); + } + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + } + } + + /// Scroll the page by (x, y) pixels. + pub async fn scroll(&self, x: i32, y: i32) -> Result<(), String> { + self.evaluate(&format!("window.scrollBy({x}, {y})")).await?; + Ok(()) + } + + /// Get all cookies. + pub async fn get_cookies(&self) -> Result { + self.send("Network.getCookies", serde_json::json!({})).await + } + + /// Set a cookie. + pub async fn set_cookie(&self, cookie: Value) -> Result<(), String> { + self.send("Network.setCookie", cookie).await?; + Ok(()) + } + + /// Close the browser session. + pub async fn close(&self) -> Result<(), String> { + let _ = self.send("Browser.close", serde_json::json!({})).await; + Ok(()) + } +} + +// ─── Chrome discovery ─────────────────────────────────── + +fn find_chrome(explicit: Option<&str>) -> Result { + if let Some(p) = explicit { + return Ok(p.to_string()); + } + + let candidates = if cfg!(target_os = "windows") { + vec![ + r"C:\Program Files\Google\Chrome\Application\chrome.exe", + r"C:\Program Files (x86)\Google\Chrome\Application\chrome.exe", + ] + } else if cfg!(target_os = "macos") { + vec!["/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"] + } else { + vec![ + "/usr/bin/google-chrome", + "/usr/bin/google-chrome-stable", + "/usr/bin/chromium", + "/usr/bin/chromium-browser", + ] + }; + + for c in &candidates { + if std::path::Path::new(c).exists() { + return Ok(c.to_string()); + } + } + + Err("Chrome not found. Set chrome_path explicitly or install Chrome.".to_string()) +} + +/// Poll the Chrome debugger endpoint until it responds with a WebSocket URL. +async fn wait_for_debugger(port: u16, max_secs: u64) -> Result { + let url = format!("http://127.0.0.1:{port}/json/version"); + let client = reqwest::Client::new(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(max_secs); + + loop { + if std::time::Instant::now() > deadline { + return Err(format!( + "Chrome debugger did not respond on port {port} within {max_secs}s" + )); + } + + match client.get(&url).send().await { + Ok(resp) => { + if let Ok(json) = resp.json::().await { + if let Some(ws) = json["webSocketDebuggerUrl"].as_str() { + return Ok(ws.to_string()); + } + } + } + Err(_) => {} + } + + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + } +} diff --git a/src/browser/mod.rs b/src/browser/mod.rs new file mode 100644 index 0000000..18a000d --- /dev/null +++ b/src/browser/mod.rs @@ -0,0 +1,28 @@ +//! Browser automation module (feature-gated behind `browser`). +//! +//! Provides: +//! - CDP (Chrome DevTools Protocol) client over WebSocket +//! - Compact page snapshots (DOM → structured text) +//! - Stealth / anti-detection helpers +//! - WebMCP client (Chrome's agent-website structured tool API) +//! - Stored tool definitions and workflow engine +//! - Browser tools registered in the agent's ToolRegistry + +#[cfg(feature = "browser")] +pub mod cdp; +#[cfg(feature = "browser")] +pub mod snapshot; +#[cfg(feature = "browser")] +pub mod stealth; +#[cfg(feature = "browser")] +pub mod webmcp; +#[cfg(feature = "browser")] +pub mod store; +#[cfg(feature = "browser")] +pub mod workflow; +#[cfg(feature = "browser")] +pub mod tools; + +/// Re-export the browser session for convenience. +#[cfg(feature = "browser")] +pub use cdp::BrowserSession; diff --git a/src/browser/snapshot.rs b/src/browser/snapshot.rs new file mode 100644 index 0000000..ff02223 --- /dev/null +++ b/src/browser/snapshot.rs @@ -0,0 +1,131 @@ +//! Compact page snapshot — extract structured text from the DOM. +//! +//! Inspired by Xbot's approach: instead of raw HTML, produce a compact +//! representation that an LLM can reason about efficiently. + +use super::cdp::{BrowserSession, PageElement}; + +/// JavaScript injected into the page to extract a compact DOM snapshot. +/// +/// Returns a JSON array of `{ tag, text, attributes, selector }` objects +/// for all interactive and content-bearing elements. +const SNAPSHOT_JS: &str = r#" +(() => { + const INTERACTIVE = new Set([ + 'A', 'BUTTON', 'INPUT', 'TEXTAREA', 'SELECT', 'DETAILS', 'SUMMARY' + ]); + const CONTENT = new Set([ + 'H1','H2','H3','H4','H5','H6','P','LI','TD','TH','LABEL','SPAN', + 'STRONG','EM','CODE','PRE','BLOCKQUOTE','FIGCAPTION','ARTICLE' + ]); + const SKIP = new Set(['SCRIPT','STYLE','NOSCRIPT','SVG','PATH','META','LINK','BR','HR']); + + const results = []; + const seen = new Set(); + const MAX = 300; + + function cssSelector(el) { + if (el.id) return '#' + CSS.escape(el.id); + let path = ''; + while (el && el !== document.body) { + let seg = el.tagName.toLowerCase(); + if (el.id) { seg = '#' + CSS.escape(el.id); path = seg + (path ? ' > ' + path : ''); break; } + const parent = el.parentElement; + if (parent) { + const siblings = Array.from(parent.children).filter(c => c.tagName === el.tagName); + if (siblings.length > 1) { + seg += ':nth-of-type(' + (siblings.indexOf(el) + 1) + ')'; + } + } + path = seg + (path ? ' > ' + path : ''); + el = parent; + } + return path || 'body'; + } + + function walk(node) { + if (results.length >= MAX) return; + if (node.nodeType !== 1) return; + const tag = node.tagName; + if (SKIP.has(tag)) return; + if (node.offsetParent === null && tag !== 'BODY' && tag !== 'HTML') return; // hidden + + const isInteractive = INTERACTIVE.has(tag) || node.hasAttribute('role') || + node.hasAttribute('onclick') || node.hasAttribute('tabindex'); + const isContent = CONTENT.has(tag); + const text = (node.innerText || node.value || node.placeholder || '').trim().slice(0, 200); + + if ((isInteractive || isContent) && text.length > 0 && !seen.has(text)) { + seen.add(text); + const attrs = {}; + for (const a of ['href','type','name','aria-label','role','placeholder','value','alt','title','action']) { + const v = node.getAttribute(a); + if (v) attrs[a] = v.slice(0, 100); + } + results.push({ + tag: tag.toLowerCase(), + text: text, + attributes: attrs, + selector: cssSelector(node), + }); + } + + for (const child of node.children) walk(child); + } + + walk(document.body); + return JSON.stringify(results); +})() +"#; + +/// Take a compact snapshot of the current page. +/// +/// Returns a list of `PageElement` structs representing interactive and +/// content-bearing elements visible on the page. +pub async fn take_snapshot(session: &BrowserSession) -> Result, String> { + let result = session.evaluate(SNAPSHOT_JS).await?; + + let json_str = result.as_str().unwrap_or("[]"); + let elements: Vec = serde_json::from_str(json_str).unwrap_or_default(); + + Ok(elements) +} + +/// Format a snapshot into a compact text representation for the LLM. +pub fn format_snapshot(elements: &[PageElement]) -> String { + if elements.is_empty() { + return "(empty page)".to_string(); + } + + let mut out = String::with_capacity(elements.len() * 80); + for (i, el) in elements.iter().enumerate() { + let attrs: Vec = el + .attributes + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect(); + let attr_str = if attrs.is_empty() { + String::new() + } else { + format!(" [{}]", attrs.join(", ")) + }; + + out.push_str(&format!( + "[{i}] <{tag}>{attr} \"{text}\" → {sel}\n", + tag = el.tag, + attr = attr_str, + text = truncate(&el.text, 120), + sel = el.selector, + )); + } + out +} + +fn truncate(s: &str, max: usize) -> String { + if s.chars().count() <= max { + s.to_string() + } else { + let t: String = s.chars().take(max).collect(); + format!("{t}…") + } +} diff --git a/src/browser/stealth.rs b/src/browser/stealth.rs new file mode 100644 index 0000000..e1b84ed --- /dev/null +++ b/src/browser/stealth.rs @@ -0,0 +1,93 @@ +//! Anti-detection / stealth helpers. +//! +//! Provides Chrome launch flags and JavaScript patches to reduce +//! bot detection fingerprints. Inspired by Xbot's approach. + +/// Chrome command-line flags that reduce automation fingerprints. +pub fn chrome_flags() -> Vec { + vec![ + "--disable-blink-features=AutomationControlled".to_string(), + "--disable-infobars".to_string(), + "--disable-dev-shm-usage".to_string(), + "--disable-extensions".to_string(), + "--disable-gpu".to_string(), + "--no-sandbox".to_string(), + "--disable-setuid-sandbox".to_string(), + "--window-size=1920,1080".to_string(), + "--start-maximized".to_string(), + ] +} + +/// JavaScript patches injected early to hide automation signals. +pub const STEALTH_JS: &str = r#" +(() => { + // Remove webdriver flag + Object.defineProperty(navigator, 'webdriver', { get: () => false }); + + // Mock plugins (headless Chrome has none) + Object.defineProperty(navigator, 'plugins', { + get: () => { + const p = { length: 3 }; + p[0] = { name: 'Chrome PDF Plugin', description: 'Portable Document Format', filename: 'internal-pdf-viewer' }; + p[1] = { name: 'Chrome PDF Viewer', description: '', filename: 'mhjfbmdgcfjbbpaeojofohoefgiehjai' }; + p[2] = { name: 'Native Client', description: '', filename: 'internal-nacl-plugin' }; + return p; + } + }); + + // Mock languages + Object.defineProperty(navigator, 'languages', { get: () => ['en-US', 'en'] }); + + // Prevent detection via permissions API + const originalQuery = window.navigator.permissions.query; + window.navigator.permissions.query = (parameters) => { + if (parameters.name === 'notifications') { + return Promise.resolve({ state: Notification.permission }); + } + return originalQuery(parameters); + }; + + // Chrome runtime mock (missing in headless) + if (!window.chrome) { + window.chrome = {}; + } + if (!window.chrome.runtime) { + window.chrome.runtime = { + connect: () => {}, + sendMessage: () => {}, + }; + } + + // Hide automation-related properties from detection scripts + const automationProps = ['__webdriver_evaluate', '__selenium_evaluate', + '__fxdriver_evaluate', '__driver_evaluate', + '__webdriver_unwrapped', '__selenium_unwrapped', + '__fxdriver_unwrapped', '__driver_unwrapped', + '_Selenium_IDE_Recorder', '_selenium', 'calledSelenium', + '_WEBDRIVER_ELEM_CACHE', 'ChromeDriverw', + 'driver-hierarchical', '__webdriverFunc']; + for (const prop of automationProps) { + delete window[prop]; + delete document[prop]; + } +})(); +"#; + +/// Inject stealth patches into a browser session. +/// +/// Should be called immediately after page load (or via +/// `Page.addScriptToEvaluateOnNewDocument`). +pub async fn apply_stealth(session: &super::cdp::BrowserSession) -> Result<(), String> { + // Add the script so it runs on every new document load + session + .send( + "Page.addScriptToEvaluateOnNewDocument", + serde_json::json!({ "source": STEALTH_JS }), + ) + .await?; + + // Also inject into the current page immediately + session.evaluate(STEALTH_JS).await?; + + Ok(()) +} diff --git a/src/browser/store.rs b/src/browser/store.rs new file mode 100644 index 0000000..8906da2 --- /dev/null +++ b/src/browser/store.rs @@ -0,0 +1,280 @@ +//! Stored tool definitions — reusable browser interaction patterns. +//! +//! Inspired by Xbot's stored tool system. A stored tool captures a +//! repeatable browser interaction as a JSON definition that can be +//! replayed on demand. + +use serde::{Deserialize, Serialize}; + +/// A stored browser tool definition. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StoredTool { + /// Unique name (e.g. "twitter_post", "google_search"). + pub name: String, + /// Human-readable description. + pub description: String, + /// The URL to navigate to before executing steps. + pub start_url: String, + /// Ordered list of interaction steps. + pub steps: Vec, + /// Input parameters this tool accepts. + #[serde(default)] + pub parameters: Vec, + /// Domain briefing — context about the site for the LLM. + #[serde(default)] + pub domain_briefing: Option, + /// Whether to apply stealth patches before execution. + #[serde(default = "default_true")] + pub stealth: bool, +} + +fn default_true() -> bool { + true +} + +/// A single step in a stored tool's execution. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "action")] +pub enum ToolStep { + /// Navigate to a URL (supports {{param}} interpolation). + #[serde(rename = "navigate")] + Navigate { url: String }, + + /// Click an element by CSS selector. + #[serde(rename = "click")] + Click { selector: String }, + + /// Fill an input with text (supports {{param}} interpolation). + #[serde(rename = "fill")] + Fill { selector: String, value: String }, + + /// Wait for a selector to appear. + #[serde(rename = "wait")] + Wait { + selector: String, + #[serde(default = "default_wait_ms")] + timeout_ms: u64, + }, + + /// Wait a fixed duration. + #[serde(rename = "delay")] + Delay { + #[serde(default = "default_delay_ms")] + ms: u64, + }, + + /// Take a snapshot and return it as the step's output. + #[serde(rename = "snapshot")] + Snapshot, + + /// Evaluate JavaScript and capture the result. + #[serde(rename = "evaluate")] + Evaluate { expression: String }, + + /// Take a screenshot (returned as base64 PNG). + #[serde(rename = "screenshot")] + Screenshot, + + /// Scroll the page. + #[serde(rename = "scroll")] + Scroll { + #[serde(default)] + x: i32, + #[serde(default = "default_scroll_y")] + y: i32, + }, + + /// Press a key (Enter, Tab, Escape, etc.). + #[serde(rename = "key")] + Key { key: String }, + + /// Conditional: only execute inner steps if selector exists. + #[serde(rename = "if_exists")] + IfExists { + selector: String, + then: Vec, + #[serde(default)] + otherwise: Vec, + }, +} + +fn default_wait_ms() -> u64 { 5000 } +fn default_delay_ms() -> u64 { 1000 } +fn default_scroll_y() -> i32 { 500 } + +/// A parameter that a stored tool accepts. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolParameter { + pub name: String, + #[serde(default = "default_param_type")] + pub param_type: String, + pub description: String, + #[serde(default)] + pub required: bool, + #[serde(default)] + pub default_value: Option, +} + +fn default_param_type() -> String { + "string".to_string() +} + +impl StoredTool { + /// Interpolate `{{param}}` placeholders in a string with actual values. + pub fn interpolate( + template: &str, + params: &std::collections::HashMap, + ) -> String { + let mut result = template.to_string(); + for (key, value) in params { + result = result.replace(&format!("{{{{{key}}}}}"), value); + } + result + } + + /// Execute this stored tool using a browser session. + pub async fn execute( + &self, + session: &super::cdp::BrowserSession, + params: &std::collections::HashMap, + ) -> Result, String> { + // Apply stealth if requested + if self.stealth { + super::stealth::apply_stealth(session).await?; + } + + // Navigate to start URL + let url = Self::interpolate(&self.start_url, params); + session.navigate(&url).await?; + + // Wait a moment for page load + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + + // Execute steps + let mut results = Vec::new(); + for (i, step) in self.steps.iter().enumerate() { + match Self::execute_step(session, step, params).await { + Ok(r) => results.push(r), + Err(e) => { + results.push(StepResult { + step_index: i, + success: false, + output: format!("Step {i} failed: {e}"), + }); + break; // Stop on failure + } + } + } + + Ok(results) + } + + fn execute_step<'a>( + session: &'a super::cdp::BrowserSession, + step: &'a ToolStep, + params: &'a std::collections::HashMap, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + let i = 0; // Simplified — real index tracked by caller + match step { + ToolStep::Navigate { url } => { + let url = Self::interpolate(url, params); + session.navigate(&url).await?; + Ok(StepResult { step_index: i, success: true, output: format!("Navigated to {url}") }) + } + ToolStep::Click { selector } => { + let sel = Self::interpolate(selector, params); + session.click(&sel).await?; + Ok(StepResult { step_index: i, success: true, output: format!("Clicked {sel}") }) + } + ToolStep::Fill { selector, value } => { + let sel = Self::interpolate(selector, params); + let val = Self::interpolate(value, params); + session.fill(&sel, &val).await?; + Ok(StepResult { step_index: i, success: true, output: format!("Filled {sel}") }) + } + ToolStep::Wait { selector, timeout_ms } => { + let sel = Self::interpolate(selector, params); + session.wait_for_selector(&sel, *timeout_ms).await?; + Ok(StepResult { step_index: i, success: true, output: format!("Found {sel}") }) + } + ToolStep::Delay { ms } => { + tokio::time::sleep(std::time::Duration::from_millis(*ms)).await; + Ok(StepResult { step_index: i, success: true, output: format!("Waited {ms}ms") }) + } + ToolStep::Snapshot => { + let elements = session.snapshot().await?; + let text = super::snapshot::format_snapshot(&elements); + Ok(StepResult { step_index: i, success: true, output: text }) + } + ToolStep::Evaluate { expression } => { + let expr = Self::interpolate(expression, params); + let result = session.evaluate(&expr).await?; + Ok(StepResult { step_index: i, success: true, output: result.to_string() }) + } + ToolStep::Screenshot => { + let b64 = session.screenshot().await?; + Ok(StepResult { step_index: i, success: true, output: format!("[screenshot: {} bytes base64]", b64.len()) }) + } + ToolStep::Scroll { x, y } => { + session.scroll(*x, *y).await?; + Ok(StepResult { step_index: i, success: true, output: format!("Scrolled ({x}, {y})") }) + } + ToolStep::Key { key } => { + session + .send( + "Input.dispatchKeyEvent", + serde_json::json!({ + "type": "keyDown", + "key": key, + }), + ) + .await?; + session + .send( + "Input.dispatchKeyEvent", + serde_json::json!({ + "type": "keyUp", + "key": key, + }), + ) + .await?; + Ok(StepResult { step_index: i, success: true, output: format!("Pressed {key}") }) + } + ToolStep::IfExists { selector, then, otherwise } => { + let sel = Self::interpolate(selector, params); + let exists = session + .evaluate(&format!( + "document.querySelector({}) !== null", + serde_json::to_string(&sel).unwrap(), + )) + .await?; + + let steps = if exists.as_bool() == Some(true) { + then + } else { + otherwise + }; + + let mut last_result = StepResult { + step_index: i, + success: true, + output: format!("Condition: {sel} = {exists}"), + }; + for sub_step in steps { + last_result = Self::execute_step(session, sub_step, params).await?; + } + Ok(last_result) + } + } + }) + } +} + +/// Result of a single step execution. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StepResult { + pub step_index: usize, + pub success: bool, + pub output: String, +} diff --git a/src/browser/tools.rs b/src/browser/tools.rs new file mode 100644 index 0000000..61276a2 --- /dev/null +++ b/src/browser/tools.rs @@ -0,0 +1,549 @@ +//! Browser tools registered into the agent's ToolRegistry. +//! +//! These tools give the agent the ability to: +//! - Launch/connect to a browser +//! - Navigate, click, fill, screenshot, snapshot +//! - Discover and invoke WebMCP tools +//! - Run stored browser tools and workflows + +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::Mutex; + +use crate::tools::{Tool, ToolRegistry}; +use crate::types::ToolResult; + +use super::cdp::BrowserSession; +use super::webmcp::WebMcpCache; + +/// Shared browser state accessible by all browser tools. +pub struct BrowserState { + /// The active browser session (None if not launched). + pub session: Option, + /// WebMCP tool cache. + pub webmcp_cache: WebMcpCache, + /// Stored tool definitions loaded from graph or config. + pub stored_tools: HashMap, +} + +impl BrowserState { + pub fn new() -> Self { + Self { + session: None, + webmcp_cache: WebMcpCache::new(), + stored_tools: HashMap::new(), + } + } +} + +/// Register all browser tools into the given registry. +/// +/// The browser state is shared via `Arc>`. +pub fn register_browser_tools(reg: &mut ToolRegistry) { + let state: Arc> = Arc::new(Mutex::new(BrowserState::new())); + + // ── browser_launch: start a browser session ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_launch".to_string(), + description: concat!( + "Launch a Chrome browser with remote debugging, or connect to an existing one. ", + "Must be called before any other browser_* tools. ", + "If Chrome is already running with --remote-debugging-port, use connect_url instead." + ).to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "headless": { + "type": "boolean", + "description": "Run headless (no visible window). Default: true" + }, + "port": { + "type": "integer", + "description": "Debugging port (default: 9222)" + }, + "connect_url": { + "type": "string", + "description": "WebSocket URL to connect to an existing Chrome (overrides launch)" + } + }, + "required": [] + }), + trust: 0.7, + handler: Arc::new(move |input| { + let state = state.clone(); + Box::pin(async move { + let headless = input["headless"].as_bool().unwrap_or(true); + let port = input["port"].as_u64().unwrap_or(9222) as u16; + let connect_url = input["connect_url"].as_str().map(String::from); + + let session = if let Some(url) = connect_url { + BrowserSession::connect(&url).await + } else { + BrowserSession::launch(None, port, headless).await + }; + + match session { + Ok(s) => { + // Apply stealth + if let Err(e) = super::stealth::apply_stealth(&s).await { + tracing::warn!("stealth patches failed: {e}"); + } + let mut st = state.lock().await; + st.session = Some(s); + Ok(ToolResult { + output: "Browser launched and ready.".to_string(), + success: true, + }) + } + Err(e) => Ok(ToolResult { + output: format!("Failed to launch browser: {e}"), + success: false, + }), + } + }) + }), + }); + } + + // ── browser_navigate: go to a URL ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_navigate".to_string(), + description: "Navigate the browser to a URL. Waits for page load.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The URL to navigate to" + } + }, + "required": ["url"] + }), + trust: 0.7, + handler: Arc::new(move |input| { + let state = state.clone(); + Box::pin(async move { + let url = input["url"].as_str().unwrap_or("").to_string(); + let st = state.lock().await; + let session = st.session.as_ref() + .ok_or_else(|| crate::error::CortexError::Tool( + "No browser session. Call browser_launch first.".into() + ))?; + session.navigate(&url).await + .map_err(|e| crate::error::CortexError::Tool(e))?; + + // Auto-discover WebMCP tools for this domain + drop(st); + if let Ok(parsed) = url::Url::parse(&url) { + if let Some(domain) = parsed.host_str() { + let mut st = state.lock().await; + let tools = super::webmcp::discover(domain).await; + if !tools.is_empty() { + let tool_names: Vec = tools.iter().map(|t| t.name.clone()).collect(); + for tool in tools { + st.webmcp_cache.cache.entry(domain.to_string()) + .or_default() + .push(tool); + } + return Ok(ToolResult { + output: format!( + "Navigated to {url}\nWebMCP tools discovered: {}", + tool_names.join(", ") + ), + success: true, + }); + } + } + } + + Ok(ToolResult { + output: format!("Navigated to {url}"), + success: true, + }) + }) + }), + }); + } + + // ── browser_snapshot: get compact page content ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_snapshot".to_string(), + description: concat!( + "Get a compact snapshot of the current page. Returns interactive elements ", + "(links, buttons, inputs) and content (headings, paragraphs) with CSS selectors. ", + "Use this to understand page structure before clicking or filling forms." + ).to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + trust: 0.8, + handler: Arc::new(move |_input| { + let state = state.clone(); + Box::pin(async move { + let st = state.lock().await; + let session = st.session.as_ref() + .ok_or_else(|| crate::error::CortexError::Tool( + "No browser session. Call browser_launch first.".into() + ))?; + let url = session.current_url().await + .map_err(|e| crate::error::CortexError::Tool(e))?; + let elements = session.snapshot().await + .map_err(|e| crate::error::CortexError::Tool(e))?; + let text = super::snapshot::format_snapshot(&elements); + Ok(ToolResult { + output: format!("URL: {url}\n{} element(s):\n{text}", elements.len()), + success: true, + }) + }) + }), + }); + } + + // ── browser_click: click an element ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_click".to_string(), + description: "Click an element on the page by CSS selector. Use browser_snapshot first to find selectors.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "selector": { + "type": "string", + "description": "CSS selector of the element to click" + } + }, + "required": ["selector"] + }), + trust: 0.7, + handler: Arc::new(move |input| { + let state = state.clone(); + Box::pin(async move { + let selector = input["selector"].as_str().unwrap_or("").to_string(); + let st = state.lock().await; + let session = st.session.as_ref() + .ok_or_else(|| crate::error::CortexError::Tool( + "No browser session.".into() + ))?; + session.click(&selector).await + .map_err(|e| crate::error::CortexError::Tool(e))?; + Ok(ToolResult { + output: format!("Clicked: {selector}"), + success: true, + }) + }) + }), + }); + } + + // ── browser_fill: fill an input field ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_fill".to_string(), + description: "Fill an input or textarea with text. Uses CSS selector to target the element.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "selector": { + "type": "string", + "description": "CSS selector of the input element" + }, + "text": { + "type": "string", + "description": "Text to fill in" + } + }, + "required": ["selector", "text"] + }), + trust: 0.7, + handler: Arc::new(move |input| { + let state = state.clone(); + Box::pin(async move { + let selector = input["selector"].as_str().unwrap_or("").to_string(); + let text = input["text"].as_str().unwrap_or("").to_string(); + let st = state.lock().await; + let session = st.session.as_ref() + .ok_or_else(|| crate::error::CortexError::Tool( + "No browser session.".into() + ))?; + session.fill(&selector, &text).await + .map_err(|e| crate::error::CortexError::Tool(e))?; + Ok(ToolResult { + output: format!("Filled {selector} with text"), + success: true, + }) + }) + }), + }); + } + + // ── browser_screenshot: capture page image ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_screenshot".to_string(), + description: "Take a screenshot of the current page. Returns base64-encoded PNG.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + trust: 0.8, + handler: Arc::new(move |_input| { + let state = state.clone(); + Box::pin(async move { + let st = state.lock().await; + let session = st.session.as_ref() + .ok_or_else(|| crate::error::CortexError::Tool( + "No browser session.".into() + ))?; + let b64 = session.screenshot().await + .map_err(|e| crate::error::CortexError::Tool(e))?; + Ok(ToolResult { + output: format!("[screenshot: {} bytes base64]\ndata:image/png;base64,{}", b64.len(), &b64[..100.min(b64.len())]), + success: true, + }) + }) + }), + }); + } + + // ── browser_evaluate: run JavaScript ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_evaluate".to_string(), + description: "Execute JavaScript in the page context and return the result.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "JavaScript expression to evaluate" + } + }, + "required": ["expression"] + }), + trust: 0.6, + handler: Arc::new(move |input| { + let state = state.clone(); + Box::pin(async move { + let expr = input["expression"].as_str().unwrap_or("").to_string(); + let st = state.lock().await; + let session = st.session.as_ref() + .ok_or_else(|| crate::error::CortexError::Tool( + "No browser session.".into() + ))?; + let result = session.evaluate(&expr).await + .map_err(|e| crate::error::CortexError::Tool(e))?; + Ok(ToolResult { + output: serde_json::to_string_pretty(&result).unwrap_or_default(), + success: true, + }) + }) + }), + }); + } + + // ── browser_wait: wait for element ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_wait".to_string(), + description: "Wait for a CSS selector to appear in the DOM.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "selector": { + "type": "string", + "description": "CSS selector to wait for" + }, + "timeout_ms": { + "type": "integer", + "description": "Timeout in milliseconds (default: 10000)" + } + }, + "required": ["selector"] + }), + trust: 0.8, + handler: Arc::new(move |input| { + let state = state.clone(); + Box::pin(async move { + let selector = input["selector"].as_str().unwrap_or("").to_string(); + let timeout = input["timeout_ms"].as_u64().unwrap_or(10000); + let st = state.lock().await; + let session = st.session.as_ref() + .ok_or_else(|| crate::error::CortexError::Tool( + "No browser session.".into() + ))?; + session.wait_for_selector(&selector, timeout).await + .map_err(|e| crate::error::CortexError::Tool(e))?; + Ok(ToolResult { + output: format!("Element found: {selector}"), + success: true, + }) + }) + }), + }); + } + + // ── browser_close: close the browser ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_close".to_string(), + description: "Close the browser session.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + trust: 0.8, + handler: Arc::new(move |_input| { + let state = state.clone(); + Box::pin(async move { + let mut st = state.lock().await; + if let Some(session) = st.session.as_ref() { + let _ = session.close().await; + } + st.session = None; + Ok(ToolResult { + output: "Browser closed.".to_string(), + success: true, + }) + }) + }), + }); + } + + // ── browser_webmcp: discover and call WebMCP tools ── + { + let state = state.clone(); + reg.register(Tool { + name: "browser_webmcp".to_string(), + description: concat!( + "Interact with WebMCP tools exposed by the current website. ", + "Use action='discover' to find available tools, or action='invoke' ", + "to call a specific tool by name." + ).to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["discover", "invoke"], + "description": "discover: list tools from current site. invoke: call a tool." + }, + "domain": { + "type": "string", + "description": "Domain to discover tools from (default: current page's domain)" + }, + "tool_name": { + "type": "string", + "description": "Name of the WebMCP tool to invoke (required for action=invoke)" + }, + "input": { + "type": "object", + "description": "Input parameters for the tool (required for action=invoke)" + } + }, + "required": ["action"] + }), + trust: 0.7, + handler: Arc::new(move |input| { + let state = state.clone(); + Box::pin(async move { + let action = input["action"].as_str().unwrap_or("discover"); + let mut st = state.lock().await; + + match action { + "discover" => { + let domain = if let Some(d) = input["domain"].as_str() { + d.to_string() + } else if let Some(session) = st.session.as_ref() { + let url = session.current_url().await + .map_err(|e| crate::error::CortexError::Tool(e))?; + url::Url::parse(&url) + .ok() + .and_then(|u| u.host_str().map(String::from)) + .unwrap_or_default() + } else { + return Ok(ToolResult { + output: "No domain specified and no browser session.".into(), + success: false, + }); + }; + + let tools = st.webmcp_cache.get_or_discover(&domain).await; + if tools.is_empty() { + Ok(ToolResult { + output: format!("No WebMCP tools found at {domain}"), + success: true, + }) + } else { + let mut out = format!("{} WebMCP tool(s) from {domain}:\n", tools.len()); + for t in tools { + out.push_str(&format!("- {}: {}\n", t.name, t.description)); + } + Ok(ToolResult { output: out, success: true }) + } + } + "invoke" => { + let tool_name = input["tool_name"].as_str().unwrap_or(""); + let tool_input = input.get("input").cloned().unwrap_or(serde_json::json!({})); + + // Find the tool across all cached domains + let tool = st.webmcp_cache.cache.values() + .flat_map(|tools: &Vec| tools.iter()) + .find(|t| t.name == tool_name) + .cloned(); + + let tool = match tool { + Some(t) => t, + None => return Ok(ToolResult { + output: format!("WebMCP tool '{tool_name}' not found. Use action=discover first."), + success: false, + }), + }; + + // Try imperative first, fall back to declarative + let result = if tool.endpoint.is_some() { + super::webmcp::invoke_imperative(&tool, &tool_input).await + } else if tool.form_selector.is_some() { + if let Some(session) = st.session.as_ref() { + super::webmcp::invoke_declarative(session, &tool, &tool_input).await + } else { + Err("No browser session for declarative WebMCP tool.".to_string()) + } + } else { + Err("Tool has neither endpoint nor form_selector.".to_string()) + }; + + match result { + Ok(output) => Ok(ToolResult { output, success: true }), + Err(e) => Ok(ToolResult { + output: format!("WebMCP invoke error: {e}"), + success: false, + }), + } + } + _ => Ok(ToolResult { + output: format!("Unknown action: {action}. Use 'discover' or 'invoke'."), + success: false, + }), + } + }) + }), + }); + } +} diff --git a/src/browser/webmcp.rs b/src/browser/webmcp.rs new file mode 100644 index 0000000..2370b97 --- /dev/null +++ b/src/browser/webmcp.rs @@ -0,0 +1,216 @@ +//! WebMCP client — discover and invoke structured tools exposed by websites. +//! +//! Chrome's WebMCP (early preview) lets websites declare tools via +//! `/.well-known/webmcp.json`. This module discovers those declarations +//! and converts them into callable tool definitions for the agent. +//! +//! Reference: https://developer.chrome.com/blog/webmcp-epp + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// A WebMCP tool descriptor as declared by a website. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WebMcpTool { + pub name: String, + pub description: String, + #[serde(default)] + pub parameters: serde_json::Value, + /// The URL endpoint to POST to (for imperative tools). + #[serde(default)] + pub endpoint: Option, + /// CSS selector for the form element (for declarative tools). + #[serde(default)] + pub form_selector: Option, + /// The originating domain. + #[serde(skip_deserializing, default)] + pub domain: String, +} + +/// WebMCP manifest (/.well-known/webmcp.json). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WebMcpManifest { + #[serde(default)] + pub name: String, + #[serde(default)] + pub description: String, + #[serde(default)] + pub tools: Vec, + #[serde(default)] + pub version: String, +} + +/// Discover WebMCP tools from a website. +/// +/// Fetches `https://{domain}/.well-known/webmcp.json` and parses the manifest. +/// Returns an empty vec if the site doesn't support WebMCP. +pub async fn discover(domain: &str) -> Vec { + let url = format!("https://{domain}/.well-known/webmcp.json"); + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_default(); + + match client.get(&url).send().await { + Ok(resp) if resp.status().is_success() => { + match resp.json::().await { + Ok(manifest) => { + let mut tools = manifest.tools; + for tool in &mut tools { + tool.domain = domain.to_string(); + } + tracing::info!( + "WebMCP: discovered {} tool(s) from {domain}", + tools.len() + ); + tools + } + Err(e) => { + tracing::debug!("WebMCP: invalid manifest from {domain}: {e}"); + vec![] + } + } + } + Ok(_) => { + tracing::debug!("WebMCP: no manifest at {domain}"); + vec![] + } + Err(e) => { + tracing::debug!("WebMCP: fetch failed for {domain}: {e}"); + vec![] + } + } +} + +/// Invoke a WebMCP tool via its endpoint (imperative mode). +/// +/// POSTs the input JSON to the tool's endpoint and returns the response. +pub async fn invoke_imperative( + tool: &WebMcpTool, + input: &serde_json::Value, +) -> Result { + let endpoint = tool + .endpoint + .as_deref() + .ok_or_else(|| "tool has no endpoint (declarative only)".to_string())?; + + let url = if endpoint.starts_with("http") { + endpoint.to_string() + } else { + format!("https://{}{}", tool.domain, endpoint) + }; + + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .unwrap_or_default(); + + let resp = client + .post(&url) + .json(input) + .send() + .await + .map_err(|e| format!("WebMCP invoke error: {e}"))?; + + let status = resp.status(); + let body = resp + .text() + .await + .map_err(|e| format!("WebMCP response error: {e}"))?; + + if !status.is_success() { + return Err(format!("WebMCP tool returned {status}: {body}")); + } + + Ok(body) +} + +/// Invoke a WebMCP tool via form filling (declarative mode). +/// +/// Uses the browser session to fill and submit the form identified +/// by the tool's `form_selector`. +pub async fn invoke_declarative( + session: &super::cdp::BrowserSession, + tool: &WebMcpTool, + input: &serde_json::Value, +) -> Result { + let form_selector = tool + .form_selector + .as_deref() + .ok_or_else(|| "tool has no form_selector (imperative only)".to_string())?; + + // Fill each input field in the form + if let Some(params) = input.as_object() { + for (key, value) in params { + let val_str = match value { + serde_json::Value::String(s) => s.clone(), + other => other.to_string(), + }; + + // Try to fill by name attribute within the form + let selector = format!("{form_selector} [name=\"{key}\"]"); + if let Err(_) = session.fill(&selector, &val_str).await { + // Fall back to aria-label + let selector = format!("{form_selector} [aria-label=\"{key}\"]"); + let _ = session.fill(&selector, &val_str).await; + } + } + } + + // Submit the form + let submit_js = format!( + r#"(() => {{ + const form = document.querySelector({sel}); + if (!form) return 'FORM_NOT_FOUND'; + const submit = form.querySelector('[type="submit"], button'); + if (submit) {{ submit.click(); return 'CLICKED'; }} + form.submit(); + return 'SUBMITTED'; + }})()"#, + sel = serde_json::to_string(form_selector).unwrap(), + ); + + let result = session.evaluate(&submit_js).await?; + if result.as_str() == Some("FORM_NOT_FOUND") { + return Err(format!("form not found: {form_selector}")); + } + + // Wait briefly for response + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + // Take a snapshot of the result page + let snapshot = session.snapshot().await?; + let snapshot_text = super::snapshot::format_snapshot(&snapshot); + + Ok(format!("Form submitted. Page snapshot:\n{snapshot_text}")) +} + +/// Cache of discovered WebMCP tools, keyed by domain. +pub struct WebMcpCache { + pub cache: HashMap>, +} + +impl WebMcpCache { + pub fn new() -> Self { + Self { + cache: HashMap::new(), + } + } + + /// Get cached tools for a domain, or discover them. + pub async fn get_or_discover(&mut self, domain: &str) -> &[WebMcpTool] { + if !self.cache.contains_key(domain) { + let tools = discover(domain).await; + self.cache.insert(domain.to_string(), tools); + } + self.cache.get(domain).map(|v| v.as_slice()).unwrap_or(&[]) + } + + /// List all cached domains and their tool counts. + pub fn summary(&self) -> Vec<(String, usize)> { + self.cache + .iter() + .map(|(domain, tools)| (domain.clone(), tools.len())) + .collect() + } +} diff --git a/src/browser/workflow.rs b/src/browser/workflow.rs new file mode 100644 index 0000000..e0729d8 --- /dev/null +++ b/src/browser/workflow.rs @@ -0,0 +1,175 @@ +//! Workflow engine — execute multi-step browser workflows with conditionals. +//! +//! A workflow is an ordered sequence of stored tool invocations, +//! with conditional branching based on page state. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use super::store::StoredTool; + +/// A workflow definition — a sequence of stored tool calls with conditionals. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Workflow { + /// Workflow name. + pub name: String, + /// Description of what this workflow accomplishes. + pub description: String, + /// Ordered steps in the workflow. + pub steps: Vec, + /// Global parameters passed to all tool calls. + #[serde(default)] + pub parameters: Vec, +} + +/// A single step in a workflow. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum WorkflowStep { + /// Execute a stored tool by name. + #[serde(rename = "tool")] + RunTool { + tool_name: String, + /// Parameter overrides for this invocation. + #[serde(default)] + params: HashMap, + }, + + /// Conditional based on current page URL pattern. + #[serde(rename = "if_url")] + IfUrl { + pattern: String, + then: Vec, + #[serde(default)] + otherwise: Vec, + }, + + /// Wait for a specific page state before continuing. + #[serde(rename = "wait_for")] + WaitFor { + selector: String, + #[serde(default = "default_timeout")] + timeout_ms: u64, + }, + + /// Log a message to the workflow output. + #[serde(rename = "log")] + Log { message: String }, +} + +fn default_timeout() -> u64 { 10000 } + +/// Result of executing a workflow. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkflowResult { + pub workflow_name: String, + pub steps_executed: usize, + pub success: bool, + pub outputs: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkflowStepOutput { + pub step_index: usize, + pub step_type: String, + pub success: bool, + pub output: String, +} + +/// Execute a workflow using a browser session and a tool registry. +pub async fn execute_workflow( + session: &super::cdp::BrowserSession, + workflow: &Workflow, + tools: &HashMap, + params: &HashMap, +) -> WorkflowResult { + let mut outputs = Vec::new(); + let mut success = true; + + for (i, step) in workflow.steps.iter().enumerate() { + match execute_step(session, step, tools, params).await { + Ok(output) => { + outputs.push(WorkflowStepOutput { + step_index: i, + step_type: step_type_name(step), + success: true, + output, + }); + } + Err(e) => { + outputs.push(WorkflowStepOutput { + step_index: i, + step_type: step_type_name(step), + success: false, + output: format!("Error: {e}"), + }); + success = false; + break; + } + } + } + + WorkflowResult { + workflow_name: workflow.name.clone(), + steps_executed: outputs.len(), + success, + outputs, + } +} + +fn execute_step<'a>( + session: &'a super::cdp::BrowserSession, + step: &'a WorkflowStep, + tools: &'a HashMap, + params: &'a HashMap, +) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + match step { + WorkflowStep::RunTool { tool_name, params: extra_params } => { + let tool = tools + .get(tool_name) + .ok_or_else(|| format!("stored tool not found: {tool_name}"))?; + + let mut merged_params = params.clone(); + merged_params.extend(extra_params.clone()); + + let results = tool.execute(session, &merged_params).await?; + let output: Vec = results.iter().map(|r| r.output.clone()).collect(); + Ok(output.join("\n")) + } + WorkflowStep::IfUrl { pattern, then, otherwise } => { + let current_url = session.current_url().await?; + let matches = current_url.contains(pattern); + + let steps = if matches { then } else { otherwise }; + let mut last_output = format!("URL check: '{}' {} '{}'", + current_url, + if matches { "matches" } else { "does not match" }, + pattern, + ); + + for sub_step in steps { + last_output = execute_step(session, sub_step, tools, params).await?; + } + Ok(last_output) + } + WorkflowStep::WaitFor { selector, timeout_ms } => { + session.wait_for_selector(selector, *timeout_ms).await?; + Ok(format!("Found: {selector}")) + } + WorkflowStep::Log { message } => { + let interpolated = StoredTool::interpolate(message, params); + Ok(format!("[log] {interpolated}")) + } + } + }) +} + +fn step_type_name(step: &WorkflowStep) -> String { + match step { + WorkflowStep::RunTool { tool_name, .. } => format!("tool:{tool_name}"), + WorkflowStep::IfUrl { .. } => "if_url".to_string(), + WorkflowStep::WaitFor { .. } => "wait_for".to_string(), + WorkflowStep::Log { .. } => "log".to_string(), + } +} diff --git a/src/channels/discord.rs b/src/channels/discord.rs new file mode 100644 index 0000000..dfdcbb2 --- /dev/null +++ b/src/channels/discord.rs @@ -0,0 +1,460 @@ +//! Discord channel adapter — connects via the Discord REST + Gateway API. +//! +//! Uses the `serenity` crate for the WebSocket gateway and REST API. +//! When `serenity` is not available (default build), this module provides +//! a **stub adapter** that reports itself as unavailable. To enable the real +//! Discord adapter, build with `--features discord`. +//! +//! # Configuration +//! +//! ```json +//! { +//! "token": "MTIz…", // or DISCORD_BOT_TOKEN env var +//! "allow_from": ["*"], // guild:channel pairs, or "*" for all +//! "dm_policy": "open" // "open", "pairing", or "closed" +//! } +//! ``` + +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::error::{CortexError, Result}; + +use super::types::*; +use super::types::now_unix; +use super::Channel; + +const DISCORD_API: &str = "https://discord.com/api/v10"; + +/// Discord channel adapter. +/// +/// This is a REST-only implementation that polls for messages. For full +/// real-time support, enable the `discord` feature flag which brings in +/// the serenity gateway. +pub struct DiscordChannel { + client: reqwest::Client, + started: AtomicBool, + cancel: tokio::sync::watch::Sender, +} + +impl DiscordChannel { + pub fn new() -> Self { + let (cancel, _) = tokio::sync::watch::channel(false); + Self { + client: reqwest::Client::new(), + started: AtomicBool::new(false), + cancel, + } + } + + fn resolve_token(config: &serde_json::Value) -> Result { + if let Ok(token) = std::env::var("DISCORD_BOT_TOKEN") { + return Ok(token); + } + config + .get("token") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .ok_or_else(|| { + CortexError::Config( + "Discord: token not set in config or DISCORD_BOT_TOKEN env var".into(), + ) + }) + } +} + +// ─── Discord API types ────────────────────────────────── + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +struct DiscordMessage { + id: String, + content: String, + author: DiscordUser, + channel_id: String, + guild_id: Option, + #[serde(default)] + bot: bool, + #[serde(default)] + attachments: Vec, +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +struct DiscordAttachment { + url: String, + filename: String, + content_type: Option, + size: Option, +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +struct DiscordUser { + id: String, + username: String, + #[serde(default)] + bot: bool, +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +struct DmChannel { + id: String, + #[serde(rename = "type")] + channel_type: u8, +} + +#[derive(Debug, Serialize)] +struct CreateMessage { + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + message_reference: Option, +} + +#[derive(Debug, Serialize)] +struct MessageReference { + message_id: String, +} + +// ─── Channel implementation ───────────────────────────── + +#[async_trait] +impl Channel for DiscordChannel { + fn id(&self) -> &str { + "discord" + } + + fn display_name(&self) -> &str { + "Discord" + } + + async fn start(&self, ctx: ChannelContext) -> Result<()> { + let token = Self::resolve_token(&ctx.config)?; + + // Verify the token by calling /users/@me + let resp = self + .client + .get(format!("{}/users/@me", DISCORD_API)) + .header("Authorization", format!("Bot {}", token)) + .send() + .await + .map_err(|e| CortexError::Channel(format!("Discord auth check failed: {e}")))?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(CortexError::Channel(format!( + "Discord auth failed: {body}" + ))); + } + + self.started.store(true, Ordering::SeqCst); + + // Parse bot's own user ID so we can filter self-messages + let me: serde_json::Value = resp + .json() + .await + .map_err(|e| CortexError::Channel(format!("Discord parse @me: {e}")))?; + let bot_id = me["id"] + .as_str() + .unwrap_or("") + .to_string(); + + // Parse optional guild channel IDs to poll from DISCORD_CHANNELS env + let extra_channels: Vec = std::env::var("DISCORD_CHANNELS") + .unwrap_or_default() + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + // Start the DM + channel polling loop + let client = self.client.clone(); + let inbound_tx = ctx.inbound_tx.clone(); + let mut shutdown = ctx.shutdown.clone(); + let mut cancel_rx = self.cancel.subscribe(); + let token_clone = token.clone(); + + tokio::spawn(async move { + tracing::info!( + "discord adapter started (REST polling for DMs{})", + if extra_channels.is_empty() { + String::new() + } else { + format!(" + {} guild channels", extra_channels.len()) + } + ); + + // Track last seen message ID per channel to avoid re-processing + let mut last_seen: HashMap = HashMap::new(); + + loop { + tokio::select! { + _ = shutdown.changed() => break, + _ = cancel_rx.changed() => break, + _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { + // Collect channel IDs to poll: DM channels + configured guild channels + let mut channels_to_poll: Vec = extra_channels.clone(); + + // Fetch DM channels + let dm_url = format!("{}/users/@me/channels", DISCORD_API); + match client + .get(&dm_url) + .header("Authorization", format!("Bot {}", token_clone)) + .send() + .await + { + Ok(r) if r.status().is_success() => { + if let Ok(dms) = r.json::>().await { + for dm in dms { + // type 1 = DM channel + if dm.channel_type == 1 { + channels_to_poll.push(dm.id); + } + } + } + } + Ok(r) => { + tracing::warn!(status = %r.status(), "discord: failed to list DM channels"); + } + Err(e) => { + tracing::warn!(error = %e, "discord: DM channel list request failed"); + continue; + } + } + + // Poll each channel for new messages + for chan_id in &channels_to_poll { + let mut url = format!( + "{}/channels/{}/messages?limit=50", + DISCORD_API, chan_id + ); + if let Some(after) = last_seen.get(chan_id) { + url = format!("{}&after={}", url, after); + } + + let msgs = match client + .get(&url) + .header("Authorization", format!("Bot {}", token_clone)) + .send() + .await + { + Ok(r) if r.status().is_success() => { + r.json::>().await.unwrap_or_default() + } + _ => continue, + }; + + // Messages come newest-first; process oldest-first + for msg in msgs.iter().rev() { + // Skip bot messages (including self) + if msg.author.bot || msg.author.id == bot_id { + continue; + } + + // Try to download first image attachment + let media = if let Some(att) = msg.attachments.iter().find(|a| { + a.content_type + .as_deref() + .map_or(false, |ct| ct.starts_with("image/")) + }) { + download_discord_attachment(&client, att).await + } else { + None + }; + + // Skip if there's neither text nor media + let text = msg.content.clone(); + if text.trim().is_empty() && media.is_none() { + continue; + } + + let envelope = InboundEnvelope { + channel: "discord".into(), + external_id: msg.author.id.clone(), + sender_name: Some(msg.author.username.clone()), + text: if text.trim().is_empty() { + String::new() + } else { + text + }, + media, + reply_to: None, + group_id: Some(msg.channel_id.clone()), + callback_url: None, + raw: serde_json::Value::Null, + timestamp: now_unix(), + }; + + if let Err(e) = inbound_tx.send(envelope).await { + tracing::error!(error = %e, "discord: failed to send to pipeline"); + } + } + + // Update last_seen to newest message + if let Some(newest) = msgs.first() { + last_seen.insert(chan_id.clone(), newest.id.clone()); + } + } + } + } + } + tracing::info!("discord polling loop stopped"); + }); + + Ok(()) + } + + async fn stop(&self) -> Result<()> { + let _ = self.cancel.send(true); + self.started.store(false, Ordering::SeqCst); + tracing::info!("discord channel stopped"); + Ok(()) + } + + async fn send(&self, target: &OutboundTarget, message: OutboundMessage) -> Result<()> { + let token = std::env::var("DISCORD_BOT_TOKEN").map_err(|_| { + CortexError::Channel("DISCORD_BOT_TOKEN not set".into()) + })?; + + // The external_id for Discord can be a channel_id or user_id. + // For DMs, we need to create a DM channel first. + let channel_id = if let Some(ref gid) = target.group_id { + // group_id is the Discord channel_id for guild messages + gid.clone() + } else { + // For DMs, create a DM channel with the user + let dm_resp = self + .client + .post(format!("{}/users/@me/channels", DISCORD_API)) + .header("Authorization", format!("Bot {}", token)) + .json(&serde_json::json!({ "recipient_id": target.external_id })) + .send() + .await + .map_err(|e| { + CortexError::Channel(format!("Discord DM channel creation failed: {e}")) + })?; + + let dm: serde_json::Value = dm_resp.json().await.map_err(|e| { + CortexError::Channel(format!("Discord DM parse error: {e}")) + })?; + + dm["id"] + .as_str() + .ok_or_else(|| CortexError::Channel("Discord DM channel missing 'id'".into()))? + .to_string() + }; + + let msg_ref = target.reply_to_message_id.as_ref().map(|id| MessageReference { + message_id: id.clone(), + }); + + let payload = CreateMessage { + content: message.text, + message_reference: msg_ref, + }; + + let url = format!("{}/channels/{}/messages", DISCORD_API, channel_id); + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bot {}", token)) + .json(&payload) + .send() + .await + .map_err(|e| CortexError::Channel(format!("Discord send failed: {e}")))?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(CortexError::Channel(format!( + "Discord send error: {body}" + ))); + } + + Ok(()) + } + + async fn health(&self) -> ChannelHealth { + if !self.started.load(Ordering::SeqCst) { + return ChannelHealth::Disconnected { + reason: "not started".into(), + }; + } + + let token = match std::env::var("DISCORD_BOT_TOKEN") { + Ok(t) => t, + Err(_) => { + return ChannelHealth::Degraded { + reason: "DISCORD_BOT_TOKEN not set".into(), + } + } + }; + + let url = format!("{}/users/@me", DISCORD_API); + match self.client.get(&url).header("Authorization", format!("Bot {}", token)).send().await { + Ok(resp) if resp.status().is_success() => ChannelHealth::Connected, + Ok(resp) => ChannelHealth::Degraded { + reason: format!("API returned {}", resp.status()), + }, + Err(e) => ChannelHealth::Disconnected { + reason: format!("HTTP error: {e}"), + }, + } + } + + fn max_message_length(&self) -> usize { + 2000 + } + + async fn send_typing(&self, target: &OutboundTarget) -> Result<()> { + let token = std::env::var("DISCORD_BOT_TOKEN").unwrap_or_default(); + let channel_id = target + .group_id + .as_deref() + .unwrap_or(&target.external_id); + let url = format!("{}/channels/{}/typing", DISCORD_API, channel_id); + let _ = self + .client + .post(&url) + .header("Authorization", format!("Bot {}", token)) + .send() + .await; + Ok(()) + } +} + +/// Download a Discord image attachment and return it as a `MediaPayload`. +async fn download_discord_attachment( + client: &reqwest::Client, + att: &DiscordAttachment, +) -> Option { + let data = client + .get(&att.url) + .send() + .await + .ok()? + .bytes() + .await + .ok()? + .to_vec(); + + if data.is_empty() { + return None; + } + + let mime = att + .content_type + .clone() + .unwrap_or_else(|| "image/png".to_string()); + + Some(MediaPayload { + kind: MediaKind::Image, + data, + mime_type: mime, + filename: Some(att.filename.clone()), + url: Some(att.url.clone()), + }) +} diff --git a/src/channels/hooks.rs b/src/channels/hooks.rs new file mode 100644 index 0000000..f5ef518 --- /dev/null +++ b/src/channels/hooks.rs @@ -0,0 +1,137 @@ +//! Hooks — lifecycle interception points for the channel pipeline. +//! +//! Inspired by OpenClaw's `before_dispatch` / `after_tool_call` / `session:patch` +//! hooks. Any struct implementing [`ChannelHook`] can be registered with the +//! pipeline to intercept and transform messages at well-defined points: +//! +//! 1. **before_agent** — after identity/session resolution, before the agent runs. +//! The hook receives the mutable envelope and can modify or reject it. +//! 2. **after_agent** — after the agent produces a reply. Can modify the reply. +//! 3. **before_send** — just before a message is sent on a channel. +//! 4. **after_send** — after successful delivery (for logging, metrics, etc.). +//! +//! Hooks are executed in registration order. A hook returning `Err` aborts +//! the pipeline at that stage (except `after_send`, which is best-effort). + +use async_trait::async_trait; + +use crate::error::Result; +use super::types::{InboundEnvelope, OutboundMessage, OutboundTarget}; + +/// A lifecycle hook that intercepts messages at defined pipeline stages. +/// +/// All methods have default no-op implementations so you only need to +/// override the stages you care about. +#[async_trait] +pub trait ChannelHook: Send + Sync { + /// Human-readable name for logging. + fn name(&self) -> &str { + "unnamed-hook" + } + + /// Called after identity/session resolution and before the agent processes + /// the message. Return `Err(CortexError::Pipeline(...))` to reject the message. + /// + /// Use cases: allowlist checks, rate limiting, command parsing. + async fn before_agent(&self, _envelope: &mut InboundEnvelope) -> Result<()> { + Ok(()) + } + + /// Called after the agent produces a reply. The hook receives the original + /// envelope (immutable) and the reply text (mutable). + /// + /// Use cases: content filtering, response augmentation, analytics. + async fn after_agent( + &self, + _envelope: &InboundEnvelope, + _reply: &mut String, + ) -> Result<()> { + Ok(()) + } + + /// Called just before a message chunk is sent on a channel. + /// + /// Use cases: rate limiting, audit logging, message transformation. + async fn before_send( + &self, + _target: &OutboundTarget, + _message: &mut OutboundMessage, + ) -> Result<()> { + Ok(()) + } + + /// Called after a message chunk is successfully sent. + /// + /// Use cases: delivery tracking, metrics, follow-up scheduling. + /// Errors from this hook are logged but do not fail the pipeline. + async fn after_send( + &self, + _target: &OutboundTarget, + _message: &OutboundMessage, + ) -> Result<()> { + Ok(()) + } +} + +// ─── Built-in hooks ───────────────────────────────────── + +/// A logging hook that traces every pipeline stage. +pub struct TracingHook; + +#[async_trait] +impl ChannelHook for TracingHook { + fn name(&self) -> &str { + "tracing" + } + + async fn before_agent(&self, envelope: &mut InboundEnvelope) -> Result<()> { + tracing::info!( + channel = %envelope.channel, + sender = %envelope.external_id, + text_len = envelope.text.len(), + "before_agent" + ); + Ok(()) + } + + async fn after_agent( + &self, + envelope: &InboundEnvelope, + reply: &mut String, + ) -> Result<()> { + tracing::info!( + channel = %envelope.channel, + sender = %envelope.external_id, + reply_len = reply.len(), + "after_agent" + ); + Ok(()) + } + + async fn before_send( + &self, + target: &OutboundTarget, + message: &mut OutboundMessage, + ) -> Result<()> { + tracing::info!( + channel = %target.channel, + recipient = %target.external_id, + text_len = message.text.len(), + "before_send" + ); + Ok(()) + } + + async fn after_send( + &self, + target: &OutboundTarget, + _message: &OutboundMessage, + ) -> Result<()> { + tracing::info!( + channel = %target.channel, + recipient = %target.external_id, + "after_send — delivered" + ); + Ok(()) + } +} diff --git a/src/channels/mod.rs b/src/channels/mod.rs new file mode 100644 index 0000000..1512d6d --- /dev/null +++ b/src/channels/mod.rs @@ -0,0 +1,106 @@ +//! Channel system — omnichannel messaging adapters. +//! +//! Every messaging platform (Telegram, Discord, Slack, WhatsApp, WebChat, +//! generic webhook) implements the [`Channel`] trait. The [`ChannelRegistry`] +//! manages their lifecycle, and the [`Pipeline`] routes messages through a +//! normalised inbound → agent → outbound flow with [`ChannelHook`] interception. +//! +//! # Architecture +//! +//! ```text +//! Platform → Adapter → InboundEnvelope → Pipeline → Agent → OutboundMessage → Adapter → Platform +//! ``` +//! +//! All channels share the same pipeline. Business logic stays in the pipeline +//! and agent; the Channel trait only handles platform wire protocol. + +pub mod types; +pub mod hooks; +pub mod registry; +pub mod pipeline; +pub mod webhook; +pub mod telegram; +pub mod discord; +pub mod webchat; + +// Re-export the public surface. +pub use types::*; +pub use hooks::{ChannelHook, TracingHook}; +pub use registry::ChannelRegistry; +pub use pipeline::Pipeline; + +use async_trait::async_trait; + +use crate::error::Result; + +// ─── Channel trait ────────────────────────────────────── + +/// The core abstraction: every messaging platform adapter implements this. +/// +/// A channel knows how to: +/// - Start listening for inbound messages (push them into the pipeline). +/// - Send outbound messages back to users on its platform. +/// - Report its health status. +/// +/// Optional capabilities (typing indicators, message editing, media) have +/// default no-op implementations so simple channels needn't bother. +#[async_trait] +pub trait Channel: Send + Sync + 'static { + /// Unique, lowercase channel identifier: `"telegram"`, `"discord"`, etc. + fn id(&self) -> &str; + + /// Human-readable display name. + fn display_name(&self) -> &str; + + /// Start the adapter. + /// + /// The implementation should spawn any long-running tasks (polling loops, + /// WebSocket connections) and push inbound messages into + /// `ctx.inbound_tx`. It must respect `ctx.shutdown` to exit cleanly. + async fn start(&self, ctx: ChannelContext) -> Result<()>; + + /// Stop the adapter gracefully. Called before process exit. + async fn stop(&self) -> Result<()>; + + /// Send a message to a user on this platform. + async fn send(&self, target: &OutboundTarget, message: OutboundMessage) -> Result<()>; + + /// Report current health. + async fn health(&self) -> ChannelHealth; + + // ── Optional capabilities ─────────────────────────── + + /// Maximum text length this channel supports per message. + /// The outbound pipeline uses this for chunking. + fn max_message_length(&self) -> usize { + types::max_message_length(self.id()) + } + + /// Send a "typing…" indicator. + async fn send_typing(&self, _target: &OutboundTarget) -> Result<()> { + Ok(()) // no-op by default + } + + /// Edit an already-sent message (for streaming responses). + async fn edit_message(&self, _message_id: &str, _new_text: &str) -> Result<()> { + Err(crate::error::CortexError::Unsupported( + "edit_message not supported on this channel".into(), + )) + } + + /// Whether this channel supports media attachments. + fn supports_media(&self) -> bool { + false + } + + /// Send a media attachment. + async fn send_media( + &self, + _target: &OutboundTarget, + _media: MediaPayload, + ) -> Result<()> { + Err(crate::error::CortexError::Unsupported( + "send_media not supported on this channel".into(), + )) + } +} diff --git a/src/channels/pipeline.rs b/src/channels/pipeline.rs new file mode 100644 index 0000000..8c5b20d --- /dev/null +++ b/src/channels/pipeline.rs @@ -0,0 +1,369 @@ +//! Pipeline — inbound and outbound message processing. +//! +//! The pipeline is the heart of the omnichannel system. Every inbound message +//! — regardless of which channel it came from — flows through the same stages: +//! +//! ```text +//! Inbound: +//! 1. Normalise (trim, detect /commands) +//! 2. Identity resolution (channel + external_id → internal user_id) +//! 3. Session resolution (user_id + channel → session graph node) +//! 4. Hooks: before_agent (allowlist, rate-limit, commands) +//! 5. Agent: run_turn(session_id, text) → reply +//! 6. Hooks: after_agent (content filtering, augmentation) +//! 7. Record turn +//! 8. Outbound delivery +//! +//! Outbound: +//! 1. Hooks: before_send +//! 2. Chunk (split long replies per channel limits) +//! 3. Channel.send(target, chunk) +//! 4. Hooks: after_send +//! ``` + +use std::sync::Arc; + +use tokio::sync::mpsc; + +use crate::agent::Agent; +use crate::db::Db; +use crate::error::{CortexError, Result}; +use crate::identity::{self, ChannelId}; +use crate::session; +use crate::types::TurnContext; + +use super::hooks::ChannelHook; +use super::registry::ChannelRegistry; +use super::types::*; + +/// The unified message pipeline. +pub struct Pipeline { + /// Channel registry — used for outbound routing. + registry: Arc, + /// Registered hooks, executed in order. + hooks: Vec>, +} + +impl Pipeline { + /// Create a new pipeline. + pub fn new(registry: Arc) -> Self { + Self { + registry, + hooks: Vec::new(), + } + } + + /// Register a lifecycle hook. Hooks execute in registration order. + pub fn add_hook(&mut self, hook: Arc) { + tracing::info!(hook = hook.name(), "pipeline hook registered"); + self.hooks.push(hook); + } + + /// Process a single inbound message through the full pipeline. + /// + /// This is the core method. It performs identity resolution, session + /// management, agent execution, and outbound delivery. + pub async fn process( + &self, + mut envelope: InboundEnvelope, + db: &Db, + agent: &Agent, + ) -> Result { + // ── 1. Normalise ──────────────────────────────── + normalise(&mut envelope); + + // ── 2. Identity resolution ────────────────────── + let channel_id = ChannelId::new(&envelope.channel, &envelope.external_id); + let user = identity::resolve_user(db, channel_id).await.map_err(|e| { + CortexError::Pipeline(format!("Identity resolution failed: {e}")) + })?; + + // ── 3. Session resolution ─────────────────────── + let managed = session::get_or_create(db, &user.id, &envelope.channel) + .await + .map_err(|e| { + CortexError::Pipeline(format!("Session resolution failed: {e}")) + })?; + + // ── 4. Hooks: before_agent ────────────────────── + for hook in &self.hooks { + if let Err(e) = hook.before_agent(&mut envelope).await { + tracing::warn!( + hook = hook.name(), + error = %e, + "before_agent hook rejected message" + ); + return Err(e); + } + } + + // ── 5. Agent ──────────────────────────────────── + let turn_ctx = TurnContext { + channel: envelope.channel.clone(), + sender_name: envelope.sender_name.clone().or(user.display_name.clone()), + user_id: user.id.clone(), + is_group: envelope.group_id.is_some(), + }; + + let mut reply = agent + .run_turn(&managed.node_id, &envelope.text, &turn_ctx, envelope.media.as_ref()) + .await + .map_err(|e| CortexError::Pipeline(format!("Agent error: {e}")))?; + + // ── 6. Hooks: after_agent ─────────────────────── + for hook in &self.hooks { + if let Err(e) = hook.after_agent(&envelope, &mut reply).await { + tracing::warn!( + hook = hook.name(), + error = %e, + "after_agent hook error (continuing)" + ); + // after_agent errors are non-fatal — we still have a reply + } + } + + // ── 7. Record turn ───────────────────────────── + let _ = session::record_turn(db, &managed.node_id).await; + + // ── 8. Outbound delivery ──────────────────────── + let target = OutboundTarget::from_envelope(&envelope); + let message = OutboundMessage::text(&reply); + + if let Err(e) = self.send_outbound(&target, message).await { + tracing::error!( + channel = %target.channel, + error = %e, + "outbound delivery failed" + ); + // Don't fail the whole pipeline — the reply was produced, just delivery failed + } + + Ok(PipelineResult { + reply, + user_id: user.id, + session_id: managed.node_id, + }) + } + + /// Process a message synchronously (no outbound delivery). + /// + /// Used by the HTTP API where the caller handles the response directly. + pub async fn process_sync( + &self, + mut envelope: InboundEnvelope, + db: &Db, + agent: &Agent, + ) -> Result { + normalise(&mut envelope); + + let channel_id = ChannelId::new(&envelope.channel, &envelope.external_id); + let user = identity::resolve_user(db, channel_id).await.map_err(|e| { + CortexError::Pipeline(format!("Identity resolution failed: {e}")) + })?; + + let managed = session::get_or_create(db, &user.id, &envelope.channel) + .await + .map_err(|e| { + CortexError::Pipeline(format!("Session resolution failed: {e}")) + })?; + + for hook in &self.hooks { + hook.before_agent(&mut envelope).await?; + } + + let turn_ctx = TurnContext { + channel: envelope.channel.clone(), + sender_name: envelope.sender_name.clone().or(user.display_name.clone()), + user_id: user.id.clone(), + is_group: envelope.group_id.is_some(), + }; + + let mut reply = agent + .run_turn(&managed.node_id, &envelope.text, &turn_ctx, envelope.media.as_ref()) + .await + .map_err(|e| CortexError::Pipeline(format!("Agent error: {e}")))?; + + for hook in &self.hooks { + let _ = hook.after_agent(&envelope, &mut reply).await; + } + + let _ = session::record_turn(db, &managed.node_id).await; + + Ok(PipelineResult { + reply, + user_id: user.id, + session_id: managed.node_id, + }) + } + + /// Send an outbound message, running hooks and applying chunking. + pub async fn send_outbound( + &self, + target: &OutboundTarget, + mut message: OutboundMessage, + ) -> Result<()> { + // ── before_send hooks ─────────────────────────── + for hook in &self.hooks { + if let Err(e) = hook.before_send(target, &mut message).await { + tracing::warn!( + hook = hook.name(), + error = %e, + "before_send hook rejected" + ); + return Err(e); + } + } + + // ── Chunking ─────────────────────────────────── + let channel = self.registry.get(&target.channel).await; + let max_len = channel + .as_ref() + .map(|ch| ch.max_message_length()) + .unwrap_or(4096); + + let chunks = chunk_text(&message.text, max_len); + + // ── Send each chunk ───────────────────────────── + if let Some(ch) = channel { + for (i, chunk) in chunks.iter().enumerate() { + let chunk_msg = OutboundMessage { + text: chunk.clone(), + media: if i == 0 { message.media.clone() } else { None }, + metadata: message.metadata.clone(), + }; + + ch.send(target, chunk_msg).await?; + + // Send typing between chunks (except the last) + if i < chunks.len() - 1 { + let _ = ch.send_typing(target).await; + } + } + } else { + tracing::warn!( + channel = %target.channel, + "no channel adapter found for outbound — skipping delivery" + ); + } + + // ── after_send hooks ──────────────────────────── + for hook in &self.hooks { + if let Err(e) = hook.after_send(target, &message).await { + tracing::warn!( + hook = hook.name(), + error = %e, + "after_send hook error (ignored)" + ); + } + } + + Ok(()) + } + + /// Start the background inbound processing loop. + /// + /// Reads envelopes from the channel registry's inbound receiver and + /// processes each one through the pipeline. Runs until the receiver is + /// closed (i.e. all channels have stopped). + pub async fn run_inbound_loop( + self: Arc, + mut rx: mpsc::Receiver, + db: Db, + agent: Arc, + ) { + tracing::info!("pipeline inbound loop started"); + while let Some(envelope) = rx.recv().await { + let pipeline = Arc::clone(&self); + let db = db.clone(); + let agent = Arc::clone(&agent); + + // Process each message in its own task so one slow turn + // doesn't block the rest. + tokio::spawn(async move { + match pipeline.process(envelope, &db, &agent).await { + Ok(result) => { + tracing::debug!( + user_id = %result.user_id, + session_id = %result.session_id, + reply_len = result.reply.len(), + "pipeline turn complete" + ); + } + Err(e) => { + tracing::error!(error = %e, "pipeline processing error"); + } + } + }); + } + tracing::info!("pipeline inbound loop ended (all channels closed)"); + } +} + +// ─── Helpers ──────────────────────────────────────────── + +/// Normalise an inbound envelope: trim whitespace, collapse newlines. +fn normalise(envelope: &mut InboundEnvelope) { + envelope.text = envelope.text.trim().to_string(); +} + +/// Split text into chunks respecting the given max length. +/// +/// Tries to split on double-newlines first, then single newlines, then spaces. +/// Falls back to hard character splits as a last resort. +fn chunk_text(text: &str, max_len: usize) -> Vec { + if text.len() <= max_len { + return vec![text.to_string()]; + } + + let mut chunks = Vec::new(); + let mut remaining = text; + + while !remaining.is_empty() { + if remaining.len() <= max_len { + chunks.push(remaining.to_string()); + break; + } + + // Try to find a good split point + let slice = &remaining[..max_len]; + let split_at = slice + .rfind("\n\n") + .or_else(|| slice.rfind('\n')) + .or_else(|| slice.rfind(' ')) + .unwrap_or(max_len); + + let (chunk, rest) = remaining.split_at(split_at); + chunks.push(chunk.trim_end().to_string()); + remaining = rest.trim_start(); + } + + chunks +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_chunk_short_message() { + let chunks = chunk_text("hello", 100); + assert_eq!(chunks, vec!["hello"]); + } + + #[test] + fn test_chunk_on_newline() { + let text = "line one\n\nline two\n\nline three"; + let chunks = chunk_text(text, 15); + assert_eq!(chunks.len(), 3); + assert_eq!(chunks[0], "line one"); + assert_eq!(chunks[1], "line two"); + assert_eq!(chunks[2], "line three"); + } + + #[test] + fn test_chunk_on_space() { + let text = "word1 word2 word3 word4"; + let chunks = chunk_text(text, 12); + assert!(chunks.len() >= 2); + } +} diff --git a/src/channels/registry.rs b/src/channels/registry.rs new file mode 100644 index 0000000..adfe581 --- /dev/null +++ b/src/channels/registry.rs @@ -0,0 +1,136 @@ +//! Channel registry — manages the lifecycle of all channel adapters. +//! +//! The registry owns every registered [`Channel`], starts and stops them as a +//! group, and provides the shared inbound MPSC sender that channels push +//! messages into. The [`Pipeline`] reads from the other end. + +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::{mpsc, watch, RwLock}; + +use crate::db::Db; +use crate::error::Result; + +use super::types::{ChannelContext, ChannelHealth, InboundEnvelope}; +use super::Channel; + +/// Manages all registered channel adapters. +pub struct ChannelRegistry { + /// Registered channels keyed by their `id()`. + channels: Arc>>>, + /// The sending half — cloned to each channel on start. + inbound_tx: mpsc::Sender, + /// The receiving half — handed to the pipeline. + inbound_rx: Option>, + /// Shutdown broadcaster. + shutdown_tx: watch::Sender, + /// Shutdown receiver (cloned per channel). + shutdown_rx: watch::Receiver, +} + +impl ChannelRegistry { + /// Create a new registry with the given inbound buffer size. + pub fn new(buffer: usize) -> Self { + let (inbound_tx, inbound_rx) = mpsc::channel(buffer); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + Self { + channels: Arc::new(RwLock::new(HashMap::new())), + inbound_tx, + inbound_rx: Some(inbound_rx), + shutdown_tx, + shutdown_rx, + } + } + + /// Register a channel adapter. If one with the same `id()` already exists + /// it is replaced (the old one is not stopped — call `stop_all` first). + pub async fn register(&self, channel: Arc) { + let id = channel.id().to_string(); + tracing::info!(channel = %id, "channel registered"); + self.channels.write().await.insert(id, channel); + } + + /// Take the inbound receiver. Only call once — the pipeline needs it. + pub fn take_inbound_rx(&mut self) -> Option> { + self.inbound_rx.take() + } + + /// Start all registered channels. + /// + /// Each channel receives its own [`ChannelContext`] with a cloned + /// `inbound_tx`, the DB handle, its config section from `channel_configs`, + /// and the shutdown watch. + pub async fn start_all( + &self, + db: &Db, + channel_configs: &HashMap, + ) -> Result<()> { + let channels = self.channels.read().await; + let mut started = 0usize; + let mut failed = 0usize; + + for (id, ch) in channels.iter() { + let config = channel_configs + .get(id) + .cloned() + .unwrap_or(serde_json::Value::Null); + + let ctx = ChannelContext { + inbound_tx: self.inbound_tx.clone(), + db: db.clone(), + config, + shutdown: self.shutdown_rx.clone(), + }; + + tracing::info!(channel = %id, "starting channel adapter"); + match ch.start(ctx).await { + Ok(()) => { + started += 1; + } + Err(e) => { + // Log and skip — don't abort other channels + tracing::warn!(channel = %id, error = %e, "channel failed to start (skipping)"); + failed += 1; + } + } + } + + tracing::info!(started, failed, "channel startup complete"); + Ok(()) + } + + /// Stop all registered channels and signal shutdown. + pub async fn stop_all(&self) -> Result<()> { + let _ = self.shutdown_tx.send(true); + let channels = self.channels.read().await; + for (id, ch) in channels.iter() { + tracing::info!(channel = %id, "stopping channel adapter"); + if let Err(e) = ch.stop().await { + tracing::warn!(channel = %id, error = %e, "error stopping channel"); + } + } + Ok(()) + } + + /// Get a channel adapter by its ID (for outbound routing). + pub async fn get(&self, id: &str) -> Option> { + self.channels.read().await.get(id).cloned() + } + + /// List all channels and their current health. + pub async fn health_all(&self) -> Vec<(String, ChannelHealth)> { + let channels = self.channels.read().await; + let mut out = Vec::with_capacity(channels.len()); + for (id, ch) in channels.iter() { + let h = ch.health().await; + out.push((id.clone(), h)); + } + out + } + + /// List the IDs of all registered channels. + pub async fn list_ids(&self) -> Vec { + self.channels.read().await.keys().cloned().collect() + } +} diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs new file mode 100644 index 0000000..038a3d4 --- /dev/null +++ b/src/channels/telegram.rs @@ -0,0 +1,569 @@ +//! Telegram channel adapter — connects to the Telegram Bot API. +//! +//! Supports two modes: +//! - **Polling** (default): calls `getUpdates` in a loop. Simple, no public URL needed. +//! - **Webhook**: Telegram POSTs updates to our `/v1/channels/telegram/webhook`. +//! Requires a public HTTPS URL. +//! +//! # Configuration +//! +//! ```json +//! { +//! "bot_token": "123456:ABCDEF…", +//! "mode": "polling", // "polling" or "webhook" +//! "webhook_url": "https://…", // required if mode = "webhook" +//! "allow_from": ["*"], // Telegram user IDs, or "*" for all +//! "polling_timeout": 30 // long-poll timeout in seconds +//! } +//! ``` +//! +//! The `bot_token` can also be set via the `TELEGRAM_BOT_TOKEN` env var +//! (env var takes precedence). + +use std::sync::atomic::{AtomicBool, Ordering}; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::error::{CortexError, Result}; + +use super::types::*; +use super::Channel; + +const BASE_URL: &str = "https://api.telegram.org/bot"; + +/// Telegram channel adapter. +pub struct TelegramChannel { + client: reqwest::Client, + started: AtomicBool, + /// Stored after start() to allow stop() to signal shutdown. + cancel: tokio::sync::watch::Sender, +} + +impl TelegramChannel { + pub fn new() -> Self { + let (cancel, _) = tokio::sync::watch::channel(false); + Self { + client: reqwest::Client::new(), + started: AtomicBool::new(false), + cancel, + } + } + + fn resolve_token(config: &serde_json::Value) -> Result { + // Env var takes precedence + if let Ok(token) = std::env::var("TELEGRAM_BOT_TOKEN") { + return Ok(token); + } + config + .get("bot_token") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .ok_or_else(|| { + CortexError::Config( + "Telegram: bot_token not set in config or TELEGRAM_BOT_TOKEN env var" + .into(), + ) + }) + } +} + +// ─── Telegram API types ───────────────────────────────── + +#[derive(Debug, Deserialize)] +struct TgResponse { + ok: bool, + result: Option, + description: Option, +} + +#[derive(Debug, Deserialize)] +struct TgUpdate { + update_id: i64, + message: Option, +} + +#[derive(Debug, Deserialize)] +struct TgMessage { + message_id: i64, + from: Option, + chat: TgChat, + text: Option, + /// Caption for media messages (photo, document, etc.). + caption: Option, + /// Photo sizes — Telegram sends multiple resolutions; we pick the largest. + photo: Option>, +} + +#[derive(Debug, Deserialize)] +struct TgPhotoSize { + file_id: String, + #[allow(dead_code)] + file_unique_id: String, + #[allow(dead_code)] + width: i64, + #[allow(dead_code)] + height: i64, + #[serde(default)] + file_size: Option, +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +struct TgFile { + file_id: String, + file_path: Option, +} + +#[derive(Debug, Deserialize)] +struct TgUser { + id: i64, + first_name: String, + last_name: Option, + #[allow(dead_code)] + username: Option, +} + +#[derive(Debug, Deserialize)] +struct TgChat { + id: i64, + #[serde(rename = "type")] + chat_type: String, +} + +#[derive(Debug, Serialize)] +struct SendMessageRequest { + chat_id: i64, + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + reply_to_message_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + parse_mode: Option, +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +struct TgSentMessage { + message_id: i64, +} + +// ─── Channel implementation ───────────────────────────── + +#[async_trait] +impl Channel for TelegramChannel { + fn id(&self) -> &str { + "telegram" + } + + fn display_name(&self) -> &str { + "Telegram" + } + + async fn start(&self, ctx: ChannelContext) -> Result<()> { + let token = Self::resolve_token(&ctx.config)?; + let mode = ctx + .config + .get("mode") + .and_then(|v| v.as_str()) + .unwrap_or("polling"); + + let polling_timeout = ctx + .config + .get("polling_timeout") + .and_then(|v| v.as_u64()) + .unwrap_or(30); + + match mode { + "polling" => { + self.started.store(true, Ordering::SeqCst); + let client = self.client.clone(); + let inbound_tx = ctx.inbound_tx.clone(); + let mut shutdown = ctx.shutdown.clone(); + let mut cancel_rx = self.cancel.subscribe(); + + tokio::spawn(async move { + let mut offset: i64 = 0; + tracing::info!("telegram polling loop started"); + + loop { + // Check shutdown signals + if *shutdown.borrow() || *cancel_rx.borrow() { + tracing::info!("telegram polling loop shutting down"); + break; + } + + let url = format!( + "{}{}/getUpdates?offset={}&timeout={}&allowed_updates=[\"message\"]", + BASE_URL, token, offset, polling_timeout + ); + + let result = tokio::select! { + r = client.get(&url).send() => r, + _ = shutdown.changed() => break, + _ = cancel_rx.changed() => break, + }; + + match result { + Ok(resp) => { + match resp.json::>>().await { + Ok(tg_resp) if tg_resp.ok => { + if let Some(updates) = tg_resp.result { + for update in updates { + offset = update.update_id + 1; + if let Some(msg) = update.message { + // Extract text: prefer `text`, fall back to `caption` for media messages + let text = msg.text.clone() + .or_else(|| msg.caption.clone()); + + // Download photo if present + let media = if let Some(ref photos) = msg.photo { + // Pick the largest photo (last in the array) + if let Some(photo) = photos.last() { + download_telegram_photo( + &client, &token, &photo.file_id, + ).await + } else { + None + } + } else { + None + }; + + // Skip if there's neither text nor media + if text.is_none() && media.is_none() { + continue; + } + + let sender_id = msg + .from + .as_ref() + .map(|u| u.id.to_string()) + .unwrap_or_else(|| { + msg.chat.id.to_string() + }); + let sender_name = msg.from.as_ref().map( + |u| { + let mut name = u.first_name.clone(); + if let Some(ref last) = u.last_name + { + name.push(' '); + name.push_str(last); + } + name + }, + ); + + let group_id = + if msg.chat.chat_type != "private" { + Some(msg.chat.id.to_string()) + } else { + None + }; + + let envelope = InboundEnvelope { + channel: "telegram".into(), + external_id: sender_id, + sender_name, + text: text.unwrap_or_default(), + media, + reply_to: None, + group_id, + callback_url: None, + raw: serde_json::json!({ + "chat_id": msg.chat.id, + "message_id": msg.message_id, + }), + timestamp: now_unix(), + }; + + if inbound_tx.send(envelope).await.is_err() + { + tracing::error!( + "telegram: inbound channel closed" + ); + return; + } + } + } + } + } + Ok(tg_resp) => { + tracing::warn!( + desc = ?tg_resp.description, + "telegram API error" + ); + tokio::time::sleep( + std::time::Duration::from_secs(5), + ) + .await; + } + Err(e) => { + tracing::warn!(error = %e, "telegram parse error"); + tokio::time::sleep( + std::time::Duration::from_secs(5), + ) + .await; + } + } + } + Err(e) => { + tracing::warn!(error = %e, "telegram HTTP error"); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } + } + } + }); + } + "webhook" => { + // Webhook mode: Telegram will POST to our endpoint. + // We need to register the webhook URL with Telegram. + let webhook_url = ctx + .config + .get("webhook_url") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + CortexError::Config( + "Telegram webhook mode requires 'webhook_url' in config".into(), + ) + })?; + + let url = format!( + "{}{}/setWebhook?url={}/v1/channels/telegram/webhook", + BASE_URL, token, webhook_url + ); + let resp = self.client.get(&url).send().await.map_err(|e| { + CortexError::Channel(format!("Failed to set Telegram webhook: {e}")) + })?; + + let body: TgResponse = resp.json().await.map_err(|e| { + CortexError::Channel(format!("Failed to parse webhook response: {e}")) + })?; + + if !body.ok { + return Err(CortexError::Channel(format!( + "Telegram setWebhook failed: {}", + body.description.unwrap_or_default() + ))); + } + + self.started.store(true, Ordering::SeqCst); + tracing::info!(url = %webhook_url, "telegram webhook registered"); + } + other => { + return Err(CortexError::Config(format!( + "Unknown Telegram mode: '{other}'. Use 'polling' or 'webhook'." + ))); + } + } + + Ok(()) + } + + async fn stop(&self) -> Result<()> { + let _ = self.cancel.send(true); + self.started.store(false, Ordering::SeqCst); + tracing::info!("telegram channel stopped"); + Ok(()) + } + + async fn send(&self, target: &OutboundTarget, message: OutboundMessage) -> Result<()> { + // Determine the chat_id: use group_id if present, otherwise external_id + let chat_id: i64 = target + .group_id + .as_deref() + .or(Some(&target.external_id)) + .and_then(|s| s.parse().ok()) + .ok_or_else(|| { + CortexError::Channel(format!( + "Invalid Telegram chat_id: {}", + target.external_id + )) + })?; + + // We need the token — try env var first, then check raw metadata + let token = std::env::var("TELEGRAM_BOT_TOKEN").map_err(|_| { + CortexError::Channel( + "TELEGRAM_BOT_TOKEN not set — cannot send outbound message".into(), + ) + })?; + + let reply_to = target + .reply_to_message_id + .as_ref() + .and_then(|s| s.parse::().ok()); + + let req = SendMessageRequest { + chat_id, + text: message.text, + reply_to_message_id: reply_to, + parse_mode: Some("Markdown".into()), + }; + + let url = format!("{}{}/sendMessage", BASE_URL, token); + let resp = self + .client + .post(&url) + .json(&req) + .send() + .await + .map_err(|e| CortexError::Channel(format!("Telegram sendMessage failed: {e}")))?; + + let body: TgResponse = resp.json().await.map_err(|e| { + CortexError::Channel(format!("Telegram sendMessage parse error: {e}")) + })?; + + if !body.ok { + return Err(CortexError::Channel(format!( + "Telegram sendMessage error: {}", + body.description.unwrap_or_default() + ))); + } + + Ok(()) + } + + async fn health(&self) -> ChannelHealth { + if !self.started.load(Ordering::SeqCst) { + return ChannelHealth::Disconnected { + reason: "not started".into(), + }; + } + + // Quick liveness check: call getMe + let token = match std::env::var("TELEGRAM_BOT_TOKEN") { + Ok(t) => t, + Err(_) => { + return ChannelHealth::Degraded { + reason: "TELEGRAM_BOT_TOKEN not set".into(), + } + } + }; + + let url = format!("{}{}/getMe", BASE_URL, token); + match self.client.get(&url).send().await { + Ok(resp) if resp.status().is_success() => ChannelHealth::Connected, + Ok(resp) => ChannelHealth::Degraded { + reason: format!("API returned {}", resp.status()), + }, + Err(e) => ChannelHealth::Disconnected { + reason: format!("HTTP error: {e}"), + }, + } + } + + fn max_message_length(&self) -> usize { + 4096 + } + + async fn send_typing(&self, target: &OutboundTarget) -> Result<()> { + let chat_id: i64 = target + .group_id + .as_deref() + .or(Some(&target.external_id)) + .and_then(|s| s.parse().ok()) + .unwrap_or(0); + + if chat_id == 0 { + return Ok(()); + } + + let token = std::env::var("TELEGRAM_BOT_TOKEN").unwrap_or_default(); + let url = format!( + "{}{}/sendChatAction?chat_id={}&action=typing", + BASE_URL, token, chat_id + ); + let _ = self.client.get(&url).send().await; + Ok(()) + } + + async fn edit_message(&self, message_id: &str, new_text: &str) -> Result<()> { + // Telegram supports editMessageText but we'd need the chat_id too. + // For now, the basic implementation assumes the message_id is "chat_id:message_id". + let parts: Vec<&str> = message_id.splitn(2, ':').collect(); + if parts.len() != 2 { + return Err(CortexError::Channel( + "Telegram edit_message requires 'chat_id:message_id' format".into(), + )); + } + + let token = std::env::var("TELEGRAM_BOT_TOKEN").map_err(|_| { + CortexError::Channel("TELEGRAM_BOT_TOKEN not set".into()) + })?; + + let payload = serde_json::json!({ + "chat_id": parts[0].parse::().unwrap_or(0), + "message_id": parts[1].parse::().unwrap_or(0), + "text": new_text, + "parse_mode": "Markdown", + }); + + let url = format!("{}{}/editMessageText", BASE_URL, token); + let _ = self.client.post(&url).json(&payload).send().await; + Ok(()) + } +} + +fn now_unix() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64 +} + +/// Download a Telegram photo by its `file_id`. +/// +/// 1. `getFile` → obtain `file_path` +/// 2. Download raw bytes from `https://api.telegram.org/file/bot/` +/// 3. Return as `MediaPayload { kind: Image, data, mime_type: "image/jpeg" }` +async fn download_telegram_photo( + client: &reqwest::Client, + token: &str, + file_id: &str, +) -> Option { + // Step 1: getFile + let url = format!("{}{}/getFile?file_id={}", BASE_URL, token, file_id); + let resp = client.get(&url).send().await.ok()?; + let body: serde_json::Value = resp.json().await.ok()?; + let file_path = body + .get("result") + .and_then(|r| r.get("file_path")) + .and_then(|p| p.as_str())?; + + // Step 2: download bytes + let download_url = format!( + "https://api.telegram.org/file/bot{}/{}", + token, file_path + ); + let data = client + .get(&download_url) + .send() + .await + .ok()? + .bytes() + .await + .ok()? + .to_vec(); + + if data.is_empty() { + return None; + } + + // Infer MIME from file extension, default to image/jpeg + let mime = if file_path.ends_with(".png") { + "image/png" + } else if file_path.ends_with(".gif") { + "image/gif" + } else if file_path.ends_with(".webp") { + "image/webp" + } else { + "image/jpeg" + }; + + Some(MediaPayload { + kind: MediaKind::Image, + data, + mime_type: mime.to_string(), + filename: Some(file_path.to_string()), + url: Some(download_url), + }) +} diff --git a/src/channels/types.rs b/src/channels/types.rs new file mode 100644 index 0000000..379297c --- /dev/null +++ b/src/channels/types.rs @@ -0,0 +1,214 @@ +//! Channel types — the shared vocabulary for the omnichannel pipeline. +//! +//! These types are channel-agnostic: every adapter speaks in terms of +//! [`InboundEnvelope`] (messages coming in) and [`OutboundMessage`] / +//! [`OutboundTarget`] (messages going out). The pipeline never sees +//! platform-specific payloads. + +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; + +use crate::db::Db; + +// ─── Inbound ──────────────────────────────────────────── + +/// A normalised inbound message from any channel. +/// +/// Channel adapters construct this from raw platform payloads and push it +/// into the pipeline via `ChannelContext::inbound_tx`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InboundEnvelope { + /// Which channel this came from ("telegram", "discord", "webhook", …). + pub channel: String, + /// The sender's external identifier on the channel. + pub external_id: String, + /// Display name of the sender (if the channel provides one). + pub sender_name: Option, + /// The user's message text. + pub text: String, + /// Optional media attachment. + #[serde(skip_serializing_if = "Option::is_none")] + pub media: Option, + /// If replying to a specific message, its platform message ID. + #[serde(skip_serializing_if = "Option::is_none")] + pub reply_to: Option, + /// Group / guild / workspace ID (if this is a group message). + #[serde(skip_serializing_if = "Option::is_none")] + pub group_id: Option, + /// A URL the channel can POST the reply to (webhook callback). + #[serde(skip_serializing_if = "Option::is_none")] + pub callback_url: Option, + /// The raw, channel-specific payload for hooks that need it. + #[serde(default)] + pub raw: serde_json::Value, + /// Unix timestamp (seconds) when the message was received. + pub timestamp: i64, +} + +impl InboundEnvelope { + /// Build a minimal envelope (used by the webhook adapter and tests). + pub fn new(channel: &str, external_id: &str, text: &str) -> Self { + Self { + channel: channel.to_string(), + external_id: external_id.to_string(), + sender_name: None, + text: text.to_string(), + media: None, + reply_to: None, + group_id: None, + callback_url: None, + raw: serde_json::Value::Null, + timestamp: now_unix(), + } + } +} + +// ─── Outbound ─────────────────────────────────────────── + +/// Who to send a reply to. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OutboundTarget { + /// Channel identifier to route through. + pub channel: String, + /// External user/chat ID on that channel. + pub external_id: String, + /// Group/guild context (if replying in a group). + #[serde(skip_serializing_if = "Option::is_none")] + pub group_id: Option, + /// Platform message ID to reply to (threaded replies). + #[serde(skip_serializing_if = "Option::is_none")] + pub reply_to_message_id: Option, + /// Optional callback URL (for webhook channels that POST replies). + #[serde(skip_serializing_if = "Option::is_none")] + pub callback_url: Option, +} + +impl OutboundTarget { + /// Derive an outbound target from an inbound envelope. + pub fn from_envelope(env: &InboundEnvelope) -> Self { + Self { + channel: env.channel.clone(), + external_id: env.external_id.clone(), + group_id: env.group_id.clone(), + reply_to_message_id: env.reply_to.clone(), + callback_url: env.callback_url.clone(), + } + } +} + +/// An outbound message — text, optional media, arbitrary metadata. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OutboundMessage { + /// The reply text. + pub text: String, + /// Optional media attachment. + #[serde(skip_serializing_if = "Option::is_none")] + pub media: Option, + /// Arbitrary metadata (per-channel or per-hook). + #[serde(default)] + pub metadata: serde_json::Value, +} + +impl OutboundMessage { + /// Plain text reply. + pub fn text(s: impl Into) -> Self { + Self { + text: s.into(), + media: None, + metadata: serde_json::Value::Null, + } + } +} + +// ─── Media ────────────────────────────────────────────── + +/// A media attachment (image, audio, video, document). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MediaPayload { + pub kind: MediaKind, + /// Raw bytes — `#[serde(skip)]` because we don't serialise blobs over JSON. + #[serde(skip)] + pub data: Vec, + /// MIME type, e.g. "image/jpeg". + pub mime_type: String, + /// Original filename, if known. + #[serde(skip_serializing_if = "Option::is_none")] + pub filename: Option, + /// URL where the media can be fetched (for channels that use URLs). + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum MediaKind { + Image, + Audio, + Video, + Document, +} + +// ─── Channel health ───────────────────────────────────── + +/// Reported by each channel adapter via the `health()` method. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "status", rename_all = "snake_case")] +pub enum ChannelHealth { + Connected, + Degraded { reason: String }, + Disconnected { reason: String }, +} + +// ─── Channel context ──────────────────────────────────── + +/// Passed to a channel adapter when it starts. +/// +/// Gives the adapter everything it needs to push inbound messages into the +/// pipeline and access channel-specific configuration. +pub struct ChannelContext { + /// Push inbound messages here — the pipeline picks them up. + pub inbound_tx: mpsc::Sender, + /// Database handle (for low-level needs; most channels don't need this). + pub db: Db, + /// Channel-specific configuration section (parsed from master config). + pub config: serde_json::Value, + /// Shutdown signal — channels should select on this and exit gracefully. + pub shutdown: tokio::sync::watch::Receiver, +} + +// ─── Message length limits (for outbound chunking) ────── + +/// Maximum message length per channel. If a reply exceeds this, the outbound +/// pipeline will split it into multiple sends. +pub fn max_message_length(channel: &str) -> usize { + match channel { + "telegram" => 4096, + "discord" => 2000, + "slack" => 40_000, + "whatsapp" => 65_536, + "webchat" => 100_000, // practically unlimited + _ => 4096, // conservative default + } +} + +// ─── Pipeline result (returned to the HTTP API) ───────── + +/// The result of processing a single inbound message through the pipeline. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PipelineResult { + /// The agent's reply text. + pub reply: String, + /// Internal user ID resolved by the identity layer. + pub user_id: String, + /// Graph-node session ID. + pub session_id: String, +} + +// ─── Helpers ──────────────────────────────────────────── + +pub(crate) fn now_unix() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64 +} diff --git a/src/channels/webchat.rs b/src/channels/webchat.rs new file mode 100644 index 0000000..b930676 --- /dev/null +++ b/src/channels/webchat.rs @@ -0,0 +1,208 @@ +//! WebChat channel — built-in WebSocket chat served from the gateway. +//! +//! Provides a real-time chat interface via WebSocket upgrade at +//! `ws://host:port/v1/ws/chat`. Supports streaming-style responses by +//! sending the full reply once the agent finishes. +//! +//! # Configuration +//! +//! ```json +//! { +//! "require_auth": false, // whether WS connections need an API key +//! "max_connections": 100 // max concurrent WebSocket connections +//! } +//! ``` + +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::error::Result; + +use super::types::*; +use super::Channel; + +/// WebChat channel adapter. +/// +/// Unlike other channels, WebChat doesn't connect to an external service. +/// It serves WebSocket connections directly from the axum server. The axum +/// WebSocket handler creates `InboundEnvelope` messages and pushes them +/// into the pipeline; replies are sent back through the WebSocket. +/// +/// This struct tracks state but the actual WS upgrade happens in the API +/// layer (axum route). +pub struct WebChatChannel { + started: AtomicBool, + /// Number of active WebSocket connections. + active_connections: Arc, + /// Max concurrent connections. + max_connections: usize, +} + +/// A message sent from the client over WebSocket. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsChatMessage { + /// Client-provided session token (or generated on first connect). + #[serde(default)] + pub session_token: Option, + /// The user's message text. + pub text: String, + /// Optional: unique client message ID for deduplication. + #[serde(default)] + pub client_msg_id: Option, +} + +/// A message sent from the server over WebSocket. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsChatReply { + /// "reply" | "typing" | "error" | "connected" + #[serde(rename = "type")] + pub msg_type: String, + /// The reply text (for type="reply"). + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + /// Session token assigned to this connection. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_token: Option, + /// Error message (for type="error"). + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl WsChatReply { + pub fn connected(session_token: &str) -> Self { + Self { + msg_type: "connected".into(), + text: None, + session_token: Some(session_token.into()), + error: None, + } + } + + pub fn reply(text: &str) -> Self { + Self { + msg_type: "reply".into(), + text: Some(text.into()), + session_token: None, + error: None, + } + } + + pub fn typing() -> Self { + Self { + msg_type: "typing".into(), + text: None, + session_token: None, + error: None, + } + } + + pub fn error(msg: &str) -> Self { + Self { + msg_type: "error".into(), + text: None, + session_token: None, + error: Some(msg.into()), + } + } +} + +impl WebChatChannel { + pub fn new() -> Self { + Self { + started: AtomicBool::new(false), + active_connections: Arc::new(AtomicUsize::new(0)), + max_connections: 100, + } + } + + pub fn with_max_connections(mut self, max: usize) -> Self { + self.max_connections = max; + self + } + + /// Get current active connection count. + pub fn active_connections(&self) -> usize { + self.active_connections.load(Ordering::Relaxed) + } + + /// Get the shared connection counter (for the WS handler to increment/decrement). + pub fn connection_counter(&self) -> Arc { + Arc::clone(&self.active_connections) + } + + /// Check if a new connection can be accepted. + pub fn can_accept(&self) -> bool { + self.active_connections.load(Ordering::Relaxed) < self.max_connections + } + + /// Get the max connections limit. + pub fn max_connections(&self) -> usize { + self.max_connections + } +} + +#[async_trait] +impl Channel for WebChatChannel { + fn id(&self) -> &str { + "webchat" + } + + fn display_name(&self) -> &str { + "WebSocket Chat" + } + + async fn start(&self, ctx: ChannelContext) -> Result<()> { + // Read max_connections from config + if let Some(max) = ctx.config.get("max_connections").and_then(|v| v.as_u64()) { + // Note: we can't mutate self here, but the default is fine. + // A future version could use AtomicUsize for max_connections. + tracing::info!(max_connections = max, "webchat max_connections configured"); + } + + self.started.store(true, Ordering::SeqCst); + tracing::info!( + "webchat channel started (WebSocket connections accepted at /v1/ws/chat)" + ); + Ok(()) + } + + async fn stop(&self) -> Result<()> { + self.started.store(false, Ordering::SeqCst); + let active = self.active_connections.load(Ordering::Relaxed); + if active > 0 { + tracing::info!( + active_connections = active, + "webchat channel stopping — active connections will be dropped" + ); + } + tracing::info!("webchat channel stopped"); + Ok(()) + } + + async fn send(&self, _target: &OutboundTarget, _message: OutboundMessage) -> Result<()> { + // WebChat outbound is handled directly through the WebSocket connection, + // not through this method. The WS handler sends replies inline. + // + // If we need to push messages to a specific session (e.g. notifications), + // we'd maintain a map of session_token → WS sender. That's a Phase 4 feature. + tracing::trace!("webchat send: reply delivered directly via WebSocket"); + Ok(()) + } + + async fn health(&self) -> ChannelHealth { + if self.started.load(Ordering::SeqCst) { + ChannelHealth::Connected + } else { + ChannelHealth::Disconnected { + reason: "not started".into(), + } + } + } + + fn max_message_length(&self) -> usize { + 100_000 // WebSocket messages are practically unlimited + } +} diff --git a/src/channels/webhook.rs b/src/channels/webhook.rs new file mode 100644 index 0000000..fabd4ac --- /dev/null +++ b/src/channels/webhook.rs @@ -0,0 +1,119 @@ +//! Webhook channel — generic inbound/outbound for any platform that POSTs JSON. +//! +//! This is the simplest channel adapter. It doesn't poll or hold connections — +//! it receives messages via the HTTP API (`POST /v1/channels/webhook/inbound`) +//! and optionally delivers replies by POSTing to a callback URL. +//! +//! Any system can integrate with omni-cede without a dedicated adapter by +//! using the webhook channel. + +use std::sync::atomic::{AtomicBool, Ordering}; + +use async_trait::async_trait; + +use crate::error::{CortexError, Result}; + +use super::types::*; +use super::Channel; + +/// A generic webhook channel adapter. +/// +/// Inbound: messages arrive via the HTTP API (the API handler creates +/// `InboundEnvelope` and pushes it into the pipeline). +/// +/// Outbound: if the inbound message included a `callback_url`, the reply +/// is POSTed there. Otherwise the reply is returned synchronously via the +/// HTTP response. +pub struct WebhookChannel { + /// HTTP client for callback delivery. + client: reqwest::Client, + /// Whether the channel is "started" (always true once start() is called). + started: AtomicBool, +} + +impl WebhookChannel { + pub fn new() -> Self { + Self { + client: reqwest::Client::new(), + started: AtomicBool::new(false), + } + } +} + +#[async_trait] +impl Channel for WebhookChannel { + fn id(&self) -> &str { + "webhook" + } + + fn display_name(&self) -> &str { + "Generic Webhook" + } + + async fn start(&self, _ctx: ChannelContext) -> Result<()> { + // Webhook channel is passive — it doesn't poll. Messages come in via + // the HTTP API. We just mark ourselves as started. + self.started.store(true, Ordering::SeqCst); + tracing::info!("webhook channel started (passive — receives via HTTP API)"); + Ok(()) + } + + async fn stop(&self) -> Result<()> { + self.started.store(false, Ordering::SeqCst); + tracing::info!("webhook channel stopped"); + Ok(()) + } + + async fn send(&self, target: &OutboundTarget, message: OutboundMessage) -> Result<()> { + // If there's a callback URL, POST the reply there + if let Some(ref url) = target.callback_url { + let payload = serde_json::json!({ + "channel": target.channel, + "external_id": target.external_id, + "text": message.text, + "metadata": message.metadata, + }); + + let resp = self + .client + .post(url) + .json(&payload) + .send() + .await + .map_err(|e| { + CortexError::Channel(format!("Webhook callback failed: {e}")) + })?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(CortexError::Channel(format!( + "Webhook callback returned {status}: {body}" + ))); + } + + tracing::debug!(url = %url, "webhook callback delivered"); + } else { + // No callback URL — reply was returned synchronously via the HTTP response. + // Nothing to do here. + tracing::trace!("webhook outbound: no callback_url (reply returned synchronously)"); + } + + Ok(()) + } + + async fn health(&self) -> ChannelHealth { + if self.started.load(Ordering::SeqCst) { + ChannelHealth::Connected + } else { + ChannelHealth::Disconnected { + reason: "not started".into(), + } + } + } + + fn max_message_length(&self) -> usize { + // Webhooks have no inherent limit — use a generous default + 100_000 + } +} diff --git a/src/cli/graph_tui.rs b/src/cli/graph_tui.rs index ed97e4f..f8a3b13 100644 --- a/src/cli/graph_tui.rs +++ b/src/cli/graph_tui.rs @@ -18,6 +18,8 @@ use tokio::sync::mpsc; use crate::agent::orchestrator::Agent; use crate::db::Db; +use crate::embed::EmbedHandle; +use crate::hnsw::VectorIndex; use crate::types::*; // ─── Color mapping ────────────────────────────────────── @@ -30,7 +32,9 @@ fn kind_color(kind: NodeKind) -> Color { NodeKind::Session | NodeKind::Turn | NodeKind::LlmCall | NodeKind::ToolCall | NodeKind::LoopIteration => Color::Yellow, NodeKind::Pattern | NodeKind::Limitation | NodeKind::Capability => Color::Green, - NodeKind::SubAgent | NodeKind::Delegation | NodeKind::Synthesis => Color::Blue, + NodeKind::BackgroundTask => Color::Blue, + NodeKind::CronJob | NodeKind::CronExecution | NodeKind::Skill => Color::LightBlue, + NodeKind::Notification => Color::LightYellow, } } @@ -62,7 +66,7 @@ fn short_id(id: &str) -> String { // ─── Category helpers ─────────────────────────────────── -const ALL_CATEGORIES: &[&str] = &["All", "Identity", "Knowledge", "Conversational", "Operational", "Self-Model", "Sub-Agents"]; +const ALL_CATEGORIES: &[&str] = &["All", "Identity", "Knowledge", "Conversational", "Operational", "Self-Model", "Tasks"]; fn node_category(kind: NodeKind) -> &'static str { match kind { @@ -72,7 +76,10 @@ fn node_category(kind: NodeKind) -> &'static str { NodeKind::Session | NodeKind::Turn | NodeKind::LlmCall | NodeKind::ToolCall | NodeKind::LoopIteration => "Operational", NodeKind::Pattern | NodeKind::Limitation | NodeKind::Capability => "Self-Model", - NodeKind::SubAgent | NodeKind::Delegation | NodeKind::Synthesis => "Sub-Agents", + NodeKind::BackgroundTask => "Tasks", + NodeKind::CronJob | NodeKind::CronExecution => "Scheduler", + NodeKind::Skill => "Skills", + NodeKind::Notification => "Notifications", } } @@ -98,6 +105,7 @@ enum Focus { NodeList, Detail, Chat, + SoulEdit, } // ─── App state ────────────────────────────────────────── @@ -127,6 +135,12 @@ pub struct App { // Stats for delta display prev_node_count: usize, prev_edge_count: usize, + // Soul edit modal + edit_node_id: Option, + edit_node_title: String, + edit_input: String, + edit_saving: bool, + prev_focus: Focus, } impl App { @@ -163,6 +177,11 @@ impl App { thinking: false, prev_node_count: nc, prev_edge_count: ec, + edit_node_id: None, + edit_node_title: String::new(), + edit_input: String::new(), + edit_saving: false, + prev_focus: Focus::NodeList, } } @@ -270,6 +289,32 @@ impl App { } conns } + + fn enter_edit_mode(&mut self) { + let info = self.selected_node().and_then(|node| { + match node.kind { + NodeKind::Soul | NodeKind::Belief | NodeKind::Goal => { + Some((node.id.clone(), node.title.clone(), node.body.clone())) + } + _ => None, + } + }); + if let Some((id, title, body)) = info { + self.edit_node_id = Some(id); + self.edit_node_title = title; + self.edit_input = body.unwrap_or_default(); + self.edit_saving = false; + self.prev_focus = self.focus; + self.focus = Focus::SoulEdit; + } + } + + fn exit_edit_mode(&mut self) { + self.edit_node_id = None; + self.edit_input.clear(); + self.edit_node_title.clear(); + self.focus = self.prev_focus; + } } fn build_lookups(nodes: &[Node], edges: &[Edge]) -> ( @@ -292,6 +337,8 @@ fn build_lookups(nodes: &[Node], edges: &[Edge]) -> ( enum AgentResult { Response(String), Error(String), + EditSaved(String), + EditError(String), } // ─── Public entry points ──────────────────────────────── @@ -320,11 +367,137 @@ pub fn run_interactive(nodes: Vec, edges: Vec) -> std::io::Result<() Ok(()) } +/// Launch the graph explorer with editing support for identity nodes (no chat). +pub async fn run_with_edit( + db: Db, + embed: EmbedHandle, + hnsw: Arc>, + start_category: usize, +) -> std::io::Result<()> { + enable_raw_mode()?; + let mut stdout = std::io::stdout(); + execute!(stdout, EnterAlternateScreen)?; + let backend = CrosstermBackend::new(stdout); + let mut terminal = Terminal::new(backend)?; + + let nodes = db.call(|conn| crate::db::queries::get_all_nodes_light(conn)).await + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; + let edges = db.call(|conn| crate::db::queries::get_all_edges(conn)).await + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; + + let mut app = App::new(nodes, edges); + app.focus = Focus::NodeList; + // Pre-filter to specified category (e.g., Identity = 1) + app.category_idx = start_category; + app.refilter(); + + let (result_tx, mut result_rx) = mpsc::unbounded_channel::(); + let mut event_stream = EventStream::new(); + + loop { + terminal.draw(|f| draw(f, &mut app, false))?; + + tokio::select! { + maybe_event = event_stream.next() => { + match maybe_event { + Some(Ok(Event::Key(key))) if key.kind == KeyEventKind::Press => { + if app.focus == Focus::SoulEdit { + match key.code { + KeyCode::Esc => { app.exit_edit_mode(); } + KeyCode::Char('s') if key.modifiers.contains(KeyModifiers::CONTROL) => { + if let Some(ref node_id) = app.edit_node_id { + app.edit_saving = true; + let nid = node_id.clone(); + let new_body = app.edit_input.clone(); + let title = app.edit_node_title.clone(); + let db_c = db.clone(); + let embed_c = embed.clone(); + let hnsw_c = hnsw.clone(); + let tx = result_tx.clone(); + tokio::spawn(async move { + let body_c = new_body.clone(); + let nid_c = nid.clone(); + match db_c.call(move |conn| { + crate::db::queries::update_node_fields( + conn, &nid_c, None, None, Some(&body_c), None, None, + ) + }).await { + Ok(_) => { + let embed_text = format!("{} {}", title, new_body); + if let Ok(vec) = embed_c.embed(&embed_text).await { + hnsw_c.write().await.insert(nid.clone(), vec); + } + let _ = tx.send(AgentResult::EditSaved(nid)); + } + Err(e) => { + let _ = tx.send(AgentResult::EditError(e.to_string())); + } + } + }); + } + } + KeyCode::Enter => { app.edit_input.push('\n'); } + KeyCode::Backspace => { app.edit_input.pop(); } + KeyCode::Char(c) => { app.edit_input.push(c); } + _ => {} + } + } else if app.is_node_search { + match key.code { + KeyCode::Esc => { app.is_node_search = false; app.search_query.clear(); app.refilter(); } + KeyCode::Enter => { app.is_node_search = false; } + KeyCode::Backspace => { app.search_query.pop(); app.refilter(); } + KeyCode::Char(c) => { app.search_query.push(c); app.refilter(); } + _ => {} + } + } else { + // Graph keys + 'e' for edit + match key.code { + KeyCode::Char('e') => app.enter_edit_mode(), + _ => { if handle_graph_keys(&mut app, key) { break; } } + } + } + } + Some(Ok(_)) => {} + Some(Err(_)) => break, + None => break, + } + } + Some(result) = result_rx.recv() => { + match result { + AgentResult::EditSaved(_nid) => { + app.edit_saving = false; + app.exit_edit_mode(); + // Reload graph + if let Ok(nodes) = db.call(|conn| crate::db::queries::get_all_nodes_light(conn)).await { + if let Ok(edges) = db.call(|conn| crate::db::queries::get_all_edges(conn)).await { + app.reload_graph(nodes, edges); + } + } + } + AgentResult::EditError(e) => { + app.edit_saving = false; + // Just exit edit mode on error — user sees original content + app.exit_edit_mode(); + let _ = e; // logged via tracing in production + } + _ => {} + } + } + } + } + + disable_raw_mode()?; + execute!(terminal.backend_mut(), LeaveAlternateScreen)?; + Ok(()) +} + /// Launch the TUI with an embedded chat panel and live graph updates. pub async fn run_with_chat( db: Db, agent: Agent, session_id: String, + embed: Option, + hnsw: Option>>, ) -> std::io::Result<()> { enable_raw_mode()?; let mut stdout = std::io::stdout(); @@ -351,7 +524,53 @@ pub async fn run_with_chat( maybe_event = event_stream.next() => { match maybe_event { Some(Ok(Event::Key(key))) if key.kind == KeyEventKind::Press => { - if app.focus == Focus::Chat && !app.is_node_search { + if app.focus == Focus::SoulEdit { + match key.code { + KeyCode::Esc => { app.exit_edit_mode(); } + KeyCode::Char('s') if key.modifiers.contains(KeyModifiers::CONTROL) => { + // Save the edit + if let Some(ref node_id) = app.edit_node_id { + app.edit_saving = true; + let nid = node_id.clone(); + let new_body = app.edit_input.clone(); + let title = app.edit_node_title.clone(); + let db_c = db.clone(); + let embed_c = embed.clone(); + let hnsw_c = hnsw.clone(); + let tx = result_tx.clone(); + tokio::spawn(async move { + let body_c = new_body.clone(); + let nid_c = nid.clone(); + match db_c.call(move |conn| { + crate::db::queries::update_node_fields( + conn, &nid_c, None, None, Some(&body_c), None, None, + ) + }).await { + Ok(_) => { + // Re-embed if embed handle available + if let Some(ref emb) = embed_c { + let embed_text = format!("{} {}", title, new_body); + if let Ok(vec) = emb.embed(&embed_text).await { + if let Some(ref h) = hnsw_c { + h.write().await.insert(nid.clone(), vec); + } + } + } + let _ = tx.send(AgentResult::EditSaved(nid)); + } + Err(e) => { + let _ = tx.send(AgentResult::EditError(e.to_string())); + } + } + }); + } + } + KeyCode::Enter => { app.edit_input.push('\n'); } + KeyCode::Backspace => { app.edit_input.pop(); } + KeyCode::Char(c) => { app.edit_input.push(c); } + _ => {} + } + } else if app.focus == Focus::Chat && !app.is_node_search { match key.code { KeyCode::Char('c') if key.modifiers.contains(KeyModifiers::CONTROL) => break, KeyCode::Esc => { app.focus = Focus::NodeList; } @@ -373,8 +592,14 @@ pub async fn run_with_chat( let agent_c = agent.clone(); let sid = session_id.clone(); let tx = result_tx.clone(); + let cli_ctx = crate::types::TurnContext { + channel: "cli-tui".to_string(), + sender_name: None, + user_id: "local".to_string(), + is_group: false, + }; tokio::spawn(async move { - match agent_c.run_turn(&sid, &input).await { + match agent_c.run_turn(&sid, &input, &cli_ctx, None).await { Ok(resp) => { let _ = tx.send(AgentResult::Response(resp)); } Err(e) => { let _ = tx.send(AgentResult::Error(e.to_string())); } } @@ -413,6 +638,21 @@ pub async fn run_with_chat( AgentResult::Error(e) => { app.chat_messages.push(ChatMsg { role: ChatRole::System, text: format!("Error: {e}") }); } + AgentResult::EditSaved(_nid) => { + app.edit_saving = false; + app.chat_messages.push(ChatMsg { + role: ChatRole::System, + text: format!("Saved: {}", app.edit_node_title), + }); + app.exit_edit_mode(); + } + AgentResult::EditError(e) => { + app.edit_saving = false; + app.chat_messages.push(ChatMsg { + role: ChatRole::System, + text: format!("Edit error: {e}"), + }); + } } // Reload graph to show new nodes/edges if let Ok(nodes) = db.call(|conn| crate::db::queries::get_all_nodes_light(conn)).await { @@ -467,6 +707,7 @@ fn handle_graph_keys_with_chat(app: &mut App, key: crossterm::event::KeyEvent) - match key.code { KeyCode::Char('q') | KeyCode::Char('Q') => return true, KeyCode::Char('c') if key.modifiers.contains(KeyModifiers::CONTROL) => return true, + KeyCode::Char('e') => app.enter_edit_mode(), KeyCode::Up | KeyCode::Char('k') => nav_up(app), KeyCode::Down | KeyCode::Char('j') => nav_down(app), KeyCode::PageUp => nav_page_up(app), @@ -476,8 +717,8 @@ fn handle_graph_keys_with_chat(app: &mut App, key: crossterm::event::KeyEvent) - KeyCode::Char('f') => { app.category_idx = (app.category_idx + 1) % ALL_CATEGORIES.len(); app.refilter(); } KeyCode::Char('/') => { app.is_node_search = true; app.search_query.clear(); } KeyCode::Enter => drill_into(app), - KeyCode::Tab => { app.focus = match app.focus { Focus::NodeList => Focus::Detail, Focus::Detail => Focus::Chat, Focus::Chat => Focus::NodeList }; } - KeyCode::BackTab => { app.focus = match app.focus { Focus::NodeList => Focus::Chat, Focus::Detail => Focus::NodeList, Focus::Chat => Focus::Detail }; } + KeyCode::Tab => { app.focus = match app.focus { Focus::NodeList => Focus::Detail, Focus::Detail => Focus::Chat, Focus::Chat => Focus::NodeList, Focus::SoulEdit => Focus::NodeList }; } + KeyCode::BackTab => { app.focus = match app.focus { Focus::NodeList => Focus::Chat, Focus::Detail => Focus::NodeList, Focus::Chat => Focus::Detail, Focus::SoulEdit => Focus::Detail }; } KeyCode::Esc | KeyCode::Backspace => { if !app.search_query.is_empty() { app.search_query.clear(); app.refilter(); } else { app.go_back(); } } KeyCode::Char(c @ '1'..='9') => jump_to_connection(app, c), _ => {} @@ -585,6 +826,11 @@ fn draw(f: &mut ratatui::Frame, app: &mut App, show_chat: bool) { } draw_help(f, app, main_chunks[2], show_chat); + + // Soul edit modal overlay + if app.focus == Focus::SoulEdit { + draw_soul_edit(f, app, size); + } } fn draw_header(f: &mut ratatui::Frame, app: &App, area: Rect) { @@ -839,13 +1085,65 @@ fn draw_chat(f: &mut ratatui::Frame, app: &App, area: Rect) { f.render_widget(input_widget, chat_chunks[1]); } +fn draw_soul_edit(f: &mut ratatui::Frame, app: &App, area: Rect) { + // Center the modal — 70% width, 60% height + let modal_w = (area.width as f32 * 0.7) as u16; + let modal_h = (area.height as f32 * 0.6) as u16; + let x = area.x + (area.width.saturating_sub(modal_w)) / 2; + let y = area.y + (area.height.saturating_sub(modal_h)) / 2; + let modal_area = Rect::new(x, y, modal_w, modal_h); + + // Clear the area behind the modal + let clear = Paragraph::new("").style(Style::default().bg(Color::Black)); + f.render_widget(clear, modal_area); + + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Length(3), Constraint::Min(3), Constraint::Length(1)]) + .split(modal_area); + + // Title bar + let title_text = format!(" Editing: {} ", app.edit_node_title); + let title = Paragraph::new(Line::from(Span::styled( + &title_text, + Style::default().fg(Color::White).add_modifier(Modifier::BOLD), + ))) + .block( + Block::default() + .borders(Borders::ALL) + .border_style(Style::default().fg(Color::Magenta)), + ); + f.render_widget(title, chunks[0]); + + // Content area + let status = if app.edit_saving { " (saving…)" } else { "" }; + let cursor = if !app.edit_saving { "█" } else { "" }; + let content_text = format!("{}{}", app.edit_input, cursor); + let content = Paragraph::new(content_text) + .block( + Block::default() + .title(format!(" Content{status} ")) + .borders(Borders::ALL) + .border_style(Style::default().fg(Color::Magenta)), + ) + .wrap(Wrap { trim: false }); + f.render_widget(content, chunks[1]); + + // Help bar + let help = Paragraph::new(Line::from(Span::styled( + " Ctrl+S: save │ Esc: cancel │ Enter: newline", + Style::default().fg(Color::DarkGray), + ))); + f.render_widget(help, chunks[2]); +} + fn draw_help(f: &mut ratatui::Frame, app: &App, area: Rect, show_chat: bool) { let text = if app.is_node_search { " Type to search │ Enter: apply │ Esc: cancel" } else if show_chat { match app.focus { Focus::Chat => " Type + Enter │ ↑↓/PgUp/PgDn: scroll │ Tab/Esc: graph │ Ctrl+C: quit", - _ => " ↑↓/jk: navigate │ f: filter │ /: search │ Enter: drill │ 1-9: jump │ Tab: cycle │ q: quit", + _ => " ↑↓/jk: navigate │ f: filter │ /: search │ Enter: drill │ e: edit │ 1-9: jump │ Tab: cycle │ q: quit", } } else { " ↑↓/jk: navigate │ Tab: category │ /: search │ Enter: drill │ 1-9: jump │ Esc: back │ q: quit" diff --git a/src/cli/graph_viz.rs b/src/cli/graph_viz.rs index 8f82b5e..f286975 100644 --- a/src/cli/graph_viz.rs +++ b/src/cli/graph_viz.rs @@ -22,8 +22,12 @@ fn kind_color(kind: NodeKind) -> &'static str { | NodeKind::ToolCall | NodeKind::LoopIteration => "\x1b[93m", // Self-model → green NodeKind::Pattern | NodeKind::Limitation | NodeKind::Capability => "\x1b[92m", - // Sub-agents → blue - NodeKind::SubAgent | NodeKind::Delegation | NodeKind::Synthesis => "\x1b[94m", + // Background tasks → blue + NodeKind::BackgroundTask => "\x1b[94m", + // Cron / skills → light blue + NodeKind::CronJob | NodeKind::CronExecution | NodeKind::Skill => "\x1b[94m", + // Notifications → light yellow + NodeKind::Notification => "\x1b[93m", } } @@ -57,7 +61,10 @@ fn kind_category(kind: NodeKind) -> &'static str { NodeKind::Session | NodeKind::Turn | NodeKind::LlmCall | NodeKind::ToolCall | NodeKind::LoopIteration => "Operational", NodeKind::Pattern | NodeKind::Limitation | NodeKind::Capability => "Self-Model", - NodeKind::SubAgent | NodeKind::Delegation | NodeKind::Synthesis => "Sub-Agents", + NodeKind::BackgroundTask => "Tasks", + NodeKind::CronJob | NodeKind::CronExecution => "Scheduler", + NodeKind::Skill => "Skills", + NodeKind::Notification => "Notifications", } } diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 12b6a26..bc4242b 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -6,10 +6,10 @@ mod graph_viz; mod graph_tui; #[derive(Parser)] -#[command(name = "cede", about = "A forkable self-aware agent with graph memory")] +#[command(name = "omni-cede", about = "Omnichannel self-aware agent with graph memory")] pub struct Cli { /// Path to the SQLite database file. - #[arg(long, default_value = "cede.db")] + #[arg(long, default_value = "omni-cede.db")] pub db: String, /// Use Ollama as the LLM backend (format: model@url, e.g. llama3@http://localhost:11434) @@ -60,6 +60,16 @@ pub enum Commands { /// Check graph health Doctor, + /// Start the HTTP API server + Serve { + /// Host to bind to + #[arg(long, default_value = "0.0.0.0")] + host: String, + /// Port to listen on + #[arg(long, default_value = "3000")] + port: u16, + }, + /// Pre-download the embedding model and initialize DB Init, } @@ -109,6 +119,168 @@ pub async fn run() -> crate::error::Result<()> { let cx = crate::CortexEmbedded::open(&cli.db).await?; match cli.command { + Commands::Serve { host, port } => { + let llm = build_llm_client(&ollama_spec)?; + cx.set_llm(llm.clone()).await; + let agent = crate::agent::orchestrator::Agent { + db: cx.db.clone(), + embed: cx.embed.clone(), + hnsw: cx.hnsw.clone(), + config: cx.config.clone(), + llm: llm.clone(), + tools: crate::tools::builtin_registry( + cx.db.clone(), + cx.embed.clone(), + cx.hnsw.clone(), + cx.auto_link_tx.clone(), + Some(llm), + cx.config.clone(), + ).await, + auto_link_tx: cx.auto_link_tx.clone(), + }; + + let api_key = std::env::var("API_KEY").ok(); + + // ── Build the omnichannel pipeline + registry ── + let mut registry = crate::channels::ChannelRegistry::new(1024); + + // Register built-in passive channels (always available) + registry.register(std::sync::Arc::new( + crate::channels::webhook::WebhookChannel::new(), + )).await; + let webchat = crate::channels::webchat::WebChatChannel::new(); + let webchat_counter = webchat.connection_counter(); + let webchat_max = webchat.max_connections(); + registry.register(std::sync::Arc::new(webchat)).await; + + // Register active channels only if their tokens are configured + if std::env::var("TELEGRAM_BOT_TOKEN").is_ok() { + registry.register(std::sync::Arc::new( + crate::channels::telegram::TelegramChannel::new(), + )).await; + } + if std::env::var("DISCORD_BOT_TOKEN").is_ok() { + registry.register(std::sync::Arc::new( + crate::channels::discord::DiscordChannel::new(), + )).await; + } + + // Take inbound receiver before moving registry into Arc + let inbound_rx = registry.take_inbound_rx(); + + let registry = std::sync::Arc::new(registry); + let pipeline = std::sync::Arc::new( + crate::channels::Pipeline::new(std::sync::Arc::clone(®istry)), + ); + + let state = std::sync::Arc::new(crate::api::AppState { + cx, + agent, + api_key, + pipeline, + registry, + webchat_counter, + webchat_max, + }); + + let app = crate::api::router(state.clone()); + let addr = format!("{host}:{port}"); + println!("omni-cede API server listening on {addr}"); + if std::env::var("API_KEY").is_err() { + println!(" WARNING: API_KEY not set — auth disabled (dev mode)"); + } + + // ── Start channel adapters (Telegram polling, etc.) ── + { + let mut channel_configs = std::collections::HashMap::new(); + + // Telegram: auto-enable if TELEGRAM_BOT_TOKEN is set + if std::env::var("TELEGRAM_BOT_TOKEN").is_ok() { + channel_configs.insert( + "telegram".to_string(), + serde_json::json!({ "mode": "polling" }), + ); + println!(" Telegram: enabled (polling mode)"); + } + + // Discord: auto-enable if DISCORD_BOT_TOKEN is set + if std::env::var("DISCORD_BOT_TOKEN").is_ok() { + channel_configs.insert( + "discord".to_string(), + serde_json::json!({}), + ); + println!(" Discord: enabled"); + } + + // Webhook + WebChat are always available (passive channels) + channel_configs.insert( + "webhook".to_string(), + serde_json::json!({}), + ); + channel_configs.insert( + "webchat".to_string(), + serde_json::json!({}), + ); + + if let Err(e) = state.registry.start_all(&state.cx.db, &channel_configs).await { + eprintln!(" WARNING: Channel start error: {e}"); + eprintln!(" (server will still handle /v1/message requests)"); + } + } + + // ── Start the pipeline inbound loop (processes messages from channels) ── + if let Some(rx) = inbound_rx { + let pipeline_clone = std::sync::Arc::clone(&state.pipeline); + let db_clone = state.cx.db.clone(); + let agent_clone = std::sync::Arc::new(crate::agent::orchestrator::Agent { + db: state.cx.db.clone(), + embed: state.cx.embed.clone(), + hnsw: state.cx.hnsw.clone(), + config: state.cx.config.clone(), + llm: state.agent.llm.clone(), + tools: crate::tools::builtin_registry( + state.cx.db.clone(), + state.cx.embed.clone(), + state.cx.hnsw.clone(), + state.cx.auto_link_tx.clone(), + Some(state.agent.llm.clone()), + state.cx.config.clone(), + ).await, + auto_link_tx: state.cx.auto_link_tx.clone(), + }); + tokio::spawn(async move { + pipeline_clone.run_inbound_loop(rx, db_clone, agent_clone).await; + }); + println!(" Pipeline inbound loop: started"); + + // ── Start proactive notification delivery loop ── + { + let notif_pipeline = std::sync::Arc::clone(&state.pipeline); + let notif_db = state.cx.db.clone(); + let notif_llm = state.agent.llm.clone(); + let notif_shutdown = state.cx.shutdown_rx(); + tokio::spawn(async move { + crate::notification_delivery::run( + notif_db, + notif_pipeline, + notif_llm, + notif_shutdown, + 10, // check every 10 seconds + ) + .await; + }); + println!(" Notification delivery loop: started"); + } + } + + let listener = tokio::net::TcpListener::bind(&addr) + .await + .map_err(|e| crate::error::CortexError::Config(format!("bind failed: {e}")))?; + axum::serve(listener, app) + .await + .map_err(|e| crate::error::CortexError::Config(format!("server error: {e}")))?; + Ok(()) + } Commands::Init => { println!("Database initialized at: {}", cli.db); println!("Embedding model ready."); @@ -191,7 +363,17 @@ pub async fn run() -> crate::error::Result<()> { Ok(()) } SoulAction::Edit => { - println!("Soul editing not yet implemented. Use `cede memory show ` to inspect."); + // Launch the graph TUI filtered to Identity nodes with edit mode. + // Select an identity node and press 'e' to edit. + // Identity category is index 1 in ALL_CATEGORIES + graph_tui::run_with_edit( + cx.db.clone(), + cx.embed.clone(), + cx.hnsw.clone(), + 1, // "Identity" category + ) + .await + .map_err(|e| crate::error::CortexError::Config(format!("TUI error: {e}")))?; Ok(()) } }, @@ -244,7 +426,7 @@ pub async fn run() -> crate::error::Result<()> { cx.auto_link_tx.clone(), Some(llm), cx.config.clone(), - ), + ).await, auto_link_tx: cx.auto_link_tx.clone(), }; @@ -258,7 +440,13 @@ pub async fn run() -> crate::error::Result<()> { }) .await?; - graph_tui::run_with_chat(cx.db.clone(), agent, session_id) + graph_tui::run_with_chat( + cx.db.clone(), + agent, + session_id, + Some(cx.embed.clone()), + Some(cx.hnsw.clone()), + ) .await .map_err(|e| crate::error::CortexError::Config(format!("TUI error: {e}")))?; } @@ -451,7 +639,7 @@ pub async fn run() -> crate::error::Result<()> { cx.auto_link_tx.clone(), Some(llm), cx.config.clone(), - ), + ).await, auto_link_tx: cx.auto_link_tx.clone(), }; @@ -465,7 +653,7 @@ pub async fn run() -> crate::error::Result<()> { }) .await?; - println!("cede chat — type 'exit' or Ctrl+C to quit\n"); + println!("omni-cede chat — type 'exit' or Ctrl+C to quit\n"); let stdin = io::stdin(); loop { print!("> "); @@ -478,7 +666,13 @@ pub async fn run() -> crate::error::Result<()> { if input == "exit" || input == "quit" { break; } - match agent.run_turn(&session_id, input).await { + let cli_ctx = crate::types::TurnContext { + channel: "cli".to_string(), + sender_name: None, + user_id: "local".to_string(), + is_group: false, + }; + match agent.run_turn(&session_id, input, &cli_ctx, None).await { Ok(response) => println!("\n{response}\n"), Err(e) => eprintln!("\nError: {e}\n"), } @@ -502,7 +696,7 @@ pub async fn run() -> crate::error::Result<()> { cx.auto_link_tx.clone(), Some(llm), cx.config.clone(), - ), + ).await, auto_link_tx: cx.auto_link_tx.clone(), }; diff --git a/src/config.rs b/src/config.rs index f03133c..ed886e7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -26,6 +26,16 @@ pub struct Config { /// Number of most-recent session nodes (UserInput + Fact) always included /// in a chat turn's briefing, regardless of semantic similarity. pub session_recency_window: usize, + /// Enable the bash/shell execution tool. + pub bash_enabled: bool, + /// Maximum seconds a bash command can run before being killed. + pub bash_timeout_secs: u64, + /// Maximum bytes of command output returned to the LLM. + pub bash_max_output_bytes: usize, + /// Shell command prefixes that are always blocked (case-insensitive substring match). + pub bash_blocked_patterns: Vec, + /// Maximum nodes included in the semantic briefing for a chat turn. + pub briefing_max_nodes: usize, } impl Default for Config { @@ -43,6 +53,21 @@ impl Default for Config { decay_lambda: 0.01, auto_link_candidates: 20, session_recency_window: 7, + bash_enabled: true, + bash_timeout_secs: 30, + bash_max_output_bytes: 10_000, + bash_blocked_patterns: vec![ + "rm -rf /".into(), + "mkfs".into(), + "dd if=".into(), + ":(){:|:&};:".into(), + "shutdown".into(), + "reboot".into(), + "halt".into(), + "init 0".into(), + "init 6".into(), + ], + briefing_max_nodes: 16, } } } diff --git a/src/db/queries.rs b/src/db/queries.rs index 6e743a7..d98f0e8 100644 --- a/src/db/queries.rs +++ b/src/db/queries.rs @@ -718,3 +718,77 @@ fn blob_to_embedding(blob: &[u8]) -> Vec { } bytemuck::cast_slice::(blob).to_vec() } + +// ─── Notification nodes (graph-native) ────────────────── + +/// Return all sessions that have at least one undelivered notification. +/// +/// Each entry is `(user_id, channel, session_node_id, Vec)`. +/// Used by the proactive notification delivery loop. +pub fn get_sessions_with_pending_notifications( + conn: &Connection, +) -> Result)>> { + // First, find all (session_id, user_id, channel) tuples that have pending notifications + let mut session_stmt = conn.prepare( + "SELECT DISTINCT ms.node_id, ms.user_id, ms.channel + FROM managed_sessions ms + JOIN edges e ON e.dst = ms.node_id AND e.kind = 'part_of' + JOIN nodes n ON n.id = e.src + WHERE n.kind = 'notification' AND n.access_count = 0", + )?; + let sessions: Vec<(String, String, String)> = session_stmt + .query_map([], |row| { + Ok((row.get(0)?, row.get(1)?, row.get(2)?)) + })? + .filter_map(|r| r.ok()) + .collect(); + + let mut result = Vec::new(); + for (session_id, user_id, channel) in sessions { + let nodes = get_pending_notification_nodes(conn, &session_id)?; + if !nodes.is_empty() { + result.push((user_id, channel, session_id, nodes)); + } + } + Ok(result) +} + +/// Fetch all undelivered Notification nodes linked to a session, oldest first. +/// A notification is "undelivered" when access_count == 0. +pub fn get_pending_notification_nodes( + conn: &Connection, + session_id: &str, +) -> Result> { + let mut stmt = conn.prepare( + "SELECT n.id, n.kind, n.title, n.body, n.importance, n.trust_score, + n.access_count, n.created_at, n.last_access, n.decay_rate + FROM nodes n + JOIN edges e ON e.src = n.id + WHERE n.kind = 'notification' + AND n.access_count = 0 + AND e.dst = ?1 + AND e.kind = 'part_of' + ORDER BY n.created_at ASC", + )?; + let rows = stmt.query_map(params![session_id], |row| { + let kind_str: String = row.get(1)?; + Ok(Node { + id: row.get(0)?, + kind: NodeKind::from_str_opt(&kind_str).unwrap_or(NodeKind::Fact), + title: row.get(2)?, + body: row.get(3)?, + importance: row.get(4)?, + trust_score: row.get(5)?, + access_count: row.get(6)?, + created_at: row.get(7)?, + last_access: row.get(8)?, + decay_rate: row.get(9)?, + embedding: None, + }) + })?; + let mut result = Vec::new(); + for r in rows { + result.push(r?); + } + Ok(result) +} diff --git a/src/error.rs b/src/error.rs index b0cab0b..4b9b3af 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,6 +34,15 @@ pub enum CortexError { #[error("Not found: {0}")] NotFound(String), + + #[error("Channel error: {0}")] + Channel(String), + + #[error("Unsupported: {0}")] + Unsupported(String), + + #[error("Pipeline error: {0}")] + Pipeline(String), } pub type Result = std::result::Result; diff --git a/src/identity/mod.rs b/src/identity/mod.rs new file mode 100644 index 0000000..086882f --- /dev/null +++ b/src/identity/mod.rs @@ -0,0 +1,188 @@ +//! Identity layer — maps external channel identifiers to internal user IDs. +//! +//! A single human can interact via multiple channels (WhatsApp, Telegram, REST +//! API, CLI). Each channel has its own external identifier format. The identity +//! layer resolves all of them to a single internal `UserId`. +//! +//! Storage: a dedicated SQLite table (`identities`) alongside the graph DB. + +use rusqlite::{params, Connection, OptionalExtension}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::db::Db; +use crate::error::Result; + +/// A unique internal user identifier. +pub type UserId = String; + +/// A channel identifier — the external handle for a user on a specific platform. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChannelId { + /// e.g. "whatsapp", "telegram", "api", "cli" + pub channel: String, + /// e.g. "+447123456789", "12345678", "api-key-hash", "local" + pub external_id: String, +} + +impl ChannelId { + pub fn new(channel: &str, external_id: &str) -> Self { + Self { + channel: channel.to_string(), + external_id: external_id.to_string(), + } + } + + /// Canonical string form: "channel:external_id" + pub fn canonical(&self) -> String { + format!("{}:{}", self.channel, self.external_id) + } +} + +/// A user record. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct User { + pub id: UserId, + pub display_name: Option, + pub created_at: i64, +} + +/// Create the identities table if it doesn't exist. +pub fn create_tables(conn: &Connection) -> std::result::Result<(), rusqlite::Error> { + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + display_name TEXT, + created_at INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS channel_mappings ( + channel TEXT NOT NULL, + external_id TEXT NOT NULL, + user_id TEXT NOT NULL REFERENCES users(id), + created_at INTEGER NOT NULL, + PRIMARY KEY (channel, external_id) + ); + + CREATE INDEX IF NOT EXISTS idx_channel_mappings_user + ON channel_mappings(user_id);", + )?; + Ok(()) +} + +/// Look up a user by their channel identifier, or create a new one. +pub fn resolve_or_create( + conn: &Connection, + channel_id: &ChannelId, +) -> std::result::Result { + // Try to find existing mapping + let existing: Option = conn + .query_row( + "SELECT user_id FROM channel_mappings WHERE channel = ?1 AND external_id = ?2", + params![channel_id.channel, channel_id.external_id], + |row| row.get(0), + ) + .optional()?; + + if let Some(user_id) = existing { + let user = conn.query_row( + "SELECT id, display_name, created_at FROM users WHERE id = ?1", + params![user_id], + |row| { + Ok(User { + id: row.get(0)?, + display_name: row.get(1)?, + created_at: row.get(2)?, + }) + }, + )?; + return Ok(user); + } + + // Create new user + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + let user_id = Uuid::new_v4().to_string(); + + conn.execute( + "INSERT INTO users (id, display_name, created_at) VALUES (?1, ?2, ?3)", + params![user_id, Option::::None, now], + )?; + + conn.execute( + "INSERT INTO channel_mappings (channel, external_id, user_id, created_at) VALUES (?1, ?2, ?3, ?4)", + params![channel_id.channel, channel_id.external_id, user_id, now], + )?; + + Ok(User { + id: user_id, + display_name: None, + created_at: now, + }) +} + +/// Link an additional channel identifier to an existing user. +pub fn link_channel( + conn: &Connection, + user_id: &str, + channel_id: &ChannelId, +) -> std::result::Result<(), rusqlite::Error> { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + conn.execute( + "INSERT OR IGNORE INTO channel_mappings (channel, external_id, user_id, created_at) VALUES (?1, ?2, ?3, ?4)", + params![channel_id.channel, channel_id.external_id, user_id, now], + )?; + Ok(()) +} + +/// Look up the external_id for a user on a specific channel. +/// +/// Returns `None` if no mapping exists for that (user, channel) pair. +pub fn get_external_id( + conn: &Connection, + user_id: &str, + channel: &str, +) -> std::result::Result, rusqlite::Error> { + conn.query_row( + "SELECT external_id FROM channel_mappings WHERE user_id = ?1 AND channel = ?2", + params![user_id, channel], + |row| row.get(0), + ) + .optional() +} + +/// List all channel identifiers for a user. +pub fn list_channels( + conn: &Connection, + user_id: &str, +) -> std::result::Result, rusqlite::Error> { + let mut stmt = conn.prepare( + "SELECT channel, external_id FROM channel_mappings WHERE user_id = ?1", + )?; + let rows = stmt.query_map(params![user_id], |row| { + Ok(ChannelId { + channel: row.get(0)?, + external_id: row.get(1)?, + }) + })?; + let mut result = Vec::new(); + for r in rows { + result.push(r?); + } + Ok(result) +} + +/// Async wrapper: resolve or create a user from a channel identifier. +pub async fn resolve_user(db: &Db, channel_id: ChannelId) -> Result { + db.call(move |conn| { + create_tables(conn)?; + resolve_or_create(conn, &channel_id).map_err(Into::into) + }) + .await + .map_err(|e| crate::error::CortexError::DbTask(e.to_string())) +} diff --git a/src/lib.rs b/src/lib.rs index 7e60718..8bdd2d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,14 @@ pub mod tools; pub mod llm; pub mod agent; pub mod cli; +pub mod api; +pub mod identity; +pub mod session; +pub mod channels; +pub mod scheduler; +pub mod notification_delivery; +#[cfg(feature = "browser")] +pub mod browser; use std::collections::HashMap; use std::sync::Arc; @@ -102,6 +110,12 @@ impl CortexEmbedded { *guard = Some(client); } + /// Get a new shutdown receiver. Each receiver is independent — + /// used by components that need to know when to stop. + pub fn shutdown_rx(&self) -> tokio::sync::watch::Receiver { + self.shutdown_tx.subscribe() + } + // ─── Core memory ──────────────────────────────────── /// Store a node in the graph. Embeds its text, writes to SQLite, @@ -232,7 +246,7 @@ impl CortexEmbedded { fn start_background_tasks( &self, auto_link_rx: async_channel::Receiver, - mut shutdown_rx: tokio::sync::watch::Receiver, + shutdown_rx: tokio::sync::watch::Receiver, ) { // Auto-link task let db = self.db.clone(); @@ -256,6 +270,7 @@ impl CortexEmbedded { let db = self.db.clone(); let interval = std::time::Duration::from_secs(self.config.decay_interval_secs); let decay_interval_secs = self.config.decay_interval_secs; + let mut shutdown_rx3 = shutdown_rx.clone(); tokio::spawn(async move { let mut ticker = tokio::time::interval(interval); @@ -265,10 +280,34 @@ impl CortexEmbedded { _ = ticker.tick() => { let _ = run_decay(&db, decay_interval_secs).await; } - _ = shutdown_rx.changed() => break, + _ = shutdown_rx3.changed() => break, } } }); + + // Cron scheduler task + { + let db = self.db.clone(); + let embed = self.embed.clone(); + let hnsw = self.hnsw.clone(); + let auto_link_tx = self.auto_link_tx.clone(); + let llm = self.llm.clone(); + let config = self.config.clone(); + + tokio::spawn(async move { + scheduler::run( + db, + embed, + hnsw, + auto_link_tx, + llm, + config, + shutdown_rx, + 30, // check every 30 seconds + ) + .await; + }); + } } } diff --git a/src/llm/mod.rs b/src/llm/mod.rs index b9cd09d..7284941 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -244,12 +244,9 @@ impl OllamaClient { model, } } -} -#[async_trait::async_trait] -impl LlmClient for OllamaClient { - async fn complete(&self, messages: &[Message]) -> Result { - let msgs: Vec = messages + fn build_messages(messages: &[Message]) -> Vec { + messages .iter() .map(|m| { serde_json::json!({ @@ -257,19 +254,15 @@ impl LlmClient for OllamaClient { Role::System => "system", Role::User => "user", Role::Assistant => "assistant", - Role::Tool => "user", + Role::Tool => "tool", }, "content": m.content, }) }) - .collect(); - - let body = serde_json::json!({ - "model": self.model, - "messages": msgs, - "stream": false, - }); + .collect() + } + async fn do_request(&self, body: serde_json::Value) -> Result { let resp = self .client .post(format!("{}/api/chat", self.url)) @@ -288,18 +281,82 @@ impl LlmClient for OllamaClient { .unwrap_or("") .to_string(); + // Parse tool calls from Ollama response + let mut tool_calls = Vec::new(); + let mut tool_name = None; + let mut tool_input = None; + let mut tool_use_id = None; + + if let Some(calls) = json["message"]["tool_calls"].as_array() { + for (i, call) in calls.iter().enumerate() { + let name = call["function"]["name"] + .as_str() + .unwrap_or("") + .to_string(); + let arguments = call["function"]["arguments"].clone(); + let id = format!("ollama_tc_{i}"); + + if tool_name.is_none() { + tool_name = Some(name.clone()); + tool_input = Some(arguments.clone()); + tool_use_id = Some(id.clone()); + } + tool_calls.push(ToolCall { + id, + name, + input: arguments, + }); + } + } + + let stop_reason = if tool_calls.is_empty() { + StopReason::EndTurn + } else { + StopReason::ToolUse + }; + Ok(LlmResponse { text, - stop_reason: StopReason::EndTurn, - tool_name: None, - tool_input: None, - tool_use_id: None, - tool_calls: Vec::new(), + stop_reason, + tool_name, + tool_input, + tool_use_id, + tool_calls, raw_content: None, input_tokens: 0, output_tokens: 0, }) } +} + +#[async_trait::async_trait] +impl LlmClient for OllamaClient { + async fn complete(&self, messages: &[Message]) -> Result { + let msgs = Self::build_messages(messages); + let body = serde_json::json!({ + "model": self.model, + "messages": msgs, + "stream": false, + }); + self.do_request(body).await + } + + async fn complete_with_tools( + &self, + messages: &[Message], + tools: &[serde_json::Value], + ) -> Result { + let msgs = Self::build_messages(messages); + let mut body = serde_json::json!({ + "model": self.model, + "messages": msgs, + "stream": false, + }); + if !tools.is_empty() { + body["tools"] = serde_json::Value::Array(tools.to_vec()); + } + self.do_request(body).await + } fn model_name(&self) -> &str { &self.model diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 83ad8d4..9afa489 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -12,6 +12,56 @@ use crate::types::*; use std::sync::Arc; use tokio::sync::RwLock; +/// Format a unix timestamp as a human-readable datetime string. +pub fn format_timestamp(unix: i64) -> String { + use std::time::{Duration, UNIX_EPOCH}; + let dt = UNIX_EPOCH + Duration::from_secs(unix as u64); + // Format as ISO-like: YYYY-MM-DD HH:MM:SS UTC + let secs_since_epoch = dt.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs(); + let days = secs_since_epoch / 86400; + let time_of_day = secs_since_epoch % 86400; + let hours = time_of_day / 3600; + let minutes = (time_of_day % 3600) / 60; + let seconds = time_of_day % 60; + // Simple date calculation from days since epoch + let (year, month, day) = days_to_ymd(days); + format!("{year:04}-{month:02}-{day:02} {hours:02}:{minutes:02}:{seconds:02} UTC") +} + +/// Convert days since Unix epoch to (year, month, day). +fn days_to_ymd(days_since_epoch: u64) -> (u64, u64, u64) { + // Algorithm from http://howardhinnant.github.io/date_algorithms.html + let z = days_since_epoch + 719468; + let era = z / 146097; + let doe = z - era * 146097; + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let y = yoe + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let d = doy - (153 * mp + 2) / 5 + 1; + let m = if mp < 10 { mp + 3 } else { mp - 9 }; + let y = if m <= 2 { y + 1 } else { y }; + (y, m, d) +} + +/// Format a relative time description (e.g., "2 hours ago", "3 days ago"). +pub fn relative_time(unix: i64) -> String { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + let diff = now - unix; + if diff < 60 { + "just now".to_string() + } else if diff < 3600 { + format!("{} min ago", diff / 60) + } else if diff < 86400 { + format!("{} hours ago", diff / 3600) + } else { + format!("{} days ago", diff / 86400) + } +} + // ─── recall ───────────────────────────────────────────── /// Hybrid semantic + graph search. @@ -215,9 +265,19 @@ pub async fn briefing_with_kinds( } /// Render the briefing as a markdown document for the LLM system prompt. +/// +/// Written in natural language — no raw IDs, numeric scores, or kind tags +/// so the agent's responses stay human-friendly. fn format_context_doc(nodes: &[ScoredNode], contradictions: &[ContradictionPair]) -> String { let mut doc = String::new(); + // Current time header — so the agent always knows what time it is + let now_unix = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + doc.push_str(&format!("## Current time\n{}\n\n", format_timestamp(now_unix))); + // Who you are let identity: Vec<&ScoredNode> = nodes .iter() @@ -227,12 +287,26 @@ fn format_context_doc(nodes: &[ScoredNode], contradictions: &[ContradictionPair] doc.push_str("## Who you are\n"); for s in &identity { let body = s.node.body.as_deref().unwrap_or(""); + let rel = relative_time(s.node.created_at); doc.push_str(&format!( - "- **[{}]** {}: {}\n", - s.node.kind, s.node.title, body + "- **{}**: {} _(remembered {})_\n", + s.node.title, body, rel )); } doc.push('\n'); + } else { + // Bootstrap prompt — no identity exists yet. + doc.push_str("## First contact\n"); + doc.push_str("You have no memory yet — this is a blank slate.\n\n"); + doc.push_str("Start by finding out who you're talking to: ask their name, what they need from you, and what role they want you to play. "); + doc.push_str("Let the conversation shape who you become.\n\n"); + doc.push_str("You have a `remember` tool that stores things permanently in your memory. "); + doc.push_str("As you learn about yourself and the people you talk to, use it to build your own identity:\n"); + doc.push_str("- Your name and nature\n"); + doc.push_str("- Values and principles you adopt\n"); + doc.push_str("- What you're working towards\n"); + doc.push_str("- Things you learn about the world and people\n\n"); + doc.push_str("Don't invent a persona. Let it emerge from what you're told and what you observe.\n\n"); } // What you know @@ -255,14 +329,15 @@ fn format_context_doc(nodes: &[ScoredNode], contradictions: &[ContradictionPair] doc.push_str("## What you know\n"); for s in &knowledge { let body = s.node.body.as_deref().unwrap_or(""); - let trust_flag = if s.node.trust_score < 0.5 { - " ⚠ LOW TRUST" + let rel = relative_time(s.node.created_at); + let confidence = if s.node.trust_score < 0.5 { + " *(uncertain — may need verification)*" } else { "" }; doc.push_str(&format!( - "- **[{}]** {} (trust: {:.2}, score: {:.3}){}\n {}\n", - s.node.kind, s.node.title, s.node.trust_score, s.score, trust_flag, body + "- **{}**{} _(remembered {})_\n {}\n", + s.node.title, confidence, rel, body )); } doc.push('\n'); @@ -277,39 +352,43 @@ fn format_context_doc(nodes: &[ScoredNode], contradictions: &[ContradictionPair] doc.push_str("## Recent conversation\n"); for s in &conversation { let body = s.node.body.as_deref().unwrap_or(&s.node.title); + let rel = relative_time(s.node.created_at); doc.push_str(&format!( - "- User said (score: {:.3}): {}\n", - s.score, body + "- ({}) User said: {}\n", + rel, body )); } doc.push('\n'); } - // Active contradictions + // Active contradictions — described by title, not raw IDs if !contradictions.is_empty() { - doc.push_str("## Active contradictions\n"); + doc.push_str("## Conflicting information\n"); + doc.push_str("You have memories that contradict each other. Consider asking the user to clarify:\n"); for c in contradictions { + // We show the short IDs as a fallback but they'll be overridden + // when the caller has node titles available. For now, keep it + // minimally technical. + let a_label = &c.node_a[..8.min(c.node_a.len())]; + let b_label = &c.node_b[..8.min(c.node_b.len())]; doc.push_str(&format!( - "- CONFLICT: node {} ↔ node {} (unresolved)\n", - &c.node_a[..8.min(c.node_a.len())], - &c.node_b[..8.min(c.node_b.len())], + "- Memory {} conflicts with memory {} (unresolved)\n", + a_label, b_label, )); } doc.push('\n'); } - // What to verify + // What to verify — items with low trust let stale_or_untrusted: Vec<&ScoredNode> = nodes .iter() .filter(|s| s.node.trust_score < 0.5) .collect(); if !stale_or_untrusted.is_empty() { - doc.push_str("## What to verify\n"); + doc.push_str("## Needs verification\n"); + doc.push_str("These memories may be outdated or unreliable:\n"); for s in &stale_or_untrusted { - doc.push_str(&format!( - "- {} (trust: {:.2})\n", - s.node.title, s.node.trust_score - )); + doc.push_str(&format!("- {}\n", s.node.title)); } doc.push('\n'); } diff --git a/src/notification_delivery.rs b/src/notification_delivery.rs new file mode 100644 index 0000000..fc4d7a7 --- /dev/null +++ b/src/notification_delivery.rs @@ -0,0 +1,308 @@ +//! Proactive notification delivery — timer-based background loop. +//! +//! When background tool execution completes, the result is stored as a +//! `Notification` graph node linked to the user's session. This module runs +//! a periodic loop that detects undelivered notifications, formulates a +//! natural message via a brief LLM call, and pushes it proactively to the +//! user's channel — so the user doesn't have to send another message to +//! see the results. +//! +//! # Flow +//! +//! ```text +//! tick (every N seconds) +//! → query sessions with pending notifications +//! → for each: resolve outbound routing (channel + external_id) +//! → brief LLM call to formulate a natural update message +//! → Pipeline::send_outbound() to push it to the user +//! → touch_nodes() to mark notifications as delivered +//! ``` + +use std::sync::Arc; + +use crate::channels::Pipeline; +use crate::channels::types::*; +use crate::db::Db; +use crate::db::queries; +use crate::identity; +use crate::llm::LlmClient; +use crate::memory; +use crate::types::*; + +/// Run the notification delivery loop until shutdown is signalled. +/// +/// Checks for pending notification nodes every `interval_secs` seconds. +/// When found, resolves outbound routing, runs a brief LLM call to +/// produce a natural message, and sends it via the pipeline. +pub async fn run( + db: Db, + pipeline: Arc, + llm: Arc, + mut shutdown_rx: tokio::sync::watch::Receiver, + interval_secs: u64, +) { + let interval = std::time::Duration::from_secs(interval_secs); + let mut ticker = tokio::time::interval(interval); + ticker.tick().await; // skip the first immediate tick + + tracing::info!(interval_secs, "notification delivery loop started"); + + loop { + tokio::select! { + _ = ticker.tick() => { + if let Err(e) = deliver_pending(&db, &pipeline, &llm).await { + tracing::warn!(error = %e, "notification delivery tick failed"); + } + } + _ = shutdown_rx.changed() => { + tracing::info!("notification delivery loop shutting down"); + break; + } + } + } +} + +/// One tick: find all sessions with pending notifications and deliver them. +async fn deliver_pending( + db: &Db, + pipeline: &Arc, + llm: &Arc, +) -> crate::error::Result<()> { + // Query all sessions that have undelivered notification nodes. + let sessions_with_notifs = db + .call(|conn| queries::get_sessions_with_pending_notifications(conn)) + .await?; + + if sessions_with_notifs.is_empty() { + return Ok(()); + } + + for (user_id, channel, session_id, notifications) in sessions_with_notifs { + if let Err(e) = deliver_for_session( + db, pipeline, llm, + &user_id, &channel, &session_id, ¬ifications, + ).await { + tracing::warn!( + user_id = %user_id, + channel = %channel, + session_id = %session_id, + error = %e, + "failed to deliver notifications for session" + ); + } + } + + Ok(()) +} + +/// Deliver all pending notifications for a single session. +async fn deliver_for_session( + db: &Db, + pipeline: &Arc, + llm: &Arc, + user_id: &str, + channel: &str, + session_id: &str, + notifications: &[Node], +) -> crate::error::Result<()> { + // 1. Resolve outbound routing: channel + external_id + let uid = user_id.to_string(); + let ch = channel.to_string(); + let external_id = db + .call(move |conn| { + identity::create_tables(conn)?; + Ok(identity::get_external_id(conn, &uid, &ch)?) + }) + .await?; + + let external_id = match external_id { + Some(eid) => eid, + None => { + tracing::warn!( + user_id = %user_id, + channel = %channel, + "no external_id found — cannot deliver proactive notification" + ); + return Ok(()); + } + }; + + let target = OutboundTarget { + channel: channel.to_string(), + external_id, + group_id: None, // proactive notifications go to DMs + reply_to_message_id: None, + callback_url: None, + }; + + // 2. Build a brief prompt with the notification summaries + persona + // Pull Soul + Belief nodes so the LLM reply stays in character. + let persona = db + .call(|conn| { + let mut parts = Vec::new(); + let souls = queries::get_nodes_by_kind(conn, NodeKind::Soul)?; + for n in &souls { + if let Some(ref b) = n.body { + parts.push(b.clone()); + } + } + let beliefs = queries::get_nodes_by_kind(conn, NodeKind::Belief)?; + for n in &beliefs { + if let Some(ref b) = n.body { + parts.push(format!("Belief: {}", b)); + } + } + Ok(parts.join("\n")) + }) + .await + .unwrap_or_default(); + + let mut notification_block = String::new(); + let mut delivered_ids: Vec = Vec::new(); + for node in notifications { + let rel = memory::relative_time(node.created_at); + let body = node.body.as_deref().unwrap_or(&node.title); + notification_block.push_str(&format!("- ({}) {}\n", rel, body)); + delivered_ids.push(node.id.clone()); + } + + let persona_section = if persona.is_empty() { + String::new() + } else { + format!("## Your identity\n{}\n\n", persona) + }; + + // Pull the full body of linked BackgroundTask nodes for richer context. + let bg_bodies = fetch_background_task_bodies(db, notifications).await; + let bg_context = if bg_bodies.is_empty() { + String::new() + } else { + format!("\n## Full background task results\n{}\n", bg_bodies.join("\n---\n")) + }; + + // Pull recent conversation so the notification LLM knows what was + // already said and can skip redundant updates or blend naturally. + let sid = session_id.to_string(); + let recent_nodes = db + .call(move |conn| queries::get_recent_session_nodes(conn, &sid, 10)) + .await + .unwrap_or_default(); + + let mut conversation_block = String::new(); + if !recent_nodes.is_empty() { + conversation_block.push_str("## Recent conversation (what was already said)\n"); + for node in recent_nodes.iter().rev() { + let label = match node.kind { + NodeKind::UserInput => "User", + NodeKind::ToolCall => "Tool", + NodeKind::BackgroundTask => "Background", + _ => "Assistant", + }; + let body = node.body.as_deref().unwrap_or(&node.title); + let rel = memory::relative_time(node.created_at); + conversation_block.push_str(&format!("- ({rel}) {label}: {body}\n")); + } + conversation_block.push('\n'); + } + + let system_prompt = format!( + "{persona_section}\ + {conversation_block}\ + You are following up on background work you kicked off earlier. \ + The following tasks have completed:\n\n{notification_block}\n\ + {bg_context}\n\ + Write a brief, natural message to let the user know what happened. \ + Be conversational and concise — this is a proactive update, not a \ + formal report. If something failed, mention it clearly but calmly. \ + Do NOT say \"notification\" or refer to yourself as a system. \ + Stay in character.\n\n\ + CRITICAL RULES:\n\ + 1. Read the recent conversation above carefully. If the user ALREADY \ + knows about this result (because you discussed it, acknowledged it, \ + or the topic was covered), respond with exactly [SKIP].\n\ + 2. Do NOT repeat, paraphrase, or re-announce anything already said.\n\ + 3. If sending a message, it must contain NEW information the user \ + hasn't seen yet. Blend naturally into the ongoing conversation.\n\ + 4. If the task results are vague, empty, or contain no concrete \ + information worth sharing, respond with exactly [SKIP].\n\ + 5. Match the tone and energy of the recent conversation.", + ); + + let messages = vec![ + Message::system(system_prompt), + Message::user("What's the update?"), + ]; + + // 3. Brief LLM call to formulate the message + let response = llm.complete(&messages).await?; + let reply_text = response.text.trim().to_string(); + + // Strip [SKIP] anywhere in the response — exact match, starts-with, or contains + if reply_text.is_empty() || reply_text.contains("[SKIP]") { + tracing::info!( + session_id = %session_id, + count = notifications.len(), + "notification delivery skipped (no substantive content)" + ); + // Still mark as delivered so we don't keep retrying + db.call(move |conn| queries::touch_nodes(conn, &delivered_ids)) + .await?; + return Ok(()); + } + + // 4. Send via the pipeline's outbound path + let message = OutboundMessage::text(&reply_text); + if let Err(e) = pipeline.send_outbound(&target, message).await { + tracing::error!( + channel = %channel, + session_id = %session_id, + error = %e, + "proactive notification delivery failed" + ); + return Err(e); + } + + tracing::info!( + session_id = %session_id, + channel = %channel, + count = notifications.len(), + "proactive notifications delivered" + ); + + // 5. Mark notifications as delivered (bump access_count from 0) + db.call(move |conn| queries::touch_nodes(conn, &delivered_ids)) + .await?; + + Ok(()) +} + +/// Follow DerivesFrom edges from notification nodes to their BackgroundTask +/// nodes and collect the full bodies. This gives the delivery LLM richer +/// context than the truncated notification summary alone. +async fn fetch_background_task_bodies(db: &Db, notifications: &[Node]) -> Vec { + let mut bodies = Vec::new(); + for notif in notifications { + let nid = notif.id.clone(); + if let Ok(edges) = db + .call(move |conn| queries::get_edges_from(conn, &nid)) + .await + { + for edge in edges { + if edge.kind == EdgeKind::DerivesFrom { + let target_id = edge.dst.clone(); + if let Ok(Some(node)) = db + .call(move |conn| queries::get_node(conn, &target_id)) + .await + { + if node.kind == NodeKind::BackgroundTask { + if let Some(body) = &node.body { + bodies.push(body.clone()); + } + } + } + } + } + } + } + bodies +} diff --git a/src/scheduler.rs b/src/scheduler.rs new file mode 100644 index 0000000..0d55dba --- /dev/null +++ b/src/scheduler.rs @@ -0,0 +1,346 @@ +//! Proactive cron scheduler. +//! +//! Loads `CronJob` nodes from the graph and fires them on schedule. +//! Each execution spawns a short-lived Agent loop and records a +//! `CronExecution` node linked to the originating `CronJob`. + +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::RwLock; + +use crate::config::Config; +use crate::db::Db; +use crate::db::queries; +use crate::embed::EmbedHandle; +use crate::error::Result; +use crate::hnsw::VectorIndex; +use crate::llm::LlmClient; +use crate::memory::format_timestamp; +use crate::types::*; + +/// Metadata stored in a CronJob node's body (JSON). +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CronJobMeta { + /// Standard cron expression (5 or 7 fields). + pub cron: String, + /// The task prompt to run when the schedule fires. + pub task: String, + /// Maximum agent loop iterations per execution (default 5). + #[serde(default = "default_max_iter")] + pub max_iterations: usize, + /// Whether this job is active. + #[serde(default = "default_enabled")] + pub enabled: bool, + /// Unix timestamp of the last successful fire (0 = never). + #[serde(default)] + pub last_fired: i64, + /// The user who created this job (for routing results back). + #[serde(default)] + pub user_id: Option, + /// The channel from which this job was created. + #[serde(default)] + pub channel: Option, +} + +fn default_max_iter() -> usize { 5 } +fn default_enabled() -> bool { true } + +/// Run the scheduler loop. Call this from a `tokio::spawn`. +/// +/// Every `tick_secs` seconds it loads all CronJob nodes, evaluates them +/// against the current time, and fires any that are due. +pub async fn run( + db: Db, + embed: EmbedHandle, + hnsw: Arc>, + auto_link_tx: async_channel::Sender, + llm: Arc>>>, + config: Config, + mut shutdown_rx: tokio::sync::watch::Receiver, + tick_secs: u64, +) { + let mut ticker = tokio::time::interval(std::time::Duration::from_secs(tick_secs)); + ticker.tick().await; // skip the first immediate tick + + loop { + tokio::select! { + _ = ticker.tick() => { + if let Err(e) = tick(&db, &embed, &hnsw, &auto_link_tx, &llm, &config).await { + tracing::warn!("scheduler tick error: {e}"); + } + } + _ = shutdown_rx.changed() => { + tracing::info!("scheduler shutting down"); + break; + } + } + } +} + +/// Single scheduler tick — evaluate all CronJob nodes and fire any due. +async fn tick( + db: &Db, + embed: &EmbedHandle, + hnsw: &Arc>, + auto_link_tx: &async_channel::Sender, + llm: &Arc>>>, + config: &Config, +) -> Result<()> { + // 1. Load all CronJob nodes + let cron_nodes = db + .call(|conn| queries::get_nodes_by_kind(conn, NodeKind::CronJob)) + .await?; + + if cron_nodes.is_empty() { + return Ok(()); + } + + let now = chrono::Utc::now(); + let now_ts = now.timestamp(); + + for node in &cron_nodes { + let meta: CronJobMeta = match &node.body { + Some(body) => match serde_json::from_str(body) { + Ok(m) => m, + Err(e) => { + tracing::warn!("invalid CronJob meta for {}: {e}", &node.id[..8]); + continue; + } + }, + None => continue, + }; + + if !meta.enabled { + continue; + } + + // Parse the cron expression + let schedule = match cron::Schedule::from_str(&meta.cron) { + Ok(s) => s, + Err(e) => { + tracing::warn!("bad cron expr '{}' for {}: {e}", meta.cron, &node.id[..8]); + continue; + } + }; + + // Determine if this job should fire: + // Find the most recent scheduled time <= now, and check if it's after last_fired. + let should_fire = if meta.last_fired == 0 { + // Never fired — fire on the first tick + true + } else { + let last_fired_dt = chrono::DateTime::from_timestamp(meta.last_fired, 0) + .unwrap_or(chrono::DateTime::UNIX_EPOCH); + // Check if any scheduled time exists between last_fired and now + schedule + .after(&last_fired_dt) + .take(1) + .any(|next| next <= now) + }; + + if !should_fire { + continue; + } + + tracing::info!("firing cron job '{}' ({})", node.title, &node.id[..8]); + + // Update last_fired in the node's body + { + let mut updated_meta = meta.clone(); + updated_meta.last_fired = now_ts; + let new_body = serde_json::to_string(&updated_meta).unwrap_or_default(); + let nid = node.id.clone(); + db.call(move |conn| { + conn.execute( + "UPDATE nodes SET body = ?1 WHERE id = ?2", + rusqlite::params![new_body, nid], + )?; + Ok(()) + }) + .await?; + } + + // Get an LLM client, or skip if none set + let llm_client = { + let guard = llm.read().await; + match &*guard { + Some(c) => c.clone(), + None => { + tracing::warn!("no LLM configured — skipping cron execution"); + continue; + } + } + }; + + // Spawn the execution as a background task + fire_cron_job( + db.clone(), + embed.clone(), + hnsw.clone(), + auto_link_tx.clone(), + llm_client, + config.clone(), + node.id.clone(), + node.title.clone(), + meta.task.clone(), + meta.max_iterations, + meta.user_id.clone(), + meta.channel.clone(), + ); + } + + Ok(()) +} + +/// Spawn a background agent loop for a cron execution. +fn fire_cron_job( + db: Db, + embed: EmbedHandle, + hnsw: Arc>, + auto_link_tx: async_channel::Sender, + llm: Arc, + config: Config, + cron_job_id: NodeId, + job_title: String, + task: String, + max_iterations: usize, + user_id: Option, + channel: Option, +) { + tokio::spawn(async move { + // 1. Create a CronExecution node + let exec_node = Node::new(NodeKind::CronExecution, format!("[{}] Ran scheduled task: {job_title}", format_timestamp(crate::types::now_unix()))) + .with_body(&format!("Status: running\nTask: {task}")); + let exec_id = exec_node.id.clone(); + if let Err(e) = db + .call({ + let n = exec_node; + move |conn| queries::insert_node(conn, &n) + }) + .await + { + tracing::error!("failed to create CronExecution node: {e}"); + return; + } + + // 2. Link CronExecution → CronJob via DerivesFrom + let edge = Edge::new(exec_id.clone(), cron_job_id.clone(), EdgeKind::DerivesFrom); + let _ = db + .call(move |conn| queries::insert_edge(conn, &edge)) + .await; + + // 3. Build a tools registry and agent + let tools = crate::tools::builtin_registry_core( + db.clone(), + embed.clone(), + hnsw.clone(), + auto_link_tx.clone(), + None, // no recursive spawn_task from cron + config.clone(), + ); + + let mut agent_config = config; + agent_config.max_iterations = max_iterations; + + let agent = crate::agent::orchestrator::Agent { + db: db.clone(), + embed, + hnsw, + config: agent_config, + llm, + tools, + auto_link_tx: auto_link_tx.clone(), + }; + + // 4. Run the agent loop + let result = agent.run(&task).await; + + // 5. Update the CronExecution node with results + let (status, result_body) = match &result { + Ok(answer) => ("completed", format!("Status: completed\n\n{answer}")), + Err(e) => ("failed", format!("Status: failed\n\nError: {e}")), + }; + + // Store result as a Fact linked to the execution + let fact = Node::new(NodeKind::Fact, format!("[{}] Result of scheduled task: {job_title}", format_timestamp(crate::types::now_unix()))) + .with_body(&result_body) + .with_importance(0.5); + let fact_id = fact.id.clone(); + let _ = db + .call({ + let f = fact; + move |conn| queries::insert_node(conn, &f) + }) + .await; + + let derives = Edge::new(fact_id.clone(), exec_id.clone(), EdgeKind::DerivesFrom); + let _ = db + .call(move |conn| queries::insert_edge(conn, &derives)) + .await; + + let _ = auto_link_tx.try_send(fact_id); + + // Create a Notification in the user's session so it gets delivered + // to the right channel by the notification delivery loop. + if let (Some(ref uid), Some(ref ch)) = (&user_id, &channel) { + let uid2 = uid.clone(); + let ch2 = ch.clone(); + let session_id: Option = db + .call(move |conn| { + crate::session::create_tables(conn)?; + let mut stmt = conn.prepare( + "SELECT node_id FROM managed_sessions WHERE user_id = ?1 AND channel = ?2", + )?; + let rows: Vec = stmt + .query_map(rusqlite::params![uid2, ch2], |row| row.get(0))? + .filter_map(|r| r.ok()) + .collect(); + Ok(rows.into_iter().next()) + }) + .await + .ok() + .flatten(); + + if let Some(sid) = session_id { + let notif = Node::new( + NodeKind::Notification, + format!( + "[{}] Scheduled task completed: {job_title}", + format_timestamp(crate::types::now_unix()) + ), + ) + .with_body(&result_body); + let notif_id = notif.id.clone(); + let _ = db + .call({ + let n = notif; + move |conn| queries::insert_node(conn, &n) + }) + .await; + let notif_edge = Edge::new(notif_id, sid, EdgeKind::PartOf); + let _ = db + .call(move |conn| queries::insert_edge(conn, ¬if_edge)) + .await; + tracing::info!( + user_id = %uid.as_str(), + channel = %ch.as_str(), + "created notification for cron result in user session" + ); + } + } + + // Update execution node body + let eid = exec_id; + let _ = db + .call(move |conn| { + conn.execute( + "UPDATE nodes SET body = ?1 WHERE id = ?2", + rusqlite::params![result_body, eid], + )?; + Ok(()) + }) + .await; + + tracing::info!("cron execution [{status}]: {job_title}"); + }); +} diff --git a/src/session/mod.rs b/src/session/mod.rs new file mode 100644 index 0000000..e81b9ac --- /dev/null +++ b/src/session/mod.rs @@ -0,0 +1,195 @@ +//! Session manager — one active session per (user_id, channel). +//! +//! In omni-cede, sessions are scoped to a specific user on a specific channel. +//! A WhatsApp conversation has its own session; the same user on Telegram gets +//! a separate one. The recency window in the engine's hybrid recall operates +//! on the session, giving each channel its own conversational flow while the +//! semantic (HNSW) layer searches the global graph — cross-channel knowledge. +//! +//! Sessions are stored both as graph nodes (for the engine's native recall) +//! and in a lightweight lookup table for fast resolution by (user_id, channel). + +use rusqlite::{params, Connection, OptionalExtension}; +use serde::{Deserialize, Serialize}; + +use crate::db::Db; +use crate::db::queries; +use crate::error::Result; +use crate::types::{Node, NodeId}; + +/// Metadata for a managed session. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManagedSession { + /// Graph node ID — this is also the session_id passed to `agent.run_turn()`. + pub node_id: NodeId, + /// Internal user ID from the identity layer. + pub user_id: String, + /// Channel this session belongs to (e.g. "whatsapp", "telegram", "api"). + pub channel: String, + /// Unix timestamp when this session was created. + pub created_at: i64, + /// Number of turns processed in this session. + pub turn_count: i64, + /// Unix timestamp of the last turn. + pub last_active: i64, +} + +/// Create the session lookup table if it doesn't exist. +pub fn create_tables(conn: &Connection) -> std::result::Result<(), rusqlite::Error> { + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS managed_sessions ( + node_id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + channel TEXT NOT NULL, + created_at INTEGER NOT NULL, + turn_count INTEGER NOT NULL DEFAULT 0, + last_active INTEGER NOT NULL, + UNIQUE(user_id, channel) + ); + + CREATE INDEX IF NOT EXISTS idx_managed_sessions_user + ON managed_sessions(user_id);", + )?; + Ok(()) +} + +/// Get or create the active session for a (user_id, channel) pair. +/// +/// If a session already exists, returns it (and bumps `last_active`). +/// Otherwise creates a new `Node::session()` in the graph and a row in +/// the lookup table. +pub async fn get_or_create( + db: &Db, + user_id: &str, + channel: &str, +) -> Result { + let uid = user_id.to_string(); + let ch = channel.to_string(); + + db.call(move |conn| { + create_tables(conn)?; + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + // Try to find existing session + let existing: Option = conn + .query_row( + "SELECT node_id, user_id, channel, created_at, turn_count, last_active + FROM managed_sessions + WHERE user_id = ?1 AND channel = ?2", + params![uid, ch], + |row| { + Ok(ManagedSession { + node_id: row.get(0)?, + user_id: row.get(1)?, + channel: row.get(2)?, + created_at: row.get(3)?, + turn_count: row.get(4)?, + last_active: row.get(5)?, + }) + }, + ) + .optional()?; + + if let Some(mut session) = existing { + // Bump last_active + conn.execute( + "UPDATE managed_sessions SET last_active = ?1 WHERE node_id = ?2", + params![now, session.node_id], + )?; + session.last_active = now; + return Ok(session); + } + + // Create a new session node in the graph + let session_node = Node::session(&format!("{ch} session for {uid}")); + let node_id = session_node.id.clone(); + queries::insert_node(conn, &session_node)?; + + // Insert into the lookup table + conn.execute( + "INSERT INTO managed_sessions (node_id, user_id, channel, created_at, turn_count, last_active) + VALUES (?1, ?2, ?3, ?4, 0, ?5)", + params![node_id, uid, ch, now, now], + )?; + + Ok(ManagedSession { + node_id, + user_id: uid, + channel: ch, + created_at: now, + turn_count: 0, + last_active: now, + }) + }) + .await +} + +/// Increment the turn count for a session after a successful turn. +pub async fn record_turn(db: &Db, session_node_id: &str) -> Result<()> { + let nid = session_node_id.to_string(); + db.call(move |conn| { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + conn.execute( + "UPDATE managed_sessions SET turn_count = turn_count + 1, last_active = ?1 WHERE node_id = ?2", + params![now, nid], + )?; + Ok(()) + }) + .await +} + +/// List all sessions for a user. +pub async fn list_user_sessions(db: &Db, user_id: &str) -> Result> { + let uid = user_id.to_string(); + db.call(move |conn| { + create_tables(conn)?; + let mut stmt = conn.prepare( + "SELECT node_id, user_id, channel, created_at, turn_count, last_active + FROM managed_sessions + WHERE user_id = ?1 + ORDER BY last_active DESC", + )?; + let rows = stmt.query_map(params![uid], |row| { + Ok(ManagedSession { + node_id: row.get(0)?, + user_id: row.get(1)?, + channel: row.get(2)?, + created_at: row.get(3)?, + turn_count: row.get(4)?, + last_active: row.get(5)?, + }) + })?; + let mut result = Vec::new(); + for r in rows { + result.push(r?); + } + Ok(result) + }) + .await +} + +/// Get total session count and total turn count. +pub async fn stats(db: &Db) -> Result<(i64, i64)> { + db.call(move |conn| { + create_tables(conn)?; + let session_count: i64 = conn.query_row( + "SELECT COUNT(*) FROM managed_sessions", + [], + |row| row.get(0), + )?; + let turn_count: i64 = conn.query_row( + "SELECT COALESCE(SUM(turn_count), 0) FROM managed_sessions", + [], + |row| row.get(0), + )?; + Ok((session_count, turn_count)) + }) + .await +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 5417287..99f7bc4 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,8 +1,10 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; +use std::str::FromStr; use std::sync::Arc; +use tokio::process::Command as TokioCommand; use tokio::sync::RwLock; use crate::db::Db; @@ -10,6 +12,7 @@ use crate::db::queries; use crate::embed::EmbedHandle; use crate::error::{CortexError, Result}; use crate::hnsw::VectorIndex; +use crate::memory::format_timestamp; use crate::types::*; /// A tool the agent can call. The handler is an async function that @@ -26,7 +29,20 @@ pub struct Tool { >, } +impl Clone for Tool { + fn clone(&self) -> Self { + Self { + name: self.name.clone(), + description: self.description.clone(), + input_schema: self.input_schema.clone(), + trust: self.trust, + handler: self.handler.clone(), + } + } +} + /// Registry of available tools. +#[derive(Clone)] pub struct ToolRegistry { tools: HashMap, } @@ -63,13 +79,17 @@ impl ToolRegistry { .get(name) .ok_or_else(|| CortexError::Tool(format!("unknown tool: {name}")))?; + // Validate input against schema before executing + self.validate_input(name, &input)?; + let trust = tool.trust; let result = (tool.handler)(input.clone()).await?; // Write ToolCall node + let ts = format_timestamp(crate::types::now_unix()); let tool_call_node = Node { kind: NodeKind::ToolCall, - title: format!("ToolCall: {name}"), + title: format!("[{ts}] Used {name}"), body: Some(serde_json::json!({ "tool": name, "input": input, @@ -77,7 +97,7 @@ impl ToolRegistry { "success": result.success, }).to_string()), trust_score: trust as f64, - ..Node::new(NodeKind::ToolCall, format!("ToolCall: {name}")) + ..Node::new(NodeKind::ToolCall, format!("[{ts}] Used {name}")) }; let tc_id = tool_call_node.id.clone(); db.call({ @@ -92,7 +112,7 @@ impl ToolRegistry { // If success, write Fact derived from tool result if result.success { - let fact = Node::new(NodeKind::Fact, format!("Result: {name}")) + let fact = Node::new(NodeKind::Fact, format!("Output from {name}")) .with_body(&result.output) .with_trust(trust as f64); let fact_id = fact.id.clone(); @@ -112,6 +132,97 @@ impl ToolRegistry { Ok(result) } + /// Validate tool input against its JSON schema. + pub fn validate_input(&self, name: &str, input: &serde_json::Value) -> Result<()> { + let tool = self + .get(name) + .ok_or_else(|| CortexError::Tool(format!("unknown tool: {name}")))?; + + // Skip validation for tools with no meaningful schema + if tool.input_schema.is_null() || tool.input_schema.as_object().is_none() { + return Ok(()); + } + + let validator = jsonschema::validator_for(&tool.input_schema).map_err(|e| { + CortexError::Tool(format!("invalid schema for tool '{name}': {e}")) + })?; + + if let Err(e) = validator.validate(&input) { + return Err(CortexError::Tool(format!( + "Input validation failed for tool '{name}': {e}" + ))); + } + + Ok(()) + } + + /// Get a cloneable handler function for a tool (for parallel execution). + pub fn get_handler( + &self, + name: &str, + ) -> Option< + Arc< + dyn Fn(serde_json::Value) -> Pin> + Send>> + + Send + + Sync, + >, + > { + self.tools.get(name).map(|t| t.handler.clone()) + } + + /// Record a tool call's graph nodes after parallel execution. + pub async fn record_tool_call( + &self, + name: &str, + result: &ToolResult, + iter_node: NodeId, + db: &Db, + auto_link_tx: &async_channel::Sender, + ) -> Result<()> { + let trust = self.get(name).map(|t| t.trust).unwrap_or(0.5); + + let ts = format_timestamp(crate::types::now_unix()); + let tool_call_node = Node { + kind: NodeKind::ToolCall, + title: format!("[{ts}] Used {name}"), + body: Some(serde_json::json!({ + "tool": name, + "output": &result.output, + "success": result.success, + }).to_string()), + trust_score: trust as f64, + ..Node::new(NodeKind::ToolCall, format!("[{ts}] Used {name}")) + }; + let tc_id = tool_call_node.id.clone(); + db.call({ + let node = tool_call_node; + move |conn| queries::insert_node(conn, &node) + }) + .await?; + + let edge = Edge::new(tc_id.clone(), iter_node, EdgeKind::PartOf); + db.call(move |conn| queries::insert_edge(conn, &edge)).await?; + + if result.success { + let fact = Node::new(NodeKind::Fact, format!("[{ts}] Output from {name}")) + .with_body(&result.output) + .with_trust(trust as f64); + let fact_id = fact.id.clone(); + db.call({ + let fact = fact; + move |conn| queries::insert_node(conn, &fact) + }) + .await?; + + let derives = Edge::new(fact_id.clone(), tc_id, EdgeKind::DerivesFrom); + db.call(move |conn| queries::insert_edge(conn, &derives)).await?; + + let _ = auto_link_tx.try_send(fact_id); + } + + Ok(()) + } + /// Build a JSON schema description of all tools (for LLM system prompt). pub fn schema_json(&self) -> serde_json::Value { let tools: Vec = self @@ -145,10 +256,10 @@ impl ToolRegistry { // ─── Built-in tools ───────────────────────────────────── -/// Create a registry pre-loaded with the built-in cortex tools. -/// Pass `llm` to enable the `delegate` tool (sub-agent spawning). -/// Pass `None` to create a registry without delegation (used by sub-agents to prevent recursion). -pub fn builtin_registry( +/// Create a registry pre-loaded with the built-in cortex tools (synchronous). +/// This contains all tool definitions but does NOT load persisted skills from the DB. +/// Use `builtin_registry()` (async) for the full registry including persisted skills. +pub fn builtin_registry_core( db: Db, embed: EmbedHandle, hnsw: Arc>, @@ -851,7 +962,167 @@ pub fn builtin_registry( }), }); - // ── delegate: spawn a sub-agent for a focused task ── + // ── bash: execute shell commands on the host ── + if config.bash_enabled { + let blocked = config.bash_blocked_patterns.clone(); + let timeout_secs = config.bash_timeout_secs; + let max_output = config.bash_max_output_bytes; + reg.register(Tool { + name: "bash".to_string(), + description: concat!( + "Execute a shell command on the host machine and return its output. ", + "On Linux/macOS this runs via /bin/sh -c, on Windows via cmd /C. ", + "Use this for file operations, system inspection, running scripts, ", + "installing packages, managing services, or any task that requires ", + "interacting with the operating system. ", + "Commands have a timeout and dangerous operations are blocked. ", + "Always prefer single commands; for multi-step work, call bash multiple times." + ).to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute" + }, + "working_dir": { + "type": "string", + "description": "Working directory for the command (optional, defaults to current dir)" + }, + "timeout_secs": { + "type": "integer", + "description": "Override timeout in seconds (optional, max 300)" + } + }, + "required": ["command"] + }), + trust: 0.7, + handler: Arc::new(move |input| { + let blocked = blocked.clone(); + let timeout_secs = timeout_secs; + let max_output = max_output; + Box::pin(async move { + let command = input["command"].as_str().unwrap_or("").to_string(); + if command.is_empty() { + return Ok(ToolResult { + output: "Error: command is required.".into(), + success: false, + }); + } + + // Safety: check against blocked patterns + let cmd_lower = command.to_lowercase(); + for pattern in &blocked { + if cmd_lower.contains(&pattern.to_lowercase()) { + return Ok(ToolResult { + output: format!( + "Blocked: command matches safety pattern '{}'. This operation is not allowed.", + pattern + ), + success: false, + }); + } + } + + // Resolve timeout (user override capped at 300s) + let timeout = std::time::Duration::from_secs( + input["timeout_secs"] + .as_u64() + .unwrap_or(timeout_secs) + .min(300), + ); + + // Build the OS-appropriate command + let mut cmd = if cfg!(target_os = "windows") { + let mut c = TokioCommand::new("cmd"); + c.args(["/C", &command]); + c + } else { + let mut c = TokioCommand::new("/bin/sh"); + c.args(["-c", &command]); + c + }; + + // Set working directory if provided + if let Some(dir) = input["working_dir"].as_str() { + cmd.current_dir(dir); + } + + // Capture stdout + stderr + cmd.stdout(std::process::Stdio::piped()); + cmd.stderr(std::process::Stdio::piped()); + + // Spawn and await with timeout + let child = cmd.spawn(); + let child = match child { + Ok(c) => c, + Err(e) => { + return Ok(ToolResult { + output: format!("Failed to spawn command: {e}"), + success: false, + }); + } + }; + + let result = tokio::time::timeout(timeout, child.wait_with_output()).await; + + match result { + Ok(Ok(output)) => { + let code = output.status.code().unwrap_or(-1); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + // Combine output, truncate if needed + let mut combined = String::new(); + if !stdout.is_empty() { + combined.push_str(&stdout); + } + if !stderr.is_empty() { + if !combined.is_empty() { + combined.push_str("\n--- stderr ---\n"); + } + combined.push_str(&stderr); + } + if combined.is_empty() { + combined = "(no output)".into(); + } + + // Truncate to max_output bytes + if combined.len() > max_output { + combined.truncate(max_output); + combined.push_str(&format!( + "\n... [truncated at {} bytes]", + max_output + )); + } + + let success = output.status.success(); + Ok(ToolResult { + output: format!( + "[exit code: {}]\n{}", + code, combined + ), + success, + }) + } + Ok(Err(e)) => Ok(ToolResult { + output: format!("Command execution error: {e}"), + success: false, + }), + Err(_) => Ok(ToolResult { + output: format!( + "Command timed out after {} seconds and was killed.", + timeout.as_secs() + ), + success: false, + }), + } + }) + }), + }); + } + + // ── spawn_task: kick off a background autonomous loop ── if let Some(llm) = llm { let db = db.clone(); let embed = embed.clone(); @@ -859,34 +1130,36 @@ pub fn builtin_registry( let auto_link_tx = auto_link_tx.clone(); let config = config.clone(); reg.register(Tool { - name: "delegate".to_string(), + name: "spawn_task".to_string(), description: concat!( - "Spawn a sub-agent to handle a focused task independently. ", - "The sub-agent gets its own session, full memory access (recall/remember/etc), ", - "and runs up to max_iterations loops before returning its answer. ", - "Use this for tasks that need focused research, multi-step reasoning, ", - "or when you want to explore a topic without cluttering the main conversation. ", - "The sub-agent's work is recorded in the graph as a Delegation." + "Spawn a background task that runs autonomously. The task gets its own ", + "agent loop with full tool access (recall, remember, bash, etc.) and writes ", + "all results directly to the graph. Returns immediately with a task ID — ", + "the agent does NOT wait for the task to finish. ", + "Use this for multi-step autonomous work: research, file processing, ", + "system maintenance, report generation, or any task that would take ", + "multiple tool calls to complete. ", + "Results are discoverable via recall once the task finishes." ).to_string(), input_schema: serde_json::json!({ "type": "object", "properties": { "task": { "type": "string", - "description": "What the sub-agent should do. Be specific and self-contained." + "description": "What the background task should accomplish. Be specific." }, "context": { "type": "string", - "description": "Additional context or constraints for the sub-agent (optional)" + "description": "Additional context or constraints (optional)" }, "max_iterations": { "type": "integer", - "description": "Max loops the sub-agent can run (default: 5, max: 10)" + "description": "Max agent loop iterations (default: 10, max: 25)" } }, "required": ["task"] }), - trust: 0.9, + trust: 0.8, handler: Arc::new(move |input| { let db = db.clone(); let embed = embed.clone(); @@ -898,8 +1171,8 @@ pub fn builtin_registry( let task = input["task"].as_str().unwrap_or("").to_string(); let context = input["context"].as_str().unwrap_or("").to_string(); let max_iter = input["max_iterations"].as_u64() - .unwrap_or(5) - .min(10) as usize; + .unwrap_or(10) + .min(25) as usize; if task.is_empty() { return Ok(ToolResult { @@ -908,97 +1181,641 @@ pub fn builtin_registry( }); } - // Build the full prompt for the sub-agent let full_task = if context.is_empty() { task.clone() } else { - format!("{task}\n\nAdditional context: {context}") + format!("{task}\n\nContext: {context}") }; - // Write a Delegation node - let delegation = Node::new(NodeKind::Delegation, format!("Delegate: {}", &task)) - .with_body(&full_task) - .with_importance(0.4); - let delegation_id = delegation.id.clone(); + let task_node = Node::new( + NodeKind::BackgroundTask, + format!("[{}] Working on: {}", format_timestamp(crate::types::now_unix()), &task), + ) + .with_body(&format!("Status: running\n\n{full_task}")) + .with_importance(0.6); + let task_id = task_node.id.clone(); db.call({ - let d = delegation.clone(); - move |conn| queries::insert_node(conn, &d) + let n = task_node; + move |conn| queries::insert_node(conn, &n) }) .await?; - // Build sub-agent config with capped iterations - let mut sub_config = config.clone(); - sub_config.max_iterations = max_iter; - - // Sub-agent gets all tools EXCEPT delegate (llm=None prevents recursion) - let sub_tools = builtin_registry( - db.clone(), - embed.clone(), - hnsw.clone(), - auto_link_tx.clone(), - None, - sub_config.clone(), - ); + let bg_task_id = task_id.clone(); + let bg_db = db.clone(); + let bg_embed = embed.clone(); + let bg_hnsw = hnsw.clone(); + let bg_auto_link_tx = auto_link_tx.clone(); + let bg_llm = llm.clone(); + let bg_config = config.clone(); + + tokio::spawn(async move { + let bg_tools = builtin_registry_core( + bg_db.clone(), + bg_embed.clone(), + bg_hnsw.clone(), + bg_auto_link_tx.clone(), + None, + bg_config.clone(), + ); + + let mut bg_agent_config = bg_config; + bg_agent_config.max_iterations = max_iter; + + let agent = crate::agent::orchestrator::Agent { + db: bg_db.clone(), + embed: bg_embed, + hnsw: bg_hnsw, + config: bg_agent_config, + llm: bg_llm, + tools: bg_tools, + auto_link_tx: bg_auto_link_tx.clone(), + }; + + let result = agent.run(&full_task).await; + + let (status, body) = match result { + Ok(answer) => ("completed", format!("Status: completed\n\n{answer}")), + Err(e) => ("failed", format!("Status: failed\n\nError: {e}")), + }; + + let result_fact = Node::new( + NodeKind::Fact, + format!("Finished: {}", &task), + ) + .with_body(&body) + .with_importance(0.6); + let fact_id = result_fact.id.clone(); + let _ = bg_db + .call({ + let f = result_fact; + move |conn| queries::insert_node(conn, &f) + }) + .await; + + let edge = Edge::new( + fact_id.clone(), + bg_task_id, + EdgeKind::DerivesFrom, + ); + let _ = bg_db + .call(move |conn| queries::insert_edge(conn, &edge)) + .await; + + let _ = bg_auto_link_tx.try_send(fact_id); + + eprintln!("[background task {status}]: {task}"); + }); + + Ok(ToolResult { + output: format!( + "Background task spawned (id: {}). It will run autonomously and write results to the graph. Use recall to check for results later.", + &task_id[..8] + ), + success: true, + }) + }) + }), + }); + } + + // ── schedule_cron: create a recurring scheduled task ── + { + let db = db.clone(); + reg.register(Tool { + name: "schedule_cron".to_string(), + description: concat!( + "Create a recurring scheduled task (cron job). The task runs autonomously ", + "on the specified schedule with its own agent loop and full tool access. ", + "Results are stored in the graph as CronExecution nodes. ", + "Use standard 7-field cron expressions: sec min hour day month weekday year. ", + "Examples: '0 0 * * * * *' (every hour), '0 */30 * * * * *' (every 30 min), ", + "'0 0 9 * * MON-FRI *' (9am weekdays)." + ).to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Short name for the cron job (e.g. 'Daily health check')" + }, + "cron": { + "type": "string", + "description": "Cron expression (7 fields: sec min hour day month weekday year)" + }, + "task": { + "type": "string", + "description": "What the agent should do each time this fires. Be specific." + }, + "max_iterations": { + "type": "integer", + "description": "Max agent loop iterations per execution (default: 5, max: 15)" + } + }, + "required": ["name", "cron", "task"] + }), + trust: 0.8, + handler: Arc::new(move |input| { + let db = db.clone(); + Box::pin(async move { + let name = input["name"].as_str().unwrap_or("Unnamed cron").to_string(); + let cron_expr = input["cron"].as_str().unwrap_or("").to_string(); + let task = input["task"].as_str().unwrap_or("").to_string(); + let max_iter = input["max_iterations"].as_u64().unwrap_or(5).min(15) as usize; + + if cron_expr.is_empty() || task.is_empty() { + return Ok(ToolResult { + output: "Error: cron and task are required.".into(), + success: false, + }); + } - let sub_agent = crate::agent::orchestrator::Agent { - db: db.clone(), - embed: embed.clone(), - hnsw: hnsw.clone(), - config: sub_config, - llm: llm.clone(), - tools: sub_tools, - auto_link_tx: auto_link_tx.clone(), + // Validate cron expression + if cron::Schedule::from_str(&cron_expr).is_err() { + return Ok(ToolResult { + output: format!( + "Invalid cron expression: '{}'. Use 7 fields: sec min hour day month weekday year.", + cron_expr + ), + success: false, + }); + } + + // Look up the active user's session to tag the cron job + // with owner info so results route back to the right channel. + let session_owner: Option<(String, String)> = { + let db2 = db.clone(); + db2.call(|conn| { + crate::session::create_tables(conn)?; + let mut stmt = conn.prepare( + "SELECT user_id, channel FROM managed_sessions ORDER BY last_active DESC LIMIT 1", + )?; + let rows: Vec<(String, String)> = stmt + .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))? + .filter_map(|r| r.ok()) + .collect(); + Ok(rows.into_iter().next()) + }) + .await + .ok() + .flatten() }; - // Run the sub-agent - let result = sub_agent.run(&full_task).await; + let meta = crate::scheduler::CronJobMeta { + cron: cron_expr.clone(), + task: task.clone(), + max_iterations: max_iter, + enabled: true, + last_fired: 0, + user_id: session_owner.as_ref().map(|(u, _)| u.clone()), + channel: session_owner.as_ref().map(|(_, c)| c.clone()), + }; - match result { - Ok(answer) => { - // Write Synthesis node with the result - let synthesis = Node::new( - NodeKind::Synthesis, - format!("Synthesis: {}", &task), - ) - .with_body(&answer) - .with_importance(0.6); - let synthesis_id = synthesis.id.clone(); - db.call({ - let s = synthesis.clone(); - move |conn| queries::insert_node(conn, &s) - }) - .await?; + let node = Node { + kind: NodeKind::CronJob, + title: format!("Cron: {name}"), + body: Some(serde_json::to_string(&meta).unwrap()), + importance: 0.8, + decay_rate: 0.0, + ..Node::new(NodeKind::CronJob, format!("Cron: {name}")) + }; + let node_id = node.id.clone(); + db.call({ + let n = node; + move |conn| queries::insert_node(conn, &n) + }) + .await?; + + Ok(ToolResult { + output: format!( + "Cron job created: '{}' (id: {})\nSchedule: {}\nTask: {}", + name, &node_id[..8], cron_expr, task + ), + success: true, + }) + }) + }), + }); + } + + // ── delete_cron: remove a scheduled task ── + { + let db = db.clone(); + reg.register(Tool { + name: "delete_cron".to_string(), + description: "Delete a cron job by its node ID prefix. Use list_crons to find IDs.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "node_id": { + "type": "string", + "description": "Cron job node ID or unique prefix (at least 6 chars)" + } + }, + "required": ["node_id"] + }), + trust: 0.8, + handler: Arc::new(move |input| { + let db = db.clone(); + Box::pin(async move { + let raw_id = input["node_id"].as_str().unwrap_or("").to_string(); + if raw_id.len() < 6 { + return Ok(ToolResult { + output: "Error: node_id must be at least 6 characters.".into(), + success: false, + }); + } + + let full_id = { + let rid = raw_id.clone(); + let matches = db.call(move |conn| queries::find_nodes_by_prefix(conn, &rid)).await?; + match matches.len() { + 0 => return Ok(ToolResult { + output: format!("No node found with prefix '{raw_id}'"), + success: false, + }), + 1 => matches.into_iter().next().unwrap(), + n => return Ok(ToolResult { + output: format!("Ambiguous prefix '{raw_id}' matches {n} nodes."), + success: false, + }), + } + }; + + // Verify it's a CronJob + let node = { + let id = full_id.clone(); + db.call(move |conn| queries::get_node(conn, &id)).await? + }; + match &node { + Some(n) if n.kind == NodeKind::CronJob => {}, + Some(n) => return Ok(ToolResult { + output: format!("Node {} is a {}, not a cron_job.", &full_id[..8], n.kind), + success: false, + }), + None => return Ok(ToolResult { + output: format!("Node {raw_id} not found."), + success: false, + }), + } - // Link: Synthesis ──DerivesFrom──▸ Delegation - let edge = Edge::new( - synthesis_id.clone(), - delegation_id.clone(), - EdgeKind::DerivesFrom, - ); - db.call(move |conn| queries::insert_edge(conn, &edge)).await?; + let title = node.unwrap().title; + let id_del = full_id.clone(); + db.call(move |conn| queries::delete_node(conn, &id_del)).await?; - // Enqueue synthesis for auto-linking - let _ = auto_link_tx.try_send(synthesis_id); + Ok(ToolResult { + output: format!("Deleted cron job '{}' ({})", title, &full_id[..8]), + success: true, + }) + }) + }), + }); + } - Ok(ToolResult { - output: format!( - "[Sub-agent completed]\n\n{}", - answer - ), - success: true, - }) + // ── list_crons: show all scheduled tasks ── + { + let db = db.clone(); + reg.register(Tool { + name: "list_crons".to_string(), + description: "List all cron jobs (scheduled tasks) currently in the graph.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + trust: 1.0, + handler: Arc::new(move |_input| { + let db = db.clone(); + Box::pin(async move { + let nodes = db + .call(|conn| queries::get_nodes_by_kind(conn, NodeKind::CronJob)) + .await?; + + if nodes.is_empty() { + return Ok(ToolResult { + output: "No cron jobs found.".to_string(), + success: true, + }); + } + + let mut out = format!("{} cron job(s):\n", nodes.len()); + for n in &nodes { + let meta: Option = n + .body + .as_deref() + .and_then(|b| serde_json::from_str(b).ok()); + if let Some(m) = meta { + let status = if m.enabled { "active" } else { "paused" }; + let last = if m.last_fired == 0 { + "never".to_string() + } else { + chrono::DateTime::from_timestamp(m.last_fired, 0) + .map(|dt| dt.format("%Y-%m-%d %H:%M UTC").to_string()) + .unwrap_or_else(|| "?".to_string()) + }; + out.push_str(&format!( + "- {} (id: {}, {}) — schedule: '{}', last: {}\n task: {}\n", + n.title, &n.id[..8], status, m.cron, last, m.task, + )); + } else { + out.push_str(&format!("- {} (id: {}, invalid metadata)\n", n.title, &n.id[..8])); } - Err(e) => { - Ok(ToolResult { - output: format!("Sub-agent error: {e}"), + } + Ok(ToolResult { output: out, success: true }) + }) + }), + }); + } + + // ── create_skill: define a dynamic prompt-based tool ── + { + let db = db.clone(); + reg.register(Tool { + name: "create_skill".to_string(), + description: concat!( + "Create a dynamic skill (prompt-based tool). When another agent or session ", + "calls this skill, the LLM receives the skill's instructions plus the caller's ", + "input, and returns a result. Skills persist in the graph as Skill nodes and ", + "are available after restart. Use this for reusable capabilities: ", + "code review templates, analysis frameworks, domain-specific procedures." + ).to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Tool name (snake_case, e.g. 'code_review')" + }, + "description": { + "type": "string", + "description": "What this skill does (shown to the LLM when choosing tools)" + }, + "instructions": { + "type": "string", + "description": "Detailed instructions for executing this skill. This becomes the system prompt when the skill runs." + }, + "input_fields": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string", "enum": ["string", "number", "boolean"] }, + "description": { "type": "string" }, + "required": { "type": "boolean" } + } + }, + "description": "Input parameters the skill accepts (optional — defaults to a single 'input' string field)" + } + }, + "required": ["name", "description", "instructions"] + }), + trust: 0.8, + handler: Arc::new(move |input| { + let db = db.clone(); + Box::pin(async move { + let name = input["name"].as_str().unwrap_or("").to_string(); + let description = input["description"].as_str().unwrap_or("").to_string(); + let instructions = input["instructions"].as_str().unwrap_or("").to_string(); + let input_fields = input["input_fields"].clone(); + + if name.is_empty() || description.is_empty() || instructions.is_empty() { + return Ok(ToolResult { + output: "Error: name, description, and instructions are all required.".into(), + success: false, + }); + } + + // Validate name is snake_case-ish + if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { + return Ok(ToolResult { + output: "Error: skill name must be alphanumeric with underscores only.".into(), + success: false, + }); + } + + // Build the skill definition + let skill_def = serde_json::json!({ + "name": name, + "description": description, + "instructions": instructions, + "input_fields": input_fields, + }); + + let node = Node { + kind: NodeKind::Skill, + title: format!("Skill: {name}"), + body: Some(serde_json::to_string(&skill_def).unwrap()), + importance: 0.8, + decay_rate: 0.0, + ..Node::new(NodeKind::Skill, format!("Skill: {name}")) + }; + let node_id = node.id.clone(); + db.call({ + let n = node; + move |conn| queries::insert_node(conn, &n) + }) + .await?; + + Ok(ToolResult { + output: format!( + "Skill '{}' created (id: {}). It will be available as a tool in new sessions after restart.", + name, &node_id[..8] + ), + success: true, + }) + }) + }), + }); + } + + // ── delete_skill: remove a dynamic skill ── + { + let db = db.clone(); + reg.register(Tool { + name: "delete_skill".to_string(), + description: "Delete a dynamic skill by its node ID prefix. Use list_memories with kind=Skill to find IDs.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "node_id": { + "type": "string", + "description": "Skill node ID or unique prefix (at least 6 chars)" + } + }, + "required": ["node_id"] + }), + trust: 0.8, + handler: Arc::new(move |input| { + let db = db.clone(); + Box::pin(async move { + let raw_id = input["node_id"].as_str().unwrap_or("").to_string(); + if raw_id.len() < 6 { + return Ok(ToolResult { + output: "Error: node_id must be at least 6 characters.".into(), + success: false, + }); + } + + let full_id = { + let rid = raw_id.clone(); + let matches = db.call(move |conn| queries::find_nodes_by_prefix(conn, &rid)).await?; + match matches.len() { + 0 => return Ok(ToolResult { + output: format!("No node found with prefix '{raw_id}'"), success: false, - }) + }), + 1 => matches.into_iter().next().unwrap(), + n => return Ok(ToolResult { + output: format!("Ambiguous prefix '{raw_id}' matches {n} nodes."), + success: false, + }), } + }; + + // Verify it's a Skill + let node = { + let id = full_id.clone(); + db.call(move |conn| queries::get_node(conn, &id)).await? + }; + match &node { + Some(n) if n.kind == NodeKind::Skill => {}, + Some(n) => return Ok(ToolResult { + output: format!("Node {} is a {}, not a skill.", &full_id[..8], n.kind), + success: false, + }), + None => return Ok(ToolResult { + output: format!("Node {raw_id} not found."), + success: false, + }), } + + let title = node.unwrap().title; + let id_del = full_id.clone(); + db.call(move |conn| queries::delete_node(conn, &id_del)).await?; + + Ok(ToolResult { + output: format!("Deleted skill '{}' ({})", title, &full_id[..8]), + success: true, + }) }) }), }); } + // ── Browser tools (feature-gated) ── + #[cfg(feature = "browser")] + { + crate::browser::tools::register_browser_tools(&mut reg); + } + + reg +} + +/// Create a full registry including persisted skills loaded from the graph. +/// This is the async version that wraps `builtin_registry_core` and adds +/// dynamically-created skill tools from the DB. +pub async fn builtin_registry( + db: Db, + embed: EmbedHandle, + hnsw: Arc>, + auto_link_tx: async_channel::Sender, + llm: Option>, + config: crate::config::Config, +) -> ToolRegistry { + let mut reg = builtin_registry_core( + db.clone(), embed, hnsw, auto_link_tx, llm, config, + ); + + // ── Load persisted dynamic skills from graph ── + // They become prompt-based tools that delegate to the LLM. + { + let skill_nodes = match db.call(|conn| queries::get_nodes_by_kind(conn, NodeKind::Skill)).await { + Ok(nodes) => nodes, + Err(e) => { + tracing::warn!("failed to load persisted skills: {e}"); + vec![] + } + }; + + for skill_node in skill_nodes { + let skill_def: serde_json::Value = match &skill_node.body { + Some(body) => match serde_json::from_str(body) { + Ok(v) => v, + Err(_) => continue, + }, + None => continue, + }; + + let skill_name = skill_def["name"].as_str().unwrap_or("").to_string(); + let skill_desc = skill_def["description"].as_str().unwrap_or("").to_string(); + let instructions = skill_def["instructions"].as_str().unwrap_or("").to_string(); + + if skill_name.is_empty() || instructions.is_empty() { + continue; + } + + // Build input schema from input_fields or use default + let input_schema = if let Some(fields) = skill_def["input_fields"].as_array() { + let mut properties = serde_json::Map::new(); + let mut required = Vec::new(); + for field in fields { + let fname = field["name"].as_str().unwrap_or("input"); + let ftype = field["type"].as_str().unwrap_or("string"); + let fdesc = field["description"].as_str().unwrap_or(""); + properties.insert( + fname.to_string(), + serde_json::json!({ "type": ftype, "description": fdesc }), + ); + if field["required"].as_bool().unwrap_or(false) { + required.push(serde_json::Value::String(fname.to_string())); + } + } + serde_json::json!({ + "type": "object", + "properties": properties, + "required": required, + }) + } else { + serde_json::json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "Input for this skill" + } + }, + "required": ["input"] + }) + }; + + // The handler: format the input with instructions and return + // a structured prompt result. The actual LLM call happens when + // the orchestrator processes the tool result. + let instr = instructions.clone(); + reg.register(Tool { + name: format!("skill_{skill_name}"), + description: format!("[Dynamic Skill] {skill_desc}"), + input_schema, + trust: 0.7, + handler: Arc::new(move |input| { + let instr = instr.clone(); + Box::pin(async move { + // Build a formatted prompt from instructions + input + let input_str = serde_json::to_string_pretty(&input).unwrap_or_default(); + Ok(ToolResult { + output: format!( + "[Skill Execution]\nInstructions: {instr}\n\nInput: {input_str}\n\n\ + Please follow the instructions above to process this input and provide your result." + ), + success: true, + }) + }) + }), + }); + + tracing::info!("loaded persisted skill: skill_{skill_name}"); + } + } + reg } diff --git a/src/types.rs b/src/types.rs index a76a8b7..8fb09b6 100644 --- a/src/types.rs +++ b/src/types.rs @@ -28,10 +28,15 @@ pub enum NodeKind { LlmCall, ToolCall, LoopIteration, - // Sub-agents - SubAgent, - Delegation, - Synthesis, + // Background tasks + BackgroundTask, + // Scheduled tasks + CronJob, + CronExecution, + // Dynamic skills / plugins + Skill, + // Notifications — delivered via graph, not a separate table + Notification, // Self-model — medium decay Pattern, Limitation, @@ -54,9 +59,11 @@ impl NodeKind { Self::LlmCall => "llm_call", Self::ToolCall => "tool_call", Self::LoopIteration => "loop_iteration", - Self::SubAgent => "sub_agent", - Self::Delegation => "delegation", - Self::Synthesis => "synthesis", + Self::BackgroundTask => "background_task", + Self::CronJob => "cron_job", + Self::CronExecution => "cron_execution", + Self::Skill => "skill", + Self::Notification => "notification", Self::Pattern => "pattern", Self::Limitation => "limitation", Self::Capability => "capability", @@ -78,9 +85,11 @@ impl NodeKind { "llm_call" => Some(Self::LlmCall), "tool_call" => Some(Self::ToolCall), "loop_iteration" => Some(Self::LoopIteration), - "sub_agent" => Some(Self::SubAgent), - "delegation" => Some(Self::Delegation), - "synthesis" => Some(Self::Synthesis), + "background_task" => Some(Self::BackgroundTask), + "cron_job" => Some(Self::CronJob), + "cron_execution" => Some(Self::CronExecution), + "skill" => Some(Self::Skill), + "notification" => Some(Self::Notification), "pattern" => Some(Self::Pattern), "limitation" => Some(Self::Limitation), "capability" => Some(Self::Capability), @@ -95,6 +104,12 @@ impl NodeKind { Self::Soul | Self::Belief | Self::Goal => 0.0, // User inputs decay moderately (they're conversation context) Self::UserInput => 0.02, + // Cron definitions persist like identity + Self::CronJob | Self::Skill => 0.0, + // Cron executions decay fast like operational nodes + Self::CronExecution => 0.05, + // Notifications decay fast (ephemeral once delivered) + Self::Notification => 0.05, // Operational nodes decay fast Self::Session | Self::Turn | Self::LlmCall | Self::ToolCall | Self::LoopIteration => 0.05, @@ -109,7 +124,10 @@ impl NodeKind { pub fn default_importance(&self) -> f64 { match self { Self::Soul | Self::Belief | Self::Goal => 1.0, + Self::CronJob | Self::Skill => 0.8, Self::UserInput => 0.4, + Self::CronExecution => 0.2, + Self::Notification => 0.3, Self::Session | Self::Turn | Self::LlmCall | Self::ToolCall | Self::LoopIteration => 0.2, _ => 0.5, @@ -169,7 +187,7 @@ impl fmt::Display for EdgeKind { // ─── Node ─────────────────────────────────────────────── -fn now_unix() -> i64 { +pub fn now_unix() -> i64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() @@ -253,10 +271,17 @@ impl Node { } pub fn loop_iteration(iter: usize, session_id: &NodeId) -> Self { - Node::new(NodeKind::LoopIteration, format!("Iteration {iter}")) + let ts = crate::memory::format_timestamp(now_unix()); + Node::new(NodeKind::LoopIteration, format!("[{ts}] Iteration {iter}")) .with_body(format!("session:{session_id}")) } + pub fn notification(summary: impl Into) -> Self { + let s = summary.into(); + let ts = crate::memory::format_timestamp(now_unix()); + Node::new(NodeKind::Notification, format!("[{ts}] {s}")) + } + pub fn fact_from_response(text: &str, _session_id: &NodeId) -> Self { let title = if text.chars().count() > 80 { let s: String = text.chars().take(80).collect(); @@ -401,6 +426,38 @@ impl Message { pub fn user(content: impl Into) -> Self { Self { role: Role::User, content: content.into(), tool_call_id: None, content_blocks: None } } + /// Create a user message with an inline image (Anthropic vision format). + /// + /// The image is sent as a base64-encoded source block alongside the text. + /// If `text` is empty, only the image block is included (with a generic + /// prompt so the model knows to describe/process it). + pub fn user_with_image(text: &str, base64_data: &str, mime_type: &str) -> Self { + let text_content = if text.is_empty() { + "[The user sent an image]".to_string() + } else { + text.to_string() + }; + let blocks = serde_json::json!([ + { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": base64_data, + } + }, + { + "type": "text", + "text": text_content, + } + ]); + Self { + role: Role::User, + content: text_content, + tool_call_id: None, + content_blocks: Some(blocks), + } + } pub fn assistant(content: impl Into) -> Self { Self { role: Role::Assistant, content: content.into(), tool_call_id: None, content_blocks: None } } @@ -490,22 +547,20 @@ pub struct ToolResult { pub success: bool, } -// ─── Sub-agent types ──────────────────────────────────── - -#[derive(Debug, Clone)] -pub struct SubAgentSpec { - pub name: String, - pub soul: String, - pub capabilities: Vec, - pub tool_allowlist: Vec, - pub max_iterations: usize, -} +// ─── Turn context ─────────────────────────────────────── +/// Contextual metadata about the current turn, carried from the channel +/// pipeline into the agent so it knows *who* is talking and *where*. #[derive(Debug, Clone)] -pub struct SubAgentResult { - pub answer: String, - pub facts_created: Vec, - pub tokens_used: usize, +pub struct TurnContext { + /// Channel name (e.g. "discord", "telegram", "webchat", "api"). + pub channel: String, + /// Human-readable display name for the sender, if known. + pub sender_name: Option, + /// Internal user ID (from identity resolution). + pub user_id: String, + /// True when the message came from a group/channel (not a DM). + pub is_group: bool, } // ─── Model backend ────────────────────────────────────── diff --git a/tests/integration.rs b/tests/integration.rs index f8a9c7a..188f897 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -7,14 +7,14 @@ //! Run with: cargo test --test integration -- --test-threads=1 //! (the embedding model is shared and not safe for parallel init) -use cede::config::Config; -use cede::db::Db; -use cede::db::queries; -use cede::embed::EmbedHandle; -use cede::hnsw::VectorIndex; -use cede::llm::MockLlmClient; -use cede::memory; -use cede::types::*; +use omni_cede::config::Config; +use omni_cede::db::Db; +use omni_cede::db::queries; +use omni_cede::embed::EmbedHandle; +use omni_cede::hnsw::VectorIndex; +use omni_cede::llm::MockLlmClient; +use omni_cede::memory; +use omni_cede::types::*; use std::sync::{Arc, OnceLock}; use tokio::sync::RwLock; @@ -324,7 +324,7 @@ async fn phase4_briefing_shows_contradictions() { #[tokio::test] async fn phase5_mock_llm_returns_scripted_responses() { - use cede::llm::LlmClient; + use omni_cede::llm::LlmClient; let mock = MockLlmClient::new(vec![LlmResponse { text: "Hello, world!".into(), @@ -351,10 +351,10 @@ async fn phase5_mock_llm_returns_scripted_responses() { #[tokio::test] async fn phase6_tool_registry_executes_and_records() { let h = TestHarness::new(); - let mut tools = cede::tools::ToolRegistry::new(); + let mut tools = omni_cede::tools::ToolRegistry::new(); // Register a simple echo tool - tools.register(cede::tools::Tool { + tools.register(omni_cede::tools::Tool { name: "echo".into(), description: "Echoes input".into(), input_schema: serde_json::json!({"type": "object", "properties": {"text": {"type": "string"}}}), @@ -442,13 +442,13 @@ async fn phase7_agent_loop_end_to_end() { output_tokens: 10, }]); - let agent = cede::agent::orchestrator::Agent { + let agent = omni_cede::agent::orchestrator::Agent { db: h.db.clone(), embed: h.embed.clone(), hnsw: h.hnsw.clone(), config: h.config.clone(), llm: Arc::new(mock), - tools: cede::tools::ToolRegistry::new(), + tools: omni_cede::tools::ToolRegistry::new(), auto_link_tx: h.auto_link_tx.clone(), }; @@ -495,7 +495,7 @@ async fn phase8_decay_reduces_importance() { let node_id = h.remember(node).await; // Run decay via the public function (uses proportional elapsed-time decay) - cede::run_decay(&h.db, h.config.decay_interval_secs) + omni_cede::run_decay(&h.db, h.config.decay_interval_secs) .await .unwrap(); @@ -627,11 +627,11 @@ async fn phase9_consolidation_adjusts_trust() { } // ═══════════════════════════════════════════════════════════ -// Phase 10: Sub-agents (basic structure test) +// Phase 10: Background tasks (basic structure test) // ═══════════════════════════════════════════════════════════ #[tokio::test] -async fn phase10_sub_agent_nodes_created() { +async fn phase10_background_task_nodes_created() { let h = TestHarness::new(); // Create parent session @@ -645,57 +645,51 @@ async fn phase10_sub_agent_nodes_created() { .await .unwrap(); - // Write SubAgent + Delegation nodes (testing the structure) - let sub_agent_node = Node::new(NodeKind::SubAgent, "Research sub-agent") - .with_body("Specializes in research tasks."); - let sub_id = sub_agent_node.id.clone(); + // Write BackgroundTask node (testing the structure) + let task_node = Node::new(NodeKind::BackgroundTask, "Research background task") + .with_body("Status: running\n\nResearch JWT token best practices"); + let task_id = task_node.id.clone(); h.db .call({ - let n = sub_agent_node; + let n = task_node; move |conn| queries::insert_node(conn, &n) }) .await .unwrap(); - let delegation = Node::new(NodeKind::Delegation, "Delegated: research JWT") - .with_body("Research JWT token best practices"); - let del_id = delegation.id.clone(); + // Link: BackgroundTask → Session (PartOf) + let e1 = Edge::new(task_id.clone(), session_id.clone(), EdgeKind::PartOf); h.db - .call({ - let n = delegation; - move |conn| queries::insert_node(conn, &n) - }) + .call(move |conn| queries::insert_edge(conn, &e1)) .await .unwrap(); - // Link: Delegation → Session (PartOf) - let e1 = Edge::new(del_id.clone(), session_id.clone(), EdgeKind::PartOf); + // Write a result fact derived from the task + let result_fact = Node::new(NodeKind::Fact, "Task result: research JWT") + .with_body("Status: completed\n\nJWT best practices summary..."); + let fact_id = result_fact.id.clone(); h.db - .call(move |conn| queries::insert_edge(conn, &e1)) + .call({ + let n = result_fact; + move |conn| queries::insert_node(conn, &n) + }) .await .unwrap(); - // Link: Delegation → SubAgent (DerivesFrom) - let e2 = Edge::new(del_id, sub_id, EdgeKind::DerivesFrom); + // Link: Fact → BackgroundTask (DerivesFrom) + let e2 = Edge::new(fact_id, task_id, EdgeKind::DerivesFrom); h.db .call(move |conn| queries::insert_edge(conn, &e2)) .await .unwrap(); // Verify structure - let sub_agents = h - .db - .call(|conn| queries::get_nodes_by_kind(conn, NodeKind::SubAgent)) - .await - .unwrap(); - assert_eq!(sub_agents.len(), 1); - - let delegations = h + let tasks = h .db - .call(|conn| queries::get_nodes_by_kind(conn, NodeKind::Delegation)) + .call(|conn| queries::get_nodes_by_kind(conn, NodeKind::BackgroundTask)) .await .unwrap(); - assert_eq!(delegations.len(), 1); + assert_eq!(tasks.len(), 1); // Verify edges let sid = session_id; @@ -770,7 +764,7 @@ async fn graph_bfs_traverse() { // BFS from A with depth 2 let aid = a_id.clone(); let walked = h.db.call(move |conn| { - cede::graph::bfs_walk(conn, &[aid], 2) + omni_cede::graph::bfs_walk(conn, &[aid], 2) }).await.unwrap(); assert!(walked.contains_key(&a_id), "BFS should include seed A"); @@ -844,7 +838,7 @@ async fn phase11_decay_proportional_to_elapsed_time() { .unwrap(); // Run proportional decay (interval = 60s) - cede::run_decay(&h.db, 60).await.unwrap(); + omni_cede::run_decay(&h.db, 60).await.unwrap(); let nid2 = node_id; let updated = h @@ -902,7 +896,7 @@ async fn phase11_decay_clamps_to_floor() { .await .unwrap(); - cede::run_decay(&h.db, 60).await.unwrap(); + omni_cede::run_decay(&h.db, 60).await.unwrap(); let nid2 = node_id; let updated = h @@ -981,7 +975,7 @@ async fn phase12_negation_keyword_detected() { #[tokio::test] async fn phase12_mock_llm_adjudicates_contradiction() { - use cede::llm::MockLlmClient; + use omni_cede::llm::MockLlmClient; let h = TestHarness::new(); @@ -997,7 +991,7 @@ async fn phase12_mock_llm_adjudicates_contradiction() { input_tokens: 0, output_tokens: 0, }]); - let llm: Arc = Arc::new(mock); + let llm: Arc = Arc::new(mock); // Create two contradictory nodes let node_a = Node::new(NodeKind::Fact, "Earth distance") @@ -1047,7 +1041,7 @@ async fn phase12_mock_llm_adjudicates_contradiction() { #[tokio::test] async fn phase12_mock_llm_rejects_false_positive() { - use cede::llm::MockLlmClient; + use omni_cede::llm::MockLlmClient; // Mock LLM that says "NO" (not a contradiction despite negation keywords) let mock = MockLlmClient::new(vec![LlmResponse { @@ -1061,7 +1055,7 @@ async fn phase12_mock_llm_rejects_false_positive() { input_tokens: 0, output_tokens: 0, }]); - let llm: Arc = Arc::new(mock); + let llm: Arc = Arc::new(mock); // Two nodes with negation keywords but not actually contradictory let messages = vec![Message::user(