From 43c3b6a0a0942aef6b44b91764a9673123d54fcf Mon Sep 17 00:00:00 2001 From: dowzhong Date: Sat, 9 May 2026 01:26:18 +1000 Subject: [PATCH 1/8] feat: reprepare and retry prepared statements --- pgdog/src/backend/mod.rs | 1 + pgdog/src/backend/prepared_statements.rs | 12 + pgdog/src/backend/protocol/state.rs | 9 + pgdog/src/backend/server.rs | 320 ++++++++++++++++++++++- pgdog/src/backend/server_state.rs | 156 +++++++++++ pgdog/src/net/messages/describe.rs | 2 +- 6 files changed, 497 insertions(+), 3 deletions(-) create mode 100644 pgdog/src/backend/server_state.rs diff --git a/pgdog/src/backend/mod.rs b/pgdog/src/backend/mod.rs index 0ced5523c..617431388 100644 --- a/pgdog/src/backend/mod.rs +++ b/pgdog/src/backend/mod.rs @@ -15,6 +15,7 @@ pub mod replication; pub mod schema; pub mod server; pub mod server_options; +pub mod server_state; pub mod stats; pub use connect_reason::ConnectReason; diff --git a/pgdog/src/backend/prepared_statements.rs b/pgdog/src/backend/prepared_statements.rs index feaf0311f..30a5e27b2 100644 --- a/pgdog/src/backend/prepared_statements.rs +++ b/pgdog/src/backend/prepared_statements.rs @@ -103,6 +103,18 @@ impl PreparedStatements { self.state.add_ignore('1'); Ok(()) } + ProtocolMessage::Sync(_) => { + self.state.add_ignore('Z'); + Ok(()) + } + ProtocolMessage::Close(_) => { + self.state.add_ignore('3'); + Ok(()) + } + ProtocolMessage::Bind(_) => { + self.state.add_ignore('2'); + Ok(()) + } _ => Err(Error::UnsupportedHandleIgnore(request.code())), } } diff --git a/pgdog/src/backend/protocol/state.rs b/pgdog/src/backend/protocol/state.rs index a93b89a99..10414f46b 100644 --- a/pgdog/src/backend/protocol/state.rs +++ b/pgdog/src/backend/protocol/state.rs @@ -129,6 +129,15 @@ impl ProtocolState { self.simulated.push_back(message); } + /// Clear the state queue. Can be used for when a + /// statement fails and we have to retry it: + /// 1. Clear the state queue + /// 2. Fix up the statement (eg deallocate + reprepare) + /// 3. Run the statement again, building up the state again + pub(crate) fn clear(&mut self) { + self.queue.clear(); + } + /// Get a simulated message from the execution queue. /// /// Returns a message only if it should be returned at the current state diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index 02372cd18..fc802281a 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -18,7 +18,7 @@ use super::{ }; use crate::{ auth::{md5, scram::Client}, - backend::pool::stats::MemoryStats, + backend::{pool::stats::MemoryStats, server_state::ServerState}, config::AuthType, frontend::ClientRequest, net::{ @@ -26,7 +26,7 @@ use crate::{ hello::SslReply, Authentication, BackendKeyData, ErrorResponse, FromBytes, Message, ParameterStatus, Password, Protocol, Query, ReadyForQuery, Startup, Terminate, ToBytes, }, - Close, MessageBuffer, Parameter, ProtocolMessage, Sync, + Close, Flush, MessageBuffer, Parameter, ProtocolMessage, Sync, }, stats::memory::MemoryUsage, }; @@ -70,6 +70,7 @@ pub struct Server { /// "use the pool's configured `max_age`" (no jitter applied). /// Sampled once at creation by [`Server::apply_lifetime_jitter`]. max_age: Option, + state: ServerState, } impl MemoryUsage for Server { @@ -348,6 +349,7 @@ impl Server { disconnect_reason: None, password_attempts: 1, // This is going to be changed by parent caller. max_age: None, + state: ServerState::new(), }; server.stats.memory_used(server.memory_stats()); // Stream capacity. @@ -374,6 +376,8 @@ impl Server { self.stats.state(State::Active); + self.state.set_state(&client_request.messages); + for message in client_request.messages.iter() { self.send_one(message).await?; } @@ -463,14 +467,76 @@ impl Server { pub async fn read(&mut self) -> Result { let message = loop { if let Some(message) = self.prepared_statements.state_mut().get_simulated() { + self.state.process(message.code()); return Ok(message.backend(self.id)); } match self.stream_buffer.read(self.stream.as_mut().unwrap()).await { Ok(message) => { let message = message.stream(self.streaming).backend(self.id); + let code = message.code(); + + if code == 'E' { + // In the case of an "cached plan must not change result type" error, + // we want to re-prepare (deallocate, parse) and then retry this request + // while keeping the internal state. + let error = ErrorResponse::from_bytes(message.to_bytes()?)?; + if error.code == "0A000" + && error.message == "cached plan must not change result type" + { + let current_request_messages = self.state.get_current_query().to_vec(); + + let binds: Vec<&ProtocolMessage> = current_request_messages + .iter() + .filter(|msg| matches!(msg, ProtocolMessage::Bind(_))) + .collect(); + + let executes: usize = current_request_messages + .iter() + .filter(|msg| matches!(msg, ProtocolMessage::Execute(_))) + .count(); + + if binds.len() == 1 && executes == 1 { + let ProtocolMessage::Bind(bind) = binds[0] else { + unreachable!() + }; + + let was_in_cache = + self.prepared_statements.remove(bind.statement()); + + let renamed = self.prepared_statements.parse(bind.statement()); + + match (was_in_cache, renamed) { + (true, Some(renamed)) => { + self.prepared_statements.state_mut().clear(); + + self.send_ignore(&ProtocolMessage::Sync(Sync)).await?; + self.send_ignore(&ProtocolMessage::Close(Close::named( + renamed.name(), + ))) + .await?; + + for message in current_request_messages + [self.state.get_successfully_processed()..] + .iter() + { + self.send_one(&message).await?; + } + self.send_one(&Flush.into()).await?; + self.flush().await?; + continue; + } + _ => { + warn!("Tried to reprepare statement, but it was not available in cache"); + } + } + } + } + } + match self.prepared_statements.forward(&message) { Ok(forward) => { if forward { + self.state.process(code); break message; } } @@ -1181,6 +1247,7 @@ pub mod test { net::TcpListener, }; + use crate::frontend::PreparedStatements as FrontendPreparedStatements; use crate::{config::Memory, frontend::PreparedStatements, net::*}; use super::{Error, *}; @@ -1232,6 +1299,7 @@ pub mod test { sending_request: false, password_attempts: 1, max_age: None, + state: ServerState::new(), } } } @@ -1255,6 +1323,13 @@ pub mod test { .unwrap() } + /// Insert a prepared statement into the global cache so check_prepared can find it. + fn insert_global(name: &str, query: &str) -> String { + let parse = Parse::named(name, query); + let (_, rewritten_name) = FrontendPreparedStatements::global().write().insert(&parse); + rewritten_name + } + pub async fn test_replication_server() -> Server { Server::connect( &Address::new_test(), @@ -2766,6 +2841,247 @@ pub mod test { ); } + #[tokio::test] + async fn test_retry_prepared() { + let mut server = test_server().await; + + server + .execute( + "DROP TABLE IF EXISTS retry_prepared_1; CREATE TABLE IF NOT EXISTS retry_prepared_1 ( + id INTEGER, + value1 INTEGER + );", + ) + .await + .unwrap(); + + let name = "test".to_string(); + let parse = Parse::named(&name, "SELECT * FROM retry_prepared_1"); + + let name = insert_global("test", "SELECT * FROM retry_prepared_1"); + + server + .send( + &vec![ + ProtocolMessage::from(parse.clone()), + ProtocolMessage::from(Describe::new_statement(&name)), + Bind::new_statement("test").into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', 't', 'T', '2', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + + server + .send( + &vec![ProtocolMessage::from(Query::new( + "ALTER TABLE retry_prepared_1 ADD new_col_1 INTEGER", + ))] + .into(), + ) + .await + .unwrap(); + + for c in ['C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(c, msg.code()); + } + + server + .send(&vec![Bind::new_statement(&name).into(), Execute::new().into()].into()) + .await + .unwrap(); + + for c in ['2', 'C'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + server.send(&vec![Sync.into()].into()).await.unwrap(); + + for c in ['Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + } + + #[tokio::test] + async fn test_retry_prepared_with_parse() { + let mut server = test_server().await; + + server + .execute( + "DROP TABLE IF EXISTS retry_prepared_1; CREATE TABLE IF NOT EXISTS retry_prepared_1 ( + id INTEGER, + value1 INTEGER + );", + ) + .await + .unwrap(); + + let name = insert_global("test", "SELECT * FROM retry_prepared_1"); + + let parse = Parse::named(&name, "SELECT * FROM retry_prepared_1"); + + server + .send( + &vec![ + ProtocolMessage::from(parse.clone()), + ProtocolMessage::from(Describe::new_statement(&name)), + Bind::new_statement(&name).into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', 't', 'T', '2', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + + server + .send( + &vec![ProtocolMessage::from(Query::new( + "ALTER TABLE retry_prepared_1 ADD new_col_1 INTEGER", + ))] + .into(), + ) + .await + .unwrap(); + + for c in ['C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(c, msg.code()); + } + + assert!(server.done()); + + server + .send( + &vec![ + ProtocolMessage::from(parse.clone()), + Bind::new_statement(&name).into(), + Execute::new().into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', '2', 'C'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + server.send(&vec![Sync.into()].into()).await.unwrap(); + + for c in ['Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + } + + #[tokio::test] + async fn test_retry_prepared_with_parse_describe() { + let mut server = test_server().await; + + server + .execute( + "DROP TABLE IF EXISTS retry_prepared_1; CREATE TABLE IF NOT EXISTS retry_prepared_1 ( + id INTEGER, + value1 INTEGER + );", + ) + .await + .unwrap(); + + let name = insert_global("test", "SELECT * FROM retry_prepared_1"); + + let parse = Parse::named(&name, "SELECT * FROM retry_prepared_1"); + + server + .send( + &vec![ + ProtocolMessage::from(parse.clone()), + ProtocolMessage::from(Describe::new_statement(&name)), + Bind::new_statement(&name).into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', 't', 'T', '2', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + + server + .send( + &vec![ProtocolMessage::from(Query::new( + "ALTER TABLE retry_prepared_1 ADD new_col_1 INTEGER", + ))] + .into(), + ) + .await + .unwrap(); + + for c in ['C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(c, msg.code()); + } + + assert!(server.done()); + + server + .send( + &vec![ + ProtocolMessage::from(Describe::new_statement(&name)), + Bind::new_statement(&name).into(), + Execute::new().into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['t', 't', 'T', '2', 'C'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + server.send(&vec![Sync.into()].into()).await.unwrap(); + + for c in ['Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + } + #[tokio::test] async fn test_prepare_forces_sync_prepared_flag() { let mut server = test_server().await; diff --git a/pgdog/src/backend/server_state.rs b/pgdog/src/backend/server_state.rs new file mode 100644 index 000000000..987ffa143 --- /dev/null +++ b/pgdog/src/backend/server_state.rs @@ -0,0 +1,156 @@ +use crate::net::{Protocol, ProtocolMessage}; + +pub trait StateTransition { + fn process(self, message_code: char) -> impl StateTransition; +} + +#[derive(Debug)] +pub enum State { + RunningParse, + RunningBind, + RunningDescribe, + RunningExecute, + RunningClose, + RunningFlush, + RunningSync, + RunningQuery, + RunningCopy, + RunningCopyDone, + RunningCopyFail, +} + +impl From<&ProtocolMessage> for State { + fn from(value: &ProtocolMessage) -> Self { + match value { + ProtocolMessage::Parse(_) => State::RunningParse, + ProtocolMessage::Bind(_) => State::RunningBind, + ProtocolMessage::Describe(_) => State::RunningDescribe, + ProtocolMessage::Execute(_) => State::RunningExecute, + ProtocolMessage::Close(_) => State::RunningClose, + ProtocolMessage::Sync(_) => State::RunningSync, + ProtocolMessage::Query(_) => State::RunningQuery, + ProtocolMessage::Other(message) => match message.code() { + 'H' => State::RunningFlush, + _ => panic!("Unexpected other type {:?}", value.code()), + }, + ProtocolMessage::CopyData(_) => State::RunningCopy, + ProtocolMessage::CopyFail(_) => State::RunningCopyFail, + ProtocolMessage::CopyDone(_) => State::RunningCopyDone, + _ => panic!("Unexpected type {:?}", value.code()), + } + } +} + +#[derive(Debug)] +pub struct ServerState { + states: Vec, + active_state_index: usize, + current_request_messages: Vec, +} + +impl ServerState { + pub fn new() -> Self { + ServerState { + states: Vec::new(), + active_state_index: 0, + current_request_messages: Vec::new(), + } + } + + pub fn set_state(&mut self, messages: &[ProtocolMessage]) { + self.states = messages.iter().map(|msg| msg.into()).collect(); + self.current_request_messages = messages.to_vec(); + + self.active_state_index = 0; + } + + pub fn process(&mut self, message_code: char) { + if message_code == 'E' { + // Clear state when we get error + self.set_state(&[]); + } + + if self.active_state_index + 1 > self.states.len() { + return; + } + + let current_state = &self.states[self.active_state_index]; + match current_state { + State::RunningParse => { + if message_code == '1' { + self.active_state_index += 1; + return; + } + panic!("Received unexpected message {}", message_code) + } + State::RunningBind => { + if message_code == '2' { + self.active_state_index += 1; + return; + } + panic!("Received unexpected message {}", message_code) + } + State::RunningClose => { + if message_code == '3' { + self.active_state_index += 1; + return; + } + panic!("Received unexpected message {}", message_code) + } + State::RunningDescribe => match message_code { + 'T' | 'n' => self.active_state_index += 1, + 't' => (), + _ => panic!("Received unexpected message {}", message_code), + }, + State::RunningExecute => match message_code { + 'C' | 'I' | 's' => self.active_state_index += 1, + 'D' | 'N' => (), + _ => panic!("Received unexpected message {}", message_code), + }, + State::RunningFlush => (), + State::RunningSync => { + if message_code == 'Z' { + self.active_state_index += 1; + return; + } + panic!("Received unexpected message {}", message_code) + } + State::RunningQuery => { + if message_code == 'Z' { + self.active_state_index += 1; + return; + } + // TODO: panic on unexpected + } + State::RunningCopy => { + if message_code == 'C' { + self.active_state_index += 1; + return; + } + if message_code == 'E' { + return; + } + } + State::RunningCopyFail => { + // Backend will respond with an E, + // so nothing to do. + return; + } + State::RunningCopyDone => { + if message_code == 'Z' { + self.active_state_index += 1; + return; + } + // TODO: panic on unexpected + } + } + } + + pub fn get_successfully_processed(&self) -> usize { + self.active_state_index + } + + pub fn get_current_query(&self) -> &[ProtocolMessage] { + &self.current_request_messages + } +} diff --git a/pgdog/src/net/messages/describe.rs b/pgdog/src/net/messages/describe.rs index 79d5142cc..064d1bc87 100644 --- a/pgdog/src/net/messages/describe.rs +++ b/pgdog/src/net/messages/describe.rs @@ -132,7 +132,7 @@ mod test { let pool = pool(); let mut conn = pool.get(&Request::default()).await.unwrap(); let describe = Describe::new_portal(""); - conn.send(&vec![ProtocolMessage::from(describe.message().unwrap())].into()) + conn.send(&vec![ProtocolMessage::from(describe)].into()) .await .unwrap(); let res = conn.read().await.unwrap(); From a4dcb3ad2fba8a103b38b9860895391adb25c183 Mon Sep 17 00:00:00 2001 From: dowzhong Date: Sat, 9 May 2026 12:20:12 +1000 Subject: [PATCH 2/8] fix: handle function message --- pgdog/src/backend/server_state.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pgdog/src/backend/server_state.rs b/pgdog/src/backend/server_state.rs index 987ffa143..7f5c52446 100644 --- a/pgdog/src/backend/server_state.rs +++ b/pgdog/src/backend/server_state.rs @@ -17,6 +17,7 @@ pub enum State { RunningCopy, RunningCopyDone, RunningCopyFail, + RunningFunction, } impl From<&ProtocolMessage> for State { @@ -31,6 +32,7 @@ impl From<&ProtocolMessage> for State { ProtocolMessage::Query(_) => State::RunningQuery, ProtocolMessage::Other(message) => match message.code() { 'H' => State::RunningFlush, + 'F' => State::RunningFunction, _ => panic!("Unexpected other type {:?}", value.code()), }, ProtocolMessage::CopyData(_) => State::RunningCopy, @@ -143,6 +145,13 @@ impl ServerState { } // TODO: panic on unexpected } + State::RunningFunction => { + if message_code == 'Z' { + self.active_state_index += 1; + return; + } + // TODO: panic on unexpected + } } } From 51c14c91c5996ee05435da86d3cd3d216a4be7b6 Mon Sep 17 00:00:00 2001 From: dowzhong Date: Sat, 9 May 2026 12:49:04 +1000 Subject: [PATCH 3/8] fix: handle prepare --- pgdog/src/backend/server_state.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pgdog/src/backend/server_state.rs b/pgdog/src/backend/server_state.rs index 7f5c52446..7b4c94797 100644 --- a/pgdog/src/backend/server_state.rs +++ b/pgdog/src/backend/server_state.rs @@ -18,12 +18,14 @@ pub enum State { RunningCopyDone, RunningCopyFail, RunningFunction, + RunningPrepare, } impl From<&ProtocolMessage> for State { fn from(value: &ProtocolMessage) -> Self { match value { ProtocolMessage::Parse(_) => State::RunningParse, + ProtocolMessage::Prepare { .. } => State::RunningPrepare, ProtocolMessage::Bind(_) => State::RunningBind, ProtocolMessage::Describe(_) => State::RunningDescribe, ProtocolMessage::Execute(_) => State::RunningExecute, @@ -152,6 +154,13 @@ impl ServerState { } // TODO: panic on unexpected } + State::RunningPrepare => { + if message_code == 'Z' { + self.active_state_index += 1; + return; + } + // TODO: panic on unexpected + } } } From 9798f3585d83de1d38e3ecf96514c2755b35b9bb Mon Sep 17 00:00:00 2001 From: dowzhong Date: Sat, 9 May 2026 13:02:07 +1000 Subject: [PATCH 4/8] chore: rename function to fastpath + refactor if to match --- pgdog/src/backend/server_state.rs | 70 +++++++++++++------------------ 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/pgdog/src/backend/server_state.rs b/pgdog/src/backend/server_state.rs index 7b4c94797..5b0b38a59 100644 --- a/pgdog/src/backend/server_state.rs +++ b/pgdog/src/backend/server_state.rs @@ -17,7 +17,7 @@ pub enum State { RunningCopy, RunningCopyDone, RunningCopyFail, - RunningFunction, + RunningFastpath, RunningPrepare, } @@ -34,13 +34,12 @@ impl From<&ProtocolMessage> for State { ProtocolMessage::Query(_) => State::RunningQuery, ProtocolMessage::Other(message) => match message.code() { 'H' => State::RunningFlush, - 'F' => State::RunningFunction, _ => panic!("Unexpected other type {:?}", value.code()), }, ProtocolMessage::CopyData(_) => State::RunningCopy, ProtocolMessage::CopyFail(_) => State::RunningCopyFail, ProtocolMessage::CopyDone(_) => State::RunningCopyDone, - _ => panic!("Unexpected type {:?}", value.code()), + ProtocolMessage::Fastpath(_) => State::RunningFastpath, } } } @@ -112,54 +111,45 @@ impl ServerState { _ => panic!("Received unexpected message {}", message_code), }, State::RunningFlush => (), - State::RunningSync => { - if message_code == 'Z' { - self.active_state_index += 1; - return; - } - panic!("Received unexpected message {}", message_code) - } - State::RunningQuery => { - if message_code == 'Z' { - self.active_state_index += 1; - return; - } - // TODO: panic on unexpected - } - State::RunningCopy => { - if message_code == 'C' { - self.active_state_index += 1; - return; - } - if message_code == 'E' { - return; - } - } + State::RunningSync => match message_code { + 'Z' => self.active_state_index += 1, + _ => panic!("Received unexpected message {}", message_code), + }, + State::RunningQuery => match message_code { + 'Z' => self.active_state_index += 1, + _ => panic!("Received unexpected message {}", message_code), + }, + State::RunningCopy => match message_code { + 'C' => self.active_state_index += 1, + 'E' => _, + }, State::RunningCopyFail => { // Backend will respond with an E, // so nothing to do. - return; } State::RunningCopyDone => { - if message_code == 'Z' { - self.active_state_index += 1; - return; + match message_code { + 'Z' => self.active_state_index += 1, + _ => { + // TODO: panic on unexpected + } } - // TODO: panic on unexpected } - State::RunningFunction => { - if message_code == 'Z' { - self.active_state_index += 1; - return; + State::RunningFastpath => { + match message_code { + 'Z' => self.active_state_index += 1, + _ => { + // TODO: panic on unexpected + } } - // TODO: panic on unexpected } State::RunningPrepare => { - if message_code == 'Z' { - self.active_state_index += 1; - return; + match message_code { + 'Z' => self.active_state_index += 1, + _ => { + // TODO: panic on unexpected + } } - // TODO: panic on unexpected } } } From 26d7ae9a9624fa5825e03c7e25940fe199103db5 Mon Sep 17 00:00:00 2001 From: dowzhong Date: Sat, 9 May 2026 13:17:15 +1000 Subject: [PATCH 5/8] fix: oops... --- pgdog/src/backend/server_state.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgdog/src/backend/server_state.rs b/pgdog/src/backend/server_state.rs index 5b0b38a59..fce732905 100644 --- a/pgdog/src/backend/server_state.rs +++ b/pgdog/src/backend/server_state.rs @@ -121,7 +121,7 @@ impl ServerState { }, State::RunningCopy => match message_code { 'C' => self.active_state_index += 1, - 'E' => _, + 'E' => (), }, State::RunningCopyFail => { // Backend will respond with an E, From d5ba779680756e2e1860798ebf3e1ac796efb0b5 Mon Sep 17 00:00:00 2001 From: dowzhong Date: Sat, 9 May 2026 13:19:33 +1000 Subject: [PATCH 6/8] fix: exhaustive match --- pgdog/src/backend/server_state.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgdog/src/backend/server_state.rs b/pgdog/src/backend/server_state.rs index fce732905..f82a8398c 100644 --- a/pgdog/src/backend/server_state.rs +++ b/pgdog/src/backend/server_state.rs @@ -121,7 +121,7 @@ impl ServerState { }, State::RunningCopy => match message_code { 'C' => self.active_state_index += 1, - 'E' => (), + _ => (), }, State::RunningCopyFail => { // Backend will respond with an E, From c152b2ef1ee482a3c0de0bf74bf590ccced6c989 Mon Sep 17 00:00:00 2001 From: dowzhong Date: Sat, 9 May 2026 13:26:46 +1000 Subject: [PATCH 7/8] fix: do not panic on non Z for running query --- pgdog/src/backend/server_state.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgdog/src/backend/server_state.rs b/pgdog/src/backend/server_state.rs index f82a8398c..fdb918e6a 100644 --- a/pgdog/src/backend/server_state.rs +++ b/pgdog/src/backend/server_state.rs @@ -117,7 +117,7 @@ impl ServerState { }, State::RunningQuery => match message_code { 'Z' => self.active_state_index += 1, - _ => panic!("Received unexpected message {}", message_code), + _ => (), }, State::RunningCopy => match message_code { 'C' => self.active_state_index += 1, From 3827ee6884557b5411d520045ca866f646c0f4d0 Mon Sep 17 00:00:00 2001 From: dowzhong Date: Mon, 11 May 2026 18:24:02 +1000 Subject: [PATCH 8/8] chore: refactor and add some tests --- pgdog/src/backend/server.rs | 4 +- pgdog/src/backend/server_state.rs | 523 +++++++++++++++++++++++++++--- 2 files changed, 475 insertions(+), 52 deletions(-) diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index fc802281a..2b8dd79e2 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -467,7 +467,7 @@ impl Server { pub async fn read(&mut self) -> Result { let message = loop { if let Some(message) = self.prepared_statements.state_mut().get_simulated() { - self.state.process(message.code()); + self.state.process(message.code())?; return Ok(message.backend(self.id)); } match self.stream_buffer.read(self.stream.as_mut().unwrap()).await { @@ -536,7 +536,7 @@ impl Server { match self.prepared_statements.forward(&message) { Ok(forward) => { if forward { - self.state.process(code); + self.state.process(code)?; break message; } } diff --git a/pgdog/src/backend/server_state.rs b/pgdog/src/backend/server_state.rs index fdb918e6a..d60c22835 100644 --- a/pgdog/src/backend/server_state.rs +++ b/pgdog/src/backend/server_state.rs @@ -1,3 +1,4 @@ +use super::Error; use crate::net::{Protocol, ProtocolMessage}; pub trait StateTransition { @@ -67,53 +68,44 @@ impl ServerState { self.active_state_index = 0; } - pub fn process(&mut self, message_code: char) { + pub fn process(&mut self, message_code: char) -> Result<(), Error> { if message_code == 'E' { // Clear state when we get error self.set_state(&[]); } if self.active_state_index + 1 > self.states.len() { - return; + return Ok(()); } let current_state = &self.states[self.active_state_index]; match current_state { - State::RunningParse => { - if message_code == '1' { - self.active_state_index += 1; - return; - } - panic!("Received unexpected message {}", message_code) - } - State::RunningBind => { - if message_code == '2' { - self.active_state_index += 1; - return; - } - panic!("Received unexpected message {}", message_code) - } - State::RunningClose => { - if message_code == '3' { - self.active_state_index += 1; - return; - } - panic!("Received unexpected message {}", message_code) - } + State::RunningParse => match message_code { + '1' => self.active_state_index += 1, + _ => return Err(Error::UnexpectedMessage(message_code)), + }, + State::RunningBind => match message_code { + '2' => self.active_state_index += 1, + _ => return Err(Error::UnexpectedMessage(message_code)), + }, + State::RunningClose => match message_code { + '3' => self.active_state_index += 1, + _ => return Err(Error::UnexpectedMessage(message_code)), + }, State::RunningDescribe => match message_code { 'T' | 'n' => self.active_state_index += 1, 't' => (), - _ => panic!("Received unexpected message {}", message_code), + _ => return Err(Error::UnexpectedMessage(message_code)), }, State::RunningExecute => match message_code { 'C' | 'I' | 's' => self.active_state_index += 1, 'D' | 'N' => (), - _ => panic!("Received unexpected message {}", message_code), + _ => return Err(Error::UnexpectedMessage(message_code)), }, State::RunningFlush => (), State::RunningSync => match message_code { 'Z' => self.active_state_index += 1, - _ => panic!("Received unexpected message {}", message_code), + _ => return Err(Error::UnexpectedMessage(message_code)), }, State::RunningQuery => match message_code { 'Z' => self.active_state_index += 1, @@ -127,31 +119,21 @@ impl ServerState { // Backend will respond with an E, // so nothing to do. } - State::RunningCopyDone => { - match message_code { - 'Z' => self.active_state_index += 1, - _ => { - // TODO: panic on unexpected - } - } - } - State::RunningFastpath => { - match message_code { - 'Z' => self.active_state_index += 1, - _ => { - // TODO: panic on unexpected - } - } - } - State::RunningPrepare => { - match message_code { - 'Z' => self.active_state_index += 1, - _ => { - // TODO: panic on unexpected - } - } - } + State::RunningCopyDone => match message_code { + 'Z' => self.active_state_index += 1, + _ => (), + }, + State::RunningFastpath => match message_code { + 'Z' => self.active_state_index += 1, + _ => (), + }, + State::RunningPrepare => match message_code { + 'Z' => self.active_state_index += 1, + _ => (), + }, } + + Ok(()) } pub fn get_successfully_processed(&self) -> usize { @@ -162,3 +144,444 @@ impl ServerState { &self.current_request_messages } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::net::messages::{ + Bind, Close, CopyData, CopyDone, CopyFail, Describe, Execute, Fastpath, Flush, FromBytes, + Parse, Query, Sync, + }; + use bytes::{BufMut, BytesMut}; + + fn make_fastpath() -> ProtocolMessage { + let mut buf = BytesMut::new(); + buf.put_u8(b'F'); + buf.put_i32(4); + ProtocolMessage::Fastpath( + Fastpath::from_bytes(buf.freeze()).expect("fastpath shouldn't fail"), + ) + } + + #[test] + fn new_state_starts_empty() { + let state = ServerState::new(); + assert_eq!(state.get_successfully_processed(), 0); + assert!(state.get_current_query().is_empty()); + } + + #[test] + fn set_state_resets_index() { + let mut state = ServerState::new(); + let messages = vec![ + ProtocolMessage::Query(Query::new("SELECT 1")), + ProtocolMessage::Sync(Sync::new()), + ]; + state.set_state(&messages); + // Advance past the query + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + + // Reset with new messages + state.set_state(&[ProtocolMessage::Query(Query::new("SELECT 2"))]); + assert_eq!(state.get_successfully_processed(), 0); + } + + #[test] + fn set_state_stores_messages() { + let mut state = ServerState::new(); + let messages = vec![ + ProtocolMessage::Parse(Parse::new_anonymous("SELECT 1")), + ProtocolMessage::Bind(Bind::default()), + ProtocolMessage::Execute(Execute::new()), + ProtocolMessage::Sync(Sync::new()), + ]; + state.set_state(&messages); + assert_eq!(state.get_current_query().len(), 4); + } + + #[test] + fn state_from_parse() { + let msg = ProtocolMessage::Parse(Parse::new_anonymous("SELECT 1")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningParse)); + } + + #[test] + fn state_from_bind() { + let msg = ProtocolMessage::Bind(Bind::default()); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningBind)); + } + + #[test] + fn state_from_describe() { + let msg = ProtocolMessage::Describe(Describe::new_statement("")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningDescribe)); + } + + #[test] + fn state_from_execute() { + let msg = ProtocolMessage::Execute(Execute::new()); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningExecute)); + } + + #[test] + fn state_from_close() { + let msg = ProtocolMessage::Close(Close::named("stmt")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningClose)); + } + + #[test] + fn state_from_sync() { + let msg = ProtocolMessage::Sync(Sync::new()); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningSync)); + } + + #[test] + fn state_from_query() { + let msg = ProtocolMessage::Query(Query::new("SELECT 1")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningQuery)); + } + + #[test] + fn state_from_copy_data() { + let msg = ProtocolMessage::CopyData(CopyData::new(b"row data")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningCopy)); + } + + #[test] + fn state_from_copy_fail() { + let msg = ProtocolMessage::CopyFail(CopyFail::new("error")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningCopyFail)); + } + + #[test] + fn state_from_copy_done() { + let msg = ProtocolMessage::CopyDone(CopyDone); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningCopyDone)); + } + + #[test] + fn state_from_flush() { + let msg = ProtocolMessage::Other(Flush.message().unwrap()); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningFlush)); + } + + #[test] + fn state_from_fastpath() { + let msg = make_fastpath(); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningFastpath)); + } + + #[test] + fn state_from_prepare() { + let msg = ProtocolMessage::Prepare { + name: "stmt1".to_string(), + statement: "SELECT 1".to_string(), + }; + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningPrepare)); + } + + #[test] + fn simple_query_advances_on_ready_for_query() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Query(Query::new("SELECT 1"))]); + + // Data rows and other messages don't advance + state.process('D').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + state.process('T').unwrap(); // RowDescription + assert_eq!(state.get_successfully_processed(), 0); + state.process('C').unwrap(); // CommandComplete + assert_eq!(state.get_successfully_processed(), 0); + + // ReadyForQuery advances + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn extended_query_full_cycle() { + let mut state = ServerState::new(); + state.set_state(&[ + ProtocolMessage::Parse(Parse::new_anonymous("SELECT $1")), + ProtocolMessage::Bind(Bind::default()), + ProtocolMessage::Describe(Describe::new_statement("")), + ProtocolMessage::Execute(Execute::new()), + ProtocolMessage::Sync(Sync::new()), + ]); + + // ParseComplete + state.process('1').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + + // BindComplete + state.process('2').unwrap(); + assert_eq!(state.get_successfully_processed(), 2); + + // RowDescription (from Describe) + state.process('T').unwrap(); + assert_eq!(state.get_successfully_processed(), 3); + + // DataRow (from Execute) - doesn't advance + state.process('D').unwrap(); + assert_eq!(state.get_successfully_processed(), 3); + + // CommandComplete (from Execute) - advances + state.process('C').unwrap(); + assert_eq!(state.get_successfully_processed(), 4); + + // ReadyForQuery (from Sync) + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 5); + } + + #[test] + fn describe_advances_on_no_data() { + let mut state = ServerState::new(); + state.set_state(&[ + ProtocolMessage::Describe(Describe::new_statement("")), + ProtocolMessage::Sync(Sync::new()), + ]); + + // NoData response + state.process('n').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn describe_ignores_parameter_description() { + let mut state = ServerState::new(); + state.set_state(&[ + ProtocolMessage::Describe(Describe::new_statement("")), + ProtocolMessage::Sync(Sync::new()), + ]); + + // ParameterDescription ('t') doesn't advance + state.process('t').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + + // RowDescription advances + state.process('T').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn execute_advances_on_empty_query() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Execute(Execute::new())]); + + // EmptyQueryResponse + state.process('I').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn execute_advances_on_portal_suspended() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Execute(Execute::new())]); + + // PortalSuspended + state.process('s').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn execute_ignores_notice_response() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Execute(Execute::new())]); + + // NoticeResponse doesn't advance + state.process('N').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + + // CommandComplete advances + state.process('C').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn close_advances_on_close_complete() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Close(Close::named("stmt"))]); + + // CloseComplete + state.process('3').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn error_response_clears_state() { + let mut state = ServerState::new(); + state.set_state(&[ + ProtocolMessage::Parse(Parse::new_anonymous("SELECT 1")), + ProtocolMessage::Bind(Bind::default()), + ProtocolMessage::Sync(Sync::new()), + ]); + + // ParseComplete + state.process('1').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + + // ErrorResponse clears everything + state.process('E').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + assert!(state.get_current_query().is_empty()); + } + + #[test] + fn copy_data_advances_on_command_complete() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::CopyData(CopyData::new(b"data"))]); + + // Other messages don't advance + state.process('d').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + + // CommandComplete advances + state.process('C').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn copy_done_advances_on_ready_for_query() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::CopyDone(CopyDone)]); + + state.process('C').unwrap(); // CommandComplete doesn't advance + assert_eq!(state.get_successfully_processed(), 0); + + state.process('Z').unwrap(); // ReadyForQuery advances + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn copy_fail_does_not_advance() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::CopyFail(CopyFail::new("abort"))]); + + // Backend responds with ErrorResponse, but CopyFail state itself doesn't advance + state.process('E').unwrap(); + // Error clears state entirely + assert_eq!(state.get_successfully_processed(), 0); + } + + #[test] + fn flush_never_advances() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Other(Flush.message().unwrap())]); + + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + state.process('T').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + } + + #[test] + fn fastpath_advances_on_ready_for_query() { + let mut state = ServerState::new(); + state.set_state(&[make_fastpath()]); + + state.process('V').unwrap(); // FunctionCallResponse doesn't advance + assert_eq!(state.get_successfully_processed(), 0); + + state.process('Z').unwrap(); // ReadyForQuery advances + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn prepare_advances_on_ready_for_query() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Prepare { + name: "stmt1".to_string(), + statement: "SELECT 1".to_string(), + }]); + + state.process('T').unwrap(); // RowDescription doesn't advance + assert_eq!(state.get_successfully_processed(), 0); + + state.process('Z').unwrap(); // ReadyForQuery advances + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn process_does_nothing_when_all_states_consumed() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Sync(Sync::new())]); + + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + state.process('T').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn process_does_nothing_on_empty_state() { + let mut state = ServerState::new(); + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + } + + #[test] + fn parse_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Parse(Parse::new_anonymous("SELECT 1"))]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn bind_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Bind(Bind::default())]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn close_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Close(Close::named("s"))]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn describe_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Describe(Describe::new_statement(""))]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn execute_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Execute(Execute::new())]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn sync_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Sync(Sync::new())]); + let result = state.process('T'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } +}