diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..3040967 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,73 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +OasisDataManager (`oasis-data-manager` on PyPI) is a Python data management library for the Oasis Loss Modelling Framework. It provides storage backend abstractions (local, AWS S3, Azure Blob), DataFrame reader abstractions (Pandas, Dask, PyArrow), and a complex data pipeline layer that combines fetching, reading, filtering, and adjusting data. + +## Build & Development Commands + +```bash +# Install dependencies +pip install pip-tools +pip install -r requirements.txt + +# Run all tests (requires Docker services) +docker compose up -d # Start LocalStack (S3) and Azurite (Azure Blob) +pytest + +# Run a single test file +pytest tests/filestorage/test_local.py + +# Run a single test +pytest tests/filestorage/test_local.py::TestLocalStorage::test_put_get -v + +# Linting +flake8 oasis_data_manager/ + +# Type checking +mypy oasis_data_manager/ +``` + +## Architecture + +Three main modules, each following a base-class/backend pattern: + +### FileStore (`oasis_data_manager/filestore/`) +Storage abstraction layer built on `fsspec`. `BaseStorage` defines the interface (`get`, `put`, `exists`, `extract`, `compress`, caching). Backends: `LocalStorage`, `AwsS3Storage`, `AzureABFSStorage`. Configuration via `StorageConfig` TypedDict with `storage_class` (dotted path string) and `options` dict. Factory: `get_storage_from_config()`. + +### DataFrame Reader (`oasis_data_manager/df_reader/`) +Read CSV/Parquet files with optional filtering and SQL. `OasisReader` base class with `read_csv`, `read_parquet`, `filter`, `sql`, `as_pandas`. Backends: `OasisPandasReader`, `OasisDaskReader` (adds SQL via dask-sql), `OasisPyarrowReader`. Auto-detects format from file extension. Factory: `get_df_reader()`. + +### Complex Data (`oasis_data_manager/complex/`) +High-level pipeline combining fetch, read, filter, and adjust. `ComplexData` orchestrates the flow: `fetch()` retrieves remote data, `get_df_reader()` wraps it in an OasisReader, `run()` executes the pipeline. `Adjustment` is a base class for pandas DataFrame transformations applied in sequence. + +### Shared Utilities +- `config.py`: `load_class(path)` dynamically imports classes from dotted path strings. +- `errors/`: `OasisException` base exception. + +## Code Style + +- Max line length: 150 (flake8) +- flake8 ignores: E501, E402 +- mypy: `follow_imports = skip`, `ignore_missing_imports = true` +- isort is in dependencies but not actively enforced + +## Testing + +- Tests mirror source structure: `tests/filestorage/`, `tests/df_reader/`, `tests/complex/` +- Storage tests use Docker services: LocalStack (port 4566) for S3, Azurite (port 10000) for Azure +- `tests/filestorage/test_general.py` uses hypothesis for property-based testing across all backends +- Storage test fixtures are context managers (`aws_s3_storage()`, `azure_abfs_storage()`, `local_storage()`) +- HTTP mocking uses `respx` + +## Version + +Version is stored in `oasis_data_manager/__init__.py` as `__version__`. The CI `version.yml` workflow updates it automatically. + +## Branch & PR Conventions + +- Main branch: `develop` +- Release branches: `release/x.y.z` +- PRs require 2+ reviewers and must link to an issue diff --git a/oasis_data_manager/complex/complex.py b/oasis_data_manager/complex/complex.py index 266b479..12acdfb 100644 --- a/oasis_data_manager/complex/complex.py +++ b/oasis_data_manager/complex/complex.py @@ -87,8 +87,7 @@ def to_reader(self, fetch_result) -> OasisReader: def run(self): if self.fetch_required and self.filename: - filename_or_url = self.filename if self.filename else self.url - extension = pathlib.Path(filename_or_url).suffix + extension = pathlib.Path(self.filename).suffix self.fetch_required = extension not in [".parquet", ".pq", ".csv"] fetch_result = None @@ -123,10 +122,6 @@ def fetch(self): class RestComplexData(ComplexData): exceptions = ( httpx.RequestError, - httpx.TimeoutException, - httpx.ReadTimeout, - httpx.ConnectTimeout, - httpx.ConnectError, httpcore.ReadTimeout, httpcore.ConnectTimeout, httpcore.ConnectError, diff --git a/oasis_data_manager/config.py b/oasis_data_manager/config.py index 27b2f74..b7aff1e 100644 --- a/oasis_data_manager/config.py +++ b/oasis_data_manager/config.py @@ -14,7 +14,7 @@ def load_class(path, base=None): module = importlib.import_module(module_path) cls = getattr(module, cls_name) - if base and cls is not base and base not in cls.__bases__: + if base and not issubclass(cls, base): raise ConfigError(f"'{cls.__name__}' does not extend '{base.__name__}'") return cls diff --git a/oasis_data_manager/df_reader/backends/dask.py b/oasis_data_manager/df_reader/backends/dask.py index 5fe51e7..7292855 100644 --- a/oasis_data_manager/df_reader/backends/dask.py +++ b/oasis_data_manager/df_reader/backends/dask.py @@ -76,7 +76,6 @@ def apply_geo(self, shape_filename_path, *args, drop_geo=True, **kwargs): return self.copy_with_df(df) def apply_sql(self, sql): - df = self.df.copy() try: # Initially this was the filename, but some filenames are invalid for the table, # is it ok to call it the same name all the time? Mapped to DaskDataTable in case @@ -84,7 +83,10 @@ def apply_sql(self, sql): self.sql_context.create_table("DaskDataTable", self.df) formatted_sql = sql.replace(self.sql_table_name, "DaskDataTable") - self.pre_sql_columns.extend(df.columns) + # Combine columns from join() tables with current df columns for case restoration + col_map = {} + for col in list(self.pre_sql_columns) + list(self.df.columns): + col_map.setdefault(col.lower(), col) # dask expects the columns to be lower case, which won't match some data df = self.sql_context.sql( @@ -93,17 +95,7 @@ def apply_sql(self, sql): ) # which means we then need to map the columns back to the original # and allow for any aggregations to be retained - validated_columns = [] - for v in df.columns: - pre = False - for x in self.pre_sql_columns: - if v.lower() == x.lower(): - validated_columns.append(x) - pre = True - - if not pre: - validated_columns.append(v) - df.columns = validated_columns + df.columns = [col_map.get(v.lower(), v) for v in df.columns] return self.copy_with_df(df) except ParsingException: diff --git a/oasis_data_manager/df_reader/backends/pandas.py b/oasis_data_manager/df_reader/backends/pandas.py index b836416..109252a 100644 --- a/oasis_data_manager/df_reader/backends/pandas.py +++ b/oasis_data_manager/df_reader/backends/pandas.py @@ -14,43 +14,25 @@ class OasisPandasReader(OasisReader): - def read_csv(self, *args, **kwargs): - if isinstance(self.filename_or_buffer, str): - if self.filename_or_buffer.startswith( - "http://" - ) or self.filename_or_buffer.startswith("https://"): - self.df = pd.read_csv(self.filename_or_buffer, *args, **kwargs) - else: - _, uri = self.storage.get_storage_url( - self.filename_or_buffer, encode_params=False - ) - self.df = pd.read_csv( - uri, - *args, - **kwargs, - storage_options=self.storage.get_fsspec_storage_options(), - ) + def _read_with(self, read_fn, *args, **kwargs): + if isinstance(self.filename_or_buffer, str) and not self.filename_or_buffer.startswith(("http://", "https://")): + _, uri = self.storage.get_storage_url( + self.filename_or_buffer, encode_params=False + ) + self.df = read_fn( + uri, + *args, + **kwargs, + storage_options=self.storage.get_fsspec_storage_options(), + ) else: - self.df = pd.read_csv(self.filename_or_buffer, *args, **kwargs) + self.df = read_fn(self.filename_or_buffer, *args, **kwargs) + + def read_csv(self, *args, **kwargs): + self._read_with(pd.read_csv, *args, **kwargs) def read_parquet(self, *args, **kwargs): - if isinstance(self.filename_or_buffer, str): - if self.filename_or_buffer.startswith( - "http://" - ) or self.filename_or_buffer.startswith("https://"): - self.df = pd.read_parquet(self.filename_or_buffer, *args, **kwargs) - else: - _, uri = self.storage.get_storage_url( - self.filename_or_buffer, encode_params=False - ) - self.df = pd.read_parquet( - uri, - *args, - **kwargs, - storage_options=self.storage.get_fsspec_storage_options(), - ) - else: - self.df = pd.read_parquet(self.filename_or_buffer, *args, **kwargs) + self._read_with(pd.read_parquet, *args, **kwargs) def apply_geo(self, shape_filename_path, *args, drop_geo=True, **kwargs): """ diff --git a/oasis_data_manager/df_reader/backends/pyarrow.py b/oasis_data_manager/df_reader/backends/pyarrow.py index e21eb4f..8bbc36f 100644 --- a/oasis_data_manager/df_reader/backends/pyarrow.py +++ b/oasis_data_manager/df_reader/backends/pyarrow.py @@ -53,21 +53,13 @@ def list_of_lists(lol): else: ds_filter = None - if isinstance(self.filename_or_buffer, str): - if self.filename_or_buffer.startswith( - "http://" - ) or self.filename_or_buffer.startswith("https://"): - dataset = ds.dataset(self.filename_or_buffer, partitioning='hive') - self.df = dataset.to_table(filter=ds_filter).to_pandas() - else: - _, uri = self.storage.get_storage_url( - self.filename_or_buffer, encode_params=False - ) - - uri = uri.replace('file://', '') - dataset = ds.dataset(uri, partitioning='hive') - self.df = dataset.to_table(filter=ds_filter).to_pandas() - + if isinstance(self.filename_or_buffer, str) and not self.filename_or_buffer.startswith(("http://", "https://")): + _, uri = self.storage.get_storage_url( + self.filename_or_buffer, encode_params=False + ) + source = uri.replace('file://', '') else: - dataset = ds.dataset(self.filename_or_buffer, partitioning='hive') - self.df = dataset.to_table(filter=ds_filter).to_pandas() + source = self.filename_or_buffer + + dataset = ds.dataset(source, partitioning='hive') + self.df = dataset.to_table(filter=ds_filter).to_pandas() diff --git a/oasis_data_manager/df_reader/config.py b/oasis_data_manager/df_reader/config.py index 952bf9e..b63a03e 100644 --- a/oasis_data_manager/df_reader/config.py +++ b/oasis_data_manager/df_reader/config.py @@ -1,14 +1,9 @@ import json -import sys from copy import deepcopy from pathlib import Path +from typing import Any, Dict, TypedDict, Union -if sys.version_info >= (3, 8): - from typing import Any, Dict, TypedDict, Union - from typing_extensions import NotRequired -else: - from typing import Any, Dict, Union - from typing_extensions import NotRequired, TypedDict +from typing_extensions import NotRequired from ..config import ConfigError, load_class from ..filestore.backends.local import LocalStorage diff --git a/oasis_data_manager/filestore/backends/aws_s3.py b/oasis_data_manager/filestore/backends/aws_s3.py index 6c43f30..55c873f 100755 --- a/oasis_data_manager/filestore/backends/aws_s3.py +++ b/oasis_data_manager/filestore/backends/aws_s3.py @@ -119,12 +119,7 @@ def _set_lifecycle(self, ): self.gzip_content_types = gzip_content_types set_aws_log_level(self.aws_log_level) - root_dir = os.path.join(self.bucket_name or "", root_dir) - if root_dir.startswith(os.path.sep): - root_dir = root_dir[1:] - if root_dir.endswith(os.path.sep): - root_dir = root_dir[:-1] - + root_dir = self._normalize_root_dir(self.bucket_name, root_dir) super(AwsS3Storage, self).__init__(root_dir=root_dir, **kwargs) @property diff --git a/oasis_data_manager/filestore/backends/azure_abfs.py b/oasis_data_manager/filestore/backends/azure_abfs.py index 8a16d10..68ac50d 100755 --- a/oasis_data_manager/filestore/backends/azure_abfs.py +++ b/oasis_data_manager/filestore/backends/azure_abfs.py @@ -65,12 +65,7 @@ def __init__( self.endpoint_url = endpoint_url set_azure_log_level(self.azure_log_level) - root_dir = os.path.join(self.azure_container or "", root_dir or location or "") - if root_dir.startswith(os.path.sep): - root_dir = root_dir[1:] - if root_dir.endswith(os.path.sep): - root_dir = root_dir[:-1] - + root_dir = self._normalize_root_dir(self.azure_container, root_dir or location or "") super(AzureABFSStorage, self).__init__(root_dir=root_dir, **kwargs) @property diff --git a/oasis_data_manager/filestore/backends/base.py b/oasis_data_manager/filestore/backends/base.py index a6c1e96..fa9ca8b 100755 --- a/oasis_data_manager/filestore/backends/base.py +++ b/oasis_data_manager/filestore/backends/base.py @@ -44,23 +44,20 @@ def _join(self, path): return res - def exists(self, path): + def _safe_check(self, method_name, path): try: - return super().exists(path) + return getattr(super(), method_name)(path) except FileNotFoundError: return False + def exists(self, path): + return self._safe_check('exists', path) + def isfile(self, path): - try: - return super().isfile(path) - except FileNotFoundError: - return False + return self._safe_check('isfile', path) def isdir(self, path): - try: - return super().isdir(path) - except FileNotFoundError: - return False + return self._safe_check('isdir', path) class BaseStorage(object): @@ -83,6 +80,11 @@ def __init__( self.logger = logger or logging.getLogger() self._fs: Optional[StrictRootDirFs] = None + @staticmethod + def _normalize_root_dir(container, root_dir): + result = os.path.join(container or "", root_dir) + return result.strip(os.path.sep) + def to_config(self) -> dict: return { "storage_class": f"{self.__module__}.{type(self).__name__}", @@ -217,14 +219,12 @@ def get_from_cache(self, reference, required=False, no_cache_target=None): raise OasisException("Error: caching disabled for this filesystem and no_cache_target not provided") Path(no_cache_target).parent.mkdir(parents=True, exist_ok=True) if self._is_valid_url(reference): - with urlopen(reference, timeout=30) as r: - data = r.read() - with open(no_cache_target, "wb") as f: - f.write(data) - logging.info("Get from URL: {}".format(reference)) + with urlopen(reference, timeout=30) as r, open(no_cache_target, "wb") as f: + shutil.copyfileobj(r, f) + self.logger.info("Get from URL: {}".format(reference)) else: self.fs.get(reference, no_cache_target, recursive=True) - logging.info("Get from Filestore: {}".format(reference)) + self.logger.info("Get from Filestore: {}".format(reference)) return no_cache_target # Caching enabled @@ -394,9 +394,9 @@ def delete_file(self, reference): """ if self.fs.isfile(reference): self.fs.delete(reference) - logging.info("Deleted Shared file: {}".format(reference)) + self.logger.info("Deleted Shared file: {}".format(reference)) else: - logging.info("Delete Error - Unknwon reference {}".format(reference)) + self.logger.info("Delete Error - Unknwon reference {}".format(reference)) def delete_dir(self, reference): """ @@ -407,12 +407,12 @@ def delete_dir(self, reference): """ if self.fs.isdir(reference): if Path("/") == Path(reference).resolve(): - logging.info("Delete Error - prevented media root deletion") + self.logger.info("Delete Error - prevented media root deletion") else: self.fs.delete(reference, recursive=True) - logging.info("Deleted shared dir: {}".format(reference)) + self.logger.info("Deleted shared dir: {}".format(reference)) else: - logging.info("Delete Error - Unknwon reference {}".format(reference)) + self.logger.info("Delete Error - Unknwon reference {}".format(reference)) def create_traceback(self, stdout, stderr, output_dir=""): traceback_file = self._get_unique_filename(LOG_FILE_SUFFIX) diff --git a/oasis_data_manager/filestore/config.py b/oasis_data_manager/filestore/config.py index d742bf4..8664634 100644 --- a/oasis_data_manager/filestore/config.py +++ b/oasis_data_manager/filestore/config.py @@ -1,13 +1,8 @@ import json import os -import sys - -if sys.version_info >= (3, 8): - from typing import Optional, Tuple, TypedDict, Union - from typing_extensions import NotRequired -else: - from typing import Optional, Tuple, Union - from typing_extensions import NotRequired, TypedDict +from typing import Optional, Tuple, TypedDict, Union + +from typing_extensions import NotRequired from oasis_data_manager.config import ConfigError, load_class from oasis_data_manager.filestore.backends.base import BaseStorage diff --git a/oasis_data_manager/filestore/filestore.py b/oasis_data_manager/filestore/filestore.py index 382f67a..a9113ca 100644 --- a/oasis_data_manager/filestore/filestore.py +++ b/oasis_data_manager/filestore/filestore.py @@ -1,5 +1,4 @@ import contextlib -import urllib.parse from urllib import parse import fsspec @@ -20,14 +19,14 @@ def split_s3_url(parts): } return ( - urllib.parse.urlunparse( + parse.urlunparse( ( - parts[0], - parts[1], - parts[2], - parts[3], + parts.scheme, + parts.netloc, + parts.path, + parts.params, "", - parts[5], + parts.fragment, ) ), params, @@ -39,7 +38,7 @@ def split_azure_url(parts): query = parse.parse_qs(parts.query) if "connection_string" in query: - connection_string = parts.get("connection_string")[0] + connection_string = query.get("connection_string")[0] else: if "endpoint" in query: connection_string += f"BlobEndpoint={ query.get('endpoint', [None])[0]};" @@ -60,14 +59,14 @@ def split_azure_url(parts): } return ( - urllib.parse.urlunparse( + parse.urlunparse( ( - parts[0], - parts[1], - parts[2], - parts[3], + parts.scheme, + parts.netloc, + parts.path, + parts.params, "", - parts[5], + parts.fragment, ) ), params, diff --git a/oasis_data_manager/filestore/log.py b/oasis_data_manager/filestore/log.py index 194d9c0..a3a14c7 100644 --- a/oasis_data_manager/filestore/log.py +++ b/oasis_data_manager/filestore/log.py @@ -1,23 +1,22 @@ import logging -def set_aws_log_level(log_level): - # Set log level for s3boto3 +def _parse_log_level(log_level): try: - LOG_LEVEL = getattr(logging, log_level.upper()) + return getattr(logging, log_level.upper()) except AttributeError: - LOG_LEVEL = logging.WARNING + return logging.WARNING + - logging.getLogger("boto3").setLevel(LOG_LEVEL) - logging.getLogger("botocore").setLevel(LOG_LEVEL) - logging.getLogger("nose").setLevel(LOG_LEVEL) - logging.getLogger("s3transfer").setLevel(LOG_LEVEL) - logging.getLogger("urllib3").setLevel(LOG_LEVEL) +def set_aws_log_level(log_level): + level = _parse_log_level(log_level) + logging.getLogger("boto3").setLevel(level) + logging.getLogger("botocore").setLevel(level) + logging.getLogger("nose").setLevel(level) + logging.getLogger("s3transfer").setLevel(level) + logging.getLogger("urllib3").setLevel(level) def set_azure_log_level(log_level): - try: - LOG_LEVEL = getattr(logging, log_level.upper()) - except AttributeError: - LOG_LEVEL = logging.WARNING - logging.getLogger("azure").setLevel(LOG_LEVEL) + level = _parse_log_level(log_level) + logging.getLogger("azure").setLevel(level)