@@ -21,7 +21,7 @@ impl QueryEngine {
2121 route : & Route ,
2222 ) -> Result < ( ) , Error > {
2323 // Check that we're not in a transaction error state.
24- if !self . transaction_error_check ( context) . await ? {
24+ if !self . transaction_error_check ( context, route ) . await ? {
2525 return Ok ( ( ) ) ;
2626 }
2727
@@ -115,7 +115,12 @@ impl QueryEngine {
115115
116116 match state {
117117 TransactionState :: Error => {
118- context. transaction = Some ( TransactionType :: Error ) ;
118+ let error_state = match context. transaction {
119+ Some ( TransactionType :: ReadOnly ) => Some ( TransactionType :: ErrorReadOnly ) ,
120+ Some ( TransactionType :: ReadWrite ) => Some ( TransactionType :: ErrorReadWrite ) ,
121+ _ => None ,
122+ } ;
123+ context. transaction = error_state;
119124 if self . two_pc . auto ( ) {
120125 self . end_two_pc ( true ) . await ?;
121126 // TODO: this records a 2pc transaction in client
@@ -133,10 +138,23 @@ impl QueryEngine {
133138 self . end_two_pc ( false ) . await ?;
134139 two_pc_auto = true ;
135140 }
136- if context. transaction . is_none ( ) {
141+ match context. transaction {
137142 // Query parser is disabled, so the server is responsible for telling us
138143 // we started a transaction.
139- context. transaction = Some ( TransactionType :: ReadWrite ) ;
144+ None => {
145+ context. transaction = Some ( TransactionType :: ReadWrite ) ;
146+ }
147+
148+ // Restore transaction state after rollback to savepoint.
149+ Some ( TransactionType :: ErrorReadOnly ) => {
150+ context. transaction = Some ( TransactionType :: ReadOnly ) ;
151+ }
152+
153+ Some ( TransactionType :: ErrorReadWrite ) => {
154+ context. transaction = Some ( TransactionType :: ReadWrite ) ;
155+ }
156+
157+ _ => ( ) ,
140158 }
141159 }
142160 }
@@ -264,8 +282,13 @@ impl QueryEngine {
264282 async fn transaction_error_check (
265283 & mut self ,
266284 context : & mut QueryEngineContext < ' _ > ,
285+ route : & Route ,
267286 ) -> Result < bool , Error > {
268- if context. in_error ( ) && !context. rollback && context. client_request . executable ( ) {
287+ if context. in_error ( )
288+ && !context. rollback
289+ && context. client_request . executable ( )
290+ && !route. rollback_savepoint ( )
291+ {
269292 let bytes_sent = context
270293 . stream
271294 . error (
0 commit comments