1- from typing import List , Any
1+ from typing import List , Any , Dict
22import logging
33
4- from .chain_result import ChainResult
4+ from .chain_result import ChainResult , ChainResultList
55from .chain import Chain
66from .chain_sequence import ChainSequence
77
@@ -40,9 +40,8 @@ def __init__(
4040 self .output_keys = [output_name ]
4141 self .required_keys = template .required_keys
4242
43- async def async_run (self , ** inputs_dict : Any ) -> ChainResult :
43+ async def _run_with_error_handling (self , task ) :
4444 handlers = self .llm_runner ._get_error_handlers ()
45- prompt = self .template .format (** inputs_dict )
4645 while True :
4746 holds_a_lock = False
4847 try :
@@ -65,7 +64,7 @@ async def async_run(self, **inputs_dict: Any) -> ChainResult:
6564 holds_a_lock = False
6665 self ._error_handling_lock .release ()
6766
68- result = await self . llm_runner . async_run ( prompt )
67+ result = await task
6968
7069 if not holds_a_lock :
7170 async with self ._error_handling_lock :
@@ -79,9 +78,8 @@ async def async_run(self, **inputs_dict: Any) -> ChainResult:
7978 holds_a_lock = False
8079 self ._rate_limited_state = False
8180 self ._error_handling_lock .release ()
82- output_dict = inputs_dict
83- output_dict [self .output_name ] = result
84- return ChainResult (output_dict = output_dict )
81+
82+ return result
8583
8684 finally :
8785 if holds_a_lock :
@@ -129,6 +127,26 @@ async def async_run(self, **inputs_dict: Any) -> ChainResult:
129127 if holds_a_lock :
130128 holds_a_lock = False
131129 self ._error_handling_lock .release ()
130+
131+ async def async_run_multiple (
132+ self , * inputs_dict : Dict [str , Any ]
133+ ) -> List [ChainResult ]:
134+ prompts = [self .template .format (** inp ) for inp in inputs_dict ]
135+ llm_results = await self ._run_with_error_handling (asyncio .create_task (self .llm_runner .run_batch (prompts = prompts )))
136+ results = ChainResultList ()
137+ for llm_result , input_dict in zip (llm_results , inputs_dict ):
138+ result = ChainResult (input_dict )
139+ result .output_dict [self .output_name ] = llm_result
140+ results .append (result )
141+ return results
132142
143+ async def async_run (self , ** inputs_dict : Any ) -> ChainResult :
144+ handlers = self .llm_runner ._get_error_handlers ()
145+ prompt = self .template .format (** inputs_dict )
146+
147+ result = ChainResult (output_dict = inputs_dict )
148+ result .output_dict [self .output_name ] = await self ._run_with_error_handling (asyncio .create_task (self .llm_runner .async_run (prompt )))
149+ return result
150+
133151 def __add__ (self , other ) -> Chain :
134152 return ChainSequence ([self ]) + other
0 commit comments