Skip to content

Commit a6b375c

Browse files
committed
fix(functions-pull): preserve prompt serialization and identity metadata
1 parent f4c9a73 commit a6b375c

1 file changed

Lines changed: 122 additions & 31 deletions

File tree

src/functions/pull.rs

Lines changed: 122 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ struct PullFunctionRow {
5151

5252
#[derive(Debug, Clone)]
5353
struct NormalizedPrompt {
54+
id: String,
55+
version: Option<String>,
5456
variable_seed: String,
5557
name: String,
5658
slug: String,
@@ -60,6 +62,8 @@ struct NormalizedPrompt {
6062
model: Option<Value>,
6163
params: Option<Value>,
6264
tools: Option<Value>,
65+
raw_tools_json: Option<String>,
66+
tool_functions: Option<Value>,
6367
}
6468

6569
#[derive(Debug)]
@@ -374,6 +378,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> {
374378
&repo,
375379
args.force,
376380
args.language,
381+
&project_id,
377382
&project_name,
378383
&file_name,
379384
&rows,
@@ -615,6 +620,7 @@ fn write_pull_file(
615620
repo: &Option<GitRepo>,
616621
force: bool,
617622
language: FunctionsLanguage,
623+
project_id: &str,
618624
project_name: &str,
619625
file_name: &str,
620626
rows: &[PullFunctionRow],
@@ -655,18 +661,19 @@ fn write_pull_file(
655661
return;
656662
}
657663

658-
let rendered = match render_project_file(language, project_name, &display_target, rows) {
659-
Ok(rendered) => rendered,
660-
Err(err) => {
661-
record_pull_file_failure(
662-
summary,
663-
target.display().to_string(),
664-
HardFailureReason::ResponseInvalid,
665-
err.to_string(),
666-
);
667-
return;
668-
}
669-
};
664+
let rendered =
665+
match render_project_file(language, project_id, project_name, &display_target, rows) {
666+
Ok(rendered) => rendered,
667+
Err(err) => {
668+
record_pull_file_failure(
669+
summary,
670+
target.display().to_string(),
671+
HardFailureReason::ResponseInvalid,
672+
err.to_string(),
673+
);
674+
return;
675+
}
676+
};
670677
match write_text_atomic(&target, &rendered) {
671678
Ok(()) => {
672679
summary.files_written += 1;
@@ -959,6 +966,7 @@ fn display_output_path(target: &Path) -> String {
959966

960967
fn render_project_file(
961968
language: FunctionsLanguage,
969+
project_id: &str,
962970
project_name: &str,
963971
file_name: &str,
964972
rows: &[PullFunctionRow],
@@ -973,7 +981,7 @@ fn render_project_file(
973981

974982
match language {
975983
FunctionsLanguage::Typescript => {
976-
render_project_file_ts(project_name, file_name, &normalized)
984+
render_project_file_ts(project_id, project_name, file_name, &normalized)
977985
}
978986
FunctionsLanguage::Python => render_project_file_py(project_name, file_name, &normalized),
979987
}
@@ -1042,27 +1050,31 @@ fn normalize_prompt_row(row: &PullFunctionRow) -> Result<NormalizedPrompt> {
10421050
.filter(|value| !is_empty_render_value(value))
10431051
.cloned();
10441052

1045-
let mut tools: Vec<Value> = prompt_data
1046-
.get("tool_functions")
1047-
.and_then(Value::as_array)
1048-
.cloned()
1049-
.unwrap_or_default();
1050-
if let Some(raw_tools) = prompt_block.get("tools").and_then(Value::as_str) {
1051-
if !raw_tools.trim().is_empty() {
1052-
if let Ok(parsed) = serde_json::from_str::<Value>(raw_tools) {
1053-
if let Some(items) = parsed.as_array() {
1054-
tools.extend(items.iter().cloned());
1055-
}
1056-
}
1057-
}
1058-
}
1059-
let tools = if tools.is_empty() {
1060-
None
1061-
} else {
1062-
Some(Value::Array(tools))
1053+
let raw_tools_json = prompt_block
1054+
.get("tools")
1055+
.and_then(Value::as_str)
1056+
.map(str::trim)
1057+
.filter(|value| !value.is_empty())
1058+
.map(ToOwned::to_owned);
1059+
let tools = match prompt_block.get("tools") {
1060+
Some(Value::String(_)) | None => None,
1061+
Some(other) if is_empty_render_value(other) => None,
1062+
Some(other) => Some(other.clone()),
10631063
};
1064+
let tool_functions = prompt_data
1065+
.get("tool_functions")
1066+
.filter(|value| !is_empty_render_value(value))
1067+
.cloned();
1068+
let version = row
1069+
._xact_id
1070+
.as_deref()
1071+
.map(str::trim)
1072+
.filter(|value| !value.is_empty())
1073+
.map(ToOwned::to_owned);
10641074

10651075
Ok(NormalizedPrompt {
1076+
id: row.id.clone(),
1077+
version,
10661078
variable_seed: row.slug.clone(),
10671079
name: row.name.clone(),
10681080
slug: row.slug.clone(),
@@ -1072,10 +1084,13 @@ fn normalize_prompt_row(row: &PullFunctionRow) -> Result<NormalizedPrompt> {
10721084
model,
10731085
params,
10741086
tools,
1087+
raw_tools_json,
1088+
tool_functions,
10751089
})
10761090
}
10771091

10781092
fn render_project_file_ts(
1093+
project_id: &str,
10791094
project_name: &str,
10801095
file_name: &str,
10811096
prompts: &[NormalizedPrompt],
@@ -1098,6 +1113,7 @@ fn render_project_file_ts(
10981113

10991114
out.push_str("import braintrust from \"braintrust\";\n\n");
11001115
out.push_str("const project = braintrust.projects.create({\n");
1116+
out.push_str(&format!(" id: {},\n", serde_json::to_string(project_id)?));
11011117
out.push_str(&format!(
11021118
" name: {},\n",
11031119
serde_json::to_string(project_name)?
@@ -1114,8 +1130,12 @@ fn render_project_file_ts(
11141130
);
11151131

11161132
let mut body_lines = Vec::new();
1133+
body_lines.push(format!(" id: {},", serde_json::to_string(&row.id)?));
11171134
body_lines.push(format!(" name: {},", serde_json::to_string(&row.name)?));
11181135
body_lines.push(format!(" slug: {},", serde_json::to_string(&row.slug)?));
1136+
if let Some(version) = &row.version {
1137+
body_lines.push(format!(" version: {},", serde_json::to_string(version)?));
1138+
}
11191139

11201140
if let Some(description) = &row.description {
11211141
body_lines.push(format!(
@@ -1139,6 +1159,18 @@ fn render_project_file_ts(
11391159
if let Some(tools) = &row.tools {
11401160
body_lines.push(format!(" tools: {},", format_ts_value(tools, 2)));
11411161
}
1162+
if let Some(raw_tools_json) = &row.raw_tools_json {
1163+
body_lines.push(format!(
1164+
" tools: JSON.parse({}),",
1165+
serde_json::to_string(raw_tools_json)?
1166+
));
1167+
}
1168+
if let Some(tool_functions) = &row.tool_functions {
1169+
body_lines.push(format!(
1170+
" toolFunctions: {},",
1171+
format_ts_value(tool_functions, 2)
1172+
));
1173+
}
11421174

11431175
out.push_str(&format!(
11441176
"export const {var_name} = project.prompts.create({{\n"
@@ -1155,6 +1187,7 @@ fn render_project_file_py(
11551187
file_name: &str,
11561188
prompts: &[NormalizedPrompt],
11571189
) -> Result<String> {
1190+
let needs_json_import = prompts.iter().any(|row| row.raw_tools_json.is_some());
11581191
let mut out = String::new();
11591192
out.push_str("# This file was automatically generated by bt functions pull. You can\n");
11601193
out.push_str("# generate it again by running:\n");
@@ -1170,6 +1203,9 @@ fn render_project_file_py(
11701203
"# $ bt functions push --file {}\n\n",
11711204
serde_json::to_string(file_name)?
11721205
));
1206+
if needs_json_import {
1207+
out.push_str("import json\n");
1208+
}
11731209
out.push_str("import braintrust\n\n");
11741210
out.push_str(&format!(
11751211
"project = braintrust.projects.create(name={})\n\n",
@@ -1213,6 +1249,18 @@ fn render_project_file_py(
12131249
if let Some(tools) = &row.tools {
12141250
out.push_str(&format!(" tools={},\n", format_py_value(tools, 4)));
12151251
}
1252+
if let Some(raw_tools_json) = &row.raw_tools_json {
1253+
out.push_str(&format!(
1254+
" tools=json.loads({}),\n",
1255+
format_py_value(&Value::String(raw_tools_json.clone()), 4)
1256+
));
1257+
}
1258+
if let Some(tool_functions) = &row.tool_functions {
1259+
out.push_str(&format!(
1260+
" tool_functions={},\n",
1261+
format_py_value(tool_functions, 4)
1262+
));
1263+
}
12161264
out.push_str(")\n\n");
12171265
}
12181266

@@ -1738,6 +1786,7 @@ mod tests {
17381786

17391787
let rendered = render_project_file(
17401788
FunctionsLanguage::Python,
1789+
"p1",
17411790
"woohoo",
17421791
"braintrust/woohoo.py",
17431792
&[row],
@@ -1753,6 +1802,48 @@ mod tests {
17531802
assert!(rendered.contains("model=\"gpt-4o-mini\""));
17541803
}
17551804

1805+
#[test]
1806+
fn render_project_file_typescript_includes_prompt_identity() {
1807+
let row = PullFunctionRow {
1808+
id: "f1".to_string(),
1809+
name: "Basic math".to_string(),
1810+
slug: "basic-math".to_string(),
1811+
project_id: "p1".to_string(),
1812+
project_name: Some("woohoo".to_string()),
1813+
description: None,
1814+
prompt_data: Some(serde_json::json!({
1815+
"prompt": {
1816+
"type": "chat",
1817+
"messages": [
1818+
{ "content": "Hello", "role": "system" }
1819+
]
1820+
},
1821+
"options": {
1822+
"model": "gpt-4o-mini"
1823+
}
1824+
})),
1825+
function_data: Some(serde_json::json!({ "type": "prompt" })),
1826+
created: None,
1827+
_xact_id: Some("123".to_string()),
1828+
};
1829+
1830+
let rendered = render_project_file(
1831+
FunctionsLanguage::Typescript,
1832+
"p1",
1833+
"woohoo",
1834+
"braintrust/woohoo.ts",
1835+
&[row],
1836+
)
1837+
.expect("rendered");
1838+
1839+
assert!(rendered.contains("const project = braintrust.projects.create({"));
1840+
assert!(rendered.contains(" id: \"p1\","));
1841+
assert!(rendered.contains(" name: \"woohoo\","));
1842+
assert!(rendered.contains("export const basicMath = project.prompts.create({"));
1843+
assert!(rendered.contains(" id: \"f1\","));
1844+
assert!(rendered.contains(" version: \"123\","));
1845+
}
1846+
17561847
#[test]
17571848
fn format_ts_value_unquotes_safe_keys_only() {
17581849
let value = serde_json::json!({

0 commit comments

Comments
 (0)