Skip to content

Commit ea620b7

Browse files
authored
fix: role detection; feat: set inside transactions (#647)
The role detection was pretty bad, it was causing an infinite loop on role change and DoSing itself. Now, we're using atomics on load balancer targets, so we don't need to reload the config when role changes. For `SET`, we now handle it inside transactions and run it on the client's behalf. This makes it work when `cross_shard_disabled = true`. Additionally, all `SET` commands executed inside a transaction are rewritten to use `SET LOCAL` instead to avoid leaking client state between servers.
1 parent 345bd28 commit ea620b7

30 files changed

Lines changed: 947 additions & 783 deletions

File tree

integration/rust/tests/integration/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub mod prepared;
1212
pub mod reload;
1313
pub mod rewrite;
1414
pub mod savepoint;
15+
pub mod set_in_transaction;
1516
pub mod set_sharding_key;
1617
pub mod shard_consistency;
1718
pub mod stddev;
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
use rust::setup::{admin_sqlx, connections_sqlx};
2+
use serial_test::serial;
3+
use sqlx::Executor;
4+
5+
#[tokio::test]
6+
#[serial]
7+
async fn test_set_in_transaction_reset_after_commit() {
8+
let admin = admin_sqlx().await;
9+
admin
10+
.execute("SET cross_shard_disabled TO true")
11+
.await
12+
.unwrap();
13+
14+
let pools = connections_sqlx().await;
15+
let sharded = &pools[1];
16+
17+
let mut conn = sharded.acquire().await.unwrap();
18+
19+
// Get the original lock_timeout before any transaction
20+
let original_timeout: String = sqlx::query_scalar("SHOW lock_timeout")
21+
.fetch_one(&mut *conn)
22+
.await
23+
.unwrap();
24+
25+
// Make sure we set it to something different
26+
let new_timeout = if original_timeout == "45s" {
27+
"30s"
28+
} else {
29+
"45s"
30+
};
31+
32+
// Start a transaction and change lock_timeout
33+
conn.execute("BEGIN").await.unwrap();
34+
conn.execute(format!("SET lock_timeout TO '{}'", new_timeout).as_str())
35+
.await
36+
.unwrap();
37+
38+
// Verify lock_timeout is set inside transaction
39+
let timeout_in_tx: String = sqlx::query_scalar("SHOW lock_timeout")
40+
.fetch_one(&mut *conn)
41+
.await
42+
.unwrap();
43+
assert_eq!(
44+
timeout_in_tx, new_timeout,
45+
"lock_timeout should be {} inside transaction",
46+
new_timeout
47+
);
48+
49+
conn.execute("COMMIT").await.unwrap();
50+
51+
// Verify lock_timeout is reset to original after commit
52+
let timeout_after_commit: String = sqlx::query_scalar("SHOW lock_timeout")
53+
.fetch_one(&mut *conn)
54+
.await
55+
.unwrap();
56+
assert_eq!(
57+
timeout_after_commit, original_timeout,
58+
"lock_timeout should be reset to original after commit"
59+
);
60+
61+
admin
62+
.execute("SET cross_shard_disabled TO false")
63+
.await
64+
.unwrap();
65+
}
66+
67+
#[tokio::test]
68+
#[serial]
69+
async fn test_set_in_transaction_reset_after_rollback() {
70+
let admin = admin_sqlx().await;
71+
admin
72+
.execute("SET cross_shard_disabled TO true")
73+
.await
74+
.unwrap();
75+
76+
let pools = connections_sqlx().await;
77+
let sharded = &pools[1];
78+
79+
let mut conn = sharded.acquire().await.unwrap();
80+
81+
// Get the original statement_timeout before any transaction
82+
let original_timeout: String = sqlx::query_scalar("SHOW statement_timeout")
83+
.fetch_one(&mut *conn)
84+
.await
85+
.unwrap();
86+
87+
// Make sure we set it to something different
88+
let new_timeout = if original_timeout == "30s" {
89+
"45s"
90+
} else {
91+
"30s"
92+
};
93+
94+
// Start a transaction and change statement_timeout
95+
conn.execute("BEGIN").await.unwrap();
96+
conn.execute(format!("SET statement_timeout TO '{}'", new_timeout).as_str())
97+
.await
98+
.unwrap();
99+
100+
// Verify statement_timeout is set inside transaction
101+
let timeout_in_tx: String = sqlx::query_scalar("SHOW statement_timeout")
102+
.fetch_one(&mut *conn)
103+
.await
104+
.unwrap();
105+
assert_eq!(
106+
timeout_in_tx, new_timeout,
107+
"statement_timeout should be {} inside transaction",
108+
new_timeout
109+
);
110+
111+
conn.execute("ROLLBACK").await.unwrap();
112+
113+
// Verify statement_timeout is back to original after rollback
114+
let timeout_after_rollback: String = sqlx::query_scalar("SHOW statement_timeout")
115+
.fetch_one(&mut *conn)
116+
.await
117+
.unwrap();
118+
assert_eq!(
119+
timeout_after_rollback, original_timeout,
120+
"statement_timeout should be reset to original after rollback"
121+
);
122+
123+
admin
124+
.execute("SET cross_shard_disabled TO false")
125+
.await
126+
.unwrap();
127+
}
128+
129+
#[tokio::test]
130+
#[serial]
131+
async fn test_set_local_in_transaction_reset_after_commit() {
132+
let admin = admin_sqlx().await;
133+
admin
134+
.execute("SET cross_shard_disabled TO true")
135+
.await
136+
.unwrap();
137+
138+
let pools = connections_sqlx().await;
139+
let sharded = &pools[1];
140+
141+
let mut conn = sharded.acquire().await.unwrap();
142+
143+
// Get the original work_mem before any transaction
144+
let original_work_mem: String = sqlx::query_scalar("SHOW work_mem")
145+
.fetch_one(&mut *conn)
146+
.await
147+
.unwrap();
148+
149+
// Make sure we set it to something different
150+
let new_work_mem = if original_work_mem == "8MB" {
151+
"16MB"
152+
} else {
153+
"8MB"
154+
};
155+
156+
// Start a transaction and change work_mem using SET LOCAL
157+
conn.execute("BEGIN").await.unwrap();
158+
conn.execute(format!("SET LOCAL work_mem TO '{}'", new_work_mem).as_str())
159+
.await
160+
.unwrap();
161+
162+
// Verify work_mem is set inside transaction
163+
let work_mem_in_tx: String = sqlx::query_scalar("SHOW work_mem")
164+
.fetch_one(&mut *conn)
165+
.await
166+
.unwrap();
167+
assert_eq!(
168+
work_mem_in_tx, new_work_mem,
169+
"work_mem should be {} inside transaction",
170+
new_work_mem
171+
);
172+
173+
conn.execute("COMMIT").await.unwrap();
174+
175+
// Verify work_mem is reset to original after commit (SET LOCAL is transaction-scoped)
176+
let work_mem_after_commit: String = sqlx::query_scalar("SHOW work_mem")
177+
.fetch_one(&mut *conn)
178+
.await
179+
.unwrap();
180+
assert_eq!(
181+
work_mem_after_commit, original_work_mem,
182+
"work_mem should be reset to original after commit (SET LOCAL is transaction-scoped)"
183+
);
184+
185+
admin
186+
.execute("SET cross_shard_disabled TO false")
187+
.await
188+
.unwrap();
189+
}

pgdog-config/src/core.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use tracing::{info, warn};
77
use crate::sharding::ShardedSchema;
88
use crate::{
99
EnumeratedDatabase, Memory, OmnishardedTable, PassthoughAuth, PreparedStatements, RewriteMode,
10+
Role,
1011
};
1112

1213
use super::database::Database;
@@ -312,18 +313,39 @@ impl Config {
312313
}
313314
}
314315

315-
// Check pooler mode.
316-
let mut pooler_mode = HashMap::<String, Option<PoolerMode>>::new();
316+
struct Check {
317+
pooler_mode: Option<PoolerMode>,
318+
role: Role,
319+
role_warned: bool,
320+
}
321+
322+
// Check identical configs.
323+
let mut checks = HashMap::<String, Check>::new();
317324
for database in &self.databases {
318-
if let Some(mode) = pooler_mode.get(&database.name) {
319-
if mode != &database.pooler_mode {
325+
if let Some(existing) = checks.get_mut(&database.name) {
326+
if existing.pooler_mode != database.pooler_mode {
320327
warn!(
321328
"database \"{}\" (shard={}, role={}) has a different \"pooler_mode\" setting, ignoring",
322329
database.name, database.shard, database.role,
323330
);
324331
}
332+
let auto = existing.role == Role::Auto || database.role == Role::Auto;
333+
if auto && existing.role != database.role && !existing.role_warned {
334+
warn!(
335+
r#"database "{}" has a mix of auto and specific roles, automatic role detection will be disabled"#,
336+
database.name
337+
);
338+
existing.role_warned = true;
339+
}
325340
} else {
326-
pooler_mode.insert(database.name.clone(), database.pooler_mode.clone());
341+
checks.insert(
342+
database.name.clone(),
343+
Check {
344+
pooler_mode: database.pooler_mode.clone(),
345+
role: database.role,
346+
role_warned: false,
347+
},
348+
);
327349
}
328350
}
329351

0 commit comments

Comments
 (0)