Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pgdog/src/frontend/client/query_engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ impl QueryEngine {
.client_request
.route // Admin commands don't have a route.
.as_mut()
.map(|route| route.take_explain())
.flatten()
.and_then(|route| route.take_explain())
{
if config().config.general.expanded_explain {
self.pending_explain = Some(ExplainResponseState::new(trace));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl ForwardCheck {
pub(crate) fn new(request: &ClientRequest) -> Self {
Self {
codes: request.iter().map(|m| m.code()).collect(),
describe: request.iter().find(|m| m.code() == 'D').is_some(),
describe: request.iter().any(|m| m.code() == 'D'),
sent: HashSet::default(),
}
}
Expand Down
132 changes: 119 additions & 13 deletions pgdog/src/frontend/client/query_engine/multi_step/test/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,13 @@ async fn test_transaction_required() {
async fn test_move_rows_simple() {
let mut client = TestClient::new_rewrites(Parameters::default()).await;

let shard_0_id = client.random_id_for_shard(0);
let shard_1_id = client.random_id_for_shard(1);

client
.send_simple(Query::new(format!(
"INSERT INTO sharded (id) VALUES (1) ON CONFLICT(id) DO NOTHING",
"INSERT INTO sharded (id) VALUES ({}) ON CONFLICT(id) DO NOTHING",
shard_0_id
)))
.await;
client.read_until('Z').await.unwrap();
Expand All @@ -234,14 +238,16 @@ async fn test_move_rows_simple() {
client.read_until('Z').await.unwrap();

client
.try_send_simple(Query::new(
"UPDATE sharded SET id = 11 WHERE id = 1 RETURNING id",
))
.try_send_simple(Query::new(format!(
"UPDATE sharded SET id = {} WHERE id = {} RETURNING id",
shard_1_id, shard_0_id
)))
.await
.unwrap();

let reply = client.read_until('Z').await.unwrap();

let shard_1_id_str = shard_1_id.to_string();
reply
.into_iter()
.zip(['T', 'D', 'C', 'Z'])
Expand All @@ -266,7 +272,7 @@ async fn test_move_rows_simple() {
),
'D' => assert_eq!(
DataRow::try_from(message).unwrap().column(0).unwrap(),
"11".as_bytes()
shard_1_id_str.as_bytes()
),
_ => unreachable!(),
}
Expand All @@ -277,9 +283,13 @@ async fn test_move_rows_simple() {
async fn test_move_rows_extended() {
let mut client = TestClient::new_rewrites(Parameters::default()).await;

let shard_0_id = client.random_id_for_shard(0);
let shard_1_id = client.random_id_for_shard(1);

client
.send_simple(Query::new(format!(
"INSERT INTO sharded (id) VALUES (1) ON CONFLICT(id) DO NOTHING",
"INSERT INTO sharded (id) VALUES ({}) ON CONFLICT(id) DO NOTHING",
shard_0_id
)))
.await;
client.read_until('Z').await.unwrap();
Expand All @@ -296,8 +306,8 @@ async fn test_move_rows_extended() {
.send(Bind::new_params(
"",
&[
Parameter::new("1".as_bytes()),
Parameter::new("11".as_bytes()),
Parameter::new(shard_0_id.to_string().as_bytes()),
Parameter::new(shard_1_id.to_string().as_bytes()),
],
))
.await;
Expand All @@ -307,6 +317,7 @@ async fn test_move_rows_extended() {

let reply = client.read_until('Z').await.unwrap();

let shard_1_id_str = shard_1_id.to_string();
reply
.into_iter()
.zip(['1', '2', 'D', 'C', 'Z'])
Expand All @@ -323,7 +334,7 @@ async fn test_move_rows_extended() {
),
'D' => assert_eq!(
DataRow::try_from(message).unwrap().column(0).unwrap(),
"11".as_bytes()
shard_1_id_str.as_bytes()
),
'1' | '2' => (),
_ => unreachable!(),
Expand All @@ -336,9 +347,13 @@ async fn test_move_rows_prepared() {
crate::logger();
let mut client = TestClient::new_rewrites(Parameters::default()).await;

let shard_0_id = client.random_id_for_shard(0);
let shard_1_id = client.random_id_for_shard(1);

client
.send_simple(Query::new(format!(
"INSERT INTO sharded (id) VALUES (1) ON CONFLICT(id) DO NOTHING",
"INSERT INTO sharded (id) VALUES ({}) ON CONFLICT(id) DO NOTHING",
shard_0_id
)))
.await;
client.read_until('Z').await.unwrap();
Expand Down Expand Up @@ -383,8 +398,8 @@ async fn test_move_rows_prepared() {
.send(Bind::new_params(
"__test_1",
&[
Parameter::new("1".as_bytes()),
Parameter::new("11".as_bytes()),
Parameter::new(shard_0_id.to_string().as_bytes()),
Parameter::new(shard_1_id.to_string().as_bytes()),
],
))
.await;
Expand All @@ -394,6 +409,7 @@ async fn test_move_rows_prepared() {

let reply = client.read_until('Z').await.unwrap();

let shard_1_id_str = shard_1_id.to_string();
reply
.into_iter()
.zip(['2', 'D', 'C', 'Z'])
Expand All @@ -410,7 +426,7 @@ async fn test_move_rows_prepared() {
),
'D' => assert_eq!(
DataRow::try_from(message).unwrap().column(0).unwrap(),
"11".as_bytes()
shard_1_id_str.as_bytes()
),
'1' | '2' => (),
_ => unreachable!(),
Expand Down Expand Up @@ -463,3 +479,93 @@ async fn test_same_shard_binary() {
}
});
}

#[tokio::test]
async fn test_update_with_expr() {
// Test that UPDATE with expression columns (not simple values) works correctly.
// This validates the bind parameter alignment fix where expression columns
// don't consume bind parameter slots.
//
// Note: Expressions that reference the original row's columns (like COALESCE(value, 'default'))
// won't work because they're inserted literally into the INSERT statement where those
// columns don't exist. Only standalone expressions like 'prefix' || 'suffix' work.
let mut client = TestClient::new_rewrites(Parameters::default()).await;

// Use random IDs to avoid conflicts with other tests
let shard_0_id = client.random_id_for_shard(0);
let shard_1_id = client.random_id_for_shard(1);

// Insert a row into shard 0
client
.send_simple(Query::new(format!(
"INSERT INTO sharded (id, value) VALUES ({}, 'original') ON CONFLICT(id) DO UPDATE SET value = 'original'",
shard_0_id
)))
.await;
client.read_until('Z').await.unwrap();

client.send_simple(Query::new("BEGIN")).await;
client.read_until('Z').await.unwrap();

// UPDATE that moves row to different shard with an expression column.
// Use a standalone expression that doesn't reference any columns.
client
.try_send_simple(Query::new(format!(
"UPDATE sharded SET id = {}, value = 'prefix' || '_suffix' WHERE id = {} RETURNING id, value",
shard_1_id, shard_0_id
)))
.await
.unwrap();

let reply = client.read_until('Z').await.unwrap();

let shard_1_id_str = shard_1_id.to_string();
reply
.into_iter()
.zip(['T', 'D', 'C', 'Z'])
.for_each(|(message, code)| {
assert_eq!(message.code(), code);
match code {
'C' => assert_eq!(
CommandComplete::try_from(message).unwrap().command(),
"UPDATE 1"
),
'Z' => assert!(
ReadyForQuery::try_from(message).unwrap().state().unwrap()
== TransactionState::InTrasaction
),
'T' => {
let rd = RowDescription::try_from(message).unwrap();
assert_eq!(rd.field(0).unwrap().name, "id");
assert_eq!(rd.field(1).unwrap().name, "value");
}
'D' => {
let dr = DataRow::try_from(message).unwrap();
assert_eq!(dr.column(0).unwrap(), shard_1_id_str.as_bytes());
// The value should be 'prefix_suffix' from the expression
assert_eq!(dr.column(1).unwrap(), "prefix_suffix".as_bytes());
}
_ => unreachable!(),
}
});

client.send_simple(Query::new("COMMIT")).await;
client.read_until('Z').await.unwrap();

// Verify the row was actually moved to the new shard with correct values
client
.send_simple(Query::new(format!(
"SELECT id, value FROM sharded WHERE id = {}",
shard_1_id
)))
.await;
let reply = client.read_until('Z').await.unwrap();

let data_row = reply
.iter()
.find(|m| m.code() == 'D')
.expect("should have data row");
let dr = DataRow::try_from(data_row.clone()).unwrap();
assert_eq!(dr.column(0).unwrap(), shard_1_id_str.as_bytes());
assert_eq!(dr.column(1).unwrap(), "prefix_suffix".as_bytes());
}
18 changes: 9 additions & 9 deletions pgdog/src/frontend/client/query_engine/multi_step/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ impl<'a> UpdateMulti<'a> {
self.engine
.error_response(context, ErrorResponse::from_err(&err))
.await?;
return Ok(());
Ok(())
} else {
// These are bad, disconnecting the client.
return Err(err.into());
Err(err)
}
}
}
Expand All @@ -57,7 +57,7 @@ impl<'a> UpdateMulti<'a> {
&mut self,
context: &mut QueryEngineContext<'_>,
) -> Result<(), Error> {
let mut check = self.rewrite.check.build_request(&context.client_request)?;
let mut check = self.rewrite.check.build_request(context.client_request)?;
self.route(&mut check, context)?;

// The new row is on the same shard as the old row
Expand Down Expand Up @@ -100,7 +100,7 @@ impl<'a> UpdateMulti<'a> {
row: Row,
) -> Result<(), Error> {
let mut request = self.rewrite.insert.build_request(
&context.client_request,
context.client_request,
&row.row_description,
&row.data_row,
)?;
Expand Down Expand Up @@ -168,7 +168,7 @@ impl<'a> UpdateMulti<'a> {
.handle_client_request(request, &mut Router::default(), false)
.await?;

let mut checker = ForwardCheck::new(&context.client_request);
let mut checker = ForwardCheck::new(context.client_request);

while self.engine.backend.has_more_messages() {
let message = self.engine.read_server_message(context).await?;
Expand All @@ -194,7 +194,7 @@ impl<'a> UpdateMulti<'a> {
self.engine
.backend
.handle_client_request(
&context.client_request,
context.client_request,
&mut self.engine.router,
self.engine.streaming,
)
Expand All @@ -212,7 +212,7 @@ impl<'a> UpdateMulti<'a> {
&mut self,
context: &mut QueryEngineContext<'_>,
) -> Result<(), Error> {
let mut request = self.rewrite.delete.build_request(&context.client_request)?;
let mut request = self.rewrite.delete.build_request(context.client_request)?;
self.route(&mut request, context)?;

self.execute_request_internal(context, &mut request, false)
Expand All @@ -223,7 +223,7 @@ impl<'a> UpdateMulti<'a> {
&mut self,
context: &mut QueryEngineContext<'_>,
) -> Result<Option<Row>, Error> {
let mut request = self.rewrite.select.build_request(&context.client_request)?;
let mut request = self.rewrite.select.build_request(context.client_request)?;
self.route(&mut request, context)?;

self.engine
Expand Down Expand Up @@ -262,7 +262,7 @@ impl<'a> UpdateMulti<'a> {
/// This is an optimization to avoid doing a multi-shard UPDATE when
/// we don't have to.
pub(super) fn is_same_shard(&self, context: &QueryEngineContext<'_>) -> Result<bool, Error> {
let mut check = self.rewrite.check.build_request(&context.client_request)?;
let mut check = self.rewrite.check.build_request(context.client_request)?;
self.route(&mut check, context)?;

let new_shard = check.route().shard();
Expand Down
Loading
Loading