diff --git a/examples/1-Getting-Started.ipynb b/examples/1-Getting-Started.ipynb index ef44aca..75c669b 100644 --- a/examples/1-Getting-Started.ipynb +++ b/examples/1-Getting-Started.ipynb @@ -298,8 +298,6 @@ "```\n", "\n", "\n", - "\n", - "\n", "#### Example of Seed Specification\n", "\n", "When using `seed` and specifying the inclusion of specific `ErrorType` objects, the seed must be set in the constructor of these objects manually if fixing of the random generator is desired.\n", @@ -323,6 +321,22 @@ " error_mechanisms_to_exclude=[error_mechanism.EAR()],\n", " seed=42\n", " )\n", + "```\n", + "\n", + "#### Config Extraction\n", + "\n", + "The user can print each error model used in the high-level API as follows. Note that in the example, the two indices on the config object are the affected column name and the error model number. i.e. `config.columns[col-name][error-model-number]`.\n", + "\n", + "```python\n", + "from tab_err.api import high_level\n", + "\n", + "corrupted_data, error_mask, config = high_level.create_errors_with_config(\n", + " data=data\n", + " error_rate=0.5\n", + ")\n", + "\n", + "# If \"name\" were a column in data,\n", + "print(config.columns['name'][0])\n", "```" ] } diff --git a/tab_err/_error_model.py b/tab_err/_error_model.py index 4908e1c..6f14e4c 100644 --- a/tab_err/_error_model.py +++ b/tab_err/_error_model.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from tab_err.api import low_level +from tab_err.error_mechanism import EAR if TYPE_CHECKING: import pandas as pd @@ -25,6 +26,29 @@ class ErrorModel: error_type: ErrorType error_rate: float + def __repr__(self) -> str: + """Unambiguous representation, evaluating the mechanism's repr but only the type's name.""" + if self.error_mechanism.__class__ == EAR: + return ( + f"{self.__class__.__name__}(" + f"error_mechanism={self.error_mechanism.__class__.__name__}(condition_to_column='{self.error_mechanism.condition_to_column}'), " + f"error_type={self.error_type.__class__.__name__}, " + f"error_rate={self.error_rate})" + ) + return ( + f"{self.__class__.__name__}(" + f"error_mechanism={self.error_mechanism.__class__.__name__}, " + f"error_type={self.error_type.__class__.__name__}, " + f"error_rate={self.error_rate})" + ) + + def __str__(self) -> str: + """Readable representation for end-users.""" + # Assumes error_rate is a float like 0.05. Displays as 5.0%. + if self.error_mechanism.__class__ == EAR: + return f"ErrorModel: {self.error_rate:.1%} '{self.error_type.__class__.__name__}' errors via {self.error_mechanism.__class__.__name__} conditioning on column '{self.error_mechanism.condition_to_column}'" # Noqa: E501 + return f"ErrorModel: {self.error_rate:.1%} '{self.error_type.__class__.__name__}' errors via {self.error_mechanism.__class__.__name__}" + def apply(self: ErrorModel, data: pd.DataFrame, column: str | int) -> tuple[pd.DataFrame, pd.DataFrame]: """Applies the defined ErrorModel to the given column of a pandas DataFrame. diff --git a/tab_err/api/high_level.py b/tab_err/api/high_level.py index 7570981..2f68442 100644 --- a/tab_err/api/high_level.py +++ b/tab_err/api/high_level.py @@ -200,6 +200,56 @@ def create_errors( # noqa: PLR0913 ) -> tuple[pd.DataFrame, pd.DataFrame]: """Creates errors in a given DataFrame, at a rate of *approximately* max_error_rate. + Description: + Functionally identical to `create_errors_with_config`, but allows for terser usage when detailed configuration is not needed. + + Args: + data (pd.DataFrame): The pandas DataFrame to create errors in. + error_rate (float): The maximum error rate to be introduced to each column in the DataFrame. + n_error_models_per_column (int, optional): The number of valid error models to apply to each column. Defaults to 1. + error_types_to_include (list[ErrorType] | None, optional): A list of the error types to be included when building error models. Defaults to None. + error_types_to_exclude (list[ErrorType] | None, optional): A list of the error types to be excluded when building error models. Defaults to None. + When both error_types_to_include and error_types_to_exclude are none, the maximum number of default error types will be used. + At least one must be None or an error will occur. + error_mechanisms_to_include (list[ErrorMechanism] | None = None): A list of the error mechanisms to be included when building error models. + Defaults to None. + error_mechanisms_to_exclude (list[ErrorMechanism] | None = None): A list of the error mechanisms to be excluded when building error models. + Defaults to None. + seed (int | None, optional): Random seed. Defaults to None. + + Returns: + tuple[pd.DataFrame, pd.DataFrame]: + - The first element is a copy of 'data' with errors. + - The second element is the associated error mask. + """ + return create_errors_with_config( + data=data, + error_rate=error_rate, + n_error_models_per_column=n_error_models_per_column, + error_types_to_include=error_types_to_include, + error_types_to_exclude=error_types_to_exclude, + error_mechanisms_to_include=error_mechanisms_to_include, + error_mechanisms_to_exclude=error_mechanisms_to_exclude, + seed=seed, + )[:2] # Drop the config from the return for this function. + + +def create_errors_with_config( # noqa: PLR0913 + data: pd.DataFrame, + error_rate: float, + n_error_models_per_column: int = 1, + error_types_to_include: list[ErrorType] | None = None, + error_types_to_exclude: list[ErrorType] | None = None, + error_mechanisms_to_include: list[ErrorMechanism] | None = None, + error_mechanisms_to_exclude: list[ErrorMechanism] | None = None, + seed: int | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame, MidLevelConfig]: + """Creates errors in a given DataFrame, at a rate of *approximately* max_error_rate. + + Description: + Builds a configuration of error models to apply to the DataFrame based on the input parameters and then applies the error models to the DataFrame. + Returns the dirty DataFrame, the error mask, and the configuration for reproducibility. + Args: data (pd.DataFrame): The pandas DataFrame to create errors in. error_rate (float): The maximum error rate to be introduced to each column in the DataFrame. @@ -266,4 +316,4 @@ def create_errors( # noqa: PLR0913 # Create Errors & Return dirty_data, error_mask = mid_level.create_errors(data_copy, config) - return dirty_data, error_mask + return dirty_data, error_mask, config