Skip to content

Commit 71d9389

Browse files
committed
Context management cleanup
1 parent 8352537 commit 71d9389

File tree

2 files changed

+652
-598
lines changed

2 files changed

+652
-598
lines changed

src/pipedream/proxy/client.py

Lines changed: 151 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import base64
44
import typing
5+
from collections.abc import AsyncIterator, Iterator
56

67
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
78
from ..core.request_options import RequestOptions
@@ -88,13 +89,26 @@ def get(
8889
additional_headers=downstream_headers,
8990
additional_query_parameters=params or {},
9091
)
91-
_response = self._raw_client.get(
92+
ctx = self._raw_client.get(
9293
url_64,
9394
external_user_id=external_user_id,
9495
account_id=account_id,
9596
request_options=request_options,
9697
)
97-
return _response.data
98+
_response = ctx.__enter__()
99+
data = _response.data
100+
101+
if not isinstance(data, Iterator):
102+
ctx.__exit__(None, None, None)
103+
return data
104+
105+
def _stream() -> typing.Iterator[bytes]:
106+
try:
107+
for chunk in data:
108+
yield chunk
109+
finally:
110+
ctx.__exit__(None, None, None)
111+
return _stream()
98112

99113
def post(
100114
self,
@@ -162,14 +176,27 @@ def post(
162176
additional_headers=downstream_headers,
163177
additional_query_parameters=params or {},
164178
)
165-
_response = self._raw_client.post(
179+
ctx = self._raw_client.post(
166180
url_64,
167181
external_user_id=external_user_id,
168182
account_id=account_id,
169183
request=body or {},
170184
request_options=request_options,
171185
)
172-
return _response.data
186+
_response = ctx.__enter__()
187+
data = _response.data
188+
189+
if not isinstance(data, Iterator):
190+
ctx.__exit__(None, None, None)
191+
return data
192+
193+
def _stream() -> typing.Iterator[bytes]:
194+
try:
195+
for chunk in data:
196+
yield chunk
197+
finally:
198+
ctx.__exit__(None, None, None)
199+
return _stream()
173200

174201
def put(
175202
self,
@@ -237,14 +264,27 @@ def put(
237264
additional_headers=downstream_headers,
238265
additional_query_parameters=params or {},
239266
)
240-
_response = self._raw_client.put(
267+
ctx = self._raw_client.put(
241268
url_64,
242269
external_user_id=external_user_id,
243270
account_id=account_id,
244271
request=body or {},
245272
request_options=request_options,
246273
)
247-
return _response.data
274+
_response = ctx.__enter__()
275+
data = _response.data
276+
277+
if not isinstance(data, Iterator):
278+
ctx.__exit__(None, None, None)
279+
return data
280+
281+
def _stream() -> typing.Iterator[bytes]:
282+
try:
283+
for chunk in data:
284+
yield chunk
285+
finally:
286+
ctx.__exit__(None, None, None)
287+
return _stream()
248288

249289
def delete(
250290
self,
@@ -304,13 +344,26 @@ def delete(
304344
additional_headers=downstream_headers,
305345
additional_query_parameters=params or {},
306346
)
307-
_response = self._raw_client.delete(
347+
ctx = self._raw_client.delete(
308348
url_64,
309349
external_user_id=external_user_id,
310350
account_id=account_id,
311351
request_options=request_options,
312352
)
313-
return _response.data
353+
_response = ctx.__enter__()
354+
data = _response.data
355+
356+
if not isinstance(data, Iterator):
357+
ctx.__exit__(None, None, None)
358+
return data
359+
360+
def _stream() -> typing.Iterator[bytes]:
361+
try:
362+
for chunk in data:
363+
yield chunk
364+
finally:
365+
ctx.__exit__(None, None, None)
366+
return _stream()
314367

315368
def patch(
316369
self,
@@ -378,14 +431,27 @@ def patch(
378431
additional_headers=downstream_headers,
379432
additional_query_parameters=params or {},
380433
)
381-
_response = self._raw_client.patch(
434+
ctx = self._raw_client.patch(
382435
url_64,
383436
external_user_id=external_user_id,
384437
account_id=account_id,
385438
request=body or {},
386439
request_options=request_options,
387440
)
388-
return _response.data
441+
_response = ctx.__enter__()
442+
data = _response.data
443+
444+
if not isinstance(data, Iterator):
445+
ctx.__exit__(None, None, None)
446+
return data
447+
448+
def _stream() -> typing.Iterator[bytes]:
449+
try:
450+
for chunk in data:
451+
yield chunk
452+
finally:
453+
ctx.__exit__(None, None, None)
454+
return _stream()
389455

390456

391457
class AsyncProxyClient:
@@ -472,13 +538,26 @@ async def main() -> None:
472538
additional_headers=downstream_headers,
473539
additional_query_parameters=params or {},
474540
)
475-
_response = await self._raw_client.get(
541+
ctx = self._raw_client.get(
476542
url_64,
477543
external_user_id=external_user_id,
478544
account_id=account_id,
479545
request_options=request_options,
480546
)
481-
return _response.data
547+
_response = await ctx.__aenter__()
548+
data = _response.data
549+
550+
if not isinstance(data, AsyncIterator):
551+
await ctx.__aexit__(None, None, None)
552+
return data
553+
554+
async def _stream() -> typing.AsyncIterator[bytes]:
555+
try:
556+
async for chunk in data:
557+
yield chunk
558+
finally:
559+
await ctx.__aexit__(None, None, None)
560+
return _stream()
482561

483562
async def post(
484563
self,
@@ -554,14 +633,27 @@ async def main() -> None:
554633
additional_headers=downstream_headers,
555634
additional_query_parameters=params or {},
556635
)
557-
_response = await self._raw_client.post(
636+
ctx = self._raw_client.post(
558637
url_64,
559638
external_user_id=external_user_id,
560639
account_id=account_id,
561640
request=body or {},
562641
request_options=request_options,
563642
)
564-
return _response.data
643+
_response = await ctx.__aenter__()
644+
data = _response.data
645+
646+
if not isinstance(data, AsyncIterator):
647+
await ctx.__aexit__(None, None, None)
648+
return data
649+
650+
async def _stream() -> typing.AsyncIterator[bytes]:
651+
try:
652+
async for chunk in data:
653+
yield chunk
654+
finally:
655+
await ctx.__aexit__(None, None, None)
656+
return _stream()
565657

566658
async def put(
567659
self,
@@ -637,14 +729,27 @@ async def main() -> None:
637729
additional_headers=downstream_headers,
638730
additional_query_parameters=params or {},
639731
)
640-
_response = await self._raw_client.put(
732+
ctx = self._raw_client.put(
641733
url_64,
642734
external_user_id=external_user_id,
643735
account_id=account_id,
644736
request=body or {},
645737
request_options=request_options,
646738
)
647-
return _response.data
739+
_response = await ctx.__aenter__()
740+
data = _response.data
741+
742+
if not isinstance(data, AsyncIterator):
743+
await ctx.__aexit__(None, None, None)
744+
return data
745+
746+
async def _stream() -> typing.AsyncIterator[bytes]:
747+
try:
748+
async for chunk in data:
749+
yield chunk
750+
finally:
751+
await ctx.__aexit__(None, None, None)
752+
return _stream()
648753

649754
async def delete(
650755
self,
@@ -712,13 +817,26 @@ async def main() -> None:
712817
additional_headers=downstream_headers,
713818
additional_query_parameters=params or {},
714819
)
715-
_response = await self._raw_client.delete(
820+
ctx = self._raw_client.delete(
716821
url_64,
717822
external_user_id=external_user_id,
718823
account_id=account_id,
719824
request_options=request_options,
720825
)
721-
return _response.data
826+
_response = await ctx.__aenter__()
827+
data = _response.data
828+
829+
if not isinstance(data, AsyncIterator):
830+
await ctx.__aexit__(None, None, None)
831+
return data
832+
833+
async def _stream() -> typing.AsyncIterator[bytes]:
834+
try:
835+
async for chunk in data:
836+
yield chunk
837+
finally:
838+
await ctx.__aexit__(None, None, None)
839+
return _stream()
722840

723841
async def patch(
724842
self,
@@ -794,11 +912,24 @@ async def main() -> None:
794912
additional_headers=downstream_headers,
795913
additional_query_parameters=params or {},
796914
)
797-
_response = await self._raw_client.patch(
915+
ctx = self._raw_client.patch(
798916
url_64,
799917
external_user_id=external_user_id,
800918
account_id=account_id,
801919
request=body or {},
802920
request_options=request_options,
803921
)
804-
return _response.data
922+
_response = await ctx.__aenter__()
923+
data = _response.data
924+
925+
if not isinstance(data, AsyncIterator):
926+
await ctx.__aexit__(None, None, None)
927+
return data
928+
929+
async def _stream() -> typing.AsyncIterator[bytes]:
930+
try:
931+
async for chunk in data:
932+
yield chunk
933+
finally:
934+
await ctx.__aexit__(None, None, None)
935+
return _stream()

0 commit comments

Comments
 (0)