diff --git a/README.md b/README.md index febe8e9ae..54f2f840c 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,11 @@ All examples below use `python core.py` for source code users. **If you're using ### Running a validation (`validate`) +Pandas Copy-on-Write (CoW) enabled globally when using the rules engine. +In Pandas 2.x this is an opt-in feature, so it affects the whole process. + +Note: in Pandas 3.x, CoW is enabled by default, so this behavior will become standard once the project is upgraded. + Clone the repository and run: ```bash diff --git a/cdisc_rules_engine/rules_engine.py b/cdisc_rules_engine/rules_engine.py index d4edf3df1..8860ff8d5 100644 --- a/cdisc_rules_engine/rules_engine.py +++ b/cdisc_rules_engine/rules_engine.py @@ -2,6 +2,7 @@ from typing import List, Union from dateutil.parser._parser import ParserError import traceback +import pandas as pd from business_rules import export_rule_data from business_rules.engine import run @@ -33,6 +34,7 @@ DataServiceInterface, ) from cdisc_rules_engine.models.actions import COREActions +from cdisc_rules_engine.models.dataset import DaskDataset from cdisc_rules_engine.models.dataset.dataset_interface import DatasetInterface from cdisc_rules_engine.models.dataset_variable import DatasetVariable from cdisc_rules_engine.models.failed_validation_entity import FailedValidationEntity @@ -59,6 +61,8 @@ from cdisc_rules_engine.models.sdtm_dataset_metadata import SDTMDatasetMetadata from cdisc_rules_engine.enums.sensitivity import Sensitivity +pd.options.mode.copy_on_write = True + class RulesEngine: def __init__( @@ -375,9 +379,9 @@ def execute_rule( rule["conditions"], dataset.columns.to_list() ) rule_copy["conditions"].set_conditions(updated_conditions) - # Adding copy for now to avoid updating cached dataset - dataset = deepcopy(dataset) # preprocess dataset + if isinstance(dataset, DaskDataset): + dataset = deepcopy(dataset) dataset_preprocessor = DatasetPreprocessor( dataset, dataset_metadata, self.data_service, self.cache ) diff --git a/cdisc_rules_engine/services/cache/in_memory_cache_service.py b/cdisc_rules_engine/services/cache/in_memory_cache_service.py index cdc79b01c..9f36280ab 100644 --- a/cdisc_rules_engine/services/cache/in_memory_cache_service.py +++ b/cdisc_rules_engine/services/cache/in_memory_cache_service.py @@ -5,7 +5,7 @@ from cdisc_rules_engine.interfaces import ( CacheServiceInterface, ) -from cdisc_rules_engine.models.dataset import DatasetInterface +from cdisc_rules_engine.models.dataset import DatasetInterface, PandasDataset from cachetools import LRUCache import psutil from multiprocessing import Lock @@ -66,7 +66,10 @@ def add_dataset(self, cache_key, data): self.dataset_cache[cache_key] = data def get_dataset(self, cache_key): - return self.dataset_cache.get(cache_key, None) + cached = self.dataset_cache.get(cache_key) + if type(cached) is PandasDataset: + return PandasDataset(cached.data.copy(deep=False)) + return cached def add_batch( self, @@ -82,27 +85,30 @@ def add_batch( self.add(prefix + cache_key, item) def get(self, cache_key): - return self.cache.get(cache_key, None) + cached = self.cache.get(cache_key) + if type(cached) is PandasDataset: + return PandasDataset(cached.data.copy(deep=False)) + return cached def get_all(self, cache_keys: List[str]): - return [self.cache.get(key) for key in cache_keys] + return [self.get(key) for key in cache_keys] def get_all_by_prefix(self, prefix): items = [] for key in self.cache: if key.startswith(prefix): - items.append(self.cache[key]) + items.append(self.get(key)) return items def dataset_keys(self): return self.dataset_cache.keys() def filter_cache(self, prefix: str) -> dict: - return {k: self.cache[k] for k in self.cache.keys() if k.startswith(prefix)} + return {k: self.get(k) for k in self.cache.keys() if k.startswith(prefix)} def get_by_regex(self, regex: str) -> dict: regex = regex.replace("*", ".*") - return {k: self.cache[k] for k in self.cache.keys() if re.search(regex, k)} + return {k: self.get(k) for k in self.cache.keys() if re.search(regex, k)} def exists(self, cache_key): return cache_key in self.cache diff --git a/tests/unit/test_services/test_cache/test_immutable_cache.py b/tests/unit/test_services/test_cache/test_immutable_cache.py new file mode 100644 index 000000000..1846fe135 --- /dev/null +++ b/tests/unit/test_services/test_cache/test_immutable_cache.py @@ -0,0 +1,224 @@ +import numpy as np +import pandas as pd +import pytest + +from cdisc_rules_engine.models.dataset.pandas_dataset import PandasDataset +from cdisc_rules_engine.services.cache.in_memory_cache_service import ( + InMemoryCacheService, +) + + +@pytest.fixture(autouse=True) +def reset_singleton(): + InMemoryCacheService._instance = None + yield + InMemoryCacheService._instance = None + + +@pytest.fixture +def cache(): + return InMemoryCacheService() + + +@pytest.fixture +def sample_dataset(): + return PandasDataset(pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})) + + +class TestGet: + def test_returns_new_wrapper_not_cached_object(self, cache, sample_dataset): + cache.add("x", sample_dataset) + result = cache.get("x") + assert result is not cache.cache["x"] + assert result.data is not cache.cache["x"].data + + def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("x", sample_dataset) + retrieved = cache.get("x") + retrieved.data.loc[0, "A"] = 999 + assert cache.cache["x"].data.loc[0, "A"] == 1 + + def test_shares_memory_before_write(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("x", sample_dataset) + retrieved = cache.get("x") + assert np.shares_memory(retrieved.data["A"], cache.cache["x"].data["A"]) + + def test_add_rows_does_not_affect_cache(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("x", sample_dataset) + retrieved = cache.get("x") + retrieved.data = pd.concat( + [retrieved.data, pd.DataFrame({"A": [999], "B": [999]})], + ignore_index=True, + ) + assert len(cache.cache["x"].data) == 3 + assert len(retrieved.data) == 4 + + def test_drop_rows_does_not_affect_cache(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("x", sample_dataset) + retrieved = cache.get("x") + retrieved.data = retrieved.data.drop(index=0).reset_index(drop=True) + assert len(cache.cache["x"].data) == 3 + assert len(retrieved.data) == 2 + + def test_filter_rows_does_not_affect_cache(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("x", sample_dataset) + retrieved = cache.get("x") + retrieved.data = retrieved.data[retrieved.data["A"] > 1].reset_index(drop=True) + assert len(cache.cache["x"].data) == 3 + assert cache.cache["x"].data["A"].tolist() == [1, 2, 3] + + def test_multiple_gets_are_independent(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("x", sample_dataset) + first = cache.get("x") + second = cache.get("x") + first.data = first.data.drop(index=0).reset_index(drop=True) + assert len(second.data) == 3 + assert len(cache.cache["x"].data) == 3 + + def test_non_dataset_returns_as_is(self, cache): + cache.add("key", {"some": "dict"}) + assert cache.get("key") == {"some": "dict"} + + def test_object_dtype_nested_mutation_affects_cache(self, cache): + """CoW can't protect in nested mutations""" + df = pd.DataFrame({"A": [[1], [2], [3]]}) + cache.add("x", PandasDataset(df)) + retrieved = cache.get("x") + retrieved.data.loc[0, "A"].append(999) + assert cache.cache["x"].data.loc[0, "A"] == [1, 999] + + +class TestGetDataset: + def test_returns_new_wrapper_not_cached_object(self, cache, sample_dataset): + cache.add_dataset("x", sample_dataset) + result = cache.get_dataset("x") + assert result is not cache.dataset_cache["x"] + assert result.data is not cache.dataset_cache["x"].data + + def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add_dataset("x", sample_dataset) + retrieved = cache.get_dataset("x") + retrieved.data.loc[0, "A"] = 999 + assert cache.dataset_cache["x"].data.loc[0, "A"] == 1 + + def test_add_rows_does_not_affect_cache(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add_dataset("x", sample_dataset) + retrieved = cache.get_dataset("x") + retrieved.data = pd.concat( + [retrieved.data, pd.DataFrame({"A": [999], "B": [999]})], + ignore_index=True, + ) + assert len(cache.dataset_cache["x"].data) == 3 + assert len(retrieved.data) == 4 + + def test_drop_rows_does_not_affect_cache(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add_dataset("x", sample_dataset) + retrieved = cache.get_dataset("x") + retrieved.data = retrieved.data.drop(index=0).reset_index(drop=True) + assert len(cache.dataset_cache["x"].data) == 3 + assert len(retrieved.data) == 2 + + +class TestGetAll: + def test_returns_new_wrappers(self, cache, sample_dataset): + cache.add("x", sample_dataset) + cache.add("y", sample_dataset) + results = cache.get_all(["x", "y"]) + assert all(r is not cache.cache["x"] for r in results) + assert all(r.data is not cache.cache["x"].data for r in results) + + def test_results_are_independent(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("x", sample_dataset) + cache.add("y", sample_dataset) + first, second = cache.get_all(["x", "y"]) + first.data = first.data.drop(index=0).reset_index(drop=True) + assert len(second.data) == 3 + assert len(cache.cache["x"].data) == 3 + + def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("x", sample_dataset) + results = cache.get_all(["x"]) + results[0].data.loc[0, "A"] = 999 + assert cache.cache["x"].data.loc[0, "A"] == 1 + + def test_missing_key_returns_none(self, cache): + assert cache.get_all(["missing"]) == [None] + + +class TestGetAllByPrefix: + def test_returns_only_matching_keys(self, cache, sample_dataset): + cache.add("ds/ae", sample_dataset) + cache.add("ds/lb", sample_dataset) + cache.add("other/ae", sample_dataset) + results = cache.get_all_by_prefix("ds/") + assert len(results) == 2 + + def test_returns_new_wrappers(self, cache, sample_dataset): + cache.add("ds/ae", sample_dataset) + results = cache.get_all_by_prefix("ds/") + assert results[0] is not cache.cache["ds/ae"] + assert results[0].data is not cache.cache["ds/ae"].data + + def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("ds/ae", sample_dataset) + results = cache.get_all_by_prefix("ds/") + results[0].data.loc[0, "A"] = 999 + assert cache.cache["ds/ae"].data.loc[0, "A"] == 1 + + def test_drop_rows_does_not_affect_cache(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("ds/ae", sample_dataset) + results = cache.get_all_by_prefix("ds/") + results[0].data = results[0].data.drop(index=0).reset_index(drop=True) + assert len(cache.cache["ds/ae"].data) == 3 + + def test_no_match_returns_empty(self, cache, sample_dataset): + cache.add("ds/ae", sample_dataset) + assert cache.get_all_by_prefix("other/") == [] + + +class TestGetByRegex: + def test_returns_matching_keys(self, cache, sample_dataset): + cache.add("ae_data", sample_dataset) + cache.add("lb_data", sample_dataset) + cache.add("ae_meta", sample_dataset) + result = cache.get_by_regex("ae_*") + assert set(result.keys()) == {"ae_data", "ae_meta"} + + def test_returns_new_wrappers(self, cache, sample_dataset): + cache.add("ae_data", sample_dataset) + result = cache.get_by_regex("ae_*") + assert result["ae_data"] is not cache.cache["ae_data"] + assert result["ae_data"].data is not cache.cache["ae_data"].data + + def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("ae_data", sample_dataset) + result = cache.get_by_regex("ae_*") + result["ae_data"].data.loc[0, "A"] = 999 + assert cache.cache["ae_data"].data.loc[0, "A"] == 1 + + def test_drop_rows_does_not_affect_cache(self, cache, sample_dataset): + pd.options.mode.copy_on_write = True + cache.add("ae_data", sample_dataset) + result = cache.get_by_regex("ae_*") + result["ae_data"].data = ( + result["ae_data"].data.drop(index=0).reset_index(drop=True) + ) + assert len(cache.cache["ae_data"].data) == 3 + + def test_no_match_returns_empty_dict(self, cache, sample_dataset): + cache.add("ae_data", sample_dataset) + assert cache.get_by_regex("lb_*") == {}