Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ def _check_versioned_read_support(self, grouped_refs):
supported_types.append(MilvusOnlineStore)
except ImportError:
pass
try:
from feast.infra.online_stores.snowflake import SnowflakeOnlineStore

supported_types.append(SnowflakeOnlineStore)
except ImportError:
pass

if isinstance(self, tuple(supported_types)):
return
Expand Down
31 changes: 24 additions & 7 deletions sdk/python/feast/infra/online_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from feast.entity import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.online_stores.helpers import compute_versioned_name
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
Expand All @@ -22,6 +23,12 @@
from feast.utils import to_naive_utc


def _snowflake_table_name(
project: str, table: FeatureView, enable_versioning: bool = False
) -> str:
return f"[online-transient] {project}_{compute_versioned_name(table, enable_versioning)}"


class SnowflakeOnlineStoreConfig(FeastConfigBaseModel):
"""Online store config for Snowflake"""

Expand Down Expand Up @@ -120,17 +127,19 @@ def online_write_batch(

# This combines both the data upload plus the overwrite in the same transaction
online_path = get_snowflake_online_store_path(config, table)
versioning = config.registry.enable_online_feature_view_versioning
tbl = _snowflake_table_name(config.project, table, versioning)
with GetSnowflakeConnection(config.online_store, autocommit=False) as conn:
write_pandas_binary(
conn,
agg_df,
table_name=f"[online-transient] {config.project}_{table.name}",
table_name=tbl,
database=f"{config.online_store.database}",
schema=f"{config.online_store.schema_}",
) # special function for writing binary to snowflake

query = f"""
INSERT OVERWRITE INTO {online_path}."[online-transient] {config.project}_{table.name}"
INSERT OVERWRITE INTO {online_path}."{tbl}"
SELECT
"entity_feature_key",
"entity_key",
Expand All @@ -143,7 +152,7 @@ def online_write_batch(
*,
ROW_NUMBER() OVER(PARTITION BY "entity_key","feature_name" ORDER BY "event_ts" DESC, "created_ts" DESC) AS "_feast_row"
FROM
{online_path}."[online-transient] {config.project}_{table.name}")
{online_path}."{tbl}")
WHERE
"_feast_row" = 1;
"""
Expand Down Expand Up @@ -191,12 +200,15 @@ def online_read(
)

online_path = get_snowflake_online_store_path(config, table)
tbl = _snowflake_table_name(
config.project, table, config.registry.enable_online_feature_view_versioning
)
with GetSnowflakeConnection(config.online_store) as conn:
query = f"""
SELECT
"entity_key", "feature_name", "value", "event_ts"
FROM
{online_path}."[online-transient] {config.project}_{table.name}"
{online_path}."{tbl}"
WHERE
"entity_feature_key" IN ({entity_fetch_str})
"""
Expand Down Expand Up @@ -228,11 +240,13 @@ def update(
):
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

versioning = config.registry.enable_online_feature_view_versioning
with GetSnowflakeConnection(config.online_store) as conn:
for table in tables_to_keep:
online_path = get_snowflake_online_store_path(config, table)
tbl = _snowflake_table_name(config.project, table, versioning)
query = f"""
CREATE TRANSIENT TABLE IF NOT EXISTS {online_path}."[online-transient] {config.project}_{table.name}" (
CREATE TRANSIENT TABLE IF NOT EXISTS {online_path}."{tbl}" (
"entity_feature_key" BINARY,
"entity_key" BINARY,
"feature_name" VARCHAR,
Expand All @@ -245,7 +259,8 @@ def update(

for table in tables_to_delete:
online_path = get_snowflake_online_store_path(config, table)
query = f'DROP TABLE IF EXISTS {online_path}."[online-transient] {config.project}_{table.name}"'
tbl = _snowflake_table_name(config.project, table, versioning)
query = f'DROP TABLE IF EXISTS {online_path}."{tbl}"'
execute_snowflake_statement(conn, query)

def teardown(
Expand All @@ -256,8 +271,10 @@ def teardown(
):
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

versioning = config.registry.enable_online_feature_view_versioning
with GetSnowflakeConnection(config.online_store) as conn:
for table in tables:
online_path = get_snowflake_online_store_path(config, table)
query = f'DROP TABLE IF EXISTS {online_path}."[online-transient] {config.project}_{table.name}"'
tbl = _snowflake_table_name(config.project, table, versioning)
query = f'DROP TABLE IF EXISTS {online_path}."{tbl}"'
execute_snowflake_statement(conn, query)
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Unit tests for Snowflake online store feature view versioning."""

import sys
from datetime import timedelta
from types import ModuleType
from unittest.mock import MagicMock

from feast import Entity, FeatureView
from feast.field import Field
from feast.types import Float32
from feast.value_type import ValueType


def _stub_snowflake_modules():
"""Stub out Snowflake connector and cryptography so the online store can be imported."""

# Build a proper package hierarchy so submodule imports don't fail.
def _mod(name):
m = ModuleType(name)
sys.modules[name] = m
return m

if "cryptography" not in sys.modules:
crypto = _mod("cryptography")
hazmat = _mod("cryptography.hazmat")
backends = _mod("cryptography.hazmat.backends")
backends.default_backend = MagicMock()
primitives = _mod("cryptography.hazmat.primitives")
serialization = _mod("cryptography.hazmat.primitives.serialization")
serialization.Encoding = MagicMock()
serialization.PrivateFormat = MagicMock()
serialization.NoEncryption = MagicMock()
crypto.hazmat = hazmat
hazmat.backends = backends
hazmat.primitives = primitives
primitives.serialization = serialization

if "snowflake" not in sys.modules:
sf = _mod("snowflake")
connector = _mod("snowflake.connector")
connector.ProgrammingError = Exception
connector.SnowflakeConnection = MagicMock()
cursor_mod = _mod("snowflake.connector.cursor")
cursor_mod.SnowflakeCursor = MagicMock()
sf.connector = connector
connector.cursor = cursor_mod

if "tenacity" not in sys.modules:
tenacity = _mod("tenacity")
tenacity.retry = lambda *a, **kw: lambda f: f
tenacity.retry_if_exception_type = MagicMock()
tenacity.stop_after_attempt = MagicMock()
tenacity.wait_exponential = MagicMock()


_stub_snowflake_modules()


def _make_feature_view(name="driver_stats", version_number=None, version_tag=None):
entity = Entity(
name="driver_id",
join_keys=["driver_id"],
value_type=ValueType.INT64,
)
fv = FeatureView(
name=name,
entities=[entity],
ttl=timedelta(days=1),
schema=[Field(name="trips_today", dtype=Float32)],
)
if version_number is not None:
fv.current_version_number = version_number
if version_tag is not None:
fv.projection.version_tag = version_tag
return fv


class TestSnowflakeTableName:
"""Test _snowflake_table_name with versioning enabled/disabled."""

def test_no_versioning(self):
from feast.infra.online_stores.snowflake import _snowflake_table_name

fv = _make_feature_view()
assert (
_snowflake_table_name("test_project", fv, False)
== "[online-transient] test_project_driver_stats"
)

def test_versioning_disabled_ignores_version(self):
from feast.infra.online_stores.snowflake import _snowflake_table_name

fv = _make_feature_view(version_number=5)
assert (
_snowflake_table_name("test_project", fv, False)
== "[online-transient] test_project_driver_stats"
)

def test_versioning_enabled_no_version_set(self):
from feast.infra.online_stores.snowflake import _snowflake_table_name

fv = _make_feature_view()
assert (
_snowflake_table_name("test_project", fv, True)
== "[online-transient] test_project_driver_stats"
)

def test_versioning_enabled_with_current_version_number(self):
from feast.infra.online_stores.snowflake import _snowflake_table_name

fv = _make_feature_view(version_number=2)
assert (
_snowflake_table_name("test_project", fv, True)
== "[online-transient] test_project_driver_stats_v2"
)

def test_version_zero_no_suffix(self):
from feast.infra.online_stores.snowflake import _snowflake_table_name

fv = _make_feature_view(version_number=0)
assert (
_snowflake_table_name("test_project", fv, True)
== "[online-transient] test_project_driver_stats"
)

def test_projection_version_tag_takes_priority(self):
from feast.infra.online_stores.snowflake import _snowflake_table_name

fv = _make_feature_view(version_number=1, version_tag=3)
assert (
_snowflake_table_name("test_project", fv, True)
== "[online-transient] test_project_driver_stats_v3"
)

def test_projection_version_tag_zero_no_suffix(self):
from feast.infra.online_stores.snowflake import _snowflake_table_name

fv = _make_feature_view(version_tag=0, version_number=3)
assert (
_snowflake_table_name("test_project", fv, True)
== "[online-transient] test_project_driver_stats"
)

def test_different_versions_produce_different_table_names(self):
from feast.infra.online_stores.snowflake import _snowflake_table_name

fv_v1 = _make_feature_view(version_number=1)
fv_v2 = _make_feature_view(version_number=2)
name_v1 = _snowflake_table_name("prod", fv_v1, True)
name_v2 = _snowflake_table_name("prod", fv_v2, True)
assert name_v1 != name_v2
assert name_v1 == "[online-transient] prod_driver_stats_v1"
assert name_v2 == "[online-transient] prod_driver_stats_v2"


class TestSnowflakeVersionedReadSupport:
"""Test that SnowflakeOnlineStore passes _check_versioned_read_support."""

def test_allowed_with_version_tag(self):
from feast.infra.online_stores.snowflake import SnowflakeOnlineStore

store = SnowflakeOnlineStore()
fv = _make_feature_view()
fv.projection.version_tag = 2
store._check_versioned_read_support([(fv, ["trips_today"])])

def test_allowed_without_version_tag(self):
from feast.infra.online_stores.snowflake import SnowflakeOnlineStore

store = SnowflakeOnlineStore()
fv = _make_feature_view()
store._check_versioned_read_support([(fv, ["trips_today"])])
Loading