Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions src-tauri/src/proxy/providers/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
let mut has_sent_message_start = false;
let mut current_non_tool_block_type: Option<&'static str> = None;
let mut current_non_tool_block_index: Option<u32> = None;
let mut pending_leading_text = String::new();
let mut tool_blocks_by_index: HashMap<usize, ToolBlockState> = HashMap::new();
let mut open_tool_block_indices: HashSet<u32> = HashSet::new();

Expand Down Expand Up @@ -180,6 +181,7 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(

// 处理 reasoning(thinking)
if let Some(reasoning) = &choice.delta.reasoning {
pending_leading_text.clear();
if current_non_tool_block_type != Some("thinking") {
if let Some(index) = current_non_tool_block_index.take() {
let event = json!({
Expand Down Expand Up @@ -225,6 +227,14 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
// 处理文本内容
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
if current_non_tool_block_index.is_none()
&& current_non_tool_block_type.is_none()
&& content.trim().is_empty()
{
pending_leading_text.push_str(content);
continue;
}

if current_non_tool_block_type != Some("text") {
if let Some(index) = current_non_tool_block_index.take() {
let event = json!({
Expand Down Expand Up @@ -254,12 +264,20 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
}

if let Some(index) = current_non_tool_block_index {
let text = if pending_leading_text.is_empty() {
content.clone()
} else {
let mut text =
std::mem::take(&mut pending_leading_text);
text.push_str(content);
text
};
let event = json!({
"type": "content_block_delta",
"index": index,
"delta": {
"type": "text_delta",
"text": content
"text": text
}
});
let sse_data = format!("event: content_block_delta\ndata: {}\n\n",
Expand All @@ -271,6 +289,7 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(

// 处理工具调用
if let Some(tool_calls) = &choice.delta.tool_calls {
pending_leading_text.clear();
if let Some(index) = current_non_tool_block_index.take() {
let event = json!({
"type": "content_block_stop",
Expand Down Expand Up @@ -381,7 +400,8 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
"content_block": {
"type": "tool_use",
"id": id,
"name": name
"name": name,
"input": {}
}
});
let sse_data = format!("event: content_block_start\ndata: {}\n\n",
Expand Down Expand Up @@ -473,7 +493,8 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
"content_block": {
"type": "tool_use",
"id": id,
"name": name
"name": name,
"input": {}
}
});
let sse_data = format!("event: content_block_start\ndata: {}\n\n",
Expand Down Expand Up @@ -658,6 +679,12 @@ mod tests {
event.pointer("/content_block/id").and_then(|v| v.as_str()),
event.get("index").and_then(|v| v.as_u64()),
) {
assert!(
event
.pointer("/content_block/input")
.is_some_and(Value::is_object),
"tool_use content_block_start must include an empty input object"
);
tool_index_by_call.insert(call_id.to_string(), index);
}
}
Expand Down Expand Up @@ -760,6 +787,12 @@ mod tests {
.unwrap_or(""),
"first_tool"
);
assert!(
starts[0]
.pointer("/content_block/input")
.is_some_and(Value::is_object),
"late-started tool_use content_block_start must include input"
);

let deltas: Vec<&str> = events
.iter()
Expand All @@ -778,6 +811,60 @@ mod tests {
assert!(deltas.contains(&"1}"));
}

#[tokio::test]
async fn test_streaming_drops_whitespace_only_text_before_tool_call() {
let input = concat!(
"data: {\"id\":\"chatcmpl_4\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"\\n\\n\"}}]}\n\n",
"data: {\"id\":\"chatcmpl_4\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\" \"}}]}\n\n",
"data: {\"id\":\"chatcmpl_4\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_0\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\\\"x\\\"}\"}}]}}]}\n\n",
"data: {\"id\":\"chatcmpl_4\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":3}}\n\n",
"data: [DONE]\n\n"
);

let upstream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(
input.as_bytes().to_vec(),
))]);
let converted = create_anthropic_sse_stream(upstream);
let chunks: Vec<_> = converted.collect().await;
let merged = chunks
.into_iter()
.map(|chunk| String::from_utf8_lossy(chunk.unwrap().as_ref()).to_string())
.collect::<String>();

let events: Vec<Value> = merged
.split("\n\n")
.filter_map(|block| {
let data = block
.lines()
.find_map(|line| strip_sse_field(line, "data"))?;
serde_json::from_str::<Value>(data).ok()
})
.collect();

let text_starts = events
.iter()
.filter(|event| {
event.get("type").and_then(|v| v.as_str()) == Some("content_block_start")
&& event
.pointer("/content_block/type")
.and_then(|v| v.as_str())
== Some("text")
})
.count();
assert_eq!(
text_starts, 0,
"leading whitespace must not create a standalone text block before tool_use"
);

assert!(events.iter().any(|event| {
event.get("type").and_then(|v| v.as_str()) == Some("content_block_start")
&& event
.pointer("/content_block/type")
.and_then(|v| v.as_str())
== Some("tool_use")
}));
}

#[tokio::test]
async fn test_streaming_chinese_split_across_chunks_no_replacement_chars() {
// "你好" split across two TCP chunks inside a streaming text delta.
Expand Down
84 changes: 43 additions & 41 deletions src-tauri/src/proxy/providers/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,9 @@ pub fn anthropic_to_openai(body: Value) -> Result<Value, ProxyError> {
// 单个字符串
messages.push(json!({"role": "system", "content": text}));
} else if let Some(arr) = system.as_array() {
// 多个 system message — preserve cache_control for compatible proxies
for msg in arr {
if let Some(text) = msg.get("text").and_then(|t| t.as_str()) {
let mut sys_msg = json!({"role": "system", "content": text});
if let Some(cc) = msg.get("cache_control") {
sys_msg["cache_control"] = cc.clone();
}
messages.push(sys_msg);
messages.push(json!({"role": "system", "content": text}));
}
}
}
Expand Down Expand Up @@ -149,18 +144,14 @@ pub fn anthropic_to_openai(body: Value) -> Result<Value, ProxyError> {
.iter()
.filter(|t| t.get("type").and_then(|v| v.as_str()) != Some("BatchTool"))
.map(|t| {
let mut tool = json!({
json!({
"type": "function",
"function": {
"name": t.get("name").and_then(|n| n.as_str()).unwrap_or(""),
"description": t.get("description"),
"parameters": clean_schema(t.get("input_schema").cloned().unwrap_or(json!({})))
}
});
if let Some(cc) = t.get("cache_control") {
tool["cache_control"] = cc.clone();
}
tool
})
})
.collect();

Expand Down Expand Up @@ -316,11 +307,7 @@ fn convert_message_to_openai(
match block_type {
"text" => {
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
let mut part = json!({"type": "text", "text": text});
if let Some(cc) = block.get("cache_control") {
part["cache_control"] = cc.clone();
}
content_parts.push(part);
content_parts.push(json!({"type": "text", "text": text}));
}
}
"image" => {
Expand Down Expand Up @@ -382,17 +369,21 @@ fn convert_message_to_openai(
if content_parts.is_empty() {
msg["content"] = Value::Null;
} else if content_parts.len() == 1 {
// When cache_control is present, keep array format to preserve it
let has_cache_control = content_parts[0].get("cache_control").is_some();
if !has_cache_control {
if let Some(text) = content_parts[0].get("text") {
msg["content"] = text.clone();
} else {
msg["content"] = json!(content_parts);
}
if let Some(text) = content_parts[0].get("text") {
msg["content"] = text.clone();
} else {
msg["content"] = json!(content_parts);
}
} else if content_parts
.iter()
.all(|part| part.get("type").and_then(|v| v.as_str()) == Some("text"))
{
let text = content_parts
.iter()
.filter_map(|part| part.get("text").and_then(|v| v.as_str()))
.collect::<Vec<_>>()
.join("\n");
msg["content"] = json!(text);
} else {
msg["content"] = json!(content_parts);
}
Expand Down Expand Up @@ -663,7 +654,7 @@ mod tests {
}

#[test]
fn test_anthropic_to_openai_preserves_matching_system_cache_control_when_merging() {
fn test_anthropic_to_openai_drops_matching_system_cache_control_when_merging() {
let input = json!({
"model": "claude-3-sonnet",
"max_tokens": 1024,
Expand All @@ -681,7 +672,7 @@ mod tests {
result["messages"][0]["content"],
"You are Claude Code.\nBe concise."
);
assert_eq!(result["messages"][0]["cache_control"]["type"], "ephemeral");
assert!(result["messages"][0].get("cache_control").is_none());
assert_eq!(result["messages"][1]["role"], "user");
}

Expand Down Expand Up @@ -849,6 +840,27 @@ mod tests {
assert_eq!(msg["content"], "Sunny, 25°C");
}

#[test]
fn test_anthropic_to_openai_flattens_text_only_content_blocks() {
let input = json!({
"model": "moonshotai/kimi-k2.5",
"max_tokens": 1024,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "Caveat: local command output follows."},
{"type": "text", "text": "Tell me the settings path."}
]
}]
});

let result = anthropic_to_openai(input).unwrap();
assert_eq!(
result["messages"][0]["content"],
"Caveat: local command output follows.\nTell me the settings path."
);
}

#[test]
fn test_openai_to_anthropic_simple() {
let input = json!({
Expand Down Expand Up @@ -931,7 +943,7 @@ mod tests {
}

#[test]
fn test_anthropic_to_openai_cache_control_preserved() {
fn test_anthropic_to_openai_drops_cache_control() {
let input = json!({
"model": "claude-3-opus",
"max_tokens": 1024,
Expand All @@ -953,19 +965,9 @@ mod tests {
});

let result = anthropic_to_openai(input).unwrap();
// System message cache_control preserved
assert_eq!(result["messages"][0]["cache_control"]["type"], "ephemeral");
// Text block cache_control preserved
assert_eq!(
result["messages"][1]["content"][0]["cache_control"]["type"],
"ephemeral"
);
assert_eq!(
result["messages"][1]["content"][0]["cache_control"]["ttl"],
"5m"
);
// Tool cache_control preserved
assert_eq!(result["tools"][0]["cache_control"]["type"], "ephemeral");
assert!(result["messages"][0].get("cache_control").is_none());
assert_eq!(result["messages"][1]["content"], "Hello");
assert!(result["tools"][0].get("cache_control").is_none());
}

#[test]
Expand Down