diff --git a/MIGRATION.md b/MIGRATION.md index 30132a2..bb879e8 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -1,12 +1,15 @@ # Migration Guide -## 1.4.x -> 1.5.0 (REST client) +## 1.4.x -> 1.5.0 (REST/Onboarding clients) - `x10.perpetual.trading_client.PerpetualTradingClient` has been replaced with `x10.clients.rest.RestApiClient` (client has the same interface but new name reflects its purpose better). - Leftover models were migrated to `x10.models.*`. - Most of the dataclasses are immutable now. - `markets_info` module has been merged into `info` module. +- `UserClient` replaced by `OnboardingClient`, which accepts an account address and a sign-message callback instead of a raw L1 private key. +- `onboard_subaccount` error handling has changed. Previously, it silently recovered an existing sub-account (HTTP 409) by fetching it from `get_accounts()`. Now it raises `ValidationError` on conflict. Handle duplicates explicitly if you relied on the automatic recovery. +- Fixes https://github.com/x10xchange/python_sdk/issues/99. --- diff --git a/README.md b/README.md index 8da150f..d93aee2 100644 --- a/README.md +++ b/README.md @@ -257,7 +257,7 @@ All new accounts should use the `MAINNET_CONFIG` configuration bundle. ## OnBoarding via SDK (Since Version 0.3.0) -To onboard to the Extended Exchange, the `UserClient` defined in [user_client.py](x10/perpetual/user_client/user_client.py) provides a way to use an Ethereum account to onboard onto the Extended Exchange. +To onboard to the Extended Exchange, the `UserClient` defined in [user_client.py](x10/perpetual/user_client/user_client.py) provides a way to use an Ethereum account to onboard onto the Extended Exchange. ### TLDR - Check out: [onboarding_example.py](examples/onboarding_example.py) diff --git a/examples/cases/advanced/onboarding_with_eth_account.py b/examples/cases/advanced/onboarding_with_eth_account.py index 34a18d4..380ef5a 100644 --- a/examples/cases/advanced/onboarding_with_eth_account.py +++ b/examples/cases/advanced/onboarding_with_eth_account.py @@ -5,10 +5,10 @@ from eth_account.signers.local import LocalAccount from examples.utils import init_env +from x10.clients.onboarding import OnboardingClient from x10.clients.rest import RestApiClient from x10.config import TESTNET_CONFIG from x10.core.stark_account import StarkPerpetualAccount -from x10.perpetual.user_client.user_client import UserClient from x10.utils.string import is_hex_string LOGGER = logging.getLogger() @@ -23,12 +23,23 @@ async def run_example(): assert is_hex_string(eth_account_private_key), "`eth_account_private_key` must be a hex string" eth_local_account: LocalAccount = Account.from_key(eth_account_private_key) - user_client = UserClient(config=CONFIG, l1_private_key=eth_local_account.key.hex) + + onboarding_client = OnboardingClient( + config=CONFIG, + account_address=eth_local_account.address, + sign_message=lambda msg: eth_local_account.sign_message(msg).signature.hex(), + ) LOGGER.info("Onboarding with ETH account %s...", eth_local_account.address) - main_account = await user_client.onboard() - main_account_api_key = await user_client.create_account_api_key(main_account.account, "Onboarding example API key") + main_account = await onboarding_client.auth.onboard_client() + sub_account = await onboarding_client.auth.onboard_subaccount( + account_index=1, description="Onboarding example subaccount" + ) + main_account_api_key = await onboarding_client.account.create_api_key( + account_id=main_account.account.id, + description="Onboarding example API key", + ) starknet_account = StarkPerpetualAccount( api_key=main_account_api_key, @@ -38,7 +49,8 @@ async def run_example(): ) rest_client = RestApiClient(CONFIG, starknet_account) - LOGGER.info("StarkNet public key: %s", starknet_account.public_key) + LOGGER.info("StarkNet public key (main): %s", main_account.l2_key_pair.public_hex) + LOGGER.info("StarkNet public key (sub): %s", sub_account.l2_key_pair.public_hex) claim = await rest_client.testnet.claim_testing_funds() claim_id = claim.data.id if claim.data else None diff --git a/pyproject.toml b/pyproject.toml index f199601..6f956e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "x10-python-trading-starknet" -version = "1.4.1" +version = "1.5.0" description = "Python client for X10 API" authors = ["X10 "] repository = "https://github.com/x10xchange/python_sdk" diff --git a/tests/conftest.py b/tests/conftest.py index 6b3d1c0..2880a4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,6 +70,13 @@ def get_asset_xvs(): return _get_asset_xvs +@pytest.fixture +def get_eth_private_key(): + from tests.fixtures.onboarding import get_eth_private_key as _get_eth_private_key + + return _get_eth_private_key + + @pytest.fixture def create_asset_operations(): from tests.fixtures.asset import create_asset_operations as _create_asset_operations diff --git a/tests/fixtures/onboarding.py b/tests/fixtures/onboarding.py new file mode 100644 index 0000000..1170d57 --- /dev/null +++ b/tests/fixtures/onboarding.py @@ -0,0 +1,3 @@ +def get_eth_private_key(): + # All known values from authentication service tests are used. + return "50c8e358cc974aaaa6e460641e53f78bdc550fd372984aa78ef8fd27c751e6f4" diff --git a/tests/perpetual/test_l2_key_derivation.py b/tests/perpetual/test_l2_key_derivation.py deleted file mode 100644 index d560c9d..0000000 --- a/tests/perpetual/test_l2_key_derivation.py +++ /dev/null @@ -1,15 +0,0 @@ -from eth_account import Account -from hamcrest import assert_that, equal_to - - -def test_known_l2_accounts(): - from x10.perpetual.user_client.onboarding import get_l2_keys_from_l1_account - - known_private_key = "50c8e358cc974aaaa6e460641e53f78bdc550fd372984aa78ef8fd27c751e6f4" - known_l2_private_key = "0x7dbb2c8651cc40e1d0d60b45eb52039f317a8aa82798bda52eee272136c0c44" - known_l2_public_key = "0x78298687996aff29a0bbcb994e1305db082d084f85ec38bb78c41e6787740ec" - - derived_keys = get_l2_keys_from_l1_account(Account.from_key(known_private_key), 0, signing_domain="x10.exchange") - - assert_that(derived_keys.private_hex, equal_to(known_l2_private_key)) - assert_that(derived_keys.public_hex, equal_to(known_l2_public_key)) diff --git a/tests/perpetual/test_onboarding_payload.py b/tests/perpetual/test_onboarding_payload.py deleted file mode 100644 index 2aab37e..0000000 --- a/tests/perpetual/test_onboarding_payload.py +++ /dev/null @@ -1,58 +0,0 @@ -import datetime - -from eth_account import Account -from hamcrest import assert_that, equal_to - -from x10.perpetual.user_client.onboarding import get_l2_keys_from_l1_account - - -def test_onboarding_object_generation(): - """ - All known values from authentication service tests are used. - """ - from x10.perpetual.user_client.onboarding import get_onboarding_payload - - known_private_key = "50c8e358cc974aaaa6e460641e53f78bdc550fd372984aa78ef8fd27c751e6f4" - known_l2_public_key = "0x78298687996aff29a0bbcb994e1305db082d084f85ec38bb78c41e6787740ec" - - l1_account = Account.from_key(known_private_key) - key_pair = get_l2_keys_from_l1_account(l1_account=l1_account, account_index=0, signing_domain="x10.exchange") - - payload = get_onboarding_payload( - account=l1_account, - time=datetime.datetime( - year=2024, - month=7, - day=30, - hour=16, - minute=1, - second=2, - tzinfo=datetime.timezone.utc, - ), - host="host", - key_pair=key_pair, - signing_domain="x10.exchange", - ).to_json() - - assert_that( - payload, - equal_to( - { - "l1Signature": "9a59eb699eb58f2ec975455f33dd7205c8a569f7b6d7647c25b71e7ab7eec3d30f2b8c9038f06f077167eb90e0c002602e4ecbab180fad4b2c91d2259883e6571c", # noqa: E501 - "l2Key": known_l2_public_key, - "l2Signature": { - "r": "0x70881694c59c7212b1a47fbbc07df4d32678f0326f778861ec3a2a5dbc09157", - "s": "0x558805193faa5d780719cba5f699ae1c888eec1fee23da4215fdd94a744d2cb", - }, - "accountCreation": { - "accountIndex": 0, - "wallet": "0x2c12f074766f5eF9c5300ca8C85d06fBa605C59f", - "tosAccepted": True, - "time": "2024-07-30T16:01:02Z", - "action": "REGISTER", - "host": "host", - }, - "referralCode": None, - } - ), - ) diff --git a/tests/signing/test_onboarding.py b/tests/signing/test_onboarding.py new file mode 100644 index 0000000..83796b4 --- /dev/null +++ b/tests/signing/test_onboarding.py @@ -0,0 +1,90 @@ +from eth_account import Account +from eth_account.messages import SignableMessage +from eth_account.signers.local import LocalAccount +from freezegun import freeze_time +from hamcrest import assert_that, equal_to + +from x10.signing.onboarding import ( + RequestSignature, + get_l2_keys_from_l1_account, + get_onboarding_payload, + sign_api_request, +) +from x10.utils.date import utc_now + +# All known values from authentication service tests are used. +KNOWN_L2_PRIVATE_KEY = "0x7dbb2c8651cc40e1d0d60b45eb52039f317a8aa82798bda52eee272136c0c44" +KNOWN_L2_PUBLIC_KEY = "0x78298687996aff29a0bbcb994e1305db082d084f85ec38bb78c41e6787740ec" + + +@freeze_time("2024-01-05 01:08:56.860694") +def test_sign_api_request(get_eth_private_key): + local_account: LocalAccount = Account.from_key(get_eth_private_key()) + signature = sign_api_request("/action", lambda msg: local_account.sign_message(msg).signature.hex()) + + assert_that( + signature, + equal_to( + RequestSignature( + "f4e4e9aaf2014a3651dfafec63854e4dfd486dcc10e77f56b330e9942630fde03588e43d6c022f8513c1e4cf211e670c3134d3cfdf1bd61b570d2588bfb9fc921b", # noqa: E501 + "2024-01-05T01:08:56Z", + ) + ), + ) + + +@freeze_time("2024-07-30 16:01:02.000000") +def test_onboarding_object_generation(get_eth_private_key): + l1_account = Account.from_key(get_eth_private_key()) + + def sign_message(msg: SignableMessage) -> str: + return l1_account.sign_message(msg).signature.hex() + + key_pair = get_l2_keys_from_l1_account( + account_index=0, account_address=l1_account.address, signing_domain="x10.exchange", sign_message=sign_message + ) + + payload = get_onboarding_payload( + account_address=l1_account.address, + time=utc_now(), + host="host", + key_pair=key_pair, + signing_domain="x10.exchange", + sign_message=sign_message, + ).to_json() + + assert_that( + payload, + equal_to( + { + "l1Signature": "9a59eb699eb58f2ec975455f33dd7205c8a569f7b6d7647c25b71e7ab7eec3d30f2b8c9038f06f077167eb90e0c002602e4ecbab180fad4b2c91d2259883e6571c", # noqa: E501 + "l2Key": KNOWN_L2_PUBLIC_KEY, + "l2Signature": { + "r": "0x70881694c59c7212b1a47fbbc07df4d32678f0326f778861ec3a2a5dbc09157", + "s": "0x558805193faa5d780719cba5f699ae1c888eec1fee23da4215fdd94a744d2cb", + }, + "accountCreation": { + "accountIndex": 0, + "wallet": "0x2c12f074766f5eF9c5300ca8C85d06fBa605C59f", + "tosAccepted": True, + "time": "2024-07-30T16:01:02Z", + "action": "REGISTER", + "host": "host", + }, + "referralCode": None, + } + ), + ) + + +def test_known_l2_accounts(get_eth_private_key): + local_account: LocalAccount = Account.from_key(get_eth_private_key()) + derived_keys = get_l2_keys_from_l1_account( + account_index=0, + account_address=local_account.address, + signing_domain="x10.exchange", + sign_message=lambda msg: local_account.sign_message(msg).signature.hex(), + ) + + assert_that(derived_keys.private_hex, equal_to(KNOWN_L2_PRIVATE_KEY)) + assert_that(derived_keys.public_hex, equal_to(KNOWN_L2_PUBLIC_KEY)) diff --git a/x10/clients/onboarding/__init__.py b/x10/clients/onboarding/__init__.py new file mode 100644 index 0000000..fa2f58e --- /dev/null +++ b/x10/clients/onboarding/__init__.py @@ -0,0 +1 @@ +from x10.clients.onboarding.onboarding_client import OnboardingClient # noqa: F401 diff --git a/x10/perpetual/user_client/__init__.py b/x10/clients/onboarding/modules/__init__.py similarity index 100% rename from x10/perpetual/user_client/__init__.py rename to x10/clients/onboarding/modules/__init__.py diff --git a/x10/clients/onboarding/modules/account_module.py b/x10/clients/onboarding/modules/account_module.py new file mode 100644 index 0000000..9ea1256 --- /dev/null +++ b/x10/clients/onboarding/modules/account_module.py @@ -0,0 +1,32 @@ +from x10.clients.onboarding.modules.base_module import BaseModule +from x10.errors import ValidationError +from x10.models.account import ApiKeyRequestModel, ApiKeyResponseModel +from x10.signing.onboarding import sign_api_request +from x10.utils.http import RequestHeader, send_post_request + + +class AccountModule(BaseModule): + async def create_api_key(self, *, account_id: int, description: str) -> str: + request_path = "/api/v1/user/account/api-key" + signature = sign_api_request(request_path, self._sign_message) + headers: dict[str, str] = { + RequestHeader.AUTH_L1_SIGNATURE: signature.value, + RequestHeader.AUTH_L1_MESSAGE_TIME: signature.time, + RequestHeader.AUTH_ACTIVE_ACCOUNT: str(account_id), + } + + payload = ApiKeyRequestModel(description=description) + url = self._get_url(request_path) + response = await send_post_request( + await self._get_session(), + url, + ApiKeyResponseModel, + json=payload.to_api_request_json(), + request_headers=headers, + ) + response_data = response.data + + if response_data is None: + raise ValidationError("No API key data returned from onboarding") + + return response_data.key diff --git a/x10/clients/onboarding/modules/auth_module.py b/x10/clients/onboarding/modules/auth_module.py new file mode 100644 index 0000000..bf3db0f --- /dev/null +++ b/x10/clients/onboarding/modules/auth_module.py @@ -0,0 +1,89 @@ +from aiohttp.web_exceptions import HTTPConflict + +from x10.clients.onboarding.modules.base_module import BaseModule +from x10.errors import SdkError, ValidationError +from x10.models.account import AccountModel +from x10.models.client import OnboardedClientModel +from x10.signing.onboarding import ( + OnBoardedAccount, + get_l2_keys_from_l1_account, + get_onboarding_payload, + get_sub_account_creation_payload, + sign_api_request, +) +from x10.utils.http import RequestHeader, send_post_request + + +class SubAccountExists(SdkError): + pass + + +class AuthModule(BaseModule): + async def onboard_client(self, *, referral_code: str | None = None) -> OnBoardedAccount: + l2_key_pair = get_l2_keys_from_l1_account( + account_index=0, + account_address=self._get_account_address(), + signing_domain=self._get_config().signing.signing_domain, + sign_message=self._sign_message, + ) + payload = get_onboarding_payload( + account_address=self._get_account_address(), + signing_domain=self._get_config().signing.signing_domain, + key_pair=l2_key_pair, + referral_code=referral_code, + host=self._get_config().endpoints.onboarding_url, + sign_message=self._sign_message, + ) + + url = self._get_url("/auth/onboard") + onboarding_response = await send_post_request( + await self._get_session(), url, OnboardedClientModel, json=payload.to_json() + ) + + onboarded_client = onboarding_response.data + + if onboarded_client is None: + raise ValidationError("No account data returned from onboarding") + + return OnBoardedAccount(account=onboarded_client.default_account, l2_key_pair=l2_key_pair) + + async def onboard_subaccount(self, *, account_index: int, description: str): + request_path = "/auth/onboard/subaccount" + signature = sign_api_request(request_path, self._sign_message) + headers: dict[str, str] = { + RequestHeader.AUTH_L1_SIGNATURE: signature.value, + RequestHeader.AUTH_L1_MESSAGE_TIME: signature.time, + } + + key_pair = get_l2_keys_from_l1_account( + account_index=account_index, + account_address=self._get_account_address(), + signing_domain=self._get_config().signing.signing_domain, + sign_message=self._sign_message, + ) + payload = get_sub_account_creation_payload( + account_index=account_index, + l1_address=self._get_account_address(), + key_pair=key_pair, + description=description, + host=self._get_config().endpoints.onboarding_url, + ) + url = self._get_url(request_path) + + try: + onboarding_response = await send_post_request( + await self._get_session(), + url, + AccountModel, + json=payload.to_json(), + request_headers=headers, + response_code_to_exception={HTTPConflict.status_code: SubAccountExists}, + ) + onboarded_account = onboarding_response.data + except SubAccountExists: + raise ValidationError("Subaccount already exists") + + if onboarded_account is None: + raise ValidationError("No account data returned from onboarding") + + return OnBoardedAccount(account=onboarded_account, l2_key_pair=key_pair) diff --git a/x10/clients/onboarding/modules/base_module.py b/x10/clients/onboarding/modules/base_module.py new file mode 100644 index 0000000..e137b71 --- /dev/null +++ b/x10/clients/onboarding/modules/base_module.py @@ -0,0 +1,51 @@ +from typing import Dict, Optional + +import aiohttp +from aiohttp import ClientTimeout +from eth_account.messages import SignableMessage +from eth_typing import ChecksumAddress + +from x10.config import Config +from x10.signing.onboarding import SignMessageCallback +from x10.utils.http import get_url + + +class BaseModule: + __config: Config + __account_address: ChecksumAddress + __sign_message: SignMessageCallback + __session: Optional[aiohttp.ClientSession] + + def __init__(self, config: Config, *, account_address: ChecksumAddress, sign_message: SignMessageCallback): + super().__init__() + + self.__config = config + self.__account_address = account_address + self.__sign_message = sign_message + self.__session = None + + def _get_url(self, path: str, *, query: Optional[Dict] = None, **path_params) -> str: + return get_url(f"{self.__config.endpoints.onboarding_url}{path}", query=query, **path_params) + + def _get_config(self) -> Config: + return self.__config + + def _get_account_address(self) -> ChecksumAddress: + return self.__account_address + + def _sign_message(self, msg: SignableMessage) -> str: + return self.__sign_message(msg) + + async def _get_session(self) -> aiohttp.ClientSession: + if self.__session is None: + created_session = aiohttp.ClientSession( + timeout=ClientTimeout(total=self.__config.defaults.request_timeout_seconds) + ) + self.__session = created_session + + return self.__session + + async def close_session(self): + if self.__session: + await self.__session.close() + self.__session = None diff --git a/x10/clients/onboarding/onboarding_client.py b/x10/clients/onboarding/onboarding_client.py new file mode 100644 index 0000000..ed38234 --- /dev/null +++ b/x10/clients/onboarding/onboarding_client.py @@ -0,0 +1,33 @@ +from eth_typing import ChecksumAddress + +from x10.clients.onboarding.modules.account_module import AccountModule +from x10.clients.onboarding.modules.auth_module import AuthModule +from x10.config import Config +from x10.signing.onboarding import SignMessageCallback + + +class OnboardingClient: + __account_module: AccountModule + __auth_module: AuthModule + + async def close(self): + await self.__account_module.close_session() + await self.__auth_module.close_session() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + def __init__(self, config: Config, *, account_address: ChecksumAddress, sign_message: SignMessageCallback): + self.__account_module = AccountModule(config, account_address=account_address, sign_message=sign_message) + self.__auth_module = AuthModule(config, account_address=account_address, sign_message=sign_message) + + @property + def account(self): + return self.__account_module + + @property + def auth(self): + return self.__auth_module diff --git a/x10/clients/rest/modules/account_module.py b/x10/clients/rest/modules/account_module.py index 8a54779..366185a 100644 --- a/x10/clients/rest/modules/account_module.py +++ b/x10/clients/rest/modules/account_module.py @@ -31,11 +31,11 @@ class AccountModule(BaseModule): async def get_account(self) -> WrappedApiResponseModel[AccountModel]: url = self._get_url("/user/account/info") - return await send_get_request(await self.get_session(), url, AccountModel, api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, AccountModel, api_key=self._get_api_key()) async def get_client(self) -> WrappedApiResponseModel[ClientModel]: url = self._get_url("/user/client/info") - return await send_get_request(await self.get_session(), url, ClientModel, api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, ClientModel, api_key=self._get_api_key()) async def get_balance(self) -> WrappedApiResponseModel[BalanceModel]: """ @@ -43,7 +43,7 @@ async def get_balance(self) -> WrappedApiResponseModel[BalanceModel]: """ url = self._get_url("/user/balance") - return await send_get_request(await self.get_session(), url, BalanceModel, api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, BalanceModel, api_key=self._get_api_key()) async def get_positions( self, *, market_names: Optional[List[str]] = None, position_side: Optional[PositionSide] = None @@ -53,7 +53,7 @@ async def get_positions( """ url = self._get_url("/user/positions", query={"market": market_names, "side": position_side}) - return await send_get_request(await self.get_session(), url, List[PositionModel], api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, List[PositionModel], api_key=self._get_api_key()) async def get_positions_history( self, @@ -71,7 +71,7 @@ async def get_positions_history( query={"market": market_names, "side": position_side, "cursor": cursor, "limit": limit}, ) return await send_get_request( - await self.get_session(), url, List[PositionHistoryModel], api_key=self._get_api_key() + await self._get_session(), url, List[PositionHistoryModel], api_key=self._get_api_key() ) async def get_open_orders( @@ -88,7 +88,7 @@ async def get_open_orders( "/user/orders", query={"market": market_names, "type": order_type, "side": order_side}, ) - return await send_get_request(await self.get_session(), url, List[OpenOrderModel], api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, List[OpenOrderModel], api_key=self._get_api_key()) async def get_orders_history( self, @@ -106,7 +106,7 @@ async def get_orders_history( "/user/orders/history", query={"market": market_names, "type": order_type, "side": order_side, "cursor": cursor, "limit": limit}, ) - return await send_get_request(await self.get_session(), url, List[OpenOrderModel], api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, List[OpenOrderModel], api_key=self._get_api_key()) async def get_order_by_id(self, order_id: int) -> WrappedApiResponseModel[OpenOrderModel]: """ @@ -115,7 +115,7 @@ async def get_order_by_id(self, order_id: int) -> WrappedApiResponseModel[OpenOr url = self._get_url("/user/orders/", order_id=order_id) - return await send_get_request(await self.get_session(), url, OpenOrderModel, api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, OpenOrderModel, api_key=self._get_api_key()) async def get_order_by_external_id(self, external_id: str) -> WrappedApiResponseModel[list[OpenOrderModel]]: """ @@ -124,7 +124,7 @@ async def get_order_by_external_id(self, external_id: str) -> WrappedApiResponse url = self._get_url("/user/orders/external/", external_id=external_id) - return await send_get_request(await self.get_session(), url, list[OpenOrderModel], api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, list[OpenOrderModel], api_key=self._get_api_key()) async def get_spot_balances(self) -> WrappedApiResponseModel[List[SpotBalanceModel]]: """ @@ -133,7 +133,7 @@ async def get_spot_balances(self) -> WrappedApiResponseModel[List[SpotBalanceMod url = self._get_url("/user/spot/balances") return await send_get_request( - await self.get_session(), url, List[SpotBalanceModel], api_key=self._get_api_key() + await self._get_session(), url, List[SpotBalanceModel], api_key=self._get_api_key() ) async def get_trades( @@ -154,7 +154,7 @@ async def get_trades( ) return await send_get_request( - await self.get_session(), url, List[AccountTradeModel], api_key=self._get_api_key() + await self._get_session(), url, List[AccountTradeModel], api_key=self._get_api_key() ) async def get_fees( @@ -171,7 +171,9 @@ async def get_fees( "builderId": builder_id, }, ) - return await send_get_request(await self.get_session(), url, List[TradingFeeModel], api_key=self._get_api_key()) + return await send_get_request( + await self._get_session(), url, List[TradingFeeModel], api_key=self._get_api_key() + ) async def get_leverage( self, market_names: Optional[List[str]] = None @@ -182,7 +184,7 @@ async def get_leverage( url = self._get_url("/user/leverage", query={"market": market_names}) return await send_get_request( - await self.get_session(), url, List[AccountLeverageModel], api_key=self._get_api_key() + await self._get_session(), url, List[AccountLeverageModel], api_key=self._get_api_key() ) async def update_leverage(self, market_name: str, leverage: Decimal) -> WrappedApiResponseModel[EmptyModel]: @@ -193,7 +195,7 @@ async def update_leverage(self, market_name: str, leverage: Decimal) -> WrappedA url = self._get_url("/user/leverage") request_model = AccountLeverageModel(market=market_name, leverage=leverage) return await send_patch_request( - await self.get_session(), + await self._get_session(), url, EmptyModel, json=request_model.to_api_request_json(), @@ -202,7 +204,7 @@ async def update_leverage(self, market_name: str, leverage: Decimal) -> WrappedA async def get_bridge_config(self) -> WrappedApiResponseModel[BridgesConfigModel]: url = self._get_url("/user/bridge/config") - return await send_get_request(await self.get_session(), url, BridgesConfigModel, api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, BridgesConfigModel, api_key=self._get_api_key()) async def get_bridge_quote( self, chain_in: str, chain_out: str, amount: Decimal @@ -215,7 +217,7 @@ async def get_bridge_quote( "amount": amount, }, ) - return await send_get_request(await self.get_session(), url, QuoteModel, api_key=self._get_api_key()) + return await send_get_request(await self._get_session(), url, QuoteModel, api_key=self._get_api_key()) async def commit_bridge_quote(self, id: str): url = self._get_url( @@ -224,7 +226,7 @@ async def commit_bridge_quote(self, id: str): "id": id, }, ) - await send_post_request(await self.get_session(), url, EmptyModel, api_key=self._get_api_key()) + await send_post_request(await self._get_session(), url, EmptyModel, api_key=self._get_api_key()) async def transfer( self, @@ -244,13 +246,13 @@ async def transfer( to_vault=to_vault, to_l2_key=to_l2_key, amount=amount, - config=self._get_endpoint_config(), + config=self._get_config(), stark_account=self._get_stark_account(), nonce=nonce, ) return await send_post_request( - await self.get_session(), + await self._get_session(), url, TransferResponseModel, json=request_model.to_api_request_json(), @@ -267,42 +269,48 @@ async def withdraw( ) -> WrappedApiResponseModel[int]: url = self._get_url("/user/withdrawal") account = (await self.get_account()).data + if account is None: raise ValidationError("Account not found") + if quote_id is None and chain_id != "STRK": raise ValidationError("quote_id is required for EVM withdrawals") - recipient_stark_address = None - if stark_address is None: + async def get_recipient_stark_address() -> str: + if stark_address: + return stark_address + if chain_id == "STRK": client = (await self.get_client()).data + if client is None: raise ValidationError("Client not found") + if client.starknet_wallet_address is None: raise ValidationError( - "Client does not have attached starknet_wallet_address. Can't determine withdrawal address." + "Client does not have attached `starknet_wallet_address`. Can't determine withdrawal address." ) - else: - recipient_stark_address = client.starknet_wallet_address - else: - if account.bridge_starknet_address is None: - raise ValidationError("Account bridge_starknet_address not found") - recipient_stark_address = account.bridge_starknet_address - else: - recipient_stark_address = stark_address + return client.starknet_wallet_address + + if account.bridge_starknet_address is None: + raise ValidationError("Account `bridge_starknet_address` not found") + + return account.bridge_starknet_address + + recipient_stark_address = await get_recipient_stark_address() request_model = create_withdrawal_object( amount=amount, recipient_stark_address=recipient_stark_address, stark_account=self._get_stark_account(), - config=self._get_endpoint_config(), + config=self._get_config(), account_id=account.id, chain_id=chain_id, quote_id=quote_id, nonce=nonce, ) return await send_post_request( - await self.get_session(), + await self._get_session(), url, int, json=request_model.to_api_request_json(), @@ -334,5 +342,5 @@ async def asset_operations( }, ) return await send_get_request( - await self.get_session(), url, List[AssetOperationModel], api_key=self._get_api_key() + await self._get_session(), url, List[AssetOperationModel], api_key=self._get_api_key() ) diff --git a/x10/clients/rest/modules/base_module.py b/x10/clients/rest/modules/base_module.py index 9725869..f099e9a 100644 --- a/x10/clients/rest/modules/base_module.py +++ b/x10/clients/rest/modules/base_module.py @@ -23,16 +23,17 @@ def __init__( stark_account: Optional[StarkPerpetualAccount] = None, ): super().__init__() + self.__config = config self.__api_key = api_key self.__stark_account = stark_account self.__session = None def _get_url(self, path: str, *, query: Optional[Dict] = None, **path_params) -> str: - return get_url(f"{self._get_endpoint_config().api_base_url}{path}", query=query, **path_params) + return get_url(f"{self.__config.endpoints.api_base_url}{path}", query=query, **path_params) - def _get_endpoint_config(self): - return self.__config.endpoints + def _get_config(self) -> Config: + return self.__config def _get_api_key(self): if not self.__api_key: @@ -46,7 +47,7 @@ def _get_stark_account(self): return self.__stark_account - async def get_session(self) -> aiohttp.ClientSession: + async def _get_session(self) -> aiohttp.ClientSession: if self.__session is None: created_session = aiohttp.ClientSession( timeout=ClientTimeout(total=self.__config.defaults.request_timeout_seconds) diff --git a/x10/clients/rest/modules/info_module.py b/x10/clients/rest/modules/info_module.py index d462721..11c8fe1 100644 --- a/x10/clients/rest/modules/info_module.py +++ b/x10/clients/rest/modules/info_module.py @@ -16,11 +16,11 @@ class InfoModule(BaseModule): async def get_settings(self): url = self._get_url("/info/settings") - return await send_get_request(await self.get_session(), url, SettingsModel) + return await send_get_request(await self._get_session(), url, SettingsModel) async def get_assets(self): url = self._get_url("/info/assets") - return await send_get_request(await self.get_session(), url, List[AssetModel]) + return await send_get_request(await self._get_session(), url, List[AssetModel]) async def get_assets_dict(self): assets = await self.get_assets() @@ -28,7 +28,7 @@ async def get_assets_dict(self): async def get_asset_price(self, *, asset_name: str): url = self._get_url("/info/assets//price", asset_name=asset_name) - return await send_get_request(await self.get_session(), url, Decimal) + return await send_get_request(await self._get_session(), url, Decimal) async def get_markets(self, *, market_names: Optional[List[str]] = None): """ @@ -36,7 +36,7 @@ async def get_markets(self, *, market_names: Optional[List[str]] = None): """ url = self._get_url("/info/markets", query={"market": market_names}) - return await send_get_request(await self.get_session(), url, List[MarketModel]) + return await send_get_request(await self._get_session(), url, List[MarketModel]) async def get_markets_dict(self): markets = await self.get_markets() @@ -48,7 +48,7 @@ async def get_market_statistics(self, *, market_name: str): """ url = self._get_url("/info/markets//stats", market=market_name) - return await send_get_request(await self.get_session(), url, MarketStatsModel) + return await send_get_request(await self._get_session(), url, MarketStatsModel) async def get_candles_history( self, @@ -73,7 +73,7 @@ async def get_candles_history( "endTime": to_epoch_millis(end_time) if end_time else None, }, ) - return await send_get_request(await self.get_session(), url, List[CandleModel]) + return await send_get_request(await self._get_session(), url, List[CandleModel]) async def get_funding_rates_history(self, *, market_name: str, start_time: datetime, end_time: datetime): """ @@ -88,7 +88,7 @@ async def get_funding_rates_history(self, *, market_name: str, start_time: datet "endTime": to_epoch_millis(end_time), }, ) - return await send_get_request(await self.get_session(), url, List[FundingRateModel]) + return await send_get_request(await self._get_session(), url, List[FundingRateModel]) async def get_orderbook_snapshot(self, *, market_name: str): """ @@ -96,4 +96,4 @@ async def get_orderbook_snapshot(self, *, market_name: str): """ url = self._get_url("/info/markets//orderbook", market=market_name) - return await send_get_request(await self.get_session(), url, OrderbookUpdateModel) + return await send_get_request(await self._get_session(), url, OrderbookUpdateModel) diff --git a/x10/clients/rest/modules/order_management_module.py b/x10/clients/rest/modules/order_management_module.py index e5ff659..b942c84 100644 --- a/x10/clients/rest/modules/order_management_module.py +++ b/x10/clients/rest/modules/order_management_module.py @@ -22,7 +22,7 @@ async def place_order(self, order: NewOrderModel): url = self._get_url("/user/order") response = await send_post_request( - await self.get_session(), + await self._get_session(), url, PlacedOrderModel, json=order.to_api_request_json(exclude_none=True), @@ -36,7 +36,7 @@ async def cancel_order(self, order_id: int): """ url = self._get_url("/user/order/", order_id=order_id) - return await send_delete_request(await self.get_session(), url, EmptyModel, api_key=self._get_api_key()) + return await send_delete_request(await self._get_session(), url, EmptyModel, api_key=self._get_api_key()) async def cancel_order_by_external_id(self, order_external_id: str): """ @@ -44,7 +44,7 @@ async def cancel_order_by_external_id(self, order_external_id: str): """ url = self._get_url("/user/order", query={"externalId": order_external_id}) - return await send_delete_request(await self.get_session(), url, EmptyModel, api_key=self._get_api_key()) + return await send_delete_request(await self._get_session(), url, EmptyModel, api_key=self._get_api_key()) async def mass_cancel( self, @@ -66,7 +66,7 @@ async def mass_cancel( cancel_all=cancel_all, ) return await send_post_request( - await self.get_session(), + await self._get_session(), url, EmptyModel, json=request_model.to_api_request_json(exclude_none=True), diff --git a/x10/clients/rest/modules/testnet_module.py b/x10/clients/rest/modules/testnet_module.py index 0c742ca..5132c31 100644 --- a/x10/clients/rest/modules/testnet_module.py +++ b/x10/clients/rest/modules/testnet_module.py @@ -14,18 +14,18 @@ class TestnetModule(BaseModule): def __init__( self, config: Config, - api_key: Optional[str] = None, + *, account_module: Optional[AccountModule] = None, + api_key: Optional[str] = None, ): super().__init__(config, api_key=api_key) + self._account_module = account_module - async def claim_testing_funds( - self, - ) -> WrappedApiResponseModel[ClaimResponseModel]: + async def claim_testing_funds(self) -> WrappedApiResponseModel[ClaimResponseModel]: url = self._get_url("/user/claim") resp = await send_post_request( - await self.get_session(), + await self._get_session(), url, ClaimResponseModel, json={}, @@ -34,6 +34,7 @@ async def claim_testing_funds( if resp.error: return resp + if self._account_module and resp.data: account_module = self._account_module claim_to_check = resp.data.id @@ -61,4 +62,5 @@ async def wait_for_claim_to_complete() -> List[AssetOperationModel]: await wait_for_claim_to_complete() except tenacity.RetryError: pass + return resp diff --git a/x10/clients/rest/modules/vault_module.py b/x10/clients/rest/modules/vault_module.py index 7f3bf7e..e9c1b89 100644 --- a/x10/clients/rest/modules/vault_module.py +++ b/x10/clients/rest/modules/vault_module.py @@ -38,7 +38,7 @@ async def get_vault_share_balance(self) -> Decimal: spot_balances = (await self._account_module.get_spot_balances()).data if spot_balances is None: raise ValidationError("Failed to get spot balances") - vault_asset_balances = filter(lambda b: b.asset == self._get_endpoint_config().vault_asset_name, spot_balances) + vault_asset_balances = filter(lambda b: b.asset == self._get_config().endpoints.vault_asset_name, spot_balances) total_vault_asset_balance = sum(map(lambda b: b.balance, vault_asset_balances), Decimal(0)) return total_vault_asset_balance @@ -49,7 +49,7 @@ async def deposit_to_vault(self, *, collateral_amount: Decimal) -> None: account_info = (await self._account_module.get_account()).data assets = await self._info_module.get_assets_dict() vault_asset_price = ( - await self._info_module.get_asset_price(asset_name=self._get_endpoint_config().vault_asset_name) + await self._info_module.get_asset_price(asset_name=self._get_config().endpoints.vault_asset_name) ).data assert account_info is not None @@ -57,7 +57,7 @@ async def deposit_to_vault(self, *, collateral_amount: Decimal) -> None: position_id = account_info.l2_vault collateral_asset = assets[COLLATERAL_ASSET_NAME] - vault_asset = assets[self._get_endpoint_config().vault_asset_name] + vault_asset = assets[self._get_config().endpoints.vault_asset_name] vault_shares_expected = self.__calc_vault_shares_expected( collateral_amount, vault_asset_price, @@ -71,7 +71,7 @@ async def deposit_to_vault(self, *, collateral_amount: Decimal) -> None: quote_asset_model=collateral_asset, base_asset_model=vault_asset, starknet_account=self._account, - starknet_domain=self._get_endpoint_config().starknet_domain, + starknet_domain=self._get_config().signing.starknet_domain, is_buy=True, ) deposit_request = DepositRequestModel( @@ -84,7 +84,7 @@ async def deposit_to_vault(self, *, collateral_amount: Decimal) -> None: url = self._get_url("/vault/user/deposits") resp = await send_post_request( - await self.get_session(), + await self._get_session(), url, NoneType, json=deposit_request.to_api_request_json(exclude_none=True), @@ -101,7 +101,7 @@ async def withdraw_from_vault(self, *, shares_amount: Decimal) -> None: assets = await self._info_module.get_assets_dict() account_info = (await self._account_module.get_account()).data vault_asset_price = ( - await self._info_module.get_asset_price(asset_name=self._get_endpoint_config().vault_asset_name) + await self._info_module.get_asset_price(asset_name=self._get_config().endpoints.vault_asset_name) ).data assert account_info is not None @@ -109,7 +109,7 @@ async def withdraw_from_vault(self, *, shares_amount: Decimal) -> None: position_id = account_info.l2_vault collateral_asset = assets[COLLATERAL_ASSET_NAME] - vault_asset = assets[self._get_endpoint_config().vault_asset_name] + vault_asset = assets[self._get_config().endpoints.vault_asset_name] collateral_amount_expected = self.__calc_collateral_amount_expected( shares_amount, vault_asset_price, @@ -123,7 +123,7 @@ async def withdraw_from_vault(self, *, shares_amount: Decimal) -> None: quote_asset_model=collateral_asset, base_asset_model=vault_asset, starknet_account=self._account, - starknet_domain=self._get_endpoint_config().starknet_domain, + starknet_domain=self._get_config().signing.starknet_domain, is_buy=False, ) withdraw_request = WithdrawRequestModel( @@ -135,7 +135,7 @@ async def withdraw_from_vault(self, *, shares_amount: Decimal) -> None: ) url = self._get_url("/vault/user/withdrawals") resp = await send_post_request( - await self.get_session(), + await self._get_session(), url, NoneType, json=withdraw_request.to_api_request_json(exclude_none=True), diff --git a/x10/perpetual/user_client/user_client.py b/x10/perpetual/user_client/user_client.py deleted file mode 100644 index 5059cf2..0000000 --- a/x10/perpetual/user_client/user_client.py +++ /dev/null @@ -1,203 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime, timezone -from typing import Callable, Dict, List, Optional - -import aiohttp -from aiohttp import ClientTimeout -from aiohttp.web_exceptions import HTTPConflict -from eth_account import Account -from eth_account.messages import encode_defunct -from eth_account.signers.local import LocalAccount - -from x10.config import Config -from x10.errors import SdkError, ValidationError -from x10.models.account import AccountModel, ApiKeyRequestModel, ApiKeyResponseModel -from x10.models.client import OnboardedClientModel -from x10.perpetual.user_client.onboarding import ( - StarkKeyPair, - get_l2_keys_from_l1_account, - get_onboarding_payload, - get_sub_account_creation_payload, -) -from x10.utils.http import get_url, send_get_request, send_post_request - -L1_AUTH_SIGNATURE_HEADER = "L1_SIGNATURE" -L1_MESSAGE_TIME_HEADER = "L1_MESSAGE_TIME" -ACTIVE_ACCOUNT_HEADER = "X-X10-ACTIVE-ACCOUNT" - - -class SubAccountExists(SdkError): - pass - - -@dataclass(frozen=True) -class OnBoardedAccount: - account: AccountModel - l2_key_pair: StarkKeyPair - - -class UserClient: - __config: Config - __l1_private_key: Callable[[], str] - __session: Optional[aiohttp.ClientSession] = None - - def __init__( - self, - config: Config, - l1_private_key: Callable[[], str], - ): - super().__init__() - self.__config = config - self.__l1_private_key = l1_private_key - - def _get_url(self, base_url: str, path: str, *, query: Optional[Dict] = None, **path_params) -> str: - return get_url(f"{base_url}{path}", query=query, **path_params) - - async def get_session(self) -> aiohttp.ClientSession: - if self.__session is None: - created_session = aiohttp.ClientSession( - timeout=ClientTimeout(total=self.__config.defaults.request_timeout_seconds) - ) - self.__session = created_session - - return self.__session - - async def close_session(self): - if self.__session: - await self.__session.close() - self.__session = None - - async def onboard(self, referral_code: Optional[str] = None): - signing_account: LocalAccount = Account.from_key(self.__l1_private_key()) - key_pair = get_l2_keys_from_l1_account( - l1_account=signing_account, account_index=0, signing_domain=self.__config.signing.signing_domain - ) - payload = get_onboarding_payload( - signing_account, - signing_domain=self.__config.signing.signing_domain, - key_pair=key_pair, - referral_code=referral_code, - host=self._get_endpoint_config().onboarding_url, - ) - url = self._get_url(self._get_endpoint_config().onboarding_url, path="/auth/onboard") - onboarding_response = await send_post_request( - await self.get_session(), url, OnboardedClientModel, json=payload.to_json() - ) - - onboarded_client = onboarding_response.data - if onboarded_client is None: - raise ValidationError("No account data returned from onboarding") - - return OnBoardedAccount(account=onboarded_client.default_account, l2_key_pair=key_pair) - - async def onboard_subaccount(self, account_index: int, description: str | None = None): - request_path = "/auth/onboard/subaccount" - if description is None: - description = f"Subaccount {account_index}" - - signing_account: LocalAccount = Account.from_key(self.__l1_private_key()) - time = datetime.now(timezone.utc) - auth_time_string = time.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - l1_message = f"{request_path}@{auth_time_string}".encode(encoding="utf-8") - signable_message = encode_defunct(l1_message) - l1_signature = signing_account.sign_message(signable_message) - key_pair = get_l2_keys_from_l1_account( - l1_account=signing_account, - account_index=account_index, - signing_domain=self.__config.signing.signing_domain, - ) - payload = get_sub_account_creation_payload( - account_index=account_index, - l1_address=signing_account.address, - key_pair=key_pair, - description=description, - host=self._get_endpoint_config().onboarding_url, - ) - headers = { - L1_AUTH_SIGNATURE_HEADER: l1_signature.signature.hex(), - L1_MESSAGE_TIME_HEADER: auth_time_string, - } - url = self._get_url(self._get_endpoint_config().onboarding_url, path=request_path) - - try: - onboarding_response = await send_post_request( - await self.get_session(), - url, - AccountModel, - json=payload.to_json(), - request_headers=headers, - response_code_to_exception={HTTPConflict.status_code: SubAccountExists}, - ) - onboarded_account = onboarding_response.data - except SubAccountExists: - client_accounts = await self.get_accounts() - account_with_index = [ - account for account in client_accounts if account.account.account_index == account_index - ] - if not account_with_index: - raise ValidationError("Subaccount already exists but not found in client accounts") - onboarded_account = account_with_index[0].account - if onboarded_account is None: - raise ValidationError("No account data returned from onboarding") - return OnBoardedAccount(account=onboarded_account, l2_key_pair=key_pair) - - async def get_accounts(self) -> List[OnBoardedAccount]: - request_path = "/api/v1/user/accounts" - signing_account: LocalAccount = Account.from_key(self.__l1_private_key()) - time = datetime.now(timezone.utc) - auth_time_string = time.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - l1_message = f"{request_path}@{auth_time_string}".encode(encoding="utf-8") - signable_message = encode_defunct(l1_message) - l1_signature = signing_account.sign_message(signable_message) - headers = { - L1_AUTH_SIGNATURE_HEADER: l1_signature.signature.hex(), - L1_MESSAGE_TIME_HEADER: auth_time_string, - } - url = self._get_url(self._get_endpoint_config().onboarding_url, path=request_path) - response = await send_get_request(await self.get_session(), url, List[AccountModel], request_headers=headers) - accounts = response.data or [] - - return [ - OnBoardedAccount( - account=account, - l2_key_pair=get_l2_keys_from_l1_account( - l1_account=signing_account, - account_index=account.account_index, - signing_domain=self.__config.signing.signing_domain, - ), - ) - for account in accounts - ] - - async def create_account_api_key(self, account: AccountModel, description: str | None) -> str: - request_path = "/api/v1/user/account/api-key" - if description is None: - description = "trading api key for account {}".format(account.id) - - signing_account: LocalAccount = Account.from_key(self.__l1_private_key()) - time = datetime.now(timezone.utc) - auth_time_string = time.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - l1_message = f"{request_path}@{auth_time_string}".encode(encoding="utf-8") - signable_message = encode_defunct(l1_message) - l1_signature = signing_account.sign_message(signable_message) - headers = { - L1_AUTH_SIGNATURE_HEADER: l1_signature.signature.hex(), - L1_MESSAGE_TIME_HEADER: auth_time_string, - ACTIVE_ACCOUNT_HEADER: str(account.id), - } - url = self._get_url(self._get_endpoint_config().onboarding_url, path=request_path) - request = ApiKeyRequestModel(description=description) - response = await send_post_request( - await self.get_session(), - url, - ApiKeyResponseModel, - json=request.to_api_request_json(), - request_headers=headers, - ) - response_data = response.data - if response_data is None: - raise ValidationError("No API key data returned from onboarding") - return response_data.key - - def _get_endpoint_config(self): - return self.__config.endpoints diff --git a/x10/perpetual/user_client/onboarding.py b/x10/signing/onboarding.py similarity index 75% rename from x10/perpetual/user_client/onboarding.py rename to x10/signing/onboarding.py index df86956..f1fdad6 100644 --- a/x10/perpetual/user_client/onboarding.py +++ b/x10/signing/onboarding.py @@ -1,14 +1,21 @@ from dataclasses import dataclass from datetime import datetime, timezone from functools import cached_property +from typing import Callable, NamedTuple, TypeAlias -from eth_account.messages import SignableMessage, encode_typed_data -from eth_account.signers.local import LocalAccount +from eth_account.messages import SignableMessage, encode_defunct, encode_typed_data +from eth_typing import ChecksumAddress from fast_stark_crypto import generate_keypair_from_eth_signature, pedersen_hash from fast_stark_crypto import sign as stark_sign -register_action = "REGISTER" -sub_account_action = "CREATE_SUB_ACCOUNT" +from x10.models.account import AccountModel +from x10.utils.date import utc_now + +ACTION_REGISTER = "REGISTER" +ACTION_CREATE_SUB_ACCOUNT = "CREATE_SUB_ACCOUNT" + + +SignMessageCallback: TypeAlias = Callable[[SignableMessage], str] @dataclass(frozen=True) @@ -25,6 +32,12 @@ def private_hex(self) -> str: return hex(self.private) +@dataclass(frozen=True) +class OnBoardedAccount: + account: AccountModel + l2_key_pair: StarkKeyPair + + @dataclass(frozen=True) class AccountRegistration: account_index: int @@ -125,7 +138,7 @@ def to_json(self): def get_registration_struct_to_sign( - account_index: int, address: str, timestamp: datetime, action: str, host: str + *, account_index: int, address: ChecksumAddress, timestamp: datetime, action: str, host: str ) -> AccountRegistration: return AccountRegistration( account_index=account_index, @@ -137,7 +150,9 @@ def get_registration_struct_to_sign( ) -def get_key_derivation_struct_to_sign(account_index: int, address: str, signing_domain: str) -> SignableMessage: +def get_key_derivation_struct_to_sign( + *, account_index: int, address: ChecksumAddress, signing_domain: str +) -> SignableMessage: primary_type = "AccountCreation" domain = {"name": signing_domain} message = { @@ -164,35 +179,39 @@ def get_key_derivation_struct_to_sign(account_index: int, address: str, signing_ return encode_typed_data(full_message=structured_data) -def get_l2_keys_from_l1_account(l1_account: LocalAccount, account_index: int, signing_domain: str) -> StarkKeyPair: +def get_l2_keys_from_l1_account( + *, account_index: int, account_address: ChecksumAddress, signing_domain: str, sign_message: SignMessageCallback +) -> StarkKeyPair: struct = get_key_derivation_struct_to_sign( account_index=account_index, - address=l1_account.address, + address=account_address, signing_domain=signing_domain, ) - s = l1_account.sign_message(struct) - (private, public) = generate_keypair_from_eth_signature(s.signature.hex()) + s = sign_message(struct) + (private, public) = generate_keypair_from_eth_signature(s) return StarkKeyPair(private=private, public=public) def get_onboarding_payload( - account: LocalAccount, + *, + account_address: ChecksumAddress, signing_domain: str, key_pair: StarkKeyPair, host: str, time: datetime | None = None, referral_code: str | None = None, + sign_message: SignMessageCallback, ) -> OnboardingPayLoad: if time is None: time = datetime.now(timezone.utc) registration_payload = get_registration_struct_to_sign( - account_index=0, address=account.address, timestamp=time, action=register_action, host=host + account_index=0, address=account_address, timestamp=time, action=ACTION_REGISTER, host=host ) payload = registration_payload.to_signable_message(signing_domain=signing_domain) - l1_signature = account.sign_message(payload).signature.hex() + l1_signature = sign_message(payload) - l2_message = pedersen_hash(int(account.address, 16), key_pair.public) + l2_message = pedersen_hash(int(account_address, 16), key_pair.public) l2_r, l2_s = stark_sign(msg_hash=l2_message, private_key=key_pair.private) onboarding_payload = OnboardingPayLoad( @@ -207,8 +226,9 @@ def get_onboarding_payload( def get_sub_account_creation_payload( + *, account_index: int, - l1_address: str, + l1_address: ChecksumAddress, key_pair: StarkKeyPair, description: str, host: str, @@ -218,7 +238,7 @@ def get_sub_account_creation_payload( time = datetime.now(timezone.utc) registration_payload = get_registration_struct_to_sign( - account_index=account_index, address=l1_address, timestamp=time, action=sub_account_action, host=host + account_index=account_index, address=l1_address, timestamp=time, action=ACTION_CREATE_SUB_ACCOUNT, host=host ) l2_message = pedersen_hash(int(l1_address, 16), key_pair.public) @@ -231,3 +251,18 @@ def get_sub_account_creation_payload( account_registration=registration_payload, description=description, ) + + +class RequestSignature(NamedTuple): + value: str + time: str + + +def sign_api_request(request_path: str, sign_message: SignMessageCallback) -> RequestSignature: + now = utc_now() + now_as_string = now.strftime("%Y-%m-%dT%H:%M:%SZ") + l1_message = f"{request_path}@{now_as_string}".encode(encoding="utf-8") + encoded_l1_message = encode_defunct(l1_message) + l1_signature = sign_message(encoded_l1_message) + + return RequestSignature(l1_signature, now_as_string) diff --git a/x10/utils/http.py b/x10/utils/http.py index 2b00e0f..a179e6f 100644 --- a/x10/utils/http.py +++ b/x10/utils/http.py @@ -34,6 +34,10 @@ class RequestHeader(StrEnum): CONTENT_TYPE = "Content-Type" USER_AGENT = "User-Agent" + AUTH_ACTIVE_ACCOUNT = "X-X10-ACTIVE-ACCOUNT" + AUTH_L1_SIGNATURE = "L1_SIGNATURE" + AUTH_L1_MESSAGE_TIME = "L1_MESSAGE_TIME" + def parse_response_to_model( response_text: str, model_class: Type[ApiResponseType]