diff --git a/moddata/_utils.py b/moddata/_utils.py index 48e818e..2c0a16d 100644 --- a/moddata/_utils.py +++ b/moddata/_utils.py @@ -59,11 +59,9 @@ def _load_btc(): def _load_pl_banking_stocks() -> pd.DataFrame: - with ( + return pd.read_parquet(str( resources.files('moddata.data').joinpath('pl_banking_stocks.parquet') - as f - ): - return pd.read_parquet(f) + )) def load_data(dataset: Dataset) -> pd.DataFrame | None: diff --git a/tests/pipeline/test_bankchurn_pipeline.py b/tests/pipeline/test_bankchurn_pipeline.py index 4173d06..4c0278d 100644 --- a/tests/pipeline/test_bankchurn_pipeline.py +++ b/tests/pipeline/test_bankchurn_pipeline.py @@ -4,17 +4,21 @@ from moddata.src.config import BankchurnPipelineConfig -def test_bankchurn_pipeline_run(): +def test_bankchurn_pipeline_tree_like(): X_train, X_test, y_train, y_test = BankchurnPipeline( config=BankchurnPipelineConfig( random_state=12345, - train_size=0.8 + train_size=0.8, + encoding_and_scaling_model_type="tree_like" ) ).run() - assert X_train.shape == (8_000, 10) - assert X_test.shape == (2_000, 10) + assert X_train.shape == (8_000, 11) + assert X_test.shape == (2_000, 11) assert y_train.shape == (8_000, 1) assert y_test.shape == (2_000, 1) assert np.all(np.array(y_test.index[:3]) == np.array([7867, 1402, 8606])) + + +test_bankchurn_pipeline_tree_like \ No newline at end of file