Skip to content

Commit 9ce3bdd

Browse files
committed
add features to handle unnamed queries, extracting properly hive tables from queries with subscripts, some additional issues that were already fixed were documented by tests, some cleanup and refactor to decrease unreachable paths
1 parent 9b967db commit 9ce3bdd

13 files changed

Lines changed: 406 additions & 140 deletions

sql_metadata/column_extractor.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ExtractionResult:
3737
alias_map: Dict[str, Union[str, list]]
3838
cte_names: UniqueList
3939
subquery_names: UniqueList
40+
output_columns: list
4041

4142

4243
# ---------------------------------------------------------------------------
@@ -149,6 +150,7 @@ class _Collector:
149150
"cte_names",
150151
"cte_alias_names",
151152
"subquery_items",
153+
"output_columns",
152154
)
153155

154156
def __init__(self, table_aliases: Dict[str, str]):
@@ -161,6 +163,7 @@ def __init__(self, table_aliases: Dict[str, str]):
161163
self.cte_names = UniqueList()
162164
self.cte_alias_names: set = set()
163165
self.subquery_items: list = []
166+
self.output_columns: list = []
164167

165168
def add_column(self, name: str, clause: str) -> None:
166169
"""Record a column name, filing it into the appropriate section."""
@@ -227,12 +230,6 @@ def extract(self) -> ExtractionResult:
227230

228231
self._seed_cte_names()
229232

230-
# Handle CREATE TABLE with column defs (no SELECT)
231-
if isinstance(self._ast, exp.Create) and not self._ast.find(exp.Select):
232-
for col_def in self._ast.find_all(exp.ColumnDef):
233-
c.add_column(col_def.name, "")
234-
return self._build_result()
235-
236233
# Reset cte_names — walk will re-collect them in order
237234
c.cte_names = UniqueList()
238235
self._walk(self._ast)
@@ -251,6 +248,7 @@ def extract(self) -> ExtractionResult:
251248
alias_map=c.alias_map,
252249
cte_names=final_cte,
253250
subquery_names=self._build_subquery_names(),
251+
output_columns=c.output_columns,
254252
)
255253

256254
# -------------------------------------------------------------------
@@ -279,19 +277,6 @@ def _build_subquery_names(self) -> UniqueList:
279277
names.append(name)
280278
return names
281279

282-
def _build_result(self) -> ExtractionResult:
283-
"""Build result from collector (used for early-return CREATE TABLE path)."""
284-
c = self._collector
285-
alias_dict = c.alias_dict if c.alias_dict else None
286-
return ExtractionResult(
287-
columns=c.columns,
288-
columns_dict=c.columns_dict,
289-
alias_names=c.alias_names,
290-
alias_dict=alias_dict,
291-
alias_map=c.alias_map,
292-
cte_names=c.cte_names,
293-
subquery_names=self._build_subquery_names(),
294-
)
295280

296281
# -------------------------------------------------------------------
297282
# Column name helpers
@@ -366,15 +351,11 @@ def _dispatch_leaf(self, node: exp.Expression, clause: str, depth: int) -> bool:
366351
367352
Returns ``True`` if handled (stop recursion), ``False`` to continue.
368353
"""
369-
if isinstance(node, (exp.Values, exp.Star, exp.ColumnDef, exp.Identifier)):
370-
# TODO: revisit if Stars appear outside Select.expressions
371-
if isinstance(node, exp.Star): # pragma: no cover
372-
self._handle_star(node, clause)
373-
# TODO: revisit if CREATE TABLE walk stops returning early
374-
elif isinstance(node, exp.ColumnDef): # pragma: no cover
354+
if isinstance(node, exp.Values) and not node.find(exp.Select):
355+
return True
356+
if isinstance(node, (exp.Star, exp.ColumnDef, exp.Identifier)):
357+
if isinstance(node, exp.ColumnDef):
375358
self._collector.add_column(node.name, clause)
376-
elif isinstance(node, exp.Identifier):
377-
self._handle_identifier(node, clause)
378359
return True
379360
if isinstance(node, exp.CTE):
380361
self._handle_cte(node, depth)
@@ -417,23 +398,6 @@ def _recurse_child(self, child: Any, clause: str, depth: int) -> None:
417398
# Node handlers
418399
# -------------------------------------------------------------------
419400

420-
# TODO: revisit if Stars reach _dispatch_leaf
421-
def _handle_star(self, node: exp.Star, clause: str) -> None: # pragma: no cover
422-
"""Handle a standalone Star node (not inside a Column or function)."""
423-
not_in_col = not isinstance(node.parent, exp.Column)
424-
if not_in_col and not self._is_star_inside_function(node):
425-
self._collector.add_column("*", clause)
426-
427-
def _handle_identifier(self, node: exp.Identifier, clause: str) -> None:
428-
"""Handle an Identifier in a USING clause (not inside a Column)."""
429-
if not isinstance(
430-
node.parent,
431-
(exp.Column, exp.Table, exp.TableAlias, exp.CTE),
432-
):
433-
# TODO: revisit if JOIN produces bare Identifiers
434-
if clause == "join": # pragma: no cover
435-
self._collector.add_column(node.name, clause)
436-
437401
def _handle_insert_schema(self, node: exp.Insert) -> None:
438402
"""Extract column names from the Schema of an INSERT statement."""
439403
schema = node.find(exp.Schema)
@@ -482,18 +446,23 @@ def _handle_column(self, col: exp.Column, clause: str) -> None:
482446
def _handle_select_exprs(self, exprs: list, clause: str, depth: int) -> None:
483447
"""Handle the expressions list of a SELECT clause."""
484448
assert isinstance(exprs, list)
449+
out = self._collector.output_columns
485450

486451
for expr in exprs:
487452
if isinstance(expr, exp.Alias):
488453
self._handle_alias(expr, clause, depth)
454+
out.append(expr.alias)
489455
elif isinstance(expr, exp.Star):
490456
self._collector.add_column("*", clause)
457+
out.append("*")
491458
elif isinstance(expr, exp.Column):
492459
self._handle_column(expr, clause)
460+
out.append(self._column_full_name(expr))
493461
else:
494462
cols = self._flat_columns(expr)
495463
for col in cols:
496464
self._collector.add_column(col, clause)
465+
out.append(cols[0] if len(cols) == 1 else str(expr))
497466

498467
def _handle_alias(self, alias_node: exp.Alias, clause: str, depth: int) -> None:
499468
"""Handle an Alias node inside a SELECT expression list."""

sql_metadata/dialect_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ def _detect_dialects(sql: str) -> list:
103103
return [HashVarDialect, None, "mysql"]
104104
if "`" in sql:
105105
return ["mysql", None]
106+
if "LATERAL VIEW" in upper:
107+
return ["spark", None, "mysql"]
106108
if "[" in sql or " TOP " in upper:
107109
return [BracketedTableDialect, None, "mysql"]
108110
if " UNIQUE " in upper:
109111
return [None, "mysql", "oracle"]
110-
if "LATERAL VIEW" in upper:
111-
return ["spark", None, "mysql"]
112112
return [None, "mysql"]
113113

114114
# -- parsing ------------------------------------------------------------

sql_metadata/nested_resolver.py

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from __future__ import annotations
1010

1111
import copy
12-
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
12+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
1313

1414
if TYPE_CHECKING:
1515
from sql_metadata.parser import Parser
@@ -167,24 +167,49 @@ def extract_cte_names(
167167
return names
168168

169169
@staticmethod
170-
def extract_subquery_names(ast: Optional[exp.Expression]) -> List[str]:
171-
"""Extract aliased subquery names from the AST in post-order.
170+
def extract_subqueries(
171+
ast: Optional[exp.Expression],
172+
) -> Tuple[List[str], Dict[str, str]]:
173+
"""Extract subquery names and bodies in a single post-order walk.
174+
175+
Aliased subqueries keep their alias as the name. Unaliased
176+
subqueries (e.g. ``WHERE id IN (SELECT …)``) get auto-generated
177+
names ``subquery_1``, ``subquery_2``, etc.
172178
173-
Called by :attr:`Parser.subqueries_names`.
179+
:returns: ``(names, bodies)`` where *names* is ordered innermost-first.
174180
"""
175181
if ast is None: # pragma: no cover — Parser ensures AST exists
176-
return []
177-
names = UniqueList()
178-
NestedResolver._collect_subquery_names_postorder(ast, names)
179-
return names
182+
return [], {}
183+
names: list = UniqueList()
184+
bodies: Dict[str, str] = {}
185+
NestedResolver._walk_subqueries(ast, names, bodies, 0)
186+
return names, bodies
180187

181188
@staticmethod
182-
def _collect_subquery_names_postorder(node: exp.Expression, out: list) -> None:
183-
"""Recursively collect subquery aliases in post-order."""
189+
def _walk_subqueries(
190+
node: exp.Expression,
191+
names: list,
192+
bodies: Dict[str, str],
193+
counter: int,
194+
) -> int:
195+
"""Post-order walk collecting subquery names and bodies.
196+
197+
Returns the updated *counter* so unnamed subqueries are numbered
198+
sequentially.
199+
"""
184200
for child in node.iter_expressions():
185-
NestedResolver._collect_subquery_names_postorder(child, out)
186-
if isinstance(node, exp.Subquery) and node.alias:
187-
out.append(node.alias)
201+
counter = NestedResolver._walk_subqueries(
202+
child, names, bodies, counter
203+
)
204+
if isinstance(node, exp.Subquery):
205+
if node.alias:
206+
name = node.alias
207+
else:
208+
counter += 1
209+
name = f"subquery_{counter}"
210+
names.append(name)
211+
bodies[name] = NestedResolver._body_sql(node.this)
212+
return counter
188213

189214
# -------------------------------------------------------------------
190215
# Body extraction
@@ -226,39 +251,6 @@ def extract_cte_bodies(
226251

227252
return results
228253

229-
def extract_subquery_bodies(
230-
self,
231-
subquery_names: List[str],
232-
) -> Dict[str, str]:
233-
"""Extract subquery body SQL for each name in *subquery_names*.
234-
235-
Uses a post-order AST walk so that inner subqueries appear before
236-
outer ones.
237-
238-
:param subquery_names: List of subquery alias names to extract.
239-
:returns: Mapping of ``{subquery_name: body_sql}``.
240-
"""
241-
if not self._ast or not subquery_names:
242-
return {}
243-
244-
names_upper = {n.upper(): n for n in subquery_names}
245-
results: Dict[str, str] = {}
246-
self._collect_subqueries_postorder(self._ast, names_upper, results)
247-
return results
248-
249-
@staticmethod
250-
def _collect_subqueries_postorder(
251-
node: exp.Expression, names_upper: Dict[str, str], out: Dict[str, str]
252-
) -> None:
253-
"""Recursively collect subquery bodies in post-order."""
254-
for child in node.iter_expressions():
255-
NestedResolver._collect_subqueries_postorder(child, names_upper, out)
256-
if isinstance(node, exp.Subquery) and node.alias:
257-
alias_upper = node.alias.upper()
258-
if alias_upper in names_upper:
259-
original_name = names_upper[alias_upper]
260-
out[original_name] = NestedResolver._body_sql(node.this)
261-
262254
# -------------------------------------------------------------------
263255
# Column resolution (from parser.py)
264256
# -------------------------------------------------------------------

0 commit comments

Comments
 (0)