1- use std:: collections:: HashMap ;
1+ use std:: collections:: { HashMap , HashSet } ;
22
33use once_cell:: sync:: Lazy ;
44use pg_query:: { protobuf, Node , NodeEnum } ;
@@ -20,6 +20,9 @@ static WRITE_ONLY: Lazy<HashMap<&'static str, LockingBehavior>> = Lazy::new(|| {
2020 ] )
2121} ) ;
2222
23+ static CROSS_SHARD : Lazy < HashSet < ( & ' static str , & ' static str ) > > =
24+ Lazy :: new ( || HashSet :: from ( [ ( "pgdog" , "install_sharded_sequence" ) ] ) ) ;
25+
2326#[ derive( Debug , Copy , Clone , PartialEq , Default ) ]
2427pub enum LockingBehavior {
2528 Lock ,
@@ -32,6 +35,7 @@ pub enum LockingBehavior {
3235pub struct FunctionBehavior {
3336 pub writes : bool ,
3437 pub locking_behavior : LockingBehavior ,
38+ pub cross_shard : bool ,
3539}
3640
3741impl FunctionBehavior {
@@ -45,28 +49,49 @@ impl FunctionBehavior {
4549
4650pub struct Function < ' a > {
4751 pub name : & ' a str ,
52+ pub schema : Option < & ' a str > ,
4853}
4954
5055impl < ' a > Function < ' a > {
51- fn from_string ( node : & ' a Option < NodeEnum > ) -> Result < Self , ( ) > {
52- match node {
53- Some ( NodeEnum :: String ( protobuf:: String { sval } ) ) => Ok ( Self {
54- name : sval. as_str ( ) ,
56+ /// Build a Function from a qualified name list (as found in `FuncCall.funcname`).
57+ /// The last element is the function name; the preceding element (if any) is the
58+ /// schema.
59+ fn from_strings ( parts : & ' a [ Node ] ) -> Result < Self , ( ) > {
60+ let str_of = |node : & ' a Node | match & node. node {
61+ Some ( NodeEnum :: String ( protobuf:: String { sval } ) ) => Ok ( sval. as_str ( ) ) ,
62+ _ => Err ( ( ) ) ,
63+ } ;
64+ match parts {
65+ [ name] => Ok ( Self {
66+ name : str_of ( name) ?,
67+ schema : None ,
68+ } ) ,
69+ [ .., schema, name] => Ok ( Self {
70+ name : str_of ( name) ?,
71+ schema : Some ( str_of ( schema) ?) ,
5572 } ) ,
56-
5773 _ => Err ( ( ) ) ,
5874 }
5975 }
6076
6177 /// This function likely writes.
6278 pub fn behavior ( & self ) -> FunctionBehavior {
79+ let cross_shard = self
80+ . schema
81+ . map ( |schema| CROSS_SHARD . contains ( & ( schema, self . name ) ) )
82+ . unwrap_or ( false ) ;
83+
6384 if let Some ( locks) = WRITE_ONLY . get ( & self . name ) {
6485 FunctionBehavior {
6586 writes : true ,
6687 locking_behavior : * locks,
88+ cross_shard,
6789 }
6890 } else {
69- FunctionBehavior :: default ( )
91+ FunctionBehavior {
92+ cross_shard,
93+ ..FunctionBehavior :: default ( )
94+ }
7095 }
7196 }
7297}
@@ -76,9 +101,7 @@ impl<'a> TryFrom<&'a Node> for Function<'a> {
76101 fn try_from ( value : & ' a Node ) -> Result < Self , Self :: Error > {
77102 match & value. node {
78103 Some ( NodeEnum :: FuncCall ( func) ) => {
79- if let Some ( node) = func. funcname . last ( ) {
80- return Self :: from_string ( & node. node ) ;
81- }
104+ return Self :: from_strings ( & func. funcname ) ;
82105 }
83106
84107 Some ( NodeEnum :: TypeCast ( cast) ) => {
@@ -123,10 +146,52 @@ mod test {
123146 for node in & stmt. target_list {
124147 let func = Function :: try_from ( node) . unwrap ( ) ;
125148 assert ! ( func. name. contains( "advisory_lock" ) ) ;
149+ assert ! ( func. schema. is_none( ) ) ;
150+ assert ! ( !func. behavior( ) . cross_shard) ;
126151 }
127152 }
128153
129154 _ => panic ! ( "not a select" ) ,
130155 }
131156 }
157+
158+ fn first_func < R > ( query : & str , check : impl FnOnce ( Function < ' _ > ) -> R ) -> R {
159+ let ast = parse ( query) . unwrap ( ) ;
160+ let root = ast. protobuf . stmts . first ( ) . unwrap ( ) . stmt . as_ref ( ) . unwrap ( ) ;
161+ match root. node . as_ref ( ) {
162+ Some ( NodeEnum :: SelectStmt ( stmt) ) => {
163+ let target = stmt. target_list . first ( ) . unwrap ( ) ;
164+ check ( Function :: try_from ( target) . unwrap ( ) )
165+ }
166+ _ => panic ! ( "not a select" ) ,
167+ }
168+ }
169+
170+ #[ test]
171+ fn test_cross_shard_function ( ) {
172+ first_func (
173+ "SELECT pgdog.install_sharded_sequence('foo', 'id')" ,
174+ |func| {
175+ assert_eq ! ( func. name, "install_sharded_sequence" ) ;
176+ assert_eq ! ( func. schema, Some ( "pgdog" ) ) ;
177+ assert ! ( func. behavior( ) . cross_shard) ;
178+ } ,
179+ ) ;
180+
181+ // Same function name without the schema should not be flagged.
182+ first_func ( "SELECT install_sharded_sequence('foo', 'id')" , |func| {
183+ assert_eq ! ( func. name, "install_sharded_sequence" ) ;
184+ assert ! ( func. schema. is_none( ) ) ;
185+ assert ! ( !func. behavior( ) . cross_shard) ;
186+ } ) ;
187+
188+ // Different schema should not be flagged.
189+ first_func (
190+ "SELECT other.install_sharded_sequence('foo', 'id')" ,
191+ |func| {
192+ assert_eq ! ( func. schema, Some ( "other" ) ) ;
193+ assert ! ( !func. behavior( ) . cross_shard) ;
194+ } ,
195+ ) ;
196+ }
132197}
0 commit comments