Skip to content

Commit c4e1872

Browse files
committed
Allow different 'omit trailing semicolon' behavior pro connection
1 parent 7b8a3bc commit c4e1872

2 files changed

Lines changed: 35 additions & 18 deletions

File tree

src/DatabaseLibrary/connection_manager.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
class Connection:
2929
client: Any
3030
module_name: str
31+
omit_trailing_semicolon: bool
3132

3233

3334
class ConnectionStore:
@@ -36,13 +37,13 @@ def __init__(self, warn_on_overwrite=True):
3637
self.default_alias: str = "default"
3738
self.warn_on_overwrite = warn_on_overwrite
3839

39-
def register_connection(self, client: Any, module_name: str, alias: str):
40+
def register_connection(self, client: Any, module_name: str, alias: str, omit_trailing_semicolon=False):
4041
if alias in self._connections and self.warn_on_overwrite:
4142
if alias == self.default_alias:
4243
logger.warn("Overwriting not closed connection.")
4344
else:
4445
logger.warn(f"Overwriting not closed connection for alias = '{alias}'")
45-
self._connections[alias] = Connection(client, module_name)
46+
self._connections[alias] = Connection(client, module_name, omit_trailing_semicolon)
4647

4748
def get_connection(self, alias: Optional[str]) -> Connection:
4849
"""
@@ -142,7 +143,6 @@ class ConnectionManager:
142143
"""
143144

144145
def __init__(self, warn_on_connection_overwrite=True):
145-
self.omit_trailing_semicolon: bool = False
146146
self.connection_store: ConnectionStore = ConnectionStore(warn_on_overwrite=warn_on_connection_overwrite)
147147
self.ibmdb_driver_already_added_to_path: bool = False
148148

@@ -320,6 +320,8 @@ def _arg_or_config(arg_value, param_name, *, old_param_name=None, mandatory=Fals
320320
if other_config_file_params:
321321
logger.info(f"Other params from configuration file: {list(other_config_file_params.keys())}")
322322

323+
omit_trailing_semicolon = False
324+
323325
if db_module == "excel" or db_module == "excelrw":
324326
db_api_module_name = "pyodbc"
325327
else:
@@ -444,7 +446,7 @@ def _arg_or_config(arg_value, param_name, *, old_param_name=None, mandatory=Fals
444446
con_params = _build_connection_params(user=db_user, password=db_password, dsn=oracle_dsn)
445447
_log_all_connection_params(**con_params)
446448
db_connection = db_api_2.connect(**con_params)
447-
self.omit_trailing_semicolon = True
449+
omit_trailing_semicolon = True
448450

449451
elif db_module in ["oracledb"]:
450452
db_port = db_port or 1521
@@ -473,7 +475,7 @@ def _arg_or_config(arg_value, param_name, *, old_param_name=None, mandatory=Fals
473475
f"Expected oracledb to run in thin mode: {oracle_thin_mode}, "
474476
f"but the connection has thin mode: {db_connection.thin}"
475477
)
476-
self.omit_trailing_semicolon = True
478+
omit_trailing_semicolon = True
477479

478480
elif db_module in ["teradata"]:
479481
db_port = db_port or 1025
@@ -505,7 +507,7 @@ def _arg_or_config(arg_value, param_name, *, old_param_name=None, mandatory=Fals
505507
_log_all_connection_params(**con_params)
506508
db_connection = db_api_2.connect(**con_params)
507509

508-
self.connection_store.register_connection(db_connection, db_api_module_name, alias)
510+
self.connection_store.register_connection(db_connection, db_api_module_name, alias, omit_trailing_semicolon)
509511

510512
@renamed_args(mapping={"dbapiModuleName": "db_module"})
511513
def connect_to_database_using_custom_params(

src/DatabaseLibrary/query.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,12 @@ def query(
8181
cur = None
8282
try:
8383
cur = db_connection.client.cursor()
84-
self._execute_sql(cur, select_statement, parameters=parameters)
84+
self._execute_sql(
85+
cur,
86+
select_statement,
87+
parameters=parameters,
88+
omit_trailing_semicolon=db_connection.omit_trailing_semicolon,
89+
)
8590
all_rows = cur.fetchall()
8691
col_names = [c[0] for c in cur.description]
8792
self._log_query_results(col_names, all_rows)
@@ -130,7 +135,12 @@ def row_count(
130135
cur = None
131136
try:
132137
cur = db_connection.client.cursor()
133-
self._execute_sql(cur, select_statement, parameters=parameters)
138+
self._execute_sql(
139+
cur,
140+
select_statement,
141+
parameters=parameters,
142+
omit_trailing_semicolon=db_connection.omit_trailing_semicolon,
143+
)
134144
data = cur.fetchall()
135145
col_names = [c[0] for c in cur.description]
136146
if db_connection.module_name in ["sqlite3", "ibm_db", "ibm_db_dbi", "pyodbc", "jaydebeapi"]:
@@ -182,7 +192,12 @@ def description(
182192
cur = None
183193
try:
184194
cur = db_connection.client.cursor()
185-
self._execute_sql(cur, select_statement, parameters=parameters)
195+
self._execute_sql(
196+
cur,
197+
select_statement,
198+
parameters=parameters,
199+
omit_trailing_semicolon=db_connection.omit_trailing_semicolon,
200+
)
186201
description = list(cur.description)
187202
if sys.version_info[0] < 3:
188203
for row in range(0, len(description)):
@@ -276,7 +291,11 @@ def execute_sql_script(
276291
cur = db_connection.client.cursor()
277292
if not split:
278293
logger.info("Statements splitting disabled - pass entire script content to the database module")
279-
self._execute_sql(cur, sql_file.read())
294+
self._execute_sql(
295+
cur,
296+
sql_file.read(),
297+
omit_trailing_semicolon=db_connection.omit_trailing_semicolon,
298+
)
280299
else:
281300
logger.info("Splitting script file into statements...")
282301
statements_to_execute = []
@@ -377,11 +396,7 @@ def execute_sql_string(
377396
Use ``parameters`` for query variable substitution (variable substitution syntax may be different
378397
depending on the database client).
379398
380-
Use ``omit_trailing_semicolon`` for explicit instruction,
381-
if the trailing semicolon (;) at the SQL string end should be removed or not:
382-
- Some database modules (e.g. Oracle) throw an exception, if you leave a semicolon at the string end
383-
- However, there are exceptional cases, when you need it even for Oracle - e.g. at the end of a PL/SQL block
384-
- If not explicitly specified, it's decided based on the current database module in use. For Oracle, the semicolon is removed by default.
399+
Set the ``omit_trailing_semicolon`` to explicitly control the `Omitting trailing semicolon behavior` for the command.
385400
386401
=== Some parameters were renamed in version 2.0 ===
387402
The old parameters ``sqlString``, ``sansTran`` and ``omitTrailingSemicolon`` are *deprecated*,
@@ -401,6 +416,8 @@ def execute_sql_string(
401416
cur = None
402417
try:
403418
cur = db_connection.client.cursor()
419+
if omit_trailing_semicolon is None:
420+
omit_trailing_semicolon = db_connection.omit_trailing_semicolon
404421
self._execute_sql(cur, sql_string, omit_trailing_semicolon=omit_trailing_semicolon, parameters=parameters)
405422
self._commit_if_needed(db_connection, no_transaction)
406423
except Exception as e:
@@ -734,7 +751,7 @@ def _execute_sql(
734751
self,
735752
cur,
736753
sql_statement: str,
737-
omit_trailing_semicolon: Optional[bool] = None,
754+
omit_trailing_semicolon: Optional[bool] = False,
738755
parameters: Optional[Tuple] = None,
739756
):
740757
"""
@@ -744,8 +761,6 @@ def _execute_sql(
744761
won't be executed by some databases (e.g. Oracle).
745762
Otherwise, it's decided based on the current database module in use.
746763
"""
747-
if omit_trailing_semicolon is None:
748-
omit_trailing_semicolon = self.omit_trailing_semicolon
749764
if omit_trailing_semicolon:
750765
sql_statement = sql_statement.rstrip(";")
751766
if parameters is None:

0 commit comments

Comments
 (0)