From 2944177e5924c32a740775e632bbd088ea747c79 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 23 Dec 2025 22:09:03 -0800 Subject: [PATCH 1/2] fix: handle expressions in shard key UPDATE statements --- .../query_engine/multi_step/test/update.rs | 132 ++++++++++++++++-- .../router/parser/rewrite/statement/update.rs | 124 +++++++++++++++- 2 files changed, 237 insertions(+), 19 deletions(-) diff --git a/pgdog/src/frontend/client/query_engine/multi_step/test/update.rs b/pgdog/src/frontend/client/query_engine/multi_step/test/update.rs index 6588d88b..80adcca4 100644 --- a/pgdog/src/frontend/client/query_engine/multi_step/test/update.rs +++ b/pgdog/src/frontend/client/query_engine/multi_step/test/update.rs @@ -223,9 +223,13 @@ async fn test_transaction_required() { async fn test_move_rows_simple() { let mut client = TestClient::new_rewrites(Parameters::default()).await; + let shard_0_id = client.random_id_for_shard(0); + let shard_1_id = client.random_id_for_shard(1); + client .send_simple(Query::new(format!( - "INSERT INTO sharded (id) VALUES (1) ON CONFLICT(id) DO NOTHING", + "INSERT INTO sharded (id) VALUES ({}) ON CONFLICT(id) DO NOTHING", + shard_0_id ))) .await; client.read_until('Z').await.unwrap(); @@ -234,14 +238,16 @@ async fn test_move_rows_simple() { client.read_until('Z').await.unwrap(); client - .try_send_simple(Query::new( - "UPDATE sharded SET id = 11 WHERE id = 1 RETURNING id", - )) + .try_send_simple(Query::new(format!( + "UPDATE sharded SET id = {} WHERE id = {} RETURNING id", + shard_1_id, shard_0_id + ))) .await .unwrap(); let reply = client.read_until('Z').await.unwrap(); + let shard_1_id_str = shard_1_id.to_string(); reply .into_iter() .zip(['T', 'D', 'C', 'Z']) @@ -266,7 +272,7 @@ async fn test_move_rows_simple() { ), 'D' => assert_eq!( DataRow::try_from(message).unwrap().column(0).unwrap(), - "11".as_bytes() + shard_1_id_str.as_bytes() ), _ => unreachable!(), } @@ -277,9 +283,13 @@ async fn test_move_rows_simple() { async fn test_move_rows_extended() { let mut client = TestClient::new_rewrites(Parameters::default()).await; + let shard_0_id = client.random_id_for_shard(0); + let shard_1_id = client.random_id_for_shard(1); + client .send_simple(Query::new(format!( - "INSERT INTO sharded (id) VALUES (1) ON CONFLICT(id) DO NOTHING", + "INSERT INTO sharded (id) VALUES ({}) ON CONFLICT(id) DO NOTHING", + shard_0_id ))) .await; client.read_until('Z').await.unwrap(); @@ -296,8 +306,8 @@ async fn test_move_rows_extended() { .send(Bind::new_params( "", &[ - Parameter::new("1".as_bytes()), - Parameter::new("11".as_bytes()), + Parameter::new(shard_0_id.to_string().as_bytes()), + Parameter::new(shard_1_id.to_string().as_bytes()), ], )) .await; @@ -307,6 +317,7 @@ async fn test_move_rows_extended() { let reply = client.read_until('Z').await.unwrap(); + let shard_1_id_str = shard_1_id.to_string(); reply .into_iter() .zip(['1', '2', 'D', 'C', 'Z']) @@ -323,7 +334,7 @@ async fn test_move_rows_extended() { ), 'D' => assert_eq!( DataRow::try_from(message).unwrap().column(0).unwrap(), - "11".as_bytes() + shard_1_id_str.as_bytes() ), '1' | '2' => (), _ => unreachable!(), @@ -336,9 +347,13 @@ async fn test_move_rows_prepared() { crate::logger(); let mut client = TestClient::new_rewrites(Parameters::default()).await; + let shard_0_id = client.random_id_for_shard(0); + let shard_1_id = client.random_id_for_shard(1); + client .send_simple(Query::new(format!( - "INSERT INTO sharded (id) VALUES (1) ON CONFLICT(id) DO NOTHING", + "INSERT INTO sharded (id) VALUES ({}) ON CONFLICT(id) DO NOTHING", + shard_0_id ))) .await; client.read_until('Z').await.unwrap(); @@ -383,8 +398,8 @@ async fn test_move_rows_prepared() { .send(Bind::new_params( "__test_1", &[ - Parameter::new("1".as_bytes()), - Parameter::new("11".as_bytes()), + Parameter::new(shard_0_id.to_string().as_bytes()), + Parameter::new(shard_1_id.to_string().as_bytes()), ], )) .await; @@ -394,6 +409,7 @@ async fn test_move_rows_prepared() { let reply = client.read_until('Z').await.unwrap(); + let shard_1_id_str = shard_1_id.to_string(); reply .into_iter() .zip(['2', 'D', 'C', 'Z']) @@ -410,7 +426,7 @@ async fn test_move_rows_prepared() { ), 'D' => assert_eq!( DataRow::try_from(message).unwrap().column(0).unwrap(), - "11".as_bytes() + shard_1_id_str.as_bytes() ), '1' | '2' => (), _ => unreachable!(), @@ -463,3 +479,93 @@ async fn test_same_shard_binary() { } }); } + +#[tokio::test] +async fn test_update_with_expr() { + // Test that UPDATE with expression columns (not simple values) works correctly. + // This validates the bind parameter alignment fix where expression columns + // don't consume bind parameter slots. + // + // Note: Expressions that reference the original row's columns (like COALESCE(value, 'default')) + // won't work because they're inserted literally into the INSERT statement where those + // columns don't exist. Only standalone expressions like 'prefix' || 'suffix' work. + let mut client = TestClient::new_rewrites(Parameters::default()).await; + + // Use random IDs to avoid conflicts with other tests + let shard_0_id = client.random_id_for_shard(0); + let shard_1_id = client.random_id_for_shard(1); + + // Insert a row into shard 0 + client + .send_simple(Query::new(format!( + "INSERT INTO sharded (id, value) VALUES ({}, 'original') ON CONFLICT(id) DO UPDATE SET value = 'original'", + shard_0_id + ))) + .await; + client.read_until('Z').await.unwrap(); + + client.send_simple(Query::new("BEGIN")).await; + client.read_until('Z').await.unwrap(); + + // UPDATE that moves row to different shard with an expression column. + // Use a standalone expression that doesn't reference any columns. + client + .try_send_simple(Query::new(format!( + "UPDATE sharded SET id = {}, value = 'prefix' || '_suffix' WHERE id = {} RETURNING id, value", + shard_1_id, shard_0_id + ))) + .await + .unwrap(); + + let reply = client.read_until('Z').await.unwrap(); + + let shard_1_id_str = shard_1_id.to_string(); + reply + .into_iter() + .zip(['T', 'D', 'C', 'Z']) + .for_each(|(message, code)| { + assert_eq!(message.code(), code); + match code { + 'C' => assert_eq!( + CommandComplete::try_from(message).unwrap().command(), + "UPDATE 1" + ), + 'Z' => assert!( + ReadyForQuery::try_from(message).unwrap().state().unwrap() + == TransactionState::InTrasaction + ), + 'T' => { + let rd = RowDescription::try_from(message).unwrap(); + assert_eq!(rd.field(0).unwrap().name, "id"); + assert_eq!(rd.field(1).unwrap().name, "value"); + } + 'D' => { + let dr = DataRow::try_from(message).unwrap(); + assert_eq!(dr.column(0).unwrap(), shard_1_id_str.as_bytes()); + // The value should be 'prefix_suffix' from the expression + assert_eq!(dr.column(1).unwrap(), "prefix_suffix".as_bytes()); + } + _ => unreachable!(), + } + }); + + client.send_simple(Query::new("COMMIT")).await; + client.read_until('Z').await.unwrap(); + + // Verify the row was actually moved to the new shard with correct values + client + .send_simple(Query::new(format!( + "SELECT id, value FROM sharded WHERE id = {}", + shard_1_id + ))) + .await; + let reply = client.read_until('Z').await.unwrap(); + + let data_row = reply + .iter() + .find(|m| m.code() == 'D') + .expect("should have data row"); + let dr = DataRow::try_from(data_row.clone()).unwrap(); + assert_eq!(dr.column(0).unwrap(), shard_1_id_str.as_bytes()); + assert_eq!(dr.column(1).unwrap(), "prefix_suffix".as_bytes()); +} diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/update.rs b/pgdog/src/frontend/router/parser/rewrite/statement/update.rs index 93eb329a..d8c123e2 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/update.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/update.rs @@ -139,13 +139,14 @@ impl Insert { let mut columns_str = vec![]; let mut values_str = vec![]; - for (idx, field) in row_description.iter().enumerate() { + let mut bind_idx = 0; + for (row_idx, field) in row_description.iter().enumerate() { columns_str.push(format!(r#""{}""#, field.name.replace("\"", "\"\""))); // Escape " if let Some(value) = self.mapping.get(&field.name) { let value = match value { UpdateValue::Value(value) => { - values_str.push(format!("${}", idx + 1)); + values_str.push(format!("${}", bind_idx + 1)); Value::try_from(value).unwrap() // SAFETY: We check that the value is valid. } UpdateValue::Expr(expr) => { @@ -186,7 +187,9 @@ impl Insert { Value::Null => bind.push_param(Parameter::new_null(), Format::Text), } } else { - let value = data_row.get_raw(idx).ok_or(Error::MissingColumn(idx))?; + let value = data_row + .get_raw(row_idx) + .ok_or(Error::MissingColumn(row_idx))?; if value.is_null() { bind.push_param(Parameter::new_null(), Format::Text); @@ -194,7 +197,7 @@ impl Insert { bind.push_param(Parameter::new(&value), Format::Text); } - values_str.push(format!("${}", idx + 1)); + values_str.push(format!("${}", bind_idx + 1)); } columns.push(Node { @@ -206,10 +209,12 @@ impl Insert { values.push(Node { node: Some(NodeEnum::ParamRef(ParamRef { - number: idx as i32 + 1, + number: bind_idx as i32 + 1, ..Default::default() })), }); + + bind_idx += 1; } let insert = InsertStmt { @@ -442,7 +447,7 @@ fn res_targets_to_insert_res_targets( let value = if valid { UpdateValue::Value(*target.val.clone().unwrap()) } else { - UpdateValue::Expr(target.val.as_ref().unwrap().deparse()?) + UpdateValue::Expr(deparse_expr(target.val.as_ref().unwrap())?) }; result.insert(target.name.clone(), value); } @@ -627,6 +632,7 @@ mod test { use pgdog_config::{Rewrite, ShardedTable}; use crate::backend::{replication::ShardedSchemas, ShardedTables}; + use crate::net::messages::row_description::Field; use super::*; @@ -1048,4 +1054,110 @@ mod test { Some("id, email, random()".into()) ); } + + #[test] + fn test_res_targets_to_insert_res_targets_expr_branch() { + // Test that expression assignments (non-simple values) are deparsed correctly + // and stored as UpdateValue::Expr in the insert mapping. + let result = run_test("UPDATE sharded SET id = $1, email = random() WHERE id = $2") + .unwrap() + .unwrap(); + + // The id column should be UpdateValue::Value (simple parameter) + let id_value = result.insert.mapping.get("id").unwrap(); + assert!(matches!(id_value, UpdateValue::Value(_))); + + // The email column should be UpdateValue::Expr with the deparsed expression + let email_value = result.insert.mapping.get("email").unwrap(); + match email_value { + UpdateValue::Expr(expr) => assert_eq!(expr, "random()"), + _ => panic!("Expected UpdateValue::Expr for email"), + } + } + + #[test] + fn test_res_targets_to_insert_res_targets_expr_arithmetic() { + // Test arithmetic expressions are deparsed correctly + let result = run_test("UPDATE sharded SET id = $1, counter = counter + 1 WHERE id = $2") + .unwrap() + .unwrap(); + + let counter_value = result.insert.mapping.get("counter").unwrap(); + match counter_value { + UpdateValue::Expr(expr) => assert_eq!(expr, "counter + 1"), + _ => panic!("Expected UpdateValue::Expr for counter"), + } + } + + #[test] + fn test_res_targets_to_insert_res_targets_expr_coalesce() { + // Test COALESCE expressions are deparsed correctly + let result = + run_test("UPDATE sharded SET id = $1, name = COALESCE(name, 'default') WHERE id = $2") + .unwrap() + .unwrap(); + + let name_value = result.insert.mapping.get("name").unwrap(); + match name_value { + UpdateValue::Expr(expr) => assert_eq!(expr, "COALESCE(name, 'default')"), + _ => panic!("Expected UpdateValue::Expr for name"), + } + } + + #[test] + fn test_insert_build_request_with_expr_column() { + // Test that INSERT statement is built correctly when there are expression columns. + // The expression should appear directly in the VALUES clause. + // Use literal values (not placeholders) to avoid needing bind parameters. + let result = run_test("UPDATE sharded SET id = 42, email = random() WHERE id = 1") + .unwrap() + .unwrap(); + + // Create a mock row description matching the SELECT * result + let row_description = RowDescription::new(&[ + Field::bigint("id"), + Field::text("email"), + Field::text("other_col"), + ]); + + // Create a mock data row with values for columns not in the UPDATE SET clause + let mut data_row = DataRow::new(); + data_row.add("1"); // id - will be overwritten by mapping + data_row.add("old@example.com"); // email - will be overwritten by mapping + data_row.add("other_value"); // other_col - from existing row + + // Create a simple query request (not prepared statement) + let request = ClientRequest::from(vec![ProtocolMessage::from(Query::new( + "UPDATE sharded SET id = 42, email = random() WHERE id = 1", + ))]); + + let insert_request = result + .insert + .build_request(&request, &row_description, &data_row) + .unwrap(); + + // Get the query from the request to verify the INSERT statement + let query = insert_request.query().unwrap().unwrap(); + let stmt = query.query(); + + // The INSERT should contain the expression random() directly in VALUES + assert!( + stmt.contains("random()"), + "INSERT statement should contain the expression: {}", + stmt + ); + // Verify it's an INSERT statement + assert!( + stmt.starts_with("INSERT INTO"), + "Should be an INSERT statement: {}", + stmt + ); + // Verify parameter numbering is correct: $1 for id, random() for email, $2 for other_col + // (not $3, which would be wrong if we used row index instead of bind param index) + assert!( + stmt.contains("$1") && stmt.contains("$2") && !stmt.contains("$3"), + "Parameter numbering should be sequential without gaps: {}", + stmt + ); + } } From 2f899de17aa453266929d8d4392e395f3b9aaf54 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 23 Dec 2025 22:10:49 -0800 Subject: [PATCH 2/2] chore: clippy --- pgdog/src/frontend/client/query_engine/mod.rs | 3 +- .../query_engine/multi_step/forward_check.rs | 2 +- .../client/query_engine/multi_step/update.rs | 18 +++++----- .../router/parser/rewrite/statement/update.rs | 33 +++++++++---------- pgdog/src/net/messages/bind.rs | 2 +- 5 files changed, 28 insertions(+), 30 deletions(-) diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index 1cc2a981..bb2de835 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -178,8 +178,7 @@ impl QueryEngine { .client_request .route // Admin commands don't have a route. .as_mut() - .map(|route| route.take_explain()) - .flatten() + .and_then(|route| route.take_explain()) { if config().config.general.expanded_explain { self.pending_explain = Some(ExplainResponseState::new(trace)); diff --git a/pgdog/src/frontend/client/query_engine/multi_step/forward_check.rs b/pgdog/src/frontend/client/query_engine/multi_step/forward_check.rs index 3b027e4e..a79dbf42 100644 --- a/pgdog/src/frontend/client/query_engine/multi_step/forward_check.rs +++ b/pgdog/src/frontend/client/query_engine/multi_step/forward_check.rs @@ -17,7 +17,7 @@ impl ForwardCheck { pub(crate) fn new(request: &ClientRequest) -> Self { Self { codes: request.iter().map(|m| m.code()).collect(), - describe: request.iter().find(|m| m.code() == 'D').is_some(), + describe: request.iter().any(|m| m.code() == 'D'), sent: HashSet::default(), } } diff --git a/pgdog/src/frontend/client/query_engine/multi_step/update.rs b/pgdog/src/frontend/client/query_engine/multi_step/update.rs index 78da9804..2d92b1fe 100644 --- a/pgdog/src/frontend/client/query_engine/multi_step/update.rs +++ b/pgdog/src/frontend/client/query_engine/multi_step/update.rs @@ -43,10 +43,10 @@ impl<'a> UpdateMulti<'a> { self.engine .error_response(context, ErrorResponse::from_err(&err)) .await?; - return Ok(()); + Ok(()) } else { // These are bad, disconnecting the client. - return Err(err.into()); + Err(err) } } } @@ -57,7 +57,7 @@ impl<'a> UpdateMulti<'a> { &mut self, context: &mut QueryEngineContext<'_>, ) -> Result<(), Error> { - let mut check = self.rewrite.check.build_request(&context.client_request)?; + let mut check = self.rewrite.check.build_request(context.client_request)?; self.route(&mut check, context)?; // The new row is on the same shard as the old row @@ -100,7 +100,7 @@ impl<'a> UpdateMulti<'a> { row: Row, ) -> Result<(), Error> { let mut request = self.rewrite.insert.build_request( - &context.client_request, + context.client_request, &row.row_description, &row.data_row, )?; @@ -168,7 +168,7 @@ impl<'a> UpdateMulti<'a> { .handle_client_request(request, &mut Router::default(), false) .await?; - let mut checker = ForwardCheck::new(&context.client_request); + let mut checker = ForwardCheck::new(context.client_request); while self.engine.backend.has_more_messages() { let message = self.engine.read_server_message(context).await?; @@ -194,7 +194,7 @@ impl<'a> UpdateMulti<'a> { self.engine .backend .handle_client_request( - &context.client_request, + context.client_request, &mut self.engine.router, self.engine.streaming, ) @@ -212,7 +212,7 @@ impl<'a> UpdateMulti<'a> { &mut self, context: &mut QueryEngineContext<'_>, ) -> Result<(), Error> { - let mut request = self.rewrite.delete.build_request(&context.client_request)?; + let mut request = self.rewrite.delete.build_request(context.client_request)?; self.route(&mut request, context)?; self.execute_request_internal(context, &mut request, false) @@ -223,7 +223,7 @@ impl<'a> UpdateMulti<'a> { &mut self, context: &mut QueryEngineContext<'_>, ) -> Result, Error> { - let mut request = self.rewrite.select.build_request(&context.client_request)?; + let mut request = self.rewrite.select.build_request(context.client_request)?; self.route(&mut request, context)?; self.engine @@ -262,7 +262,7 @@ impl<'a> UpdateMulti<'a> { /// This is an optimization to avoid doing a multi-shard UPDATE when /// we don't have to. pub(super) fn is_same_shard(&self, context: &QueryEngineContext<'_>) -> Result { - let mut check = self.rewrite.check.build_request(&context.client_request)?; + let mut check = self.rewrite.check.build_request(context.client_request)?; self.route(&mut check, context)?; let new_shard = check.route().shard(); diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/update.rs b/pgdog/src/frontend/router/parser/rewrite/statement/update.rs index d8c123e2..f5a401c0 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/update.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/update.rs @@ -66,7 +66,7 @@ impl Statement { request.push(Parse::new_anonymous(&self.stmt).into()); request.push(Describe::new_statement("").into()); if let Some(params) = params { - request.push(self.rewrite_bind(¶ms)?.into()); + request.push(self.rewrite_bind(params)?.into()); request.push(Execute::new().into()); request.push(Sync.into()); } else { @@ -194,7 +194,7 @@ impl Insert { if value.is_null() { bind.push_param(Parameter::new_null(), Format::Text); } else { - bind.push_param(Parameter::new(&value), Format::Text); + bind.push_param(Parameter::new(value), Format::Text); } values_str.push(format!("${}", bind_idx + 1)); @@ -209,7 +209,7 @@ impl Insert { values.push(Node { node: Some(NodeEnum::ParamRef(ParamRef { - number: bind_idx as i32 + 1, + number: bind_idx + 1, ..Default::default() })), }); @@ -224,9 +224,9 @@ impl Insert { node: Some(NodeEnum::SelectStmt(Box::new(SelectStmt { target_list: vec![], from_clause: vec![], - limit_option: LimitOption::Default.try_into().unwrap(), + limit_option: LimitOption::Default.into(), where_clause: None, - op: SetOperation::SetopNone.try_into().unwrap(), + op: SetOperation::SetopNone.into(), values_lists: vec![Node { node: Some(NodeEnum::List(List { items: values })), }], @@ -234,11 +234,11 @@ impl Insert { }))), })), returning_list: self.returning_list.clone(), - r#override: OverridingKind::OverridingNotSet.try_into().unwrap(), + r#override: OverridingKind::OverridingNotSet.into(), ..Default::default() }; - let table = self.table.as_ref().map(|table| Table::from(table)).unwrap(); // SAFETY: We check that UPDATE has a table. + let table = self.table.as_ref().map(Table::from).unwrap(); // SAFETY: We check that UPDATE has a table. // This is probably one of the few places in the code where // we shouldn't use the parser. It's quicker to concatenate strings @@ -295,8 +295,7 @@ impl<'a> StatementRewrite<'a> { .stmt .stmts .first() - .map(|stmt| stmt.stmt.as_ref().map(|stmt| stmt.node.as_ref())) - .flatten() + .and_then(|stmt| stmt.stmt.as_ref().map(|stmt| stmt.node.as_ref())) .flatten(); let stmt = if let Some(NodeEnum::UpdateStmt(stmt)) = stmt { @@ -534,8 +533,8 @@ fn deparse_list(list: &[Node]) -> Result, Error> { let stmt = SelectStmt { target_list: list.to_vec(), - limit_option: LimitOption::Default.try_into().unwrap(), - op: SetOperation::SetopNone.try_into().unwrap(), + limit_option: LimitOption::Default.into(), + op: SetOperation::SetopNone.into(), ..Default::default() }; let string = parse_result(NodeEnum::SelectStmt(Box::new(stmt))) @@ -553,9 +552,9 @@ fn create_stmts(stmt: &UpdateStmt, new_value: &ResTarget) -> Result Result Result ParameterWithFormat<'a> { } pub fn parameter(&self) -> &Parameter { - &self.parameter + self.parameter } }