diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 3fc675ea402..92c651258e3 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -523,20 +523,33 @@ def persist( ): """ Run the retrieval and persist the results in the same offline store used for read. - Please note the persisting is done only within the scope of the spark session for local warehouse directory. + Supports both table-based and path-based SparkSource configurations. + For table-based: persists via saveAsTable (remote warehouse) or createOrReplaceTempView (local). + For path-based: writes directly to the specified path in the given file format. """ assert isinstance(storage, SavedDatasetSparkStorage) + table_name = storage.spark_options.table - if not table_name: - raise ValueError("Cannot persist, table_name is not defined") - if self._has_remote_warehouse_in_config(): - file_format = storage.spark_options.file_format + path = storage.spark_options.path + file_format = storage.spark_options.file_format + + if path: if not file_format: - self.to_spark_df().write.saveAsTable(table_name) + file_format = "parquet" + write_mode = "overwrite" if allow_overwrite else "error" + self.to_spark_df().write.format(file_format).mode(write_mode).save(path) + elif table_name: + if self._has_remote_warehouse_in_config(): + if not file_format: + self.to_spark_df().write.saveAsTable(table_name) + else: + self.to_spark_df().write.format(file_format).saveAsTable(table_name) else: - self.to_spark_df().write.format(file_format).saveAsTable(table_name) + self.to_spark_df().createOrReplaceTempView(table_name) else: - self.to_spark_df().createOrReplaceTempView(table_name) + raise ValueError( + "Cannot persist: either 'table' or 'path' must be specified in SavedDatasetSparkStorage" + ) def _has_remote_warehouse_in_config(self) -> bool: """ diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index cd41921e56a..d94a14123c7 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -453,3 +453,14 @@ def to_data_source(self) -> DataSource: file_format=self.spark_options.file_format, table_format=self.spark_options.table_format, ) + + @staticmethod + def from_data_source(data_source: DataSource) -> "SavedDatasetSparkStorage": + assert isinstance(data_source, SparkSource) + return SavedDatasetSparkStorage( + table=data_source.table, + query=data_source.query, + path=data_source.path, + file_format=data_source.file_format, + table_format=data_source.table_format, + ) diff --git a/sdk/python/feast/saved_dataset.py b/sdk/python/feast/saved_dataset.py index 4a3043a8731..9ed69861d4b 100644 --- a/sdk/python/feast/saved_dataset.py +++ b/sdk/python/feast/saved_dataset.py @@ -34,6 +34,7 @@ def __new__(cls, name, bases, dct): _DATA_SOURCE_TO_SAVED_DATASET_STORAGE = { "FileSource": "feast.infra.offline_stores.file_source.SavedDatasetFileStorage", + "SparkSource": "feast.infra.offline_stores.contrib.spark_offline_store.spark_source.SavedDatasetSparkStorage", } diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_persist.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_persist.py new file mode 100644 index 00000000000..adcb54a1cf7 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_persist.py @@ -0,0 +1,260 @@ +""" +Unit tests for SparkRetrievalJob.persist() and SavedDatasetSparkStorage.from_data_source(). + +Covers the fix for https://github.com/feast-dev/feast/issues/6261 where: +1. SavedDatasetStorage.from_data_source() did not support SparkSource +2. SavedDatasetSparkStorage lacked a from_data_source() method +3. SparkRetrievalJob.persist() only supported table-based storage, not path-based +""" + +from unittest.mock import MagicMock + +import pytest + +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStoreConfig, + SparkRetrievalJob, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SavedDatasetSparkStorage, + SparkSource, +) +from feast.infra.offline_stores.file_source import FileSource +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig +from feast.repo_config import RepoConfig +from feast.saved_dataset import SavedDatasetStorage +from feast.table_format import IcebergFormat + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def repo_config(): + return RepoConfig( + registry="file:///tmp/registry.db", + project="test", + provider="local", + online_store=SqliteOnlineStoreConfig(type="sqlite"), + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + +@pytest.fixture() +def table_spark_source(): + return SparkSource( + name="my_table", + table="db.my_table", + timestamp_field="event_timestamp", + ) + + +@pytest.fixture() +def path_spark_source(): + return SparkSource( + name="my_path_source", + path="s3a://bucket/data/features/", + file_format="parquet", + timestamp_field="event_timestamp", + ) + + +def _make_spark_retrieval_job(repo_config, remote_warehouse=True): + """Build a SparkRetrievalJob with a mocked SparkSession.""" + mock_spark = MagicMock() + + if remote_warehouse: + mock_spark.conf.get.side_effect = lambda key: { + "hive.metastore.uris": "thrift://metastore:9083", + }.get(key, None) + else: + + def _local_conf_get(key): + if key == "hive.metastore.uris": + raise Exception("not set") + if key == "spark.sql.warehouse.dir": + return "file:///tmp/spark-warehouse" + return None + + mock_spark.conf.get.side_effect = _local_conf_get + + return SparkRetrievalJob( + spark_session=mock_spark, + query="SELECT 1", + full_feature_names=False, + config=repo_config, + ) + + +# --------------------------------------------------------------------------- +# Group 1: SavedDatasetSparkStorage.from_data_source() +# --------------------------------------------------------------------------- + + +class TestSavedDatasetSparkStorageFromDataSource: + def test_from_data_source_with_table_source(self, table_spark_source): + storage = SavedDatasetSparkStorage.from_data_source(table_spark_source) + + assert isinstance(storage, SavedDatasetSparkStorage) + assert storage.spark_options.table == "db.my_table" + assert storage.spark_options.query is None + assert storage.spark_options.path is None + + def test_from_data_source_with_path_source(self, path_spark_source): + storage = SavedDatasetSparkStorage.from_data_source(path_spark_source) + + assert isinstance(storage, SavedDatasetSparkStorage) + assert storage.spark_options.path == "s3a://bucket/data/features/" + assert storage.spark_options.file_format == "parquet" + assert storage.spark_options.table is None + assert storage.spark_options.query is None + + def test_from_data_source_rejects_non_spark_source(self): + file_source = FileSource( + path="/tmp/data.parquet", + timestamp_field="event_timestamp", + ) + with pytest.raises(AssertionError): + SavedDatasetSparkStorage.from_data_source(file_source) + + +# --------------------------------------------------------------------------- +# Group 2: SavedDatasetStorage.from_data_source() dispatch +# --------------------------------------------------------------------------- + + +class TestSavedDatasetStorageDispatch: + def test_from_data_source_resolves_spark(self, table_spark_source): + storage = SavedDatasetStorage.from_data_source(table_spark_source) + + assert isinstance(storage, SavedDatasetSparkStorage) + assert storage.spark_options.table == "db.my_table" + + def test_from_data_source_resolves_path_spark(self, path_spark_source): + storage = SavedDatasetStorage.from_data_source(path_spark_source) + + assert isinstance(storage, SavedDatasetSparkStorage) + assert storage.spark_options.path == "s3a://bucket/data/features/" + assert storage.spark_options.file_format == "parquet" + + def test_roundtrip_table_source(self, table_spark_source): + storage = SavedDatasetStorage.from_data_source(table_spark_source) + roundtripped = storage.to_data_source() + + assert isinstance(roundtripped, SparkSource) + assert roundtripped.table == table_spark_source.table + assert roundtripped.query == table_spark_source.query + assert roundtripped.path == table_spark_source.path + + def test_roundtrip_path_source(self): + source = SparkSource( + name="my_path_source", + table="fallback_name", + timestamp_field="event_timestamp", + ) + storage = SavedDatasetStorage.from_data_source(source) + roundtripped = storage.to_data_source() + + assert isinstance(roundtripped, SparkSource) + assert roundtripped.table == source.table + + +# --------------------------------------------------------------------------- +# Group 3: SparkRetrievalJob.persist() +# --------------------------------------------------------------------------- + + +class TestSparkRetrievalJobPersist: + def test_persist_with_table_saves_as_table(self, repo_config): + job = _make_spark_retrieval_job(repo_config, remote_warehouse=True) + storage = SavedDatasetSparkStorage(table="output_table") + + job.persist(storage) + + mock_df = job.spark_session.sql.return_value + mock_df.write.saveAsTable.assert_called_once_with("output_table") + + def test_persist_with_table_and_format(self, repo_config): + job = _make_spark_retrieval_job(repo_config, remote_warehouse=True) + storage = SavedDatasetSparkStorage(table="output_table", file_format="parquet") + + job.persist(storage) + + mock_df = job.spark_session.sql.return_value + mock_df.write.format.assert_called_once_with("parquet") + mock_df.write.format.return_value.saveAsTable.assert_called_once_with( + "output_table" + ) + + def test_persist_with_path_writes_to_path(self, repo_config): + job = _make_spark_retrieval_job(repo_config, remote_warehouse=True) + storage = SavedDatasetSparkStorage( + path="s3a://bucket/output/", file_format="parquet" + ) + + job.persist(storage) + + mock_df = job.spark_session.sql.return_value + mock_df.write.format.assert_called_once_with("parquet") + mock_df.write.format.return_value.mode.assert_called_once_with("error") + mock_df.write.format.return_value.mode.return_value.save.assert_called_once_with( + "s3a://bucket/output/" + ) + + def test_persist_with_path_defaults_to_parquet(self, repo_config): + """When path is set with table_format but no file_format, persist defaults to parquet.""" + job = _make_spark_retrieval_job(repo_config, remote_warehouse=True) + storage = SavedDatasetSparkStorage( + path="s3a://bucket/output/", + file_format=None, + table_format=IcebergFormat(catalog="test_catalog"), + ) + + job.persist(storage) + + mock_df = job.spark_session.sql.return_value + mock_df.write.format.assert_called_once_with("parquet") + + def test_persist_with_path_allow_overwrite(self, repo_config): + job = _make_spark_retrieval_job(repo_config, remote_warehouse=True) + storage = SavedDatasetSparkStorage( + path="s3a://bucket/output/", file_format="parquet" + ) + + job.persist(storage, allow_overwrite=True) + + mock_df = job.spark_session.sql.return_value + mock_df.write.format.return_value.mode.assert_called_once_with("overwrite") + + def test_persist_with_path_custom_format(self, repo_config): + job = _make_spark_retrieval_job(repo_config, remote_warehouse=True) + storage = SavedDatasetSparkStorage( + path="s3a://bucket/output/", file_format="avro" + ) + + job.persist(storage) + + mock_df = job.spark_session.sql.return_value + mock_df.write.format.assert_called_once_with("avro") + mock_df.write.format.return_value.mode.return_value.save.assert_called_once_with( + "s3a://bucket/output/" + ) + + def test_persist_raises_without_table_or_path(self, repo_config): + job = _make_spark_retrieval_job(repo_config, remote_warehouse=True) + storage = SavedDatasetSparkStorage(query="SELECT * FROM t") + + with pytest.raises( + ValueError, match="either 'table' or 'path' must be specified" + ): + job.persist(storage) + + def test_persist_local_warehouse_creates_temp_view(self, repo_config): + job = _make_spark_retrieval_job(repo_config, remote_warehouse=False) + storage = SavedDatasetSparkStorage(table="output_table") + + job.persist(storage) + + mock_df = job.spark_session.sql.return_value + mock_df.createOrReplaceTempView.assert_called_once_with("output_table")