Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pycryptodome
pip install pytest pytest-asyncio pycryptodome

- name: Run tests
run: PYTHONPATH=. pytest
12 changes: 12 additions & 0 deletions sfs2x/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from sfs2x.transport.base import Acceptor, Transport # noqa: I001
from sfs2x.transport.tcp import TCPAcceptor, TCPTransport
from sfs2x.transport.factory import client_from_url, server_from_url

__all__ = [
"Acceptor",
"TCPAcceptor",
"TCPTransport",
"Transport",
"client_from_url",
"server_from_url",
]
68 changes: 68 additions & 0 deletions sfs2x/transport/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Protocol

from sfs2x.core import Buffer
from sfs2x.protocol import Message, decode, encode


class Transport(ABC):
"""Abstract base class for transports."""

_closed: bool

def __init__(self) -> None:
self._closed = True

async def open(self) -> "Transport":
await self._open()
self._closed = False
return self

async def send(self, msg: Message) -> None:
if self._closed:
err_msg = "Connection closed by remote host"
raise ConnectionError(err_msg)
await self._send_raw(encode(msg))

async def recv(self) -> Message:
if self._closed:
msg = "Connection closed by remote host"
raise ConnectionError(msg)
raw = await self._recv_raw()
return decode(Buffer(raw))

async def close(self) -> None:
if not self._closed:
await self._close_impl()
self._closed = True

@abstractmethod
async def _open(self) -> None:
...

@abstractmethod
async def _send_raw(self, raw: bytes) -> None:
...

@abstractmethod
async def _recv_raw(self) -> bytes:
...

@abstractmethod
async def _close_impl(self) -> None:
...

@abstractmethod
def host(self) -> str:
...

@abstractmethod
def port(self) -> int:
...


class Acceptor(Protocol):
"""Async listener for server."""

async def __aiter__(self) -> AsyncIterator[Transport]: ... # noqa: D105
37 changes: 37 additions & 0 deletions sfs2x/transport/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from urllib.parse import urlparse

from sfs2x.transport import Acceptor, TCPAcceptor, TCPTransport, Transport


def client_from_url(url: str) -> Transport:
"""
Create transport from url.

* ``tcp://host:port``
* ``ws://host:port/path``
* ``http://host:port/path
"""
u = urlparse(url)
scheme = (u.scheme or "tcp").lower()

if scheme == "tcp":
port = u.port or 9933
return TCPTransport(u.hostname or "localhost", port)
raise NotImplementedError


def server_from_url(url: str) -> TCPAcceptor | Acceptor:
"""
Create acceptor from url.

* ``tcp://host:port``
* ``ws://host:port/path``
* ``http://host:port/path
"""
u = urlparse(url)
scheme = u.scheme.lower()

if scheme == "tcp":
port = u.port or 9933
return TCPAcceptor(u.hostname or "localhost", port)
raise NotImplementedError
113 changes: 113 additions & 0 deletions sfs2x/transport/tcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import asyncio
import logging
from asyncio import AbstractServer, IncompleteReadError, StreamReader, StreamWriter, get_running_loop, start_server
from collections.abc import AsyncIterator

from sfs2x.protocol import Flag
from sfs2x.transport import Acceptor, Transport

logger = logging.getLogger("SFS2X/TCPTransport")


class TCPTransport(Transport):
"""SmartFox Transport realisation with Async Streams."""

def __init__(self, host: str, port: int) -> None:
super().__init__()
self._host = host
self._port = port
self._reader: StreamReader | None = None
self._writer: StreamWriter | None = None

@property
def host(self) -> str:
return self._host

@property
def port(self) -> int:
return self._port

async def _open(self) -> None:
self._reader, self._writer = await asyncio.open_connection(self._host, self._port)
logger.info("Opened connection to %s:%s", self._host, self._port)

async def _send_raw(self, raw: bytes) -> None:
if not self._writer:
msg = "Connection closed by remote host"
raise ConnectionError(msg)

self._writer.write(raw)
await self._writer.drain()
logger.info("Sent %s bytes", {len(raw)})

async def _recv_raw(self) -> bytes:
if not self._reader:
msg = "Connection closed by remote host"
raise ConnectionError(msg)

try:
_flags = await self._reader.readexactly(1)
flags = Flag(_flags[0])
if not flags & Flag.BINARY:
msg = "Invalid packet type"
raise RuntimeWarning(msg)

len_bytes = await self._reader.readexactly(2)
if flags & Flag.BIG_SIZE:
len_bytes += await self._reader.readexactly(2)

length = int.from_bytes(len_bytes, byteorder="big", signed=False)
body = await self._reader.readexactly(length)
except IncompleteReadError as e:
msg = "Connection closed by remote host"
raise ConnectionError(msg) from e


logger.info("Received %s bytes from %s:%s", length, self._host, self._port)

return _flags + len_bytes + body

async def _close_impl(self) -> None:
if self._writer:
self._writer.close()
await self._writer.wait_closed()
logger.info("Closed connection to %s:%s", self._host, self._port)


class TCPAcceptor(Acceptor):
"""Server-Side implementation of the TCP Acceptor."""

def __init__(self, host: str, port: int) -> None:
super().__init__()
self._host = host
self._port = port
self._server: AbstractServer | None = None

async def __aiter__(self) -> AsyncIterator[Transport]: # type: ignore # noqa: PGH003
"""Iterate all new connections."""
loop = get_running_loop()
self._server = await start_server(self._on_conn, self._host, self._port)
logger.info("Started server on %s:%s", self._host, self._port)

self._queue: asyncio.Queue[TCPTransport] = asyncio.Queue()

async def producer() -> None:
async with self._server: # type: ignore # noqa: PGH003
await self._server.serve_forever() # type: ignore # noqa: PGH003

loop.create_task(producer()) # noqa: RUF006

try:
while True:
yield await self._queue.get()
finally:
self._server.close()

async def _on_conn(self, reader: StreamReader, writer: StreamWriter) -> None:
host, port = writer.get_extra_info("peername")
logger.info("Connection from %s:%s", host, port)
transport = TCPTransport(host, port)
transport._reader = reader # noqa: SLF001
transport._writer = writer # noqa: SLF001
transport._closed = False # noqa: SLF001
await self._queue.put(transport)
73 changes: 73 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import asyncio
import pytest
import pytest_asyncio

from sfs2x.core import Float, UtfString, Int, Double
from sfs2x.transport import client_from_url, server_from_url, TCPTransport
from sfs2x.protocol import Message, ControllerID, SysAction
from sfs2x.core.types.containers import SFSObject

@pytest_asyncio.fixture
async def echo_server(event_loop):
server_task = event_loop.create_task(run_echo_server())
await asyncio.sleep(0.2)

yield

server_task.cancel()
with pytest.raises(asyncio.CancelledError):
await server_task

async def run_echo_server():
async for conn in server_from_url("tcp://0.0.0.0:9000"):
asyncio.create_task(some_handler(conn))

async def some_handler(conn: TCPTransport):
try:
while True:
msg = await conn.recv()
obj = msg.payload.value.get('input')
obj.value *= 2
msg.payload['resp'] = obj
await conn.send(msg)
except ConnectionError:
await conn.close()

@pytest.mark.asyncio
async def test_tcp_echo_roundtrip(echo_server):
conn = await client_from_url("tcp://localhost:9000").open()
for value in [UtfString('Friday, '), Int(8), Double(123.12)]:
test_msg = Message(ControllerID.SYSTEM, SysAction.PING_PONG, SFSObject({'input': value}))
await conn.send(test_msg)

answer = await conn.recv()
assert answer.controller == test_msg.controller
assert answer.action == test_msg.action
assert answer.payload.get('resp') == value.value * 2
await conn.close()

@pytest.mark.asyncio
async def test_msm_server():
conn = await client_from_url("tcp://107.20.67.227").open()

session_info = SFSObject()
session_info.put_utf_string("api", "1.0.3")
session_info.put_utf_string("cl", "UnityPlayer::")
session_info.put_bool("bin", True)

await conn.send(Message(ControllerID.SYSTEM, SysAction.HANDSHAKE, session_info))

handshake = await conn.recv()
assert handshake.controller == ControllerID.SYSTEM
assert handshake.action == SysAction.HANDSHAKE

auth_info = SFSObject()
auth_info.put_utf_string("zn", "MySingingPenis")
auth_info.put_utf_string("un", "")
auth_info.put_utf_string("pw", "")
auth_info.put_sfs_object("p", SFSObject())

await conn.send(Message(ControllerID.SYSTEM, SysAction.LOGIN, auth_info))

resp = await conn.recv()
assert resp.payload['ec'] == 1