diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index f2dd5687ecb..fb35ff79de8 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -142,7 +142,7 @@ class VersionedOnlineReadNotSupported(FeastError): def __init__(self, store_name: str, version: int): super().__init__( f"Versioned feature reads (@v{version}) are not yet supported by {store_name}. " - f"Currently only SQLite supports version-qualified feature references. " + f"Currently only SQLite, PostgreSQL, and MySQL support version-qualified feature references. " ) diff --git a/sdk/python/feast/infra/online_stores/mysql_online_store/mysql.py b/sdk/python/feast/infra/online_stores/mysql_online_store/mysql.py index 2172f3aa359..6df9331529a 100644 --- a/sdk/python/feast/infra/online_stores/mysql_online_store/mysql.py +++ b/sdk/python/feast/infra/online_stores/mysql_online_store/mysql.py @@ -70,6 +70,7 @@ def online_write_batch( cur = conn.cursor() project = config.project + versioning = config.registry.enable_online_feature_view_versioning batch_write = config.online_store.batch_write if not batch_write: @@ -92,6 +93,7 @@ def online_write_batch( table, timestamp, val, + versioning, ) conn.commit() if progress: @@ -124,7 +126,9 @@ def online_write_batch( if len(insert_values) >= batch_size: try: - self._execute_batch(cur, project, table, insert_values) + self._execute_batch( + cur, project, table, insert_values, versioning + ) conn.commit() if progress: progress(len(insert_values)) @@ -135,7 +139,7 @@ def online_write_batch( if insert_values: try: - self._execute_batch(cur, project, table, insert_values) + self._execute_batch(cur, project, table, insert_values, versioning) conn.commit() if progress: progress(len(insert_values)) @@ -143,9 +147,12 @@ def online_write_batch( conn.rollback() raise e - def _execute_batch(self, cur, project, table, insert_values): - sql = f""" - INSERT INTO {_table_id(project, table)} + def _execute_batch( + self, cur, project, table, insert_values, enable_versioning=False + ): + table_name = _table_id(project, table, enable_versioning) + stmt = f""" + INSERT INTO {table_name} (entity_key, feature_name, value, event_ts, created_ts) values (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE @@ -154,22 +161,29 @@ def _execute_batch(self, cur, project, table, insert_values): created_ts = VALUES(created_ts); """ try: - cur.executemany(sql, insert_values) + cur.executemany(stmt, insert_values) except Exception as e: - # Log SQL info for debugging without leaking sensitive data first_sample = insert_values[0] if insert_values else None raise RuntimeError( - f"Failed to execute batch insert into table '{_table_id(project, table)}' " + f"Failed to execute batch insert into table '{table_name}' " f"(rows={len(insert_values)}, sample={first_sample}): {e}" ) from e @staticmethod def write_to_table( - created_ts, cur, entity_key_bin, feature_name, project, table, timestamp, val + created_ts, + cur, + entity_key_bin, + feature_name, + project, + table, + timestamp, + val, + enable_versioning=False, ) -> None: cur.execute( f""" - INSERT INTO {_table_id(project, table)} + INSERT INTO {_table_id(project, table, enable_versioning)} (entity_key, feature_name, value, event_ts, created_ts) values (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE @@ -204,6 +218,7 @@ def online_read( result: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = [] project = config.project + versioning = config.registry.enable_online_feature_view_versioning for entity_key in entity_keys: entity_key_bin = serialize_entity_key( entity_key, @@ -211,7 +226,7 @@ def online_read( ).hex() cur.execute( - f"SELECT feature_name, value, event_ts FROM {_table_id(project, table)} WHERE entity_key = %s", + f"SELECT feature_name, value, event_ts FROM {_table_id(project, table, versioning)} WHERE entity_key = %s", (entity_key_bin,), ) @@ -243,10 +258,11 @@ def update( conn = self._get_conn(config) cur = conn.cursor() project = config.project + versioning = config.registry.enable_online_feature_view_versioning # We don't create any special state for the entities in this implementation. for table in tables_to_keep: - table_name = _table_id(project, table) + table_name = _table_id(project, table, versioning) index_name = f"{table_name}_ek" cur.execute( f"""CREATE TABLE IF NOT EXISTS {table_name} (entity_key VARCHAR(512), @@ -269,7 +285,7 @@ def update( ) for table in tables_to_delete: - _drop_table_and_index(cur, project, table) + _drop_table_and_index(cur, project, table, versioning) def teardown( self, @@ -280,16 +296,26 @@ def teardown( conn = self._get_conn(config) cur = conn.cursor() project = config.project + versioning = config.registry.enable_online_feature_view_versioning for table in tables: - _drop_table_and_index(cur, project, table) + _drop_table_and_index(cur, project, table, versioning) -def _drop_table_and_index(cur: Cursor, project: str, table: FeatureView) -> None: - table_name = _table_id(project, table) +def _drop_table_and_index( + cur: Cursor, project: str, table: FeatureView, enable_versioning: bool = False +) -> None: + table_name = _table_id(project, table, enable_versioning) cur.execute(f"DROP INDEX {table_name}_ek ON {table_name};") cur.execute(f"DROP TABLE IF EXISTS {table_name}") -def _table_id(project: str, table: FeatureView) -> str: - return f"{project}_{table.name}" +def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str: + name = table.name + if enable_versioning: + version = getattr(table.projection, "version_tag", None) + if version is None: + version = getattr(table, "current_version_number", None) + if version is not None and version > 0: + name = f"{table.name}_v{version}" + return f"{project}_{name}" diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index 4913046470c..7af616be01f 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -256,9 +256,17 @@ def get_online_features( def _check_versioned_read_support(self, grouped_refs): """Raise an error if versioned reads are attempted on unsupported stores.""" + from feast.infra.online_stores.mysql_online_store.mysql import ( + MySQLOnlineStore, + ) + from feast.infra.online_stores.postgres_online_store.postgres import ( + PostgreSQLOnlineStore, + ) from feast.infra.online_stores.sqlite import SqliteOnlineStore - if isinstance(self, SqliteOnlineStore): + if isinstance( + self, (SqliteOnlineStore, PostgreSQLOnlineStore, MySQLOnlineStore) + ): return for table, _ in grouped_refs: version_tag = getattr(table.projection, "version_tag", None) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index f7780726d12..487dacac227 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -152,7 +152,15 @@ def online_write_batch( event_ts = EXCLUDED.event_ts, created_ts = EXCLUDED.created_ts; """ - ).format(sql.Identifier(_table_id(config.project, table))) + ).format( + sql.Identifier( + _table_id( + config.project, + table, + config.registry.enable_online_feature_view_versioning, + ) + ) + ) # Push data into the online store with self._get_conn(config) as conn, conn.cursor() as cur: @@ -214,7 +222,13 @@ def _construct_query_and_params( FROM {} WHERE entity_key = ANY(%s) AND feature_name = ANY(%s); """ ).format( - sql.Identifier(_table_id(config.project, table)), + sql.Identifier( + _table_id( + config.project, + table, + config.registry.enable_online_feature_view_versioning, + ) + ), ) params = (keys, requested_features) else: @@ -224,7 +238,13 @@ def _construct_query_and_params( FROM {} WHERE entity_key = ANY(%s); """ ).format( - sql.Identifier(_table_id(config.project, table)), + sql.Identifier( + _table_id( + config.project, + table, + config.registry.enable_online_feature_view_versioning, + ) + ), ) params = (keys, []) return query, params @@ -304,12 +324,13 @@ def update( ), ) + versioning = config.registry.enable_online_feature_view_versioning for table in tables_to_delete: - table_name = _table_id(project, table) + table_name = _table_id(project, table, versioning) cur.execute(_drop_table_and_index(table_name)) for table in tables_to_keep: - table_name = _table_id(project, table) + table_name = _table_id(project, table, versioning) if config.online_store.vector_enabled: vector_value_type = "vector" else: @@ -363,10 +384,11 @@ def teardown( entities: Sequence[Entity], ): project = config.project + versioning = config.registry.enable_online_feature_view_versioning try: with self._get_conn(config) as conn, conn.cursor() as cur: for table in tables: - table_name = _table_id(project, table) + table_name = _table_id(project, table, versioning) cur.execute(_drop_table_and_index(table_name)) conn.commit() except Exception: @@ -432,7 +454,9 @@ def retrieve_online_documents( ] ] = [] with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur: - table_name = _table_id(project, table) + table_name = _table_id( + project, table, config.registry.enable_online_feature_view_versioning + ) # Search query template to find the top k items that are closest to the given embedding # SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5; @@ -533,7 +557,11 @@ def retrieve_online_documents_v2( and feature.name in requested_features ] - table_name = _table_id(config.project, table) + table_name = _table_id( + config.project, + table, + config.registry.enable_online_feature_view_versioning, + ) with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur: query = None @@ -794,8 +822,15 @@ def retrieve_online_documents_v2( return result -def _table_id(project: str, table: FeatureView) -> str: - return f"{project}_{table.name}" +def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str: + name = table.name + if enable_versioning: + version = getattr(table.projection, "version_tag", None) + if version is None: + version = getattr(table, "current_version_number", None) + if version is not None and version > 0: + name = f"{table.name}_v{version}" + return f"{project}_{name}" def _drop_table_and_index(table_name): diff --git a/sdk/python/tests/integration/online_store/test_mysql_versioning.py b/sdk/python/tests/integration/online_store/test_mysql_versioning.py new file mode 100644 index 00000000000..d1d132681c5 --- /dev/null +++ b/sdk/python/tests/integration/online_store/test_mysql_versioning.py @@ -0,0 +1,187 @@ +"""Integration tests for MySQL online store feature view versioning. + +Run with: pytest --integration sdk/python/tests/integration/online_store/test_mysql_versioning.py +""" + +import shutil +from datetime import datetime, timedelta, timezone + +import pytest + +from feast import Entity, FeatureView +from feast.field import Field +from feast.infra.online_stores.mysql_online_store.mysql import MySQLOnlineStore +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.repo_config import RegistryConfig, RepoConfig +from feast.types import Float32, Int64 +from feast.value_type import ValueType + + +def _make_feature_view(name="driver_stats", version="latest"): + entity = Entity( + name="driver_id", + join_keys=["driver_id"], + value_type=ValueType.INT64, + ) + return FeatureView( + name=name, + entities=[entity], + ttl=timedelta(days=1), + schema=[ + Field(name="driver_id", dtype=Int64), + Field(name="trips_today", dtype=Int64), + Field(name="avg_rating", dtype=Float32), + ], + version=version, + ) + + +def _make_entity_key(driver_id: int) -> EntityKeyProto: + entity_key = EntityKeyProto() + entity_key.join_keys.append("driver_id") + val = ValueProto() + val.int64_val = driver_id + entity_key.entity_values.append(val) + return entity_key + + +def _write_and_read(store, config, fv, driver_id=1001, trips=42): + entity_key = _make_entity_key(driver_id) + val = ValueProto() + val.int64_val = trips + now = datetime.now(tz=timezone.utc) + store.online_write_batch( + config, fv, [(entity_key, {"trips_today": val}, now, now)], None + ) + return store.online_read(config, fv, [entity_key], ["trips_today"]) + + +@pytest.mark.integration +@pytest.mark.skipif( + not shutil.which("docker"), + reason="Docker not available", +) +class TestMySQLVersioningIntegration: + """Integration tests for MySQL versioning with a real database.""" + + @pytest.fixture(autouse=True) + def setup_mysql(self): + try: + from testcontainers.mysql import MySqlContainer + except ImportError: + pytest.skip("testcontainers[mysql] not installed") + + self.container = MySqlContainer( + "mysql:8.0", + username="root", + password="testpass", # pragma: allowlist secret + dbname="feast", + ).with_exposed_ports(3306) + self.container.start() + self.port = self.container.get_exposed_port(3306) + yield + self.container.stop() + + def _make_config(self, enable_versioning=False): + from feast.infra.online_stores.mysql_online_store.mysql import ( + MySQLOnlineStoreConfig, + ) + + return RepoConfig( + project="test_project", + provider="local", + online_store=MySQLOnlineStoreConfig( + type="mysql", + host="localhost", + port=int(self.port), + user="root", + password="testpass", # pragma: allowlist secret + database="feast", + ), + registry=RegistryConfig( + path="/tmp/test_mysql_registry.pb", + enable_online_feature_view_versioning=enable_versioning, + ), + entity_key_serialization_version=3, + ) + + def test_write_read_without_versioning(self): + config = self._make_config(enable_versioning=False) + store = MySQLOnlineStore() + fv = _make_feature_view() + store.update(config, [], [fv], [], [], False) + + result = _write_and_read(store, config, fv) + assert result[0][1] is not None + assert result[0][1]["trips_today"].int64_val == 42 + + def test_write_read_with_versioning_v1(self): + config = self._make_config(enable_versioning=True) + store = MySQLOnlineStore() + fv = _make_feature_view() + fv.current_version_number = 1 + store.update(config, [], [fv], [], [], False) + + result = _write_and_read(store, config, fv) + assert result[0][1] is not None + assert result[0][1]["trips_today"].int64_val == 42 + + def test_version_isolation(self): + """Data written to v1 is not visible from v2.""" + config = self._make_config(enable_versioning=True) + store = MySQLOnlineStore() + + fv_v1 = _make_feature_view() + fv_v1.current_version_number = 1 + store.update(config, [], [fv_v1], [], [], False) + _write_and_read(store, config, fv_v1, driver_id=1001, trips=10) + + fv_v2 = _make_feature_view() + fv_v2.current_version_number = 2 + store.update(config, [], [fv_v2], [], [], False) + + entity_key = _make_entity_key(1001) + result = store.online_read(config, fv_v2, [entity_key], ["trips_today"]) + assert result[0] == (None, None) + + result = store.online_read(config, fv_v1, [entity_key], ["trips_today"]) + assert result[0][1] is not None + assert result[0][1]["trips_today"].int64_val == 10 + + def test_projection_version_tag_routes_to_correct_table(self): + """projection.version_tag routes reads to the correct versioned table.""" + config = self._make_config(enable_versioning=True) + store = MySQLOnlineStore() + + fv_v1 = _make_feature_view() + fv_v1.current_version_number = 1 + store.update(config, [], [fv_v1], [], [], False) + _write_and_read(store, config, fv_v1, driver_id=1001, trips=100) + + fv_v2 = _make_feature_view() + fv_v2.current_version_number = 2 + store.update(config, [], [fv_v2], [], [], False) + _write_and_read(store, config, fv_v2, driver_id=1001, trips=200) + + fv_read = _make_feature_view() + fv_read.projection.version_tag = 1 + entity_key = _make_entity_key(1001) + result = store.online_read(config, fv_read, [entity_key], ["trips_today"]) + assert result[0][1]["trips_today"].int64_val == 100 + + fv_read2 = _make_feature_view() + fv_read2.projection.version_tag = 2 + result = store.online_read(config, fv_read2, [entity_key], ["trips_today"]) + assert result[0][1]["trips_today"].int64_val == 200 + + def test_teardown_versioned_table(self): + config = self._make_config(enable_versioning=True) + store = MySQLOnlineStore() + + fv = _make_feature_view() + fv.current_version_number = 1 + store.update(config, [], [fv], [], [], False) + _write_and_read(store, config, fv) + + store.teardown(config, [fv], []) diff --git a/sdk/python/tests/integration/online_store/test_postgres_versioning.py b/sdk/python/tests/integration/online_store/test_postgres_versioning.py new file mode 100644 index 00000000000..9816c2a43c2 --- /dev/null +++ b/sdk/python/tests/integration/online_store/test_postgres_versioning.py @@ -0,0 +1,205 @@ +"""Integration tests for PostgreSQL online store feature view versioning. + +Run with: pytest --integration sdk/python/tests/integration/online_store/test_postgres_versioning.py +""" + +import shutil +from datetime import datetime, timedelta, timezone + +import pytest + +from feast import Entity, FeatureView +from feast.field import Field +from feast.infra.online_stores.postgres_online_store.postgres import ( + PostgreSQLOnlineStore, +) +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.repo_config import RegistryConfig, RepoConfig +from feast.types import Float32, Int64 +from feast.value_type import ValueType + + +def _make_feature_view(name="driver_stats", version="latest"): + entity = Entity( + name="driver_id", + join_keys=["driver_id"], + value_type=ValueType.INT64, + ) + return FeatureView( + name=name, + entities=[entity], + ttl=timedelta(days=1), + schema=[ + Field(name="driver_id", dtype=Int64), + Field(name="trips_today", dtype=Int64), + Field(name="avg_rating", dtype=Float32), + ], + version=version, + ) + + +def _make_entity_key(driver_id: int) -> EntityKeyProto: + entity_key = EntityKeyProto() + entity_key.join_keys.append("driver_id") + val = ValueProto() + val.int64_val = driver_id + entity_key.entity_values.append(val) + return entity_key + + +def _write_and_read(store, config, fv, driver_id=1001, trips=42): + entity_key = _make_entity_key(driver_id) + val = ValueProto() + val.int64_val = trips + now = datetime.now(tz=timezone.utc) + store.online_write_batch( + config, fv, [(entity_key, {"trips_today": val}, now, now)], None + ) + return store.online_read(config, fv, [entity_key], ["trips_today"]) + + +@pytest.mark.integration +@pytest.mark.skipif( + not shutil.which("docker"), + reason="Docker not available", +) +class TestPostgresVersioningIntegration: + """Integration tests for PostgreSQL versioning with a real database.""" + + @pytest.fixture(autouse=True) + def setup_postgres(self): + try: + from testcontainers.postgres import PostgresContainer + except ImportError: + pytest.skip("testcontainers[postgres] not installed") + + self.container = PostgresContainer( + "postgres:16", + username="root", + password="testpass", # pragma: allowlist secret + dbname="test", + ).with_exposed_ports(5432) + self.container.start() + self.port = self.container.get_exposed_port(5432) + yield + self.container.stop() + + def _make_config(self, enable_versioning=False): + from feast.infra.online_stores.postgres_online_store.postgres import ( + PostgreSQLOnlineStoreConfig, + ) + + return RepoConfig( + project="test_project", + provider="local", + online_store=PostgreSQLOnlineStoreConfig( + type="postgres", + host="localhost", + port=int(self.port), + user="root", + password="testpass", # pragma: allowlist secret + database="test", + sslmode="disable", + ), + registry=RegistryConfig( + path="/tmp/test_pg_registry.pb", + enable_online_feature_view_versioning=enable_versioning, + ), + entity_key_serialization_version=3, + ) + + def test_write_read_without_versioning(self): + config = self._make_config(enable_versioning=False) + store = PostgreSQLOnlineStore() + fv = _make_feature_view() + store.update(config, [], [fv], [], [], False) + + result = _write_and_read(store, config, fv) + assert result[0][1] is not None + assert result[0][1]["trips_today"].int64_val == 42 + + def test_write_read_with_versioning_v1(self): + config = self._make_config(enable_versioning=True) + store = PostgreSQLOnlineStore() + fv = _make_feature_view() + fv.current_version_number = 1 + store.update(config, [], [fv], [], [], False) + + result = _write_and_read(store, config, fv) + assert result[0][1] is not None + assert result[0][1]["trips_today"].int64_val == 42 + + def test_version_isolation(self): + """Data written to v1 is not visible from v2.""" + config = self._make_config(enable_versioning=True) + store = PostgreSQLOnlineStore() + + fv_v1 = _make_feature_view() + fv_v1.current_version_number = 1 + store.update(config, [], [fv_v1], [], [], False) + _write_and_read(store, config, fv_v1, driver_id=1001, trips=10) + + fv_v2 = _make_feature_view() + fv_v2.current_version_number = 2 + store.update(config, [], [fv_v2], [], [], False) + + entity_key = _make_entity_key(1001) + result = store.online_read(config, fv_v2, [entity_key], ["trips_today"]) + assert result[0] == (None, None) + + result = store.online_read(config, fv_v1, [entity_key], ["trips_today"]) + assert result[0][1] is not None + assert result[0][1]["trips_today"].int64_val == 10 + + def test_projection_version_tag_routes_to_correct_table(self): + """projection.version_tag routes reads to the correct versioned table.""" + config = self._make_config(enable_versioning=True) + store = PostgreSQLOnlineStore() + + fv_v1 = _make_feature_view() + fv_v1.current_version_number = 1 + store.update(config, [], [fv_v1], [], [], False) + _write_and_read(store, config, fv_v1, driver_id=1001, trips=100) + + fv_v2 = _make_feature_view() + fv_v2.current_version_number = 2 + store.update(config, [], [fv_v2], [], [], False) + _write_and_read(store, config, fv_v2, driver_id=1001, trips=200) + + fv_read = _make_feature_view() + fv_read.projection.version_tag = 1 + entity_key = _make_entity_key(1001) + result = store.online_read(config, fv_read, [entity_key], ["trips_today"]) + assert result[0][1]["trips_today"].int64_val == 100 + + fv_read2 = _make_feature_view() + fv_read2.projection.version_tag = 2 + result = store.online_read(config, fv_read2, [entity_key], ["trips_today"]) + assert result[0][1]["trips_today"].int64_val == 200 + + def test_teardown_versioned_table(self): + """teardown() drops the versioned table without error.""" + config = self._make_config(enable_versioning=True) + store = PostgreSQLOnlineStore() + + fv = _make_feature_view() + fv.current_version_number = 1 + store.update(config, [], [fv], [], [], False) + _write_and_read(store, config, fv) + + # Should not raise + store.teardown(config, [fv], []) + + def test_update_deletes_versioned_table(self): + """update() with tables_to_delete correctly drops versioned tables.""" + config = self._make_config(enable_versioning=True) + store = PostgreSQLOnlineStore() + + fv = _make_feature_view() + fv.current_version_number = 1 + store.update(config, [], [fv], [], [], False) + _write_and_read(store, config, fv, driver_id=1001, trips=50) + + # Delete the versioned table + store.update(config, [fv], [], [], [], False) diff --git a/sdk/python/tests/unit/infra/online_store/test_mysql_versioning.py b/sdk/python/tests/unit/infra/online_store/test_mysql_versioning.py new file mode 100644 index 00000000000..7bbe967a53b --- /dev/null +++ b/sdk/python/tests/unit/infra/online_store/test_mysql_versioning.py @@ -0,0 +1,98 @@ +"""Unit tests for MySQL online store feature view versioning.""" + +from datetime import timedelta + +from feast import Entity, FeatureView +from feast.field import Field +from feast.infra.online_stores.mysql_online_store.mysql import ( + MySQLOnlineStore, + _table_id, +) +from feast.types import Float32, Int64 +from feast.value_type import ValueType + + +def _make_feature_view(name="driver_stats", version="latest"): + entity = Entity( + name="driver_id", + join_keys=["driver_id"], + value_type=ValueType.INT64, + ) + return FeatureView( + name=name, + entities=[entity], + ttl=timedelta(days=1), + schema=[ + Field(name="driver_id", dtype=Int64), + Field(name="trips_today", dtype=Int64), + Field(name="avg_rating", dtype=Float32), + ], + version=version, + ) + + +class TestMySQLTableId: + """Test _table_id generates correct versioned table names.""" + + def test_default_no_versioning(self): + fv = _make_feature_view() + assert _table_id("proj", fv) == "proj_driver_stats" + + def test_versioning_explicitly_disabled(self): + fv = _make_feature_view() + assert _table_id("proj", fv, enable_versioning=False) == "proj_driver_stats" + + def test_versioning_enabled_no_version_set(self): + fv = _make_feature_view() + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats" + + def test_versioning_enabled_with_current_version_number(self): + fv = _make_feature_view() + fv.current_version_number = 2 + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats_v2" + + def test_version_zero_no_suffix(self): + fv = _make_feature_view() + fv.current_version_number = 0 + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats" + + def test_projection_version_tag_takes_priority(self): + fv = _make_feature_view() + fv.current_version_number = 2 + fv.projection.version_tag = 4 + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats_v4" + + def test_projection_version_tag_zero_no_suffix(self): + fv = _make_feature_view() + fv.projection.version_tag = 0 + fv.current_version_number = 1 + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats" + + def test_different_project_names(self): + fv = _make_feature_view() + fv.current_version_number = 1 + assert _table_id("prod", fv, enable_versioning=True) == "prod_driver_stats_v1" + assert ( + _table_id("staging", fv, enable_versioning=True) + == "staging_driver_stats_v1" + ) + + def test_different_feature_view_names(self): + fv = _make_feature_view(name="user_stats") + fv.current_version_number = 2 + assert _table_id("proj", fv, enable_versioning=True) == "proj_user_stats_v2" + + +class TestMySQLVersionedReadSupport: + """Test that MySQLOnlineStore passes _check_versioned_read_support.""" + + def test_allowed_with_version_tag(self): + store = MySQLOnlineStore() + fv = _make_feature_view() + fv.projection.version_tag = 2 + store._check_versioned_read_support([(fv, ["trips_today"])]) + + def test_allowed_without_version_tag(self): + store = MySQLOnlineStore() + fv = _make_feature_view() + store._check_versioned_read_support([(fv, ["trips_today"])]) diff --git a/sdk/python/tests/unit/infra/online_store/test_postgres_versioning.py b/sdk/python/tests/unit/infra/online_store/test_postgres_versioning.py new file mode 100644 index 00000000000..ad6877cd4f1 --- /dev/null +++ b/sdk/python/tests/unit/infra/online_store/test_postgres_versioning.py @@ -0,0 +1,98 @@ +"""Unit tests for PostgreSQL online store feature view versioning.""" + +from datetime import timedelta + +from feast import Entity, FeatureView +from feast.field import Field +from feast.infra.online_stores.postgres_online_store.postgres import ( + PostgreSQLOnlineStore, + _table_id, +) +from feast.types import Float32, Int64 +from feast.value_type import ValueType + + +def _make_feature_view(name="driver_stats", version="latest"): + entity = Entity( + name="driver_id", + join_keys=["driver_id"], + value_type=ValueType.INT64, + ) + return FeatureView( + name=name, + entities=[entity], + ttl=timedelta(days=1), + schema=[ + Field(name="driver_id", dtype=Int64), + Field(name="trips_today", dtype=Int64), + Field(name="avg_rating", dtype=Float32), + ], + version=version, + ) + + +class TestPostgresTableId: + """Test _table_id generates correct versioned table names.""" + + def test_default_no_versioning(self): + fv = _make_feature_view() + assert _table_id("proj", fv) == "proj_driver_stats" + + def test_versioning_explicitly_disabled(self): + fv = _make_feature_view() + assert _table_id("proj", fv, enable_versioning=False) == "proj_driver_stats" + + def test_versioning_enabled_no_version_set(self): + fv = _make_feature_view() + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats" + + def test_versioning_enabled_with_current_version_number(self): + fv = _make_feature_view() + fv.current_version_number = 3 + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats_v3" + + def test_version_zero_no_suffix(self): + fv = _make_feature_view() + fv.current_version_number = 0 + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats" + + def test_projection_version_tag_takes_priority(self): + fv = _make_feature_view() + fv.current_version_number = 2 + fv.projection.version_tag = 5 + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats_v5" + + def test_projection_version_tag_zero_no_suffix(self): + fv = _make_feature_view() + fv.projection.version_tag = 0 + fv.current_version_number = 3 + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats" + + def test_different_project_names(self): + fv = _make_feature_view() + fv.current_version_number = 1 + assert _table_id("prod", fv, enable_versioning=True) == "prod_driver_stats_v1" + assert ( + _table_id("staging", fv, enable_versioning=True) + == "staging_driver_stats_v1" + ) + + def test_different_feature_view_names(self): + fv = _make_feature_view(name="user_stats") + fv.current_version_number = 2 + assert _table_id("proj", fv, enable_versioning=True) == "proj_user_stats_v2" + + +class TestPostgresVersionedReadSupport: + """Test that PostgreSQLOnlineStore passes _check_versioned_read_support.""" + + def test_allowed_with_version_tag(self): + store = PostgreSQLOnlineStore() + fv = _make_feature_view() + fv.projection.version_tag = 2 + store._check_versioned_read_support([(fv, ["trips_today"])]) + + def test_allowed_without_version_tag(self): + store = PostgreSQLOnlineStore() + fv = _make_feature_view() + store._check_versioned_read_support([(fv, ["trips_today"])])