Skip to content

Commit b5ba2f2

Browse files
committed
additional changes and optimizations after initial review
1 parent fbf3014 commit b5ba2f2

7 files changed

Lines changed: 36 additions & 26 deletions

File tree

sql_metadata/ast_parser.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def ast(self) -> exp.Expression | None:
4848
self._ast = self._parse(self._raw_sql)
4949
return self._ast
5050

51+
def _ensure_parsed(self) -> None:
52+
"""Trigger lazy parsing so side-effect fields are populated."""
53+
_ = self.ast
54+
5155
@property
5256
def dialect(self) -> DialectType:
5357
"""The sqlglot dialect that produced the current AST.
@@ -58,7 +62,7 @@ def dialect(self) -> DialectType:
5862
5963
:rtype: DialectType
6064
"""
61-
_ = self.ast
65+
self._ensure_parsed()
6266
return self._dialect
6367

6468
@property
@@ -72,7 +76,7 @@ def is_replace(self) -> bool:
7276
7377
:rtype: bool
7478
"""
75-
_ = self.ast
79+
self._ensure_parsed()
7680
return self._is_replace
7781

7882
@property
@@ -84,7 +88,7 @@ def cte_name_map(self) -> dict[str, str]:
8488
8589
:rtype: dict[str, str]
8690
"""
87-
_ = self.ast
91+
self._ensure_parsed()
8892
return self._cte_name_map
8993

9094
def _parse(self, sql: str) -> exp.Expression | None:

sql_metadata/column_extractor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __init__(
251251
self._table_aliases = table_aliases
252252
self._cte_name_map = cte_name_map or {}
253253
self._collector = _Collector()
254-
self._reverse_cte_map = self._cte_name_map
254+
self._cte_restore_map = self._cte_name_map
255255

256256
# -------------------------------------------------------------------
257257
# Public API
@@ -290,7 +290,7 @@ def extract(self) -> ExtractionResult:
290290
# Restore qualified CTE names (reverse placeholder mapping)
291291
final_cte = UniqueList()
292292
for name in c.cte_names:
293-
final_cte.append(self._reverse_cte_map.get(name, name))
293+
final_cte.append(self._cte_restore_map.get(name, name))
294294

295295
alias_dict = c.alias_dict
296296
return ExtractionResult(
@@ -327,7 +327,7 @@ def _seed_cte_names(self) -> None:
327327
alias = cte.alias
328328
if alias:
329329
self._collector.cte_names.append(
330-
self._reverse_cte_map.get(alias, alias)
330+
self._cte_restore_map.get(alias, alias)
331331
)
332332

333333
def _build_subquery_names(self) -> UniqueList:

sql_metadata/parser.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sql_metadata.column_extractor import ColumnExtractor
2222
from sql_metadata.comments import extract_comments, strip_comments
2323
from sql_metadata.generalizator import Generalizator
24+
from sql_metadata.keywords_lists import QueryType
2425
from sql_metadata.nested_resolver import NestedResolver
2526
from sql_metadata.query_type_extractor import QueryTypeExtractor
2627
from sql_metadata.sql_cleaner import SqlCleaner
@@ -47,7 +48,7 @@ def __init__(self, sql: str = "", disable_logging: bool = False) -> None:
4748
self._logger.disabled = disable_logging
4849

4950
self._raw_query = sql
50-
self._query_type: str | None = None
51+
self._query_type: QueryType | None = None
5152

5253
self._ast_parser = ASTParser(sql)
5354
self._resolver: NestedResolver | None = None
@@ -104,7 +105,7 @@ def query(self) -> str:
104105
return SqlCleaner.preprocess_query(self._raw_query)
105106

106107
@property
107-
def query_type(self) -> str:
108+
def query_type(self) -> QueryType | None:
108109
"""Return the type of the SQL query.
109110
110111
Delegates to :class:`QueryTypeExtractor` which maps the top-level
@@ -113,7 +114,7 @@ def query_type(self) -> str:
113114
(unparseable SQL), the extractor falls back to keyword matching on
114115
the raw string.
115116
116-
:rtype: str
117+
:rtype: QueryType | None
117118
"""
118119
if self._query_type:
119120
return self._query_type
@@ -320,7 +321,7 @@ def tables_aliases(self) -> dict[str, str]:
320321
if self._table_aliases is not None:
321322
return self._table_aliases
322323
ast = self._ast_parser.ast
323-
assert ast is not None # guaranteed by prior tables/query_type access
324+
assert ast is not None # mypy
324325
extractor = TableExtractor(ast)
325326
self._table_aliases = extractor.extract_aliases(self.tables)
326327
return self._table_aliases

test/test_edge_cases.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
"""Edge-case tests for internals not covered by feature-specific test files."""
1+
"""Edge-case tests for internal utilities.
2+
3+
These tests exercise code paths (depth guards, degenerate inputs) that
4+
are difficult or impossible to trigger through the public Parser API.
5+
They test internal symbols directly and may need updating if those
6+
internals are refactored.
7+
"""
28

39
from sql_metadata import Parser
410
from sql_metadata.sql_cleaner import SqlCleaner, _strip_outer_parens

test/test_getting_columns.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,4 +772,5 @@ def test_cte_with_table_star_in_body():
772772
"WITH cte(a) AS (SELECT t.* FROM t) "
773773
"SELECT a FROM cte"
774774
)
775-
assert "t.*" in p.columns or "a" in p.columns_aliases_names
775+
assert p.columns == ["t.*"]
776+
assert p.columns_aliases_names == ["a"]

test/test_getting_tables.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -943,11 +943,9 @@ def test_on_keyword_not_table_alias():
943943
def test_unmatched_parentheses_graceful():
944944
# solved: https://github.com/macbre/sql-metadata/issues/532
945945
# Should not raise IndexError; graceful handling of malformed SQL
946-
try:
947-
parser = Parser("SELECT arrayJoin(tags.key)) FROM foo")
948-
_ = parser.tables
949-
except (ValueError, Exception):
950-
pass
946+
parser = Parser("SELECT arrayJoin(tags.key)) FROM foo")
947+
tables = parser.tables
948+
assert isinstance(tables, list)
951949

952950

953951
def test_degraded_parse_falls_through_to_last_dialect():

test/test_query_type.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,25 @@
66
def test_insert_query():
77
queries = [
88
"INSERT IGNORE /* foo */ INTO bar VALUES (1, '123', '2017-01-01');",
9-
"/* foo */ INSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');"
10-
"-- foo\nINSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');"
9+
"/* foo */ INSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');",
10+
"-- foo\nINSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');",
1111
"# foo\nINSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');",
1212
]
1313

1414
for query in queries:
15-
assert "INSERT" == Parser(query).query_type
15+
assert QueryType.INSERT == Parser(query).query_type
1616

1717

1818
def test_select_query():
1919
queries = [
2020
"SELECT /* foo */ foo FROM bar",
21-
"/* foo */ SELECT foo FROM bar"
22-
"-- foo\nSELECT foo FROM bar"
21+
"/* foo */ SELECT foo FROM bar",
22+
"-- foo\nSELECT foo FROM bar",
2323
"# foo\nSELECT foo FROM bar",
2424
]
2525

2626
for query in queries:
27-
assert "SELECT" == Parser(query).query_type
27+
assert QueryType.SELECT == Parser(query).query_type
2828

2929

3030
def test_delete_query():
@@ -35,7 +35,7 @@ def test_delete_query():
3535

3636
for query in queries:
3737
for comment in ["", "/* foo */", "\n--foo\n", "\n# foo\n"]:
38-
assert "DELETE" == Parser(query.format(comment)).query_type
38+
assert QueryType.DELETE == Parser(query.format(comment)).query_type
3939

4040

4141
def test_drop_table_query():
@@ -45,7 +45,7 @@ def test_drop_table_query():
4545

4646
for query in queries:
4747
for comment in ["", "/* foo */", "\n--foo\n", "\n# foo\n"]:
48-
assert "DROP TABLE" == Parser(query.format(comment)).query_type
48+
assert QueryType.DROP == Parser(query.format(comment)).query_type
4949

5050

5151
def test_unsupported_query(caplog):
@@ -171,7 +171,7 @@ def test_unrecognized_command_type():
171171
def test_deeply_parenthesized_query():
172172
"""Triple-parenthesized SELECT parses correctly."""
173173
p = Parser("(((SELECT col FROM t)))")
174-
assert p.query_type == "SELECT"
174+
assert p.query_type == QueryType.SELECT
175175
assert p.tables == ["t"]
176176
assert p.columns == ["col"]
177177

0 commit comments

Comments
 (0)