Skip to content

Commit 1f608d6

Browse files
authored
Support savepoints (#547)
* Support savepoints * Fix spelling
1 parent d727e21 commit 1f608d6

6 files changed

Lines changed: 73 additions & 8 deletions

File tree

integration/rust/tests/integration/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod notify;
99
pub mod per_stmt_routing;
1010
pub mod prepared;
1111
pub mod reload;
12+
pub mod savepoint;
1213
pub mod set_sharding_key;
1314
pub mod shard_consistency;
1415
pub mod stddev;
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use rust::setup::{admin_sqlx, connections_sqlx};
2+
use sqlx::Executor;
3+
4+
#[tokio::test]
5+
async fn test_savepoint() {
6+
let conns = connections_sqlx().await;
7+
8+
for conn in conns {
9+
let mut transaction = conn.begin().await.unwrap();
10+
transaction
11+
.execute("CREATE TABLE test_savepoint (id BIGINT)")
12+
.await
13+
.unwrap();
14+
transaction.execute("SAVEPOINT test").await.unwrap();
15+
assert!(transaction.execute("SELECT sdfsf").await.is_err());
16+
transaction
17+
.execute("ROLLBACK TO SAVEPOINT test")
18+
.await
19+
.unwrap();
20+
transaction
21+
.execute("SELECT * FROM test_savepoint")
22+
.await
23+
.unwrap();
24+
transaction.rollback().await.unwrap();
25+
}
26+
}

pgdog/src/frontend/client/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ pub enum TransactionType {
5656
ReadOnly,
5757
#[default]
5858
ReadWrite,
59-
Error,
59+
ErrorReadWrite,
60+
ErrorReadOnly,
6061
}
6162

6263
impl TransactionType {
@@ -69,7 +70,7 @@ impl TransactionType {
6970
}
7071

7172
pub fn error(&self) -> bool {
72-
matches!(self, Self::Error)
73+
matches!(self, Self::ErrorReadWrite | Self::ErrorReadOnly)
7374
}
7475
}
7576

pgdog/src/frontend/client/query_engine/query.rs

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ impl QueryEngine {
2121
route: &Route,
2222
) -> Result<(), Error> {
2323
// Check that we're not in a transaction error state.
24-
if !self.transaction_error_check(context).await? {
24+
if !self.transaction_error_check(context, route).await? {
2525
return Ok(());
2626
}
2727

@@ -115,7 +115,12 @@ impl QueryEngine {
115115

116116
match state {
117117
TransactionState::Error => {
118-
context.transaction = Some(TransactionType::Error);
118+
let error_state = match context.transaction {
119+
Some(TransactionType::ReadOnly) => Some(TransactionType::ErrorReadOnly),
120+
Some(TransactionType::ReadWrite) => Some(TransactionType::ErrorReadWrite),
121+
_ => None,
122+
};
123+
context.transaction = error_state;
119124
if self.two_pc.auto() {
120125
self.end_two_pc(true).await?;
121126
// TODO: this records a 2pc transaction in client
@@ -133,10 +138,23 @@ impl QueryEngine {
133138
self.end_two_pc(false).await?;
134139
two_pc_auto = true;
135140
}
136-
if context.transaction.is_none() {
141+
match context.transaction {
137142
// Query parser is disabled, so the server is responsible for telling us
138143
// we started a transaction.
139-
context.transaction = Some(TransactionType::ReadWrite);
144+
None => {
145+
context.transaction = Some(TransactionType::ReadWrite);
146+
}
147+
148+
// Restore transaction state after rollback to savepoint.
149+
Some(TransactionType::ErrorReadOnly) => {
150+
context.transaction = Some(TransactionType::ReadOnly);
151+
}
152+
153+
Some(TransactionType::ErrorReadWrite) => {
154+
context.transaction = Some(TransactionType::ReadWrite);
155+
}
156+
157+
_ => (),
140158
}
141159
}
142160
}
@@ -264,8 +282,13 @@ impl QueryEngine {
264282
async fn transaction_error_check(
265283
&mut self,
266284
context: &mut QueryEngineContext<'_>,
285+
route: &Route,
267286
) -> Result<bool, Error> {
268-
if context.in_error() && !context.rollback && context.client_request.executable() {
287+
if context.in_error()
288+
&& !context.rollback
289+
&& context.client_request.executable()
290+
&& !route.rollback_savepoint()
291+
{
269292
let bytes_sent = context
270293
.stream
271294
.error(

pgdog/src/frontend/router/parser/query/transaction.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ impl QueryParser {
1616
context: &QueryParserContext,
1717
) -> Result<Command, Error> {
1818
let extended = !context.query()?.simple();
19+
let mut rollback_savepoint = false;
1920

2021
if context.rw_conservative() && !context.read_only {
2122
self.write_override = true;
@@ -37,6 +38,7 @@ impl QueryParser {
3738
extended,
3839
});
3940
}
41+
TransactionStmtKind::TransStmtRollbackTo => rollback_savepoint = true,
4042
TransactionStmtKind::TransStmtPrepare
4143
| TransactionStmtKind::TransStmtCommitPrepared
4244
| TransactionStmtKind::TransStmtRollbackPrepared => {
@@ -47,7 +49,9 @@ impl QueryParser {
4749
_ => (),
4850
}
4951

50-
Ok(Command::Query(Route::write(None)))
52+
Ok(Command::Query(
53+
Route::write(None).set_rollback_savepoint(rollback_savepoint),
54+
))
5155
}
5256

5357
#[inline]

pgdog/src/frontend/router/parser/route.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ pub struct Route {
6060
maintenance: bool,
6161
rewrite_plan: RewritePlan,
6262
rewritten_sql: Option<String>,
63+
rollback_savepoint: bool,
6364
}
6465

6566
impl Display for Route {
@@ -187,6 +188,15 @@ impl Route {
187188
self.read = read;
188189
}
189190

191+
pub fn set_rollback_savepoint(mut self, rollback: bool) -> Self {
192+
self.rollback_savepoint = rollback;
193+
self
194+
}
195+
196+
pub fn rollback_savepoint(&self) -> bool {
197+
self.rollback_savepoint
198+
}
199+
190200
pub fn set_write(mut self, write: FunctionBehavior) -> Self {
191201
self.set_write_mut(write);
192202
self

0 commit comments

Comments
 (0)