@@ -6,12 +6,13 @@ class so that callers only need to call :meth:`DialectParser.parse`.
66"""
77
88import logging
9+ from typing import Optional
910
1011import sqlglot
1112from sqlglot import Dialect , exp
1213from sqlglot .dialects .tsql import TSQL
1314from sqlglot .errors import ParseError , TokenError
14- from sqlglot .tokens import Tokenizer
15+ from sqlglot .tokens import Tokenizer as BaseTokenizer
1516
1617from sql_metadata .comments import _has_hash_variables
1718
@@ -37,12 +38,12 @@ class HashVarDialect(Dialect):
3738 of a ``VAR`` token instead.
3839 """
3940
40- class Tokenizer (Tokenizer ):
41+ class Tokenizer (BaseTokenizer ):
4142 """Tokenizer subclass that includes ``#`` in variable tokens."""
4243
43- SINGLE_TOKENS = {** Tokenizer .SINGLE_TOKENS }
44+ SINGLE_TOKENS = {** BaseTokenizer .SINGLE_TOKENS }
4445 SINGLE_TOKENS .pop ("#" , None )
45- VAR_SINGLE_TOKENS = {* Tokenizer .VAR_SINGLE_TOKENS , "#" }
46+ VAR_SINGLE_TOKENS = {* BaseTokenizer .VAR_SINGLE_TOKENS , "#" }
4647
4748
4849class BracketedTableDialect (TSQL ):
@@ -63,7 +64,7 @@ class BracketedTableDialect(TSQL):
6364class DialectParser :
6465 """Detect the appropriate sqlglot dialect and parse SQL into an AST."""
6566
66- def parse (self , clean_sql : str ) -> tuple :
67+ def parse (self , clean_sql : str ) -> tuple [ exp . Expression , object ] :
6768 """Parse *clean_sql*, returning ``(ast, dialect)``.
6869
6970 Detects candidate dialects via heuristics, tries each in order,
@@ -112,7 +113,9 @@ def _detect_dialects(sql: str) -> list:
112113
113114 # -- parsing ------------------------------------------------------------
114115
115- def _try_dialects (self , clean_sql : str , dialects : list ) -> tuple :
116+ def _try_dialects (
117+ self , clean_sql : str , dialects : list
118+ ) -> tuple [exp .Expression , object ]:
116119 """Try parsing *clean_sql* with each dialect, returning the best.
117120
118121 :returns: 2-tuple of ``(ast_node, winning_dialect)``.
@@ -141,7 +144,7 @@ def _try_dialects(self, clean_sql: str, dialects: list) -> tuple:
141144 raise ValueError ("This query is wrong" )
142145
143146 @staticmethod
144- def _parse_with_dialect (clean_sql : str , dialect ) -> exp .Expression :
147+ def _parse_with_dialect (clean_sql : str , dialect ) -> Optional [ exp .Expression ] :
145148 """Parse *clean_sql* with a single dialect, suppressing warnings."""
146149 logger = logging .getLogger ("sqlglot" )
147150 old_level = logger .level
@@ -158,9 +161,13 @@ def _parse_with_dialect(clean_sql: str, dialect) -> exp.Expression:
158161 if not results or results [0 ] is None :
159162 return None
160163 result = results [0 ]
164+ if result is None :
165+ return None
161166 if isinstance (result , exp .Subquery ) and not result .alias :
162- result = result .this
163- return result
167+ inner = result .this
168+ if isinstance (inner , exp .Expression ):
169+ return inner
170+ return result # type: ignore[return-value]
164171
165172 # -- quality checks -----------------------------------------------------
166173
0 commit comments