Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions roborock/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ class RoborockTooManyRequest(RoborockException):

class RoborockRateLimit(RoborockException):
"""Class for our rate limits exceptions."""


class RoborockNoResponseFromBaseURL(RoborockException):
"""We could not find an url that had a record of the given account."""
130 changes: 87 additions & 43 deletions roborock/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import secrets
import string
import time
from dataclasses import dataclass

import aiohttp
from aiohttp import ContentTypeError, FormData
Expand All @@ -22,14 +23,28 @@
RoborockInvalidEmail,
RoborockInvalidUserAgreement,
RoborockMissingParameters,
RoborockNoResponseFromBaseURL,
RoborockNoUserAgreement,
RoborockRateLimit,
RoborockTooFrequentCodeRequests,
RoborockTooManyRequest,
RoborockUrlException,
)

_LOGGER = logging.getLogger(__name__)
BASE_URLS = [
"https://usiot.roborock.com",
"https://euiot.roborock.com",
"https://cniot.roborock.com",
"https://ruiot.roborock.com",
]


@dataclass
class IotLoginInfo:
"""Information about the login to the iot server."""

base_url: str
country_code: str
country: str


class RoborockApiClient:
Expand All @@ -49,41 +64,64 @@ class RoborockApiClient:
_login_limiter = Limiter(_LOGIN_RATES)
_home_data_limiter = Limiter(_HOME_DATA_RATES)

def __init__(self, username: str, base_url=None, session: aiohttp.ClientSession | None = None) -> None:
def __init__(
self, username: str, base_url: str | None = None, session: aiohttp.ClientSession | None = None
) -> None:
"""Sample API Client."""
self._username = username
self._default_url = "https://euiot.roborock.com"
self.base_url = base_url
self._base_url = base_url
self._device_identifier = secrets.token_urlsafe(16)
self.session = session

async def _get_base_url(self) -> str:
if not self.base_url:
url_request = PreparedRequest(self._default_url, self.session)
response = await url_request.request(
"post",
"/api/v1/getUrlByEmail",
params={"email": self._username, "needtwostepauth": "false"},
)
if response is None:
raise RoborockUrlException("get url by email returned None")
response_code = response.get("code")
if response_code != 200:
_LOGGER.info("Get base url failed for %s with the following context: %s", self._username, response)
if response_code == 2003:
raise RoborockInvalidEmail("Your email was incorrectly formatted.")
elif response_code == 1001:
raise RoborockMissingParameters(
"You are missing parameters for this request, are you sure you entered your username?"
self._iot_login_info: IotLoginInfo | None = None

async def _get_iot_login_info(self) -> IotLoginInfo:
Comment thread
Lash-L marked this conversation as resolved.
if self._iot_login_info is None:
valid_urls = BASE_URLS if self._base_url is None else [self._base_url]
for iot_url in valid_urls:
url_request = PreparedRequest(iot_url, self.session)
response = await url_request.request(
"post",
"/api/v1/getUrlByEmail",
params={"email": self._username, "needtwostepauth": "false"},
)
if response is None:
continue
response_code = response.get("code")
if response_code != 200:
if response_code == 2003:
raise RoborockInvalidEmail("Your email was incorrectly formatted.")
elif response_code == 1001:
raise RoborockMissingParameters(
"You are missing parameters for this request, are you sure you entered your username?"
)
else:
raise RoborockException(f"{response.get('msg')} - response code: {response_code}")
if response["data"]["countrycode"] is not None:
self._iot_login_info = IotLoginInfo(
base_url=response["data"]["url"],
country=response["data"]["country"],
country_code=response["data"]["countrycode"],
)
elif response_code == 9002:
raise RoborockTooManyRequest("Please temporarily disable making requests and try again later.")
raise RoborockUrlException(f"error code: {response_code} msg: {response.get('error')}")
response_data = response.get("data")
if response_data is None:
raise RoborockUrlException("response does not have 'data'")
self.base_url = response_data.get("url")
return self.base_url
return self._iot_login_info
raise RoborockNoResponseFromBaseURL(
"No account was found for any base url we tried. Either your email is incorrect or we do not have a"
" record of the roborock server your device is on."
)
return self._iot_login_info

@property
async def base_url(self):
if self._base_url is not None:
return self._base_url
return (await self._get_iot_login_info()).base_url

@property
async def country(self):
return (await self._get_iot_login_info()).country

@property
async def country_code(self):
return (await self._get_iot_login_info()).country_code

def _get_header_client_id(self):
md5 = hashlib.md5()
Expand Down Expand Up @@ -167,7 +205,7 @@ async def request_code(self) -> None:
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})

Expand Down Expand Up @@ -198,7 +236,7 @@ async def request_code_v4(self) -> None:
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(
base_url,
Expand Down Expand Up @@ -229,7 +267,7 @@ async def request_code_v4(self) -> None:

async def _sign_key_v3(self, s: str) -> str:
"""Sign a randomly generated string."""
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})

Expand All @@ -249,14 +287,20 @@ async def _sign_key_v3(self, s: str) -> str:

return code_response["data"]["k"]

async def code_login_v4(self, code: int | str, country: str, country_code: int) -> UserData:
async def code_login_v4(
self, code: int | str, country: str | None = None, country_code: int | None = None
) -> UserData:
"""
Login via code authentication.
:param code: The code from the email.
:param country: The two-character representation of the country, i.e. "US"
:param country_code: the country phone number code i.e. 1 for US.
"""
base_url = await self._get_base_url()
base_url = await self.base_url
if country is None:
country = await self.country
if country_code is None:
country_code = await self.country_code
header_clientid = self._get_header_client_id()
x_mercy_ks = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(16))
x_mercy_k = await self._sign_key_v3(x_mercy_ks)
Expand Down Expand Up @@ -304,7 +348,7 @@ async def pass_login(self, password: str) -> UserData:
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()

login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
Expand Down Expand Up @@ -343,7 +387,7 @@ async def pass_login_v3(self, password: str) -> UserData:
raise NotImplementedError("Pass_login_v3 has not yet been implemented")

async def code_login(self, code: int | str) -> UserData:
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()

login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
Expand Down Expand Up @@ -376,7 +420,7 @@ async def code_login(self, code: int | str) -> UserData:
return UserData.from_dict(user_data)

async def _get_home_id(self, user_data: UserData):
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
home_id_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
home_id_response = await home_id_request.request(
Expand Down Expand Up @@ -547,7 +591,7 @@ async def execute_scene(self, user_data: UserData, scene_id: int) -> None:

async def get_products(self, user_data: UserData) -> ProductResponse:
"""Gets all products and their schemas, good for determining status codes and model numbers."""
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
product_response = await product_request.request(
Expand All @@ -565,7 +609,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse:
raise RoborockException("product result was an unexpected type")

async def download_code(self, user_data: UserData, product_id: int):
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
request = {"apilevel": 99999, "productids": [product_id], "type": 2}
Expand All @@ -578,7 +622,7 @@ async def download_code(self, user_data: UserData, product_id: int):
return response["data"][0]["url"]

async def download_category_code(self, user_data: UserData):
base_url = await self._get_base_url()
base_url = await self.base_url
header_clientid = self._get_header_client_id()
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
response = await product_request.request(
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def mock_rest() -> aioresponses:
with aioresponses() as mocked:
# Match the base URL and allow any query params
mocked.post(
re.compile(r"https://euiot\.roborock\.com/api/v1/getUrlByEmail.*"),
re.compile(r"https://.*iot\.roborock\.com/api/v1/getUrlByEmail.*"),
status=200,
payload={
"code": 200,
Expand Down
2 changes: 1 addition & 1 deletion tests/mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@
BASE_URL_REQUEST = {
"code": 200,
"msg": "success",
"data": {"url": "https://sample.com"},
"data": {"url": "https://sample.com", "countrycode": 1, "country": "US"},
}

GET_CODE_RESPONSE = {"code": 200, "msg": "success", "data": None}
Expand Down
49 changes: 29 additions & 20 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,28 @@ async def test_get_base_url_no_url():
rc = RoborockApiClient("sample@gmail.com")
with patch("roborock.web_api.PreparedRequest.request") as mock_request:
mock_request.return_value = BASE_URL_REQUEST
await rc._get_base_url()
assert rc.base_url == "https://sample.com"
await rc._get_iot_login_info()
assert await rc.base_url == "https://sample.com"


async def test_request_code():
rc = RoborockApiClient("sample@gmail.com")
with patch("roborock.web_api.RoborockApiClient._get_base_url"), patch(
"roborock.web_api.RoborockApiClient._get_header_client_id"
), patch("roborock.web_api.PreparedRequest.request") as mock_request:
with (
patch("roborock.web_api.RoborockApiClient._get_iot_login_info"),
patch("roborock.web_api.RoborockApiClient._get_header_client_id"),
patch("roborock.web_api.PreparedRequest.request") as mock_request,
):
mock_request.return_value = GET_CODE_RESPONSE
await rc.request_code()


async def test_get_home_data():
rc = RoborockApiClient("sample@gmail.com")
with patch("roborock.web_api.RoborockApiClient._get_base_url"), patch(
"roborock.web_api.RoborockApiClient._get_header_client_id"
), patch("roborock.web_api.PreparedRequest.request") as mock_prepared_request:
with (
patch("roborock.web_api.RoborockApiClient._get_iot_login_info"),
patch("roborock.web_api.RoborockApiClient._get_header_client_id"),
patch("roborock.web_api.PreparedRequest.request") as mock_prepared_request,
):
mock_prepared_request.side_effect = [
{"code": 200, "msg": "success", "data": {"rrHomeId": 1}},
{"code": 200, "success": True, "result": HOME_DATA_RAW},
Expand Down Expand Up @@ -117,10 +121,11 @@ async def test_get_prop():
home_data = HomeData.from_dict(HOME_DATA_RAW)
device_info = DeviceData(device=home_data.devices[0], model=home_data.products[0].model)
rmc = RoborockMqttClientV1(UserData.from_dict(USER_DATA), device_info)
with patch("roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_status") as get_status, patch(
"roborock.version_1_apis.roborock_client_v1.RoborockClientV1.send_command"
), patch("roborock.version_1_apis.roborock_client_v1.AttributeCache.async_value"), patch(
"roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_dust_collection_mode"
with (
patch("roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_status") as get_status,
patch("roborock.version_1_apis.roborock_client_v1.RoborockClientV1.send_command"),
patch("roborock.version_1_apis.roborock_client_v1.AttributeCache.async_value"),
patch("roborock.version_1_apis.roborock_mqtt_client_v1.RoborockMqttClientV1.get_dust_collection_mode"),
):
status = S7MaxVStatus.from_dict(STATUS)
status.dock_type = RoborockDockTypeCode.auto_empty_dock_pure
Expand Down Expand Up @@ -194,8 +199,9 @@ async def test_disconnect_failure(connected_mqtt_client: RoborockMqttClientV1) -
assert connected_mqtt_client.is_connected()

# Make the MQTT client returns with an error when disconnecting
with patch("roborock.cloud_api.mqtt.Client.disconnect", return_value=mqtt.MQTT_ERR_PROTOCOL), pytest.raises(
RoborockException, match="Failed to disconnect"
with (
patch("roborock.cloud_api.mqtt.Client.disconnect", return_value=mqtt.MQTT_ERR_PROTOCOL),
pytest.raises(RoborockException, match="Failed to disconnect"),
):
await connected_mqtt_client.async_disconnect()

Expand Down Expand Up @@ -231,8 +237,9 @@ async def test_subscribe_failure(

response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2))

with patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)), pytest.raises(
RoborockException, match="Failed to subscribe"
with (
patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)),
pytest.raises(RoborockException, match="Failed to subscribe"),
):
await mqtt_client.async_connect()

Expand Down Expand Up @@ -298,8 +305,9 @@ async def test_publish_failure(

msg = mqtt.MQTTMessageInfo(0)
msg.rc = mqtt.MQTT_ERR_PROTOCOL
with patch("roborock.cloud_api.mqtt.Client.publish", return_value=msg), pytest.raises(
RoborockException, match="Failed to publish"
with (
patch("roborock.cloud_api.mqtt.Client.publish", return_value=msg),
pytest.raises(RoborockException, match="Failed to publish"),
):
await connected_mqtt_client.get_room_mapping()

Expand All @@ -308,7 +316,8 @@ async def test_future_timeout(
connected_mqtt_client: RoborockMqttClientV1,
) -> None:
"""Test a timeout raised while waiting for an RPC response."""
with patch("roborock.roborock_future.async_timeout.timeout", side_effect=asyncio.TimeoutError), pytest.raises(
RoborockTimeout, match="Timeout after"
with (
patch("roborock.roborock_future.async_timeout.timeout", side_effect=asyncio.TimeoutError),
pytest.raises(RoborockTimeout, match="Timeout after"),
):
await connected_mqtt_client.get_room_mapping()
Loading