diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index 5b5daf17575..e623cf707df 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -300,6 +300,12 @@ def _check_versioned_read_support(self, grouped_refs): supported_types.append(MilvusOnlineStore) except ImportError: pass + try: + from feast.infra.online_stores.snowflake import SnowflakeOnlineStore + + supported_types.append(SnowflakeOnlineStore) + except ImportError: + pass if isinstance(self, tuple(supported_types)): return diff --git a/sdk/python/feast/infra/online_stores/snowflake.py b/sdk/python/feast/infra/online_stores/snowflake.py index d2df674ed94..2a3b3734465 100644 --- a/sdk/python/feast/infra/online_stores/snowflake.py +++ b/sdk/python/feast/infra/online_stores/snowflake.py @@ -9,6 +9,7 @@ from feast.entity import Entity from feast.feature_view import FeatureView from feast.infra.key_encoding_utils import serialize_entity_key +from feast.infra.online_stores.helpers import compute_versioned_name from feast.infra.online_stores.online_store import OnlineStore from feast.infra.utils.snowflake.snowflake_utils import ( GetSnowflakeConnection, @@ -22,6 +23,12 @@ from feast.utils import to_naive_utc +def _snowflake_table_name( + project: str, table: FeatureView, enable_versioning: bool = False +) -> str: + return f"[online-transient] {project}_{compute_versioned_name(table, enable_versioning)}" + + class SnowflakeOnlineStoreConfig(FeastConfigBaseModel): """Online store config for Snowflake""" @@ -120,17 +127,19 @@ def online_write_batch( # This combines both the data upload plus the overwrite in the same transaction online_path = get_snowflake_online_store_path(config, table) + versioning = config.registry.enable_online_feature_view_versioning + tbl = _snowflake_table_name(config.project, table, versioning) with GetSnowflakeConnection(config.online_store, autocommit=False) as conn: write_pandas_binary( conn, agg_df, - table_name=f"[online-transient] {config.project}_{table.name}", + table_name=tbl, database=f"{config.online_store.database}", schema=f"{config.online_store.schema_}", ) # special function for writing binary to snowflake query = f""" - INSERT OVERWRITE INTO {online_path}."[online-transient] {config.project}_{table.name}" + INSERT OVERWRITE INTO {online_path}."{tbl}" SELECT "entity_feature_key", "entity_key", @@ -143,7 +152,7 @@ def online_write_batch( *, ROW_NUMBER() OVER(PARTITION BY "entity_key","feature_name" ORDER BY "event_ts" DESC, "created_ts" DESC) AS "_feast_row" FROM - {online_path}."[online-transient] {config.project}_{table.name}") + {online_path}."{tbl}") WHERE "_feast_row" = 1; """ @@ -191,12 +200,15 @@ def online_read( ) online_path = get_snowflake_online_store_path(config, table) + tbl = _snowflake_table_name( + config.project, table, config.registry.enable_online_feature_view_versioning + ) with GetSnowflakeConnection(config.online_store) as conn: query = f""" SELECT "entity_key", "feature_name", "value", "event_ts" FROM - {online_path}."[online-transient] {config.project}_{table.name}" + {online_path}."{tbl}" WHERE "entity_feature_key" IN ({entity_fetch_str}) """ @@ -228,11 +240,13 @@ def update( ): assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) + versioning = config.registry.enable_online_feature_view_versioning with GetSnowflakeConnection(config.online_store) as conn: for table in tables_to_keep: online_path = get_snowflake_online_store_path(config, table) + tbl = _snowflake_table_name(config.project, table, versioning) query = f""" - CREATE TRANSIENT TABLE IF NOT EXISTS {online_path}."[online-transient] {config.project}_{table.name}" ( + CREATE TRANSIENT TABLE IF NOT EXISTS {online_path}."{tbl}" ( "entity_feature_key" BINARY, "entity_key" BINARY, "feature_name" VARCHAR, @@ -245,7 +259,8 @@ def update( for table in tables_to_delete: online_path = get_snowflake_online_store_path(config, table) - query = f'DROP TABLE IF EXISTS {online_path}."[online-transient] {config.project}_{table.name}"' + tbl = _snowflake_table_name(config.project, table, versioning) + query = f'DROP TABLE IF EXISTS {online_path}."{tbl}"' execute_snowflake_statement(conn, query) def teardown( @@ -256,8 +271,10 @@ def teardown( ): assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) + versioning = config.registry.enable_online_feature_view_versioning with GetSnowflakeConnection(config.online_store) as conn: for table in tables: online_path = get_snowflake_online_store_path(config, table) - query = f'DROP TABLE IF EXISTS {online_path}."[online-transient] {config.project}_{table.name}"' + tbl = _snowflake_table_name(config.project, table, versioning) + query = f'DROP TABLE IF EXISTS {online_path}."{tbl}"' execute_snowflake_statement(conn, query) diff --git a/sdk/python/tests/unit/infra/online_store/test_snowflake_versioning.py b/sdk/python/tests/unit/infra/online_store/test_snowflake_versioning.py new file mode 100644 index 00000000000..052ba9fa2a6 --- /dev/null +++ b/sdk/python/tests/unit/infra/online_store/test_snowflake_versioning.py @@ -0,0 +1,172 @@ +"""Unit tests for Snowflake online store feature view versioning.""" + +import sys +from datetime import timedelta +from types import ModuleType +from unittest.mock import MagicMock + +from feast import Entity, FeatureView +from feast.field import Field +from feast.types import Float32 +from feast.value_type import ValueType + + +def _stub_snowflake_modules(): + """Stub out Snowflake connector and cryptography so the online store can be imported.""" + + # Build a proper package hierarchy so submodule imports don't fail. + def _mod(name): + m = ModuleType(name) + sys.modules[name] = m + return m + + if "cryptography" not in sys.modules: + crypto = _mod("cryptography") + hazmat = _mod("cryptography.hazmat") + backends = _mod("cryptography.hazmat.backends") + backends.default_backend = MagicMock() + primitives = _mod("cryptography.hazmat.primitives") + serialization = _mod("cryptography.hazmat.primitives.serialization") + serialization.Encoding = MagicMock() + serialization.PrivateFormat = MagicMock() + serialization.NoEncryption = MagicMock() + crypto.hazmat = hazmat + hazmat.backends = backends + hazmat.primitives = primitives + primitives.serialization = serialization + + if "snowflake" not in sys.modules: + sf = _mod("snowflake") + connector = _mod("snowflake.connector") + connector.ProgrammingError = Exception + connector.SnowflakeConnection = MagicMock() + cursor_mod = _mod("snowflake.connector.cursor") + cursor_mod.SnowflakeCursor = MagicMock() + sf.connector = connector + connector.cursor = cursor_mod + + if "tenacity" not in sys.modules: + tenacity = _mod("tenacity") + tenacity.retry = lambda *a, **kw: lambda f: f + tenacity.retry_if_exception_type = MagicMock() + tenacity.stop_after_attempt = MagicMock() + tenacity.wait_exponential = MagicMock() + + +_stub_snowflake_modules() + + +def _make_feature_view(name="driver_stats", version_number=None, version_tag=None): + entity = Entity( + name="driver_id", + join_keys=["driver_id"], + value_type=ValueType.INT64, + ) + fv = FeatureView( + name=name, + entities=[entity], + ttl=timedelta(days=1), + schema=[Field(name="trips_today", dtype=Float32)], + ) + if version_number is not None: + fv.current_version_number = version_number + if version_tag is not None: + fv.projection.version_tag = version_tag + return fv + + +class TestSnowflakeTableName: + """Test _snowflake_table_name with versioning enabled/disabled.""" + + def test_no_versioning(self): + from feast.infra.online_stores.snowflake import _snowflake_table_name + + fv = _make_feature_view() + assert ( + _snowflake_table_name("test_project", fv, False) + == "[online-transient] test_project_driver_stats" + ) + + def test_versioning_disabled_ignores_version(self): + from feast.infra.online_stores.snowflake import _snowflake_table_name + + fv = _make_feature_view(version_number=5) + assert ( + _snowflake_table_name("test_project", fv, False) + == "[online-transient] test_project_driver_stats" + ) + + def test_versioning_enabled_no_version_set(self): + from feast.infra.online_stores.snowflake import _snowflake_table_name + + fv = _make_feature_view() + assert ( + _snowflake_table_name("test_project", fv, True) + == "[online-transient] test_project_driver_stats" + ) + + def test_versioning_enabled_with_current_version_number(self): + from feast.infra.online_stores.snowflake import _snowflake_table_name + + fv = _make_feature_view(version_number=2) + assert ( + _snowflake_table_name("test_project", fv, True) + == "[online-transient] test_project_driver_stats_v2" + ) + + def test_version_zero_no_suffix(self): + from feast.infra.online_stores.snowflake import _snowflake_table_name + + fv = _make_feature_view(version_number=0) + assert ( + _snowflake_table_name("test_project", fv, True) + == "[online-transient] test_project_driver_stats" + ) + + def test_projection_version_tag_takes_priority(self): + from feast.infra.online_stores.snowflake import _snowflake_table_name + + fv = _make_feature_view(version_number=1, version_tag=3) + assert ( + _snowflake_table_name("test_project", fv, True) + == "[online-transient] test_project_driver_stats_v3" + ) + + def test_projection_version_tag_zero_no_suffix(self): + from feast.infra.online_stores.snowflake import _snowflake_table_name + + fv = _make_feature_view(version_tag=0, version_number=3) + assert ( + _snowflake_table_name("test_project", fv, True) + == "[online-transient] test_project_driver_stats" + ) + + def test_different_versions_produce_different_table_names(self): + from feast.infra.online_stores.snowflake import _snowflake_table_name + + fv_v1 = _make_feature_view(version_number=1) + fv_v2 = _make_feature_view(version_number=2) + name_v1 = _snowflake_table_name("prod", fv_v1, True) + name_v2 = _snowflake_table_name("prod", fv_v2, True) + assert name_v1 != name_v2 + assert name_v1 == "[online-transient] prod_driver_stats_v1" + assert name_v2 == "[online-transient] prod_driver_stats_v2" + + +class TestSnowflakeVersionedReadSupport: + """Test that SnowflakeOnlineStore passes _check_versioned_read_support.""" + + def test_allowed_with_version_tag(self): + from feast.infra.online_stores.snowflake import SnowflakeOnlineStore + + store = SnowflakeOnlineStore() + fv = _make_feature_view() + fv.projection.version_tag = 2 + store._check_versioned_read_support([(fv, ["trips_today"])]) + + def test_allowed_without_version_tag(self): + from feast.infra.online_stores.snowflake import SnowflakeOnlineStore + + store = SnowflakeOnlineStore() + fv = _make_feature_view() + store._check_versioned_read_support([(fv, ["trips_today"])])