diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index fba814fc3..fc243a458 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -14,10 +14,6 @@ Phase 1 gaps documented in the integration design: -- Parameter binding (``parameters=[TSparkParameter, ...]``) is not - yet supported — the PyO3 ``Statement`` doesn't expose - ``bind_param``. ``execute_command(parameters=[...])`` raises - ``NotSupportedError``. - ``query_tags`` on execute is not supported (kernel exposes ``statement_conf`` but PyO3 doesn't surface it). - ``get_tables`` with a non-empty ``table_types`` filter applies @@ -231,11 +227,6 @@ def execute_command( ) -> Union["ResultSet", None]: if self._kernel_session is None: raise InterfaceError("Cannot execute_command without an open session.") - if parameters: - raise NotSupportedError( - "Parameter binding is not yet supported on the kernel backend " - "(PyO3 Statement.bind_param lands in a follow-up PR)." - ) if query_tags: raise NotSupportedError( "Statement-level query_tags are not yet supported on the kernel backend." @@ -248,6 +239,15 @@ def execute_command( try: try: stmt.set_sql(operation) + if parameters: + # Lazy import — type_mapping touches pyarrow at + # module load; keep ``execute_command`` callable + # from contexts that don't yet need it. + from databricks.sql.backend.kernel.type_mapping import ( + bind_tspark_params, + ) + + bind_tspark_params(stmt, parameters) if async_op: async_exec = stmt.submit() command_id = CommandId.from_sea_statement_id( diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py index bedcdcebd..f53662e4d 100644 --- a/src/databricks/sql/backend/kernel/type_mapping.py +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -13,19 +13,20 @@ the kernel receives Arrow schemas directly), so the mapping function stays local but the names are shared. -Parameter binding (``TSparkParameter`` → kernel ``TypedValue``) is -not yet implemented — the PyO3 ``Statement`` doesn't expose a -``bind_param`` method on this branch. It'll land in a follow-up -once that PyO3 surface ships. +Parameter binding (``TSparkParameter`` → kernel +``Statement.bind_param``) is handled by ``bind_tspark_params`` — +forwards the connector's already-string-encoded form to the kernel +binding without an intermediate Python-typed round-trip. """ from __future__ import annotations -from typing import List, Tuple +from typing import Any, List, Tuple import pyarrow from databricks.sql.backend.sea.utils.conversion import SqlType +from databricks.sql.thrift_api.TCLIService import ttypes def _arrow_type_to_dbapi_string(arrow_type: pyarrow.DataType) -> str: @@ -92,3 +93,55 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: ) for field in schema ] + + +def _tspark_param_value_str(param: ttypes.TSparkParameter) -> Any: + """Extract the string-encoded value from a ``TSparkParameter``, + or ``None`` for SQL NULL. + + Native parameters (``IntegerParameter`` etc.) always wrap their + value in ``TSparkParameterValue(stringValue=str(self.value))``; + ``VoidParameter`` sets ``stringValue="None"`` but the type is + ``"VOID"`` — the kernel-side parser ignores the value when the + type is VOID, so we don't have to special-case here. + """ + if param.value is None: + return None + return param.value.stringValue + + +def bind_tspark_params(kernel_stmt, parameters: List[ttypes.TSparkParameter]) -> None: + """Bind a list of ``TSparkParameter`` onto a kernel ``Statement``. + + The kernel expects positional bindings only (SEA v0 doesn't + accept named bindings on the wire). The connector's + ``TSparkParameter`` has an ``ordinal: bool`` flag; ``True`` means + "treat as positional in source-list order". Native bindings + almost always come through positional today; named-binding + parameters surface as ``NotSupportedError`` so the user gets a + clear message instead of a server-side rejection. + + Compound types (``ARRAY`` / ``MAP`` / ``STRUCT``) are routed + through the kernel parser which currently rejects them — same + user-visible message ("compound parameter types … are not yet + supported"). Tracked as a follow-up. + """ + for i, param in enumerate(parameters, start=1): + # The connector's `ordinal` field is a bool (True/False) on + # native params and indicates positional vs named. Named + # params can't flow through the kernel today; raise early + # rather than letting the server reject. + if getattr(param, "ordinal", None) is False and getattr(param, "name", None): + from databricks.sql.exc import NotSupportedError + + raise NotSupportedError( + f"Named parameter binding (got name={param.name!r}) is not yet " + "supported on the kernel backend; pass parameters positionally." + ) + + sql_type = param.type or "STRING" + value_str = _tspark_param_value_str(param) + # The kernel takes 1-based ordinals; `i` is already that. + # Errors from the kernel side (bad literal, unsupported type, + # etc.) come up as KernelError and bubble through normally. + kernel_stmt.bind_param(i, value_str, sql_type) diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py index 67f6e858d..d2c0c9b9c 100644 --- a/tests/e2e/test_kernel_backend.py +++ b/tests/e2e/test_kernel_backend.py @@ -199,3 +199,69 @@ def test_bad_sql_surfaces_as_databaseerror(conn): # Structured fields copied off the kernel exception: assert getattr(err, "code", None) == "SqlError" assert getattr(err, "sql_state", None) == "42P01" + + +# ── Parameter binding ───────────────────────────────────────────── + + +def test_parameterized_query_round_trips(conn): + """Positional parameter binding via the kernel backend. The + connector's native parameter classes (IntegerParameter etc.) + serialize to TSparkParameter under the hood; the kernel + backend's mapper forwards them positionally to the kernel. + """ + from databricks.sql.parameters.native import ( + IntegerParameter, + StringParameter, + BooleanParameter, + ) + + with conn.cursor() as cur: + cur.execute( + "SELECT ? AS i, ? AS s, ? AS b", + [ + IntegerParameter(42), + StringParameter("alice"), + BooleanParameter(True), + ], + ) + rows = cur.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 42 + assert rows[0][1] == "alice" + assert rows[0][2] is True + + +def test_parameterized_query_with_null(conn): + """`None` in the parameter list flows through as VoidParameter + → kernel TypedValue::Null.""" + with conn.cursor() as cur: + cur.execute("SELECT ? IS NULL AS is_null", [None]) + rows = cur.fetchall() + assert rows[0][0] is True + + +def test_parameterized_query_decimal(conn): + """DECIMAL parameters carry precision/scale in the SQL type + string ('DECIMAL(p,s)') — the kernel parser extracts them so + fractional digits survive the wire. + + Uses the connector's auto-inference path + (`calculate_decimal_cast_string`) to derive precision/scale + from the value; the explicit-arg path + (`DecimalParameter(v, scale=, precision=)`) has a pre-existing + bug in this branch where the format-args are passed + `(scale, precision)` instead of `(precision, scale)` — out of + scope for this PR. + """ + import decimal + from databricks.sql.parameters.native import DecimalParameter + + with conn.cursor() as cur: + cur.execute( + "SELECT ? AS d", + [DecimalParameter(decimal.Decimal("-123.45"))], + ) + rows = cur.fetchall() + # Server echoes back as decimal.Decimal. + assert str(rows[0][0]) == "-123.45" diff --git a/tests/unit/test_kernel_client.py b/tests/unit/test_kernel_client.py index f43d8c7c7..a9e9c9090 100644 --- a/tests/unit/test_kernel_client.py +++ b/tests/unit/test_kernel_client.py @@ -234,25 +234,57 @@ def test_open_session_rejects_double_open(monkeypatch): c.open_session(session_configuration=None, catalog=None, schema=None) -def test_execute_command_rejects_parameters(): +def test_execute_command_forwards_parameters_to_bind_param(): + """``execute_command(parameters=[...])`` routes each parameter + through ``bind_tspark_params`` onto the kernel statement before + ``execute()`` is called. Replaces the prior ``NotSupportedError`` + rejection now that the kernel-side ``Statement.bind_param`` is + live (kernel PR #18).""" + from databricks.sql.thrift_api.TCLIService import ttypes + c = _make_client() c._kernel_session = MagicMock() cursor = MagicMock() cursor.arraysize = 100 cursor.buffer_size_bytes = 1024 - with pytest.raises(NotSupportedError, match="Parameter binding"): - c.execute_command( - operation="SELECT ?", - session_id=MagicMock(), - max_rows=1, - max_bytes=1, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[object()], # any non-empty list - async_op=False, - enforce_embedded_schema_correctness=False, - ) + + # Stub the statement chain so we can observe bind_param calls + # without exercising the full ExecutedStatement → arrow_schema() + # path (that's covered elsewhere). + stmt = MagicMock() + stmt.bind_param = MagicMock() + stmt.execute.return_value = MagicMock( + statement_id="stmt-id", + arrow_schema=MagicMock(return_value=pa.schema([("x", pa.int64())])), + ) + c._kernel_session.statement.return_value = stmt + + p1 = ttypes.TSparkParameter(ordinal=True, name=None, type="INT") + p1.value = ttypes.TSparkParameterValue(stringValue="42") + p2 = ttypes.TSparkParameter(ordinal=True, name=None, type="STRING") + p2.value = ttypes.TSparkParameterValue(stringValue="hello") + + c.execute_command( + operation="SELECT ?, ?", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[p1, p2], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # bind_param was called once per TSparkParameter, in order, with + # 1-based ordinals. + assert stmt.bind_param.call_args_list == [ + ((1, "42", "INT"), {}), + ((2, "hello", "STRING"), {}), + ] + # …and execute fired after binding. + assert stmt.execute.called def test_execute_command_rejects_query_tags(): diff --git a/tests/unit/test_kernel_type_mapping.py b/tests/unit/test_kernel_type_mapping.py index 82f62559a..30cca9425 100644 --- a/tests/unit/test_kernel_type_mapping.py +++ b/tests/unit/test_kernel_type_mapping.py @@ -84,3 +84,102 @@ def test_description_from_schema_reports_non_nullable_fields(): desc = description_from_arrow_schema(schema) assert desc[0][6] is False assert desc[1][6] is True + + +# ─── bind_tspark_params ────────────────────────────────────────────────── + + +def _mk_param(*, type, value, ordinal=True, name=None): + """Build a minimal TSparkParameter for tests.""" + from databricks.sql.thrift_api.TCLIService import ttypes + + p = ttypes.TSparkParameter(ordinal=ordinal, name=name, type=type) + p.value = ttypes.TSparkParameterValue(stringValue=value) if value is not None else None + return p + + +class _RecordingStmt: + """Stand-in for the kernel `Statement` pyclass — records every + `bind_param` call so tests can assert the (ordinal, value, type) + triples the mapper forwarded.""" + + def __init__(self): + self.calls = [] + + def bind_param(self, ordinal, value_str, sql_type): + self.calls.append((ordinal, value_str, sql_type)) + + +def test_bind_tspark_params_forwards_each_param_positionally(): + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + params = [ + _mk_param(type="INT", value="42"), + _mk_param(type="STRING", value="alice"), + _mk_param(type="DATE", value="2026-05-15"), + ] + stmt = _RecordingStmt() + bind_tspark_params(stmt, params) + assert stmt.calls == [ + (1, "42", "INT"), + (2, "alice", "STRING"), + (3, "2026-05-15", "DATE"), + ] + + +def test_bind_tspark_params_null_value(): + """TSparkParameter with value=None → kernel sees value_str=None, + interpreted as SQL NULL regardless of the SQL type.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + p = _mk_param(type="STRING", value=None) + stmt = _RecordingStmt() + bind_tspark_params(stmt, [p]) + assert stmt.calls == [(1, None, "STRING")] + + +def test_bind_tspark_params_void_passes_through(): + """VoidParameter sets type='VOID' with stringValue='None'; the + kernel parser ignores the value when type=VOID.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + p = _mk_param(type="VOID", value="None") + stmt = _RecordingStmt() + bind_tspark_params(stmt, [p]) + assert stmt.calls == [(1, "None", "VOID")] + + +def test_bind_tspark_params_named_param_rejected(): + """The kernel doesn't accept named bindings on the SEA wire; + surface that at the connector layer so the user gets a pointed + error instead of a server-side rejection.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + from databricks.sql.exc import NotSupportedError + + p = _mk_param(type="INT", value="42", ordinal=False, name="my_param") + stmt = _RecordingStmt() + with pytest.raises(NotSupportedError, match="(?i)named"): + bind_tspark_params(stmt, [p]) + # Nothing should have been forwarded before the rejection. + assert stmt.calls == [] + + +def test_bind_tspark_params_missing_type_defaults_to_string(): + """Defensive: a TSparkParameter with no `type` shouldn't crash + the mapper — fall back to STRING and let the kernel parse.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + from databricks.sql.thrift_api.TCLIService import ttypes + + p = ttypes.TSparkParameter(ordinal=True, name=None, type=None) + p.value = ttypes.TSparkParameterValue(stringValue="hello") + stmt = _RecordingStmt() + bind_tspark_params(stmt, [p]) + assert stmt.calls == [(1, "hello", "STRING")] + + +def test_bind_tspark_params_empty_list_is_noop(): + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + stmt = _RecordingStmt() + bind_tspark_params(stmt, []) + assert stmt.calls == []