Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions src/ai/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ impl AiClient {
}

pub fn execute_prompt(&self, system_prompt: &str, user_message: &str) -> Result<AiResponse> {
match self.send_prompt(&self.primary, system_prompt, user_message) {
match self.send_prompt(&self.primary, "primary", system_prompt, user_message) {
Ok(response) => Ok(response),
Err(primary_error) => {
if let Some(fallback) = &self.fallback {
self.send_prompt(fallback, system_prompt, user_message)
self.send_prompt(fallback, "fallback", system_prompt, user_message)
.map_err(|fallback_error| {
anyhow!(
"Primary provider failed: {primary_error}. Fallback provider failed: {fallback_error}"
Expand All @@ -77,10 +77,11 @@ impl AiClient {
fn send_prompt(
&self,
provider: &ProviderConfig,
keychain_role: &str,
system_prompt: &str,
user_message: &str,
) -> Result<AiResponse> {
let api_key = resolve_api_key(provider)?;
let api_key = resolve_api_key(provider, keychain_role)?;
let url = endpoint_url(provider);

let request = match provider.protocol {
Expand Down Expand Up @@ -126,7 +127,7 @@ impl AiClient {
}
}

fn resolve_api_key(provider: &ProviderConfig) -> Result<String> {
fn resolve_api_key(provider: &ProviderConfig, keychain_role: &str) -> Result<String> {
// 1. explicit per-provider env var (highest-precedence override)
if !provider.api_key_env.trim().is_empty() {
if let Ok(value) = env::var(&provider.api_key_env) {
Expand All @@ -152,8 +153,9 @@ fn resolve_api_key(provider: &ProviderConfig) -> Result<String> {
// 4. OS keychain — where `qr init` stores the key when you opt in. Checked
// last because it is only populated when the key is NOT in env or config,
// so the common paths never touch the keychain backend.
let account = crate::secret::account_for(&provider.api_key_env, well_known_env);
if let Some(value) = crate::secret::get(&account) {
if let Some(value) =
crate::secret::get_for_role(keychain_role, &provider.api_key_env, well_known_env)
{
if !value.trim().is_empty() {
return Ok(value);
}
Expand Down Expand Up @@ -437,13 +439,16 @@ mod tests {
std::env::set_var("OPENAI_API_KEY", "well-known-token");
}

let api_key = resolve_api_key(&ProviderConfig {
protocol: AiProtocol::OpenAi,
base_url: "https://example.test/v1".into(),
model: "demo".into(),
api_key: "config-token".into(),
api_key_env: "CUSTOM_QR_TEST_AI_KEY".into(),
})
let api_key = resolve_api_key(
&ProviderConfig {
protocol: AiProtocol::OpenAi,
base_url: "https://example.test/v1".into(),
model: "demo".into(),
api_key: "config-token".into(),
api_key_env: "CUSTOM_QR_TEST_AI_KEY".into(),
},
"primary",
)
.unwrap();
assert_eq!(api_key, "custom-token");

Expand All @@ -462,13 +467,16 @@ mod tests {
std::env::set_var("ANTHROPIC_API_KEY", "well-known-token");
}

let api_key = resolve_api_key(&ProviderConfig {
protocol: AiProtocol::Anthropic,
base_url: "https://example.test".into(),
model: "claude-demo".into(),
api_key: "config-token".into(),
api_key_env: "CUSTOM_QR_TEST_ANTHROPIC_KEY".into(),
})
let api_key = resolve_api_key(
&ProviderConfig {
protocol: AiProtocol::Anthropic,
base_url: "https://example.test".into(),
model: "claude-demo".into(),
api_key: "config-token".into(),
api_key_env: "CUSTOM_QR_TEST_ANTHROPIC_KEY".into(),
},
"primary",
)
.unwrap();
assert_eq!(api_key, "well-known-token");

Expand All @@ -486,13 +494,16 @@ mod tests {
std::env::remove_var("OPENAI_API_KEY");
}

let api_key = resolve_api_key(&ProviderConfig {
protocol: AiProtocol::OpenAi,
base_url: "https://example.test/v1".into(),
model: "demo".into(),
api_key: "config-token".into(),
api_key_env: "CUSTOM_QR_TEST_AI_KEY".into(),
})
let api_key = resolve_api_key(
&ProviderConfig {
protocol: AiProtocol::OpenAi,
base_url: "https://example.test/v1".into(),
model: "demo".into(),
api_key: "config-token".into(),
api_key_env: "CUSTOM_QR_TEST_AI_KEY".into(),
},
"primary",
)
.unwrap();
assert_eq!(api_key, "config-token");
}
Expand Down
79 changes: 69 additions & 10 deletions src/atomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//! concurrently: the project cache (read by `qr go` while the hourly cron
//! rewrites it) and the user's shell rc file (corrupting it is a 5-alarm fire).

#[cfg(unix)]
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
use std::{
ffi::OsString,
fs,
Expand All @@ -19,6 +21,21 @@ use anyhow::{Context, Result};

/// Atomically replace `path` with `contents`.
pub fn write(path: &Path, contents: &[u8]) -> Result<()> {
write_impl(path, contents, WriteMode::PreserveExisting)
}

/// Atomically replace `path` with `contents`, creating the file with private
/// permissions (`0600`) from the start on Unix.
pub fn write_private(path: &Path, contents: &[u8]) -> Result<()> {
write_impl(path, contents, WriteMode::Private)
}

enum WriteMode {
PreserveExisting,
Private,
}

fn write_impl(path: &Path, contents: &[u8], mode: WriteMode) -> Result<()> {
// Follow a final symlink so we replace its target — as `fs::write` did —
// rather than clobbering the symlink itself with a regular file. This keeps
// dotfiles-managed symlinked rc files (e.g. `~/.zshrc` -> a versioned repo)
Expand All @@ -32,8 +49,7 @@ pub fn write(path: &Path, contents: &[u8]) -> Result<()> {
}

let tmp = temp_path(target);
let mut file = fs::File::create(&tmp)
.with_context(|| format!("Failed to create temp file {}", tmp.display()))?;
let mut file = create_temp_file(&tmp, &mode)?;
let result = file.write_all(contents).and_then(|()| file.sync_all());
drop(file);
if let Err(error) = result {
Expand All @@ -45,12 +61,24 @@ pub fn write(path: &Path, contents: &[u8]) -> Result<()> {
// installs a fresh inode, so without this an existing rc/config file's mode
// (e.g. a 0600 config) would reset to the default. Fail loudly rather than
// silently downgrade the mode.
if let Ok(metadata) = fs::metadata(target) {
if let Err(error) = fs::set_permissions(&tmp, metadata.permissions()) {
let _ = fs::remove_file(&tmp);
return Err(error).with_context(|| {
format!("Failed to preserve permissions on {}", target.display())
});
match mode {
WriteMode::PreserveExisting => {
if let Ok(metadata) = fs::metadata(target) {
if let Err(error) = fs::set_permissions(&tmp, metadata.permissions()) {
let _ = fs::remove_file(&tmp);
return Err(error).with_context(|| {
format!("Failed to preserve permissions on {}", target.display())
});
}
}
}
WriteMode::Private => {
#[cfg(unix)]
if let Err(error) = fs::set_permissions(&tmp, fs::Permissions::from_mode(0o600)) {
let _ = fs::remove_file(&tmp);
return Err(error)
.with_context(|| format!("Failed to secure {}", target.display()));
}
}
}

Expand All @@ -61,6 +89,27 @@ pub fn write(path: &Path, contents: &[u8]) -> Result<()> {
Ok(())
}

fn create_temp_file(path: &Path, mode: &WriteMode) -> Result<fs::File> {
#[cfg(unix)]
{
let mut options = fs::OpenOptions::new();
options.write(true).create(true).truncate(true);
if matches!(mode, WriteMode::Private) {
options.mode(0o600);
}
options
.open(path)
.with_context(|| format!("Failed to create temp file {}", path.display()))
}

#[cfg(not(unix))]
{
let _ = mode;
fs::File::create(path)
.with_context(|| format!("Failed to create temp file {}", path.display()))
}
}

/// Resolve a final symlink to its real target so an atomic replacement writes
/// through the link (matching `fs::write`) instead of clobbering it — including a
/// dangling symlink, whose not-yet-existing target is created rather than the
Expand Down Expand Up @@ -136,8 +185,6 @@ mod tests {
#[cfg(unix)]
#[test]
fn write_preserves_existing_file_permissions() {
use std::os::unix::fs::PermissionsExt;

let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("config.toml");
write(&path, b"v1").unwrap();
Expand All @@ -149,6 +196,18 @@ mod tests {
assert_eq!(mode, 0o600, "atomic replace must preserve the file mode");
}

#[cfg(unix)]
#[test]
fn write_private_creates_file_with_private_permissions() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("config.toml");

write_private(&path, b"secret = true").unwrap();

let mode = fs::metadata(&path).unwrap().permissions().mode() & 0o777;
assert_eq!(mode, 0o600, "private writes must create 0600 files");
}

#[test]
fn concurrent_writes_to_same_path_never_interleave() {
// Two threads writing the same target must each use a private temp file,
Expand Down
25 changes: 15 additions & 10 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use anyhow::{Context, Result, anyhow};
use serde::{Deserialize, Serialize};

use crate::ai::providers::{AiProtocol, ProviderConfig};
use crate::atomic;

const DEFAULT_CONFIG: &str = include_str!("../config/default.toml");

Expand Down Expand Up @@ -89,18 +90,22 @@ impl AppConfig {
}

pub fn load_from_env_with_path(path: PathBuf) -> Result<Self> {
let mut config = if path.exists() {
let raw = fs::read_to_string(&path)
.with_context(|| format!("Failed to read config file {}", path.display()))?;
Self::load_from_str(&raw)?
} else {
Self::load_from_str(DEFAULT_CONFIG)?
};

let mut config = Self::load_file_without_env(&path)?;
apply_env_overrides(&mut config)?;
Ok(config)
}

/// Parse `config.toml` on disk without applying `QR_*` env overrides.
pub fn load_file_without_env(path: &Path) -> Result<Self> {
if path.exists() {
let raw = fs::read_to_string(path)
.with_context(|| format!("Failed to read config file {}", path.display()))?;
Self::load_from_str(&raw)
} else {
Self::load_from_str(DEFAULT_CONFIG)
}
}

pub fn ensure_parent_dirs(&self) -> Result<()> {
for path in [self.stats_db_path(), cache_file_path(), config_file_path()] {
if let Some(parent) = path.parent() {
Expand Down Expand Up @@ -267,7 +272,7 @@ fn rewrite_legacy_paths_in_config(
};

if let Some(updated) = rewrite_stats_db_path_in_toml(&raw, &new_value)? {
fs::write(config_path, updated)?;
atomic::write_private(config_path, updated.as_bytes())?;
}

Ok(())
Expand Down Expand Up @@ -369,7 +374,7 @@ pub fn write_default_config_if_missing(path: &Path) -> Result<bool> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(path, DEFAULT_CONFIG)?;
atomic::write_private(path, DEFAULT_CONFIG.as_bytes())?;
Ok(true)
}

Expand Down
Loading
Loading