Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions src/book/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{BTreeMap, HashMap};
use std::collections::{BTreeMap, HashMap, HashSet};

use camino::Utf8PathBuf;
use glob::glob;
Expand Down Expand Up @@ -27,7 +27,15 @@ macro_rules! eval_if_in_filter {
};
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExecutionContext {
#[serde(rename = "cmdline")]
CommandLine(Vec<String>),
#[serde(rename = "platforms")]
PlatformSpecific(BTreeMap<String, Vec<String>>),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Parameter {
#[serde(rename = "type")]
pub param_type: String,
Expand All @@ -42,7 +50,7 @@ fn default_required() -> bool {
true
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Container {
#[serde(flatten)]
pub source: ContainerSource,
Expand Down Expand Up @@ -143,7 +151,7 @@ impl Container {

// TODO: add optional parsers to reduce output tokens

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Function {
pub description: String,
pub parameters: BTreeMap<String, Parameter>,
Expand All @@ -153,7 +161,7 @@ pub struct Function {
pub execution: runtime::ExecutionContext,
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Page {
#[serde(skip_serializing_if = "String::is_empty")]
#[serde(default = "String::new")]
Expand Down Expand Up @@ -185,9 +193,10 @@ impl Page {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Book {
pub pages: BTreeMap<Utf8PathBuf, Page>,
pub failed_containers: HashSet<String>,
}

impl Book {
Expand Down Expand Up @@ -313,7 +322,14 @@ impl Book {
pages.insert(page_path, page);
}

Ok(Self { pages })
Ok(Self {
pages,
failed_containers: HashSet::new(),
})
}

pub fn mark_failed_container(&mut self, func_name: String) {
self.failed_containers.insert(func_name);
}

pub fn size(&self) -> usize {
Expand All @@ -335,20 +351,36 @@ impl Book {
Err(anyhow::anyhow!("function {} not found", name))
}

pub fn as_tools<'a, T>(&'a self, filter: Option<String>) -> Vec<T>
// Modify as_tools to filter out failed functions
pub fn as_tools<T>(&self, filter: Option<String>) -> Vec<T>
where
Vec<T>: std::convert::From<&'a Page>,
for<'a> Vec<T>: From<&'a Page>,
{
let mut tools = Vec::new();

for (page_path, page) in &self.pages {
// Create filtered page in its own scope with proper ownership
let filtered_page = {
// Filter out failed functions
let filtered_functions = page.functions
.iter()
.filter(|(name, _)| !self.failed_containers.contains(*name))
.map(|(name, func)| (name.clone(), func.clone()))
.collect();

Page {
name: page.name.clone(),
description: page.description.clone(),
functions: filtered_functions,
categories: page.categories.clone(),
}
};

eval_if_in_filter!(
page_path,
filter,
tools.extend(<&Page as Into<Vec<T>>>::into(page))
tools.extend(Vec::<T>::from(&filtered_page))
);
}

tools
}
}
Expand Down Expand Up @@ -381,7 +413,10 @@ mod tests {
},
);
pages.insert(Utf8PathBuf::from("test_page"), page);
Book { pages }
Book {
pages,
failed_containers: HashSet::new(),
}
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion src/book/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl std::fmt::Display for ExecutionFlavor {
}
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExecutionContext {
#[serde(rename = "cmdline")]
CommandLine(Vec<String>),
Expand Down
37 changes: 27 additions & 10 deletions src/cli/serve.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;

use actix_cors::Cors;
use actix_web::web;
Expand All @@ -20,7 +22,7 @@ use super::ServeArgs;

struct AppState {
max_running_tasks: usize,
book: Arc<Book>,
book: Arc<Mutex<Book>>,
ssh: Option<SSHConnection>,
}

Expand All @@ -38,13 +40,13 @@ async fn serve_pages_impl(

match flavor {
Flavor::Nerve => {
Ok(HttpResponse::Ok().json(state.book.as_tools::<nerve::FunctionGroup>(filter)))
Ok(HttpResponse::Ok().json(state.book.lock().await.as_tools::<nerve::FunctionGroup>(filter)))
}
Flavor::Rigging => {
Ok(HttpResponse::Ok().json(state.book.as_tools::<rigging::Tool>(filter)))
Ok(HttpResponse::Ok().json(state.book.lock().await.as_tools::<rigging::Tool>(filter)))
}
// default to openai
_ => Ok(HttpResponse::Ok().json(state.book.as_tools::<openai::Tool>(filter))),
_ => Ok(HttpResponse::Ok().json(state.book.lock().await.as_tools::<openai::Tool>(filter))),
}
}

Expand All @@ -67,10 +69,11 @@ async fn process_calls(
state: web::Data<Arc<AppState>>,
calls: web::Json<Vec<openai::Call>>,
) -> actix_web::Result<HttpResponse> {
let book = state.book.lock().await;
match runtime::execute(
state.ssh.clone(),
false,
state.book.clone(),
Arc::new(book.clone()),
calls.0,
state.max_running_tasks,
)
Expand Down Expand Up @@ -98,16 +101,30 @@ pub(crate) async fn serve(args: ServeArgs) -> anyhow::Result<()> {
None
};

let book = Arc::new(Book::from_path(args.path, args.filter)?);
let book = Book::from_path(args.path, args.filter)?;
let book = Arc::new(Mutex::new(book));

if !args.lazy {
for page in book.pages.values() {
for (func_name, func) in page.functions.iter() {
let mut book_guard = book.lock().await;
let mut failed_containers = HashSet::new();

// First collect all failures
for (_, page) in &book_guard.pages {
for (func_name, func) in &page.functions {
if let Some(container) = &func.container {
log::info!("pre building container for function {} ...", func_name);
container.resolve().await?;
if let Err(e) = container.resolve().await {
log::error!("Failed to resolve container for function {}: {}", func_name, e);
failed_containers.insert(func_name.clone());
}
}
}
}

// Then update the failed containers
for func_name in failed_containers {
book_guard.mark_failed_container(func_name);
}
}

let max_running_tasks = if args.workers == 0 {
Expand All @@ -118,7 +135,7 @@ pub(crate) async fn serve(args: ServeArgs) -> anyhow::Result<()> {

log::info!(
"serving {} pages on http://{} with {max_running_tasks} max running tasks",
book.size(),
book.lock().await.size(),
&args.address,
);

Expand Down
3 changes: 2 additions & 1 deletion src/runtime/docker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use tokio::{
task,
};

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum ContainerSource {
#[serde(rename = "image")]
Image(String),
Expand Down Expand Up @@ -88,6 +88,7 @@ pub(crate) async fn pull_image(image: &str, platform: Option<String>) -> anyhow:
],
)
.await
.map_err(|e| anyhow::anyhow!("Docker pull encountered an error: {}: {}", image, e))
}

pub(crate) async fn build_image(name: &str, path: &str) -> anyhow::Result<()> {
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ mod tests {

use super::*;
use std::collections::BTreeMap;
use std::collections::HashSet;

fn create_test_book() -> Arc<Book> {
Arc::new(Book {
pages: BTreeMap::new(),
failed_containers: HashSet::new(),
})
}

#[tokio::test]
async fn test_execute_call() {
Expand Down Expand Up @@ -257,6 +265,7 @@ mod tests {
map.insert(camino::Utf8PathBuf::from("test_page"), mock_page);
map
},
failed_containers: HashSet::new(),
});

let result = execute_call(None, false, 10, book, call).await.unwrap();
Expand Down Expand Up @@ -327,6 +336,7 @@ mod tests {
map.insert(camino::Utf8PathBuf::from("test_page"), mock_page);
map
},
failed_containers: HashSet::new(),
});

let results = execute(None, false, book, calls, 10).await.unwrap();
Expand All @@ -352,6 +362,7 @@ mod tests {
);
map
},
failed_containers: HashSet::new(),
});

let calls = vec![openai::Call {
Expand Down Expand Up @@ -397,6 +408,7 @@ mod tests {
);
map
},
failed_containers: HashSet::new(),
});

let calls = vec![openai::Call {
Expand Down