diff --git a/src/openhound/core/app.py b/src/openhound/core/app.py index 971aecf..526a49c 100644 --- a/src/openhound/core/app.py +++ b/src/openhound/core/app.py @@ -21,7 +21,7 @@ from openhound.core.models.extension import Extension from openhound.core.preproc import PreProcContext, PreProcessor from openhound.core.progress import Progress -from openhound.core.resources import safe_resource_wrapper +from openhound.core.resources import safe_defer_wrapper, safe_resource_wrapper logger = logging.getLogger(__name__) @@ -292,6 +292,12 @@ def wrapper( return decorator + def defer(self, func: Callable) -> Callable: + """Decorator to register a DLT defer with added exception handling.""" + logger.debug(f"Registering defer for {self.name}") + safe_func = safe_defer_wrapper(func) + return dlt.defer(safe_func) + def transformer( self, *dlt_args, diff --git a/src/openhound/core/app.pyi b/src/openhound/core/app.pyi index 9f61cd5..1720a40 100644 --- a/src/openhound/core/app.pyi +++ b/src/openhound/core/app.pyi @@ -99,6 +99,7 @@ class OpenHound: parallelized: bool = False, _impl_cls: type[DltSource] = DltSource, ) -> Any: ... + def defer(self, func: Callable[..., Any]) -> Callable[..., Callable[[], Any]]: ... def resource( self, data: Optional[Any] = None, diff --git a/src/openhound/core/resources.py b/src/openhound/core/resources.py index 344f7c5..900eb11 100644 --- a/src/openhound/core/resources.py +++ b/src/openhound/core/resources.py @@ -1,13 +1,35 @@ -import logging - - import functools import inspect +import logging from typing import Callable logger = logging.getLogger(__name__) +def safe_defer_wrapper(func: Callable) -> Callable: + """Wrap a DLT defer to catch and log exceptions without stopping the entire pipeline. + + Args: + func: The defer function + + Returns: + Wrapped function that catches exceptions and continues (if possible of course) + """ + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger.error( + f"Error executing DLT defer: {e}", + extra={"phase": "defer_execution"}, + ) + return [] + + return sync_wrapper + + def safe_resource_wrapper(func: Callable, resource_name: str) -> Callable: """Wrap a DLT resource to catch and log exceptions without stopping the entire pipeline. Can either be sync or async generator function. @@ -36,8 +58,6 @@ def sync_wrapper(*args, **kwargs): return if inspect.isgenerator(gen): - # Note: Don't use while item: := next(gen, None) because this will stop the full iterator - # if the resource yields any empty value while True: try: item = next(gen) @@ -53,7 +73,6 @@ def sync_wrapper(*args, **kwargs): }, ) continue - else: yield gen @@ -74,8 +93,6 @@ async def async_wrapper(*args, **kwargs): return if inspect.isasyncgen(gen): - # Note: Don't use while item: := next(gen, None) because this will stop the full iterator - # if the resource yields any empty value while True: try: item = await gen.__anext__() @@ -90,8 +107,8 @@ async def async_wrapper(*args, **kwargs): "phase": "resource_iteration", }, ) - continue + else: try: result = await gen diff --git a/tests/test_safe_dlt_wrappers.py b/tests/test_safe_dlt_wrappers.py new file mode 100644 index 0000000..f401666 --- /dev/null +++ b/tests/test_safe_dlt_wrappers.py @@ -0,0 +1,89 @@ +import logging + +from pydantic import BaseModel + +from openhound.core.app import OpenHound +from openhound.core.collect import Collector +from openhound.core.progress import Progress + + +class Computer(BaseModel): + id: int + hostname: str + + +class User(BaseModel): + id: int + email: str + + +class UserDetails(User): + office: str + + +def test_dlt_wrapper_pipeline_continues( + caplog, + monkeypatch, + tmp_path, +): + monkeypatch.setenv("DLT_DATA_DIR", str(tmp_path / ".dlt")) + monkeypatch.setattr( + "openhound.core.collect.logger_override.set_handler", lambda name: None + ) + caplog.set_level(logging.ERROR, logger="openhound.core.resources") + + app = OpenHound("safe_wrapper_test", "TEST") + + @app.resource(name="computers", columns=Computer) + def computers(): + yield {"id": 1, "hostname": "DESKTOP-12345"} + yield {"id": 2, "hostname": "DESKTOP-54321"} + raise RuntimeError("resource failed after valid rows") + + @app.transformer(name="users", columns=User) + def users(computer): + if computer["id"] == 1: + yield {"id": 10, "email": "someuser@example.org"} + raise RuntimeError("transformer failed after valid row") + + yield {"id": 20, "email": "someuser2@example.org"} + + @app.transformer(name="user_details", columns=UserDetails) + def user_details(user): + + @app.defer + def deferred_child(user_input): + if user_input["id"] == 1: + raise RuntimeError("defer failed for parent") + + return {"id": 20, "email": "someuser2@example.org", "office": "Amsterdam"} + + yield deferred_child(user) + + @app.source(name="safe_wrapper_test", max_table_nesting=0) + def source(): + computers_resource = computers() + return ( + computers_resource, + computers_resource | users(), + computers_resource | user_details(), + ) + + collector = Collector( + name="safe_wrapper_test", + output_path=tmp_path / "output", + progress=Progress.log, + ) + + load_info = collector.run(source()) + + assert load_info is not None + + messages = [record.getMessage() for record in caplog.records] + phases = {getattr(record, "phase", None) for record in caplog.records} + + assert any("resource failed after valid rows" in message for message in messages) + assert any("transformer failed after valid row" in message for message in messages) + assert any("defer failed for parent" in message for message in messages) + assert "resource_iteration" in phases + assert "defer_execution" in phases