diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index f10293311..9b2f8e8db 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -98,6 +98,17 @@ def __init__( self._auth_provider = auth_provider self._catalog = catalog self._schema = schema + # ``_use_arrow_native_complex_types`` is the connector-side + # toggle for whether complex columns (ARRAY / MAP / STRUCT) + # are surfaced as native Arrow shapes or as compact JSON + # strings. The Thrift backend forwards it server-side + # (``complexTypesAsArrow``); the kernel doesn't have a wire + # equivalent, so we flip the kernel's client-side + # ``complex_types_as_json`` post-processor to match. Default + # ``True`` mirrors the connector's existing default. + self._use_arrow_native_complex_types = kwargs.get( + "_use_arrow_native_complex_types", True + ) # NB: don't call ``kernel_auth_kwargs`` here. That call # materialises the bearer token in-process; keeping a # cleartext copy on a long-lived connector object that may @@ -155,6 +166,7 @@ def open_session( catalog=catalog or self._catalog, schema=schema or self._schema, session_conf=session_conf, + complex_types_as_json=not self._use_arrow_native_complex_types, **auth_kwargs, ) except Exception as exc: diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py index cc6ccc7b9..fc1a338cd 100644 --- a/src/databricks/sql/backend/kernel/type_mapping.py +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -123,15 +123,35 @@ def _databricks_type_for_field(field: pyarrow.Field) -> str: Consults the field's Arrow metadata under ``databricks.type_name`` (written by the kernel from the SEA response's column type) so types that collapse onto a generic - Arrow shape can still be distinguished. Today only ``VARIANT`` - is mapped; everything else delegates to - ``_arrow_type_to_dbapi_string``. + Arrow shape can still be distinguished. This matters in two + cases: + + - ``VARIANT`` (always ``Utf8`` on the wire — no Arrow shape + distinguishes it from ``STRING``). + - The ``complex_types_as_json`` post-processor rewrites + ``ARRAY`` / ``MAP`` / ``STRUCT`` columns to ``Utf8`` carrying + compact JSON text. The Thrift backend reports the original + SQL type in ``description`` even when ``complexTypesAsArrow`` + is off and the wire payload is a JSON string; we match that + by recovering the type name from manifest metadata. """ md = field.metadata or {} # `databricks.type_name` is bytes (Arrow metadata is always # bytes); compare against bytes to avoid one encode per field. - if md.get(b"databricks.type_name") == b"VARIANT": - return "variant" + type_name = md.get(b"databricks.type_name") + if type_name is not None: + # Lowercase to match the canonical SqlType slugs the Thrift + # backend produces (``"array"`` / ``"map"`` / ``"struct"`` / + # ``"variant"``). Other server-reported names (``"INT"`` etc.) + # would also pass through this branch but we deliberately + # don't honour them — the Arrow shape is the authoritative + # source for primitives, and the kernel's own type-name + # mapping (`map_databricks_type`) is conservative on some + # types (e.g. ``DECIMAL`` arrives as ``decimal`` on the + # Arrow side, which matches Thrift). + decoded = type_name.decode("ascii", errors="replace").lower() + if decoded in {"variant", "array", "map", "struct"}: + return decoded return _arrow_type_to_dbapi_string(field.type) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 97790e4d9..2910576f8 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -146,6 +146,7 @@ def _create_backend( http_client=self.http_client, catalog=kwargs.get("catalog"), schema=kwargs.get("schema"), + _use_arrow_native_complex_types=_use_arrow_native_complex_types, ) databricks_client_class: Type[DatabricksClient] diff --git a/tests/unit/test_kernel_client.py b/tests/unit/test_kernel_client.py index 0e2948284..541bf73a4 100644 --- a/tests/unit/test_kernel_client.py +++ b/tests/unit/test_kernel_client.py @@ -241,6 +241,44 @@ def test_open_session_rejects_double_open(monkeypatch): c.open_session(session_configuration=None, catalog=None, schema=None) +@pytest.mark.parametrize( + "kwargs, expected_flag", + [ + ({}, False), # default → arrow-native → kernel JSON off + ({"_use_arrow_native_complex_types": True}, False), + ({"_use_arrow_native_complex_types": False}, True), + ], +) +def test_open_session_passes_complex_types_as_json_to_kernel( + monkeypatch, kwargs, expected_flag +): + """``_use_arrow_native_complex_types=False`` flips the kernel's + ``complex_types_as_json`` post-processor on; the default and + explicit ``True`` both leave it off. The flag is inverted at the + boundary because the connector's option is "native Arrow"-shaped + and the kernel's is "rewrite to JSON strings"-shaped.""" + captured = {} + + def fake_session(**kw): + captured.update(kw) + sess = MagicMock() + sess.session_id = "sess-id" + return sess + + monkeypatch.setattr(kernel_client._kernel, "Session", fake_session) + + c = kernel_client.KernelDatabricksClient( + server_hostname="example.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc", + auth_provider=AccessTokenAuthProvider("dapi-test"), + ssl_options=None, + **kwargs, + ) + c.open_session(session_configuration=None, catalog=None, schema=None) + + assert captured.get("complex_types_as_json") is expected_flag + + def test_execute_command_forwards_parameters_to_bind_param(): """``execute_command(parameters=[...])`` routes each parameter through ``bind_tspark_params`` onto the kernel statement before diff --git a/tests/unit/test_kernel_type_mapping.py b/tests/unit/test_kernel_type_mapping.py index 7d1cce546..c859ffca1 100644 --- a/tests/unit/test_kernel_type_mapping.py +++ b/tests/unit/test_kernel_type_mapping.py @@ -114,6 +114,58 @@ def test_description_uses_databricks_type_name_for_variant(): assert desc[1][1] == "string" +@pytest.mark.parametrize( + "metadata_value, expected", + [ + (b"ARRAY", "array"), + (b"MAP", "map"), + (b"STRUCT", "struct"), + # Lowercase / mixed case both fine — server may report either. + (b"array", "array"), + (b"Struct", "struct"), + ], +) +def test_description_recovers_complex_type_name_from_metadata(metadata_value, expected): + """When ``complex_types_as_json`` rewrites a complex column to + ``Utf8``, the kernel preserves the original SQL type name under + ``databricks.type_name``. ``description`` must report that name + (matching the Thrift backend's behaviour with + ``complexTypesAsArrow=False``), not the post-processed ``string``. + """ + schema = pa.schema( + [ + pa.field( + "c", + pa.string(), + metadata={b"databricks.type_name": metadata_value}, + ), + ] + ) + desc = description_from_arrow_schema(schema) + assert desc[0][1] == expected + + +def test_description_passes_through_unknown_databricks_type_name(): + """Server-reported names other than the handful we explicitly + recognise (VARIANT / ARRAY / MAP / STRUCT) defer to the Arrow + shape — the Arrow type is the authoritative source for primitives + and the kernel's own type mapping is conservative there. Confirms + we don't accidentally claim ``int`` from metadata when the Arrow + column is something concrete like ``int64``.""" + schema = pa.schema( + [ + pa.field( + "n", + pa.int64(), + metadata={b"databricks.type_name": b"INT"}, + ), + ] + ) + desc = description_from_arrow_schema(schema) + # `int64` Arrow → "bigint" via the existing arrow-type mapper. + assert desc[0][1] == "bigint" + + # ─── bind_tspark_params ──────────────────────────────────────────────────