diff --git a/Cargo.lock b/Cargo.lock index 05bd242..702ed92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,9 +61,9 @@ checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", @@ -165,11 +165,15 @@ name = "bap-onest-lite" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "axum", "chrono", "chrono-tz", "config", "dashmap", + "deadpool-redis", + "futures", + "hex", "indexmap", "redis", "reqwest", @@ -177,6 +181,7 @@ dependencies = [ "serde_json", "sha2", "sqlx", + "strsim", "tokio", "tokio-cron-scheduler", "tower-http 0.5.2", @@ -460,6 +465,37 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "deadpool" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b" +dependencies = [ + "deadpool-runtime", + "lazy_static", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-redis" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0965b977f1244bc3783bb27cd79cfcff335a8341da18f79232d00504b18eb1a" +dependencies = [ + "deadpool", + "redis", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +dependencies = [ + "tokio", +] + [[package]] name = "der" version = "0.7.10" @@ -627,6 +663,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -671,6 +722,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -689,8 +751,10 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -789,6 +853,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" @@ -1348,6 +1418,16 @@ dependencies = [ "libm", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.36.7" @@ -2317,6 +2397,12 @@ dependencies = [ "unicode-properties", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" diff --git a/Cargo.toml b/Cargo.toml index 714306e..2f5ad91 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,8 @@ tower-http = { version = "0.5", features = ["cors"] } tokio-cron-scheduler = "0.14.0" chrono-tz = "0.10.4" indexmap = "2.11.1" +async-trait = "0.1.89" +hex = "0.4.3" +futures = "0.3.31" +deadpool-redis = "0.22.0" +strsim = "0.11.1" diff --git a/Dockerfile b/Dockerfile index 11522d5..a7792e0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,35 +1,39 @@ # -------- Stage 1: Build -------- -FROM debian:bookworm as build-deps - -RUN apt-get update && apt-get install -y curl build-essential pkg-config libssl-dev ca-certificates - -# Install Rust -RUN curl https://sh.rustup.rs -sSf | sh -s -- -y -ENV PATH="/root/.cargo/bin:${PATH}" +FROM rust:1.87 as build-deps +# Set working directory WORKDIR /app +# Copy Cargo files and build dummy project to cache dependencies COPY Cargo.toml Cargo.lock ./ RUN mkdir src && echo "fn main() {}" > src/main.rs RUN cargo build --release || true +# Copy the full source code COPY . . + +# Build the actual binary RUN cargo build --release --bin bap-onest-lite # -------- Stage 2: Runtime -------- FROM debian:bookworm-slim +# Install runtime dependencies RUN apt-get update && apt-get install -y \ libpq5 \ libssl3 \ ca-certificates \ && rm -rf /var/lib/apt/lists/* +# Set working directory WORKDIR /app +# Copy the built binary and config COPY --from=build-deps /app/target/release/bap-onest-lite /app/bap-onest-lite COPY config ./config +# Expose the port EXPOSE 3008 -CMD ["./bap-onest-lite"] \ No newline at end of file +# Run the binary +CMD ["./bap-onest-lite"] diff --git a/README.md b/README.md index bfd57f0..b928dc4 100644 --- a/README.md +++ b/README.md @@ -30,16 +30,8 @@ cargo run -- config/local.yaml Replace `config/local.yaml` with your configuration file as needed. -## Docker -Build and run the container manually: - -```sh -docker build -t bap-onest-lite . -docker run -p 3008:3008 bap-onest-lite ./bap-onest-lite config/local.yaml -``` - -Or use Docker Compose for multi-service setup: +use Docker Compose for multi-service setup: ```sh docker compose build --no-cache @@ -57,4 +49,4 @@ You can update the config path as needed. ## Configuration -Configuration is loaded from the path you provide as the first argument. \ No newline at end of file +Configuration is loaded from the path you provide as the first argument. diff --git a/config/match_score.json b/config/match_score.json new file mode 100644 index 0000000..75c0b05 --- /dev/null +++ b/config/match_score.json @@ -0,0 +1,108 @@ +{ + "match_score": [ + { + "name": "role", + "profile_path": "/metadata/role", + "job_path": "/tags/role", + "weight": 2, + "is_array": false, + "match_mode": "embed", + "penalty": 0.6 + }, + { + "name": "industry", + "profile_path": "/metadata/industry", + "job_path": "/tags/industry", + "weight": 1, + "is_array": false, + "match_mode": "embed", + "penalty": 0.75 + }, + { + "name": "location", + "profile_path": "/metadata/whoIAm/locationData/city", + "job_path": "/tags/jobProviderLocation/city", + "weight": 1, + "is_array": false, + "match_mode": "embed", + "penalty": 0.95 + }, + { + "name": "iti_specialization", + "profile_path": "/metadata/whatIHave/itiSpecialization", + "job_path": "/tags/jobNeeds/educationSubsection/itiSpecialtyPreference", + "weight": 1, + "is_array": true, + "match_mode": "embed", + "penalty": 0.9 + }, + { + "name": "languages", + "profile_path": "/metadata/whatIHave/languagesKnown", + "job_path": "/tags/jobNeeds/languageSubsection/languageKnown", + "weight": 1, + "is_array": true, + "match_mode": "embed", + "penalty": 0.9 + }, + { + "name": "highest_qualification", + "profile_path": "/metadata/whatIHave/highestEducation", + "job_path": "/tags/jobNeeds/highestQualificationSubsection/highestQualification", + "weight": 1, + "is_array": true, + "match_mode": "embed", + "penalty": 0.9, + "bonus": 1.0, + "embedding_only": true + }, + { + "name": "software_skills", + "profile_path": "/metadata/whatIHave/softwareSkills", + "job_path": "/tags/jobNeeds/highestQualificationSubsection/softwareKnowledgePreferred", + "weight": 1, + "is_array": true, + "match_mode": "embed", + "penalty": 0.9, + "bonus": 1.0, + "embedding_only": true + }, + { + "name": "preferred_work_mode", + "profile_path": "/metadata/whatIWant/preferredWorkMode", + "job_path": "/tags/jobDetails/modeOfWork", + "weight": 1, + "is_array": true, + "match_mode": "embed", + "penalty": 0.7, + "embedding_only": true + }, + { + "name": "age", + "profile_path": "/metadata/whoIAm/age", + "job_path_min": "/tags/jobNeeds/ageAllowedLowerLimit", + "job_path_max": "/tags/jobNeeds/ageAllowedUpperLimit", + "match_mode": "manual", + "penalty": 0.7, + "bonus": 1.05 + }, + { + "name": "monthly_in_hand", + "profile_path": "/metadata/whatIWant/monthlyInHandPreferred", + "job_path_min": "/tags/jobDetails/minMonthlyInHand", + "job_path_max": "/tags/jobDetails/maxMonthlyInHand", + "match_mode": "manual", + "penalty": 0.7, + "bonus": 1.05 + }, + { + "name": "work_hours_per_day", + "profile_path": "/metadata/whatIWant/workHoursPerDay", + "job_path_min": "/tags/jobDetails/workingHoursPerDay", + "job_path_max": "/tags/jobDetails/workingHoursPerDay", + "match_mode": "manual", + "penalty": 0.8, + "bonus": 1.05 + } + ] +} diff --git a/src/config.rs b/src/config.rs index b6874ea..8249632 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,41 +23,96 @@ pub struct Bap { pub struct RedisConfig { pub url: String, } + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct DbConfig { pub url: String, } + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct CacheConfig { pub result_ttl_secs: u64, pub txn_ttl_secs: u64, pub throttle_secs: u64, } -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct AppConfig { - debug: bool, - pub bap: Bap, - pub http: HttpConfig, - pub redis: RedisConfig, - pub db: DbConfig, - pub cache: CacheConfig, - pub cron: CronConfig, -} #[derive(Debug, Serialize, Deserialize, Clone)] pub struct JobSchedule { pub seconds: u64, } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct GcpConfig { + pub project_id: String, + pub model: String, + pub auth_token: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FieldWeight { + pub path: String, + pub weight: usize, + pub label: Option, + #[serde(default)] + pub is_array: bool, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct EmbeddingWeights { + pub job: Vec, + pub profile: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct MetaDataMatch { + pub name: String, + pub profile_path: String, + #[serde(default)] + pub job_path: String, + #[serde(default)] + pub job_path_min: Option, + #[serde(default)] + pub job_path_max: Option, + #[serde(default)] + pub weight: Option, + #[serde(default)] + pub is_array: bool, + pub match_mode: MatchMode, + pub penalty: f32, + #[serde(default)] + pub bonus: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum MatchMode { + Embed, + Manual, +} + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct CronConfig { pub fetch_jobs: JobSchedule, } + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct AppConfig { + pub debug: bool, + pub bap: Bap, + pub http: HttpConfig, + pub redis: RedisConfig, + pub db: DbConfig, + pub cache: CacheConfig, + pub cron: CronConfig, + pub gcp: GcpConfig, + pub match_score_path: String, +} + impl AppConfig { pub fn new() -> Result { let args: Vec = env::args().collect(); if args.len() < 2 { - error!("Error: Configuration path not provided. Usage: cargo run -- "); + error!("❌ Error: Configuration path not provided. Usage: cargo run -- "); process::exit(1); } let config_path = &args[1]; @@ -66,6 +121,7 @@ impl AppConfig { .add_source(File::with_name(&config_path)) .build()? .try_deserialize()?; + Ok(config) } } diff --git a/src/cron/fetch_jobs.rs b/src/cron/fetch_jobs.rs index d640c56..3f8a089 100644 --- a/src/cron/fetch_jobs.rs +++ b/src/cron/fetch_jobs.rs @@ -64,7 +64,6 @@ pub async fn run(app_state: AppState) { // Metadata to store in Redis for additional info let redis_key = format!("cron_txn:{}", txn_id); - let mut conn = app_state.redis_conn.lock().await; let metadata = serde_json::json!({ "source": "cron", "brief": false, @@ -72,28 +71,30 @@ pub async fn run(app_state: AppState) { "timestamp": Utc::now() }); - if let Err(e) = conn - .set_ex::<_, _, ()>( - &redis_key, - metadata.to_string(), - app_state.config.cache.txn_ttl_secs, - ) - .await - { - error!(target: "cron", "❌ Failed to store cron txn metadata: {:?}", e); - } + match app_state.redis_pool.get().await { + Ok(mut conn) => { + let ttl_secs = app_state.config.cache.txn_ttl_secs; - // Update a separate key to always point to the latest cron transaction - let latest_key = "cron_txn:latest"; - if let Err(e) = conn.set::<_, _, ()>(latest_key, &txn_id).await { - error!(target: "cron", "❌ Failed to store latest cron txn_id: {:?}", e); - } else { - info!(target: "cron", "✅ Updated latest cron transaction to {}", txn_id); + // Store metadata with TTL using set_ex + let res: Result<(), redis::RedisError> = conn + .set_ex(&redis_key, metadata.to_string(), ttl_secs) + .await; + + match res { + Ok(_) => info!(target: "cron", "✅ Stored cron txn metadata at key {}", redis_key), + Err(e) => error!(target: "cron", "❌ Failed to store cron txn metadata: {:?}", e), + } + } + Err(e) => { + error!(target: "cron", "❌ Failed to get Redis connection from pool: {:?}", e); + } } // Send to BAP adapter let adapter_url = format!("{}/search", app_state.config.bap.caller_uri); if let Err(e) = post_json(&adapter_url, payload).await { error!(target: "cron", "❌ Failed to send search to BAP adapter: {}", e); + } else { + info!(target: "cron", "📨 Search request sent to BAP adapter successfully"); } } diff --git a/src/http/http_server.rs b/src/http/http_server.rs index 3725f42..a3af890 100644 --- a/src/http/http_server.rs +++ b/src/http/http_server.rs @@ -4,16 +4,13 @@ use crate::{ http::routes::create_routes, state::{AppState, SharedState}, }; -use redis::Client; use sqlx::PgPool; use std::sync::Arc; -use tokio::{ - net::TcpListener, - sync::{watch, Mutex}, - task::JoinHandle, -}; +use tokio::{net::TcpListener, sync::watch, task::JoinHandle}; use tracing::info; +use deadpool_redis::{Config as RedisConfig, Runtime}; + pub async fn start_http_server( config: AppConfig, shutdown_rx: watch::Receiver<()>, @@ -27,30 +24,27 @@ pub async fn start_http_server( let shared_state = SharedState::default(); - let redis_client = Client::open(config.redis.url.as_str())?; - let redis_conn = redis_client - .get_multiplexed_async_connection() - .await - .map_err(|e| { - tracing::error!("❌ Redis connection failed: {}", e); - e - })?; + let redis_cfg = RedisConfig::from_url(config.redis.url.as_str()); + let redis_pool = redis_cfg.create_pool(Some(Runtime::Tokio1))?; // Test Redis connection { - let mut test_conn = redis_client.get_multiplexed_async_connection().await?; - let pong: String = redis::cmd("PING").query_async(&mut test_conn).await?; + let mut conn = redis_pool.get().await?; + let pong: String = redis::cmd("PING").query_async(&mut conn).await?; info!("✅ Redis PING -> {}", pong); } + // --- Postgres pool --- let db_pool = PgPool::connect(&config.db.url).await?; info!("✅ connected to db at {}", &config.db.url); + let app_state = AppState { config: Arc::new(config.clone()), shared_state, - redis_conn: Arc::new(Mutex::new(redis_conn)), + redis_pool, db_pool, }; + let _scheduler = start_cron_jobs(app_state.clone()).await; let http_server = tokio::spawn(run_http_server(listener, shutdown_rx, app_state)); diff --git a/src/models/search.rs b/src/models/search.rs index 67f28c8..f58facf 100644 --- a/src/models/search.rs +++ b/src/models/search.rs @@ -94,6 +94,7 @@ pub struct SearchRequestV2 { pub page: Option, pub limit: Option, pub primary_filters: Option, + pub profile: Option, } #[derive(Debug, Serialize, Deserialize)] diff --git a/src/services/empeding.rs b/src/services/empeding.rs new file mode 100644 index 0000000..a6861d9 --- /dev/null +++ b/src/services/empeding.rs @@ -0,0 +1,97 @@ +use anyhow::Result; +use async_trait::async_trait; +use hex; +use redis::AsyncCommands; +use reqwest; +use serde_json::Value; +use sha2::{Digest, Sha256}; +use tracing::{error, info}; + +use crate::state::AppState; + +#[async_trait] +pub trait EmbeddingService { + async fn get_embedding( + &self, + text: &str, + conn: &mut redis::aio::MultiplexedConnection, + app_state: &AppState, + ) -> Result>; +} + +pub struct GcpEmbeddingService; + +#[async_trait] +impl EmbeddingService for GcpEmbeddingService { + async fn get_embedding( + &self, + text: &str, + conn: &mut redis::aio::MultiplexedConnection, + app_state: &AppState, + ) -> Result> { + let mut hasher = Sha256::new(); + hasher.update(text.as_bytes()); + let hash = hex::encode(hasher.finalize()); + let cache_key = format!("embedding:{}:{}", app_state.config.gcp.model, hash); + info!("🔑 Cache key for embedding: {}", cache_key); + + match conn.get::<_, Option>(&cache_key).await { + Ok(Some(cached)) => { + if let Ok(vec) = serde_json::from_str::>(&cached) { + info!("✅ Cache hit for text: {}", text); + return Ok(vec); + } + } + Ok(None) => info!("❌ Cache miss for text: {}", text), + Err(e) => error!("❌ Redis get error for key {}: {:?}", cache_key, e), + } + + info!("🚀 Fetching embedding from GCP for text: {}", text); + + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{}:embedContent", + app_state.config.gcp.model + ); + + let body = serde_json::json!({ + "model": format!("models/{}", app_state.config.gcp.model), + "content": { "parts": [{ "text": text }] } + }); + + info!("🔍 Calling GCP Embedding API for text: {}", text); + + let client = reqwest::Client::new(); + let resp = client + .post(&url) + .header("x-goog-api-key", &app_state.config.gcp.auth_token) + .json(&body) + .send() + .await?; + + info!("Embedding API response status: {}", resp.status()); + + let json_resp: Value = resp.json().await?; + + let embedding_values = json_resp + .get("embedding") + .and_then(|e| e.get("values")) + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow::anyhow!("Missing embedding in response"))?; + + let embedding: Vec = embedding_values + .iter() + .map(|v| v.as_f64().unwrap_or(0.0) as f32) + .collect(); + + if let Err(e) = conn + .set::<_, _, ()>(&cache_key, serde_json::to_string(&embedding)?) + .await + { + error!("❌ Failed to cache embedding in Redis: {:?}", e); + } else { + info!("🚀 Cached embedding for text (hashed key): {}", hash); + } + + Ok(embedding) + } +} diff --git a/src/services/mod.rs b/src/services/mod.rs index 9883a53..01ffec6 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -1,3 +1,4 @@ +pub mod empeding; pub mod job_apply; pub mod job_draft; pub mod payload_generator; diff --git a/src/services/search.rs b/src/services/search.rs index ba83f05..610a73e 100644 --- a/src/services/search.rs +++ b/src/services/search.rs @@ -1,16 +1,21 @@ use crate::models::search::SearchRequestV2; use crate::models::webhook::{Ack, AckResponse, AckStatus, WebhookPayload}; +use crate::services::empeding::{EmbeddingService, GcpEmbeddingService}; use crate::{ models::search::SearchRequest, services::payload_generator::build_beckn_payload, state::AppState, - utils::{hash::generate_query_hash, http_client::post_json, search::matches_query_dynamic}, + utils::{ + empeding::{compute_match_score, job_text_for_embedding, profile_text_for_embedding}, + hash::generate_query_hash, + http_client::post_json, + search::matches_query_dynamic, + }, }; use axum::{extract::State, http::StatusCode, Json}; -use indexmap::IndexMap; use redis::AsyncCommands; use serde_json::{json, Value as JsonValue}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::time::Instant; use tracing::{error, event, info, Level}; use uuid::Uuid; @@ -18,76 +23,62 @@ use uuid::Uuid; pub async fn handle_search( State(app_state): State, Json(req): Json, -) -> Result, (StatusCode, Json)> { +) -> Result, (StatusCode, Json)> { let start = Instant::now(); let message_id = format!("msg-{}", Uuid::new_v4()); let txn_id = format!("txn-{}", Uuid::new_v4()); let query_hash = generate_query_hash(&req.message); - let pattern = format!("search:{}:*", query_hash); info!("Looking for Redis keys with pattern: {}", pattern); // --- Get cached search results --- - let mut all_keys = { - let mut conn = app_state.redis_conn.lock().await; - - let mut stream = conn.scan_match::<_, String>(&pattern).await.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": "Failed to scan Redis", - "details": e.to_string() - })), - ) - })?; - - let mut keys = vec![]; - while let Some(k) = stream.next_item().await { - keys.push(k); - } - keys - }; - - all_keys.sort(); - - info!("Matched Redis keys: {:?}", all_keys); - - let cached_results = { - let mut conn = app_state.redis_conn.lock().await; - let mut results = vec![]; - - for key in &all_keys { - match conn.get::<_, String>(key).await { - Ok(value) => { - if let Ok(json_value) = serde_json::from_str::(&value) { - results.push(json_value); - } else { - error!("Failed to parse cached value for key: {}", key); - } + let cached_results = match app_state.redis_pool.get().await { + Ok(mut conn) => { + let mut stream = conn + .scan_match::(pattern.clone()) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "Failed to scan Redis", + "details": e.to_string() + })), + ) + })?; + + let mut keys = vec![]; + while let Some(key) = stream.next_item().await { + keys.push(key); + } + drop(stream); + + let mut results = vec![]; + for key in keys { + match conn.get::(key.clone()).await { + Ok(value) => match serde_json::from_str::(&value) { + Ok(json_value) => results.push(json_value), + Err(_) => error!("Failed to parse cached value for key: {}", key), + }, + Err(e) => error!("Redis get error for key {}: {}", key, e), } - Err(e) => error!("Redis get error for key {}: {}", key, e), } + results + } + Err(e) => { + error!("Failed to get Redis connection from pool: {:?}", e); + vec![] } - - results }; // --- Cache txn_id -> query_hash for on_search mapping --- - { - let mut conn = app_state.redis_conn.lock().await; + if let Ok(mut conn) = app_state.redis_pool.get().await { let txn_key = format!("txn_to_query:{}", txn_id); - conn.set_ex::<_, _, ()>(&txn_key, &query_hash, app_state.config.cache.txn_ttl_secs) + let _: () = conn + .set_ex::<_, _, ()>(&txn_key, &query_hash, app_state.config.cache.txn_ttl_secs) .await - .map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": "Failed to cache txn_id", - "details": e.to_string() - })), - ) - })?; + .unwrap_or_else(|e| error!("Failed to cache txn_id: {:?}", e)); } let config = app_state.config.clone(); @@ -102,35 +93,37 @@ pub async fn handle_search( ); let adapter_url = format!("{}/search", config.bap.caller_uri); - // --- Throttle BAP calls (dynamic skip time) --- - let should_call_bap = { - let mut conn = app_state.redis_conn.lock().await; - let last_call_key = format!("last_call:{}", query_hash); - - match conn.exists::<_, bool>(&last_call_key).await { - Ok(exists) if exists => { - let secs = app_state.config.cache.throttle_secs; - if secs % 60 == 0 { - info!( - ": Skipping BAP call (already called within last {} min)", - secs / 60 - ); - } else { + // --- Throttle BAP calls --- + let should_call_bap = match app_state.redis_pool.get().await { + Ok(mut conn) => { + let last_call_key = format!("last_call:{}", query_hash); + match conn.exists::<_, bool>(&last_call_key).await { + Ok(exists) if exists => { + let secs = app_state.config.cache.throttle_secs; info!( - ": Skipping BAP call (already called within last {} secs)", - secs + ": Skipping BAP call (already called within last {} {})", + if secs % 60 == 0 { secs / 60 } else { secs }, + if secs % 60 == 0 { "min" } else { "secs" } ); + false + } + _ => { + let _: () = conn + .set_ex::<_, _, ()>( + &last_call_key, + "1", + app_state.config.cache.throttle_secs, + ) + .await + .unwrap_or_default(); + true } - false - } - _ => { - let _: () = conn - .set_ex(&last_call_key, "1", app_state.config.cache.throttle_secs) - .await - .unwrap_or_default(); - true } } + Err(e) => { + error!("Failed to get Redis connection for throttle check: {:?}", e); + true + } }; if should_call_bap { @@ -138,9 +131,10 @@ pub async fn handle_search( ": Sending search request to BAP adapter at: {}", adapter_url ); + let payload_clone = payload.clone(); tokio::spawn(async move { - if let Err(e) = post_json(&adapter_url, payload).await { - error!(":x: Failed to send search to BAP adapter: {}", e); + if let Err(e) = post_json(&adapter_url, payload_clone).await { + error!("❌ Failed to send search to BAP adapter: {}", e); } }); } @@ -155,12 +149,9 @@ pub async fn handle_search( duration_ms = %elapsed.as_millis(), "API timing(search)" ); - // --- Return cached results if available --- if !cached_results.is_empty() { - return Ok(Json(serde_json::json!({ - "results": cached_results - }))); + return Ok(Json(serde_json::json!({ "results": cached_results }))); } Ok(Json(serde_json::json!([]))) @@ -175,39 +166,46 @@ pub async fn handle_on_search( return handle_cron_on_search(app_state, payload, txn_id).await; } - let mut conn = app_state.redis_conn.lock().await; - let txn_key = format!("txn_to_query:{}", txn_id); - - match conn.get::<_, String>(&txn_key).await { - Ok(query_hash) => match &payload.context.bpp_id { - Some(bpp_id) => { - let redis_key = format!("search:{}:{}", query_hash, bpp_id); - match serde_json::to_string(payload) { - Ok(data) => { - if let Err(e) = conn - .set_ex::<_, _, ()>( - &redis_key, - data, - app_state.config.cache.result_ttl_secs, - ) - .await - { - info!("❌ Failed to store in Redis: {:?}", e); - } else { - info!("✅ Stored response at key: {}", redis_key); + // --- Get a Redis connection from the pool --- + match app_state.redis_pool.get().await { + Ok(mut conn) => { + let txn_key = format!("txn_to_query:{}", txn_id); + + match conn.get::<_, String>(&txn_key).await { + Ok(query_hash) => match &payload.context.bpp_id { + Some(bpp_id) => { + let redis_key = format!("search:{}:{}", query_hash, bpp_id); + match serde_json::to_string(payload) { + Ok(data) => { + if let Err(e) = conn + .set_ex::<_, _, ()>( + &redis_key, + data, + app_state.config.cache.result_ttl_secs, + ) + .await + { + info!("❌ Failed to store in Redis: {:?}", e); + } else { + info!("✅ Stored response at key: {}", redis_key); + } + } + Err(e) => { + info!("❌ Failed to serialize payload: {:?}", e); + } } } - Err(e) => { - info!("❌ Failed to serialize payload: {:?}", e); + None => { + info!("⚠️ No bpp_id found in payload, skipping Redis cache"); } + }, + Err(_) => { + info!("❌ No query_hash found for txn_id = {}", txn_id); } } - None => { - info!("⚠️ No bpp_id found in payload, skipping Redis cache"); - } - }, - Err(_) => { - info!("❌ No query_hash found for txn_id = {}", txn_id); + } + Err(e) => { + error!("❌ Failed to get Redis connection from pool: {:?}", e); } } @@ -225,7 +223,20 @@ pub async fn handle_cron_on_search( ) -> Json { info!(target: "cron", "📦 Handling cron on_search for txn_id={}", txn_id); - let mut conn = app_state.redis_conn.lock().await; + let mut conn = match app_state.redis_pool.get().await { + Ok(c) => c, + Err(e) => { + error!(target: "cron", "❌ Failed to get Redis connection: {:?}", e); + return Json(AckResponse { + message: AckStatus { + ack: Ack { status: "ACK" }, + }, + }); + } + }; + + // Create embedding service instance + let embedding_service = GcpEmbeddingService; if let Some(bpp_id) = &payload.context.bpp_id { let redis_key = format!("cron_jobs:{}:{}", txn_id, bpp_id); @@ -238,20 +249,166 @@ pub async fn handle_cron_on_search( _ => serde_json::to_value(payload).unwrap(), }; - // Append new providers to existing ones + // Append or update new providers if let Some(new_providers) = payload .message .get("catalog") .and_then(|c| c.get("providers")) .and_then(|p| p.as_array()) { - store_data + let existing_providers = store_data .pointer_mut("/message/catalog/providers") - .and_then(|existing_providers| existing_providers.as_array_mut()) - .map(|arr| arr.extend(new_providers.clone())); + .and_then(|p| p.as_array_mut()); + + if let Some(existing) = existing_providers { + // Build a map of existing providers by jobProviderName + let mut provider_index_map = std::collections::HashMap::new(); + for (i, provider) in existing.iter().enumerate() { + if let Some(name) = provider + .pointer("/items/0/tags/basicInfo/jobProviderName") + .and_then(|v| v.as_str()) + { + provider_index_map.insert(name.to_string(), i); + } + } + + for mut provider in new_providers.clone() { + let provider_name = provider + .pointer("/items/0/tags/basicInfo/jobProviderName") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown Provider") + .to_string(); + + if let Some(items) = provider.get_mut("items").and_then(|j| j.as_array_mut()) { + for job in items.iter_mut() { + let job_id = job + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown Job") + .to_string(); + + let text = job_text_for_embedding(job, &app_state.config); + + if text.trim().is_empty() { + info!( + target: "cron", + "⚠️ Skipping embedding: provider='{}', job_id='{}', reason='empty text'", + provider_name, + job_id + ); + continue; + } + + info!( + target: "cron", + "🔹 Generating embedding: provider='{}', job_id='{}', text_len={}", + provider_name, + job_id, + text.len() + ); + + match embedding_service + .get_embedding(&text, &mut conn, app_state) + .await + { + Ok(embedding) => { + let embedding_len = embedding.len(); + + if let Some(obj) = job.as_object_mut() { + obj.insert( + "embedding".to_string(), + serde_json::json!(embedding), + ); + } + + let is_stored = job.get("embedding").is_some(); + info!( + target: "cron", + "✅ Embedding stored: provider= {}, job_id='{}', embedding_len={}, stored_in_job={}", + provider_name, + job_id, + embedding_len, + is_stored + ); + } + Err(e) => { + error!( + target: "cron", + "❌ Failed embedding: provider='{}', job_id='{}', error={:?}", + provider_name, + job_id, + e + ); + } + } + } + } else { + info!( + target: "cron", + "⚠️ No items found for provider='{}', skipping embedding", + provider_name + ); + } + + // Insert or update provider + if let Some(&idx) = provider_index_map.get(&provider_name) { + // Merge items for existing provider instead of replacing the whole thing + if let Some(existing_items) = existing[idx] + .get_mut("items") + .and_then(|v| v.as_array_mut()) + { + if let Some(new_items) = + provider.get("items").and_then(|v| v.as_array()) + { + // Build set of existing job_ids to avoid duplicates + let existing_job_ids: std::collections::HashSet = + existing_items + .iter() + .filter_map(|job| { + job.get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .collect(); + + for new_job in new_items { + if let Some(job_id) = new_job.get("id").and_then(|v| v.as_str()) + { + if !existing_job_ids.contains(job_id) { + existing_items.push(new_job.clone()); + } else { + // Optional: update embedding if exists + if let Some(existing_job) = + existing_items.iter_mut().find(|j| { + j.get("id").and_then(|v| v.as_str()) + == Some(job_id) + }) + { + if let Some(new_embedding) = + new_job.get("embedding") + { + if let Some(obj) = existing_job.as_object_mut() + { + obj.insert( + "embedding".to_string(), + new_embedding.clone(), + ); + } + } + } + } + } + } + } + } + } else { + existing.push(provider); + } + } + } } - // Get pagination info from stored data + // --- Pagination info --- let (current_page, limit, total_count) = { let pagination = store_data .pointer("/message/pagination") @@ -278,13 +435,11 @@ pub async fn handle_cron_on_search( (page, limit, total_count) }; + info!( target: "cron", "📄 Pagination status for BPP {}: current_page = {} limit = {} total_count = {}", - bpp_id, - current_page, - limit, - total_count + bpp_id, current_page, limit, total_count ); // Store back to Redis with TTL @@ -304,12 +459,9 @@ pub async fn handle_cron_on_search( info!( target: "cron", "🔄 More pages to fetch: current_page = {} total_count = {} → requesting next_page = {}", - current_page, - total_count, - next_page + current_page, total_count, next_page ); - // Build intent for next page let mut intent = payload .message .get("intent") @@ -329,23 +481,19 @@ pub async fn handle_cron_on_search( } ] }); - // Build final message + let message = json!({ "intent": intent, "pagination": { "page": next_page, "limit": limit }, - "options": { - "brief": false - } + "options": { "brief": false } }); - // Update Redis with next_page prevent duplicate calls store_data.pointer_mut("/message/pagination").map(|p| { p["page"] = json!(next_page); }); - if let Err(e) = conn .set_ex::<_, String, ()>(&redis_key, store_data.to_string(), ttl_secs) .await @@ -369,19 +517,23 @@ pub async fn handle_cron_on_search( error!( target: "cron", "❌ Failed to request next_page = {} (txn_id={}): {}", - next_page, - txn_id, - e + next_page, txn_id, e ); } else { info!( target: "cron", "📨 Successfully requested next_page = {} for txn_id={}", - next_page, - txn_id + next_page, txn_id ); } } else { + let latest_key = "cron_txn:latest"; + if let Err(e) = conn.set::<_, _, ()>(latest_key, &txn_id).await { + error!(target: "cron", "❌ Failed to store latest cron txn_id: {:?}", e); + } else { + info!(target: "cron", "✅ Updated latest cron transaction to {}", txn_id); + } + info!(target: "cron", "✅ All pages fetched for txn_id={}", txn_id); info!(target: "cron", "╔════════════════════════════════════════════╗"); info!(target: "cron", "║ ✅ Finished fetch jobs cron. ║"); @@ -402,9 +554,21 @@ pub async fn handle_search_v2( State(app_state): State, Json(req): Json, ) -> Result, (StatusCode, Json)> { - let mut conn = app_state.redis_conn.lock().await; + let mut conn = match app_state.redis_pool.get().await { + Ok(c) => c, + Err(e) => { + error!("❌ Failed to get Redis connection: {:?}", e); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": "Failed to connect to Redis" })), + )); + } + }; + // ✅ Initialize string similarity cache + + let mut string_sim_cache: HashMap<(String, String), f32> = HashMap::new(); - // 👉 Get latest txn_id + // ✅ Get latest txn_id let latest_key = "cron_txn:latest"; let txn_id: String = match conn.get(latest_key).await { Ok(Some(val)) => val, @@ -416,7 +580,7 @@ pub async fn handle_search_v2( } }; - // Fetch all BPP results for this txn_id + // ✅Fetch all BPP results for this txn_id let pattern = format!("cron_jobs:{}:*", txn_id); let keys: Vec = conn.keys(&pattern).await.map_err(|e| { ( @@ -427,10 +591,6 @@ pub async fn handle_search_v2( let page = req.page.unwrap_or(1) as usize; let limit = req.limit.unwrap_or(10) as usize; - - let mut seen_ids = HashSet::new(); - let mut flat_items = vec![]; - let provider_filter = req.provider.as_ref().map(|s| s.to_lowercase()); let role_filters: Vec = req .role @@ -441,13 +601,31 @@ pub async fn handle_search_v2( let primary_filters: Vec = req .primary_filters .as_ref() - .map(|r| { - r.split(',') - .map(|s| s.trim().to_lowercase()) - .collect::>() - }) + .map(|r| r.split(',').map(|s| s.trim().to_lowercase()).collect()) .unwrap_or_default(); + // ✅ Compute embedding for profile + let profile_embedding: Option> = if let Some(profile) = &req.profile { + let profile_text = profile_text_for_embedding(profile, &app_state.config); + info!("Profile text for embedding: {}", profile_text); + + match GcpEmbeddingService + .get_embedding(&profile_text, &mut conn, &app_state) + .await + { + Ok(vec) => Some(vec), + Err(e) => { + error!("Failed to get embedding: {:?}", e); + None + } + } + } else { + None + }; + + let mut seen_ids = HashSet::new(); + let mut flat_items = Vec::new(); + for key in keys { if let Ok(Some(payload_str)) = conn.get::<_, Option>(&key).await { if let Ok(payload_json) = serde_json::from_str::(&payload_str) { @@ -463,7 +641,7 @@ pub async fn handle_search_v2( .unwrap_or("") .to_lowercase(); - // provider filter + // Provider filter if let Some(ref pf) = provider_filter { if !provider_name.contains(pf) { continue; @@ -482,49 +660,81 @@ pub async fn handle_search_v2( let item_roles: Vec<&str> = role_name.split(',').map(|s| s.trim()).collect(); - let mut match_item = true; - - // primary_filter - if !primary_filters.is_empty() { - if !primary_filters.iter().any(|pf| role_name.contains(pf)) { - continue; - } + // Filters + if !primary_filters.is_empty() + && !primary_filters.iter().any(|pf| role_name.contains(pf)) + { + continue; } - // role filter - if !role_filters.is_empty() { - if !role_filters + if !role_filters.is_empty() + && !role_filters .iter() .any(|rf| item_roles.iter().any(|r| r.contains(rf))) - { - match_item = false; - } + { + continue; } - // query filter if let Some(ref qf) = query_filter { if !matches_query_dynamic(&provider_name, item, qf) { - match_item = false; + continue; } } - if match_item { - let id_key = item - .get("id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .unwrap_or_else(|| { - serde_json::to_string(item).unwrap_or_default() - }); - - if seen_ids.insert(id_key) { - flat_items.push(( - payload_json["context"].clone(), - provider.clone(), - item.clone(), - )); + // ✅ Compute match_score + let mut match_score = 0u8; + if let Some(ref profile_emb) = profile_embedding { + if let Some(embedding_json) = item.get("embedding") { + if let Ok(job_emb) = serde_json::from_value::>( + embedding_json.clone(), + ) { + // fix: avoid temporary drop + let empty_json = serde_json::json!({}); + let profile_meta = + req.profile.as_ref().unwrap_or(&empty_json); + let profile_norm = profile_embedding + .as_ref() + .map(|v| { + v.iter().map(|x| x * x).sum::().sqrt() + }) + .unwrap_or(0.0); + + let score = compute_match_score( + profile_emb, + profile_norm, + &job_emb, + job_emb.iter().map(|x| x * x).sum::().sqrt(), // job norm + profile_meta, + &item, + &app_state.config, + &mut string_sim_cache, + ); + + match_score = (score * 10.0).round() as u8; + } } } + + // ✅ Prepare cleaned item + let mut item_obj = item.as_object().cloned().unwrap_or_default(); + item_obj.remove("embedding"); + item_obj.insert("match_score".to_string(), json!(match_score)); + + let id_key = item + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| { + serde_json::to_string(item).unwrap_or_default() + }); + + if seen_ids.insert(id_key) { + flat_items.push(( + payload_json["context"].clone(), + provider.clone(), + json!(item_obj), + )); + } } } } @@ -533,69 +743,56 @@ pub async fn handle_search_v2( } } - let total_count = flat_items.len(); + // ✅ Global sort by match_score DESC (ensure correct ordering) + if profile_embedding.is_some() { + flat_items.sort_by(|(_, _, a), (_, _, b)| { + let sa = a.get("match_score").and_then(|v| v.as_u64()).unwrap_or(0); + let sb = b.get("match_score").and_then(|v| v.as_u64()).unwrap_or(0); + sb.cmp(&sa) // descending + }); + } + // ✅ Pagination after sorting + let total_count = flat_items.len(); let start = (page - 1) * limit; - let paginated_items = flat_items - .into_iter() - .skip(start) - .take(limit) - .collect::>(); - - // Group back into payload → providers → items - let mut results_map: IndexMap)>> = - IndexMap::new(); - - for (context, provider, item) in paginated_items { - let provider_descriptor = provider["descriptor"].clone(); - let provider_id = provider.get("id").cloned().unwrap_or(json!(null)); - let provider_fulfillments = provider.get("fulfillments").cloned().unwrap_or(json!([])); - let provider_locations = provider.get("locations").cloned().unwrap_or(json!([])); - - let key = serde_json::to_string(&provider_descriptor).unwrap_or_default(); - - results_map - .entry(context.clone()) - .or_default() - .entry(key) - .and_modify(|(_, items)| items.push(item.clone())) - .or_insert_with(|| { - ( - json!({ - "descriptor": provider_descriptor, - "id": provider_id, - "fulfillments": provider_fulfillments, - "locations": provider_locations, - }), - vec![item.clone()], - ) - }); + if start >= total_count { + return Ok(Json(json!({ + "pagination": { + "page": page, + "limit": limit, + "totalCount": total_count + }, + "results": [] + }))); } - let mut results = vec![]; + let end = std::cmp::min(start + limit, total_count); + let paginated_items = flat_items[start..end].to_vec(); - for (context, providers_map) in results_map { - let mut payload = json!({ - "context": context, - "message": { - "catalog": { - "providers": [] + // ✅ Rebuild ONDC-compatible response + let results: Vec = paginated_items + .into_iter() + .map(|(context, provider, item)| { + json!({ + "context": context, + "message": { + "catalog": { + "providers": [ + { + "descriptor": provider["descriptor"].clone(), + "id": provider.get("id").cloned().unwrap_or(json!(null)), + "fulfillments": provider.get("fulfillments").cloned().unwrap_or(json!([])), + "locations": provider.get("locations").cloned().unwrap_or(json!([])), + "items": [item] + } + ] + } } - } - }); - - let providers_arr = providers_map - .into_iter() - .map(|(_, (mut provider_obj, items))| { - provider_obj["items"] = json!(items); - provider_obj }) - .collect::>(); - - payload["message"]["catalog"]["providers"] = json!(providers_arr); - results.push(payload); - } + }) + .collect(); + // ✅ Final response let response = json!({ "pagination": { "page": page, diff --git a/src/state.rs b/src/state.rs index 2845f8b..916d6fb 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,18 +1,18 @@ use dashmap::DashMap; use std::sync::Arc; -use tokio::sync::{oneshot, Mutex}; +use tokio::sync::oneshot; pub type OnSearchResponse = serde_json::Value; use crate::config::AppConfig; -use redis::aio::MultiplexedConnection; +use deadpool_redis::Pool; use sqlx::PgPool; #[derive(Clone)] pub struct AppState { pub config: Arc, pub shared_state: SharedState, - pub redis_conn: Arc>, + pub redis_pool: Pool, pub db_pool: PgPool, } diff --git a/src/utils/empeding.rs b/src/utils/empeding.rs new file mode 100644 index 0000000..ad36155 --- /dev/null +++ b/src/utils/empeding.rs @@ -0,0 +1,208 @@ +use crate::config::AppConfig; +use crate::config::MetaDataMatch; +use serde_json::Value; +use std::collections::HashMap; +use std::fs; +use strsim::jaro_winkler; +use tracing::info; + +pub fn cached_jaro( + profile_str: &str, + job_str: &str, + cache: &mut HashMap<(String, String), f32>, +) -> f32 { + let key = (profile_str.to_string(), job_str.to_string()); + if let Some(&sim) = cache.get(&key) { + return sim; + } + let sim = jaro_winkler(profile_str, job_str) as f32; + cache.insert(key, sim); + sim +} + +fn weighted_push(parts: &mut Vec, text: &str, weight: usize) { + for _ in 0..weight { + parts.push(text.to_string()); + } +} + +fn load_match_score_config(path: &str) -> Vec { + let data = fs::read_to_string(path).expect("Failed to read match_score.json"); + + #[derive(serde::Deserialize)] + struct Wrapper { + match_score: Vec, + } + + let wrapper: Wrapper = serde_json::from_str(&data).expect("Failed to parse match_score.json"); + wrapper.match_score +} + +pub fn profile_text_for_embedding(profile: &Value, config: &AppConfig) -> String { + let mut parts = Vec::new(); + let match_score = load_match_score_config(config.match_score_path.as_str()); + + for field in &match_score { + if let crate::config::MatchMode::Embed = field.match_mode { + if let Some(value) = profile.pointer(&field.profile_path) { + let weight = field.weight.unwrap_or(1); + if field.is_array { + if let Some(arr) = value.as_array() { + for v in arr { + if let Some(s) = v.as_str() { + weighted_push(&mut parts, s, weight); + } + } + } + } else if let Some(s) = value.as_str() { + weighted_push(&mut parts, s, weight); + } + } + } + } + + parts.join(" ") +} + +pub fn job_text_for_embedding(job: &Value, config: &AppConfig) -> String { + let match_score = load_match_score_config(config.match_score_path.as_str()); + + let mut parts = Vec::new(); + + for field in &match_score { + if let crate::config::MatchMode::Embed = field.match_mode { + if let Some(value) = job.pointer(&field.job_path) { + let weight = field.weight.unwrap_or(1); + if field.is_array { + if let Some(arr) = value.as_array() { + for v in arr { + if let Some(s) = v.as_str() { + weighted_push(&mut parts, s, weight); + } + } + } + } else if let Some(s) = value.as_str() { + weighted_push(&mut parts, s, weight); + } + } + } + } + + parts.join(" ") +} + +pub fn cosine_similarity_with_norm(vec_a: &[f32], vec_b: &[f32], norm_a: f32, norm_b: f32) -> f32 { + if vec_a.len() != vec_b.len() || vec_a.is_empty() || norm_a == 0.0 || norm_b == 0.0 { + return 0.0; + } + let dot_product: f32 = vec_a.iter().zip(vec_b.iter()).map(|(a, b)| a * b).sum(); + dot_product / (norm_a * norm_b) +} + +/// Compute final match score combining embedding cosine and manual numeric fields +pub fn compute_match_score( + profile_emb: &[f32], + profile_norm: f32, + job_emb: &[f32], + job_norm: f32, + profile_meta: &Value, + job_meta: &Value, + config: &AppConfig, + string_sim_cache: &mut HashMap<(String, String), f32>, +) -> f32 { + info!("🔍 Computing match score..."); + let match_score = load_match_score_config(config.match_score_path.as_str()); + + // Base cosine similarity using precomputed norms + let mut score = cosine_similarity_with_norm(profile_emb, job_emb, profile_norm, job_norm); + let base_score = score; + info!("🧮 Base cosine similarity score: {:.4}", base_score); + + let mut mismatches = 0; + + for field in &match_score { + let profile_val = profile_meta.pointer(&field.profile_path); + let job_val = job_meta.pointer(&field.job_path); + + if job_val.is_some() && (profile_val.is_none() || profile_val == Some(&Value::Null)) { + score *= field.penalty; + mismatches += 1; + info!( + "⚠️ {} present in job but missing in profile → applied penalty {:.2}, score now {:.4}", + field.name, field.penalty, score + ); + } + + match field.match_mode { + crate::config::MatchMode::Embed => { + if field.name == "role" || field.name == "industry" { + let profile_str = profile_val.and_then(|v| v.as_str()).unwrap_or_default(); + let job_str = job_val.and_then(|v| v.as_str()).unwrap_or_default(); + + if !profile_str.is_empty() && !job_str.is_empty() { + let sim = cached_jaro(profile_str, job_str, string_sim_cache); + if sim < 0.8 { + score *= field.penalty; + mismatches += 1; + info!( + "⚠️ {} similarity low ({:.2}) → applied penalty {:.2}, score now {:.4}", + field.name, sim, field.penalty, score + ); + } else { + info!( + "✅ {} similarity good ({:.2}) → no penalty applied", + field.name, sim + ); + } + } + } + } + crate::config::MatchMode::Manual => { + let job_min = field + .job_path_min + .as_ref() + .and_then(|p| job_meta.pointer(p)); + let job_max = field + .job_path_max + .as_ref() + .and_then(|p| job_meta.pointer(p)); + + if let (Some(profile_val), Some(job_min), Some(job_max)) = + (profile_val, job_min, job_max) + { + if let (Some(p), Some(min), Some(max)) = + (profile_val.as_f64(), job_min.as_f64(), job_max.as_f64()) + { + if p < min || p > max { + score *= field.penalty; + mismatches += 1; + info!( + "⚠️ {} out of range ({} not in [{}, {}]) → applied penalty {:.2}, score now {:.4}", + field.name, p, min, max, field.penalty, score + ); + } else if let Some(bonus) = field.bonus { + score *= bonus; + info!( + "✅ {} in range ({} in [{}, {}]) → applied bonus {:.2}, score now {:.4}", + field.name, p, min, max, bonus, score + ); + } + } + } + } + } + } + + match mismatches { + 2 => score *= 0.85, + 3..=usize::MAX => score *= 0.7, + _ => {} + } + + if score.is_nan() { + score = 0.0; + info!("🚫 NaN detected — setting score to 0.0"); + } + + score.clamp(0.0, 1.0) +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index a1f567a..fe090b6 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod empeding; pub mod hash; pub mod http_client; pub mod logging;