diff --git a/.changeset/add-library-crate.md b/.changeset/add-library-crate.md new file mode 100644 index 0000000..3845f7e --- /dev/null +++ b/.changeset/add-library-crate.md @@ -0,0 +1,5 @@ +--- +"@googleworkspace/cli": minor +--- + +Expose library crate (`lib.rs`) for programmatic API access. Extracts `config_dir()` and Model Armor sanitization types into standalone modules so they can be shared between the binary and library targets without pulling in CLI-only code. diff --git a/.changeset/sanitize-ssrf-validation.md b/.changeset/sanitize-ssrf-validation.md new file mode 100644 index 0000000..d46902a --- /dev/null +++ b/.changeset/sanitize-ssrf-validation.md @@ -0,0 +1,5 @@ +--- +"@googleworkspace/cli": patch +--- + +Validate sanitize template parameter against path traversal and SSRF before constructing API URLs diff --git a/Cargo.toml b/Cargo.toml index f42163f..2ca2e43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,10 @@ authors = ["Justin Poehnelt"] keywords = ["cli", "google-workspace", "google", "drive", "gmail"] categories = ["command-line-utilities", "web-programming"] +[lib] +name = "gws" +path = "src/lib.rs" + [[bin]] name = "gws" path = "src/main.rs" diff --git a/src/auth.rs b/src/auth.rs index ec57ac2..1b1457e 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -78,7 +78,7 @@ pub async fn get_token(scopes: &[&str]) -> anyhow::Result { } let creds_file = std::env::var("GOOGLE_WORKSPACE_CLI_CREDENTIALS_FILE").ok(); - let config_dir = crate::auth_commands::config_dir(); + let config_dir = crate::config::config_dir(); let enc_path = credential_store::encrypted_credentials_path(); let default_path = config_dir.join("credentials.json"); let token_cache = config_dir.join("token_cache.json"); diff --git a/src/auth_commands.rs b/src/auth_commands.rs index 47d2d4e..cc469f8 100644 --- a/src/auth_commands.rs +++ b/src/auth_commands.rs @@ -92,29 +92,7 @@ const READONLY_SCOPES: &[&str] = &[ ]; pub fn config_dir() -> PathBuf { - if let Ok(dir) = std::env::var("GOOGLE_WORKSPACE_CLI_CONFIG_DIR") { - return PathBuf::from(dir); - } - - // Use ~/.config/gws on all platforms for a consistent, user-friendly path. - let primary = dirs::home_dir() - .unwrap_or_else(|| PathBuf::from(".")) - .join(".config") - .join("gws"); - if primary.exists() { - return primary; - } - - // Backward compat: fall back to OS-specific config dir for existing installs - // (e.g. ~/Library/Application Support/gws on macOS, %APPDATA%\gws on Windows). - let legacy = dirs::config_dir() - .unwrap_or_else(|| PathBuf::from(".")) - .join("gws"); - if legacy.exists() { - return legacy; - } - - primary + crate::config::config_dir() } fn plain_credentials_path() -> PathBuf { diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..3981980 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,48 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::path::PathBuf; + +/// Returns the gws configuration directory. +/// +/// Prefers `~/.config/gws` for a consistent, cross-platform path. +/// Falls back to the OS-specific config directory (e.g. `~/Library/Application Support/gws` +/// on macOS) for backward compatibility with existing installs. +/// +/// The `GOOGLE_WORKSPACE_CLI_CONFIG_DIR` environment variable overrides the default. +pub fn config_dir() -> PathBuf { + if let Ok(dir) = std::env::var("GOOGLE_WORKSPACE_CLI_CONFIG_DIR") { + return PathBuf::from(dir); + } + + // Use ~/.config/gws on all platforms for a consistent, user-friendly path. + let primary = dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".config") + .join("gws"); + if primary.exists() { + return primary; + } + + // Backward compat: fall back to OS-specific config dir for existing installs + // (e.g. ~/Library/Application Support/gws on macOS, %APPDATA%\gws on Windows). + let legacy = dirs::config_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join("gws"); + if legacy.exists() { + return legacy; + } + + primary +} diff --git a/src/credential_store.rs b/src/credential_store.rs index 867d5bc..1f5507c 100644 --- a/src/credential_store.rs +++ b/src/credential_store.rs @@ -45,7 +45,7 @@ fn get_or_create_key() -> anyhow::Result<[u8; 32]> { .or_else(|_| std::env::var("USERNAME")) .unwrap_or_else(|_| "unknown-user".to_string()); - let key_file = crate::auth_commands::config_dir().join(".encryption_key"); + let key_file = crate::config::config_dir().join(".encryption_key"); let entry = Entry::new("gws-cli", &username); @@ -218,7 +218,7 @@ pub fn decrypt(data: &[u8]) -> anyhow::Result> { /// Returns the path for encrypted credentials. pub fn encrypted_credentials_path() -> PathBuf { - crate::auth_commands::config_dir().join("credentials.enc") + crate::config::config_dir().join("credentials.enc") } /// Saves credentials JSON to an encrypted file. diff --git a/src/discovery.rs b/src/discovery.rs index b4fa9ff..5a6bad5 100644 --- a/src/discovery.rs +++ b/src/discovery.rs @@ -195,7 +195,7 @@ pub async fn fetch_discovery_document( let version = crate::validate::validate_api_identifier(version).map_err(|e| anyhow::anyhow!("{e}"))?; - let cache_dir = crate::auth_commands::config_dir().join("cache"); + let cache_dir = crate::config::config_dir().join("cache"); std::fs::create_dir_all(&cache_dir)?; let cache_file = cache_dir.join(format!("{service}_{version}.json")); diff --git a/src/executor.rs b/src/executor.rs index 49101ec..bbe489e 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -211,7 +211,7 @@ async fn handle_json_response( body_text: &str, pagination: &PaginationConfig, sanitize_template: Option<&str>, - sanitize_mode: &crate::helpers::modelarmor::SanitizeMode, + sanitize_mode: &crate::sanitize::SanitizeMode, output_format: &crate::formatter::OutputFormat, pages_fetched: &mut u32, page_token: &mut Option, @@ -224,15 +224,14 @@ async fn handle_json_response( // Run Model Armor sanitization if --sanitize is enabled if let Some(template) = sanitize_template { let text_to_check = serde_json::to_string(&json_val).unwrap_or_default(); - match crate::helpers::modelarmor::sanitize_text(template, &text_to_check).await { + match crate::sanitize::sanitize_text(template, &text_to_check).await { Ok(result) => { let is_match = result.filter_match_state == "MATCH_FOUND"; if is_match { eprintln!("⚠️ Model Armor: prompt injection detected (filterMatchState: MATCH_FOUND)"); } - if is_match && *sanitize_mode == crate::helpers::modelarmor::SanitizeMode::Block - { + if is_match && *sanitize_mode == crate::sanitize::SanitizeMode::Block { let blocked = serde_json::json!({ "error": "Content blocked by Model Armor", "sanitizationResult": serde_json::to_value(&result).unwrap_or_default(), @@ -370,7 +369,7 @@ pub async fn execute_method( dry_run: bool, pagination: &PaginationConfig, sanitize_template: Option<&str>, - sanitize_mode: &crate::helpers::modelarmor::SanitizeMode, + sanitize_mode: &crate::sanitize::SanitizeMode, output_format: &crate::formatter::OutputFormat, capture_output: bool, ) -> Result, GwsError> { @@ -1656,7 +1655,7 @@ async fn test_execute_method_dry_run() { let params_json = r#"{"fileId": "123"}"#; let body_json = r#"{"name": "test.txt"}"#; - let sanitize_mode = crate::helpers::modelarmor::SanitizeMode::Warn; + let sanitize_mode = crate::sanitize::SanitizeMode::Warn; let pagination = PaginationConfig::default(); let result = execute_method( @@ -1701,7 +1700,7 @@ async fn test_execute_method_missing_path_param() { ..Default::default() }; - let sanitize_mode = crate::helpers::modelarmor::SanitizeMode::Warn; + let sanitize_mode = crate::sanitize::SanitizeMode::Warn; let result = execute_method( &doc, &method, diff --git a/src/helpers/modelarmor.rs b/src/helpers/modelarmor.rs index 8ac9fc1..ab17eb4 100644 --- a/src/helpers/modelarmor.rs +++ b/src/helpers/modelarmor.rs @@ -18,65 +18,19 @@ use crate::discovery::RestDescription; use crate::error::GwsError; use anyhow::Context; use clap::{Arg, ArgMatches, Command}; -use serde::{Deserialize, Serialize}; use serde_json::json; use std::future::Future; use std::pin::Pin; -/// Result of a Model Armor sanitization check. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct SanitizationResult { - /// The overall state of the match (e.g., "MATCH_FOUND", "NO_MATCH_FOUND"). - pub filter_match_state: String, - /// Detailed results from specific filters (PI, Jailbreak, etc.). - #[serde(default)] - pub filter_results: serde_json::Value, - /// The final decision based on the policy (e.g., "BLOCK", "ALLOW"). - #[serde(default)] - pub invocation_result: String, -} - -/// Controls behavior when sanitization finds a match. -#[derive(Debug, Clone, PartialEq)] -pub enum SanitizeMode { - /// Log warning to stderr, annotate output with _sanitization field - Warn, - /// Suppress response output, exit non-zero - Block, -} - -/// Configuration for Model Armor sanitization, threaded through the CLI. -#[derive(Debug, Clone)] -pub struct SanitizeConfig { - pub template: Option, - pub mode: SanitizeMode, -} - -impl Default for SanitizeConfig { - /// Provides default values for `SanitizeConfig`. - /// - /// By default, no template is set (sanitization disabled) and the mode is `Warn`. - fn default() -> Self { - Self { - template: None, - mode: SanitizeMode::Warn, - } - } -} +// Re-export sanitization types from the standalone module so existing +// `helpers::modelarmor::` paths continue to compile. +pub use crate::sanitize::{ + sanitize_text, SanitizationResult, SanitizeConfig, SanitizeMode, CLOUD_PLATFORM_SCOPE, +}; -impl SanitizeMode { - /// Parses a string into a `SanitizeMode`. - /// - /// * "block" (case-insensitive) -> `Block` - /// * Any other value -> `Warn` (safe default) - pub fn from_str(s: &str) -> Self { - match s.to_lowercase().as_str() { - "block" => SanitizeMode::Block, - _ => SanitizeMode::Warn, - } - } -} +// Re-export for tests in this module +#[cfg(test)] +pub(crate) use crate::sanitize::{build_sanitize_request_data, parse_sanitize_response}; pub struct ModelArmorHelper; @@ -243,42 +197,6 @@ TIPS: } } -pub const CLOUD_PLATFORM_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform"; - -/// Sanitize text through a Model Armor template and return the result. -/// Template format: projects/PROJECT/locations/LOCATION/templates/TEMPLATE -pub async fn sanitize_text(template: &str, text: &str) -> Result { - let (body, url) = build_sanitize_request_data(template, text, "sanitizeUserPrompt")?; - - let token = auth::get_token(&[CLOUD_PLATFORM_SCOPE]) - .await - .context("Failed to get auth token for Model Armor")?; - - let client = crate::client::build_client()?; - let resp = client - .post(&url) - .header("Authorization", format!("Bearer {token}")) - .header("Content-Type", "application/json") - .body(body) - .send() - .await - .context("Model Armor request failed")?; - - let status = resp.status(); - let resp_text = resp - .text() - .await - .context("Failed to read Model Armor response")?; - - if !status.is_success() { - return Err(GwsError::Other(anyhow::anyhow!( - "Model Armor API returned status {status}: {resp_text}" - ))); - } - - parse_sanitize_response(&resp_text) -} - /// Make a POST request to Model Armor's regional API endpoint. async fn model_armor_post(url: &str, body: &str) -> Result<(), GwsError> { let token = auth::get_token(&[CLOUD_PLATFORM_SCOPE]) @@ -459,23 +377,23 @@ mod tests { #[test] fn test_sanitize_mode_from_str_warn() { - assert_eq!(SanitizeMode::from_str("warn"), SanitizeMode::Warn); - assert_eq!(SanitizeMode::from_str("WARN"), SanitizeMode::Warn); - assert_eq!(SanitizeMode::from_str("Warn"), SanitizeMode::Warn); + assert_eq!(SanitizeMode::from("warn"), SanitizeMode::Warn); + assert_eq!(SanitizeMode::from("WARN"), SanitizeMode::Warn); + assert_eq!(SanitizeMode::from("Warn"), SanitizeMode::Warn); } #[test] fn test_sanitize_mode_from_str_block() { - assert_eq!(SanitizeMode::from_str("block"), SanitizeMode::Block); - assert_eq!(SanitizeMode::from_str("BLOCK"), SanitizeMode::Block); - assert_eq!(SanitizeMode::from_str("Block"), SanitizeMode::Block); + assert_eq!(SanitizeMode::from("block"), SanitizeMode::Block); + assert_eq!(SanitizeMode::from("BLOCK"), SanitizeMode::Block); + assert_eq!(SanitizeMode::from("Block"), SanitizeMode::Block); } #[test] fn test_sanitize_mode_from_str_unknown_defaults_to_warn() { - assert_eq!(SanitizeMode::from_str(""), SanitizeMode::Warn); - assert_eq!(SanitizeMode::from_str("invalid"), SanitizeMode::Warn); - assert_eq!(SanitizeMode::from_str("stop"), SanitizeMode::Warn); + assert_eq!(SanitizeMode::from(""), SanitizeMode::Warn); + assert_eq!(SanitizeMode::from("invalid"), SanitizeMode::Warn); + assert_eq!(SanitizeMode::from("stop"), SanitizeMode::Warn); } #[test] @@ -543,6 +461,43 @@ mod tests { assert_eq!(json["userPromptData"]["text"], "some text"); } + #[test] + fn test_build_sanitize_request_data_rejects_traversal() { + let result = + build_sanitize_request_data("../../etc/passwd", "text", "sanitizeUserPrompt"); + assert!(result.is_err()); + } + + #[test] + fn test_build_sanitize_request_data_rejects_query_injection() { + let result = build_sanitize_request_data( + "projects/p/locations/evil.com?x=y/templates/t", + "text", + "sanitizeUserPrompt", + ); + assert!(result.is_err()); + } + + #[test] + fn test_build_sanitize_request_data_rejects_percent_encoded() { + let result = build_sanitize_request_data( + "projects/p/locations/evil%2ecom/templates/t", + "text", + "sanitizeUserPrompt", + ); + assert!(result.is_err()); + } + + #[test] + fn test_build_sanitize_request_data_rejects_dotted_location() { + let result = build_sanitize_request_data( + "projects/p/locations/evil.com/templates/t", + "text", + "sanitizeUserPrompt", + ); + assert!(result.is_err()); + } + #[test] fn test_parse_sanitize_response_success() { let json_resp = json!({ @@ -565,47 +520,6 @@ mod tests { } } -pub fn build_sanitize_request_data( - template: &str, - text: &str, - method: &str, -) -> Result<(String, String), GwsError> { - let location = extract_location(template).ok_or_else(|| { - GwsError::Validation( - "Cannot extract location from --sanitize template. Expected format: projects/PROJECT/locations/LOCATION/templates/TEMPLATE".to_string(), - ) - })?; - - let base = regional_base_url(location); - let url = format!("{base}/{template}:{method}"); - - // Identify data field based on method - let data_field = if method == "sanitizeUserPrompt" { - "userPromptData" - } else { - "modelResponseData" - }; - - let body = json!({data_field: {"text": text}}).to_string(); - Ok((body, url)) -} - -pub fn parse_sanitize_response(resp_text: &str) -> Result { - // Parse the response to extract sanitizationResult - let parsed: serde_json::Value = - serde_json::from_str(resp_text).context("Failed to parse Model Armor response")?; - - let result = parsed.get("sanitizationResult").ok_or_else(|| { - GwsError::Other(anyhow::anyhow!( - "No sanitizationResult in Model Armor response" - )) - })?; - - let res = - serde_json::from_value(result.clone()).context("Failed to parse sanitization result")?; - Ok(res) -} - fn parse_sanitize_args(matches: &ArgMatches, data_field: &str) -> Result { if let Some(json_str) = matches.get_one::("json") { Ok(json_str.clone()) diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c038733 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Google Workspace CLI — library crate. +//! +//! Provides programmatic access to the core functionality of `gws`: +//! Discovery Document parsing, OAuth / Service Account authentication, +//! API execution, and input validation. +//! +//! The binary (`gws`) re-declares these same modules; both targets compile +//! from the same source files (dual compilation). + +// Internal modules are shared with the binary via dual compilation. +// They appear unused from the library's perspective but are needed by the bin target. +#![allow(dead_code)] + +pub mod auth; +pub(crate) mod auth_commands; +pub mod client; +pub mod commands; +pub mod config; +pub mod credential_store; +pub mod discovery; +pub mod error; +pub mod executor; +pub mod formatter; +pub mod fs_util; +pub(crate) mod generate_skills; +pub(crate) mod helpers; +pub mod oauth_config; +pub mod sanitize; +pub(crate) mod schema; +pub mod services; +pub(crate) mod setup; +pub(crate) mod setup_tui; +pub(crate) mod text; +pub mod token_storage; +pub mod validate; diff --git a/src/main.rs b/src/main.rs index 89d73d9..73b8889 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,6 +23,7 @@ mod auth; pub(crate) mod auth_commands; mod client; mod commands; +pub(crate) mod config; pub(crate) mod credential_store; mod discovery; mod error; @@ -32,6 +33,7 @@ mod fs_util; mod generate_skills; mod helpers; mod oauth_config; +pub(crate) mod sanitize; mod schema; mod services; mod setup; @@ -189,7 +191,7 @@ async fn run() -> Result<(), GwsError> { .or_else(|| std::env::var("GOOGLE_WORKSPACE_CLI_SANITIZE_TEMPLATE").ok()); let sanitize_mode = std::env::var("GOOGLE_WORKSPACE_CLI_SANITIZE_MODE") - .map(|v| helpers::modelarmor::SanitizeMode::from_str(&v)) + .map(|v| helpers::modelarmor::SanitizeMode::from(v.as_str())) .unwrap_or(helpers::modelarmor::SanitizeMode::Warn); let sanitize_config = parse_sanitize_config(sanitize_template, &sanitize_mode)?; @@ -225,7 +227,7 @@ async fn run() -> Result<(), GwsError> { // Select the best scope for the method. Discovery Documents list scopes as // alternatives (any one grants access). We pick the first (broadest) scope // to avoid restrictive scopes like gmail.metadata that block query parameters. - let scopes: Vec<&str> = select_scope(&method.scopes).into_iter().collect(); + let scopes: Vec<&str> = services::select_scope(&method.scopes).into_iter().collect(); // Authenticate: try OAuth, fail with error if credentials exist but are broken let (token, auth_method) = match auth::get_token(&scopes).await { @@ -265,17 +267,6 @@ async fn run() -> Result<(), GwsError> { .map(|_| ()) } -/// Select the best scope from a method's scope list. -/// -/// Discovery Documents list method scopes as alternatives — any single scope -/// grants access. The first scope is typically the broadest. Using all scopes -/// causes issues when restrictive scopes (e.g., `gmail.metadata`) are included, -/// as the API enforces that scope's restrictions even when broader scopes are -/// also present. -pub(crate) fn select_scope(scopes: &[String]) -> Option<&str> { - scopes.first().map(|s| s.as_str()) -} - fn parse_pagination_config(matches: &clap::ArgMatches) -> executor::PaginationConfig { executor::PaginationConfig { page_all: matches.get_flag("page-all"), @@ -288,27 +279,7 @@ pub fn parse_service_and_version( args: &[String], first_arg: &str, ) -> Result<(String, String), GwsError> { - let mut service_arg = first_arg; - let mut version_override: Option = None; - - // Check for --api-version flag anywhere in args - for i in 0..args.len() { - if args[i] == "--api-version" && i + 1 < args.len() { - version_override = Some(args[i + 1].clone()); - } - } - - // Support "service:version" syntax on the service arg itself - if let Some((svc, ver)) = service_arg.split_once(':') { - service_arg = svc; - if version_override.is_none() { - version_override = Some(ver.to_string()); - } - } - - let (api_name, default_version) = services::resolve_service(service_arg)?; - let version = version_override.unwrap_or(default_version); - Ok((api_name, version)) + services::parse_service_and_version(args, first_arg) } pub fn filter_args_for_subcommand(args: &[String], service_name: &str) -> Vec { @@ -663,29 +634,4 @@ mod tests { assert_eq!(filtered, vec!["gws", "files", "list", "--format", "table"]); } - #[test] - fn test_select_scope_picks_first() { - let scopes = vec![ - "https://mail.google.com/".to_string(), - "https://www.googleapis.com/auth/gmail.metadata".to_string(), - "https://www.googleapis.com/auth/gmail.modify".to_string(), - "https://www.googleapis.com/auth/gmail.readonly".to_string(), - ]; - assert_eq!(select_scope(&scopes), Some("https://mail.google.com/")); - } - - #[test] - fn test_select_scope_single() { - let scopes = vec!["https://www.googleapis.com/auth/drive".to_string()]; - assert_eq!( - select_scope(&scopes), - Some("https://www.googleapis.com/auth/drive") - ); - } - - #[test] - fn test_select_scope_empty() { - let scopes: Vec = vec![]; - assert_eq!(select_scope(&scopes), None); - } } diff --git a/src/oauth_config.rs b/src/oauth_config.rs index 02154b5..ec9002d 100644 --- a/src/oauth_config.rs +++ b/src/oauth_config.rs @@ -54,7 +54,7 @@ pub struct ClientSecretFile { /// Returns the path for the client secret config file. pub fn client_config_path() -> PathBuf { - crate::auth_commands::config_dir().join("client_secret.json") + crate::config::config_dir().join("client_secret.json") } /// Saves OAuth client configuration in the standard Google Cloud Console format. diff --git a/src/sanitize.rs b/src/sanitize.rs new file mode 100644 index 0000000..4eea0ea --- /dev/null +++ b/src/sanitize.rs @@ -0,0 +1,175 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Model Armor sanitization types and helpers. +//! +//! Extracted from `helpers::modelarmor` so that library consumers can use +//! sanitization without pulling in CLI-only helper infrastructure. + +use anyhow::Context; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::error::GwsError; + +/// Result of a Model Armor sanitization check. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SanitizationResult { + /// The overall state of the match (e.g., "MATCH_FOUND", "NO_MATCH_FOUND"). + pub filter_match_state: String, + /// Detailed results from specific filters (PI, Jailbreak, etc.). + #[serde(default)] + pub filter_results: serde_json::Value, + /// The final decision based on the policy (e.g., "BLOCK", "ALLOW"). + #[serde(default)] + pub invocation_result: String, +} + +/// Controls behavior when sanitization finds a match. +#[derive(Debug, Clone, PartialEq)] +pub enum SanitizeMode { + /// Log warning to stderr, annotate output with _sanitization field + Warn, + /// Suppress response output, exit non-zero + Block, +} + +/// Configuration for Model Armor sanitization, threaded through the CLI. +#[derive(Debug, Clone)] +pub struct SanitizeConfig { + pub template: Option, + pub mode: SanitizeMode, +} + +impl Default for SanitizeConfig { + /// Provides default values for `SanitizeConfig`. + /// + /// By default, no template is set (sanitization disabled) and the mode is `Warn`. + fn default() -> Self { + Self { + template: None, + mode: SanitizeMode::Warn, + } + } +} + +impl From<&str> for SanitizeMode { + /// Parses a string into a `SanitizeMode`. + /// + /// * "block" (case-insensitive) -> `Block` + /// * Any other value -> `Warn` (safe default) + fn from(s: &str) -> Self { + match s.to_lowercase().as_str() { + "block" => SanitizeMode::Block, + _ => SanitizeMode::Warn, + } + } +} + +pub const CLOUD_PLATFORM_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform"; + +/// Sanitize text through a Model Armor template and return the result. +/// Template format: projects/PROJECT/locations/LOCATION/templates/TEMPLATE +pub async fn sanitize_text(template: &str, text: &str) -> Result { + let (body, url) = build_sanitize_request_data(template, text, "sanitizeUserPrompt")?; + + let token = crate::auth::get_token(&[CLOUD_PLATFORM_SCOPE]) + .await + .context("Failed to get auth token for Model Armor")?; + + let client = crate::client::build_client()?; + let resp = client + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .header("Content-Type", "application/json") + .body(body) + .send() + .await + .context("Model Armor request failed")?; + + let status = resp.status(); + let resp_text = resp + .text() + .await + .context("Failed to read Model Armor response")?; + + if !status.is_success() { + return Err(GwsError::Other(anyhow::anyhow!( + "Model Armor API returned status {status}: {resp_text}" + ))); + } + + parse_sanitize_response(&resp_text) +} + +/// Build the regional base URL for Model Armor API. +fn regional_base_url(location: &str) -> String { + format!("https://modelarmor.{location}.rep.googleapis.com/v1") +} + +/// Extract location from a full template resource name. +/// e.g. "projects/my-project/locations/us-central1/templates/my-template" -> "us-central1" +fn extract_location(resource_name: &str) -> Option<&str> { + let parts: Vec<&str> = resource_name.split('/').collect(); + for i in 0..parts.len() { + if parts[i] == "locations" && i + 1 < parts.len() { + return Some(parts[i + 1]); + } + } + None +} + +pub fn build_sanitize_request_data( + template: &str, + text: &str, + method: &str, +) -> Result<(String, String), GwsError> { + let template = crate::validate::validate_resource_name(template)?; + let location = extract_location(template).ok_or_else(|| { + GwsError::Validation( + "Cannot extract location from --sanitize template. Expected format: projects/PROJECT/locations/LOCATION/templates/TEMPLATE".to_string(), + ) + })?; + let location = crate::validate::validate_gcp_location(location)?; + + let base = regional_base_url(location); + let url = format!("{base}/{template}:{method}"); + + // Identify data field based on method + let data_field = if method == "sanitizeUserPrompt" { + "userPromptData" + } else { + "modelResponseData" + }; + + let body = json!({data_field: {"text": text}}).to_string(); + Ok((body, url)) +} + +pub fn parse_sanitize_response(resp_text: &str) -> Result { + // Parse the response to extract sanitizationResult + let parsed: serde_json::Value = + serde_json::from_str(resp_text).context("Failed to parse Model Armor response")?; + + let result = parsed.get("sanitizationResult").ok_or_else(|| { + GwsError::Other(anyhow::anyhow!( + "No sanitizationResult in Model Armor response" + )) + })?; + + let res = + serde_json::from_value(result.clone()).context("Failed to parse sanitization result")?; + Ok(res) +} diff --git a/src/services.rs b/src/services.rs index 40a4b81..9c6fc8b 100644 --- a/src/services.rs +++ b/src/services.rs @@ -128,6 +128,48 @@ pub const SERVICES: &[ServiceEntry] = &[ }, ]; +/// Selects the scope to request for an API method. +/// +/// Google API methods list their accepted scopes from broadest to narrowest. +/// We pick only the first (broadest) scope because requesting multiple scopes +/// causes issues when restrictive scopes (e.g., `gmail.metadata`) are included, +/// as the API enforces that scope's restrictions even when broader scopes are +/// also present. +pub fn select_scope(scopes: &[String]) -> Option<&str> { + scopes.first().map(|s| s.as_str()) +} + +/// Parses a service name (with optional `:version` suffix) and an `--api-version` +/// flag from raw CLI args into `(api_name, version)`. +pub fn parse_service_and_version( + args: &[String], + first_arg: &str, +) -> Result<(String, String), GwsError> { + let mut service_arg = first_arg; + let mut version_override: Option = None; + + // Check for --api-version flag anywhere in args + for i in 0..args.len() { + if args[i] == "--api-version" && i + 1 < args.len() && !args[i + 1].starts_with('-') { + version_override = Some(args[i + 1].clone()); + } else if let Some(val) = args[i].strip_prefix("--api-version=") { + version_override = Some(val.to_string()); + } + } + + // Support "service:version" syntax on the service arg itself + if let Some((svc, ver)) = service_arg.split_once(':') { + service_arg = svc; + if version_override.is_none() { + version_override = Some(ver.to_string()); + } + } + + let (api_name, default_version) = resolve_service(service_arg)?; + let version = version_override.unwrap_or(default_version); + Ok((api_name, version)) +} + /// Resolves a service alias to (api_name, version). pub fn resolve_service(name: &str) -> Result<(String, String), GwsError> { for entry in SERVICES { @@ -166,6 +208,32 @@ mod tests { ); } + #[test] + fn test_select_scope_picks_first() { + let scopes = vec![ + "https://mail.google.com/".to_string(), + "https://www.googleapis.com/auth/gmail.metadata".to_string(), + "https://www.googleapis.com/auth/gmail.modify".to_string(), + "https://www.googleapis.com/auth/gmail.readonly".to_string(), + ]; + assert_eq!(select_scope(&scopes), Some("https://mail.google.com/")); + } + + #[test] + fn test_select_scope_single() { + let scopes = vec!["https://www.googleapis.com/auth/drive".to_string()]; + assert_eq!( + select_scope(&scopes), + Some("https://www.googleapis.com/auth/drive") + ); + } + + #[test] + fn test_select_scope_empty() { + let scopes: Vec = vec![]; + assert_eq!(select_scope(&scopes), None); + } + #[test] fn test_resolve_service_unknown() { let err = resolve_service("unknown_service"); diff --git a/src/validate.rs b/src/validate.rs index cfefd60..898193b 100644 --- a/src/validate.rs +++ b/src/validate.rs @@ -241,6 +241,26 @@ pub fn validate_api_identifier(s: &str) -> Result<&str, GwsError> { Ok(s) } +/// Validate a GCP region/location identifier (e.g. `us-central1`, `europe-west4`). +/// Only ASCII lowercase letters, digits, and hyphens are allowed — dots are +/// rejected to prevent SSRF when the location is interpolated into a hostname. +pub fn validate_gcp_location(s: &str) -> Result<&str, GwsError> { + if s.is_empty() { + return Err(GwsError::Validation( + "GCP location must not be empty".to_string(), + )); + } + if !s + .chars() + .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-') + { + return Err(GwsError::Validation(format!( + "GCP location contains invalid characters (only lowercase alphanumeric and '-' allowed): {s}" + ))); + } + Ok(s) +} + #[cfg(test)] mod tests { use super::*; @@ -566,4 +586,37 @@ mod tests { fn test_validate_api_identifier_empty() { assert!(validate_api_identifier("").is_err()); } + + // --- validate_gcp_location --- + + #[test] + fn test_validate_gcp_location_valid() { + assert_eq!(validate_gcp_location("us-central1").unwrap(), "us-central1"); + assert_eq!(validate_gcp_location("europe-west4").unwrap(), "europe-west4"); + assert_eq!(validate_gcp_location("asia-east1").unwrap(), "asia-east1"); + assert_eq!(validate_gcp_location("global").unwrap(), "global"); + } + + #[test] + fn test_validate_gcp_location_rejects_dots() { + assert!(validate_gcp_location("evil.com").is_err()); + assert!(validate_gcp_location("internal.evil.com").is_err()); + } + + #[test] + fn test_validate_gcp_location_rejects_uppercase() { + assert!(validate_gcp_location("US-CENTRAL1").is_err()); + } + + #[test] + fn test_validate_gcp_location_rejects_special_chars() { + assert!(validate_gcp_location("us central1").is_err()); + assert!(validate_gcp_location("us/central1").is_err()); + assert!(validate_gcp_location("us_central1").is_err()); + } + + #[test] + fn test_validate_gcp_location_empty() { + assert!(validate_gcp_location("").is_err()); + } } diff --git a/tests/lib_integration.rs b/tests/lib_integration.rs new file mode 100644 index 0000000..8c33f30 --- /dev/null +++ b/tests/lib_integration.rs @@ -0,0 +1,59 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Integration tests verifying that the library crate exposes key types +//! and functions. These are offline tests — no network calls. + +use gws::discovery::RestDescription; +use gws::error::GwsError; +use gws::services::resolve_service; +use gws::validate::validate_api_identifier; + +#[test] +fn rest_description_deserializes_minimal() { + let json = r#"{"name":"test","version":"v1","rootUrl":"https://example.com/"}"#; + let doc: RestDescription = serde_json::from_str(json).unwrap(); + assert_eq!(doc.name, "test"); + assert_eq!(doc.version, "v1"); +} + +#[test] +fn resolve_service_returns_known() { + let (api, ver) = resolve_service("drive").unwrap(); + assert_eq!(api, "drive"); + assert_eq!(ver, "v3"); +} + +#[test] +fn resolve_service_rejects_unknown() { + assert!(resolve_service("nonexistent").is_err()); +} + +#[test] +fn gws_error_variants_exist() { + let err = GwsError::Validation("test".to_string()); + let json = err.to_json(); + assert_eq!(json["error"]["code"], 400); +} + +#[test] +fn validate_api_identifier_accepts_valid() { + assert!(validate_api_identifier("drive").is_ok()); + assert!(validate_api_identifier("v3").is_ok()); +} + +#[test] +fn validate_api_identifier_rejects_traversal() { + assert!(validate_api_identifier("../etc").is_err()); +}