Skip to content

Commit b7b050b

Browse files
committed
feat(mypy): fix async web3 method typing
1 parent a133ebd commit b7b050b

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

faster_web3/eth/async_eth.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,6 @@ async def modify_transaction(
612612
current_transaction
613613
)
614614
new_transaction = merge(current_transaction_params, transaction_params)
615-
616615
return await async_replace_transaction(w3, current_transaction, new_transaction)
617616

618617
# eth_sign
@@ -730,7 +729,8 @@ async def subscribe(
730729
label: Optional[str] = None,
731730
parallelize: Optional[bool] = None,
732731
) -> HexStr:
733-
if not isinstance(self.w3.provider, PersistentConnectionProvider):
732+
w3 = self.w3
733+
if not isinstance(w3.provider, PersistentConnectionProvider):
734734
raise MethodNotSupported(
735735
"eth_subscribe is only supported with providers that support "
736736
"persistent connections."
@@ -743,21 +743,22 @@ async def subscribe(
743743
label=label,
744744
parallelize=parallelize,
745745
)
746-
return await self.w3.subscription_manager.subscribe(sub)
746+
return await w3.subscription_manager.subscribe(sub)
747747

748748
_unsubscribe: Method[Callable[[HexStr], Awaitable[bool]]] = Method(
749749
RPC.eth_unsubscribe,
750750
mungers=[default_root_munger],
751751
)
752752

753753
async def unsubscribe(self, subscription_id: HexStr) -> bool:
754-
if not isinstance(self.w3.provider, PersistentConnectionProvider):
754+
w3 = self.w3
755+
if not isinstance(w3.provider, PersistentConnectionProvider):
755756
raise MethodNotSupported(
756757
"eth_unsubscribe is only supported with providers that support "
757758
"persistent connections."
758759
)
759760

760-
for sub in self.w3.subscription_manager.subscriptions:
761+
for sub in w3.subscription_manager.subscriptions:
761762
if sub._id == subscription_id:
762763
return await sub.unsubscribe()
763764

faster_web3/providers/persistent/subscription_container.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,4 @@ def get_handler_subscription_by_id(
5454
self, sub_id: HexStr
5555
) -> Optional[EthSubscription[Any]]:
5656
sub = self.get_by_id(sub_id)
57-
if sub and sub._handler:
58-
return sub
59-
return None
57+
return sub if sub and sub._handler else None

faster_web3/providers/persistent/utils.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,65 @@
11
import functools
22
from typing import (
33
TYPE_CHECKING,
4-
Any,
54
Callable,
65
Optional,
76
)
87

8+
from typing_extensions import (
9+
Concatenate,
10+
ParamSpec,
11+
)
12+
913
from faster_web3.exceptions import (
1014
Web3ValidationError,
1115
)
1216
from faster_web3.providers import (
1317
PersistentConnectionProvider,
1418
)
19+
from faster_web3.types import (
20+
TReturn,
21+
)
1522

1623
if TYPE_CHECKING:
1724
from faster_web3.main import ( # noqa: F401
25+
AsyncProviderT,
1826
AsyncWeb3,
1927
)
2028

2129

30+
P = ParamSpec("P")
31+
32+
AsyncWeb3Method = Callable[Concatenate["AsyncWeb3[AsyncProviderT]", P], TReturn]
33+
34+
2235
def persistent_connection_provider_method(
2336
message: Optional[str] = None,
24-
) -> Callable[..., Any]:
37+
) -> Callable[
38+
[AsyncWeb3Method[AsyncProviderT, P, TReturn]],
39+
AsyncWeb3Method[AsyncProviderT, P, TReturn],
40+
]:
2541
"""
2642
Decorator that raises an exception if the provider is not an instance of
2743
``PersistentConnectionProvider``.
2844
"""
2945

30-
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
31-
@functools.wraps(func)
32-
def inner(self: "AsyncWeb3[Any]", *args: Any, **kwargs: Any) -> Any:
33-
nonlocal message
34-
if message is None:
35-
message = (
36-
f"``{func.__name__}`` can only be called on a "
37-
"``PersistentConnectionProvider`` instance."
38-
)
46+
def decorator(
47+
func: AsyncWeb3Method[AsyncProviderT, P, TReturn],
48+
) -> AsyncWeb3Method[AsyncProviderT, P, TReturn]:
49+
if message is None:
50+
message_actual = (
51+
f"``{func.__name__}`` can only be called on a "
52+
"``PersistentConnectionProvider`` instance."
53+
)
54+
else:
55+
message_actual = message
3956

57+
@functools.wraps(func)
58+
def inner(
59+
self: "AsyncWeb3[AsyncProviderT]", *args: P.args, **kwargs: P.kwargs
60+
) -> TReturn:
4061
if not isinstance(self.provider, PersistentConnectionProvider):
41-
raise Web3ValidationError(message)
62+
raise Web3ValidationError(message_actual)
4263
return func(self, *args, **kwargs)
4364

4465
return inner

0 commit comments

Comments
 (0)