diff --git a/hamilton/function_modifiers/delayed.py b/hamilton/function_modifiers/delayed.py index d22ef9231..2eb8725fe 100644 --- a/hamilton/function_modifiers/delayed.py +++ b/hamilton/function_modifiers/delayed.py @@ -169,7 +169,16 @@ def resolve(self, config: dict[str, Any], fn: Callable) -> NodeTransformLifecycl for key in self._optional_config: if key in config: kwargs[key] = config[key] - return self.decorate_with(**kwargs) + decorator = self.decorate_with(**kwargs) + + # NOTE: cases where `decorator` has no `validate` method should be caught by type checkers + # since `decorate_with` is typed as `Callable[..., NodeTransformLifecycle]`. The following + # check allows non-conforming functions to be used with `resolve` without immediately + # throwing an error, which may be undesirable. + if hasattr(decorator, "validate"): + decorator.validate(fn) + + return decorator class resolve_from_config(resolve): diff --git a/tests/function_modifiers/test_delayed.py b/tests/function_modifiers/test_delayed.py index 4dcf11818..b48fc92b9 100644 --- a/tests/function_modifiers/test_delayed.py +++ b/tests/function_modifiers/test_delayed.py @@ -17,6 +17,7 @@ from collections.abc import Callable +import pandas as pd import pytest from hamilton import settings @@ -24,6 +25,8 @@ ResolveAt, base, extract_columns, + extract_fields, + parameterize_sources, resolve, resolve_from_config, ) @@ -68,12 +71,17 @@ def test_extract_and_validate_params_unhappy(fn: Callable): def test_dynamic_resolves(): + # Note: we use an empty DataFrame for validation only. This test would fail at runtime + # if we actually tried to execute the DAG because there are no columns "a" or "b" to extract. + def fn() -> pd.DataFrame: + return pd.DataFrame() + decorator = resolve( when=ResolveAt.CONFIG_AVAILABLE, decorate_with=lambda cols_to_extract: extract_columns(*cols_to_extract), ) decorator_resolved = decorator.resolve( - {"cols_to_extract": ["a", "b"], **CONFIG_WITH_POWER_MODE_ENABLED}, fn=test_dynamic_resolves + {"cols_to_extract": ["a", "b"], **CONFIG_WITH_POWER_MODE_ENABLED}, fn=fn ) # This uses an internal component of extract_columns # We may want to add a little more comprehensive testing @@ -82,11 +90,15 @@ def test_dynamic_resolves(): def test_dynamic_resolve_with_configs(): + def fn() -> pd.DataFrame: + return pd.DataFrame() + decorator = resolve_from_config( decorate_with=lambda cols_to_extract: extract_columns(*cols_to_extract), ) decorator_resolved = decorator.resolve( - {"cols_to_extract": ["a", "b"], **CONFIG_WITH_POWER_MODE_ENABLED}, fn=test_dynamic_resolves + {"cols_to_extract": ["a", "b"], **CONFIG_WITH_POWER_MODE_ENABLED}, + fn=fn, ) # This uses an internal component of extract_columns # We may want to add a little more comprehensive testing @@ -94,19 +106,16 @@ def test_dynamic_resolve_with_configs(): assert decorator_resolved.columns == ("a", "b") -def test_dynamic_fails_without_power_mode_fails(): +def test_dynamic_resolve_without_power_mode_fails(): + def fn() -> pd.DataFrame: + return pd.DataFrame() + decorator = resolve( when=ResolveAt.CONFIG_AVAILABLE, decorate_with=lambda cols_to_extract: extract_columns(*cols_to_extract), ) with pytest.raises(base.InvalidDecoratorException): - decorator_resolved = decorator.resolve( - CONFIG_WITH_POWER_MODE_DISABLED, fn=test_dynamic_fails_without_power_mode_fails - ) - # This uses an internal component of extract_columns - # We may want to add a little more comprehensive testing - # But for now this will work - assert decorator_resolved.columns == ("a", "b") + decorator.resolve(CONFIG_WITH_POWER_MODE_DISABLED, fn=fn) def test_config_derivation(): @@ -123,6 +132,9 @@ def test_config_derivation(): def test_delayed_with_optional(): + def fn() -> pd.DataFrame: + return pd.DataFrame() + decorator = resolve( when=ResolveAt.CONFIG_AVAILABLE, decorate_with=lambda cols_to_extract, some_cols_you_might_want_to_extract=["c"]: ( @@ -131,7 +143,7 @@ def test_delayed_with_optional(): ) resolved = decorator.resolve( {"cols_to_extract": ["a", "b"], **CONFIG_WITH_POWER_MODE_ENABLED}, - fn=test_delayed_with_optional, + fn=fn, ) assert list(resolved.columns) == ["a", "b", "c"] resolved = decorator.resolve( @@ -140,12 +152,15 @@ def test_delayed_with_optional(): "some_cols_you_might_want_to_extract": ["d"], **CONFIG_WITH_POWER_MODE_ENABLED, }, - fn=test_delayed_with_optional, + fn=fn, ) assert list(resolved.columns) == ["a", "b", "d"] def test_delayed_without_power_mode_fails(): + def fn() -> pd.DataFrame: + return pd.DataFrame() + decorator = resolve( when=ResolveAt.CONFIG_AVAILABLE, decorate_with=lambda cols_to_extract, some_cols_you_might_want_to_extract=["c"]: ( @@ -155,5 +170,104 @@ def test_delayed_without_power_mode_fails(): with pytest.raises(base.InvalidDecoratorException): decorator.resolve( {"cols_to_extract": ["a", "b"], **CONFIG_WITH_POWER_MODE_DISABLED}, - fn=test_delayed_with_optional, + fn=fn, ) + + +def test_dynamic_resolve_with_extract_fields(): + """Test that @resolve with @extract_fields calls validate() correctly.""" + + def fn() -> dict[str, int]: + return {"a": 1, "b": 2} + + decorator = resolve( + when=ResolveAt.CONFIG_AVAILABLE, + decorate_with=lambda fields: extract_fields(fields), + ) + decorator_resolved = decorator.resolve( + {"fields": {"a": int, "b": int}, **CONFIG_WITH_POWER_MODE_ENABLED}, + fn=fn, + ) + assert hasattr(decorator_resolved, "resolved_fields") + assert decorator_resolved.resolved_fields == {"a": int, "b": int} + + +def test_resolve_with_parameterize_sources(): + """Test that @resolve with @parameterize_sources calls validate() correctly.""" + + def fn(x: int, y: int) -> int: + return x + y + + decorator = resolve( + when=ResolveAt.CONFIG_AVAILABLE, + decorate_with=lambda: parameterize_sources(result_1={"x": "source_x", "y": "source_y"}), + ) + decorator_resolved = decorator.resolve( + {**CONFIG_WITH_POWER_MODE_ENABLED}, + fn=fn, + ) + assert "result_1" in decorator_resolved.parameterization + mapping = decorator_resolved.parameterization["result_1"] + assert mapping["x"].source == "source_x" + assert mapping["y"].source == "source_y" + + +def test_resolve_from_config_with_extract_fields(): + """Test @resolve_from_config with @extract_fields calls validate() correctly.""" + + def fn() -> dict[str, int]: + return {"a": 1, "b": 2} + + decorator = resolve_from_config( + decorate_with=lambda fields: extract_fields(fields), + ) + decorator_resolved = decorator.resolve( + {"fields": {"a": int, "b": int}, **CONFIG_WITH_POWER_MODE_ENABLED}, + fn=fn, + ) + assert hasattr(decorator_resolved, "resolved_fields") + assert decorator_resolved.resolved_fields == {"a": int, "b": int} + + +def test_resolve_propagates_validate_failure(): + """Test that validate() failures are propagated through resolve.""" + + def fn() -> str: + return "not what you were expecting..." + + decorator = resolve( + when=ResolveAt.CONFIG_AVAILABLE, + decorate_with=lambda fields: extract_fields(fields), + ) + with pytest.raises(base.InvalidDecoratorException): + decorator.resolve( + {"fields": {"a": int, "b": int}, **CONFIG_WITH_POWER_MODE_ENABLED}, + fn=fn, + ) + + +def test_resolve_with_arbitrary_decorator(): + """Test behavior when decorate_with returns something that is not a NodeTransformLifecycle.""" + + # NOTE: we want to ensure we don't interfere with other decorators on functions. + + # A decorator that doesn't inherit from NodeTransformLifecycle (but still uses kwargs only) + class ArbitraryDecorator: + def __init__(self, a: int, b: int) -> None: + pass + + def __call__(self, f: Callable) -> Callable: + return f + + def fn() -> pd.DataFrame: + return pd.DataFrame() + + decorator = resolve( + when=ResolveAt.CONFIG_AVAILABLE, + decorate_with=lambda kwargs: ArbitraryDecorator(**kwargs), + ) + decorator_resolved = decorator.resolve( + {"kwargs": {"a": 1, "b": 2}, **CONFIG_WITH_POWER_MODE_ENABLED}, + fn=fn, + ) + assert isinstance(decorator_resolved, ArbitraryDecorator)