Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ redis = { version = "0.27", features = ["tokio-comp", "connection-manager", "jso

# SQLite for history (always available)
rusqlite = { version = "0.32", features = ["bundled"] }
url = "=2.5.8"
pyo3 = { version = "0.22", features = ["macros", "auto-initialize"], optional = true }

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion examples/async_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

println!("Extracted {} memories:", result.results.len());
for r in &result.results {
println!(" - {} ({})", r.memory, r.event.to_string());
println!(" - {} ({:?})", r.memory, r.event);
}

// Search for relevant memories
Expand Down
11 changes: 8 additions & 3 deletions src/embeddings/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ impl OllamaEmbedder {
let url = url::Url::parse(&config.base_url).unwrap_or_else(|_| {
url::Url::parse("http://localhost:11434").unwrap()
});

let host = url.host_str().unwrap_or("localhost").to_string();

let scheme: String = url.scheme().try_into().unwrap_or("http").to_string();
let hostname = url.host_str().unwrap_or("localhost").to_string();
let port = url.port().unwrap_or(11434);
let host = format!("{}://{}", scheme, hostname);

let client = Ollama::new(host, port);

Expand All @@ -39,7 +41,10 @@ impl Embedder for OllamaEmbedder {
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
let response = self
.client
.generate_embeddings(self.model.clone(), text.to_string(), None)
.generate_embeddings(ollama_rs::generation::embeddings::request::GenerateEmbeddingsRequest::new(
self.model.clone(),
ollama_rs::generation::embeddings::request::EmbeddingsInput::Single(text.to_string())
))
.await
.map_err(|e| EmbeddingError::Api(e.to_string()))?;

Expand Down
8 changes: 5 additions & 3 deletions src/llms/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ impl OllamaLLM {
url::Url::parse("http://localhost:11434").unwrap()
});

let host = url.host_str().unwrap_or("localhost").to_string();
let scheme: String = url.scheme().try_into().unwrap_or("http").to_string();
let hostname = url.host_str().unwrap_or("localhost").to_string();
let port = url.port().unwrap_or(11434);
let host = format!("{}://{}", scheme, hostname);

let client = Ollama::new(host, port);

Expand Down Expand Up @@ -72,11 +74,11 @@ impl LLM for OllamaLLM {
let temperature = options.temperature.unwrap_or(self.default_temperature);
request = request.options(
ollama_rs::generation::options::GenerationOptions::default()
.temperature(temperature as f64),
.temperature(temperature),
);

if options.json_mode {
request = request.format(ollama_rs::generation::completion::request::FormatType::Json);
request = request.format(ollama_rs::generation::parameters::FormatType::Json);
}

let response = self
Expand Down
12 changes: 5 additions & 7 deletions src/llms/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use async_openai::{
types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, ChatCompletionRequestAssistantMessageArgs,
CreateChatCompletionRequestArgs, ResponseFormat, ResponseFormatType,
CreateChatCompletionRequestArgs, ResponseFormat
},
Client,
};
Expand Down Expand Up @@ -52,19 +52,19 @@ impl OpenAILLM {
match msg.role {
Role::System => Ok(ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessageArgs::default()
.content(&msg.content)
.content(msg.content.clone())
.build()
.map_err(|e| LLMError::Api(e.to_string()))?,
)),
Role::User => Ok(ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessageArgs::default()
.content(&msg.content)
.content(msg.content.clone())
.build()
.map_err(|e| LLMError::Api(e.to_string()))?,
)),
Role::Assistant => Ok(ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessageArgs::default()
.content(&msg.content)
.content(msg.content.clone())
.build()
.map_err(|e| LLMError::Api(e.to_string()))?,
)),
Expand Down Expand Up @@ -99,9 +99,7 @@ impl LLM for OpenAILLM {
}

if options.json_mode {
request_builder.response_format(ResponseFormat {
r#type: ResponseFormatType::JsonObject,
});
request_builder.response_format(ResponseFormat::JsonObject);
}

let request = request_builder
Expand Down
6 changes: 3 additions & 3 deletions src/vector_stores/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl VectorStore for PostgresStore {
let embedding_str = Self::format_embedding(embedding);

// Build WHERE clause from filters
let mut where_clauses = Vec::new();
let where_clauses: Vec<String> = Vec::new();

if let Some(_f) = filters {
// TODO: Implement full filter translation
Expand Down Expand Up @@ -272,7 +272,7 @@ impl VectorStore for PostgresStore {
filters: Option<&Filters>,
limit: usize,
) -> Result<Vec<VectorSearchResult>, VectorStoreError> {
let mut where_clauses = Vec::new();
let where_clauses: Vec<String> = Vec::new();

if let Some(_f) = filters {
// TODO: Implement full filter translation
Expand Down Expand Up @@ -312,7 +312,7 @@ impl VectorStore for PostgresStore {
}

async fn delete_all(&self, filters: Option<&Filters>) -> Result<usize, VectorStoreError> {
let mut where_clauses = Vec::new();
let where_clauses: Vec<String> = Vec::new();

if let Some(_f) = filters {
// TODO: Implement full filter translation
Expand Down
16 changes: 7 additions & 9 deletions src/vector_stores/qdrant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

use async_trait::async_trait;
use qdrant_client::qdrant::{
points_selector::PointsSelectorOneOf, Condition, CreateCollectionBuilder, Distance, Filter,
PointId, PointStruct, PointsIdsList, PointsSelector, ScrollPointsBuilder, SearchPointsBuilder,
UpsertPointsBuilder, VectorParamsBuilder, WithPayloadSelector,
Condition, CreateCollectionBuilder, Distance, Filter,
PointId, PointStruct, PointsIdsList, ScrollPointsBuilder, SearchPointsBuilder,
UpsertPointsBuilder, VectorParamsBuilder, DeletePointsBuilder
};
use qdrant_client::Qdrant;
use std::collections::HashMap;
Expand Down Expand Up @@ -239,12 +239,10 @@ impl VectorStore for QdrantStore {
async fn delete(&self, id: &str) -> Result<(), VectorStoreError> {
self.client
.delete_points(
PointsSelector {
points_selector_one_of: Some(PointsSelectorOneOf::Points(PointsIdsList {
ids: vec![PointId::from(id.to_string())],
})),
},
Some(&self.collection_name),
DeletePointsBuilder::new(&self.collection_name)
.points(PointsIdsList {
ids: vec![PointId::from(id.to_string())],
})
)
.await
.map_err(|e| VectorStoreError::Delete(e.to_string()))?;
Expand Down