From 8aaa9c6f49129b2f05a4d65f882ab1fad70f5c2f Mon Sep 17 00:00:00 2001 From: mattes <1240531+mattes@users.noreply.github.com> Date: Mon, 10 Mar 2025 18:37:04 -0700 Subject: [PATCH 01/21] with session mutex --- notary/src/frame.rs | 121 +++++++++++++++++++++++++++++++++++ notary/src/frame/sessions.rs | 36 +++++++++++ notary/src/frame/states.rs | 4 ++ notary/src/main.rs | 13 +++- 4 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 notary/src/frame.rs create mode 100644 notary/src/frame/sessions.rs create mode 100644 notary/src/frame/states.rs diff --git a/notary/src/frame.rs b/notary/src/frame.rs new file mode 100644 index 000000000..02ac76ad6 --- /dev/null +++ b/notary/src/frame.rs @@ -0,0 +1,121 @@ +use std::sync::Arc; + +use axum::{ + extract::{ws::WebSocket, Query, State, WebSocketUpgrade}, + response::IntoResponse, +}; +use futures::StreamExt; +use serde::Serialize; +use tokio::sync::Mutex; +use tracing::warn; +use uuid::Uuid; + +use crate::SharedState; + +pub mod sessions; +pub use sessions::Session; + +pub mod states; + +pub async fn handler( + ws: WebSocketUpgrade, + Query(params): Query>, + State(state): State>, +) -> impl IntoResponse { + // Parse ?session_id from query + let session_id = match params.get("session_id") { + Some(id) => match Uuid::parse_str(id) { + Ok(uuid) => uuid, + Err(_) => + return (axum::http::StatusCode::BAD_REQUEST, "Invalid session_id format, expected UUID") + .into_response(), + }, + None => + return (axum::http::StatusCode::BAD_REQUEST, "Missing required session_id query parameter") + .into_response(), + }; + + // create or resume session + let session = { + let mut sessions = state.frame_sessions.lock().await; + match sessions.get(&session_id) { + Some(session) => session.clone(), + None => { + let session = Arc::new(Mutex::new(Session::new(session_id))); + sessions.insert(session_id, session.clone()); + session + }, + } + }; + + ws.on_upgrade(|socket| handle_websocket(session, socket, state)) +} + +async fn handle_websocket( + session: Arc>>, + socket: WebSocket, + _state: Arc, +) { + let (sender, mut receiver) = socket.split(); + + // allow frame session to write to websocket + session.lock().await.set_writer(Some(WebSocketWriter::new(sender))).await; + + // handle incoming websocket messages + while let Some(Ok(message)) = receiver.next().await { + match message { + axum::extract::ws::Message::Text(text) => { + let state = match serde_json::from_str::(&text) { + Ok(state) => state, + Err(e) => { + warn!("Failed to parse websocket message: {}", e); + continue; + }, + }; + session.lock().await.read(state).await; + }, + axum::extract::ws::Message::Binary(_) => { + warn!("Binary messages are not supported"); + break; + }, + axum::extract::ws::Message::Ping(_) => { + todo!("Are Pings handled by axum's tokio-tungstenite?"); + }, + axum::extract::ws::Message::Pong(_) => { + todo!("Are Pongs handled by axum's tokio-tungstenite?"); + }, + axum::extract::ws::Message::Close(_) => { + break; + }, + } + } + + session.lock().await.close().await; +} + +pub struct WebSocketWriter { + sender: futures::stream::SplitSink, +} + +impl WebSocketWriter { + pub fn new( + sender: futures::stream::SplitSink, + ) -> Self { + WebSocketWriter { sender } + } +} + +impl sessions::Writer for WebSocketWriter { + async fn write(&mut self, data: &T) -> Result<(), String> { + use futures::SinkExt; + + let json = + serde_json::to_string(data).map_err(|e| format!("Failed to serialize to JSON: {}", e))?; + + self + .sender + .send(axum::extract::ws::Message::Text(json)) + .await + .map_err(|e| format!("Failed to send message: {}", e)) + } +} diff --git a/notary/src/frame/sessions.rs b/notary/src/frame/sessions.rs new file mode 100644 index 000000000..7d2414517 --- /dev/null +++ b/notary/src/frame/sessions.rs @@ -0,0 +1,36 @@ +use serde::Serialize; +use tokio::sync::Mutex; +use uuid::Uuid; + +use super::states::State; + +pub struct Session { + session_id: Uuid, + writer: Mutex>, +} + +impl Session { + pub fn new(session_id: Uuid) -> Self { Session { session_id, writer: Mutex::new(None) } } + + pub async fn set_writer(&mut self, writer: Option) { *self.writer.lock().await = writer; } + + async fn write(&mut self, data: &T) { + // TODO return error if no writer is set + if let Some(writer) = &mut *self.writer.lock().await { + writer.write(data).await; + } + } + + pub async fn read(&mut self, state: State) { + // TODO read incoming message from websocket, here it is already parsed into a struct + } + + pub async fn close(&mut self) { + // TODO: end or keep the session alive for another 10 mins? + // in case clients wants to resume session + } +} + +pub trait Writer: Send { + async fn write(&mut self, data: &T) -> Result<(), String>; +} diff --git a/notary/src/frame/states.rs b/notary/src/frame/states.rs new file mode 100644 index 000000000..8e6badce7 --- /dev/null +++ b/notary/src/frame/states.rs @@ -0,0 +1,4 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub enum State {} diff --git a/notary/src/main.rs b/notary/src/main.rs index 1491f4a83..620e9d192 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -1,4 +1,5 @@ use std::{ + collections::HashMap, fs, io::{self}, sync::Arc, @@ -20,21 +21,24 @@ use rustls::{ ServerConfig, }; use rustls_acme::{caches::DirCache, AcmeConfig}; -use tokio::{io::AsyncWriteExt, net::TcpListener}; +use tokio::{io::AsyncWriteExt, net::TcpListener, sync::Mutex}; use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor}; use tokio_stream::StreamExt; use tower_http::cors::CorsLayer; use tower_service::Service; use tracing::{error, info}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use uuid::Uuid; mod config; mod error; +mod frame; mod proxy; mod verifier; struct SharedState { notary_signing_key: SigningKey, + frame_sessions: Arc>>>>>, } /// Main entry point for the notary server application. @@ -84,12 +88,15 @@ async fn main() -> Result<(), NotaryServerError> { let listener = TcpListener::bind(&c.listen).await?; info!("Listening on https://{}", &c.listen); - let shared_state = - Arc::new(SharedState { notary_signing_key: load_notary_signing_key(&c.notary_signing_key) }); + let shared_state = Arc::new(SharedState { + notary_signing_key: load_notary_signing_key(&c.notary_signing_key), + frame_sessions: Arc::new(Mutex::new(HashMap::new())), + }); let router = Router::new() .route("/health", get(|| async move { (StatusCode::OK, "Ok").into_response() })) .route("/v1/proxy", post(proxy::proxy)) + .route("/v1/frame", post(frame::handler)) .route("/v1/meta/keys/:key", get(meta_keys)) .layer(CorsLayer::permissive()) .with_state(shared_state); From 584a61fa51b555e7278ae2c598c6335f30ab4224 Mon Sep 17 00:00:00 2001 From: mattes <1240531+mattes@users.noreply.github.com> Date: Tue, 11 Mar 2025 12:00:40 -0700 Subject: [PATCH 02/21] . --- notary/src/frame.rs | 16 +- .../src/frame/{sessions.rs => _sessions.rs} | 22 ++- notary/src/frame/_views.rs | 32 ++++ notary/src/frame/states.rs | 4 - notary/src/frame/views.rs | 154 ++++++++++++++++++ notary/src/main.rs | 4 +- 6 files changed, 216 insertions(+), 16 deletions(-) rename notary/src/frame/{sessions.rs => _sessions.rs} (61%) create mode 100644 notary/src/frame/_views.rs delete mode 100644 notary/src/frame/states.rs create mode 100644 notary/src/frame/views.rs diff --git a/notary/src/frame.rs b/notary/src/frame.rs index 02ac76ad6..e97c827b2 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -12,10 +12,13 @@ use uuid::Uuid; use crate::SharedState; -pub mod sessions; -pub use sessions::Session; +pub mod _sessions; +pub use _sessions::Session; -pub mod states; +pub mod _views; +pub mod actions; + +pub mod views; pub async fn handler( ws: WebSocketUpgrade, @@ -61,11 +64,14 @@ async fn handle_websocket( // allow frame session to write to websocket session.lock().await.set_writer(Some(WebSocketWriter::new(sender))).await; + // send current view to client + // TODO + // handle incoming websocket messages while let Some(Ok(message)) = receiver.next().await { match message { axum::extract::ws::Message::Text(text) => { - let state = match serde_json::from_str::(&text) { + let state = match serde_json::from_str::(&text) { Ok(state) => state, Err(e) => { warn!("Failed to parse websocket message: {}", e); @@ -105,7 +111,7 @@ impl WebSocketWriter { } } -impl sessions::Writer for WebSocketWriter { +impl _sessions::Writer for WebSocketWriter { async fn write(&mut self, data: &T) -> Result<(), String> { use futures::SinkExt; diff --git a/notary/src/frame/sessions.rs b/notary/src/frame/_sessions.rs similarity index 61% rename from notary/src/frame/sessions.rs rename to notary/src/frame/_sessions.rs index 7d2414517..594d66625 100644 --- a/notary/src/frame/sessions.rs +++ b/notary/src/frame/_sessions.rs @@ -1,16 +1,21 @@ +use std::collections::HashMap; + use serde::Serialize; use tokio::sync::Mutex; use uuid::Uuid; -use super::states::State; +use super::_views::{InitialView, View}; pub struct Session { - session_id: Uuid, - writer: Mutex>, + session_id: Uuid, + writer: Mutex>, + current_view: View, } impl Session { - pub fn new(session_id: Uuid) -> Self { Session { session_id, writer: Mutex::new(None) } } + pub fn new(session_id: Uuid) -> Self { + Session { session_id, writer: Mutex::new(None), current_view: InitialView::new() } + } pub async fn set_writer(&mut self, writer: Option) { *self.writer.lock().await = writer; } @@ -21,8 +26,8 @@ impl Session { } } - pub async fn read(&mut self, state: State) { - // TODO read incoming message from websocket, here it is already parsed into a struct + pub async fn read(&mut self, action: Action) { + // TODO dispatch to current view } pub async fn close(&mut self) { @@ -34,3 +39,8 @@ impl Session { pub trait Writer: Send { async fn write(&mut self, data: &T) -> Result<(), String>; } + +pub struct Action { + kind: String, + data: HashMap, +} diff --git a/notary/src/frame/_views.rs b/notary/src/frame/_views.rs new file mode 100644 index 000000000..1f9e8a831 --- /dev/null +++ b/notary/src/frame/_views.rs @@ -0,0 +1,32 @@ +use super::_sessions::{Action, Writer}; + +pub trait View: Send + Sync { + fn handle(&mut self, action: &Action) -> impl Response; + fn name(&self) -> String; +} + +pub struct InitialView { + foobar: String, +} + +impl View for InitialView { + fn handle(action: &Action) -> impl Response { Some(Box::new(ResultView {})) } + + fn name(&self) -> String { "initial".to_string() } + + fn serialize(&self) -> serde_json::Value { + serde_json::json!({ + "foobar": self.foobar, + }) + } +} + +pub struct ResultView {} + +impl View for ResultView { + fn handle(&mut self, writer: &mut W, action: &Action) -> Option>> { None } + + fn name(&self) -> String { "result".to_string() } + + fn serialize(&self) -> serde_json::Value { serde_json::json!({}) } +} diff --git a/notary/src/frame/states.rs b/notary/src/frame/states.rs deleted file mode 100644 index 8e6badce7..000000000 --- a/notary/src/frame/states.rs +++ /dev/null @@ -1,4 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Serialize, Deserialize)] -pub enum State {} diff --git a/notary/src/frame/views.rs b/notary/src/frame/views.rs new file mode 100644 index 000000000..ca640cf2e --- /dev/null +++ b/notary/src/frame/views.rs @@ -0,0 +1,154 @@ +use serde::Serialize; + +// Views: +// * IntialView +// * PromptView +// * action: client sends credentials back to notary (action: login) +// * PendingView +// * action: notary might send status update (30/100) done +// * action/gotoview: DoneView +// * DoneView + +pub enum Action { + GoToView(ViewKind), // can only be "sent" by Notary/ server + Message(Payload), + Close, +} + +pub enum ViewKind { + InitialViewKind(InitialView), + DoneViewKind(DoneView), +} + +pub struct Payload {} + +pub trait Handler { + fn handle(&mut self, action: &Action) -> Action; +} + +pub struct InitialView { + foo: String, +} + +impl Handler for InitialView {} + +pub struct DoneView { + bar: String, +} + +pub struct Session { + current_view: ViewKind, +} + +impl Session { + pub fn handle(&mut self, input_json: &[u8]) { + // serde deseralize into Action::Message::Payload + let response = self.current_view.handle(Action::Message::Payload); + } +} + +// pub struct Handler { +// state: HandlerStates, +// } + +// pub enum HandlerStates {} + +// ----------------------------------- +// pub struct View {} + +// pub struct Response {} + +// #[derive(Clone, Serialize)] +// pub struct Action {} + +// pub enum Response { +// GoToView(View), +// Response(Action), +// Close, +// } + +// pub struct InitialView { +// foobar: String, +// state: String, +// } + +// impl ViewT for InitialView { +// fn handle(&mut self, action: &Action) -> Action { +// match (action, self.state) { +// (Action::Response::InitialView::DoSomethingCrazy(payload), "initial") => { +// // do the crazy thing here +// self.state = "crazythingcompleted" +// return Action::Response(MyCrazyResult); +// }, + +// (Action::Response::InitialView::TheThingAfterTheCrazyCompute(payload), +// "crazythingcompleted") => { return Action::Reponse(Error()) +// } +// } + +// Action::GoToView(VIEW) +// } +// } + +// pub struct DoneView { +// foobar: String, +// } + +// impl ViewT for DoneView { +// fn handle(&mut self, action: &Action) -> Action { +// // match action { +// // // do the crazy thing here +// // }, +// // } + +// return Action::Close; +// } +// } + +// // json: +// // // { +// // "action": { +// // "type": "go_to_view", +// // "data": { +// // "foo": "bar" +// // } +// // } +// // } +// // { +// // "action": { +// // "type": "my_foobar_action", +// // "data": { +// // "foo": "bar" +// // } +// // } +// // } + +// // ------------- + +// // #[derive(Clone, Serialize)] +// // pub struct Action {} + +// // pub trait Response {} + +// // pub struct Done {} +// // pub struct GotoView {} + +// // pub enum ActionKind { +// // GoToView(View), +// // Response(Action), +// // Close, +// // } + +// // pub struct InitialView { +// // foobar: String, +// // } + +// // impl View for InitialView { +// // fn handle(&mut self, action: &Action) -> impl Response { +// // match action { +// // Action::Response(_) => Done {}, +// // Action::GoToView(view) => GotoView {}, +// // Action::Close => Done {}, +// // } +// // } +// // } diff --git a/notary/src/main.rs b/notary/src/main.rs index 620e9d192..ed03887c2 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -38,7 +38,9 @@ mod verifier; struct SharedState { notary_signing_key: SigningKey, - frame_sessions: Arc>>>>>, + + // TODO do we really need to wrap frame::Session in an Arc>? + frame_sessions: Arc>>>>>, } /// Main entry point for the notary server application. From 3e90a75c5c1cd24142a8b10650a6c92f57b72f92 Mon Sep 17 00:00:00 2001 From: mattes <1240531+mattes@users.noreply.github.com> Date: Tue, 11 Mar 2025 18:29:30 -0700 Subject: [PATCH 03/21] Connection State --- notary/src/_frame.rs | 206 ++++++++++++++++++ notary/src/frame.rs | 126 ++++------- notary/src/frame/{views.rs => __views.rs} | 59 ++++- .../src/frame/{_sessions.rs => sessions.rs} | 9 +- notary/src/main.rs | 5 +- 5 files changed, 301 insertions(+), 104 deletions(-) create mode 100644 notary/src/_frame.rs rename notary/src/frame/{views.rs => __views.rs} (69%) rename notary/src/frame/{_sessions.rs => sessions.rs} (78%) diff --git a/notary/src/_frame.rs b/notary/src/_frame.rs new file mode 100644 index 000000000..00d8cab4f --- /dev/null +++ b/notary/src/_frame.rs @@ -0,0 +1,206 @@ +use std::sync::Arc; + +use axum::{ + extract::{ws::WebSocket, Query, State, WebSocketUpgrade}, + response::IntoResponse, +}; +use futures::StreamExt; +use serde::Serialize; +use tokio::sync::Mutex; +use tracing::warn; +use uuid::Uuid; + +use crate::SharedState; + +pub mod sessions; +pub use sessions::Session; + +// pub mod _views; +// pub mod actions; + +// pub mod __views; + +// pub enum SessionState { +// Connected, +// Disconnected(Session), +// } + + +// impl SessionState { +// pub fn connect(self) -> Session { +// match self { +// SessionState::Connected => panic!(), +// SessionState::Disconnected(sess) => sess, +// } +// } +// } +// HashMap + + +pub struct MyStruct { + _d: PhantomData, + data: S::Data, +} + +// MyStruct { +// data: Session +// } + +pub trait State { + type Data; +} + +pub struct Start; +impl State for Start { + type Data = Session; +} +pub struct End; +impl State for End { + type Data = (); +} + +impl MyStruct { + pub fn have_fun(&mut self) { + todo!("does something") + } +} + +impl MyStruct { + pub fn end(self) -> MyStruct { + self.data.set_writer(writer) + MyStruct + } +} + + + + + +pub async fn handler( + ws: WebSocketUpgrade, + Query(params): Query>, + State(state): State>, +) -> impl IntoResponse { + // Parse ?session_id from query + let session_id = match params.get("session_id") { + Some(id) => match Uuid::parse_str(id) { + Ok(uuid) => uuid, + Err(_) => + return (axum::http::StatusCode::BAD_REQUEST, "Invalid session_id format, expected UUID") + .into_response(), + }, + None => + return (axum::http::StatusCode::BAD_REQUEST, "Missing required session_id query parameter") + .into_response(), + }; + + + let frame_sessions = state.frame_sessions.lock().await; + + frame_sessions.contains_key(&session_idion_id) { + + + // if yes: + match frame_sessions.get(&session_id) { + // None: websocket is already connected + // Some(session), no websocket currently connected + } + + // if no: + // session does not exist, create a new one + } + + + // when connected, set uuid to None + + drop(frame_sessions); + + + // create or resume session + // let session = { + // let mut sessions = state.frame_sessions.lock().await; + // match sessions.get(&session_id) { + // Some(session) => session.clone(), + // None => { + // let session = Session::new(session_id); + // sessions.insert(session_id, session); + // &session + // }, + // } + // }; + + ws.on_upgrade(move |socket| handle_websocket(session, socket, state)) +} + +async fn handle_websocket( + mut session: Session, + socket: WebSocket, + _state: Arc, +) { + let (sender, mut receiver) = socket.split(); + + // allow frame session to write to websocket + session.set_writer(Some(WebSocketWriter::new(sender))).await; + + // send current view to client + // TODO + + // handle incoming websocket messages + while let Some(Ok(message)) = receiver.next().await { + match message { + axum::extract::ws::Message::Text(text) => { + // let state = match serde_json::from_str::(&text) { + // Ok(state) => state, + // Err(e) => { + // warn!("Failed to parse websocket message: {}", e); + // continue; + // }, + // }; + // session.read(state).await; + todo!("fsdf") + }, + axum::extract::ws::Message::Binary(_) => { + warn!("Binary messages are not supported"); + break; + }, + axum::extract::ws::Message::Ping(_) => { + todo!("Are Pings handled by axum's tokio-tungstenite?"); + }, + axum::extract::ws::Message::Pong(_) => { + todo!("Are Pongs handled by axum's tokio-tungstenite?"); + }, + axum::extract::ws::Message::Close(_) => { + break; + }, + } + } + + session.close().await; +} + +pub struct WebSocketWriter { + sender: futures::stream::SplitSink, +} + +impl WebSocketWriter { + pub fn new( + sender: futures::stream::SplitSink, + ) -> Self { + WebSocketWriter { sender } + } +} + +impl sessions::Writer for WebSocketWriter { + async fn write(&mut self, data: &T) -> Result<(), String> { + use futures::SinkExt; + + let json = + serde_json::to_string(data).map_err(|e| format!("Failed to serialize to JSON: {}", e))?; + + self + .sender + .send(axum::extract::ws::Message::Text(json)) + .await + .map_err(|e| format!("Failed to send message: {}", e)) + } +} diff --git a/notary/src/frame.rs b/notary/src/frame.rs index e97c827b2..957249925 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -1,24 +1,27 @@ -use std::sync::Arc; +use std::{sync::Arc, time::SystemTime}; use axum::{ extract::{ws::WebSocket, Query, State, WebSocketUpgrade}, response::IntoResponse, }; -use futures::StreamExt; -use serde::Serialize; -use tokio::sync::Mutex; -use tracing::warn; +use tracing::info; use uuid::Uuid; use crate::SharedState; -pub mod _sessions; -pub use _sessions::Session; +pub enum ConnectionState { + Connected, + Disconnected(Session, SystemTime), /* TODO run a task that cleans up disconnected sessions + * every 60 secs */ +} -pub mod _views; -pub mod actions; +pub struct Session { + session_id: Uuid, +} -pub mod views; +impl Session { + pub fn new(session_id: Uuid) -> Self { Session { session_id } } +} pub async fn handler( ws: WebSocketUpgrade, @@ -38,90 +41,39 @@ pub async fn handler( .into_response(), }; - // create or resume session - let session = { - let mut sessions = state.frame_sessions.lock().await; - match sessions.get(&session_id) { - Some(session) => session.clone(), - None => { - let session = Arc::new(Mutex::new(Session::new(session_id))); - sessions.insert(session_id, session.clone()); - session - }, - } - }; - - ws.on_upgrade(|socket| handle_websocket(session, socket, state)) -} - -async fn handle_websocket( - session: Arc>>, - socket: WebSocket, - _state: Arc, -) { - let (sender, mut receiver) = socket.split(); + let mut frame_sessions = state.frame_sessions.lock().await; - // allow frame session to write to websocket - session.lock().await.set_writer(Some(WebSocketWriter::new(sender))).await; - - // send current view to client - // TODO + let session = match frame_sessions.remove(&session_id) { + Some(ConnectionState::Connected) => { + frame_sessions.insert(session_id, ConnectionState::Connected); + return (axum::http::StatusCode::BAD_REQUEST, "Session already connected").into_response(); + }, - // handle incoming websocket messages - while let Some(Ok(message)) = receiver.next().await { - match message { - axum::extract::ws::Message::Text(text) => { - let state = match serde_json::from_str::(&text) { - Ok(state) => state, - Err(e) => { - warn!("Failed to parse websocket message: {}", e); - continue; - }, - }; - session.lock().await.read(state).await; - }, - axum::extract::ws::Message::Binary(_) => { - warn!("Binary messages are not supported"); - break; - }, - axum::extract::ws::Message::Ping(_) => { - todo!("Are Pings handled by axum's tokio-tungstenite?"); - }, - axum::extract::ws::Message::Pong(_) => { - todo!("Are Pongs handled by axum's tokio-tungstenite?"); - }, - axum::extract::ws::Message::Close(_) => { - break; - }, - } - } + Some(ConnectionState::Disconnected(session, _)) => { + frame_sessions.insert(session_id, ConnectionState::Connected); + session + }, - session.lock().await.close().await; -} + None => { + let session = Session::new(session_id); + frame_sessions.insert(session_id, ConnectionState::Connected); + session + }, + }; -pub struct WebSocketWriter { - sender: futures::stream::SplitSink, -} + drop(frame_sessions); // drop mutex guard -impl WebSocketWriter { - pub fn new( - sender: futures::stream::SplitSink, - ) -> Self { - WebSocketWriter { sender } - } + ws.on_upgrade(move |socket| handle_websocket_connection(state, socket, session)) } -impl _sessions::Writer for WebSocketWriter { - async fn write(&mut self, data: &T) -> Result<(), String> { - use futures::SinkExt; +async fn handle_websocket_connection(state: Arc, socket: WebSocket, session: Session) { + info!("[{}] New Websocket connected", session.session_id); - let json = - serde_json::to_string(data).map_err(|e| format!("Failed to serialize to JSON: {}", e))?; + // TODO: Handle Websocket messages - self - .sender - .send(axum::extract::ws::Message::Text(json)) - .await - .map_err(|e| format!("Failed to send message: {}", e)) - } + // If the Websocket connection drops, mark it as disconnected, unless it was correctly closed. + info!("[{}] Websocket disconnected", session.session_id); + let mut frame_sessions = state.frame_sessions.lock().await; + frame_sessions + .insert(session.session_id, ConnectionState::Disconnected(session, SystemTime::now())); } diff --git a/notary/src/frame/views.rs b/notary/src/frame/__views.rs similarity index 69% rename from notary/src/frame/views.rs rename to notary/src/frame/__views.rs index ca640cf2e..6e9b840dc 100644 --- a/notary/src/frame/views.rs +++ b/notary/src/frame/__views.rs @@ -10,43 +10,78 @@ use serde::Serialize; // * DoneView pub enum Action { - GoToView(ViewKind), // can only be "sent" by Notary/ server + View(View), // can only be "sent" by Notary/ server Message(Payload), Close, } -pub enum ViewKind { - InitialViewKind(InitialView), - DoneViewKind(DoneView), +pub enum View { + InitialView(InitialView), + DoneView(DoneView), } pub struct Payload {} -pub trait Handler { - fn handle(&mut self, action: &Action) -> Action; -} + pub struct InitialView { foo: String, } -impl Handler for InitialView {} +impl InitialView { + pub fn into_done_view(&self) -> DoneView { + + } +} pub struct DoneView { bar: String, } -pub struct Session { - current_view: ViewKind, +pub struct Session { + current_view: V, +} + +pub trait ViewT: Serialize { + fn handle(&mut self, action: &Action) -> View; +} + +impl ViewT for InitialView { + fn handle(&mut self, action: &Action) -> View { + match action { + Action::View(_) => .., + + } + } } + impl Session { - pub fn handle(&mut self, input_json: &[u8]) { + pub fn handle(&mut self, input_json: &[u8]) -> Vec { // serde deseralize into Action::Message::Payload - let response = self.current_view.handle(Action::Message::Payload); + let action: Action = serde_json::from_slice(input_json); + + let next_view = self.current_view.handle(action); + + return serde_json::to_vec(next_view); + + // match (self.current_view, action) { + // (View::InitialView(initial_view), Action::View(View::DoneView(_))) => {serde_json::to_vec(initial_view.into_done_view());}, + // _ => Err("Invalid action given:\nAction:{:?}\nState:{:?}"); + // } + + let response = self.current_view.handle(action); + // TODO serialize response and then send to websocket connection + // if response == Close -> close connection } } + + +// impl Handler for InitialView {} + + + // pub struct Handler { // state: HandlerStates, // } diff --git a/notary/src/frame/_sessions.rs b/notary/src/frame/sessions.rs similarity index 78% rename from notary/src/frame/_sessions.rs rename to notary/src/frame/sessions.rs index 594d66625..fe28dbe40 100644 --- a/notary/src/frame/_sessions.rs +++ b/notary/src/frame/sessions.rs @@ -4,19 +4,22 @@ use serde::Serialize; use tokio::sync::Mutex; use uuid::Uuid; -use super::_views::{InitialView, View}; +// use super::_views::{InitialView, View}; pub struct Session { session_id: Uuid, writer: Mutex>, - current_view: View, + // current_view: View, } impl Session { pub fn new(session_id: Uuid) -> Self { - Session { session_id, writer: Mutex::new(None), current_view: InitialView::new() } + // Session { session_id, writer: Mutex::new(None), current_view: InitialView::new() } + Session { session_id, writer: Mutex::new(None) } } + // pub handle() func which passes through to current_view.handle() + pub async fn set_writer(&mut self, writer: Option) { *self.writer.lock().await = writer; } async fn write(&mut self, data: &T) { diff --git a/notary/src/main.rs b/notary/src/main.rs index ed03887c2..ea837abd1 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -39,8 +39,9 @@ mod verifier; struct SharedState { notary_signing_key: SigningKey, - // TODO do we really need to wrap frame::Session in an Arc>? - frame_sessions: Arc>>>>>, + // Can be None if a websocket is currently connected and dealing with the session + // frame_sessions: Arc>>>>, + frame_sessions: Arc>>>, } /// Main entry point for the notary server application. From 9d31263f144af993e949ac54c8692b94d18fb22c Mon Sep 17 00:00:00 2001 From: mattes <1240531+mattes@users.noreply.github.com> Date: Tue, 11 Mar 2025 19:13:18 -0700 Subject: [PATCH 04/21] add to closing logic --- notary/src/frame.rs | 59 +++++++++++++++++++++++++++++++++++++++------ notary/src/main.rs | 2 -- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/notary/src/frame.rs b/notary/src/frame.rs index 957249925..10965778d 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -4,7 +4,8 @@ use axum::{ extract::{ws::WebSocket, Query, State, WebSocketUpgrade}, response::IntoResponse, }; -use tracing::info; +use futures::StreamExt; +use tracing::{info, warn}; use uuid::Uuid; use crate::SharedState; @@ -21,6 +22,9 @@ pub struct Session { impl Session { pub fn new(session_id: Uuid) -> Self { Session { session_id } } + + /// Called when the client closes the connection. + pub async fn on_client_close(&mut self) {} } pub async fn handler( @@ -66,14 +70,55 @@ pub async fn handler( ws.on_upgrade(move |socket| handle_websocket_connection(state, socket, session)) } -async fn handle_websocket_connection(state: Arc, socket: WebSocket, session: Session) { +async fn handle_websocket_connection( + state: Arc, + socket: WebSocket, + mut session: Session, +) { info!("[{}] New Websocket connected", session.session_id); + let mut disconnected = false; + let (sender, mut receiver) = socket.split(); - // TODO: Handle Websocket messages + // TODO what if next() returns None?! + while let Some(result) = receiver.next().await { + match result { + Ok(message) => { + match message { + axum::extract::ws::Message::Text(text) => { + // TODO + }, + axum::extract::ws::Message::Binary(_) => { + warn!("Binary messages are not supported"); + disconnected = true; + break; + }, + axum::extract::ws::Message::Ping(_) => { + todo!("Are Pings handled by axum's tokio-tungstenite?"); + }, + axum::extract::ws::Message::Pong(_) => { + todo!("Are Pongs handled by axum's tokio-tungstenite?"); + }, + axum::extract::ws::Message::Close(_) => { + session.on_client_close().await; + disconnected = true; + break; + }, + } + }, + Err(_err) => { + disconnected = false; + break; + }, + } + } - // If the Websocket connection drops, mark it as disconnected, unless it was correctly closed. - info!("[{}] Websocket disconnected", session.session_id); let mut frame_sessions = state.frame_sessions.lock().await; - frame_sessions - .insert(session.session_id, ConnectionState::Disconnected(session, SystemTime::now())); + if !disconnected { + // If the Websocket connection drops, mark it as disconnected, unless it was correctly closed. + info!("[{}] Websocket disconnected", session.session_id); + frame_sessions + .insert(session.session_id, ConnectionState::Disconnected(session, SystemTime::now())); + } else { + frame_sessions.remove(&session.session_id); + } } diff --git a/notary/src/main.rs b/notary/src/main.rs index ea837abd1..92d79577a 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -39,8 +39,6 @@ mod verifier; struct SharedState { notary_signing_key: SigningKey, - // Can be None if a websocket is currently connected and dealing with the session - // frame_sessions: Arc>>>>, frame_sessions: Arc>>>, } From 24cbe020c60c9aeaefbb252cb50c3928305e5602 Mon Sep 17 00:00:00 2001 From: mattes <1240531+mattes@users.noreply.github.com> Date: Tue, 11 Mar 2025 19:24:21 -0700 Subject: [PATCH 05/21] add callbacks --- notary/src/frame.rs | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/notary/src/frame.rs b/notary/src/frame.rs index 10965778d..7d749c8bf 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -23,7 +23,13 @@ pub struct Session { impl Session { pub fn new(session_id: Uuid) -> Self { Session { session_id } } - /// Called when the client closes the connection. + /// Called when the client connects. Can be called multiple times. + pub async fn on_client_connect(&mut self) {} + + /// Called when the client disconnects unexpectedly. Can be called multiple times. + pub async fn on_client_disconnect(&mut self) {} + + /// Called when the client closes the connection. Called only once. pub async fn on_client_close(&mut self) {} } @@ -37,11 +43,11 @@ pub async fn handler( Some(id) => match Uuid::parse_str(id) { Ok(uuid) => uuid, Err(_) => - return (axum::http::StatusCode::BAD_REQUEST, "Invalid session_id format, expected UUID") + return (axum::http::StatusCode::BAD_REQUEST, "Invalid session_id format, expected UUID") // TODO return json error .into_response(), }, None => - return (axum::http::StatusCode::BAD_REQUEST, "Missing required session_id query parameter") + return (axum::http::StatusCode::BAD_REQUEST, "Missing required session_id query parameter") // TODO return json error .into_response(), }; @@ -50,7 +56,7 @@ pub async fn handler( let session = match frame_sessions.remove(&session_id) { Some(ConnectionState::Connected) => { frame_sessions.insert(session_id, ConnectionState::Connected); - return (axum::http::StatusCode::BAD_REQUEST, "Session already connected").into_response(); + return (axum::http::StatusCode::BAD_REQUEST, "Session already connected").into_response(); // TODO return json error }, Some(ConnectionState::Disconnected(session, _)) => { @@ -76,8 +82,9 @@ async fn handle_websocket_connection( mut session: Session, ) { info!("[{}] New Websocket connected", session.session_id); - let mut disconnected = false; + let mut keepalive = false; let (sender, mut receiver) = socket.split(); + session.on_client_connect().await; // TODO pass sender? // TODO what if next() returns None?! while let Some(result) = receiver.next().await { @@ -85,11 +92,11 @@ async fn handle_websocket_connection( Ok(message) => { match message { axum::extract::ws::Message::Text(text) => { - // TODO + // TODO parse json text and call session handle func }, axum::extract::ws::Message::Binary(_) => { warn!("Binary messages are not supported"); - disconnected = true; + keepalive = false; break; }, axum::extract::ws::Message::Ping(_) => { @@ -99,26 +106,27 @@ async fn handle_websocket_connection( todo!("Are Pongs handled by axum's tokio-tungstenite?"); }, axum::extract::ws::Message::Close(_) => { - session.on_client_close().await; - disconnected = true; + keepalive = false; break; }, } }, Err(_err) => { - disconnected = false; + keepalive = true; break; }, } } let mut frame_sessions = state.frame_sessions.lock().await; - if !disconnected { + if keepalive { // If the Websocket connection drops, mark it as disconnected, unless it was correctly closed. info!("[{}] Websocket disconnected", session.session_id); + session.on_client_disconnect().await; frame_sessions .insert(session.session_id, ConnectionState::Disconnected(session, SystemTime::now())); } else { + session.on_client_close().await; frame_sessions.remove(&session.session_id); } } From 928819b6c5c998c3ff8cd98e1916b16f9a4ec0b2 Mon Sep 17 00:00:00 2001 From: mattes <1240531+mattes@users.noreply.github.com> Date: Tue, 11 Mar 2025 22:05:01 -0700 Subject: [PATCH 06/21] . --- notary/src/frame.rs | 21 ++++++-- notary/src/frame/views.rs | 67 ++++++++++++++++++++++++++ notary/src/frame/views/done_view.rs | 1 + notary/src/frame/views/initial_view.rs | 1 + notary/src/frame/views/pending_view.rs | 1 + notary/src/frame/views/prompt_view.rs | 30 ++++++++++++ notary/src/main.rs | 2 +- 7 files changed, 117 insertions(+), 6 deletions(-) create mode 100644 notary/src/frame/views.rs create mode 100644 notary/src/frame/views/done_view.rs create mode 100644 notary/src/frame/views/initial_view.rs create mode 100644 notary/src/frame/views/pending_view.rs create mode 100644 notary/src/frame/views/prompt_view.rs diff --git a/notary/src/frame.rs b/notary/src/frame.rs index 7d749c8bf..94a5703b8 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -7,9 +7,12 @@ use axum::{ use futures::StreamExt; use tracing::{info, warn}; use uuid::Uuid; +use views::View; use crate::SharedState; +pub mod views; + pub enum ConnectionState { Connected, Disconnected(Session, SystemTime), /* TODO run a task that cleans up disconnected sessions @@ -17,14 +20,22 @@ pub enum ConnectionState { } pub struct Session { - session_id: Uuid, + session_id: Uuid, + // client: Option, + current_view: View, } impl Session { - pub fn new(session_id: Uuid) -> Self { Session { session_id } } + pub fn new(session_id: Uuid) -> Self { + Session { session_id, current_view: View::InitialView(views::InitialView {}) } + } + + // pub async fn handle(&mut self, request: Request) -> Response; /// Called when the client connects. Can be called multiple times. - pub async fn on_client_connect(&mut self) {} + pub async fn on_client_connect(&mut self) { + // TODO send current_view serialized + } /// Called when the client disconnects unexpectedly. Can be called multiple times. pub async fn on_client_disconnect(&mut self) {} @@ -33,7 +44,7 @@ impl Session { pub async fn on_client_close(&mut self) {} } -pub async fn handler( +pub async fn on_websocket( ws: WebSocketUpgrade, Query(params): Query>, State(state): State>, @@ -92,7 +103,7 @@ async fn handle_websocket_connection( Ok(message) => { match message { axum::extract::ws::Message::Text(text) => { - // TODO parse json text and call session handle func + // TODO parse json text and call session handle func, then call send with it }, axum::extract::ws::Message::Binary(_) => { warn!("Binary messages are not supported"); diff --git a/notary/src/frame/views.rs b/notary/src/frame/views.rs new file mode 100644 index 000000000..7d79a95b5 --- /dev/null +++ b/notary/src/frame/views.rs @@ -0,0 +1,67 @@ +// macro_rules! define_views { +// ($(($variant:ident, $module:ident)),*) => { +// // Define the modules +// $( +// pub mod $module; +// pub use $module::$variant; +// )* + +// pub enum View { +// $($variant($variant)),* +// } + +// impl View { +// pub fn handle(&mut self) { +// // match self { +// // $(View::$variant(view) => view.handle()),* +// // } +// } +// } +// } +// } + +// define_views!( +// (InitialView, initial_view), +// (PendingView, pending_view), +// (PromptView, prompt_view), +// (DoneView, done_view) +// ); + +pub mod initial_view; +pub use initial_view::InitialView; + +// pub mod pending_view; +// pub use pending_view::PendingView; + +pub mod prompt_view; +pub use prompt_view::PromptView; + +// pub mod done_view; +// pub use done_view::DoneView; + +pub enum View { + InitialView(InitialView), + // PendingView(PendingView), + // PromptView(PromptView), + // DoneView(DoneView), +} + +// impl View { +// pub fn handle(&mut self) { +// match self { +// View::InitialView(view) => view.handle(), +// View::PendingView(view) => view.handle(), +// View::PromptView(view) => view.handle(), +// View::DoneView(view) => view.handle(), +// } +// } + +// pub fn serialize(self) { +// match self { +// View::InitialView(view) => view.serialize(), +// View::PendingView(view) => view.serialize(), +// View::PromptView(view) => view.serialize(), +// View::DoneView(view) => view.serialize(), +// } +// } +// } diff --git a/notary/src/frame/views/done_view.rs b/notary/src/frame/views/done_view.rs new file mode 100644 index 000000000..484a93b6a --- /dev/null +++ b/notary/src/frame/views/done_view.rs @@ -0,0 +1 @@ +pub struct DoneView {} diff --git a/notary/src/frame/views/initial_view.rs b/notary/src/frame/views/initial_view.rs new file mode 100644 index 000000000..efd0ebe22 --- /dev/null +++ b/notary/src/frame/views/initial_view.rs @@ -0,0 +1 @@ +pub struct InitialView {} diff --git a/notary/src/frame/views/pending_view.rs b/notary/src/frame/views/pending_view.rs new file mode 100644 index 000000000..891cc989b --- /dev/null +++ b/notary/src/frame/views/pending_view.rs @@ -0,0 +1 @@ +pub struct PendingView {} diff --git a/notary/src/frame/views/prompt_view.rs b/notary/src/frame/views/prompt_view.rs new file mode 100644 index 000000000..0e39f66f8 --- /dev/null +++ b/notary/src/frame/views/prompt_view.rs @@ -0,0 +1,30 @@ +pub struct PromptView { + state: State, +} + +impl PromptView { + pub fn new() -> Self { PromptView { state: State::Initial } } + + pub fn handle(self, action: Action) -> Action { + match (self.state, action) { + (State::Initial, Action::PromptsReply(prompts_request)) => todo!(), + + _ => todo!(), // TODO return error + } + } +} + +pub enum State { + Initial, +} + +pub enum Action { + PromptsRequest(actions::PromptsRequest), + PromptsReply(actions::PromptsReply), +} + +pub mod actions { + pub struct PromptsRequest {} + + pub struct PromptsReply {} +} diff --git a/notary/src/main.rs b/notary/src/main.rs index 92d79577a..a343e3d21 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -97,7 +97,7 @@ async fn main() -> Result<(), NotaryServerError> { let router = Router::new() .route("/health", get(|| async move { (StatusCode::OK, "Ok").into_response() })) .route("/v1/proxy", post(proxy::proxy)) - .route("/v1/frame", post(frame::handler)) + .route("/v1/frame", post(frame::on_websocket)) .route("/v1/meta/keys/:key", get(meta_keys)) .layer(CorsLayer::permissive()) .with_state(shared_state); From a485aed784ab57848d800434c2f3623d4bf448d7 Mon Sep 17 00:00:00 2001 From: mattes <1240531+mattes@users.noreply.github.com> Date: Tue, 11 Mar 2025 22:20:19 -0700 Subject: [PATCH 07/21] fsd --- notary/src/frame/{sessions.rs => _sessions.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename notary/src/frame/{sessions.rs => _sessions.rs} (100%) diff --git a/notary/src/frame/sessions.rs b/notary/src/frame/_sessions.rs similarity index 100% rename from notary/src/frame/sessions.rs rename to notary/src/frame/_sessions.rs From 874392bdcf076d5a896c2929555a46977d7d7453 Mon Sep 17 00:00:00 2001 From: mattes <1240531+mattes@users.noreply.github.com> Date: Tue, 11 Mar 2025 23:15:05 -0700 Subject: [PATCH 08/21] pseudo code --- notary/src/frame.rs | 53 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/notary/src/frame.rs b/notary/src/frame.rs index 94a5703b8..b8e2bb4e1 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -5,13 +5,21 @@ use axum::{ response::IntoResponse, }; use futures::StreamExt; +use serde::{Deserialize, Serialize}; +use tokio::sync::oneshot; +use thiserror::Error; use tracing::{info, warn}; use uuid::Uuid; -use views::View; +// use views::View; use crate::SharedState; -pub mod views; +// pub mod views; + + +#[derive(Debug, Error)] +pub enum FrameError { +} pub enum ConnectionState { Connected, @@ -19,18 +27,42 @@ pub enum ConnectionState { * every 60 secs */ } +#[derive(Debug, Serialize, Deserialize)] +pub struct Action { + pub kind: String, + pub payload: serde_json::Value, +} + +#[derive(Debug, Serialize)] +pub enum View { + InitialView, +} + pub struct Session { session_id: Uuid, - // client: Option, + // sender: Option>, current_view: View, + cancel: oneshot::Sender<()>, } impl Session { pub fn new(session_id: Uuid) -> Self { - Session { session_id, current_view: View::InitialView(views::InitialView {}) } + let (cancel_sender, cancel_receiver) = oneshot::channel(); + let session = Session { session_id, current_view: View::InitialView, cancel: cancel_sender }; + tokio::spawn(session.run(cancel_receiver)); + session } - // pub async fn handle(&mut self, request: Request) -> Response; + async fn run(&self, cancel: oneshot::Receiver<()>) { + // TODO start running playwright script etc + + // TODO kill the session if cancelled + let _ = cancel.await; + } + + pub async fn handle(&mut self, request: Action) -> Action { + todo!("") + }; /// Called when the client connects. Can be called multiple times. pub async fn on_client_connect(&mut self) { @@ -41,7 +73,7 @@ impl Session { pub async fn on_client_disconnect(&mut self) {} /// Called when the client closes the connection. Called only once. - pub async fn on_client_close(&mut self) {} + pub async fn on_client_close(&self) { let _ = self.cancel.send(()); } } pub async fn on_websocket( @@ -103,7 +135,7 @@ async fn handle_websocket_connection( Ok(message) => { match message { axum::extract::ws::Message::Text(text) => { - // TODO parse json text and call session handle func, then call send with it + process_text_message(text, &mut session, sender).await; }, axum::extract::ws::Message::Binary(_) => { warn!("Binary messages are not supported"); @@ -141,3 +173,10 @@ async fn handle_websocket_connection( frame_sessions.remove(&session.session_id); } } + +async fn process_text_message(text: String, session: Session, sender: SplitSink) { + // TODO parse text into Action + // TODO call session.handle(action) + // TODO send error result to client + // TODO send action result to client +} From 7c5dd2c211edbddfb4f420ef77a821377659e42e Mon Sep 17 00:00:00 2001 From: mattes <1240531+mattes@users.noreply.github.com> Date: Tue, 11 Mar 2025 23:15:55 -0700 Subject: [PATCH 09/21] remove frame views --- notary/src/{frame => __frame}/__views.rs | 0 notary/src/{frame => __frame}/_sessions.rs | 0 notary/src/{frame => __frame}/_views.rs | 0 notary/src/{frame => __frame}/views.rs | 0 notary/src/{frame => __frame}/views/done_view.rs | 0 notary/src/{frame => __frame}/views/initial_view.rs | 0 notary/src/{frame => __frame}/views/pending_view.rs | 0 notary/src/{frame => __frame}/views/prompt_view.rs | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename notary/src/{frame => __frame}/__views.rs (100%) rename notary/src/{frame => __frame}/_sessions.rs (100%) rename notary/src/{frame => __frame}/_views.rs (100%) rename notary/src/{frame => __frame}/views.rs (100%) rename notary/src/{frame => __frame}/views/done_view.rs (100%) rename notary/src/{frame => __frame}/views/initial_view.rs (100%) rename notary/src/{frame => __frame}/views/pending_view.rs (100%) rename notary/src/{frame => __frame}/views/prompt_view.rs (100%) diff --git a/notary/src/frame/__views.rs b/notary/src/__frame/__views.rs similarity index 100% rename from notary/src/frame/__views.rs rename to notary/src/__frame/__views.rs diff --git a/notary/src/frame/_sessions.rs b/notary/src/__frame/_sessions.rs similarity index 100% rename from notary/src/frame/_sessions.rs rename to notary/src/__frame/_sessions.rs diff --git a/notary/src/frame/_views.rs b/notary/src/__frame/_views.rs similarity index 100% rename from notary/src/frame/_views.rs rename to notary/src/__frame/_views.rs diff --git a/notary/src/frame/views.rs b/notary/src/__frame/views.rs similarity index 100% rename from notary/src/frame/views.rs rename to notary/src/__frame/views.rs diff --git a/notary/src/frame/views/done_view.rs b/notary/src/__frame/views/done_view.rs similarity index 100% rename from notary/src/frame/views/done_view.rs rename to notary/src/__frame/views/done_view.rs diff --git a/notary/src/frame/views/initial_view.rs b/notary/src/__frame/views/initial_view.rs similarity index 100% rename from notary/src/frame/views/initial_view.rs rename to notary/src/__frame/views/initial_view.rs diff --git a/notary/src/frame/views/pending_view.rs b/notary/src/__frame/views/pending_view.rs similarity index 100% rename from notary/src/frame/views/pending_view.rs rename to notary/src/__frame/views/pending_view.rs diff --git a/notary/src/frame/views/prompt_view.rs b/notary/src/__frame/views/prompt_view.rs similarity index 100% rename from notary/src/frame/views/prompt_view.rs rename to notary/src/__frame/views/prompt_view.rs From a19535911bcdd899b938a79340c6b42f59254706 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Tue, 11 Mar 2025 15:00:37 +0530 Subject: [PATCH 10/21] init playwright --- Cargo.lock | 29 +++++++++++++-- Cargo.toml | 4 +- executor/Cargo.toml | 8 ++++ executor/src/main.rs | 89 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 executor/Cargo.toml create mode 100644 executor/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 8d2563ad2..46d6e1081 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1909,6 +1909,12 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db9c683daf087dc577b7506e9695b3d556a9f3849903fa28186283afd6809e9" +[[package]] +name = "linux-raw-sys" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db9c683daf087dc577b7506e9695b3d556a9f3849903fa28186283afd6809e9" + [[package]] name = "litemap" version = "0.7.5" @@ -2791,9 +2797,9 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dade4812df5c384711475be5fcd8c162555352945401aed22a35bffeab61f657" +checksum = "f7178faa4b75a30e269c71e61c353ce2748cf3d76f0c44c393f4e60abf49b825" dependencies = [ "bitflags", "errno", @@ -3281,7 +3287,7 @@ dependencies = [ "fastrand", "getrandom 0.3.1", "once_cell", - "rustix 1.0.1", + "rustix 1.0.2", "windows-sys 0.59.0", ] @@ -3817,6 +3823,15 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasm-bindgen" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -3934,6 +3949,14 @@ dependencies = [ "url", ] +[[package]] +name = "web-prover-executor" +version = "0.1.0" +dependencies = [ + "tempfile", + "uuid", +] + [[package]] name = "web-prover-notary" version = "0.7.0" diff --git a/Cargo.toml b/Cargo.toml index f0ab4aabe..ac2fb4312 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members =["client", "notary", "core", "tests"] +members =["client", "notary", "core", "tests", "executor"] resolver="2" [workspace.dependencies] @@ -38,6 +38,8 @@ uuid ={ version="1.10.0", default-features=false, features=["v4", "serde"] tracing-test="0.2" +tempfile="3.18.0" + [profile.dev] incremental =true opt-level =1 diff --git a/executor/Cargo.toml b/executor/Cargo.toml new file mode 100644 index 000000000..294c45bc0 --- /dev/null +++ b/executor/Cargo.toml @@ -0,0 +1,8 @@ +[package] +edition="2021" +name ="web-prover-executor" +version="0.1.0" + +[dependencies] +tempfile={ workspace=true } +uuid ={ workspace=true } diff --git a/executor/src/main.rs b/executor/src/main.rs new file mode 100644 index 000000000..d89c29151 --- /dev/null +++ b/executor/src/main.rs @@ -0,0 +1,89 @@ +use std::{io::Write, process::Stdio, time::Duration}; + +use tempfile::NamedTempFile; +use uuid::Uuid; + +/// The Playwright template with a placeholder for the script +const PLAYWRIGHT_TEMPLATE: &str = r#" +const { chromium } = require('playwright-core'); +const { prompt, prove, setSessionUUID } = require("@plutoxyz/playwright-utils"); + +(async () => { + const sessionUUID = process.argv[2]; + setSessionUUID(sessionUUID); + console.log("Starting Playwright session with UUID:", sessionUUID); + + const browser = await chromium.launch({ + headless: true, + executablePath: '/Users/darkrai/Library/Caches/ms-playwright/chromium_headless_shell-1155/chrome-mac/headless_shell' + }); + const context = await browser.newContext(); + const page = await context.newPage(); + + // Developer provided script: + {{.Script}} + + await browser.close(); +})(); +"#; + +fn run_playwright_script(script: &str) -> Result<(), Box> { + let filled_template = PLAYWRIGHT_TEMPLATE.replace("{{.Script}}", script); + + // Generate a session UUID + let session_uuid = Uuid::new_v4().to_string(); + + let mut temp_file = NamedTempFile::new()?; + let temp_path = temp_file.path().to_owned(); + + temp_file.write_all(filled_template.as_bytes())?; + + // close the file to flush the buffer + let _temp_file = temp_file.into_temp_path(); + + // Execute the command with timeout + println!("Starting Playwright session with UUID: {}", session_uuid); + let mut command = std::process::Command::new("node"); + let mut child = command + .arg(temp_path) + .arg(session_uuid.clone()) + .env("DEBUG", "pw:api") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + // Set a timeout of 20 seconds (matching the Go version) + // let timeout = Duration::from_secs(20); + // // kill process after timeout + // let _ = std::thread::spawn(move || { + // std::thread::sleep(timeout); + // let _ = child.kill(); + // }); + + let output = child.wait_with_output()?; + println!("Output: {:?}", output); + + // Convert output to string + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + + println!("Stdout: {}", stdout); + println!("Stderr: {}", stderr); + + Ok(()) +} + +fn main() { + println!("Hello, world!"); + // Example developer script to inject + let developer_script = r#" + await page.goto('https://example.com'); + console.log('Page title:', await page.title()); + + // Take a screenshot + await page.screenshot({ path: 'example.png' }); + console.log('Screenshot taken'); + "#; + + let _ = run_playwright_script(developer_script); +} From 55585bb8765e5a3cb83c12372ce298c1e3e69f6a Mon Sep 17 00:00:00 2001 From: lonerapier Date: Tue, 11 Mar 2025 18:14:10 +0530 Subject: [PATCH 11/21] basic script working --- README.md | 4 ++-- executor/README.md | 26 ++++++++++++++++++++++ executor/src/main.rs | 31 +++++++++++++++++++++++++-- notary/src/main.rs | 1 + notary/src/runner.rs | 51 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 109 insertions(+), 4 deletions(-) create mode 100644 executor/README.md create mode 100644 notary/src/runner.rs diff --git a/README.md b/README.md index f1d6a0034..129cb84c1 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,8 @@ If you have any questions, please reach out to any of Pluto's [team members](htt ### Usage ``` -cargo run -p notary -- --config ./fixture/notary-config.toml -cargo run -p client -- --config ./fixture/client.proxy.json +cargo run -p web-prover-notary -- --config ./fixture/notary-config.toml +cargo run -p web-prover-client -- --config ./fixture/client.proxy.json ``` ## Security Status diff --git a/executor/README.md b/executor/README.md new file mode 100644 index 000000000..27f22adb0 --- /dev/null +++ b/executor/README.md @@ -0,0 +1,26 @@ +# Web Prover Executor + +## Set up playground + +``` +npx playwright install +npm install -g playwright-core + +git clone git@github.com:pluto/playwright-playground.git +cd playwright-playground +npm install -g ./playwright-utils + +export NODE_PATH=$(npm root -g) +``` + +## Run example + +Run notary in a separate terminal: +``` +RUST_LOG=debug cargo run -p web-prover-notary -- --config ./fixture/notary-config.toml +``` + +Run example executor: +``` +cargo run -p web-prover-executor +``` \ No newline at end of file diff --git a/executor/src/main.rs b/executor/src/main.rs index d89c29151..a75dcf5f5 100644 --- a/executor/src/main.rs +++ b/executor/src/main.rs @@ -1,4 +1,4 @@ -use std::{io::Write, process::Stdio, time::Duration}; +use std::{io::Write, process::Stdio}; use tempfile::NamedTempFile; use uuid::Uuid; @@ -73,8 +73,34 @@ fn run_playwright_script(script: &str) -> Result<(), Box> Ok(()) } +const DEVELOPER_SCRIPT: &str = r#" +await page.goto("https://pseudo-bank.pluto.dev"); + +const username = page.getByRole("textbox", { name: "Username" }); +const password = page.getByRole("textbox", { name: "Password" }); + +let input = await prompt([ + { title: "Username", types: "text" }, + { title: "Password", types: "password" }, +]); + +await username.fill(input.inputs[0]); +await password.fill(input.inputs[1]); + +const loginBtn = page.getByRole("button", { name: "Login" }); +await loginBtn.click(); + +await page.waitForSelector("text=Your Accounts", { timeout: 5000 }); + +const balanceLocator = page.locator("\#balance-2"); +await balanceLocator.waitFor({ state: "visible", timeout: 5000 }); +const balanceText = (await balanceLocator.textContent()) || ""; +const balance = parseFloat(balanceText.replace(/[$,]/g, "")); + +await prove("bank_balance", balance); +"#; + fn main() { - println!("Hello, world!"); // Example developer script to inject let developer_script = r#" await page.goto('https://example.com'); @@ -84,6 +110,7 @@ fn main() { await page.screenshot({ path: 'example.png' }); console.log('Screenshot taken'); "#; + let developer_script = DEVELOPER_SCRIPT; let _ = run_playwright_script(developer_script); } diff --git a/notary/src/main.rs b/notary/src/main.rs index a343e3d21..b31f61ad7 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -34,6 +34,7 @@ mod config; mod error; mod frame; mod proxy; +mod runner; mod verifier; struct SharedState { diff --git a/notary/src/runner.rs b/notary/src/runner.rs new file mode 100644 index 000000000..3e8c4a369 --- /dev/null +++ b/notary/src/runner.rs @@ -0,0 +1,51 @@ +use std::sync::Arc; + +use axum::{ + extract::{self, State}, + Json, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::{error::NotaryServerError, SharedState}; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Prompt { + pub title: String, + pub types: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PromptRequest { + pub uuid: String, + pub prompts: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PromptResponse { + pub inputs: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ProveRequest { + pub uuid: String, + pub key: String, + pub value: Value, +} + +pub async fn prompt( + State(state): State>, + extract::Json(payload): extract::Json, +) -> Result, NotaryServerError> { + let inputs = payload.prompts.iter().map(|prompt| prompt.title.clone()).collect(); + let response = PromptResponse { inputs }; + Ok(Json(response)) +} + +pub async fn prove( + State(state): State>, + extract::Json(payload): extract::Json, +) -> Result, NotaryServerError> { + println!("Proving: {:?}", payload); + Ok(Json(())) +} From 65660be24a4298a7965e58d9e67cc5486283b698 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Wed, 12 Mar 2025 01:08:31 +0530 Subject: [PATCH 12/21] add playwright config and timeout --- Cargo.lock | 2 + Cargo.toml | 3 +- executor/Cargo.toml | 6 +- executor/src/lib.rs | 1 + executor/src/main.rs | 116 ----------------------- executor/src/playwright.rs | 188 +++++++++++++++++++++++++++++++++++++ 6 files changed, 197 insertions(+), 119 deletions(-) create mode 100644 executor/src/lib.rs delete mode 100644 executor/src/main.rs create mode 100644 executor/src/playwright.rs diff --git a/Cargo.lock b/Cargo.lock index 46d6e1081..f41cd8686 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3954,7 +3954,9 @@ name = "web-prover-executor" version = "0.1.0" dependencies = [ "tempfile", + "tracing", "uuid", + "wait-timeout", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ac2fb4312..29cac4711 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,8 @@ uuid ={ version="1.10.0", default-features=false, features=["v4", "serde"] tracing-test="0.2" -tempfile="3.18.0" +tempfile ="3.18.0" +wait-timeout="0.2.1" [profile.dev] incremental =true diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 294c45bc0..536379fc1 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -4,5 +4,7 @@ name ="web-prover-executor" version="0.1.0" [dependencies] -tempfile={ workspace=true } -uuid ={ workspace=true } +tempfile ={ workspace=true } +tracing ={ workspace=true } +uuid ={ workspace=true } +wait-timeout={ workspace=true } diff --git a/executor/src/lib.rs b/executor/src/lib.rs new file mode 100644 index 000000000..6d6e1590d --- /dev/null +++ b/executor/src/lib.rs @@ -0,0 +1 @@ +mod playwright; diff --git a/executor/src/main.rs b/executor/src/main.rs deleted file mode 100644 index a75dcf5f5..000000000 --- a/executor/src/main.rs +++ /dev/null @@ -1,116 +0,0 @@ -use std::{io::Write, process::Stdio}; - -use tempfile::NamedTempFile; -use uuid::Uuid; - -/// The Playwright template with a placeholder for the script -const PLAYWRIGHT_TEMPLATE: &str = r#" -const { chromium } = require('playwright-core'); -const { prompt, prove, setSessionUUID } = require("@plutoxyz/playwright-utils"); - -(async () => { - const sessionUUID = process.argv[2]; - setSessionUUID(sessionUUID); - console.log("Starting Playwright session with UUID:", sessionUUID); - - const browser = await chromium.launch({ - headless: true, - executablePath: '/Users/darkrai/Library/Caches/ms-playwright/chromium_headless_shell-1155/chrome-mac/headless_shell' - }); - const context = await browser.newContext(); - const page = await context.newPage(); - - // Developer provided script: - {{.Script}} - - await browser.close(); -})(); -"#; - -fn run_playwright_script(script: &str) -> Result<(), Box> { - let filled_template = PLAYWRIGHT_TEMPLATE.replace("{{.Script}}", script); - - // Generate a session UUID - let session_uuid = Uuid::new_v4().to_string(); - - let mut temp_file = NamedTempFile::new()?; - let temp_path = temp_file.path().to_owned(); - - temp_file.write_all(filled_template.as_bytes())?; - - // close the file to flush the buffer - let _temp_file = temp_file.into_temp_path(); - - // Execute the command with timeout - println!("Starting Playwright session with UUID: {}", session_uuid); - let mut command = std::process::Command::new("node"); - let mut child = command - .arg(temp_path) - .arg(session_uuid.clone()) - .env("DEBUG", "pw:api") - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn()?; - - // Set a timeout of 20 seconds (matching the Go version) - // let timeout = Duration::from_secs(20); - // // kill process after timeout - // let _ = std::thread::spawn(move || { - // std::thread::sleep(timeout); - // let _ = child.kill(); - // }); - - let output = child.wait_with_output()?; - println!("Output: {:?}", output); - - // Convert output to string - let stdout = String::from_utf8_lossy(&output.stdout).to_string(); - let stderr = String::from_utf8_lossy(&output.stderr).to_string(); - - println!("Stdout: {}", stdout); - println!("Stderr: {}", stderr); - - Ok(()) -} - -const DEVELOPER_SCRIPT: &str = r#" -await page.goto("https://pseudo-bank.pluto.dev"); - -const username = page.getByRole("textbox", { name: "Username" }); -const password = page.getByRole("textbox", { name: "Password" }); - -let input = await prompt([ - { title: "Username", types: "text" }, - { title: "Password", types: "password" }, -]); - -await username.fill(input.inputs[0]); -await password.fill(input.inputs[1]); - -const loginBtn = page.getByRole("button", { name: "Login" }); -await loginBtn.click(); - -await page.waitForSelector("text=Your Accounts", { timeout: 5000 }); - -const balanceLocator = page.locator("\#balance-2"); -await balanceLocator.waitFor({ state: "visible", timeout: 5000 }); -const balanceText = (await balanceLocator.textContent()) || ""; -const balance = parseFloat(balanceText.replace(/[$,]/g, "")); - -await prove("bank_balance", balance); -"#; - -fn main() { - // Example developer script to inject - let developer_script = r#" - await page.goto('https://example.com'); - console.log('Page title:', await page.title()); - - // Take a screenshot - await page.screenshot({ path: 'example.png' }); - console.log('Screenshot taken'); - "#; - let developer_script = DEVELOPER_SCRIPT; - - let _ = run_playwright_script(developer_script); -} diff --git a/executor/src/playwright.rs b/executor/src/playwright.rs new file mode 100644 index 000000000..bc4e6e666 --- /dev/null +++ b/executor/src/playwright.rs @@ -0,0 +1,188 @@ +use std::{ + error::Error, + io::{Read, Write}, + path::PathBuf, + process::{Command, Stdio}, + time::Duration, +}; + +use tempfile::NamedTempFile; +use tracing::{debug, error}; +use uuid::Uuid; +use wait_timeout::ChildExt; + +/// The Playwright template with a placeholder for the script +const PLAYWRIGHT_TEMPLATE: &str = r#" +const { chromium } = require('playwright-core'); +const { prompt, prove, setSessionUUID } = require("@plutoxyz/playwright-utils"); + +(async () => { + const sessionUUID = process.argv[2]; + setSessionUUID(sessionUUID); + console.log("Starting Playwright session with UUID:", sessionUUID); + + const browser = await chromium.launch({ + headless: true, + executablePath: '/Users/darkrai/Library/Caches/ms-playwright/chromium_headless_shell-1155/chrome-mac/headless_shell' + }); + const context = await browser.newContext(); + const page = await context.newPage(); + + // Developer provided script: + {{.Script}} + + await browser.close(); +})(); +"#; + +/// Configuration for the Playwright runner +pub struct PlaywrightRunnerConfig { + /// Developer script to run in the Playwright template + script: String, + /// Timeout for script execution in seconds + pub timeout_seconds: u64, +} + +pub struct PlaywrightRunner { + /// scipt template with placeholder for the developer script + template: String, + /// Playwright runner configuration + config: PlaywrightRunnerConfig, + /// Path to the Node.js executable + node_path: PathBuf, +} + +#[derive(Debug)] +pub struct PlaywrightOutput { + pub stdout: String, + pub stderr: String, +} + +impl PlaywrightRunner { + pub fn new(config: PlaywrightRunnerConfig, template: String, node_path: PathBuf) -> Self { + Self { config, template, node_path } + } + + pub fn run_script(&self, session_id: Uuid) -> Result> { + // fill the template with the developer script + let template = self.template.replace("{{.Script}}", &self.config.script); + + // create a temporary file to store the template + let mut temp_file = NamedTempFile::new()?; + temp_file.write_all(template.as_bytes())?; + let temp_path = temp_file.path().to_owned(); + let temp_dir = temp_path.parent().unwrap(); + + // close the file to flush the buffer + let _temp_file = temp_file.into_temp_path(); + + // Execute the command with timeout + debug!("Starting Playwright session id: {}", session_id); + let mut command = Command::new(&self.node_path); + let mut child = command + .arg(&temp_path) + .arg(session_id.to_string()) + .env("DEBUG", "pw:api") + .current_dir(temp_dir) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + // Set a timeout + let timeout = Duration::from_secs(self.config.timeout_seconds); + let _ = match child.wait_timeout(timeout)? { + Some(status) => + if let Some(code) = status.code() { + code + } else { + error!("Process terminated by signal: {:?}", status); + return Err("Process terminated by signal".into()); + }, + None => { + child.kill()?; + error!("Process timed out after {:?}", timeout); + return Err("Process timed out".into()); + }, + }; + + // Convert output to string + let stdout = match child.stdout.take() { + Some(mut stdout_stream) => { + let mut stdout = String::new(); + stdout_stream.read_to_string(&mut stdout)?; + stdout + }, + None => String::new(), + }; + + let stderr = match child.stderr.take() { + Some(mut stderr_stream) => { + let mut stderr = String::new(); + stderr_stream.read_to_string(&mut stderr)?; + stderr + }, + None => String::new(), + }; + + let output = PlaywrightOutput { stdout, stderr }; + + Ok(output) + } +} + +mod tests { + + use super::*; + + const EXAMPLE_DEVELOPER_SCRIPT: &str = r#" +await page.goto("https://pseudo-bank.pluto.dev"); + +const username = page.getByRole("textbox", { name: "Username" }); +const password = page.getByRole("textbox", { name: "Password" }); + +let input = await prompt([ + { title: "Username", types: "text" }, + { title: "Password", types: "password" }, +]); + +await username.fill(input.inputs[0]); +await password.fill(input.inputs[1]); + +const loginBtn = page.getByRole("button", { name: "Login" }); +await loginBtn.click(); + +await page.waitForSelector("text=Your Accounts", { timeout: 5000 }); + +const balanceLocator = page.locator("\#balance-2"); +await balanceLocator.waitFor({ state: "visible", timeout: 5000 }); +const balanceText = (await balanceLocator.textContent()) || ""; +const balance = parseFloat(balanceText.replace(/[$,]/g, "")); + +await prove("bank_balance", balance); +"#; + + #[test] + fn test_playwright_script() { + // Example developer script to inject into the Playwright template + let session_id = Uuid::new_v4(); + // output of `which node` + let node_path = + Command::new("which").arg("node").output().expect("Failed to run `which node`").stdout; + let node_path = String::from_utf8_lossy(&node_path).trim().to_string(); + + let config = PlaywrightRunnerConfig { + script: EXAMPLE_DEVELOPER_SCRIPT.to_string(), + timeout_seconds: 30, + }; + let runner = + PlaywrightRunner::new(config, PLAYWRIGHT_TEMPLATE.to_string(), PathBuf::from(node_path)); + + let result = runner.run_script(session_id); + + if let Err(e) = result { + eprintln!("Failed to run Playwright script: {:?}", e); + } else { + println!("output: {:?}", result.unwrap()); + } + } +} From 50bbead8bf4e1f964de9063661d42b253bf60d8d Mon Sep 17 00:00:00 2001 From: lonerapier Date: Wed, 12 Mar 2025 01:08:40 +0530 Subject: [PATCH 13/21] start internal listener --- notary/src/config.rs | 2 ++ notary/src/main.rs | 12 ++++++++++++ notary/src/runner.rs | 13 ++++++------- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/notary/src/config.rs b/notary/src/config.rs index b399ff171..33f2505b1 100644 --- a/notary/src/config.rs +++ b/notary/src/config.rs @@ -15,6 +15,7 @@ pub struct Config { pub server_cert: String, pub server_key: String, pub listen: String, + pub listen_internal: String, pub notary_signing_key: String, pub acme_email: String, pub acme_domain: String, @@ -27,6 +28,7 @@ pub fn read_config() -> Config { let builder = config::Config::builder() // TODO is this the right way to make server_cert optional? .set_default("listen", "0.0.0.0:443").unwrap() + .set_default("listen_internal", "127.0.0.1:7935").unwrap() .set_default("server_cert", "").unwrap() .set_default("server_key", "").unwrap() .set_default("notary_signing_key", "").unwrap() diff --git a/notary/src/main.rs b/notary/src/main.rs index b31f61ad7..45bf3e44c 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -103,6 +103,18 @@ async fn main() -> Result<(), NotaryServerError> { .layer(CorsLayer::permissive()) .with_state(shared_state); + // Create a separate internal router for prompts + // and Start the internal HTTP server as a separate task + let internal_router = + Router::new().route("/prompt", post(runner::prompt)).route("/prove", post(runner::prove)); + let internal_listener = TcpListener::bind(&c.listen_internal).await?; + info!("Internal server listening on http://{}", &c.listen_internal); + tokio::spawn(async move { + if let Err(e) = axum::serve(internal_listener, internal_router).await { + error!("Internal server error: {:?}", e); + } + }); + if !c.server_cert.is_empty() || !c.server_key.is_empty() { let _ = listen(listener, router, &c.server_cert, &c.server_key).await; } else { diff --git a/notary/src/runner.rs b/notary/src/runner.rs index 3e8c4a369..e8b6f37a9 100644 --- a/notary/src/runner.rs +++ b/notary/src/runner.rs @@ -1,13 +1,12 @@ -use std::sync::Arc; - use axum::{ - extract::{self, State}, + extract::{self}, Json, }; use serde::{Deserialize, Serialize}; use serde_json::Value; +use tracing::debug; -use crate::{error::NotaryServerError, SharedState}; +use crate::error::NotaryServerError; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Prompt { @@ -34,18 +33,18 @@ pub struct ProveRequest { } pub async fn prompt( - State(state): State>, extract::Json(payload): extract::Json, ) -> Result, NotaryServerError> { + debug!("Prompting: {:?}", payload); let inputs = payload.prompts.iter().map(|prompt| prompt.title.clone()).collect(); let response = PromptResponse { inputs }; + Ok(Json(response)) } pub async fn prove( - State(state): State>, extract::Json(payload): extract::Json, ) -> Result, NotaryServerError> { - println!("Proving: {:?}", payload); + debug!("Proving: {:?}", payload); Ok(Json(())) } From 02d6eab5d7b03a97b644c754d91124536e748348 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Wed, 12 Mar 2025 01:32:52 +0530 Subject: [PATCH 14/21] add env vars --- executor/src/playwright.rs | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/executor/src/playwright.rs b/executor/src/playwright.rs index bc4e6e666..78cb2ae0c 100644 --- a/executor/src/playwright.rs +++ b/executor/src/playwright.rs @@ -50,6 +50,8 @@ pub struct PlaywrightRunner { config: PlaywrightRunnerConfig, /// Path to the Node.js executable node_path: PathBuf, + /// environment variables + env: Vec<(String, String)>, } #[derive(Debug)] @@ -58,9 +60,16 @@ pub struct PlaywrightOutput { pub stderr: String, } +// TODO: add a PlaywrightError type + impl PlaywrightRunner { - pub fn new(config: PlaywrightRunnerConfig, template: String, node_path: PathBuf) -> Self { - Self { config, template, node_path } + pub fn new( + config: PlaywrightRunnerConfig, + template: String, + node_path: PathBuf, + env_vars: Vec<(String, String)>, + ) -> Self { + Self { config, template, node_path, env: env_vars } } pub fn run_script(&self, session_id: Uuid) -> Result> { @@ -79,14 +88,19 @@ impl PlaywrightRunner { // Execute the command with timeout debug!("Starting Playwright session id: {}", session_id); let mut command = Command::new(&self.node_path); - let mut child = command + let command = command .arg(&temp_path) .arg(session_id.to_string()) - .env("DEBUG", "pw:api") .current_dir(temp_dir) .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn()?; + .stderr(Stdio::piped()); + + // Add environment variables + for (key, value) in &self.env { + command.env(key, value); + } + + let mut child = command.spawn()?; // Set a timeout let timeout = Duration::from_secs(self.config.timeout_seconds); @@ -174,8 +188,12 @@ await prove("bank_balance", balance); script: EXAMPLE_DEVELOPER_SCRIPT.to_string(), timeout_seconds: 30, }; - let runner = - PlaywrightRunner::new(config, PLAYWRIGHT_TEMPLATE.to_string(), PathBuf::from(node_path)); + let runner = PlaywrightRunner::new( + config, + PLAYWRIGHT_TEMPLATE.to_string(), + PathBuf::from(node_path), + vec![(String::from("DEBUG"), String::from("pw:api"))], + ); let result = runner.run_script(session_id); From b547409d32c7c38f8d0cfdb3f7ff5f3d1130173e Mon Sep 17 00:00:00 2001 From: lonerapier Date: Wed, 12 Mar 2025 18:28:14 +0530 Subject: [PATCH 15/21] start playwright integration --- Cargo.lock | 1 + Cargo.toml | 7 ++-- notary/Cargo.toml | 6 ++-- notary/src/frame.rs | 88 ++++++++++++++++++++++++++------------------- notary/src/main.rs | 21 ++++++----- 5 files changed, 73 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f41cd8686..9a17166c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4000,6 +4000,7 @@ dependencies = [ "uuid", "web-prover-client", "web-prover-core", + "web-prover-executor", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 29cac4711..b2b53d3cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,9 +4,10 @@ resolver="2" [workspace.dependencies] # Local re-exporting -web-prover-client={ path="client" } -web-prover-core ={ path="core" } -web-prover-notary={ path="notary" } +web-prover-client ={ path="client" } +web-prover-core ={ path="core" } +web-prover-executor={ path="executor" } +web-prover-notary ={ path="notary" } # Serde serde ={ version="1.0.204", features=["derive"] } serde_json="1.0.120" diff --git a/notary/Cargo.toml b/notary/Cargo.toml index 77e62e3d8..801e89db1 100644 --- a/notary/Cargo.toml +++ b/notary/Cargo.toml @@ -5,6 +5,10 @@ name ="web-prover-notary" version="0.7.0" [dependencies] +web-prover-client ={ workspace=true } +web-prover-core ={ workspace=true } +web-prover-executor={ workspace=true } + chrono ={ workspace=true } futures ={ workspace=true } futures-util ="0.3.30" @@ -22,8 +26,6 @@ tower-http ={ version="0.5.2", features=["cors"] } tower-service ="0.3.2" tracing ={ workspace=true } tracing-subscriber={ workspace=true } -web-prover-client ={ workspace=true } -web-prover-core ={ workspace=true } alloy-primitives={ version="0.8.2", features=["k256"] } async-trait ="0.1.67" diff --git a/notary/src/frame.rs b/notary/src/frame.rs index b8e2bb4e1..ea13794ca 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -1,13 +1,17 @@ use std::{sync::Arc, time::SystemTime}; use axum::{ - extract::{ws::WebSocket, Query, State, WebSocketUpgrade}, + extract::{ + ws::{Message, WebSocket}, + Query, State, WebSocketUpgrade, + }, response::IntoResponse, }; use futures::StreamExt; +use futures_util::{stream::SplitSink, SinkExt}; use serde::{Deserialize, Serialize}; -use tokio::sync::oneshot; use thiserror::Error; +use tokio::sync::oneshot; use tracing::{info, warn}; use uuid::Uuid; @@ -16,10 +20,8 @@ use crate::SharedState; // pub mod views; - #[derive(Debug, Error)] -pub enum FrameError { -} +pub enum FrameError {} pub enum ConnectionState { Connected, @@ -49,20 +51,21 @@ impl Session { pub fn new(session_id: Uuid) -> Self { let (cancel_sender, cancel_receiver) = oneshot::channel(); let session = Session { session_id, current_view: View::InitialView, cancel: cancel_sender }; - tokio::spawn(session.run(cancel_receiver)); + // TODO: this moves session to the new task, so we can't use it anymore + // Run should be executed elsewhere + // tokio::spawn(session.run(cancel_receiver)); session } async fn run(&self, cancel: oneshot::Receiver<()>) { - // TODO start running playwright script etc + // TODO start running playwright script + // // TODO kill the session if cancelled let _ = cancel.await; } - pub async fn handle(&mut self, request: Action) -> Action { - todo!("") - }; + pub async fn handle(&mut self, request: Action) -> Action { todo!("") } /// Called when the client connects. Can be called multiple times. pub async fn on_client_connect(&mut self) { @@ -126,33 +129,31 @@ async fn handle_websocket_connection( ) { info!("[{}] New Websocket connected", session.session_id); let mut keepalive = false; - let (sender, mut receiver) = socket.split(); + let (mut sender, mut receiver) = socket.split(); session.on_client_connect().await; // TODO pass sender? // TODO what if next() returns None?! while let Some(result) = receiver.next().await { match result { - Ok(message) => { - match message { - axum::extract::ws::Message::Text(text) => { - process_text_message(text, &mut session, sender).await; - }, - axum::extract::ws::Message::Binary(_) => { - warn!("Binary messages are not supported"); - keepalive = false; - break; - }, - axum::extract::ws::Message::Ping(_) => { - todo!("Are Pings handled by axum's tokio-tungstenite?"); - }, - axum::extract::ws::Message::Pong(_) => { - todo!("Are Pongs handled by axum's tokio-tungstenite?"); - }, - axum::extract::ws::Message::Close(_) => { - keepalive = false; - break; - }, - } + Ok(message) => match message { + axum::extract::ws::Message::Text(text) => { + process_text_message(text, &mut session, &mut sender).await; + }, + axum::extract::ws::Message::Binary(_) => { + warn!("Binary messages are not supported"); + keepalive = false; + break; + }, + axum::extract::ws::Message::Ping(_) => { + todo!("Are Pings handled by axum's tokio-tungstenite?"); + }, + axum::extract::ws::Message::Pong(_) => { + todo!("Are Pongs handled by axum's tokio-tungstenite?"); + }, + axum::extract::ws::Message::Close(_) => { + keepalive = false; + break; + }, }, Err(_err) => { keepalive = true; @@ -174,9 +175,22 @@ async fn handle_websocket_connection( } } -async fn process_text_message(text: String, session: Session, sender: SplitSink) { - // TODO parse text into Action - // TODO call session.handle(action) - // TODO send error result to client - // TODO send action result to client +async fn process_text_message( + text: String, + session: &mut Session, + sender: &mut SplitSink, +) { + let action = serde_json::from_str::(&text); + match action { + Ok(action) => { + let result = session.handle(action).await; + // TODO send result to client + }, + Err(err) => { + // TODO send error to client + let _ = sender.send(Message::Text(format!("Invalid action: {}", err))).await; + }, + } + // TODO send error result to client + // TODO send action result to client } diff --git a/notary/src/main.rs b/notary/src/main.rs index 45bf3e44c..1899869cb 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -95,6 +95,8 @@ async fn main() -> Result<(), NotaryServerError> { frame_sessions: Arc::new(Mutex::new(HashMap::new())), }); + let _ = start_internal_server(&c).await?; + let router = Router::new() .route("/health", get(|| async move { (StatusCode::OK, "Ok").into_response() })) .route("/v1/proxy", post(proxy::proxy)) @@ -103,8 +105,17 @@ async fn main() -> Result<(), NotaryServerError> { .layer(CorsLayer::permissive()) .with_state(shared_state); - // Create a separate internal router for prompts - // and Start the internal HTTP server as a separate task + if !c.server_cert.is_empty() || !c.server_key.is_empty() { + let _ = listen(listener, router, &c.server_cert, &c.server_key).await; + } else { + let _ = acme_listen(listener, router, &c.acme_domain, &c.acme_email).await; + } + Ok(()) +} + +/// Create a separate internal router for prompts and Start the internal HTTP server as a separate +/// task +async fn start_internal_server(c: &config::Config) -> Result<(), NotaryServerError> { let internal_router = Router::new().route("/prompt", post(runner::prompt)).route("/prove", post(runner::prove)); let internal_listener = TcpListener::bind(&c.listen_internal).await?; @@ -114,12 +125,6 @@ async fn main() -> Result<(), NotaryServerError> { error!("Internal server error: {:?}", e); } }); - - if !c.server_cert.is_empty() || !c.server_key.is_empty() { - let _ = listen(listener, router, &c.server_cert, &c.server_key).await; - } else { - let _ = acme_listen(listener, router, &c.acme_domain, &c.acme_email).await; - } Ok(()) } From b4b0110781e6b5b55560f8bf274729f48bd02209 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Thu, 13 Mar 2025 01:03:21 +0530 Subject: [PATCH 16/21] playwright update --- Cargo.lock | 2 + executor/Cargo.toml | 2 + executor/src/lib.rs | 2 +- executor/src/playwright.rs | 26 ++++++---- notary/src/frame.rs | 103 ++++++++++++++++++++++++++----------- notary/src/main.rs | 17 ++++-- notary/src/runner.rs | 20 ++++++- 7 files changed, 124 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9a17166c7..8e9443469 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3954,6 +3954,8 @@ name = "web-prover-executor" version = "0.1.0" dependencies = [ "tempfile", + "thiserror 1.0.69", + "tokio", "tracing", "uuid", "wait-timeout", diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 536379fc1..3c13b2722 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -5,6 +5,8 @@ version="0.1.0" [dependencies] tempfile ={ workspace=true } +thiserror ={ workspace=true } +tokio ={ workspace=true } tracing ={ workspace=true } uuid ={ workspace=true } wait-timeout={ workspace=true } diff --git a/executor/src/lib.rs b/executor/src/lib.rs index 6d6e1590d..4bcccddcc 100644 --- a/executor/src/lib.rs +++ b/executor/src/lib.rs @@ -1 +1 @@ -mod playwright; +pub mod playwright; diff --git a/executor/src/playwright.rs b/executor/src/playwright.rs index 78cb2ae0c..f5de5f14f 100644 --- a/executor/src/playwright.rs +++ b/executor/src/playwright.rs @@ -1,5 +1,4 @@ use std::{ - error::Error, io::{Read, Write}, path::PathBuf, process::{Command, Stdio}, @@ -12,7 +11,7 @@ use uuid::Uuid; use wait_timeout::ChildExt; /// The Playwright template with a placeholder for the script -const PLAYWRIGHT_TEMPLATE: &str = r#" +pub const PLAYWRIGHT_TEMPLATE: &str = r#" const { chromium } = require('playwright-core'); const { prompt, prove, setSessionUUID } = require("@plutoxyz/playwright-utils"); @@ -38,7 +37,7 @@ const { prompt, prove, setSessionUUID } = require("@plutoxyz/playwright-utils"); /// Configuration for the Playwright runner pub struct PlaywrightRunnerConfig { /// Developer script to run in the Playwright template - script: String, + pub script: String, /// Timeout for script execution in seconds pub timeout_seconds: u64, } @@ -60,6 +59,15 @@ pub struct PlaywrightOutput { pub stderr: String, } +#[derive(Debug, thiserror::Error)] +pub enum PlaywrightError { + #[error(transparent)] + IoError(#[from] std::io::Error), + + #[error("Playwright execution failed: {0}")] + ExecutionError(String), +} + // TODO: add a PlaywrightError type impl PlaywrightRunner { @@ -72,7 +80,7 @@ impl PlaywrightRunner { Self { config, template, node_path, env: env_vars } } - pub fn run_script(&self, session_id: Uuid) -> Result> { + pub async fn run_script(&self, session_id: &Uuid) -> Result { // fill the template with the developer script let template = self.template.replace("{{.Script}}", &self.config.script); @@ -110,12 +118,12 @@ impl PlaywrightRunner { code } else { error!("Process terminated by signal: {:?}", status); - return Err("Process terminated by signal".into()); + return Err(PlaywrightError::ExecutionError("Process terminated by signal".into())); }, None => { child.kill()?; error!("Process timed out after {:?}", timeout); - return Err("Process timed out".into()); + return Err(PlaywrightError::ExecutionError("Process timed out".into())); }, }; @@ -175,8 +183,8 @@ const balance = parseFloat(balanceText.replace(/[$,]/g, "")); await prove("bank_balance", balance); "#; - #[test] - fn test_playwright_script() { + #[tokio::test] + async fn test_playwright_script() { // Example developer script to inject into the Playwright template let session_id = Uuid::new_v4(); // output of `which node` @@ -195,7 +203,7 @@ await prove("bank_balance", balance); vec![(String::from("DEBUG"), String::from("pw:api"))], ); - let result = runner.run_script(session_id); + let result = runner.run_script(&session_id).await; if let Err(e) = result { eprintln!("Failed to run Playwright script: {:?}", e); diff --git a/notary/src/frame.rs b/notary/src/frame.rs index ea13794ca..6bc88e8d3 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc, time::SystemTime}; +use std::{path::PathBuf, process::Command, sync::Arc, time::SystemTime}; use axum::{ extract::{ @@ -11,10 +11,11 @@ use futures::StreamExt; use futures_util::{stream::SplitSink, SinkExt}; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::sync::oneshot; -use tracing::{info, warn}; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tracing::{error, info, warn}; use uuid::Uuid; +use crate::runner::{Prompt, PromptResponse}; // use views::View; use crate::SharedState; @@ -23,10 +24,12 @@ use crate::SharedState; #[derive(Debug, Error)] pub enum FrameError {} -pub enum ConnectionState { +// TODO: either session should live under connection state or connection state should be a session +#[derive(Debug)] +pub enum ConnectionState { Connected, - Disconnected(Session, SystemTime), /* TODO run a task that cleans up disconnected sessions - * every 60 secs */ + Disconnected(SystemTime), /* TODO run a task that cleans up disconnected sessions + * every 60 secs */ } #[derive(Debug, Serialize, Deserialize)] @@ -38,31 +41,60 @@ pub struct Action { #[derive(Debug, Serialize)] pub enum View { InitialView, + PromptView { prompts: Vec }, } pub struct Session { session_id: Uuid, // sender: Option>, + sender: Option>, current_view: View, - cancel: oneshot::Sender<()>, + // prompt_request_sender: Arc>>, + // cancel: oneshot::Sender<()>, } impl Session { pub fn new(session_id: Uuid) -> Self { - let (cancel_sender, cancel_receiver) = oneshot::channel(); - let session = Session { session_id, current_view: View::InitialView, cancel: cancel_sender }; - // TODO: this moves session to the new task, so we can't use it anymore - // Run should be executed elsewhere - // tokio::spawn(session.run(cancel_receiver)); + // let (cancel_sender, cancel_receiver) = oneshot::channel(); + let session = Session { session_id, current_view: View::InitialView, sender: None }; session } - async fn run(&self, cancel: oneshot::Receiver<()>) { - // TODO start running playwright script - // + async fn run(&self) { + let playwright_runner_config = web_prover_executor::playwright::PlaywrightRunnerConfig { + script: "".to_string(), + timeout_seconds: 0, + }; + + let node_path = + Command::new("which").arg("node").output().expect("Failed to run `which node`").stdout; + let node_path = String::from_utf8_lossy(&node_path).trim().to_string(); + + let playwright_runner = web_prover_executor::playwright::PlaywrightRunner::new( + playwright_runner_config, + web_prover_executor::playwright::PLAYWRIGHT_TEMPLATE.to_string(), + PathBuf::from(node_path), + vec![(String::from("DEBUG"), String::from("pw:api"))], + ); + + let session_id = self.session_id.clone(); + let script_result = + tokio::spawn(async move { playwright_runner.run_script(&session_id).await }); + + match script_result.await { + Ok(Ok(output)) => { + info!("Playwright output: {:?}", output); + }, + Ok(Err(e)) => { + error!("Playwright script failed: {:?}", e); + }, + Err(e) => { + error!("Failed to await script result: {:?}", e); + }, + } // TODO kill the session if cancelled - let _ = cancel.await; + // let _ = cancel.await; } pub async fn handle(&mut self, request: Action) -> Action { todo!("") } @@ -76,7 +108,9 @@ impl Session { pub async fn on_client_disconnect(&mut self) {} /// Called when the client closes the connection. Called only once. - pub async fn on_client_close(&self) { let _ = self.cancel.send(()); } + pub async fn on_client_close(&self) { + // let _ = self.cancel.send(()); + } } pub async fn on_websocket( @@ -105,15 +139,16 @@ pub async fn on_websocket( return (axum::http::StatusCode::BAD_REQUEST, "Session already connected").into_response(); // TODO return json error }, - Some(ConnectionState::Disconnected(session, _)) => { + Some(ConnectionState::Disconnected(_)) => { frame_sessions.insert(session_id, ConnectionState::Connected); + let session = state.sessions.lock().await.get(&session_id).unwrap().clone(); session }, None => { let session = Session::new(session_id); frame_sessions.insert(session_id, ConnectionState::Connected); - session + Arc::new(Mutex::new(session)) }, }; @@ -125,19 +160,24 @@ pub async fn on_websocket( async fn handle_websocket_connection( state: Arc, socket: WebSocket, - mut session: Session, + session: Arc>, ) { - info!("[{}] New Websocket connected", session.session_id); + info!("[{}] New Websocket connected", session.lock().await.session_id); let mut keepalive = false; let (mut sender, mut receiver) = socket.split(); - session.on_client_connect().await; // TODO pass sender? + + state.sessions.lock().await.insert(session.lock().await.session_id, session.clone()); + + session.lock().await.on_client_connect().await; // TODO pass sender? + + session.lock().await.run().await; // TODO what if next() returns None?! while let Some(result) = receiver.next().await { match result { Ok(message) => match message { axum::extract::ws::Message::Text(text) => { - process_text_message(text, &mut session, &mut sender).await; + process_text_message(text, session.clone(), &mut sender).await; }, axum::extract::ws::Message::Binary(_) => { warn!("Binary messages are not supported"); @@ -165,25 +205,26 @@ async fn handle_websocket_connection( let mut frame_sessions = state.frame_sessions.lock().await; if keepalive { // If the Websocket connection drops, mark it as disconnected, unless it was correctly closed. - info!("[{}] Websocket disconnected", session.session_id); - session.on_client_disconnect().await; - frame_sessions - .insert(session.session_id, ConnectionState::Disconnected(session, SystemTime::now())); + info!("[{}] Websocket disconnected", session.lock().await.session_id); + session.lock().await.on_client_disconnect().await; + // frame_sessions + // .insert(session.lock().await.session_id, ConnectionState::Disconnected(session.clone(), + // SystemTime::now())); } else { - session.on_client_close().await; - frame_sessions.remove(&session.session_id); + session.lock().await.on_client_close().await; + frame_sessions.remove(&session.lock().await.session_id); } } async fn process_text_message( text: String, - session: &mut Session, + session: Arc>, sender: &mut SplitSink, ) { let action = serde_json::from_str::(&text); match action { Ok(action) => { - let result = session.handle(action).await; + let result = session.lock().await.handle(action).await; // TODO send result to client }, Err(err) => { diff --git a/notary/src/main.rs b/notary/src/main.rs index 1899869cb..e92dc99f4 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -40,7 +40,8 @@ mod verifier; struct SharedState { notary_signing_key: SigningKey, - frame_sessions: Arc>>>, + frame_sessions: Arc>>, + sessions: Arc>>>>, } /// Main entry point for the notary server application. @@ -93,9 +94,10 @@ async fn main() -> Result<(), NotaryServerError> { let shared_state = Arc::new(SharedState { notary_signing_key: load_notary_signing_key(&c.notary_signing_key), frame_sessions: Arc::new(Mutex::new(HashMap::new())), + sessions: Arc::new(Mutex::new(HashMap::new())), }); - let _ = start_internal_server(&c).await?; + let _ = start_internal_server(&c, shared_state.clone()).await?; let router = Router::new() .route("/health", get(|| async move { (StatusCode::OK, "Ok").into_response() })) @@ -115,9 +117,14 @@ async fn main() -> Result<(), NotaryServerError> { /// Create a separate internal router for prompts and Start the internal HTTP server as a separate /// task -async fn start_internal_server(c: &config::Config) -> Result<(), NotaryServerError> { - let internal_router = - Router::new().route("/prompt", post(runner::prompt)).route("/prove", post(runner::prove)); +async fn start_internal_server( + c: &config::Config, + shared_state: Arc, +) -> Result<(), NotaryServerError> { + let internal_router = Router::new() + .route("/prompt", post(runner::prompt)) + .route("/prove", post(runner::prove)) + .with_state(shared_state); let internal_listener = TcpListener::bind(&c.listen_internal).await?; info!("Internal server listening on http://{}", &c.listen_internal); tokio::spawn(async move { diff --git a/notary/src/runner.rs b/notary/src/runner.rs index e8b6f37a9..f7f386b0c 100644 --- a/notary/src/runner.rs +++ b/notary/src/runner.rs @@ -1,12 +1,14 @@ +use std::sync::Arc; + use axum::{ - extract::{self}, + extract::{self, State}, Json, }; use serde::{Deserialize, Serialize}; use serde_json::Value; use tracing::debug; -use crate::error::NotaryServerError; +use crate::{error::NotaryServerError, SharedState}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Prompt { @@ -33,16 +35,30 @@ pub struct ProveRequest { } pub async fn prompt( + State(state): State>, extract::Json(payload): extract::Json, ) -> Result, NotaryServerError> { debug!("Prompting: {:?}", payload); let inputs = payload.prompts.iter().map(|prompt| prompt.title.clone()).collect(); let response = PromptResponse { inputs }; + let session_id = uuid::Uuid::parse_str(&payload.uuid).unwrap(); + let frame_sessions = state.frame_sessions.lock().await; + // match frame_sessions.get(&session_id) { + // Some(crate::frame::ConnectionState::Connected) => {}, + // Some(crate::frame::ConnectionState::Disconnected(_)) => { + // return Err(NotaryServerError::SessionDisconnected); + // }, + // None => { + // return Err(NotaryServerError::SessionNotConnected); + // }, + // } + Ok(Json(response)) } pub async fn prove( + State(_state): State>, extract::Json(payload): extract::Json, ) -> Result, NotaryServerError> { debug!("Proving: {:?}", payload); From 72be3ad3cb47ca5edf5df659841d179a5f1c6fc9 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Thu, 13 Mar 2025 15:35:06 +0530 Subject: [PATCH 17/21] add notary -> client interface --- client/src/lib.rs | 8 +++ executor/src/playwright.rs | 1 - notary/src/error.rs | 2 + notary/src/frame.rs | 111 +++++++++++++++++++++++++++++-------- notary/src/runner.rs | 39 +++++++++---- 5 files changed, 126 insertions(+), 35 deletions(-) diff --git a/client/src/lib.rs b/client/src/lib.rs index 69ec1cb45..662b4e135 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -82,3 +82,11 @@ pub async fn verify( Ok(verify_response) } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_frame() {} +} diff --git a/executor/src/playwright.rs b/executor/src/playwright.rs index f5de5f14f..8b75964f8 100644 --- a/executor/src/playwright.rs +++ b/executor/src/playwright.rs @@ -153,7 +153,6 @@ impl PlaywrightRunner { } mod tests { - use super::*; const EXAMPLE_DEVELOPER_SCRIPT: &str = r#" diff --git a/notary/src/error.rs b/notary/src/error.rs index a43e9020c..a4e643000 100644 --- a/notary/src/error.rs +++ b/notary/src/error.rs @@ -49,6 +49,8 @@ pub enum NotaryServerError { #[error(transparent)] WebProverCoreError(#[from] WebProverCoreError), + #[error(transparent)] + RunnerError(#[from] crate::runner::RunnerError), } /// Trait implementation to convert this error into an axum http response diff --git a/notary/src/frame.rs b/notary/src/frame.rs index 6bc88e8d3..89c5e7f3b 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -1,4 +1,10 @@ -use std::{path::PathBuf, process::Command, sync::Arc, time::SystemTime}; +use core::panic; +use std::{ + path::PathBuf, + process::Command, + sync::Arc, + time::{Duration, SystemTime}, +}; use axum::{ extract::{ @@ -11,7 +17,7 @@ use futures::StreamExt; use futures_util::{stream::SplitSink, SinkExt}; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::sync::{oneshot, Mutex}; use tracing::{error, info, warn}; use uuid::Uuid; @@ -22,7 +28,12 @@ use crate::SharedState; // pub mod views; #[derive(Debug, Error)] -pub enum FrameError {} +pub enum FrameError { + #[error("WebSocket error: {0}")] + WebSocketError(String), + #[error("Prompt timeout")] + PromptTimeout, +} // TODO: either session should live under connection state or connection state should be a session #[derive(Debug)] @@ -45,18 +56,23 @@ pub enum View { } pub struct Session { - session_id: Uuid, - // sender: Option>, - sender: Option>, - current_view: View, - // prompt_request_sender: Arc>>, + session_id: Uuid, + ws_sender: Option>, + // sender: Option>, + current_view: View, + prompt_response_sender: Arc>>>, // cancel: oneshot::Sender<()>, } impl Session { pub fn new(session_id: Uuid) -> Self { // let (cancel_sender, cancel_receiver) = oneshot::channel(); - let session = Session { session_id, current_view: View::InitialView, sender: None }; + let session = Session { + session_id, + current_view: View::InitialView, + ws_sender: None, + prompt_response_sender: Arc::new(Mutex::new(None)), + }; session } @@ -99,17 +115,51 @@ impl Session { pub async fn handle(&mut self, request: Action) -> Action { todo!("") } + pub async fn handle_prompt( + &mut self, + prompts: Vec, + ) -> Result { + let prompt_view = View::PromptView { prompts }; + let serialized_view = serde_json::to_string(&prompt_view).unwrap(); + + let (prompt_response_sender, prompt_response_receiver) = oneshot::channel::(); + self.prompt_response_sender.lock().await.replace(prompt_response_sender); + + // TODO: session should store each view sent with a request id, so that it can match the + // response + self.ws_sender.as_mut().unwrap().send(Message::Text(serialized_view)).await.unwrap(); + + self.current_view = prompt_view; + + match tokio::time::timeout(Duration::from_secs(60), prompt_response_receiver).await { + Ok(Ok(prompt_response)) => Ok(prompt_response), + Ok(Err(_)) => Err(FrameError::WebSocketError("Prompt response channel closed".to_string())), + Err(_) => Err(FrameError::PromptTimeout), + } + } + + pub async fn handle_prompt_response(&mut self, prompt_response: PromptResponse) { + let prompt_response_sender = self.prompt_response_sender.lock().await.take().unwrap(); + prompt_response_sender.send(prompt_response).unwrap(); + } + /// Called when the client connects. Can be called multiple times. pub async fn on_client_connect(&mut self) { - // TODO send current_view serialized + // send initial view + let current_view_serialized = serde_json::to_string(&self.current_view).unwrap(); + self.ws_sender.as_mut().unwrap().send(Message::Text(current_view_serialized)).await.unwrap(); } /// Called when the client disconnects unexpectedly. Can be called multiple times. pub async fn on_client_disconnect(&mut self) {} /// Called when the client closes the connection. Called only once. - pub async fn on_client_close(&self) { + pub async fn on_client_close(&mut self) { // let _ = self.cancel.send(()); + // TODO: cancel other tasks like playwright + if let Some(mut ws_sender) = self.ws_sender.take() { + let _ = ws_sender.close().await; // attempt to close the websocket connection gracefully + } } } @@ -164,7 +214,9 @@ async fn handle_websocket_connection( ) { info!("[{}] New Websocket connected", session.lock().await.session_id); let mut keepalive = false; - let (mut sender, mut receiver) = socket.split(); + let (sender, mut receiver) = socket.split(); + + session.lock().await.ws_sender = Some(sender); state.sessions.lock().await.insert(session.lock().await.session_id, session.clone()); @@ -177,7 +229,7 @@ async fn handle_websocket_connection( match result { Ok(message) => match message { axum::extract::ws::Message::Text(text) => { - process_text_message(text, session.clone(), &mut sender).await; + process_text_message(text, session.clone()).await; }, axum::extract::ws::Message::Binary(_) => { warn!("Binary messages are not supported"); @@ -207,29 +259,42 @@ async fn handle_websocket_connection( // If the Websocket connection drops, mark it as disconnected, unless it was correctly closed. info!("[{}] Websocket disconnected", session.lock().await.session_id); session.lock().await.on_client_disconnect().await; - // frame_sessions - // .insert(session.lock().await.session_id, ConnectionState::Disconnected(session.clone(), - // SystemTime::now())); + frame_sessions + .insert(session.lock().await.session_id, ConnectionState::Disconnected(SystemTime::now())); } else { session.lock().await.on_client_close().await; frame_sessions.remove(&session.lock().await.session_id); + state.sessions.lock().await.remove(&session.lock().await.session_id); } + drop(frame_sessions); } -async fn process_text_message( - text: String, - session: Arc>, - sender: &mut SplitSink, -) { +async fn process_text_message(text: String, session: Arc>) { let action = serde_json::from_str::(&text); match action { Ok(action) => { - let result = session.lock().await.handle(action).await; + let action = session.lock().await.handle(action).await; + match action.kind.as_str() { + "prompt_response" => { + let prompt_response = serde_json::from_value::(action.payload).unwrap(); + session.lock().await.handle_prompt_response(prompt_response).await; + }, + _ => { + panic!("Invalid action: {}", action.kind); + }, + } // TODO send result to client }, Err(err) => { // TODO send error to client - let _ = sender.send(Message::Text(format!("Invalid action: {}", err))).await; + + // let sender = session.lock().await.ws_sender.as_mut(); + + // // Send an error message to the client + // if let Some(sender) = sender { + // let _ = + // sender.send(axum::extract::ws::Message::Text(format!("Invalid action: {}", + // err))).await; } }, } // TODO send error result to client diff --git a/notary/src/runner.rs b/notary/src/runner.rs index f7f386b0c..f4eb8b509 100644 --- a/notary/src/runner.rs +++ b/notary/src/runner.rs @@ -34,25 +34,42 @@ pub struct ProveRequest { pub value: Value, } +#[derive(Debug, thiserror::Error)] +pub enum RunnerError { + #[error("Playwright session disconnected")] + PlaywrightSessionDisconnected, + #[error("Playwright session not connected")] + PlaywrightSessionNotConnected, + #[error(transparent)] + FrameError(#[from] crate::frame::FrameError), +} + pub async fn prompt( State(state): State>, extract::Json(payload): extract::Json, ) -> Result, NotaryServerError> { debug!("Prompting: {:?}", payload); - let inputs = payload.prompts.iter().map(|prompt| prompt.title.clone()).collect(); - let response = PromptResponse { inputs }; + // let inputs = payload.prompts.iter().map(|prompt| prompt.title.clone()).collect(); + // let response = PromptResponse { inputs }; let session_id = uuid::Uuid::parse_str(&payload.uuid).unwrap(); let frame_sessions = state.frame_sessions.lock().await; - // match frame_sessions.get(&session_id) { - // Some(crate::frame::ConnectionState::Connected) => {}, - // Some(crate::frame::ConnectionState::Disconnected(_)) => { - // return Err(NotaryServerError::SessionDisconnected); - // }, - // None => { - // return Err(NotaryServerError::SessionNotConnected); - // }, - // } + let response = match frame_sessions.get(&session_id) { + Some(crate::frame::ConnectionState::Connected) => { + let session = state.sessions.lock().await.get(&session_id).unwrap().clone(); + let response = + session.lock().await.handle_prompt(payload.prompts).await.map_err(RunnerError::from)?; + Ok::(response) + }, + Some(crate::frame::ConnectionState::Disconnected(_)) => { + return Err(RunnerError::PlaywrightSessionDisconnected).map_err(NotaryServerError::from); + }, + None => { + return Err(RunnerError::PlaywrightSessionNotConnected).map_err(NotaryServerError::from); + }, + }?; + + drop(frame_sessions); Ok(Json(response)) } From 80938534553779f6831d0e999c0e2ab887f16236 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Thu, 13 Mar 2025 15:41:15 +0530 Subject: [PATCH 18/21] fix lockfile --- Cargo.lock | 196 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 122 insertions(+), 74 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8e9443469..0e3768029 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,9 +46,9 @@ checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "alloy-primitives" -version = "0.8.22" +version = "0.8.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c66bb6715b7499ea755bde4c96223ae8eb74e05c014ab38b9db602879ffb825" +checksum = "eacedba97e65cdc7ab592f2b22ef5d3ab8d60b2056bc3a6e6363577e8270ec6f" dependencies = [ "alloy-rlp", "bytes", @@ -57,7 +57,7 @@ dependencies = [ "derive_more", "foldhash", "hashbrown 0.15.2", - "indexmap 2.7.1", + "indexmap 2.8.0", "itoa", "k256", "keccak-asm", @@ -521,7 +521,7 @@ dependencies = [ "miniz_oxide", "object", "rustc-demangle", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -544,9 +544,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" -version = "1.6.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +checksum = "bb97d56060ee67d285efb8001fec9d2a4c710c32efd2e14b5cbb5ba71930fc2d" [[package]] name = "bit-set" @@ -668,9 +668,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.31" +version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767" +checksum = "6088f3ae8c3608d19260cd7445411865a485688711b78b5be70d78cd96136f83" dependencies = [ "clap_builder", "clap_derive", @@ -678,9 +678,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.31" +version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863" +checksum = "22a7ef7f676155edfb82daa97f99441f3ebf4a58d5e32f295a56259f1b6facc8" dependencies = [ "anstream", "anstyle", @@ -690,9 +690,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.28" +version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" +checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ "heck", "proc-macro2", @@ -1360,7 +1360,7 @@ dependencies = [ "cfg-if", "libc", "wasi 0.13.3+wasi-0.2.2", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -1392,7 +1392,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.7.1", + "indexmap 2.8.0", "slab", "tokio", "tokio-util", @@ -1464,9 +1464,9 @@ dependencies = [ [[package]] name = "http" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -1485,12 +1485,12 @@ dependencies = [ [[package]] name = "http-body-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", - "futures-util", + "futures-core", "http", "http-body", "pin-project-lite", @@ -1789,9 +1789,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.1" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" +checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -1887,9 +1887,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.170" +version = "0.2.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" +checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" [[package]] name = "libm" @@ -1909,12 +1909,6 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db9c683daf087dc577b7506e9695b3d556a9f3849903fa28186283afd6809e9" -[[package]] -name = "linux-raw-sys" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db9c683daf087dc577b7506e9695b3d556a9f3849903fa28186283afd6809e9" - [[package]] name = "litemap" version = "0.7.5" @@ -2082,9 +2076,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.3" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +checksum = "cde51589ab56b20a6f686b2c68f7a0bd6add753d697abf720d63f8db3ab7b1ad" [[package]] name = "openssl" @@ -2200,7 +2194,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2463,9 +2457,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1f1914ce909e1658d9907913b4b91947430c7d9be598b15a1912935b8c04801" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] @@ -2593,9 +2587,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.12" +version = "0.12.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" +checksum = "989e327e510263980e231de548a33e63d34962d29ae61b467389a1a09627a254" dependencies = [ "base64 0.22.1", "bytes", @@ -2652,9 +2646,9 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.13" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", @@ -3053,7 +3047,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.7.1", + "indexmap 2.8.0", "serde", "serde_derive", "serde_json", @@ -3414,9 +3408,9 @@ checksum = "b130bd8a58c163224b44e217b4239ca7b927d82bf6cc2fea1fc561d15056e3f7" [[package]] name = "tokio" -version = "1.44.0" +version = "1.44.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9975ea0f48b5aa3972bf2d888c238182458437cc2a19374b81b25cdf1023fb3a" +checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a" dependencies = [ "backtrace", "bytes", @@ -3525,7 +3519,7 @@ version = "0.22.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.8.0", "serde", "serde_spanned", "toml_datetime", @@ -3823,15 +3817,6 @@ dependencies = [ "wit-bindgen-rt", ] -[[package]] -name = "wasm-bindgen" -version = "0.2.93" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" -dependencies = [ - "wit-bindgen-rt", -] - [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -4071,7 +4056,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -4082,32 +4067,31 @@ checksum = "6dccfd733ce2b1753b03b6d3c65edf020262ea35e20ccdf3e288043e6dd620e3" [[package]] name = "windows-registry" -version = "0.2.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ "windows-result", "windows-strings", - "windows-targets", + "windows-targets 0.53.0", ] [[package]] name = "windows-result" -version = "0.2.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +checksum = "06374efe858fab7e4f881500e6e86ec8bc28f9462c47e5a9941a0142ad86b189" dependencies = [ - "windows-targets", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.1.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" dependencies = [ - "windows-result", - "windows-targets", + "windows-link", ] [[package]] @@ -4116,7 +4100,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -4125,7 +4109,7 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -4134,14 +4118,30 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", ] [[package]] @@ -4150,53 +4150,101 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1" +checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36" dependencies = [ "memchr", ] From 26c1ed522ead46d8aa3945808ed6352a1664a0ad Mon Sep 17 00:00:00 2001 From: lonerapier Date: Fri, 14 Mar 2025 00:23:37 +0530 Subject: [PATCH 19/21] initial websocket connection --- Cargo.lock | 105 +++++++++++++++++++++++++------ Cargo.toml | 9 ++- README.md | 1 + client/Cargo.toml | 17 +++-- client/src/config.rs | 5 -- client/src/lib.rs | 126 +++++++++++++++++++++++++++++++++++++- core/src/frame.rs | 43 +++++++++++++ core/src/lib.rs | 1 + fixture/client.proxy.json | 101 +++++++++++++++--------------- notary/Cargo.toml | 4 +- notary/src/frame.rs | 39 +++++------- notary/src/main.rs | 2 +- notary/src/runner.rs | 27 +------- 13 files changed, 341 insertions(+), 139 deletions(-) create mode 100644 core/src/frame.rs diff --git a/Cargo.lock b/Cargo.lock index 0e3768029..6373ebf82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -63,7 +63,7 @@ dependencies = [ "keccak-asm", "paste", "proptest", - "rand", + "rand 0.8.5", "ruint", "rustc-hash", "serde", @@ -263,7 +263,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1df2c09229cbc5a028b1d70e00fdb2acee28b1055dfb5ca73eea49c5a25c4e7c" dependencies = [ "num-traits", - "rand", + "rand 0.8.5", ] [[package]] @@ -273,7 +273,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" dependencies = [ "num-traits", - "rand", + "rand 0.8.5", ] [[package]] @@ -481,7 +481,7 @@ dependencies = [ "sha1", "sync_wrapper", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.24.0", "tower 0.5.2", "tower-layer", "tower-service", @@ -861,7 +861,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" dependencies = [ "generic-array", - "rand_core", + "rand_core 0.6.4", "subtle", "zeroize", ] @@ -1060,7 +1060,7 @@ dependencies = [ "group", "pem-rfc7468", "pkcs8", - "rand_core", + "rand_core 0.6.4", "sec1", "subtle", "zeroize", @@ -1156,7 +1156,7 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" dependencies = [ - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -1167,7 +1167,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "835c052cb0c08c1acf6ffd71c022172e18723949c8282f2b9f27efbc51e64534" dependencies = [ "byteorder", - "rand", + "rand 0.8.5", "rustc-hex", "static_assertions", ] @@ -1376,7 +1376,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" dependencies = [ "ff", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -2388,8 +2388,8 @@ dependencies = [ "bitflags", "lazy_static", "num-traits", - "rand", - "rand_chacha", + "rand 0.8.5", + "rand_chacha 0.3.1", "rand_xorshift", "regex-syntax 0.8.5", "rusty-fork", @@ -2429,7 +2429,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", "getrandom 0.2.15", - "rand", + "rand 0.8.5", "ring", "rustc-hash", "rustls", @@ -2477,8 +2477,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", + "zerocopy 0.8.23", ] [[package]] @@ -2488,7 +2499,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", ] [[package]] @@ -2500,13 +2521,22 @@ dependencies = [ "getrandom 0.2.15", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.1", +] + [[package]] name = "rand_xorshift" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" dependencies = [ - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -2707,7 +2737,7 @@ dependencies = [ "parity-scale-codec", "primitive-types", "proptest", - "rand", + "rand 0.8.5", "rlp", "ruint-macro", "serde", @@ -3140,7 +3170,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ "digest 0.10.7", - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -3475,7 +3505,22 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.24.0", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" +dependencies = [ + "futures-util", + "log", + "native-tls", + "rustls", + "tokio", + "tokio-native-tls", + "tungstenite 0.26.2", ] [[package]] @@ -3661,12 +3706,30 @@ dependencies = [ "http", "httparse", "log", - "rand", + "rand 0.8.5", "sha1", "thiserror 1.0.69", "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "native-tls", + "rand 0.9.0", + "sha1", + "thiserror 2.0.12", + "utf-8", +] + [[package]] name = "typenum" version = "1.18.0" @@ -3900,6 +3963,7 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", + "native-tls", "reqwest", "rustls", "rustls-pki-types", @@ -3909,6 +3973,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-rustls", + "tokio-tungstenite 0.26.2", "tokio-util", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index b2b53d3cc..f0291f061 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,9 +28,12 @@ hyper ={ version="1.6", features=["full"] } hyper-util ={ version="0.1", features=["full"] } # Async -tokio ={ version="1.39.1", features=["full"] } -tokio-rustls={ version="0.26.0", default-features=false, features=["logging", "tls12"] } -tokio-util ={ version="0.7" } +axum ={ version="0.7", features=["ws", "json"] } +axum-core ="0.4" +tokio ={ version="1.39.1", features=["full"] } +tokio-rustls ={ version="0.26.0", default-features=false, features=["logging", "tls12"] } +tokio-tungstenite={ version="0.26.2", features=["native-tls", "rustls"] } +tokio-util ={ version="0.7" } chrono ="0.4" derive_more={ version="2.0.1", features=["full"] } diff --git a/README.md b/README.md index 129cb84c1..bc0dd26c9 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ If you have any questions, please reach out to any of Pluto's [team members](htt - `client`: contains components for the client that are shared across both WASM and iOS targets. - `fixture`: contains testing artifacts such as TLS certificates and configuration files. - `notary`: notary server which can notarize TEE proofs. +- `core`: core features of web proofs, i.e. manifest validation, parser, extraction. ### Usage diff --git a/client/Cargo.toml b/client/Cargo.toml index 3f908ddeb..d2052f395 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -14,10 +14,11 @@ unsafe_skip_cert_verification=[] # Shared dependencies for all targets [dependencies] -bytes ="1" -pki-types ={ package="rustls-pki-types", version="1.7" } web-prover-core={ workspace=true } -webpki-roots ="0.26.1" + +bytes ="1" +pki-types ={ package="rustls-pki-types", version="1.7" } +webpki-roots="0.26.1" # Serde serde ={ workspace=true } serde_json={ workspace=true } @@ -45,8 +46,12 @@ uuid ={ workspace=true } # Web hyper-util={ workspace=true } # Async -rustls ={ version="0.23", default-features=false, features=["ring"] } -tokio ={ workspace=true, features=["rt", "rt-multi-thread", "macros", "net", "io-std", "fs"] } -tokio-rustls={ version="0.26", default-features=false, features=["logging", "tls12"] } +rustls ={ version="0.23", default-features=false, features=["ring"] } +tokio ={ workspace=true, features=["rt", "rt-multi-thread", "macros", "net", "io-std", "fs"] } +tokio-rustls ={ version="0.26", default-features=false, features=["logging", "tls12"] } +tokio-tungstenite={ workspace=true } # TLSN reqwest={ version="0.12", features=["json", "rustls-tls"] } + +[dev-dependencies] +native-tls="0.2.14" diff --git a/client/src/config.rs b/client/src/config.rs index 4ceba47e9..bba789d4a 100644 --- a/client/src/config.rs +++ b/client/src/config.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use serde::Deserialize; use serde_with::{ base64::{Base64, Standard}, @@ -20,9 +18,6 @@ pub struct Config { // this is helpful for local debugging with self-signed certs #[serde_as(as = "Option>")] pub notary_ca_cert: Option>, - pub target_method: String, - pub target_url: String, - pub target_headers: HashMap, pub target_body: String, pub manifest: Manifest, #[serde(skip)] diff --git a/client/src/lib.rs b/client/src/lib.rs index 662b4e135..43c646f89 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -85,8 +85,132 @@ pub async fn verify( #[cfg(test)] mod tests { + use futures::{SinkExt, StreamExt}; + // use tokio_rustls::{ + // rustls::{Certificate, ClientConfig, RootCertStore}, + // TlsConnector, + // }; + use tokio_tungstenite::tungstenite::client::IntoClientRequest; + + const EXAMPLE_DEVELOPER_SCRIPT: &str = r#" + await page.goto("https://pseudo-bank.pluto.dev"); + + const username = page.getByRole("textbox", { name: "Username" }); + const password = page.getByRole("textbox", { name: "Password" }); + + let input = await prompt([ + { title: "Username", types: "text" }, + { title: "Password", types: "password" }, + ]); + + await username.fill(input.inputs[0]); + await password.fill(input.inputs[1]); + + const loginBtn = page.getByRole("button", { name: "Login" }); + await loginBtn.click(); + + await page.waitForSelector("text=Your Accounts", { timeout: 5000 }); + + const balanceLocator = page.locator("\#balance-2"); + await balanceLocator.waitFor({ state: "visible", timeout: 5000 }); + const balanceText = (await balanceLocator.textContent()) || ""; + const balance = parseFloat(balanceText.replace(/[$,]/g, "")); + + await prove("bank_balance", balance); + "#; + use super::*; #[tokio::test] - async fn test_frame() {} + async fn test_frame() { + let config = std::fs::read("../fixture/client.proxy.json").unwrap(); + let mut config: config::Config = serde_json::from_slice(&config).unwrap(); + config.set_session_id(); + + let url = format!( + "wss://{}:{}/v1/frame?session_id={}", + config.notary_host.clone(), + config.notary_port.clone(), + config.session_id + ); + println!("url={}", url); + + // Set up TLS connector that accepts your server certificate + let mut connector_builder = native_tls::TlsConnector::builder(); + + // For testing only: disable certificate verification + // WARNING: Only use this for testing, never in production + connector_builder.danger_accept_invalid_certs(true); + + let connector = connector_builder.build().unwrap(); + let connector = native_tls::TlsConnector::from(connector); + + // Connect with TLS + let request = url.into_client_request().unwrap(); + let (mut ws_stream, response) = tokio_tungstenite::connect_async_tls_with_config( + request, + None, + false, + Some(tokio_tungstenite::Connector::NativeTls(connector)), + ) + .await + .unwrap(); + + // assert!(response.status().is_success(), "WebSocket connection failed"); + println!("response={:?}", response); + + // let message = "Hello, server!"; + // ws_stream.send(tokio_tungstenite::tungstenite::Message::Text(message.into())).await.unwrap(); + + // let received_message = ws_stream.next().await.unwrap().unwrap(); + // assert_eq!(received_message, tokio_tungstenite::tungstenite::Message::Text(message.into())); + tokio::spawn(async move { + while let Some(message) = ws_stream.next().await { + let message = message.unwrap(); + println!("message={:?}", message); + + match message { + tokio_tungstenite::tungstenite::Message::Text(text) => { + let view: web_prover_core::frame::View = serde_json::from_str(&text).unwrap(); + match view { + web_prover_core::frame::View::InitialView => { + println!("InitialView"); + ws_stream + .send(tokio_tungstenite::tungstenite::Message::Text( + serde_json::to_string(&web_prover_core::frame::Action { + kind: "initial_input".to_owned(), + payload: serde_json::to_value(web_prover_core::frame::InitialInput { + script: EXAMPLE_DEVELOPER_SCRIPT.to_owned(), + }) + .unwrap(), + }) + .unwrap() + .into(), + )) + .await + .unwrap(); + }, + web_prover_core::frame::View::PromptView { prompts } => { + println!("Received PromptView with prompts: {:?}", prompts); + let prompt_response = web_prover_core::frame::PromptResponse { + inputs: prompts.iter().map(|prompt| prompt.title.clone()).collect(), + }; + let action = web_prover_core::frame::Action { + kind: "prompt_response".to_owned(), + payload: serde_json::to_value(prompt_response).unwrap(), + }; + ws_stream + .send(tokio_tungstenite::tungstenite::Message::Text( + serde_json::to_string(&action).unwrap().into(), + )) + .await + .unwrap(); + }, + } + }, + _ => panic!("unexpected message"), + }; + } + }); + } } diff --git a/core/src/frame.rs b/core/src/frame.rs new file mode 100644 index 000000000..ec234b429 --- /dev/null +++ b/core/src/frame.rs @@ -0,0 +1,43 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Action { + pub kind: String, + pub payload: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum View { + InitialView, + PromptView { prompts: Vec }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct InitialInput { + pub script: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Prompt { + pub title: String, + pub types: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PromptRequest { + pub uuid: String, + pub prompts: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PromptResponse { + pub inputs: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ProveRequest { + pub uuid: String, + pub key: String, + pub value: Value, +} diff --git a/core/src/lib.rs b/core/src/lib.rs index be972d683..c91f3400c 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -2,6 +2,7 @@ pub mod error; pub mod http; pub mod manifest; +pub mod frame; pub mod hash; pub mod parser; pub mod proof; diff --git a/fixture/client.proxy.json b/fixture/client.proxy.json index 4b387f458..5798ec52a 100644 --- a/fixture/client.proxy.json +++ b/fixture/client.proxy.json @@ -3,63 +3,60 @@ "notary_port": 7443, "notary_ca_cert": "MIIFszCCA5ugAwIBAgIUeXLQmnjeXsHpGji7xA8oJjjw7WwwDQYJKoZIhvcNAQELBQAwaTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExEjAQBgNVBAcMCVBhbG8gQWx0bzEVMBMGA1UECgwMT3JnYW5pemF0aW9uMRowGAYDVQQDDBFMb2NhbGhvc3QgUm9vdCBDQTAeFw0yNDA2MDcyMjUyNDBaFw0yOTA2MDYyMjUyNDBaMGkxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRIwEAYDVQQHDAlQYWxvIEFsdG8xFTATBgNVBAoMDE9yZ2FuaXphdGlvbjEaMBgGA1UEAwwRTG9jYWxob3N0IFJvb3QgQ0EwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDIHMEL9gVSxT/J0qS4xDkDZs1d1UG0+z6NFLLsGdV7gu0ZJbDPlNd0kpjWsisVNB7TcqWoq5ROK5CR+6lZxXC8nbqr2YAJ2O8mHIXcYv7msAN3UYxtM6v1M7K+vNMJdDjZVAxcOKq5R7uUDUPw1weePz6eVEjntAW8mUjqkfnCqYml943Ud3724SkI5wyUT9rKS3bk6hvneq1ah/b1zRGDF2gp+T/oNe4ieS/LGoIUluE2csGRXtt542gpJnw5L54JASmGgt6hunUSWtoaht7Qxv6hYpieu4iHqZY1kfcFDjDH2WI16g1YqrWHzk1l7vWNLVDEcK3kdSQ1GmYAij8ZjAi0LJizLwtN//EkfxiOPlV435itK3uugY+etxrk77BeA6PmVcpZeLLXYuKSrzfaBh0ifP2p0uRlShURi5Rz4IE0I7wHkZ44x9MKYv8YzXK7O29HD158tgorxqwwKmkHqSxWpp7SRKvNnMulHN/el+IKDrPeBhVXsSSkd6U+/H61q7i0SY9TqqhdiMQLW/efK9LkVRen5myhwqogwiF/42Jp2nrCeuzv5YDsAFSrQ0lukW+Hz7FXV+0axnKeXZ08Nd+IS1BhGyMgHo6PWMP1fWyfO0DJVUfIqrHqvBy8bW0yNOuhiyU2oeyDRKv75OMxpIrUeX3qvmrVcYUPvfXsVQIDAQABo1MwUTAdBgNVHQ4EFgQUlro6UuKRY+lHZ3T4FQMhoQAJ2a8wHwYDVR0jBBgwFoAUlro6UuKRY+lHZ3T4FQMhoQAJ2a8wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAgEARliSCLdjNI5fFkNwXDK6k6jN5hITrY07XIawUIEwinjQqgLkYFvzXbMtfLwsWmxMPjDVYYDa5Y2NAGcH08b//2z9EuHxMOTOFTKr59BEcVxV1cxV+lspAunH8DLSlLJhf/EeR+MhIIHAfhlE8V7lvlE1EbM+Uj5JYIeefV/4omsGrphyHD3oSJAQDae0su200I/i2yAaTrwXLZ4HtaXsnxKZ4PMPFWaLvMQ8DsLgx2VB3/vQJn74Xepau6mYEWlRnUu90mj79gJOnwBKPlLojF6dJOMIJ2YHr9fI8sUfkVwPFVlkDKJcr0ll5RL3O/naNlLQZuOgijOM5YF5iTrefliVodEHpBPID2mhtq/E+ZIQWLpik8ulsJ8ufN9YfrbjbsiC/KeoMqoFCImRSyMGQDMADo4EV3DNfDFvfrHx0qBMmJ0nkhuGobphegMPCjZ3axvQwQulKuHXmFpAvGYcpK/twBMC1MJkV04tIwVEDZG6id5oKYtrIXHdSFshf6r3z4bbgq6kJnOxZ8Vo4cEw/dgc3hRivr+HnxOJcEk2CTQlCVOiCQAg64OqDEOoswVg6nzoO3RJhFatu+abO22MIXPNGma02zBoQZLYpGzL9z6pMnPKjL15G9H1SYVSTGhmq+GVtdRibg8rLBciSm3ERd7gNRqvYP5GrjCtUIbOTEc=", "target_body": "", - "proving": { - "manifest": { - "manifestVersion": "1", - "id": "reddit-user-karma", - "title": "Total Reddit Karma", - "description": "Generate a proof that you have a certain amount of karma", - "prepareUrl": "https://www.reddit.com/login/", - "request": { - "method": "GET", - "version": "HTTP/1.1", - "url": "https://gist.githubusercontent.com/mattes/23e64faadb5fd4b5112f379903d2572e/raw/74e517a60c21a5c11d94fec8b572f68addfade39/example.json", - "headers": { - }, - "body": { - "userId": "<% userId %>" - }, - "vars": { - "userId": { - "description": "Reddit username", - "required": true, - "pattern": "^[A-Za-z0-9_-]{3,20}$" - }, - "authToken": { - "description": "Authentication token", - "required": false, - "default": "abcdef1234567890abcdef1234567890", - "pattern": "^[A-Za-z0-9]{32}$" - } + "manifest": { + "manifestVersion": "1", + "id": "reddit-user-karma", + "title": "Total Reddit Karma", + "description": "Generate a proof that you have a certain amount of karma", + "prepareUrl": "https://www.reddit.com/login/", + "request": { + "method": "GET", + "version": "HTTP/1.1", + "url": "https://gist.githubusercontent.com/mattes/23e64faadb5fd4b5112f379903d2572e/raw/74e517a60c21a5c11d94fec8b572f68addfade39/example.json", + "headers": {}, + "body": { + "userId": "<% userId %>" + }, + "vars": { + "userId": { + "description": "Reddit username", + "required": true, + "pattern": "^[A-Za-z0-9_-]{3,20}$" }, - "extra": { - "headers": { - "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Mobile Safari/537.36", - "Content-Type": "application/json" - } + "authToken": { + "description": "Authentication token", + "required": false, + "default": "abcdef1234567890abcdef1234567890", + "pattern": "^[A-Za-z0-9]{32}$" } }, - "response": { - "status": "200", - "version": "HTTP/1.1", - "message": "OK", + "extra": { "headers": { - "Content-Type": "text/plain; charset=utf-8" - }, - "body": { - "format": "json", - "extractors": [ - { - "id": "helloValue", - "description": "Extract the hello value", - "selector": [ - "hello" - ], - "type": "string" - } - ] + "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Mobile Safari/537.36", + "Content-Type": "application/json" } } + }, + "response": { + "status": "200", + "version": "HTTP/1.1", + "message": "OK", + "headers": { + "Content-Type": "text/plain; charset=utf-8" + }, + "body": { + "format": "json", + "extractors": [ + { + "id": "helloValue", + "description": "Extract the hello value", + "selector": [ + "hello" + ], + "type": "string" + } + ] + } } } -} +} \ No newline at end of file diff --git a/notary/Cargo.toml b/notary/Cargo.toml index 801e89db1..914e2f8cf 100644 --- a/notary/Cargo.toml +++ b/notary/Cargo.toml @@ -29,8 +29,8 @@ tracing-subscriber={ workspace=true } alloy-primitives={ version="0.8.2", features=["k256"] } async-trait ="0.1.67" -axum ={ version="0.7", features=["ws", "json"] } -axum-core ="0.4" +axum ={ workspace=true } +axum-core ={ workspace=true } base64 ="0.21.0" clap ={ workspace=true } config ="0.14.0" diff --git a/notary/src/frame.rs b/notary/src/frame.rs index 89c5e7f3b..a8fa82f27 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -1,4 +1,3 @@ -use core::panic; use std::{ path::PathBuf, process::Command, @@ -18,11 +17,11 @@ use futures_util::{stream::SplitSink, SinkExt}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::sync::{oneshot, Mutex}; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use uuid::Uuid; - -use crate::runner::{Prompt, PromptResponse}; // use views::View; +use web_prover_core::frame::{Action, Prompt, PromptResponse, View}; + use crate::SharedState; // pub mod views; @@ -43,18 +42,6 @@ pub enum ConnectionState { * every 60 secs */ } -#[derive(Debug, Serialize, Deserialize)] -pub struct Action { - pub kind: String, - pub payload: serde_json::Value, -} - -#[derive(Debug, Serialize)] -pub enum View { - InitialView, - PromptView { prompts: Vec }, -} - pub struct Session { session_id: Uuid, ws_sender: Option>, @@ -76,10 +63,10 @@ impl Session { session } - async fn run(&self) { + async fn run(&self, initial_input: web_prover_core::frame::InitialInput) { let playwright_runner_config = web_prover_executor::playwright::PlaywrightRunnerConfig { - script: "".to_string(), - timeout_seconds: 0, + script: initial_input.script, + timeout_seconds: 60, }; let node_path = @@ -145,7 +132,7 @@ impl Session { /// Called when the client connects. Can be called multiple times. pub async fn on_client_connect(&mut self) { - // send initial view + debug!("Sending initial view"); let current_view_serialized = serde_json::to_string(&self.current_view).unwrap(); self.ws_sender.as_mut().unwrap().send(Message::Text(current_view_serialized)).await.unwrap(); } @@ -168,6 +155,7 @@ pub async fn on_websocket( Query(params): Query>, State(state): State>, ) -> impl IntoResponse { + debug!("Starting frame connection"); // Parse ?session_id from query let session_id = match params.get("session_id") { Some(id) => match Uuid::parse_str(id) { @@ -222,7 +210,7 @@ async fn handle_websocket_connection( session.lock().await.on_client_connect().await; // TODO pass sender? - session.lock().await.run().await; + // session.lock().await.run().await; // TODO what if next() returns None?! while let Some(result) = receiver.next().await { @@ -275,19 +263,24 @@ async fn process_text_message(text: String, session: Arc>) { Ok(action) => { let action = session.lock().await.handle(action).await; match action.kind.as_str() { + "initial_input" => { + let initial_input = + serde_json::from_value::(action.payload).unwrap(); + session.lock().await.run(initial_input).await; + }, "prompt_response" => { let prompt_response = serde_json::from_value::(action.payload).unwrap(); session.lock().await.handle_prompt_response(prompt_response).await; }, _ => { - panic!("Invalid action: {}", action.kind); + error!("Invalid action: {}", action.kind); }, } // TODO send result to client }, Err(err) => { // TODO send error to client - + error!("Failed to parse action: {}", err); // let sender = session.lock().await.ws_sender.as_mut(); // // Send an error message to the client diff --git a/notary/src/main.rs b/notary/src/main.rs index e92dc99f4..7d28af1f0 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -102,7 +102,7 @@ async fn main() -> Result<(), NotaryServerError> { let router = Router::new() .route("/health", get(|| async move { (StatusCode::OK, "Ok").into_response() })) .route("/v1/proxy", post(proxy::proxy)) - .route("/v1/frame", post(frame::on_websocket)) + .route("/v1/frame", get(frame::on_websocket)) .route("/v1/meta/keys/:key", get(meta_keys)) .layer(CorsLayer::permissive()) .with_state(shared_state); diff --git a/notary/src/runner.rs b/notary/src/runner.rs index f4eb8b509..736f40187 100644 --- a/notary/src/runner.rs +++ b/notary/src/runner.rs @@ -4,36 +4,11 @@ use axum::{ extract::{self, State}, Json, }; -use serde::{Deserialize, Serialize}; -use serde_json::Value; use tracing::debug; +use web_prover_core::frame::{PromptRequest, PromptResponse, ProveRequest}; use crate::{error::NotaryServerError, SharedState}; -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Prompt { - pub title: String, - pub types: String, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct PromptRequest { - pub uuid: String, - pub prompts: Vec, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct PromptResponse { - pub inputs: Vec, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ProveRequest { - pub uuid: String, - pub key: String, - pub value: Value, -} - #[derive(Debug, thiserror::Error)] pub enum RunnerError { #[error("Playwright session disconnected")] From 0e0e687f8b05bbe20e3c54a70512f339515e0a7c Mon Sep 17 00:00:00 2001 From: lonerapier Date: Fri, 14 Mar 2025 02:16:12 +0530 Subject: [PATCH 20/21] wip: save state --- client/Cargo.toml | 4 +- client/src/lib.rs | 200 +++++++++++++++++++++++++------------ client/src/main.rs | 7 +- executor/src/playwright.rs | 78 ++++++++------- notary/src/frame.rs | 36 ++++--- notary/src/runner.rs | 20 +++- 6 files changed, 227 insertions(+), 118 deletions(-) diff --git a/client/Cargo.toml b/client/Cargo.toml index d2052f395..123c5420d 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -51,7 +51,7 @@ tokio ={ workspace=true, features=["rt", "rt-multi-thread", "macros", tokio-rustls ={ version="0.26", default-features=false, features=["logging", "tls12"] } tokio-tungstenite={ workspace=true } # TLSN -reqwest={ version="0.12", features=["json", "rustls-tls"] } +native-tls="0.2.14" +reqwest ={ version="0.12", features=["json", "rustls-tls"] } [dev-dependencies] -native-tls="0.2.14" diff --git a/client/src/lib.rs b/client/src/lib.rs index 43c646f89..b29fcac2e 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -3,15 +3,47 @@ pub mod config; pub mod error; use std::collections::HashMap; +use futures::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; +use tokio_tungstenite::tungstenite::{client::IntoClientRequest, Message::Text}; use tracing::debug; use web_prover_core::{ + frame::{ + Action, InitialInput, PromptResponse, + View::{self, InitialView, PromptView}, + }, manifest::Manifest, proof::{SignedVerificationReply, TeeProof}, }; use crate::error::WebProverClientError; +const EXAMPLE_DEVELOPER_SCRIPT: &str = r#" +await page.goto("https://pseudo-bank.pluto.dev"); + +const username = page.getByRole("textbox", { name: "Username" }); +const password = page.getByRole("textbox", { name: "Password" }); + +let input = await prompt([ + { title: "Username", types: "text" }, + { title: "Password", types: "password" }, +]); + +await username.fill(input.inputs[0]); +await password.fill(input.inputs[1]); + +const loginBtn = page.getByRole("button", { name: "Login" }); +await loginBtn.click(); + +await page.waitForSelector("text=Your Accounts", { timeout: 5000 }); + +const balanceLocator = page.locator("\#balance-2"); +await balanceLocator.waitFor({ state: "visible", timeout: 5000 }); +const balanceText = (await balanceLocator.textContent()) || ""; +const balance = parseFloat(balanceText.replace(/[$,]/g, "")); + +await prove("bank_balance", balance); +"#; #[derive(Serialize, Deserialize, Clone, Debug)] pub struct ProxyConfig { pub target_method: String, @@ -83,45 +115,97 @@ pub async fn verify( Ok(verify_response) } +pub async fn frame() { + let config = std::fs::read("./fixture/client.proxy.json").unwrap(); + let mut config: config::Config = serde_json::from_slice(&config).unwrap(); + config.set_session_id(); + + let url = format!( + "wss://{}:{}/v1/frame?session_id={}", + config.notary_host.clone(), + config.notary_port.clone(), + config.session_id + ); + debug!("url={}", url); + + // Set up TLS connector that accepts your server certificate + let mut connector_builder = native_tls::TlsConnector::builder(); + + // For testing only: disable certificate verification + // WARNING: Only use this for testing, never in production + connector_builder.danger_accept_invalid_certs(true); + + let connector = connector_builder.build().unwrap(); + let connector = native_tls::TlsConnector::from(connector); + + // Connect with TLS + let request = url.into_client_request().unwrap(); + let (mut ws_stream, response) = tokio_tungstenite::connect_async_tls_with_config( + request, + None, + false, + Some(tokio_tungstenite::Connector::NativeTls(connector)), + ) + .await + .unwrap(); + + // assert!(response.status().is_success(), "WebSocket connection failed"); + debug!("response={:?}", response); + + let ws_spawn = tokio::spawn(async move { + while let Some(message) = ws_stream.next().await { + let message = message.unwrap(); + debug!("message={:?}", message); + + match message { + Text(text) => { + let view: View = serde_json::from_str(&text).unwrap(); + match view { + InitialView => { + debug!("InitialView"); + let action = Action { + kind: "initial_input".to_owned(), + payload: serde_json::to_value(InitialInput { + script: EXAMPLE_DEVELOPER_SCRIPT.to_owned(), + }) + .unwrap(), + }; + ws_stream.send(Text(serde_json::to_string(&action).unwrap().into())).await.unwrap(); + }, + PromptView { prompts } => { + debug!("Received PromptView with prompts: {:?}", prompts); + let prompt_response = PromptResponse { + inputs: prompts.iter().map(|prompt| prompt.title.clone()).collect(), + }; + let action = Action { + kind: "prompt_response".to_owned(), + payload: serde_json::to_value(prompt_response).unwrap(), + }; + ws_stream.send(Text(serde_json::to_string(&action).unwrap().into())).await.unwrap(); + debug!("Sent prompt response: {:?}", action); + }, + } + }, + _ => panic!("unexpected message"), + }; + } + }) + .await; +} + #[cfg(test)] mod tests { use futures::{SinkExt, StreamExt}; - // use tokio_rustls::{ - // rustls::{Certificate, ClientConfig, RootCertStore}, - // TlsConnector, - // }; - use tokio_tungstenite::tungstenite::client::IntoClientRequest; - - const EXAMPLE_DEVELOPER_SCRIPT: &str = r#" - await page.goto("https://pseudo-bank.pluto.dev"); - - const username = page.getByRole("textbox", { name: "Username" }); - const password = page.getByRole("textbox", { name: "Password" }); - - let input = await prompt([ - { title: "Username", types: "text" }, - { title: "Password", types: "password" }, - ]); - - await username.fill(input.inputs[0]); - await password.fill(input.inputs[1]); - - const loginBtn = page.getByRole("button", { name: "Login" }); - await loginBtn.click(); - - await page.waitForSelector("text=Your Accounts", { timeout: 5000 }); - - const balanceLocator = page.locator("\#balance-2"); - await balanceLocator.waitFor({ state: "visible", timeout: 5000 }); - const balanceText = (await balanceLocator.textContent()) || ""; - const balance = parseFloat(balanceText.replace(/[$,]/g, "")); - - await prove("bank_balance", balance); - "#; + use tokio_tungstenite::tungstenite::{client::IntoClientRequest, Message::Text}; + use web_prover_core::frame::{ + Action, InitialInput, PromptResponse, + View::{self, InitialView, PromptView}, + }; use super::*; #[tokio::test] + #[tracing::instrument] async fn test_frame() { let config = std::fs::read("../fixture/client.proxy.json").unwrap(); let mut config: config::Config = serde_json::from_slice(&config).unwrap(); @@ -159,58 +243,44 @@ mod tests { // assert!(response.status().is_success(), "WebSocket connection failed"); println!("response={:?}", response); - // let message = "Hello, server!"; - // ws_stream.send(tokio_tungstenite::tungstenite::Message::Text(message.into())).await.unwrap(); - - // let received_message = ws_stream.next().await.unwrap().unwrap(); - // assert_eq!(received_message, tokio_tungstenite::tungstenite::Message::Text(message.into())); - tokio::spawn(async move { + let ws_spawn = tokio::spawn(async move { while let Some(message) = ws_stream.next().await { let message = message.unwrap(); println!("message={:?}", message); match message { - tokio_tungstenite::tungstenite::Message::Text(text) => { - let view: web_prover_core::frame::View = serde_json::from_str(&text).unwrap(); + Text(text) => { + let view: View = serde_json::from_str(&text).unwrap(); match view { - web_prover_core::frame::View::InitialView => { + InitialView => { println!("InitialView"); - ws_stream - .send(tokio_tungstenite::tungstenite::Message::Text( - serde_json::to_string(&web_prover_core::frame::Action { - kind: "initial_input".to_owned(), - payload: serde_json::to_value(web_prover_core::frame::InitialInput { - script: EXAMPLE_DEVELOPER_SCRIPT.to_owned(), - }) - .unwrap(), - }) - .unwrap() - .into(), - )) - .await - .unwrap(); + let action = Action { + kind: "initial_input".to_owned(), + payload: serde_json::to_value(InitialInput { + script: EXAMPLE_DEVELOPER_SCRIPT.to_owned(), + }) + .unwrap(), + }; + ws_stream.send(Text(serde_json::to_string(&action).unwrap().into())).await.unwrap(); }, - web_prover_core::frame::View::PromptView { prompts } => { + PromptView { prompts } => { println!("Received PromptView with prompts: {:?}", prompts); - let prompt_response = web_prover_core::frame::PromptResponse { + let prompt_response = PromptResponse { inputs: prompts.iter().map(|prompt| prompt.title.clone()).collect(), }; - let action = web_prover_core::frame::Action { + let action = Action { kind: "prompt_response".to_owned(), payload: serde_json::to_value(prompt_response).unwrap(), }; - ws_stream - .send(tokio_tungstenite::tungstenite::Message::Text( - serde_json::to_string(&action).unwrap().into(), - )) - .await - .unwrap(); + ws_stream.send(Text(serde_json::to_string(&action).unwrap().into())).await.unwrap(); + println!("Sent prompt response: {:?}", action); }, } }, _ => panic!("unexpected message"), }; } - }); + }) + .await; } } diff --git a/client/src/main.rs b/client/src/main.rs index 9c828f96b..e49669889 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -32,8 +32,9 @@ async fn main() -> Result<(), WebProverClientError> { let mut config: Config = serde_json::from_str(&config_json)?; config.set_session_id(); - let proof = web_prover_client::proxy(config).await?; - let proof_json = serde_json::to_string_pretty(&proof)?; - println!("Proving Successful: proof_len={:?}", proof_json.len()); + // let proof = web_prover_client::proxy(config).await?; + // let proof_json = serde_json::to_string_pretty(&proof)?; + // println!("Proving Successful: proof_len={:?}", proof_json.len()); + web_prover_client::frame().await; Ok(()) } diff --git a/executor/src/playwright.rs b/executor/src/playwright.rs index 8b75964f8..8a3e18157 100644 --- a/executor/src/playwright.rs +++ b/executor/src/playwright.rs @@ -1,11 +1,12 @@ use std::{ io::{Read, Write}, path::PathBuf, - process::{Command, Stdio}, + process::Stdio, time::Duration, }; use tempfile::NamedTempFile; +use tokio::process::Command; use tracing::{debug, error}; use uuid::Uuid; use wait_timeout::ChildExt; @@ -108,43 +109,52 @@ impl PlaywrightRunner { command.env(key, value); } - let mut child = command.spawn()?; + let child = command.spawn()?; + + let status = child.wait_with_output().await?; + if !status.status.success() { + error!("Playwright execution failed: {:?}", status); + return Err(PlaywrightError::ExecutionError("Playwright execution failed".into())); + } + + let stdout = String::from_utf8_lossy(&status.stdout).to_string(); + let stderr = String::from_utf8_lossy(&status.stderr).to_string(); // Set a timeout - let timeout = Duration::from_secs(self.config.timeout_seconds); - let _ = match child.wait_timeout(timeout)? { - Some(status) => - if let Some(code) = status.code() { - code - } else { - error!("Process terminated by signal: {:?}", status); - return Err(PlaywrightError::ExecutionError("Process terminated by signal".into())); - }, - None => { - child.kill()?; - error!("Process timed out after {:?}", timeout); - return Err(PlaywrightError::ExecutionError("Process timed out".into())); - }, - }; + // let timeout = Duration::from_secs(self.config.timeout_seconds); + // let _ = match child.wait_timeout(timeout)? { + // Some(status) => + // if let Some(code) = status.code() { + // code + // } else { + // error!("Process terminated by signal: {:?}", status); + // return Err(PlaywrightError::ExecutionError("Process terminated by signal".into())); + // }, + // None => { + // child.kill()?; + // error!("Process timed out after {:?}", timeout); + // return Err(PlaywrightError::ExecutionError("Process timed out".into())); + // }, + // }; // Convert output to string - let stdout = match child.stdout.take() { - Some(mut stdout_stream) => { - let mut stdout = String::new(); - stdout_stream.read_to_string(&mut stdout)?; - stdout - }, - None => String::new(), - }; - - let stderr = match child.stderr.take() { - Some(mut stderr_stream) => { - let mut stderr = String::new(); - stderr_stream.read_to_string(&mut stderr)?; - stderr - }, - None => String::new(), - }; + // let stdout = match child.stdout.take() { + // Some(mut stdout_stream) => { + // let mut stdout = String::new(); + // stdout_stream.read_to_string(&mut stdout)?; + // stdout + // }, + // None => String::new(), + // }; + + // let stderr = match child.stderr.take() { + // Some(mut stderr_stream) => { + // let mut stderr = String::new(); + // stderr_stream.read_to_string(&mut stderr)?; + // stderr + // }, + // None => String::new(), + // }; let output = PlaywrightOutput { stdout, stderr }; diff --git a/notary/src/frame.rs b/notary/src/frame.rs index a8fa82f27..c7d83c285 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -63,7 +63,8 @@ impl Session { session } - async fn run(&self, initial_input: web_prover_core::frame::InitialInput) { + // TODO: return result + async fn run(session_id: Uuid, initial_input: web_prover_core::frame::InitialInput) { let playwright_runner_config = web_prover_executor::playwright::PlaywrightRunnerConfig { script: initial_input.script, timeout_seconds: 60, @@ -80,7 +81,7 @@ impl Session { vec![(String::from("DEBUG"), String::from("pw:api"))], ); - let session_id = self.session_id.clone(); + let session_id = session_id.clone(); let script_result = tokio::spawn(async move { playwright_runner.run_script(&session_id).await }); @@ -105,29 +106,36 @@ impl Session { pub async fn handle_prompt( &mut self, prompts: Vec, - ) -> Result { + ) -> Result, FrameError> { + debug!("Handling prompt: {:?}", prompts); let prompt_view = View::PromptView { prompts }; let serialized_view = serde_json::to_string(&prompt_view).unwrap(); let (prompt_response_sender, prompt_response_receiver) = oneshot::channel::(); self.prompt_response_sender.lock().await.replace(prompt_response_sender); + assert!(self.prompt_response_sender.lock().await.is_some()); + // TODO: session should store each view sent with a request id, so that it can match the // response - self.ws_sender.as_mut().unwrap().send(Message::Text(serialized_view)).await.unwrap(); - + debug!("Sending prompt view to client"); self.current_view = prompt_view; + self.ws_sender.as_mut().unwrap().send(Message::Text(serialized_view)).await.unwrap(); - match tokio::time::timeout(Duration::from_secs(60), prompt_response_receiver).await { - Ok(Ok(prompt_response)) => Ok(prompt_response), - Ok(Err(_)) => Err(FrameError::WebSocketError("Prompt response channel closed".to_string())), - Err(_) => Err(FrameError::PromptTimeout), - } + debug!("Prompt view sent successfully"); + Ok(prompt_response_receiver) } pub async fn handle_prompt_response(&mut self, prompt_response: PromptResponse) { + debug!("Received prompt response: {:?}", prompt_response); + assert!(self.prompt_response_sender.lock().await.is_some(), "No prompt response sender"); let prompt_response_sender = self.prompt_response_sender.lock().await.take().unwrap(); - prompt_response_sender.send(prompt_response).unwrap(); + let send_result = prompt_response_sender.send(prompt_response); + if let Err(e) = send_result { + error!("Failed to send prompt response: {:?}", e); + } else { + debug!("Prompt response handled!"); + } } /// Called when the client connects. Can be called multiple times. @@ -261,15 +269,17 @@ async fn process_text_message(text: String, session: Arc>) { let action = serde_json::from_str::(&text); match action { Ok(action) => { - let action = session.lock().await.handle(action).await; + // let action = session.lock().await.handle(action).await; match action.kind.as_str() { "initial_input" => { let initial_input = serde_json::from_value::(action.payload).unwrap(); - session.lock().await.run(initial_input).await; + let session_id = session.lock().await.session_id; + Session::run(session_id, initial_input).await; }, "prompt_response" => { let prompt_response = serde_json::from_value::(action.payload).unwrap(); + debug!("Received prompt response: {:?}", prompt_response); session.lock().await.handle_prompt_response(prompt_response).await; }, _ => { diff --git a/notary/src/runner.rs b/notary/src/runner.rs index 736f40187..92d0b056c 100644 --- a/notary/src/runner.rs +++ b/notary/src/runner.rs @@ -17,6 +17,8 @@ pub enum RunnerError { PlaywrightSessionNotConnected, #[error(transparent)] FrameError(#[from] crate::frame::FrameError), + #[error(transparent)] + RecvError(#[from] tokio::sync::oneshot::error::RecvError), } pub async fn prompt( @@ -28,12 +30,28 @@ pub async fn prompt( // let response = PromptResponse { inputs }; let session_id = uuid::Uuid::parse_str(&payload.uuid).unwrap(); + debug!("session_id: {:?}", session_id); let frame_sessions = state.frame_sessions.lock().await; + debug!("frame_sessions_got"); let response = match frame_sessions.get(&session_id) { Some(crate::frame::ConnectionState::Connected) => { let session = state.sessions.lock().await.get(&session_id).unwrap().clone(); - let response = + debug!("session lock acquired"); + let prompt_response_receiver = session.lock().await.handle_prompt(payload.prompts).await.map_err(RunnerError::from)?; + debug!("prompt_response_receiver acquired"); + // TODO: is there a deadlock here??? + // prompt response is received after timeout has passed + let response = + tokio::time::timeout(std::time::Duration::from_secs(30), prompt_response_receiver) + .await + .map_err(|_| RunnerError::FrameError(crate::frame::FrameError::PromptTimeout))? + .map_err(RunnerError::from)?; + // let response = match prompt_response_receiver.await { + // Ok(response) => response, + // Err(e) => return Err(e).map_err(RunnerError::from)?, + // }; + debug!("Prompt response: {:?}", response); Ok::(response) }, Some(crate::frame::ConnectionState::Disconnected(_)) => { From d368d7ba579d36cb234048fa702c8b3e1ac39aae Mon Sep 17 00:00:00 2001 From: lonerapier Date: Fri, 14 Mar 2025 13:31:53 +0530 Subject: [PATCH 21/21] frame working --- client/src/lib.rs | 29 +++++++++++++++++++------ core/src/frame.rs | 9 +++++--- notary/src/frame.rs | 18 ++++++++++++---- notary/src/runner.rs | 51 ++++++++++++++++++++++++-------------------- 4 files changed, 71 insertions(+), 36 deletions(-) diff --git a/client/src/lib.rs b/client/src/lib.rs index b29fcac2e..c4d98b68b 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -162,7 +162,7 @@ pub async fn frame() { let view: View = serde_json::from_str(&text).unwrap(); match view { InitialView => { - debug!("InitialView"); + debug!("Received InitialView"); let action = Action { kind: "initial_input".to_owned(), payload: serde_json::to_value(InitialInput { @@ -182,15 +182,23 @@ pub async fn frame() { payload: serde_json::to_value(prompt_response).unwrap(), }; ws_stream.send(Text(serde_json::to_string(&action).unwrap().into())).await.unwrap(); - debug!("Sent prompt response: {:?}", action); + }, + View::ProveView { proof } => { + debug!("Received ProveView with proof: {:?}", proof); + + ws_stream.close(None).await.unwrap(); }, } }, _ => panic!("unexpected message"), }; } - }) - .await; + }); + + match ws_spawn.await { + Ok(_) => debug!("WebSocket task completed"), + Err(e) => debug!("WebSocket task failed: {:?}", e), + } } #[cfg(test)] @@ -275,12 +283,21 @@ mod tests { ws_stream.send(Text(serde_json::to_string(&action).unwrap().into())).await.unwrap(); println!("Sent prompt response: {:?}", action); }, + View::ProveView { proof } => { + println!("Received ProveView with proof: {:?}", proof); + + ws_stream.close(None).await.unwrap(); + }, } }, _ => panic!("unexpected message"), }; } - }) - .await; + }); + + match ws_spawn.await { + Ok(_) => println!("WebSocket task completed"), + Err(e) => println!("WebSocket task failed: {:?}", e), + } } } diff --git a/core/src/frame.rs b/core/src/frame.rs index ec234b429..2b11afbad 100644 --- a/core/src/frame.rs +++ b/core/src/frame.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -11,6 +13,7 @@ pub struct Action { pub enum View { InitialView, PromptView { prompts: Vec }, + ProveView { proof: FrameProof }, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -35,9 +38,9 @@ pub struct PromptResponse { pub inputs: Vec, } +pub type FrameProof = HashMap; #[derive(Serialize, Deserialize, Debug)] -pub struct ProveRequest { +pub struct ProveOutput { pub uuid: String, - pub key: String, - pub value: Value, + pub proof: FrameProof, } diff --git a/notary/src/frame.rs b/notary/src/frame.rs index c7d83c285..7c4ba65a1 100644 --- a/notary/src/frame.rs +++ b/notary/src/frame.rs @@ -20,7 +20,7 @@ use tokio::sync::{oneshot, Mutex}; use tracing::{debug, error, info, warn}; use uuid::Uuid; // use views::View; -use web_prover_core::frame::{Action, Prompt, PromptResponse, View}; +use web_prover_core::frame::{Action, FrameProof, Prompt, PromptResponse, ProveOutput, View}; use crate::SharedState; @@ -126,6 +126,15 @@ impl Session { Ok(prompt_response_receiver) } + pub async fn handle_prove(&mut self, proof: FrameProof) -> Result<(), FrameError> { + debug!("Handling prove: {:?}", proof); + let prove_view = View::ProveView { proof }; + let serialized_view = serde_json::to_string(&prove_view).unwrap(); + self.ws_sender.as_mut().unwrap().send(Message::Text(serialized_view)).await.unwrap(); + debug!("Prove view sent successfully"); + Ok(()) + } + pub async fn handle_prompt_response(&mut self, prompt_response: PromptResponse) { debug!("Received prompt response: {:?}", prompt_response); assert!(self.prompt_response_sender.lock().await.is_some(), "No prompt response sender"); @@ -218,14 +227,15 @@ async fn handle_websocket_connection( session.lock().await.on_client_connect().await; // TODO pass sender? - // session.lock().await.run().await; - // TODO what if next() returns None?! while let Some(result) = receiver.next().await { match result { Ok(message) => match message { axum::extract::ws::Message::Text(text) => { - process_text_message(text, session.clone()).await; + let session_clone = session.clone(); + tokio::spawn(async move { + process_text_message(text, session_clone.clone()).await; + }); }, axum::extract::ws::Message::Binary(_) => { warn!("Binary messages are not supported"); diff --git a/notary/src/runner.rs b/notary/src/runner.rs index 92d0b056c..157645528 100644 --- a/notary/src/runner.rs +++ b/notary/src/runner.rs @@ -5,7 +5,7 @@ use axum::{ Json, }; use tracing::debug; -use web_prover_core::frame::{PromptRequest, PromptResponse, ProveRequest}; +use web_prover_core::frame::{PromptRequest, PromptResponse, ProveOutput}; use crate::{error::NotaryServerError, SharedState}; @@ -26,33 +26,17 @@ pub async fn prompt( extract::Json(payload): extract::Json, ) -> Result, NotaryServerError> { debug!("Prompting: {:?}", payload); - // let inputs = payload.prompts.iter().map(|prompt| prompt.title.clone()).collect(); - // let response = PromptResponse { inputs }; let session_id = uuid::Uuid::parse_str(&payload.uuid).unwrap(); - debug!("session_id: {:?}", session_id); + let frame_sessions = state.frame_sessions.lock().await; - debug!("frame_sessions_got"); - let response = match frame_sessions.get(&session_id) { + let prompt_response_receiver = match frame_sessions.get(&session_id) { Some(crate::frame::ConnectionState::Connected) => { let session = state.sessions.lock().await.get(&session_id).unwrap().clone(); - debug!("session lock acquired"); let prompt_response_receiver = session.lock().await.handle_prompt(payload.prompts).await.map_err(RunnerError::from)?; debug!("prompt_response_receiver acquired"); - // TODO: is there a deadlock here??? - // prompt response is received after timeout has passed - let response = - tokio::time::timeout(std::time::Duration::from_secs(30), prompt_response_receiver) - .await - .map_err(|_| RunnerError::FrameError(crate::frame::FrameError::PromptTimeout))? - .map_err(RunnerError::from)?; - // let response = match prompt_response_receiver.await { - // Ok(response) => response, - // Err(e) => return Err(e).map_err(RunnerError::from)?, - // }; - debug!("Prompt response: {:?}", response); - Ok::(response) + Ok::, RunnerError>(prompt_response_receiver) }, Some(crate::frame::ConnectionState::Disconnected(_)) => { return Err(RunnerError::PlaywrightSessionDisconnected).map_err(NotaryServerError::from); @@ -62,15 +46,36 @@ pub async fn prompt( }, }?; - drop(frame_sessions); + debug!("waiting for prompt response"); + let response = tokio::time::timeout(std::time::Duration::from_secs(30), prompt_response_receiver) + .await + .map_err(|_| RunnerError::FrameError(crate::frame::FrameError::PromptTimeout))? + .map_err(RunnerError::from)?; Ok(Json(response)) } pub async fn prove( - State(_state): State>, - extract::Json(payload): extract::Json, + State(state): State>, + extract::Json(payload): extract::Json, ) -> Result, NotaryServerError> { debug!("Proving: {:?}", payload); + + let session_id = uuid::Uuid::parse_str(&payload.uuid).unwrap(); + + let frame_sessions = state.frame_sessions.lock().await; + let session = match frame_sessions.get(&session_id) { + Some(crate::frame::ConnectionState::Connected) => + state.sessions.lock().await.get(&session_id).unwrap().clone(), + Some(crate::frame::ConnectionState::Disconnected(_)) => { + return Err(RunnerError::PlaywrightSessionDisconnected).map_err(NotaryServerError::from); + }, + None => { + return Err(RunnerError::PlaywrightSessionNotConnected).map_err(NotaryServerError::from); + }, + }; + + session.lock().await.handle_prove(payload.proof).await.map_err(RunnerError::from)?; + Ok(Json(())) }