From ebde2e81aa703a3eafbc9a6c0f0e5603afac575f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Jolovi=C4=87?= Date: Fri, 8 May 2026 23:55:08 +0200 Subject: [PATCH] feat: support tuple items in map --- src/cashet/_batch.py | 13 +++++++++++++ src/cashet/async_client.py | 3 ++- tests/test_map.py | 12 ++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/cashet/_batch.py b/src/cashet/_batch.py index fff3398..3187e8e 100644 --- a/src/cashet/_batch.py +++ b/src/cashet/_batch.py @@ -2,6 +2,7 @@ import asyncio import logging +from collections.abc import Callable, Iterable from datetime import timedelta from typing import Any @@ -74,6 +75,18 @@ def normalize_tasks( return normalized +def build_map_tasks( + func: Callable[..., Any], + items: Iterable[Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> list[Any]: + return [ + (func, (*item, *args) if isinstance(item, tuple) else (item, *args), kwargs) + for item in items + ] + + def build_deps( keys: list[BatchKey], normalized: list[NormalizedTask], diff --git a/src/cashet/async_client.py b/src/cashet/async_client.py index c1c9499..7bbc023 100644 --- a/src/cashet/async_client.py +++ b/src/cashet/async_client.py @@ -8,6 +8,7 @@ from cashet._batch import ( build_deps, + build_map_tasks, execute_batch, normalize_tasks, topological_sort, @@ -277,7 +278,7 @@ async def map( max_workers: int | None = None, **kwargs: Any, ) -> list[AsyncResultRef[T]]: - task_list: list[Any] = [(func, (item, *args), kwargs) for item in items] + task_list = build_map_tasks(func, items, args, kwargs) return await self.submit_many( task_list, _cache=_cache, diff --git a/tests/test_map.py b/tests/test_map.py index 841e7c6..12809be 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -27,6 +27,10 @@ def greet(name: str, greeting: str = "hello") -> str: return f"{greeting} {name}" +def add_pair(x: int, y: int) -> int: + return x + y + + class TestSyncMap: def test_map_basic(self, client: Client) -> None: refs = client.map(double, [1, 2, 3]) @@ -38,6 +42,10 @@ def test_map_with_extra_args(self, client: Client) -> None: refs = client.map(add, [1, 2, 3], 10) assert [r.load() for r in refs] == [11, 12, 13] + def test_map_with_tuple_items(self, client: Client) -> None: + refs = client.map(add_pair, [(1, 2), (3, 4)]) + assert [r.load() for r in refs] == [3, 7] + def test_map_with_kwargs(self, client: Client) -> None: refs = client.map(greet, ["alice", "bob"], greeting="hi") assert [r.load() for r in refs] == ["hi alice", "hi bob"] @@ -74,6 +82,10 @@ async def test_map_with_extra_args(self, async_client: AsyncClient) -> None: refs = await async_client.map(add, [1, 2, 3], 10) assert [await r.load() for r in refs] == [11, 12, 13] + async def test_map_with_tuple_items(self, async_client: AsyncClient) -> None: + refs = await async_client.map(add_pair, [(1, 2), (3, 4)]) + assert [await r.load() for r in refs] == [3, 7] + async def test_map_caching(self, async_client: AsyncClient) -> None: call_count = 0