Skip to content

Commit 345bd28

Browse files
authored
Fix aggregate group by count (#645)
Addresses #638. I had a bit of spare time, and was curious about this. So, it turns out the issue is an incorrectly parsed out `group_by` most of the time. The `group_by` expects an index into the `target_list`, so I made that work for the case where you have the exact `table_name.column_name` or same `column_name` in the `select` and the `group by` queries. this feels a little bit janky to me, but 🤷 figured I'd raise for feedback, especially as I see a draft PR for a rewritten query engine. ```sql (michael)@127.0.0.1:6432 16:54:06 [repro_sharded] > select count(1), user_id from example group by example.user_id; count | user_id -------+--------- 6 | 1 3 | 2 4 | 3 (3 rows) (michael)@127.0.0.1:6432 16:54:07 [repro_sharded] > select count(1), example.user_id from example group by example.user_id; count | user_id -------+--------- 6 | 1 3 | 2 4 | 3 (3 rows) (michael)@127.0.0.1:6432 16:54:14 [repro_sharded] > select example.user_id, count(1), example.user_id from example group by example.user_id; unexpected field count in "D" message (michael)@127.0.0.1:6432 16:54:21 [repro_sharded] > select example.user_id, count(1) from example group by example.user_id; user_id | count ---------+------- 1 | 6 2 | 3 3 | 4 (3 rows) (michael)@127.0.0.1:6432 16:54:28 [repro_sharded] > select user_id, count(1) from example group by example.user_id; user_id | count ---------+------- 3 | 4 1 | 6 2 | 3 (3 rows) (michael)@127.0.0.1:6432 16:54:32 [repro_sharded] > select example.user_id, count(1) from example group by user_id; user_id | count ---------+------- | 13 (1 row) (michael)@127.0.0.1:6432 16:54:37 [repro_sharded] > select user_id, count(1) from example group by user_id; user_id | count ---------+------- 3 | 4 1 | 6 2 | 3 (3 rows) ``` See here for results, including some cases where it fails to work correctly. (where we duplicate the column in the select, and where we specify inconsistent `example.user_id` / `user_id` (NOTE: it only fails if we're more specific in the `select` than in the `group_by`.
1 parent 99ae692 commit 345bd28

1 file changed

Lines changed: 277 additions & 1 deletion

File tree

pgdog/src/frontend/router/parser/aggregate.rs

Lines changed: 277 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use pg_query::protobuf::Integer;
2-
use pg_query::protobuf::{a_const::Val, SelectStmt};
2+
use pg_query::protobuf::{a_const::Val, Node, SelectStmt, String as PgQueryString};
33
use pg_query::NodeEnum;
44

55
use crate::frontend::router::parser::{ExpressionRegistry, Function};
@@ -67,6 +67,60 @@ pub struct Aggregate {
6767
group_by: Vec<usize>,
6868
}
6969

70+
fn target_list_to_index(stmt: &SelectStmt, column_names: Vec<&String>) -> Option<usize> {
71+
for (idx, node) in stmt.target_list.iter().enumerate() {
72+
if let Some(NodeEnum::ResTarget(res_target_box)) = node.node.as_ref() {
73+
let res_target = res_target_box.as_ref();
74+
if let Some(node_box) = res_target.val.as_ref() {
75+
if let Some(NodeEnum::ColumnRef(column_ref)) = node_box.node.as_ref() {
76+
let select_names: Vec<&String> = column_ref
77+
.fields
78+
.iter()
79+
.filter_map(|field_node| {
80+
if let Some(node_box) = field_node.node.as_ref() {
81+
match node_box {
82+
NodeEnum::String(PgQueryString {
83+
sval: found_column_name,
84+
..
85+
}) => Some(found_column_name),
86+
_ => None,
87+
}
88+
} else {
89+
None
90+
}
91+
})
92+
.collect();
93+
94+
if select_names.is_empty() {
95+
continue;
96+
}
97+
98+
if columns_match(&column_names, &select_names) {
99+
return Some(idx);
100+
}
101+
}
102+
}
103+
}
104+
}
105+
None
106+
}
107+
108+
fn columns_match(group_by_names: &[&String], select_names: &[&String]) -> bool {
109+
if group_by_names == select_names {
110+
return true;
111+
}
112+
113+
if group_by_names.len() == 1 && select_names.len() == 2 {
114+
return select_names[1] == group_by_names[0];
115+
}
116+
117+
if group_by_names.len() == 2 && select_names.len() == 1 {
118+
return group_by_names[1] == select_names[0];
119+
}
120+
121+
false
122+
}
123+
70124
impl Aggregate {
71125
/// Figure out what aggregates are present and which ones PgDog supports.
72126
pub fn parse(stmt: &SelectStmt) -> Result<Self, Error> {
@@ -81,6 +135,20 @@ impl Aggregate {
81135
Val::Ival(Integer { ival }) => Some(*ival as usize - 1), // We use 0-indexed arrays, Postgres uses 1-indexed.
82136
_ => None,
83137
}),
138+
NodeEnum::ColumnRef(column_ref) => {
139+
let column_names: Vec<&String> = column_ref
140+
.fields
141+
.iter()
142+
.filter_map(|node| match node {
143+
Node {
144+
node:
145+
Some(NodeEnum::String(PgQueryString { sval: column_name })),
146+
} => Some(column_name),
147+
_ => None,
148+
})
149+
.collect();
150+
Some(target_list_to_index(stmt, column_names))
151+
}
84152
_ => None,
85153
})
86154
})
@@ -381,4 +449,212 @@ mod test {
381449
_ => panic!("not a select"),
382450
}
383451
}
452+
453+
#[test]
454+
fn test_parse_group_by_column_name_single() {
455+
let query = pg_query::parse("SELECT user_id, COUNT(1) FROM example GROUP BY user_id")
456+
.unwrap()
457+
.protobuf
458+
.stmts
459+
.first()
460+
.cloned()
461+
.unwrap();
462+
match query.stmt.unwrap().node.unwrap() {
463+
NodeEnum::SelectStmt(stmt) => {
464+
let aggr = Aggregate::parse(&stmt).unwrap();
465+
assert_eq!(aggr.group_by(), &[0]);
466+
assert_eq!(aggr.targets().len(), 1);
467+
let target = &aggr.targets()[0];
468+
assert!(matches!(target.function(), AggregateFunction::Count));
469+
assert_eq!(target.column(), 1);
470+
}
471+
_ => panic!("not a select"),
472+
}
473+
}
474+
475+
#[test]
476+
fn test_parse_group_by_column_name_multiple() {
477+
let query = pg_query::parse(
478+
"SELECT COUNT(*), user_id, category_id FROM example GROUP BY user_id, category_id",
479+
)
480+
.unwrap()
481+
.protobuf
482+
.stmts
483+
.first()
484+
.cloned()
485+
.unwrap();
486+
match query.stmt.unwrap().node.unwrap() {
487+
NodeEnum::SelectStmt(stmt) => {
488+
let aggr = Aggregate::parse(&stmt).unwrap();
489+
assert_eq!(aggr.group_by(), &[1, 2]);
490+
assert_eq!(aggr.targets().len(), 1);
491+
let target = &aggr.targets()[0];
492+
assert!(matches!(target.function(), AggregateFunction::Count));
493+
assert_eq!(target.column(), 0);
494+
}
495+
_ => panic!("not a select"),
496+
}
497+
}
498+
499+
#[test]
500+
fn test_parse_group_by_qualified_column_name() {
501+
let query = pg_query::parse(
502+
"SELECT COUNT(1), example.user_id FROM example GROUP BY example.user_id",
503+
)
504+
.unwrap()
505+
.protobuf
506+
.stmts
507+
.first()
508+
.cloned()
509+
.unwrap();
510+
match query.stmt.unwrap().node.unwrap() {
511+
NodeEnum::SelectStmt(stmt) => {
512+
let aggr = Aggregate::parse(&stmt).unwrap();
513+
assert_eq!(aggr.group_by(), &[1]);
514+
assert_eq!(aggr.targets().len(), 1);
515+
let target = &aggr.targets()[0];
516+
assert!(matches!(target.function(), AggregateFunction::Count));
517+
assert_eq!(target.column(), 0);
518+
}
519+
_ => panic!("not a select"),
520+
}
521+
}
522+
523+
#[test]
524+
fn test_parse_group_by_mixed_ordinal_and_column_name() {
525+
let query = pg_query::parse(
526+
"SELECT user_id, category_id, SUM(quantity) FROM example GROUP BY user_id, 2",
527+
)
528+
.unwrap()
529+
.protobuf
530+
.stmts
531+
.first()
532+
.cloned()
533+
.unwrap();
534+
match query.stmt.unwrap().node.unwrap() {
535+
NodeEnum::SelectStmt(stmt) => {
536+
let aggr = Aggregate::parse(&stmt).unwrap();
537+
assert_eq!(aggr.group_by(), &[0, 1]);
538+
assert_eq!(aggr.targets().len(), 1);
539+
let target = &aggr.targets()[0];
540+
assert!(matches!(target.function(), AggregateFunction::Sum));
541+
assert_eq!(target.column(), 2);
542+
}
543+
_ => panic!("not a select"),
544+
}
545+
}
546+
547+
#[test]
548+
fn test_parse_group_by_column_not_in_select() {
549+
let query = pg_query::parse("SELECT COUNT(*) FROM example GROUP BY user_id")
550+
.unwrap()
551+
.protobuf
552+
.stmts
553+
.first()
554+
.cloned()
555+
.unwrap();
556+
match query.stmt.unwrap().node.unwrap() {
557+
NodeEnum::SelectStmt(stmt) => {
558+
let aggr = Aggregate::parse(&stmt).unwrap();
559+
let empty: Vec<usize> = vec![];
560+
assert_eq!(aggr.group_by(), empty.as_slice());
561+
assert_eq!(aggr.targets().len(), 1);
562+
}
563+
_ => panic!("not a select"),
564+
}
565+
}
566+
567+
#[test]
568+
fn test_parse_group_by_with_multiple_aggregates() {
569+
let query = pg_query::parse(
570+
"SELECT COUNT(*), SUM(price), user_id, AVG(price) FROM example GROUP BY user_id",
571+
)
572+
.unwrap()
573+
.protobuf
574+
.stmts
575+
.first()
576+
.cloned()
577+
.unwrap();
578+
match query.stmt.unwrap().node.unwrap() {
579+
NodeEnum::SelectStmt(stmt) => {
580+
let aggr = Aggregate::parse(&stmt).unwrap();
581+
assert_eq!(aggr.group_by(), &[2]);
582+
assert_eq!(aggr.targets().len(), 3);
583+
assert!(matches!(
584+
aggr.targets()[0].function(),
585+
AggregateFunction::Count
586+
));
587+
assert!(matches!(
588+
aggr.targets()[1].function(),
589+
AggregateFunction::Sum
590+
));
591+
assert!(matches!(
592+
aggr.targets()[2].function(),
593+
AggregateFunction::Avg
594+
));
595+
}
596+
_ => panic!("not a select"),
597+
}
598+
}
599+
600+
#[test]
601+
fn test_parse_group_by_qualified_matches_select_unqualified() {
602+
let query =
603+
pg_query::parse("SELECT user_id, COUNT(1) FROM example GROUP BY example.user_id")
604+
.unwrap()
605+
.protobuf
606+
.stmts
607+
.first()
608+
.cloned()
609+
.unwrap();
610+
match query.stmt.unwrap().node.unwrap() {
611+
NodeEnum::SelectStmt(stmt) => {
612+
let aggr = Aggregate::parse(&stmt).unwrap();
613+
assert_eq!(aggr.group_by(), &[0]);
614+
assert_eq!(aggr.targets().len(), 1);
615+
}
616+
_ => panic!("not a select"),
617+
}
618+
}
619+
620+
#[test]
621+
fn test_parse_group_by_unqualified_matches_select_qualified() {
622+
let query =
623+
pg_query::parse("SELECT example.user_id, COUNT(1) FROM example GROUP BY user_id")
624+
.unwrap()
625+
.protobuf
626+
.stmts
627+
.first()
628+
.cloned()
629+
.unwrap();
630+
match query.stmt.unwrap().node.unwrap() {
631+
NodeEnum::SelectStmt(stmt) => {
632+
let aggr = Aggregate::parse(&stmt).unwrap();
633+
assert_eq!(aggr.group_by(), &[0]);
634+
assert_eq!(aggr.targets().len(), 1);
635+
}
636+
_ => panic!("not a select"),
637+
}
638+
}
639+
640+
#[test]
641+
fn test_parse_group_by_both_qualified_order_matters() {
642+
let query = pg_query::parse(
643+
"SELECT example.user_id, COUNT(1) FROM example GROUP BY example.user_id",
644+
)
645+
.unwrap()
646+
.protobuf
647+
.stmts
648+
.first()
649+
.cloned()
650+
.unwrap();
651+
match query.stmt.unwrap().node.unwrap() {
652+
NodeEnum::SelectStmt(stmt) => {
653+
let aggr = Aggregate::parse(&stmt).unwrap();
654+
assert_eq!(aggr.group_by(), &[0]);
655+
assert_eq!(aggr.targets().len(), 1);
656+
}
657+
_ => panic!("not a select"),
658+
}
659+
}
384660
}

0 commit comments

Comments
 (0)