diff --git a/pytuya/__init__.py b/pytuya/__init__.py index b031bd2..42d5653 100644 --- a/pytuya/__init__.py +++ b/pytuya/__init__.py @@ -51,6 +51,9 @@ PROTOCOL_VERSION_BYTES_31 = b'3.1' PROTOCOL_VERSION_BYTES_33 = b'3.3' +PROTOCOL_VERSION_3_1 = 3.1 +PROTOCOL_VERSION_3_3 = 3.3 + IS_PY2 = sys.version_info[0] == 2 class AESCipher(object): @@ -181,6 +184,8 @@ def _send_receive(self, payload): return data def set_version(self, version): + if (version != PROTOCOL_VERSION_3_1) and (version != PROTOCOL_VERSION_3_3): + raise ValueError("Unsupported verison") self.version = version def generate_payload(self, command, data=None): @@ -214,7 +219,7 @@ def generate_payload(self, command, data=None): json_payload = json_payload.encode('utf-8') log.debug('json_payload=%r', json_payload) - if self.version == 3.3: + if self.version == PROTOCOL_VERSION_3_3: self.cipher = AESCipher(self.local_key) # expect to connect and then disconnect to set new json_payload = self.cipher.encrypt(json_payload, False) self.cipher = None @@ -302,7 +307,7 @@ def status(self): if not isinstance(result, str): result = result.decode() result = json.loads(result) - elif self.version == 3.3: + elif self.version == PROTOCOL_VERSION_3_3: cipher = AESCipher(self.local_key) result = cipher.decrypt(result, False) log.debug('decrypted result=%r', result) diff --git a/tests.py b/tests.py index 481d06d..0b44fd1 100755 --- a/tests.py +++ b/tests.py @@ -51,7 +51,9 @@ def check_data_frame(data, expected_prefix, encrypted=True): frame_ok = True if prefix != pytuya.hex2bin(expected_prefix): frame_ok = False - elif suffix != pytuya.hex2bin("000000000000aa55"): + elif suffix[-4:] != pytuya.hex2bin("0000aa55"): + # We only check for the trailing byte signature + # We could extend the test to also check the CRC if we wanted. frame_ok = False elif encrypted: if payload_len != len(version) + len(checksum) + len(encrypted_json) + len(suffix):