diff --git a/pgdog/src/backend/mod.rs b/pgdog/src/backend/mod.rs index 0ced5523c..617431388 100644 --- a/pgdog/src/backend/mod.rs +++ b/pgdog/src/backend/mod.rs @@ -15,6 +15,7 @@ pub mod replication; pub mod schema; pub mod server; pub mod server_options; +pub mod server_state; pub mod stats; pub use connect_reason::ConnectReason; diff --git a/pgdog/src/backend/prepared_statements.rs b/pgdog/src/backend/prepared_statements.rs index feaf0311f..30a5e27b2 100644 --- a/pgdog/src/backend/prepared_statements.rs +++ b/pgdog/src/backend/prepared_statements.rs @@ -103,6 +103,18 @@ impl PreparedStatements { self.state.add_ignore('1'); Ok(()) } + ProtocolMessage::Sync(_) => { + self.state.add_ignore('Z'); + Ok(()) + } + ProtocolMessage::Close(_) => { + self.state.add_ignore('3'); + Ok(()) + } + ProtocolMessage::Bind(_) => { + self.state.add_ignore('2'); + Ok(()) + } _ => Err(Error::UnsupportedHandleIgnore(request.code())), } } diff --git a/pgdog/src/backend/protocol/state.rs b/pgdog/src/backend/protocol/state.rs index a93b89a99..10414f46b 100644 --- a/pgdog/src/backend/protocol/state.rs +++ b/pgdog/src/backend/protocol/state.rs @@ -129,6 +129,15 @@ impl ProtocolState { self.simulated.push_back(message); } + /// Clear the state queue. Can be used for when a + /// statement fails and we have to retry it: + /// 1. Clear the state queue + /// 2. Fix up the statement (eg deallocate + reprepare) + /// 3. Run the statement again, building up the state again + pub(crate) fn clear(&mut self) { + self.queue.clear(); + } + /// Get a simulated message from the execution queue. /// /// Returns a message only if it should be returned at the current state diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index 02372cd18..2b8dd79e2 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -18,7 +18,7 @@ use super::{ }; use crate::{ auth::{md5, scram::Client}, - backend::pool::stats::MemoryStats, + backend::{pool::stats::MemoryStats, server_state::ServerState}, config::AuthType, frontend::ClientRequest, net::{ @@ -26,7 +26,7 @@ use crate::{ hello::SslReply, Authentication, BackendKeyData, ErrorResponse, FromBytes, Message, ParameterStatus, Password, Protocol, Query, ReadyForQuery, Startup, Terminate, ToBytes, }, - Close, MessageBuffer, Parameter, ProtocolMessage, Sync, + Close, Flush, MessageBuffer, Parameter, ProtocolMessage, Sync, }, stats::memory::MemoryUsage, }; @@ -70,6 +70,7 @@ pub struct Server { /// "use the pool's configured `max_age`" (no jitter applied). /// Sampled once at creation by [`Server::apply_lifetime_jitter`]. max_age: Option, + state: ServerState, } impl MemoryUsage for Server { @@ -348,6 +349,7 @@ impl Server { disconnect_reason: None, password_attempts: 1, // This is going to be changed by parent caller. max_age: None, + state: ServerState::new(), }; server.stats.memory_used(server.memory_stats()); // Stream capacity. @@ -374,6 +376,8 @@ impl Server { self.stats.state(State::Active); + self.state.set_state(&client_request.messages); + for message in client_request.messages.iter() { self.send_one(message).await?; } @@ -463,14 +467,76 @@ impl Server { pub async fn read(&mut self) -> Result { let message = loop { if let Some(message) = self.prepared_statements.state_mut().get_simulated() { + self.state.process(message.code())?; return Ok(message.backend(self.id)); } match self.stream_buffer.read(self.stream.as_mut().unwrap()).await { Ok(message) => { let message = message.stream(self.streaming).backend(self.id); + let code = message.code(); + + if code == 'E' { + // In the case of an "cached plan must not change result type" error, + // we want to re-prepare (deallocate, parse) and then retry this request + // while keeping the internal state. + let error = ErrorResponse::from_bytes(message.to_bytes()?)?; + if error.code == "0A000" + && error.message == "cached plan must not change result type" + { + let current_request_messages = self.state.get_current_query().to_vec(); + + let binds: Vec<&ProtocolMessage> = current_request_messages + .iter() + .filter(|msg| matches!(msg, ProtocolMessage::Bind(_))) + .collect(); + + let executes: usize = current_request_messages + .iter() + .filter(|msg| matches!(msg, ProtocolMessage::Execute(_))) + .count(); + + if binds.len() == 1 && executes == 1 { + let ProtocolMessage::Bind(bind) = binds[0] else { + unreachable!() + }; + + let was_in_cache = + self.prepared_statements.remove(bind.statement()); + + let renamed = self.prepared_statements.parse(bind.statement()); + + match (was_in_cache, renamed) { + (true, Some(renamed)) => { + self.prepared_statements.state_mut().clear(); + + self.send_ignore(&ProtocolMessage::Sync(Sync)).await?; + self.send_ignore(&ProtocolMessage::Close(Close::named( + renamed.name(), + ))) + .await?; + + for message in current_request_messages + [self.state.get_successfully_processed()..] + .iter() + { + self.send_one(&message).await?; + } + self.send_one(&Flush.into()).await?; + self.flush().await?; + continue; + } + _ => { + warn!("Tried to reprepare statement, but it was not available in cache"); + } + } + } + } + } + match self.prepared_statements.forward(&message) { Ok(forward) => { if forward { + self.state.process(code)?; break message; } } @@ -1181,6 +1247,7 @@ pub mod test { net::TcpListener, }; + use crate::frontend::PreparedStatements as FrontendPreparedStatements; use crate::{config::Memory, frontend::PreparedStatements, net::*}; use super::{Error, *}; @@ -1232,6 +1299,7 @@ pub mod test { sending_request: false, password_attempts: 1, max_age: None, + state: ServerState::new(), } } } @@ -1255,6 +1323,13 @@ pub mod test { .unwrap() } + /// Insert a prepared statement into the global cache so check_prepared can find it. + fn insert_global(name: &str, query: &str) -> String { + let parse = Parse::named(name, query); + let (_, rewritten_name) = FrontendPreparedStatements::global().write().insert(&parse); + rewritten_name + } + pub async fn test_replication_server() -> Server { Server::connect( &Address::new_test(), @@ -2766,6 +2841,247 @@ pub mod test { ); } + #[tokio::test] + async fn test_retry_prepared() { + let mut server = test_server().await; + + server + .execute( + "DROP TABLE IF EXISTS retry_prepared_1; CREATE TABLE IF NOT EXISTS retry_prepared_1 ( + id INTEGER, + value1 INTEGER + );", + ) + .await + .unwrap(); + + let name = "test".to_string(); + let parse = Parse::named(&name, "SELECT * FROM retry_prepared_1"); + + let name = insert_global("test", "SELECT * FROM retry_prepared_1"); + + server + .send( + &vec![ + ProtocolMessage::from(parse.clone()), + ProtocolMessage::from(Describe::new_statement(&name)), + Bind::new_statement("test").into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', 't', 'T', '2', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + + server + .send( + &vec![ProtocolMessage::from(Query::new( + "ALTER TABLE retry_prepared_1 ADD new_col_1 INTEGER", + ))] + .into(), + ) + .await + .unwrap(); + + for c in ['C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(c, msg.code()); + } + + server + .send(&vec![Bind::new_statement(&name).into(), Execute::new().into()].into()) + .await + .unwrap(); + + for c in ['2', 'C'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + server.send(&vec![Sync.into()].into()).await.unwrap(); + + for c in ['Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + } + + #[tokio::test] + async fn test_retry_prepared_with_parse() { + let mut server = test_server().await; + + server + .execute( + "DROP TABLE IF EXISTS retry_prepared_1; CREATE TABLE IF NOT EXISTS retry_prepared_1 ( + id INTEGER, + value1 INTEGER + );", + ) + .await + .unwrap(); + + let name = insert_global("test", "SELECT * FROM retry_prepared_1"); + + let parse = Parse::named(&name, "SELECT * FROM retry_prepared_1"); + + server + .send( + &vec![ + ProtocolMessage::from(parse.clone()), + ProtocolMessage::from(Describe::new_statement(&name)), + Bind::new_statement(&name).into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', 't', 'T', '2', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + + server + .send( + &vec![ProtocolMessage::from(Query::new( + "ALTER TABLE retry_prepared_1 ADD new_col_1 INTEGER", + ))] + .into(), + ) + .await + .unwrap(); + + for c in ['C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(c, msg.code()); + } + + assert!(server.done()); + + server + .send( + &vec![ + ProtocolMessage::from(parse.clone()), + Bind::new_statement(&name).into(), + Execute::new().into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', '2', 'C'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + server.send(&vec![Sync.into()].into()).await.unwrap(); + + for c in ['Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + } + + #[tokio::test] + async fn test_retry_prepared_with_parse_describe() { + let mut server = test_server().await; + + server + .execute( + "DROP TABLE IF EXISTS retry_prepared_1; CREATE TABLE IF NOT EXISTS retry_prepared_1 ( + id INTEGER, + value1 INTEGER + );", + ) + .await + .unwrap(); + + let name = insert_global("test", "SELECT * FROM retry_prepared_1"); + + let parse = Parse::named(&name, "SELECT * FROM retry_prepared_1"); + + server + .send( + &vec![ + ProtocolMessage::from(parse.clone()), + ProtocolMessage::from(Describe::new_statement(&name)), + Bind::new_statement(&name).into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', 't', 'T', '2', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + + server + .send( + &vec![ProtocolMessage::from(Query::new( + "ALTER TABLE retry_prepared_1 ADD new_col_1 INTEGER", + ))] + .into(), + ) + .await + .unwrap(); + + for c in ['C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(c, msg.code()); + } + + assert!(server.done()); + + server + .send( + &vec![ + ProtocolMessage::from(Describe::new_statement(&name)), + Bind::new_statement(&name).into(), + Execute::new().into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['t', 't', 'T', '2', 'C'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + server.send(&vec![Sync.into()].into()).await.unwrap(); + + for c in ['Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + } + #[tokio::test] async fn test_prepare_forces_sync_prepared_flag() { let mut server = test_server().await; diff --git a/pgdog/src/backend/server_state.rs b/pgdog/src/backend/server_state.rs new file mode 100644 index 000000000..d60c22835 --- /dev/null +++ b/pgdog/src/backend/server_state.rs @@ -0,0 +1,587 @@ +use super::Error; +use crate::net::{Protocol, ProtocolMessage}; + +pub trait StateTransition { + fn process(self, message_code: char) -> impl StateTransition; +} + +#[derive(Debug)] +pub enum State { + RunningParse, + RunningBind, + RunningDescribe, + RunningExecute, + RunningClose, + RunningFlush, + RunningSync, + RunningQuery, + RunningCopy, + RunningCopyDone, + RunningCopyFail, + RunningFastpath, + RunningPrepare, +} + +impl From<&ProtocolMessage> for State { + fn from(value: &ProtocolMessage) -> Self { + match value { + ProtocolMessage::Parse(_) => State::RunningParse, + ProtocolMessage::Prepare { .. } => State::RunningPrepare, + ProtocolMessage::Bind(_) => State::RunningBind, + ProtocolMessage::Describe(_) => State::RunningDescribe, + ProtocolMessage::Execute(_) => State::RunningExecute, + ProtocolMessage::Close(_) => State::RunningClose, + ProtocolMessage::Sync(_) => State::RunningSync, + ProtocolMessage::Query(_) => State::RunningQuery, + ProtocolMessage::Other(message) => match message.code() { + 'H' => State::RunningFlush, + _ => panic!("Unexpected other type {:?}", value.code()), + }, + ProtocolMessage::CopyData(_) => State::RunningCopy, + ProtocolMessage::CopyFail(_) => State::RunningCopyFail, + ProtocolMessage::CopyDone(_) => State::RunningCopyDone, + ProtocolMessage::Fastpath(_) => State::RunningFastpath, + } + } +} + +#[derive(Debug)] +pub struct ServerState { + states: Vec, + active_state_index: usize, + current_request_messages: Vec, +} + +impl ServerState { + pub fn new() -> Self { + ServerState { + states: Vec::new(), + active_state_index: 0, + current_request_messages: Vec::new(), + } + } + + pub fn set_state(&mut self, messages: &[ProtocolMessage]) { + self.states = messages.iter().map(|msg| msg.into()).collect(); + self.current_request_messages = messages.to_vec(); + + self.active_state_index = 0; + } + + pub fn process(&mut self, message_code: char) -> Result<(), Error> { + if message_code == 'E' { + // Clear state when we get error + self.set_state(&[]); + } + + if self.active_state_index + 1 > self.states.len() { + return Ok(()); + } + + let current_state = &self.states[self.active_state_index]; + match current_state { + State::RunningParse => match message_code { + '1' => self.active_state_index += 1, + _ => return Err(Error::UnexpectedMessage(message_code)), + }, + State::RunningBind => match message_code { + '2' => self.active_state_index += 1, + _ => return Err(Error::UnexpectedMessage(message_code)), + }, + State::RunningClose => match message_code { + '3' => self.active_state_index += 1, + _ => return Err(Error::UnexpectedMessage(message_code)), + }, + State::RunningDescribe => match message_code { + 'T' | 'n' => self.active_state_index += 1, + 't' => (), + _ => return Err(Error::UnexpectedMessage(message_code)), + }, + State::RunningExecute => match message_code { + 'C' | 'I' | 's' => self.active_state_index += 1, + 'D' | 'N' => (), + _ => return Err(Error::UnexpectedMessage(message_code)), + }, + State::RunningFlush => (), + State::RunningSync => match message_code { + 'Z' => self.active_state_index += 1, + _ => return Err(Error::UnexpectedMessage(message_code)), + }, + State::RunningQuery => match message_code { + 'Z' => self.active_state_index += 1, + _ => (), + }, + State::RunningCopy => match message_code { + 'C' => self.active_state_index += 1, + _ => (), + }, + State::RunningCopyFail => { + // Backend will respond with an E, + // so nothing to do. + } + State::RunningCopyDone => match message_code { + 'Z' => self.active_state_index += 1, + _ => (), + }, + State::RunningFastpath => match message_code { + 'Z' => self.active_state_index += 1, + _ => (), + }, + State::RunningPrepare => match message_code { + 'Z' => self.active_state_index += 1, + _ => (), + }, + } + + Ok(()) + } + + pub fn get_successfully_processed(&self) -> usize { + self.active_state_index + } + + pub fn get_current_query(&self) -> &[ProtocolMessage] { + &self.current_request_messages + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::net::messages::{ + Bind, Close, CopyData, CopyDone, CopyFail, Describe, Execute, Fastpath, Flush, FromBytes, + Parse, Query, Sync, + }; + use bytes::{BufMut, BytesMut}; + + fn make_fastpath() -> ProtocolMessage { + let mut buf = BytesMut::new(); + buf.put_u8(b'F'); + buf.put_i32(4); + ProtocolMessage::Fastpath( + Fastpath::from_bytes(buf.freeze()).expect("fastpath shouldn't fail"), + ) + } + + #[test] + fn new_state_starts_empty() { + let state = ServerState::new(); + assert_eq!(state.get_successfully_processed(), 0); + assert!(state.get_current_query().is_empty()); + } + + #[test] + fn set_state_resets_index() { + let mut state = ServerState::new(); + let messages = vec![ + ProtocolMessage::Query(Query::new("SELECT 1")), + ProtocolMessage::Sync(Sync::new()), + ]; + state.set_state(&messages); + // Advance past the query + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + + // Reset with new messages + state.set_state(&[ProtocolMessage::Query(Query::new("SELECT 2"))]); + assert_eq!(state.get_successfully_processed(), 0); + } + + #[test] + fn set_state_stores_messages() { + let mut state = ServerState::new(); + let messages = vec![ + ProtocolMessage::Parse(Parse::new_anonymous("SELECT 1")), + ProtocolMessage::Bind(Bind::default()), + ProtocolMessage::Execute(Execute::new()), + ProtocolMessage::Sync(Sync::new()), + ]; + state.set_state(&messages); + assert_eq!(state.get_current_query().len(), 4); + } + + #[test] + fn state_from_parse() { + let msg = ProtocolMessage::Parse(Parse::new_anonymous("SELECT 1")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningParse)); + } + + #[test] + fn state_from_bind() { + let msg = ProtocolMessage::Bind(Bind::default()); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningBind)); + } + + #[test] + fn state_from_describe() { + let msg = ProtocolMessage::Describe(Describe::new_statement("")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningDescribe)); + } + + #[test] + fn state_from_execute() { + let msg = ProtocolMessage::Execute(Execute::new()); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningExecute)); + } + + #[test] + fn state_from_close() { + let msg = ProtocolMessage::Close(Close::named("stmt")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningClose)); + } + + #[test] + fn state_from_sync() { + let msg = ProtocolMessage::Sync(Sync::new()); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningSync)); + } + + #[test] + fn state_from_query() { + let msg = ProtocolMessage::Query(Query::new("SELECT 1")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningQuery)); + } + + #[test] + fn state_from_copy_data() { + let msg = ProtocolMessage::CopyData(CopyData::new(b"row data")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningCopy)); + } + + #[test] + fn state_from_copy_fail() { + let msg = ProtocolMessage::CopyFail(CopyFail::new("error")); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningCopyFail)); + } + + #[test] + fn state_from_copy_done() { + let msg = ProtocolMessage::CopyDone(CopyDone); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningCopyDone)); + } + + #[test] + fn state_from_flush() { + let msg = ProtocolMessage::Other(Flush.message().unwrap()); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningFlush)); + } + + #[test] + fn state_from_fastpath() { + let msg = make_fastpath(); + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningFastpath)); + } + + #[test] + fn state_from_prepare() { + let msg = ProtocolMessage::Prepare { + name: "stmt1".to_string(), + statement: "SELECT 1".to_string(), + }; + let state: State = (&msg).into(); + assert!(matches!(state, State::RunningPrepare)); + } + + #[test] + fn simple_query_advances_on_ready_for_query() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Query(Query::new("SELECT 1"))]); + + // Data rows and other messages don't advance + state.process('D').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + state.process('T').unwrap(); // RowDescription + assert_eq!(state.get_successfully_processed(), 0); + state.process('C').unwrap(); // CommandComplete + assert_eq!(state.get_successfully_processed(), 0); + + // ReadyForQuery advances + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn extended_query_full_cycle() { + let mut state = ServerState::new(); + state.set_state(&[ + ProtocolMessage::Parse(Parse::new_anonymous("SELECT $1")), + ProtocolMessage::Bind(Bind::default()), + ProtocolMessage::Describe(Describe::new_statement("")), + ProtocolMessage::Execute(Execute::new()), + ProtocolMessage::Sync(Sync::new()), + ]); + + // ParseComplete + state.process('1').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + + // BindComplete + state.process('2').unwrap(); + assert_eq!(state.get_successfully_processed(), 2); + + // RowDescription (from Describe) + state.process('T').unwrap(); + assert_eq!(state.get_successfully_processed(), 3); + + // DataRow (from Execute) - doesn't advance + state.process('D').unwrap(); + assert_eq!(state.get_successfully_processed(), 3); + + // CommandComplete (from Execute) - advances + state.process('C').unwrap(); + assert_eq!(state.get_successfully_processed(), 4); + + // ReadyForQuery (from Sync) + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 5); + } + + #[test] + fn describe_advances_on_no_data() { + let mut state = ServerState::new(); + state.set_state(&[ + ProtocolMessage::Describe(Describe::new_statement("")), + ProtocolMessage::Sync(Sync::new()), + ]); + + // NoData response + state.process('n').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn describe_ignores_parameter_description() { + let mut state = ServerState::new(); + state.set_state(&[ + ProtocolMessage::Describe(Describe::new_statement("")), + ProtocolMessage::Sync(Sync::new()), + ]); + + // ParameterDescription ('t') doesn't advance + state.process('t').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + + // RowDescription advances + state.process('T').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn execute_advances_on_empty_query() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Execute(Execute::new())]); + + // EmptyQueryResponse + state.process('I').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn execute_advances_on_portal_suspended() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Execute(Execute::new())]); + + // PortalSuspended + state.process('s').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn execute_ignores_notice_response() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Execute(Execute::new())]); + + // NoticeResponse doesn't advance + state.process('N').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + + // CommandComplete advances + state.process('C').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn close_advances_on_close_complete() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Close(Close::named("stmt"))]); + + // CloseComplete + state.process('3').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn error_response_clears_state() { + let mut state = ServerState::new(); + state.set_state(&[ + ProtocolMessage::Parse(Parse::new_anonymous("SELECT 1")), + ProtocolMessage::Bind(Bind::default()), + ProtocolMessage::Sync(Sync::new()), + ]); + + // ParseComplete + state.process('1').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + + // ErrorResponse clears everything + state.process('E').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + assert!(state.get_current_query().is_empty()); + } + + #[test] + fn copy_data_advances_on_command_complete() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::CopyData(CopyData::new(b"data"))]); + + // Other messages don't advance + state.process('d').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + + // CommandComplete advances + state.process('C').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn copy_done_advances_on_ready_for_query() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::CopyDone(CopyDone)]); + + state.process('C').unwrap(); // CommandComplete doesn't advance + assert_eq!(state.get_successfully_processed(), 0); + + state.process('Z').unwrap(); // ReadyForQuery advances + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn copy_fail_does_not_advance() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::CopyFail(CopyFail::new("abort"))]); + + // Backend responds with ErrorResponse, but CopyFail state itself doesn't advance + state.process('E').unwrap(); + // Error clears state entirely + assert_eq!(state.get_successfully_processed(), 0); + } + + #[test] + fn flush_never_advances() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Other(Flush.message().unwrap())]); + + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + state.process('T').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + } + + #[test] + fn fastpath_advances_on_ready_for_query() { + let mut state = ServerState::new(); + state.set_state(&[make_fastpath()]); + + state.process('V').unwrap(); // FunctionCallResponse doesn't advance + assert_eq!(state.get_successfully_processed(), 0); + + state.process('Z').unwrap(); // ReadyForQuery advances + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn prepare_advances_on_ready_for_query() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Prepare { + name: "stmt1".to_string(), + statement: "SELECT 1".to_string(), + }]); + + state.process('T').unwrap(); // RowDescription doesn't advance + assert_eq!(state.get_successfully_processed(), 0); + + state.process('Z').unwrap(); // ReadyForQuery advances + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn process_does_nothing_when_all_states_consumed() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Sync(Sync::new())]); + + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + state.process('T').unwrap(); + assert_eq!(state.get_successfully_processed(), 1); + } + + #[test] + fn process_does_nothing_on_empty_state() { + let mut state = ServerState::new(); + state.process('Z').unwrap(); + assert_eq!(state.get_successfully_processed(), 0); + } + + #[test] + fn parse_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Parse(Parse::new_anonymous("SELECT 1"))]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn bind_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Bind(Bind::default())]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn close_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Close(Close::named("s"))]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn describe_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Describe(Describe::new_statement(""))]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn execute_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Execute(Execute::new())]); + let result = state.process('Z'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } + + #[test] + fn sync_errors_on_unexpected() { + let mut state = ServerState::new(); + state.set_state(&[ProtocolMessage::Sync(Sync::new())]); + let result = state.process('T'); + assert!(matches!(result, Err(Error::UnexpectedMessage(_)))); + } +} diff --git a/pgdog/src/net/messages/describe.rs b/pgdog/src/net/messages/describe.rs index 79d5142cc..064d1bc87 100644 --- a/pgdog/src/net/messages/describe.rs +++ b/pgdog/src/net/messages/describe.rs @@ -132,7 +132,7 @@ mod test { let pool = pool(); let mut conn = pool.get(&Request::default()).await.unwrap(); let describe = Describe::new_portal(""); - conn.send(&vec![ProtocolMessage::from(describe.message().unwrap())].into()) + conn.send(&vec![ProtocolMessage::from(describe)].into()) .await .unwrap(); let res = conn.read().await.unwrap();