@@ -5,7 +5,9 @@ use std::path::PathBuf;
55use tracing:: { info, warn} ;
66
77use crate :: sharding:: ShardedSchema ;
8- use crate :: { EnumeratedDatabase , Memory , PassthoughAuth , PreparedStatements , RewriteMode } ;
8+ use crate :: {
9+ EnumeratedDatabase , Memory , OmnishardedTable , PassthoughAuth , PreparedStatements , RewriteMode ,
10+ } ;
911
1012use super :: database:: Database ;
1113use super :: error:: Error ;
@@ -227,15 +229,18 @@ impl Config {
227229 tables
228230 }
229231
230- pub fn omnisharded_tables ( & self ) -> HashMap < String , Vec < String > > {
232+ pub fn omnisharded_tables ( & self ) -> HashMap < String , Vec < OmnishardedTable > > {
231233 let mut tables = HashMap :: new ( ) ;
232234
233235 for table in & self . omnisharded_tables {
234236 let entry = tables
235237 . entry ( table. database . clone ( ) )
236238 . or_insert_with ( Vec :: new) ;
237239 for t in & table. tables {
238- entry. push ( t. clone ( ) ) ;
240+ entry. push ( OmnishardedTable {
241+ name : t. clone ( ) ,
242+ sticky_routing : table. sticky ,
243+ } ) ;
239244 }
240245 }
241246
@@ -647,4 +652,58 @@ password = "users_admin_password"
647652 ) ;
648653 assert ! ( config_and_users. users. admin. is_none( ) ) ;
649654 }
655+
656+ #[ test]
657+ fn test_omnisharded_tables ( ) {
658+ let source = r#"
659+ [general]
660+ host = "0.0.0.0"
661+ port = 6432
662+
663+ [[databases]]
664+ name = "db1"
665+ host = "127.0.0.1"
666+ port = 5432
667+
668+ [[databases]]
669+ name = "db2"
670+ host = "127.0.0.1"
671+ port = 5433
672+
673+ [[omnisharded_tables]]
674+ database = "db1"
675+ tables = ["table_a", "table_b"]
676+
677+ [[omnisharded_tables]]
678+ database = "db1"
679+ tables = ["table_c"]
680+ sticky = true
681+
682+ [[omnisharded_tables]]
683+ database = "db2"
684+ tables = ["table_x"]
685+ "# ;
686+
687+ let config: Config = toml:: from_str ( source) . unwrap ( ) ;
688+
689+ assert_eq ! ( config. omnisharded_tables. len( ) , 3 ) ;
690+
691+ let tables = config. omnisharded_tables ( ) ;
692+
693+ assert_eq ! ( tables. len( ) , 2 ) ;
694+
695+ let db1_tables = tables. get ( "db1" ) . unwrap ( ) ;
696+ assert_eq ! ( db1_tables. len( ) , 3 ) ;
697+ assert_eq ! ( db1_tables[ 0 ] . name, "table_a" ) ;
698+ assert ! ( !db1_tables[ 0 ] . sticky_routing) ;
699+ assert_eq ! ( db1_tables[ 1 ] . name, "table_b" ) ;
700+ assert ! ( !db1_tables[ 1 ] . sticky_routing) ;
701+ assert_eq ! ( db1_tables[ 2 ] . name, "table_c" ) ;
702+ assert ! ( db1_tables[ 2 ] . sticky_routing) ;
703+
704+ let db2_tables = tables. get ( "db2" ) . unwrap ( ) ;
705+ assert_eq ! ( db2_tables. len( ) , 1 ) ;
706+ assert_eq ! ( db2_tables[ 0 ] . name, "table_x" ) ;
707+ assert ! ( !db2_tables[ 0 ] . sticky_routing) ;
708+ }
650709}
0 commit comments