Skip to content

Commit 6023687

Browse files
committed
2 parents 1ffc755 + 52a0927 commit 6023687

9 files changed

Lines changed: 107 additions & 16 deletions

File tree

.github/workflows/test.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,30 @@ jobs:
2323
2424
- name: Run tests
2525
run: tox
26+
27+
publish-on-pypi:
28+
runs-on: ubuntu-latest
29+
permissions:
30+
id-token: write
31+
needs: test
32+
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
33+
34+
steps:
35+
- name: Checkout code
36+
uses: actions/checkout@v2
37+
38+
- name: Set up Python
39+
uses: actions/setup-python@v2
40+
with:
41+
python-version: '3.10'
42+
43+
- name: Install dependencies
44+
run: |
45+
pip install build twine
46+
47+
- name: Build package
48+
run: |
49+
python -m build
2650
51+
- name: Publish on PyPI
52+
uses: pypa/gh-action-pypi-publish@v1.4.2

src/hyperchain/chain/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .chain_result import ChainResult
1+
from .chain_result import ChainResult, ChainResultList
22

33
from .chain import Chain
44
from .chain_sequence import ChainSequence

src/hyperchain/chain/chain.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import asyncio
66

7-
from .chain_result import ChainResult
7+
from .chain_result import ChainResult, ChainResultList
88
from ..prompt_templates import Template
99

1010
class Chain(ABC):
@@ -27,8 +27,10 @@ def run_multiple(self, *inputs_dict: Dict[str, Any]) -> List[ChainResult]:
2727
async def async_run_multiple(
2828
self, *inputs_dict: Dict[str, Any]
2929
) -> List[ChainResult]:
30-
return await asyncio.gather(
31-
*[self.async_run(**input_list) for input_list in inputs_dict]
30+
return await ChainResultList(
31+
asyncio.gather(
32+
*[self.async_run(**input_list) for input_list in inputs_dict]
33+
)
3234
)
3335

3436
def __add__(self, other: Any) -> Chain:

src/hyperchain/chain/chain_result.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class ChainResult:
1212
def __post_init__(self):
1313
self.output_dict = deepcopy(self.output_dict)
1414

15+
def __getstate__(self): return self.output_dict
16+
17+
def __setstate__(self, state): self.output_dict = state
18+
1519
def __getattr__(self, name):
1620
if name == "output_dict":
1721
return self.output_dict
@@ -24,3 +28,12 @@ def __getattr__(self, name):
2428
return str(self.output_dict[name])
2529

2630
return None
31+
32+
33+
class ChainResultList(list):
34+
def __getstate__(self): return self.__dict__
35+
36+
def __setstate__(self, state): self.__dict__ = state
37+
38+
def __getattr__(self, name):
39+
return [chain.__getattr__(name) for chain in self]

src/hyperchain/chain/llm_chain.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import List, Any
1+
from typing import List, Any, Dict
22
import logging
33

4-
from .chain_result import ChainResult
4+
from .chain_result import ChainResult, ChainResultList
55
from .chain import Chain
66
from .chain_sequence import ChainSequence
77

@@ -40,9 +40,8 @@ def __init__(
4040
self.output_keys = [output_name]
4141
self.required_keys = template.required_keys
4242

43-
async def async_run(self, **inputs_dict: Any) -> ChainResult:
43+
async def _run_with_error_handling(self, task):
4444
handlers = self.llm_runner._get_error_handlers()
45-
prompt = self.template.format(**inputs_dict)
4645
while True:
4746
holds_a_lock = False
4847
try:
@@ -65,7 +64,7 @@ async def async_run(self, **inputs_dict: Any) -> ChainResult:
6564
holds_a_lock = False
6665
self._error_handling_lock.release()
6766

68-
result = await self.llm_runner.async_run(prompt)
67+
result = await task
6968

7069
if not holds_a_lock:
7170
async with self._error_handling_lock:
@@ -79,9 +78,8 @@ async def async_run(self, **inputs_dict: Any) -> ChainResult:
7978
holds_a_lock = False
8079
self._rate_limited_state = False
8180
self._error_handling_lock.release()
82-
output_dict = inputs_dict
83-
output_dict[self.output_name] = result
84-
return ChainResult(output_dict=output_dict)
81+
82+
return result
8583

8684
finally:
8785
if holds_a_lock:
@@ -129,6 +127,26 @@ async def async_run(self, **inputs_dict: Any) -> ChainResult:
129127
if holds_a_lock:
130128
holds_a_lock = False
131129
self._error_handling_lock.release()
130+
131+
async def async_run_multiple(
132+
self, *inputs_dict: Dict[str, Any]
133+
) -> List[ChainResult]:
134+
prompts = [self.template.format(**inp) for inp in inputs_dict]
135+
llm_results = await self._run_with_error_handling(asyncio.create_task(self.llm_runner.run_batch(prompts=prompts)))
136+
results = ChainResultList()
137+
for llm_result, input_dict in zip(llm_results, inputs_dict):
138+
result = ChainResult(input_dict)
139+
result.output_dict[self.output_name] = llm_result
140+
results.append(result)
141+
return results
132142

143+
async def async_run(self, **inputs_dict: Any) -> ChainResult:
144+
handlers = self.llm_runner._get_error_handlers()
145+
prompt = self.template.format(**inputs_dict)
146+
147+
result = ChainResult(output_dict=inputs_dict)
148+
result.output_dict[self.output_name] = await self._run_with_error_handling(asyncio.create_task(self.llm_runner.async_run(prompt)))
149+
return result
150+
133151
def __add__(self, other) -> Chain:
134152
return ChainSequence([self]) + other

src/hyperchain/llm_runners/llm_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
from .error_handler import BaseErrorHandler
33
from .llm_result import LLMResult
44
from typing import List, Any
5-
5+
from asyncio import create_task, gather
66

77
class LLMRunner(ABC):
88
@abstractmethod
99
async def async_run(self, prompt: Any) -> LLMResult:
1010
pass
1111

12+
async def run_batch(self, prompts: List[Any]) -> List[LLMResult]:
13+
return await gather(*[self.async_run(prompt=prompt) for prompt in prompts])
14+
1215
def _get_error_handlers(self) -> List[BaseErrorHandler]:
1316
return []

src/hyperchain/llm_runners/masked_model_runner.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ async def async_run(self, prompt: str):
3939
response = response.replace('<mask>', prediction[0]['token_str'], 1)
4040

4141
return LLMResult(response, extra_llm_outputs=predictions)
42+
43+
async def run_batch(self, prompts):
44+
from torch import inference_mode
45+
with inference_mode():
46+
predictions_list = self.fill_mask(prompts, **self.pipeline_parameters)
47+
48+
results = []
49+
for prompt, predictions in zip(prompts, predictions_list):
50+
response = prompt
51+
for prediction in predictions:
52+
response = response.replace('<mask>', prediction[0]['token_str'], 1)
53+
results.append(LLMResult(response, extra_llm_outputs=predictions))
54+
55+
return results
4256

4357
def _get_error_handlers(self):
4458
return []

src/hyperchain/llm_runners/t5_model_runner.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(
99
model,
1010
tokenizer = None,
1111
model_kwargs = {},
12-
):
12+
):
1313
if isinstance(model, str):
1414
self.model = T5ForConditionalGeneration.from_pretrained(model)
1515
else:
@@ -60,10 +60,24 @@ def _apply_response(self, prompt, response):
6060
return result
6161

6262
async def async_run(self, prompt: str):
63+
from torch import inference_mode
6364
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
64-
response = self.model.generate(input_ids, **self.model_kwargs)
65+
with inference_mode():
66+
response = self.model.generate(input_ids, **self.model_kwargs)
6567
decoded_response = self.tokenizer.decode(self._apply_response(input_ids, response)[0], skip_special_tokens=True)
6668
return LLMResult(decoded_response, extra_llm_outputs={"input_ids": input_ids, "response": response})
69+
70+
async def run_batch(self, prompts):
71+
from torch import inference_mode
72+
input_ids_batch = self.tokenizer(prompts, padding=True, return_tensors="pt").input_ids
73+
with inference_mode():
74+
responses = self.model.generate(input_ids_batch, **self.model_kwargs)
75+
decoded_responses = self.tokenizer.batch_decode(self._apply_response(input_ids_batch, responses), skip_special_tokens=True)
76+
return [
77+
LLMResult(decoded_response, extra_llm_outputs={"input_ids": input_ids, "response": response})
78+
for decoded_response, input_ids, response in zip(decoded_responses, input_ids_batch, responses)
79+
]
80+
6781

6882
def _get_error_handlers(self):
6983
return []

tests/llm_runners/test_t5_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from unittest.mock import Mock
2+
from unittest.mock import Mock, MagicMock, patch
33
from hyperchain.llm_runners.t5_model_runner import T5ConditionalModelRunner
44

55
def test_apply_response():
@@ -18,6 +18,7 @@ def test_apply_response():
1818
assert result == [[4, 5, 6, 7, 8, 9, 10]]
1919

2020
@pytest.mark.asyncio
21+
@patch.dict('sys.modules', torch=MagicMock())
2122
async def test_async_run():
2223
mocked_model = Mock()
2324
mocked_tokenizer = Mock()

0 commit comments

Comments
 (0)