Skip to content

Commit 203dd5d

Browse files
committed
framework/v37.4.2
1 parent 04c0e1d commit 203dd5d

4 files changed

Lines changed: 56 additions & 34 deletions

File tree

framework/src/framework/common/storage/field/field_data.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from __future__ import annotations
1717

18-
from beartype.typing import Generic, TypeVar, get_args
18+
from beartype.typing import Any, Generic, TypeVar, get_args
1919

2020
from superlinked.framework.common.data_types import Vector
2121
from superlinked.framework.common.exception import InvalidInputException
@@ -36,22 +36,42 @@ def __init__(self, type_: FieldDataType, name: str, value: FT) -> None:
3636
self.__validate_value(type_, value)
3737
self.value = value
3838

39+
@classmethod
40+
def from_field(cls, field: Field, value: FT) -> FieldData:
41+
return cls(field.data_type, field.name, value)
42+
3943
def __validate_value(self, data_type: FieldDataType, value: FT) -> None:
4044
valid_types = FieldTypeConverter.get_valid_node_data_types(data_type)
41-
error_msg = "Invalid value {value} for the given field data type {data_type}"
4245
if not isinstance(value, tuple(valid_types)):
43-
raise InvalidInputException(error_msg.format(value=value, data_type=data_type))
44-
if isinstance(value, list):
45-
# Assuming list types have only 1 valid type
46-
valid_type = VALID_TYPE_BY_FIELD_DATA_TYPE[data_type][0]
47-
generic_type = get_args(valid_type)[0]
48-
# TODO FAB-3719 FieldData validation is slow for large float lists
49-
if data_type != FieldDataType.FLOAT_LIST and not all(isinstance(item, generic_type) for item in value):
50-
raise InvalidInputException(error_msg.format(value=value, data_type=data_type))
46+
self.__raise_validation_exception(data_type, value)
47+
self.__validate_list_type(data_type, value)
5148

52-
@classmethod
53-
def from_field(cls, field: Field, value: FT) -> FieldData:
54-
return cls(field.data_type, field.name, value)
49+
def __validate_list_type(self, data_type: FieldDataType, value: FT) -> None:
50+
if not isinstance(value, list):
51+
return
52+
if data_type == FieldDataType.FLOAT_LIST:
53+
if self.__is_valid_float_list(value):
54+
return
55+
self.__raise_validation_exception(data_type, value)
56+
valid_type = VALID_TYPE_BY_FIELD_DATA_TYPE[data_type][0]
57+
generic_type = get_args(valid_type)[0]
58+
if not all(isinstance(item, generic_type) for item in value):
59+
self.__raise_validation_exception(data_type, value)
60+
61+
def __is_valid_float_list(self, value: list[Any]) -> bool:
62+
"""
63+
This is a performance hot-spot. It needs to be fast as we use it
64+
for custom space input validation where we can have long float lists
65+
"""
66+
try:
67+
for item in value:
68+
float(item)
69+
return True
70+
except (ValueError, TypeError):
71+
return False
72+
73+
def __raise_validation_exception(self, data_type: FieldDataType, value: Any) -> None:
74+
raise InvalidInputException(f"Invalid value {value} for the given field data type {data_type}")
5575

5676

5777
class VectorFieldData(FieldData[Vector]):

framework/src/framework/dsl/storage/qdrant_vector_database.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
from collections.abc import Mapping
17+
1618
from beartype.typing import Any
1719

1820
from superlinked.framework.common.precision import Precision
@@ -40,11 +42,10 @@ def __init__( # pylint: disable=too-many-arguments
4042
url: str,
4143
api_key: str,
4244
default_query_limit: int = 10,
43-
timeout: int | None = None,
4445
search_algorithm: SearchAlgorithm = SearchAlgorithm.FLAT,
4546
vector_precision: Precision = Precision.FLOAT16,
46-
prefer_grpc: bool | None = None,
47-
**extra_params: Any
47+
client_params: Mapping[str, Any] | None = None,
48+
**extra_params: Any,
4849
) -> None:
4950
"""
5051
Initialize the QdrantVectorDatabase.
@@ -53,18 +54,20 @@ def __init__( # pylint: disable=too-many-arguments
5354
url (str): The url of the Qdrant server.
5455
api_key (str): The api key of the Qdrant cluster.
5556
default_query_limit (int): Default vector search limit, set to Qdrant's default of 10.
56-
timeout (int | None): Timeout in seconds for Qdrant operations. Default is 5 seconds.
5757
vector_precision (Precision): Precision to use for storing vectors. Defaults to FLOAT16.
58-
prefer_grpc (bool | None): Whether to prefer gRPC for Qdrant operations. Default is False.
58+
client_params (Mapping[str, Any] | None): Additional parameters for the QdrantClient.
59+
These are passed directly to the QdrantClient constructor, so any valid QdrantClient
60+
parameter can be used. For example `{"prefer_grpc": True}`. Defaults to None.
5961
**extra_params (Any): Additional parameters for the Qdrant connection.
6062
"""
6163
super().__init__()
62-
self._connection_params = QdrantConnectionParams(url, api_key, timeout, prefer_grpc, **extra_params)
64+
self._connection_params = QdrantConnectionParams(url, api_key, **extra_params)
6365
self._settings = VDBSettings(
6466
default_query_limit=default_query_limit,
6567
search_algorithm=search_algorithm,
6668
vector_precision=vector_precision,
6769
)
70+
self._client_params = client_params or {}
6871

6972
@property
7073
def _vdb_connector(self) -> QdrantVDBConnector:
@@ -74,4 +77,4 @@ def _vdb_connector(self) -> QdrantVDBConnector:
7477
Returns:
7578
QdrantVDBConnector: The Qdrant vector database connector instance.
7679
"""
77-
return QdrantVDBConnector(self._connection_params, self._settings)
80+
return QdrantVDBConnector(self._connection_params, self._settings, self._client_params)

framework/src/framework/storage/qdrant/qdrant_connection_params.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,16 @@
1919

2020
class QdrantConnectionParams(ConnectionParams):
2121
def __init__(
22-
self, url: str, api_key: str, timeout: int | None = None, prefer_grpc: bool | None = None, **extra_params: Any
22+
self,
23+
url: str,
24+
api_key: str,
25+
**extra_params: Any,
2326
) -> None:
2427
super().__init__()
2528
extra_params_str = self.get_uri_params_string(**extra_params)
2629
self._connection_string = f"{url}{extra_params_str}"
2730
self._api_key = api_key
28-
self._timeout = timeout
29-
self._prefer_grpc = prefer_grpc
3031

3132
@property
3233
def connection_string(self) -> str:
3334
return self._connection_string
34-
35-
@property
36-
def timeout(self) -> int | None:
37-
return self._timeout
38-
39-
@property
40-
def prefer_grpc(self) -> bool:
41-
return bool(self._prefer_grpc)

framework/src/framework/storage/qdrant/qdrant_vdb_connector.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import re
15+
from collections.abc import Mapping
1516

1617
from beartype.typing import Any, Sequence, cast
1718
from qdrant_client import QdrantClient
@@ -76,13 +77,18 @@
7677

7778

7879
class QdrantVDBConnector(VDBConnector[VDBKNNSearchConfig]):
79-
def __init__(self, connection_params: QdrantConnectionParams, vdb_settings: VDBSettings) -> None:
80+
def __init__(
81+
self,
82+
connection_params: QdrantConnectionParams,
83+
vdb_settings: VDBSettings,
84+
client_params: Mapping[str, Any] | None = None,
85+
) -> None:
8086
super().__init__(vdb_settings=vdb_settings)
87+
8188
self._client = QdrantClient(
8289
url=connection_params.connection_string,
8390
api_key=connection_params._api_key,
84-
timeout=connection_params.timeout,
85-
prefer_grpc=connection_params.prefer_grpc,
91+
**client_params or {},
8692
)
8793
self._encoder = QdrantFieldEncoder(self.vector_precision)
8894
self.__search_index_manager = QdrantSearchIndexManager(self._client)

0 commit comments

Comments
 (0)