diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 270221b..cddac23 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,4 +26,4 @@ jobs: pip install pytest - name: Run tests - run: pytest \ No newline at end of file + run: PYTHONPATH=. pytest \ No newline at end of file diff --git a/sfs2x/core/field.py b/sfs2x/core/field.py index 3524998..41217bd 100644 --- a/sfs2x/core/field.py +++ b/sfs2x/core/field.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import ClassVar, Generic, Never, TypeVar +from typing import ClassVar, Generic, TypeVar from .buffer import Buffer from .registry import Packable @@ -17,5 +17,5 @@ def to_bytes(self) -> bytearray: raise NotImplementedError @classmethod - def from_buffer(cls, buf: Buffer, /) -> Never: + def from_buffer(cls, buf: Buffer, /) -> "Field": raise NotImplementedError diff --git a/sfs2x/core/types/containers.pyi b/sfs2x/core/types/containers.pyi index dc3332c..2cbacb3 100644 --- a/sfs2x/core/types/containers.pyi +++ b/sfs2x/core/types/containers.pyi @@ -38,7 +38,7 @@ class SFSObject(Field[dict[str, Field]]): class SFSArray(Field[list[Field]]): - def __init__(self, value: list[Field] | None = None) -> SFSArray: ... + def __init__(self, value: list[Field] | None = None) -> None: ... def __getitem__(self, index: int) -> Any: ... # noqa: ANN401 def __iter__(self) -> Iterator[Any]: ... diff --git a/sfs2x/protocol/__init__.py b/sfs2x/protocol/__init__.py index 2d55513..8014865 100644 --- a/sfs2x/protocol/__init__.py +++ b/sfs2x/protocol/__init__.py @@ -1,9 +1,15 @@ +try: + from sfs2x.protocol.security import AESCipher +except ImportError: + AESCipher = None + from sfs2x.protocol.constants import ControllerID, Flag, SysAction # noqa: I001 from sfs2x.protocol.exceptions import ProtocolError, UnsupportedFlagError from sfs2x.protocol.message import Message from sfs2x.protocol.codec import decode, encode __all__ = [ + "AESCipher", "ControllerID", "Flag", "Message", @@ -11,5 +17,5 @@ "SysAction", "UnsupportedFlagError", "decode", - "encode", + "encode" ] diff --git a/sfs2x/protocol/codec.py b/sfs2x/protocol/codec.py index 0958d15..cedee75 100644 --- a/sfs2x/protocol/codec.py +++ b/sfs2x/protocol/codec.py @@ -1,12 +1,14 @@ +import zlib from typing import overload from sfs2x.core import Buffer from sfs2x.core import decode as core_decode from sfs2x.core.types.containers import SFSObject -from sfs2x.protocol import Flag, Message, ProtocolError, UnsupportedFlagError +from sfs2x.protocol import AESCipher, Flag, Message, ProtocolError, UnsupportedFlagError _SHORT_MAX = 0xFFFF + def _assemble_header(payload_len: int) -> bytearray: """Assemble first byte and packet length.""" flags = Flag.BINARY @@ -27,8 +29,8 @@ def _parse_header(buf: Buffer) -> tuple[int, Flag]: """Parse first bytes and return packet length and flags.""" flags = Flag(buf.read(1)[0]) - if flags & Flag.ENCRYPTED or flags & Flag.COMPRESSED: - msg = "Encryption / Compression flags don't supported yet." + if flags & Flag.BLUEBOX: + msg = "BLUEBOX don't supported yet." raise UnsupportedFlagError(msg) length = int.from_bytes(buf.read(4 if flags & Flag.BIG_SIZE else 2), byteorder="big") @@ -40,25 +42,56 @@ def _parse_header(buf: Buffer) -> tuple[int, Flag]: return length, flags -def encode(msg: Message) -> bytearray: +def encode(msg: Message, compress_threshold: int | None = 1024, encryption_key: bytes | None = None) -> bytearray: """Encode message to bytearray, TCP-Ready.""" - payload = msg.to_sfs_object().to_bytes() - return _assemble_header(len(payload)) + payload + flags = Flag.BINARY + payload: bytes = msg.to_sfs_object().to_bytes() + + if compress_threshold is not None and len(payload) > compress_threshold: + payload = zlib.compress(payload) + flags |= Flag.COMPRESSED + + if encryption_key is not None: + if AESCipher is None: + msg = "Library pycryptodome is not installed. Install it before using encryption (pip install pycryptodome)." + raise ImportError(msg) + cipher = AESCipher(encryption_key) + payload = cipher.encrypt(payload) + flags |= Flag.ENCRYPTED + + header = _assemble_header(len(payload)) + header[0] |= flags + return header + payload @overload -def decode(buf: Buffer) -> Message: ... -@overload -def decode(raw: (bytes, bytearray, memoryview)) -> Message: ... +def decode(buf: Buffer, *, encryption_key: bytes | None = None) -> Message: ... +@overload +def decode(raw: bytes | bytearray | memoryview, *, encryption_key: bytes | None = None) -> Message: ... + # noinspection PyTypeChecker -def decode(data): +def decode(data, *, encryption_key: bytes | None = None) -> Message: """Decode buffer to message.""" buf = data if isinstance(data, Buffer) else Buffer(data) length, flags = _parse_header(buf) payload_bytes = buf.read(length) + + if flags & Flag.ENCRYPTED: + if encryption_key is None: + msg = "Can't decrypt message without encryption key." + raise ProtocolError(msg) + if AESCipher is None: + msg = "Library pycryptodome is not installed. Install it before using encryption (pip install pycryptodome)." + raise ImportError(msg) + cipher = AESCipher(encryption_key) + payload_bytes = cipher.decrypt(payload_bytes) + + if flags & Flag.COMPRESSED: + payload_bytes = zlib.decompress(payload_bytes) + root: SFSObject = core_decode(Buffer(payload_bytes)) controller = root.get("c", 0) diff --git a/sfs2x/protocol/security.py b/sfs2x/protocol/security.py new file mode 100644 index 0000000..82d6688 --- /dev/null +++ b/sfs2x/protocol/security.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass +from os import urandom +from typing import Protocol, runtime_checkable + +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad, unpad + +_KEY_LENGTH: int = 16 + + +@runtime_checkable +class Cipher(Protocol): + """Minimal symetric cipher protocol.""" + + def encrypt(self, data: bytes) -> bytes: ... + + def decrypt(self, data: bytes) -> bytes: ... + + +@dataclass(slots=True) +class AESCipher(Cipher): + """AES-128-CBC with PKCS#7 and padding (16-bit).""" + + key: bytes # 16 signs only + + def __post_init__(self) -> None: + """Check key length.""" + if len(self.key) != _KEY_LENGTH: + msg = "key must be 16 bytes long" + raise ValueError(msg) + + def encrypt(self, data: bytes) -> bytes: + """Encrypt data, using AES-128-CBC.""" + iv = urandom(16) + cipher = AES.new(self.key, AES.MODE_CBC, iv) + return iv + cipher.encrypt(pad(data, 16)) + + def decrypt(self, data: bytes) -> bytes: + """Decrypt data, using AES-128-CBC.""" + iv = data[:16] + cipher = AES.new(self.key, AES.MODE_CBC, iv) + return unpad(cipher.decrypt(data[16:]), 16) diff --git a/tests/test_payload.py b/tests/test_payload.py index f9612c5..30026f6 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -54,7 +54,7 @@ def test_long_packet(): SysAction.HANDSHAKE, make_payload(blob=big_string), ) - raw = encode(msg) + raw = encode(msg, compress_threshold=None) first_flag = Flag(raw[0]) assert first_flag & Flag.BINARY @@ -63,6 +63,23 @@ def test_long_packet(): decoded = decode(Buffer(raw)) assert decoded.payload.get("blob") == big_string +def test_encrypted_and_compressed_long_packet(): + big_string = "x" * 70000 + msg = Message( + ControllerID.SYSTEM, + SysAction.HANDSHAKE, + make_payload(blob=big_string), + ) + raw = encode(msg, compress_threshold=0, encryption_key=b'1234567890123456') + + first_flag = Flag(raw[0]) + assert first_flag & Flag.BINARY + assert first_flag & Flag.ENCRYPTED + assert first_flag & Flag.COMPRESSED + + decoded = decode(Buffer(raw), encryption_key=b'1234567890123456') + assert decoded.payload.get("blob") == big_string + def test_unpack_binary_packet(): binary_message = b'\x80\x00T\x12\x00\x03\x00\x01c\x02\x01\x00\x01a\x03\x00\x0c\x00\x01p\x12\x00\x03\x00\x01c\x08\x00\x0ctest_command\x00\x01r\x04\xff\xff\xff\xff\x00\x01p\x12\x00\x02\x00\x03num\x04\xff\xff\xff\xff\x00\x07strings\x10\x00\x02\x00\x02hi\x00\x04mega'