@@ -311,6 +311,183 @@ async def async_stream_wrapper(*args: Any, **kwargs: Any):
311311 else :
312312 return decorator (func )
313313
314+ def to_runnable (
315+ self ,
316+ func : Callable = None ,
317+ ) -> Callable :
318+ """
319+ Decorator to be RunnableLambda.
320+
321+ :param func: The function to be decorated, Requirements are as follows:
322+ 1. When the func is called, parameter config(RunnableConfig) is required, you must use the config containing cozeloop callback handler of 'current request', otherwise, the trace may be lost!
323+
324+ Examples:
325+ @to_runnable
326+ def runnable_func(my_input: dict) -> str:
327+ return input
328+
329+ async def scorer_leader(state: MyState) -> dict | str:
330+ await runnable_func({"a": "111", "b": 222, "c": "333"}, config=state.config) # config is required
331+ """
332+
333+ def decorator (func : Callable ):
334+ from langchain_core .runnables import RunnableLambda , RunnableConfig
335+
336+ @wraps (func )
337+ def sync_wrapper (* args : Any , ** kwargs : Any ):
338+ config = kwargs .pop ("config" , None )
339+ config = _convert_config (config )
340+ res = None
341+ try :
342+ extra = {}
343+ if len (args ) > 0 and is_class_func (func ):
344+ extra = {"_inner_class_self" : args [0 ]}
345+ args = args [1 :]
346+ inp = {}
347+ if len (args ) > 0 :
348+ inp ['args' ] = args
349+ if len (kwargs ) > 0 :
350+ inp ['kwargs' ] = kwargs
351+ res = RunnableLambda (_param_wrapped_func ).invoke (input = inp , config = config , ** extra )
352+ if hasattr (res , "__iter__" ):
353+ return res
354+ except StopIteration :
355+ pass
356+ except Exception as e :
357+ raise e
358+ finally :
359+ if res is not None :
360+ return res
361+
362+ @wraps (func )
363+ async def async_wrapper (* args : Any , ** kwargs : Any ):
364+ config = kwargs .pop ("config" , None )
365+ config = _convert_config (config )
366+ res = None
367+ try :
368+ extra = {}
369+ if len (args ) > 0 and is_class_func (func ):
370+ extra = {"_inner_class_self" : args [0 ]}
371+ args = args [1 :]
372+ inp = {}
373+ if len (args ) > 0 :
374+ inp ['args' ] = args
375+ if len (kwargs ) > 0 :
376+ inp ['kwargs' ] = kwargs
377+ res = await RunnableLambda (_param_wrapped_func_async ).ainvoke (input = inp , config = config , ** extra )
378+ if hasattr (res , "__aiter__" ):
379+ return res
380+ except StopIteration :
381+ pass
382+ except StopAsyncIteration :
383+ pass
384+ except Exception as e :
385+ if e .args and e .args [0 ] == 'coroutine raised StopIteration' : # coroutine StopIteration
386+ pass
387+ else :
388+ raise e
389+ finally :
390+ if res is not None :
391+ return res
392+
393+ @wraps (func )
394+ def gen_wrapper (* args : Any , ** kwargs : Any ):
395+ config = kwargs .pop ("config" , None )
396+ config = _convert_config (config )
397+ try :
398+ extra = {}
399+ if len (args ) > 0 and is_class_func (func ):
400+ extra = {"_inner_class_self" : args [0 ]}
401+ args = args [1 :]
402+ inp = {}
403+ if len (args ) > 0 :
404+ inp ['args' ] = args
405+ if len (kwargs ) > 0 :
406+ inp ['kwargs' ] = kwargs
407+ gen = RunnableLambda (_param_wrapped_func ).invoke (input = inp , config = config , * extra )
408+ try :
409+ for item in gen :
410+ yield item
411+ except StopIteration :
412+ pass
413+ except Exception as e :
414+ raise e
415+
416+ @wraps (func )
417+ async def async_gen_wrapper (* args : Any , ** kwargs : Any ):
418+ config = kwargs .pop ("config" , None )
419+ config = _convert_config (config )
420+ try :
421+ extra = {}
422+ if len (args ) > 0 and is_class_func (func ):
423+ extra = {"_inner_class_self" : args [0 ]}
424+ args = args [1 :]
425+ inp = {}
426+ if len (args ) > 0 :
427+ inp ['args' ] = args
428+ if len (kwargs ) > 0 :
429+ inp ['kwargs' ] = kwargs
430+ gen = RunnableLambda (_param_wrapped_func_async ).invoke (input = inp , config = config , ** extra )
431+ items = []
432+ try :
433+ async for item in gen :
434+ items .append (item )
435+ yield item
436+ finally :
437+ pass
438+ except StopIteration :
439+ pass
440+ except StopAsyncIteration :
441+ pass
442+ except Exception as e :
443+ if e .args and e .args [0 ] == 'coroutine raised StopIteration' :
444+ pass
445+ else :
446+ raise e
447+
448+ # for convert parameter
449+ def _param_wrapped_func (input_dict : dict , ** kwargs ) -> Any :
450+ real_args = input_dict .get ("args" , ())
451+ real_kwargs = input_dict .get ("kwargs" , {})
452+
453+ inner_class_self = kwargs .get ("_inner_class_self" , None )
454+ if inner_class_self is not None :
455+ real_args = (inner_class_self , * real_args )
456+
457+ return func (* real_args , ** real_kwargs )
458+
459+ async def _param_wrapped_func_async (input_dict : dict , ** kwargs ) -> Any :
460+ real_args = input_dict .get ("args" , ())
461+ real_kwargs = input_dict .get ("kwargs" , {})
462+
463+ inner_class_self = kwargs .get ("_inner_class_self" , None )
464+ if inner_class_self is not None :
465+ real_args = (inner_class_self , * real_args )
466+
467+ return await func (* real_args , ** real_kwargs )
468+
469+ def _convert_config (config : RunnableConfig = None ) -> RunnableConfig | None :
470+ if config is None :
471+ config = RunnableConfig (run_name = func .__name__ )
472+ config ['run_name' ] = func .__name__
473+ elif isinstance (config , dict ):
474+ config ['run_name' ] = func .__name__
475+ return config
476+
477+ if is_async_gen_func (func ):
478+ return async_gen_wrapper
479+ if is_gen_func (func ):
480+ return gen_wrapper
481+ elif is_async_func (func ):
482+ return async_wrapper
483+ else :
484+ return sync_wrapper
485+
486+ if func is None :
487+ return decorator
488+ else :
489+ return decorator (func )
490+
314491
315492class _CozeLoopTraceStream (Generic [S ]):
316493 def __init__ (
0 commit comments