diff --git a/src/safeoutputs/add_pr_comment.rs b/src/safeoutputs/add_pr_comment.rs index 290181e8..23880854 100644 --- a/src/safeoutputs/add_pr_comment.rs +++ b/src/safeoutputs/add_pr_comment.rs @@ -4,6 +4,7 @@ use log::{debug, info}; use percent_encoding::utf8_percent_encode; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; use super::PATH_SEGMENT; use crate::safeoutputs::{ExecutionContext, ExecutionResult, Executor, Validate}; @@ -55,6 +56,14 @@ fn default_status() -> String { "active".to_string() } +fn validate_repository_selector(repository: &str) -> anyhow::Result<()> { + reject_pipeline_injection(repository, "repository")?; + if !repository.is_empty() { + crate::validate::validate_relative_safe_path(repository, "repository")?; + } + Ok(()) +} + impl Validate for AddPrCommentParams { fn validate(&self) -> anyhow::Result<()> { ensure!(self.pull_request_id > 0, "pull_request_id must be positive"); @@ -87,7 +96,7 @@ impl Validate for AddPrCommentParams { if let Some(fp) = &self.file_path { validate_file_path(fp)?; } - reject_pipeline_injection(&self.repository, "repository")?; + validate_repository_selector(&self.repository)?; Ok(()) } } @@ -199,6 +208,93 @@ fn validate_file_path(path: &str) -> anyhow::Result<()> { Ok(()) } +fn repository_checkout_dir(repository: &str, ctx: &ExecutionContext) -> anyhow::Result { + if crate::safeoutputs::input_refers_to_self(repository, ctx) { + return Ok(ctx.source_directory.clone()); + } + + if let Some((alias, _)) = ctx.allowed_repositories.get_key_value(repository) { + return Ok(ctx.source_directory.join(alias)); + } + + if let Some((alias, _)) = ctx + .allowed_repositories + .iter() + .find(|(_, name)| name.eq_ignore_ascii_case(repository)) + { + return Ok(ctx.source_directory.join(alias)); + } + + if let Some((alias, _)) = ctx.allowed_repositories.iter().find(|(_, name)| { + name.rsplit('/') + .next() + .unwrap_or(name.as_str()) + .eq_ignore_ascii_case(repository) + }) { + return Ok(ctx.source_directory.join(alias)); + } + + anyhow::bail!( + "Repository alias '{}' not found in allowed repositories", + repository + ) +} + +fn build_inline_thread_context( + workspace_root: &Path, + repo_root: &Path, + file_path: &str, + start_line: i32, + end_line: i32, +) -> anyhow::Result { + ensure!(start_line > 0, "start_line must be positive"); + ensure!(end_line > 0, "end_line must be positive"); + ensure!( + start_line <= end_line, + "start_line ({start_line}) must be less than or equal to line ({end_line})" + ); + + let resolved_path = repo_root.join(file_path); + let canonical = resolved_path.canonicalize().with_context(|| { + format!( + "Failed to canonicalize inline comment file '{}' β€” file may not exist", + file_path + ) + })?; + let canonical_root = repo_root + .canonicalize() + .context("Failed to canonicalize repository checkout root")?; + ensure!( + canonical.starts_with(&canonical_root), + "Inline comment file '{}' resolves outside the repository checkout", + file_path + ); + let canonical_workspace = workspace_root + .canonicalize() + .context("Failed to canonicalize build workspace root")?; + ensure!( + canonical.starts_with(&canonical_workspace), + "Inline comment file '{}' resolves outside the build workspace", + file_path + ); + + let contents = std::fs::read_to_string(&canonical) + .with_context(|| format!("Failed to read inline comment file '{}'", file_path))?; + let target_line = contents + .lines() + .nth((end_line - 1) as usize) + .with_context(|| format!("Inline comment line {} is out of range", end_line))?; + // Azure DevOps threadContext offsets are 1-based, so the end offset must point + // one UTF-16 code unit past the final character to span the whole target line. + let end_offset = target_line.encode_utf16().count() as i32 + 1; + + Ok(serde_json::json!({ + "filePath": format!("/{}", file_path), + "rightFileStart": { "line": start_line, "offset": 1 }, + "rightFileEnd": { "line": end_line, "offset": end_offset } + })) +} + #[async_trait::async_trait] impl Executor for AddPrCommentResult { fn dry_run_summary(&self) -> String { @@ -334,11 +430,36 @@ impl Executor for AddPrCommentResult { if let Some(ref fp) = self.file_path { let end_line = self.line.unwrap_or(1); let start_line = self.start_line.unwrap_or(end_line); - thread_body["threadContext"] = serde_json::json!({ - "filePath": format!("/{}", fp), - "rightFileStart": { "line": start_line, "offset": 1 }, - "rightFileEnd": { "line": end_line, "offset": 1 } - }); + let repo_root = match repository_checkout_dir(&self.repository, ctx).and_then(|path| { + crate::validate::ensure_path_within_base( + &path, + &ctx.source_directory, + "Repository checkout root", + ) + }) { + Ok(path) => path, + Err(err) => { + return Ok(ExecutionResult::failure(format!( + "Failed to resolve repository checkout for '{}': {}", + self.repository, err + ))); + } + }; + match build_inline_thread_context( + &ctx.source_directory, + &repo_root, + fp, + start_line, + end_line, + ) { + Ok(thread_context) => thread_body["threadContext"] = thread_context, + Err(err) => { + return Ok(ExecutionResult::failure(format!( + "Failed to anchor inline comment for '{}': {}", + fp, err + ))); + } + } } let client = reqwest::Client::new(); @@ -398,6 +519,7 @@ impl Executor for AddPrCommentResult { mod tests { use super::*; use crate::safeoutputs::ToolResult; + use tempfile::tempdir; #[test] fn test_result_has_correct_name() { @@ -478,6 +600,36 @@ mod tests { assert!(result.is_err()); } + #[test] + fn test_validation_rejects_repository_traversal_selector() { + let params = AddPrCommentParams { + pull_request_id: 42, + content: "This is a valid comment body text.".to_string(), + repository: "../sibling-repo".to_string(), + file_path: None, + start_line: None, + line: None, + status: "active".to_string(), + }; + let result: Result = params.try_into(); + assert!(result.is_err()); + } + + #[test] + fn test_validation_accepts_project_scoped_repository_selector() { + let params = AddPrCommentParams { + pull_request_id: 42, + content: "This is a valid comment body text.".to_string(), + repository: "4x4/sdk-FtdiDeviceControl".to_string(), + file_path: None, + start_line: None, + line: None, + status: "active".to_string(), + }; + let result: Result = params.try_into(); + assert!(result.is_ok()); + } + #[test] fn test_validation_rejects_line_without_file_path() { let params = AddPrCommentParams { @@ -664,4 +816,78 @@ allowed-statuses: result.repository ); } + + #[test] + fn test_build_inline_thread_context_uses_utf16_end_offset() { + let dir = tempdir().unwrap(); + std::fs::write(dir.path().join("suggestion.rs"), "prefix\nabπŸ˜€\n").unwrap(); + + let thread_context = + build_inline_thread_context(dir.path(), dir.path(), "suggestion.rs", 2, 2).unwrap(); + + assert_eq!(thread_context["rightFileStart"]["line"], 2); + assert_eq!(thread_context["rightFileStart"]["offset"], 1); + assert_eq!(thread_context["rightFileEnd"]["line"], 2); + assert_eq!(thread_context["rightFileEnd"]["offset"], 5); + } + + #[test] + fn test_build_inline_thread_context_uses_last_line_for_multiline_span() { + let dir = tempdir().unwrap(); + std::fs::write( + dir.path().join("suggestion.rs"), + "first line\nabπŸ˜€\nthird\n", + ) + .unwrap(); + + let thread_context = + build_inline_thread_context(dir.path(), dir.path(), "suggestion.rs", 1, 2).unwrap(); + + assert_eq!(thread_context["rightFileStart"]["line"], 1); + assert_eq!(thread_context["rightFileEnd"]["line"], 2); + assert_eq!(thread_context["rightFileEnd"]["offset"], 5); + } + + #[test] + fn test_build_inline_thread_context_rejects_repo_root_outside_workspace() { + let workspace = tempdir().unwrap(); + let outside_repo = tempdir().unwrap(); + std::fs::write(outside_repo.path().join("suggestion.rs"), "line 1\n").unwrap(); + + let err = build_inline_thread_context( + workspace.path(), + outside_repo.path(), + "suggestion.rs", + 1, + 1, + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("outside the build workspace"), "got: {err}"); + } + + #[test] + fn test_repository_checkout_dir_resolves_full_repository_name_to_alias_path() { + let workspace = tempdir().unwrap(); + let alias_dir = workspace.path().join("repo-sdk-ftdidevicecontrol"); + std::fs::create_dir(&alias_dir).unwrap(); + + let mut allowed_repositories = std::collections::HashMap::new(); + allowed_repositories.insert( + "repo-sdk-ftdidevicecontrol".to_string(), + "4x4/sdk-FtdiDeviceControl".to_string(), + ); + + let ctx = ExecutionContext { + source_directory: workspace.path().to_path_buf(), + allowed_repositories, + repository_name: Some("4x4/current-repo".to_string()), + ..Default::default() + }; + + let resolved = repository_checkout_dir("4x4/sdk-ftdidevicecontrol", &ctx).unwrap(); + + assert_eq!(resolved, alias_dir); + } } diff --git a/src/sanitize.rs b/src/sanitize.rs index 554dea66..8456b4f6 100644 --- a/src/sanitize.rs +++ b/src/sanitize.rs @@ -16,6 +16,7 @@ //! identifiers like area paths, wiki names, or assignee emails. use log::debug; +use std::ops::Range; /// Trait for types that contain untrusted agent-generated text fields. /// @@ -260,6 +261,23 @@ fn remove_xml_comments(input: &str) -> String { /// Convert HTML/XML tags to safe HTML entities (IS-06). fn escape_html_tags(input: &str) -> String { + let protected = markdown_protected_ranges(input); + if protected.is_empty() { + return escape_html_fragment(input); + } + + let mut result = String::with_capacity(input.len()); + let mut cursor = 0; + for range in protected { + result.push_str(&escape_html_fragment(&input[cursor..range.start])); + result.push_str(&input[range.start..range.end]); + cursor = range.end; + } + result.push_str(&escape_html_fragment(&input[cursor..])); + result +} + +fn escape_html_fragment(input: &str) -> String { let mut result = String::with_capacity(input.len()); let mut rest = input; @@ -281,6 +299,156 @@ fn escape_html_tags(input: &str) -> String { result } +fn markdown_protected_ranges(input: &str) -> Vec> { + let fence_ranges = fenced_code_ranges(input); + let mut ranges = Vec::new(); + let mut cursor = 0; + + for fence in &fence_ranges { + if cursor < fence.start { + collect_inline_code_ranges(input, cursor, fence.start, &mut ranges); + } + ranges.push(fence.clone()); + cursor = fence.end; + } + + if cursor < input.len() { + collect_inline_code_ranges(input, cursor, input.len(), &mut ranges); + } + + ranges +} + +fn fenced_code_ranges(input: &str) -> Vec> { + let mut ranges = Vec::new(); + let mut line_start = 0; + + while line_start < input.len() { + let (line_end, next_line_start) = line_bounds(input, line_start); + let line = &input[line_start..line_end]; + + if let Some((marker, count)) = parse_fence_opener(line) + && let Some(block_end) = find_matching_fence_end(input, next_line_start, marker, count) + { + ranges.push(line_start..block_end); + line_start = block_end; + continue; + } + + line_start = next_line_start; + } + + ranges +} + +fn collect_inline_code_ranges( + input: &str, + start: usize, + end: usize, + ranges: &mut Vec>, +) { + let bytes = input.as_bytes(); + let mut i = start; + + while i < end { + if bytes[i] != b'`' { + i += 1; + continue; + } + + let tick_count = count_repeated_byte(bytes, i, end, b'`'); + let inline_code_boundary = input[i..end] + .find('\n') + .map(|offset| i + offset) + .unwrap_or(end); + let mut cursor = i + tick_count; + let mut matched_end = None; + + while cursor < inline_code_boundary { + if bytes[cursor] == b'`' { + let candidate_count = + count_repeated_byte(bytes, cursor, inline_code_boundary, b'`'); + if candidate_count == tick_count { + matched_end = Some(cursor + candidate_count); + break; + } + cursor += candidate_count; + } else { + cursor += 1; + } + } + + if let Some(span_end) = matched_end { + ranges.push(i..span_end); + i = span_end; + } else { + i += tick_count; + } + } +} + +fn line_bounds(input: &str, line_start: usize) -> (usize, usize) { + let line_end = input[line_start..] + .find('\n') + .map(|offset| line_start + offset) + .unwrap_or(input.len()); + let next_line_start = if line_end < input.len() { + line_end + 1 + } else { + input.len() + }; + (line_end, next_line_start) +} + +fn parse_fence_opener(line: &str) -> Option<(u8, usize)> { + let indent = line.bytes().take_while(|b| *b == b' ').count(); + if indent > 3 { + return None; + } + + let rest = &line.as_bytes()[indent..]; + let marker = *rest.first()?; + if marker != b'`' && marker != b'~' { + return None; + } + + let count = rest.iter().take_while(|&&b| b == marker).count(); + (count >= 3).then_some((marker, count)) +} + +fn find_matching_fence_end( + input: &str, + mut line_start: usize, + marker: u8, + min_count: usize, +) -> Option { + while line_start < input.len() { + let (line_end, next_line_start) = line_bounds(input, line_start); + let line = &input[line_start..line_end]; + let indent = line.bytes().take_while(|b| *b == b' ').count(); + + if indent <= 3 { + let rest = &line.as_bytes()[indent..]; + let count = rest.iter().take_while(|&&b| b == marker).count(); + if count >= min_count && rest[count..].iter().all(|b| matches!(b, b' ' | b'\t')) { + return Some(next_line_start); + } + } + + line_start = next_line_start; + } + + None +} + +fn count_repeated_byte(bytes: &[u8], start: usize, end: usize, byte: u8) -> usize { + let mut count = 0; + while start + count < end && bytes[start + count] == byte { + count += 1; + } + count +} + // ── IS-07b: URL protocol sanitization ────────────────────────────────────── /// Strip unsafe URL protocols (javascript:, data:, file:, vbscript:). @@ -586,6 +754,33 @@ mod tests { assert_eq!(sanitize(input), input); } + #[test] + fn test_escape_html_tags_preserves_inline_code_spans() { + let input = "Use `` and bold."; + assert_eq!( + escape_html_tags(input), + "Use `` and <b>bold</b>." + ); + } + + #[test] + fn test_escape_html_tags_preserves_fenced_code_blocks() { + let input = "```suggestion\nif (a < b) {\n return;\n}\n```\n
tail
"; + assert_eq!( + escape_html_tags(input), + "```suggestion\nif (a < b) {\n return;\n}\n```\n<div>tail</div>" + ); + } + + #[test] + fn test_escape_html_tags_unmatched_inline_backtick_does_not_disable_escaping() { + let input = "Unclosed `code still escaped"; + assert_eq!( + escape_html_tags(input), + "Unclosed `code <b>still escaped</b>" + ); + } + // ── sanitize_config tests ───────────────────────────────────────────── #[test]