Skip to content

Commit 9acb5d6

Browse files
committed
fix'
1 parent c8836b9 commit 9acb5d6

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

cozeloop/decorator/decorator.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,15 @@ def decorator(func: Callable):
337337

338338
@wraps(func)
339339
def sync_wrapper(*args: Any, **kwargs: Any):
340-
config = _get_config(**kwargs)
340+
config = kwargs.pop("config", None)
341+
config = _convert_config(config)
341342
res = None
342343
try:
343344
inp = {
344345
"args": args,
345346
"kwargs": kwargs
346347
}
347-
res = RunnableLambda(_param_wrapped_func).invoke(input=inp, config=config, **kwargs)
348+
res = RunnableLambda(_param_wrapped_func).invoke(input=inp, config=config)
348349
if hasattr(res, "__iter__"):
349350
return res
350351
except StopIteration:
@@ -357,14 +358,15 @@ def sync_wrapper(*args: Any, **kwargs: Any):
357358

358359
@wraps(func)
359360
async def async_wrapper(*args: Any, **kwargs: Any):
360-
config = _get_config(**kwargs)
361+
config = kwargs.pop("config", None)
362+
config = _convert_config(config)
361363
res = None
362364
try:
363365
inp = {
364366
"args": args,
365367
"kwargs": kwargs
366368
}
367-
res = await RunnableLambda(_param_wrapped_func_async).ainvoke(input=inp, config=config, **kwargs)
369+
res = await RunnableLambda(_param_wrapped_func_async).ainvoke(input=inp, config=config)
368370
if hasattr(res, "__aiter__"):
369371
return res
370372
except StopIteration:
@@ -382,13 +384,14 @@ async def async_wrapper(*args: Any, **kwargs: Any):
382384

383385
@wraps(func)
384386
def gen_wrapper(*args: Any, **kwargs: Any):
385-
config = _get_config(**kwargs)
387+
config = kwargs.pop("config", None)
388+
config = _convert_config(config)
386389
try:
387390
inp = {
388391
"args": args,
389392
"kwargs": kwargs
390393
}
391-
gen = RunnableLambda(_param_wrapped_func).invoke(input=inp, config=config, **kwargs)
394+
gen = RunnableLambda(_param_wrapped_func).invoke(input=inp, config=config)
392395
try:
393396
for item in gen:
394397
yield item
@@ -399,13 +402,14 @@ def gen_wrapper(*args: Any, **kwargs: Any):
399402

400403
@wraps(func)
401404
async def async_gen_wrapper(*args: Any, **kwargs: Any):
402-
config = _get_config(**kwargs)
405+
config = kwargs.pop("config", None)
406+
config = _convert_config(config)
403407
try:
404408
inp = {
405409
"args": args,
406410
"kwargs": kwargs
407411
}
408-
gen = RunnableLambda(_param_wrapped_func_async).invoke(input=inp, config=config, **kwargs)
412+
gen = RunnableLambda(_param_wrapped_func_async).invoke(input=inp, config=config)
409413
items = []
410414
try:
411415
async for item in gen:
@@ -434,8 +438,7 @@ async def _param_wrapped_func_async(input_dict: dict) -> Any:
434438
kwargs = input_dict.get("kwargs", {})
435439
return await func(*args, **kwargs)
436440

437-
def _get_config(**kwargs: Any) -> RunnableConfig | None:
438-
config = kwargs.pop("config", None)
441+
def _convert_config(config: RunnableConfig = None) -> RunnableConfig | None:
439442
if config is None:
440443
config = RunnableConfig(run_name=func.__name__)
441444
config['run_name'] = func.__name__

0 commit comments

Comments
 (0)