Skip to content

Commit 1fd9d36

Browse files
committed
additional cleanup in dialect_parser.py and query_type_extractor.py
1 parent 2dd4685 commit 1fd9d36

3 files changed

Lines changed: 162 additions & 73 deletions

File tree

sql_metadata/dialect_parser.py

Lines changed: 114 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,31 @@ class BracketedTableDialect(TSQL):
9191

9292

9393
class DialectParser:
94-
"""Detect the appropriate sqlglot dialect and parse SQL into an AST."""
94+
"""Detect the appropriate sqlglot dialect and parse SQL into an AST.
95+
96+
SQL varies across database engines — back-ticks (MySQL), square
97+
brackets (TSQL), ``#temp`` tables (MSSQL), ``LATERAL VIEW`` (Hive),
98+
etc. A single sqlglot dialect cannot handle all of them, so this
99+
class first inspects the raw SQL for dialect markers, then tries
100+
candidate dialects in order and picks the first result that passes
101+
quality checks.
102+
"""
95103

96104
def parse(self, clean_sql: str) -> tuple[exp.Expression, DialectType]:
97-
"""Parse *clean_sql*, returning ``(ast, dialect)``.
98-
99-
Detects candidate dialects via heuristics, tries each in order,
100-
and returns the first non-degraded result.
101-
102-
:param clean_sql: Preprocessed SQL string (comments stripped, etc.).
103-
:type clean_sql: str
104-
:returns: 2-tuple of ``(ast_node, winning_dialect)``.
105-
:rtype: tuple
106-
:raises ValueError: If all dialect attempts fail.
105+
"""Parse *clean_sql* into a sqlglot AST, returning ``(ast, dialect)``.
106+
107+
Entry point for the two-phase process: first
108+
:meth:`_detect_dialects` builds a priority-ordered list of
109+
candidate dialects from syntactic markers in the SQL, then
110+
:meth:`_try_dialects` attempts each one and returns the first
111+
non-degraded result.
112+
113+
:param clean_sql: Preprocessed SQL string produced by
114+
:class:`~sql_metadata.sql_cleaner.SqlCleaner` (comments
115+
stripped, outer parentheses removed, CTE names normalised).
116+
:returns: 2-tuple of ``(ast_root_node, winning_dialect)``.
117+
:raises InvalidQueryDefinition: If every candidate dialect
118+
fails to produce a usable AST.
107119
"""
108120
dialects = self._detect_dialects(clean_sql)
109121
return self._try_dialects(clean_sql, dialects)
@@ -112,20 +124,30 @@ def parse(self, clean_sql: str) -> tuple[exp.Expression, DialectType]:
112124

113125
@staticmethod
114126
def _detect_dialects(sql: str) -> list[Any]:
115-
"""Choose an ordered list of sqlglot dialects to try for *sql*.
116-
117-
Heuristics:
118-
119-
* ``#WORD`` → :class:`HashVarDialect` (MSSQL temp tables).
120-
* Back-ticks → ``"mysql"``.
121-
* Square brackets or ``TOP`` → :class:`BracketedTableDialect`.
122-
* ``UNIQUE`` → try default, MySQL, Oracle.
123-
* ``LATERAL VIEW`` → ``"spark"`` (Hive).
127+
"""Build a priority-ordered list of sqlglot dialects for *sql*.
128+
129+
Scans the SQL string for syntactic markers that reveal which
130+
database engine produced it and returns the most likely dialect
131+
first. Every list includes at least one fallback so that the
132+
subsequent :meth:`_try_dialects` loop always has alternatives.
133+
134+
Heuristics (checked in order, first match wins):
135+
136+
* ``#WORD`` patterns → :class:`HashVarDialect` (MSSQL ``#temp``
137+
tables or ``#VAR#`` template placeholders).
138+
* Back-tick quoting → ``"mysql"`` (MySQL-style identifiers).
139+
* ``LATERAL VIEW`` → ``"spark"`` (Hive/Spark explode syntax).
140+
* Square brackets or ``TOP`` keyword →
141+
:class:`BracketedTableDialect` (TSQL bracket-quoted names).
142+
* ``UNIQUE`` keyword → default, ``"mysql"``, ``"oracle"``
143+
(ambiguous across engines).
144+
* ``APPEND FROM`` → :class:`RedshiftAppendDialect` (Redshift
145+
``ALTER TABLE … APPEND FROM`` not natively supported).
146+
* No markers → default dialect with ``"mysql"`` fallback.
124147
125148
:param sql: Cleaned SQL string.
126-
:type sql: str
127-
:returns: Ordered list of dialects to attempt.
128-
:rtype: list
149+
:returns: Ordered list of dialect identifiers or classes to
150+
attempt.
129151
"""
130152
upper = sql.upper()
131153
if _has_hash_variables(sql):
@@ -147,20 +169,25 @@ def _detect_dialects(sql: str) -> list[Any]:
147169
def _try_dialects(
148170
self, clean_sql: str, dialects: list[Any]
149171
) -> tuple[exp.Expression, DialectType]:
150-
"""Try parsing *clean_sql* with each dialect, returning the best.
151-
152-
:returns: 2-tuple of ``(ast_node, winning_dialect)``.
153-
:raises ValueError: If all dialect attempts fail.
172+
"""Try each candidate dialect in order and return the first good result.
173+
174+
Iterates over *dialects*, calling :meth:`_parse_with_dialect` for
175+
each. A result is accepted immediately if it is the last dialect
176+
in the list (best-effort) or if :meth:`_is_degraded` reports no
177+
quality issues. Degraded results from non-last dialects are
178+
skipped so the next candidate gets a chance.
179+
180+
:param clean_sql: Preprocessed SQL string.
181+
:param dialects: Priority-ordered list from :meth:`_detect_dialects`.
182+
:returns: 2-tuple of ``(ast_root_node, winning_dialect)``.
183+
:raises InvalidQueryDefinition: If the last dialect raises a
184+
parse error, or if no dialect produces a usable AST.
154185
"""
155-
last_result = None
156-
winning_dialect = None
157186
for dialect in dialects:
158187
try:
159188
result = self._parse_with_dialect(clean_sql, dialect)
160189
if result is None:
161190
continue
162-
last_result = result
163-
winning_dialect = dialect
164191
is_last = dialect == dialects[-1]
165192
if not is_last and self._is_degraded(result, clean_sql):
166193
continue
@@ -172,16 +199,32 @@ def _try_dialects(
172199
)
173200
continue
174201

175-
# TODO: revisit if sqlglot starts returning None from parse for last dialect
176-
if last_result is not None: # pragma: no cover
177-
return last_result, winning_dialect
178202
raise InvalidQueryDefinition(
179203
"Query could not be parsed — no dialect could handle this SQL"
180204
)
181205

182206
@staticmethod
183207
def _parse_with_dialect(clean_sql: str, dialect: Any) -> exp.Expression | None:
184-
"""Parse *clean_sql* with a single dialect, suppressing warnings."""
208+
"""Parse *clean_sql* with a single sqlglot dialect.
209+
210+
Uses ``ErrorLevel.WARN`` so that sqlglot returns a best-effort
211+
AST instead of raising on the first syntax problem — the caller
212+
decides whether the result is good enough via
213+
:meth:`_is_degraded`.
214+
215+
The sqlglot logger is temporarily raised to ``CRITICAL`` during
216+
the parse call because ``WARN`` mode emits noisy warnings for
217+
every token it cannot handle. Since :meth:`_try_dialects`
218+
intentionally tries multiple dialects expecting some to produce
219+
degraded results, those warnings are expected and would mislead
220+
end-users if left visible.
221+
222+
:param clean_sql: Preprocessed SQL string.
223+
:param dialect: A sqlglot dialect identifier, class, or ``None``
224+
for the default dialect.
225+
:returns: The root AST node, or ``None`` if sqlglot could not
226+
produce any result.
227+
"""
185228
logger = logging.getLogger("sqlglot")
186229
old_level = logger.level
187230
logger.setLevel(logging.CRITICAL)
@@ -196,35 +239,58 @@ def _parse_with_dialect(clean_sql: str, dialect: Any) -> exp.Expression | None:
196239

197240
if not results or results[0] is None:
198241
return None
199-
result = results[0]
200-
assert result is not None # guaranteed by check above
201-
# TODO: revisit if sqlglot returns top-level Subquery
202-
if isinstance(result, exp.Subquery) and not result.alias: # pragma: no cover
203-
inner = result.this
204-
if isinstance(inner, exp.Expression):
205-
return inner
206-
return result # type: ignore[return-value]
242+
return results[0] # type: ignore[return-value]
207243

208244
# -- quality checks -----------------------------------------------------
209245

210246
def _is_degraded(self, result: exp.Expression, clean_sql: str) -> bool:
211-
"""Return ``True`` when a better dialect should be tried."""
247+
"""Return ``True`` when the parse result is low quality.
248+
249+
A degraded result means the dialect parsed the SQL without
250+
raising, but the AST is suspicious — either the whole statement
251+
collapsed into an opaque ``exp.Command`` (when it should not
252+
have) or :meth:`_has_parse_issues` found placeholder-like table
253+
or column names. When ``True``, :meth:`_try_dialects` skips
254+
this dialect and moves on to the next candidate.
255+
256+
:param result: Root AST node from :meth:`_parse_with_dialect`.
257+
:param clean_sql: Original cleaned SQL (needed to check whether
258+
``exp.Command`` is expected).
259+
:returns: ``True`` if the result should be discarded in favour
260+
of the next dialect.
261+
"""
212262
if isinstance(result, exp.Command) and not self._is_expected_command(clean_sql):
213263
return True
214264
return self._has_parse_issues(result)
215265

216266
@staticmethod
217267
def _is_expected_command(sql: str) -> bool:
218-
"""Check whether *sql* legitimately parses as ``exp.Command``."""
268+
"""Return ``True`` when *sql* legitimately parses as ``exp.Command``.
269+
270+
Some dialect-specific DDL (e.g. Hive ``CREATE FUNCTION … USING
271+
JAR … WITH SERDEPROPERTIES``) is not supported by any sqlglot
272+
dialect and always degrades to ``exp.Command``. This method
273+
whitelists those known cases so :meth:`_is_degraded` does not
274+
reject them.
275+
276+
:param sql: Cleaned SQL string.
277+
:returns: ``True`` if ``exp.Command`` is the expected result.
278+
"""
219279
upper = sql.strip().upper()
220280
return upper.startswith("CREATE FUNCTION")
221281

222282
@staticmethod
223283
def _has_parse_issues(ast: exp.Expression) -> bool:
224-
"""Detect signs of a degraded or incorrect parse.
284+
"""Walk the AST looking for signs of a degraded or incorrect parse.
285+
286+
When sqlglot misinterprets a query it often places SQL keywords
287+
(``UNIQUE``, ``DISTINCT``, etc.) into column or table name
288+
positions, or produces table nodes with empty names. This
289+
method scans all :class:`~sqlglot.exp.Table` and
290+
:class:`~sqlglot.exp.Column` nodes for those telltale patterns.
225291
226-
Checks for table nodes with empty/keyword-like names and column
227-
nodes whose name is a SQL keyword without a table qualifier.
292+
:param ast: Root AST node to inspect.
293+
:returns: ``True`` if suspicious nodes were found.
228294
"""
229295
for table in ast.find_all(exp.Table):
230296
if table.name in _BAD_TABLE_NAMES:

sql_metadata/nested_resolver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def resolve(
232232
2. Drop unqualified column names that are actually aliases defined
233233
inside a nested query.
234234
235-
Also applies the same resolution to *columns_dict*.
235+
Also applies the resolution to *columns_dict*,
236+
but there instead of dropping we also resolve unqualified aliases.
236237
237238
Example SQL::
238239

sql_metadata/query_type_extractor.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
"""Extract the query type from a sqlglot AST root node.
22
33
The :class:`QueryTypeExtractor` class maps the top-level AST node to a
4-
:class:`QueryType` enum value, handling parenthesised wrappers, set
5-
operations, and opaque ``Command`` nodes.
4+
:class:`QueryType` enum value, handling set operations (``UNION``,
5+
``INTERSECT``, ``EXCEPT``) and opaque ``Command`` nodes that sqlglot
6+
cannot fully parse (e.g. Hive DDL).
67
"""
78

89
import logging
910
from typing import NoReturn
1011

1112
from sqlglot import exp
1213

14+
from sql_metadata.comments import strip_comments
1315
from sql_metadata.exceptions import InvalidQueryDefinition
1416
from sql_metadata.keywords_lists import QueryType
1517

@@ -36,8 +38,15 @@
3638
class QueryTypeExtractor:
3739
"""Determine the query type from a sqlglot AST root node.
3840
39-
:param ast: Root AST node (may be ``None``).
40-
:param raw_query: Original SQL string (for error messages).
41+
Maps the root AST node type to a :class:`QueryType` enum value
42+
using :data:`_SIMPLE_TYPE_MAP` for common statement types. Falls
43+
back to command-text inspection for opaque ``exp.Command`` nodes
44+
that sqlglot could not fully parse.
45+
46+
:param ast: Root AST node produced by :class:`DialectParser`, or
47+
``None`` when the input was empty or comment-only.
48+
:param raw_query: Original SQL string, kept for error messages
49+
when the AST is ``None``.
4150
"""
4251

4352
def __init__(
@@ -51,14 +60,20 @@ def __init__(
5160
def extract(self) -> QueryType:
5261
"""Determine the :class:`QueryType` for the parsed SQL.
5362
63+
Checks the root node type against :data:`_SIMPLE_TYPE_MAP` first.
64+
A bare ``exp.With`` node (CTE without a main statement) is rejected
65+
as invalid SQL. ``exp.Command`` nodes are forwarded to
66+
:meth:`_resolve_command_type` for command-text inspection.
67+
5468
:returns: The detected query type.
55-
:raises ValueError: If the query is empty, malformed, or
56-
unsupported.
69+
:raises InvalidQueryDefinition: If the AST is ``None`` (empty or
70+
comment-only input), the query is a bare ``WITH`` clause, or
71+
the root node type is not recognised.
5772
"""
5873
if self._ast is None:
5974
self._raise_for_none_ast()
6075

61-
root = self._unwrap_parens(self._ast)
76+
root = self._ast
6277
node_type = type(root)
6378

6479
if node_type is exp.With:
@@ -79,32 +94,39 @@ def extract(self) -> QueryType:
7994
logger.error("Not supported query type: %s", shorten_query)
8095
raise InvalidQueryDefinition("Not supported query type!")
8196

82-
@staticmethod
83-
def _unwrap_parens(ast: exp.Expression) -> exp.Expression:
84-
"""Remove Paren and Subquery wrappers to reach the real statement."""
85-
# TODO: revisit if sqlglot stops stripping outer parens before this is called
86-
if isinstance(ast, (exp.Paren, exp.Subquery)): # pragma: no cover
87-
return QueryTypeExtractor._unwrap_parens(ast.this)
88-
return ast
89-
9097
@staticmethod
9198
def _resolve_command_type(root: exp.Expression) -> QueryType | None:
92-
"""Determine query type for an opaque ``exp.Command`` node.
93-
94-
Hive ``CREATE FUNCTION ... USING JAR ... WITH SERDEPROPERTIES``
95-
is not supported by any sqlglot dialect and degrades to
96-
``exp.Command(this='CREATE', ...)``. This fallback extracts
97-
the query type from the command text so callers still get
98-
``QueryType.CREATE``.
99+
"""Extract query type from the command text of an opaque node.
100+
101+
Some dialect-specific DDL (e.g. Hive
102+
``CREATE FUNCTION ... USING JAR ... WITH SERDEPROPERTIES``) is
103+
not supported by any sqlglot dialect and degrades to
104+
``exp.Command(this='CREATE', ...)``. This method reads the
105+
``this`` attribute of the command node and maps it back to the
106+
corresponding :class:`QueryType`, so callers still get
107+
``QueryType.CREATE`` instead of an unsupported-type error.
108+
109+
:param root: An ``exp.Command`` AST node.
110+
:returns: The resolved :class:`QueryType`, or ``None`` if the
111+
command text does not match any known type.
99112
"""
100113
expression_text = str(root.this).upper() if root.this else ""
101114
if expression_text == "CREATE":
102115
return QueryType.CREATE
103116
return None
104117

105118
def _raise_for_none_ast(self) -> "NoReturn":
106-
"""Raise an appropriate error when the AST is None."""
107-
from sql_metadata.comments import strip_comments
119+
"""Raise a descriptive error when the AST is ``None``.
120+
121+
Distinguishes between truly empty input (blank or comment-only
122+
SQL) and SQL that has content but could not be parsed by
123+
sqlglot. In the first case a "not supported" message is raised;
124+
in the second a "could not parse" message points the caller
125+
toward a syntax problem.
126+
127+
:raises InvalidQueryDefinition: Always — this method never
128+
returns normally.
129+
"""
108130

109131
stripped = strip_comments(self._raw_query) if self._raw_query else ""
110132
if stripped.strip():

0 commit comments

Comments
 (0)