Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
pip install --no-deps -e .
- name: Run tests
run: |
pytest --cov=posttroll posttroll/tests --cov-report=xml
pytest -v --cov=posttroll posttroll/tests --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
Expand Down
3 changes: 2 additions & 1 deletion posttroll/address_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from posttroll import config
from posttroll.bbmcast import MulticastReceiver, get_configured_broadcast_port
from posttroll.message import Message
from posttroll.message import Message, version_needed
from posttroll.publisher import Publish

__all__ = ("AddressReceiver", "getaddress")
Expand Down Expand Up @@ -178,6 +178,7 @@ def process_address_message(self, data, pub):
if addr not in self._addresses:
logger.info("nameserver: publish add '%s'",
str(msg))
msg.version = version_needed(msg.data, False)
pub.send(msg.encode())
self._add(addr, metadata)

Expand Down
13 changes: 3 additions & 10 deletions posttroll/backends/zmq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,15 @@ def get_context() -> zmq.Context:
return context[pid]


def destroy_context(linger=None):
def destroy_context(linger: int|None = None):
"""Destroy the context."""
pid = os.getpid()
context.pop(pid).destroy(linger)


def _set_tcp_keepalive(socket):
"""Set the tcp keepalive parameters on *socket*."""
keepalive_options = get_tcp_keepalive_options()
for param, value in keepalive_options.items():
socket.setsockopt(param, value)


def get_tcp_keepalive_options():
def get_tcp_keepalive_options() -> dict[int, int]:
"""Get the tcp_keepalive options from config."""
keepalive_options = dict()
keepalive_options: dict[int, int] = dict()
for opt in ("tcp_keepalive",
"tcp_keepalive_cnt",
"tcp_keepalive_idle",
Expand Down
9 changes: 4 additions & 5 deletions posttroll/backends/zmq/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import zmq

from posttroll.backends.zmq import get_tcp_keepalive_options
from posttroll.backends.zmq.socket import close_socket, set_up_server_socket

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,12 +41,12 @@ def start(self):
return self

def _create_socket(self):
options = get_tcp_keepalive_options()
self.publish_socket, port, self._authenticator = set_up_server_socket(zmq.PUB, self.destination, options,
(self.min_port, self.max_port))
self.publish_socket, port, self._authenticator = set_up_server_socket(zmq.PUB, self.destination,
port_interval=(self.min_port,
self.max_port))
self.port_number = port

def send(self, msg):
def send(self, msg: str):
"""Send the given message."""
with self._pub_lock:
self.publish_socket.send_string(msg)
Expand Down
21 changes: 9 additions & 12 deletions posttroll/backends/zmq/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from zmq.auth.thread import ThreadAuthenticator

from posttroll import config
from posttroll.backends.zmq import get_context
from posttroll.backends.zmq import get_context, get_tcp_keepalive_options
from posttroll.message import Message, MessageError

authenticator_lock = Lock()
Expand All @@ -26,9 +26,8 @@ def close_socket(sock: zmq.Socket[int]):


def set_up_client_socket(socket_type: int, address: str,
options: dict[int|str, str]|None = None, backend: str|None = None) -> zmq.Socket[int]:
options: dict[int, int]|None = None, backend: str|None = None) -> zmq.Socket[int]:
"""Set up a client (connecting) zmq socket."""
options = options or dict()
backend = backend or config["backend"]
# Skip secure setup for inproc (internal thread communication)
if address.startswith("inproc://"):
Expand All @@ -39,7 +38,7 @@ def set_up_client_socket(socket_type: int, address: str,
sock = create_secure_client_socket(socket_type)
else:
raise NotImplementedError()
add_options(sock, options)
_add_options(sock, options)
sock.connect(address)
return sock

Expand All @@ -49,11 +48,11 @@ def create_unsecure_client_socket(socket_type: int) -> zmq.Socket[int]:
return get_context().socket(socket_type)


def add_options(sock, options=None):
def _add_options(sock: zmq.Socket[int], options: dict[int, int]|None = None):
"""Add options to a socket."""
if not options:
return
for param, val in options.items():
combined_options = get_tcp_keepalive_options()
combined_options.update(options or {})
for param, val in combined_options.items():
sock.setsockopt(param, val)


Expand Down Expand Up @@ -82,12 +81,10 @@ def create_secure_client_socket(socket_type: int) -> zmq.Socket[int]:
return subscriber


def set_up_server_socket(socket_type: int, destination: str, options: dict[int, str]|None = None,
def set_up_server_socket(socket_type: int, destination: str, options: dict[int, int]|None = None,
port_interval: tuple[int|None, int|None] = (None, None),
backend: str|None = None) -> tuple[zmq.Socket[int], int, ThreadAuthenticator|None]:
"""Set up a server (binding) socket."""
if options is None:
options = {}
_backend:str = backend or config["backend"]
# Skip ZAP for inproc (internal thread communication)
enable_zap = not destination.startswith("inproc://")
Expand All @@ -98,7 +95,7 @@ def set_up_server_socket(socket_type: int, destination: str, options: dict[int,
else:
raise NotImplementedError()

add_options(sock, options)
_add_options(sock, options)

port = bind(sock, destination, port_interval)
return sock, port, authenticator
Expand Down
12 changes: 4 additions & 8 deletions posttroll/backends/zmq/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from zmq import PULL, SUB, SUBSCRIBE, ContextTerminated, ZMQError

from posttroll import config
from posttroll.backends.zmq import get_tcp_keepalive_options
from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket
from posttroll.message import CURRENT_MESSAGE_VERSION

Expand Down Expand Up @@ -116,8 +115,7 @@ def add_hook_pull(self, address, callback):
specified subscription. Good for pushed 'inproc' messages from another thread.
"""
LOGGER.info("Subscriber adding PULL hook %s", str(address))
options = get_tcp_keepalive_options()
socket = self._create_socket(PULL, address, options)
socket = self._create_socket(PULL, address)
if self._sock_receiver:
self._sock_receiver.register(socket)
self._add_hook(socket, callback)
Expand Down Expand Up @@ -200,23 +198,21 @@ def __del__(self):
pass

def _add_sub_socket(self, address: dict[str, str], topics):

options = get_tcp_keepalive_options()
try:
backend = address.get("backend", "unsecure_zmq")
uri = address["URI"]
except AttributeError:
backend = config["backend"]
uri = address

subscriber = self._create_socket(SUB, uri, options, backend)
subscriber = self._create_socket(SUB, uri, backend=backend)
add_subscriptions(subscriber, topics)

if self._sock_receiver:
self._sock_receiver.register(subscriber)
return subscriber

def _create_socket(self, socket_type: int, address: str, options, backend: str|None = None):
def _create_socket(self, socket_type: int, address: str, options: dict[int,int]|None=None, backend: str|None = None):
return set_up_client_socket(socket_type, address, options, backend)


Expand All @@ -227,7 +223,7 @@ def ensure_address_is_dict(addr: dict[str, str]|str) -> dict[str, str]:
elif isinstance(addr, str):
res = dict(URI=addr)
else:
NotImplementedError(f"Don't know how to handle {type(addr)} addresses")
raise NotImplementedError(f"Don't know how to handle {type(addr)} addresses")
res.setdefault("backend", config["backend"])
return res

Expand Down
69 changes: 58 additions & 11 deletions posttroll/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
from functools import partial
from typing import Any, Callable
from warnings import warn

from posttroll import config

Expand Down Expand Up @@ -98,9 +99,11 @@ class Message:

def __init__(self, subject:str="", atype:str="", data:str|dict[str, Any]="", binary:bool=False,
rawstr:str|bytes|None=None, version:str|None=None):
"""Initialize a Message from a subject, type and data, or from a raw string."""
"""Initialize a Message from a subject, type and data."""
if rawstr:
self.__dict__ = _decode(rawstr)
warn("The `rawstr` argument of `Message` instantiation is being depracated in favour of the class method"
"`Message.from_string`.", PendingDeprecationWarning)
self._decode_string(rawstr)
else:
self.subject:str = subject
self.type:str = atype
Expand Down Expand Up @@ -133,10 +136,10 @@ def head(self):
self._validate()
return _encode(self, head=True)

@staticmethod
def decode(rawstr:str|bytes):
@classmethod
def decode(cls, rawstr:str|bytes):
"""Decode a raw string into a Message."""
return Message(rawstr=rawstr)
return cls.from_string(rawstr)

def encode(self) -> str:
"""Encode a Message to a raw string."""
Expand Down Expand Up @@ -167,14 +170,58 @@ def _validate(self):
raise MessageError("Invalid data: data is not JSON serializable: %s"
% str(self.data))

def __getstate__(self):
def __getstate__(self) -> str:
"""Get the Message state for pickle()."""
return self.encode()

def __setstate__(self, state):
def __setstate__(self, state: str):
"""Set the Message when unpickling."""
self.__dict__.clear()
self.__dict__ = _decode(state)
self._decode_string(state)

def _decode_string(self, rawstr:str|bytes):
"""Convert a raw string to a Message."""
rawstr = _check_for_magic_word(rawstr)

raw = _check_for_element_count(rawstr)
version = _check_for_version(raw)

# Start to build message
self.subject = raw[0].strip()
self.type = raw[1].strip()
self.sender = raw[2].strip()
self.time = dt.datetime.fromisoformat(raw[3].strip())
self.version = version

# Data part
try:
mimetype = raw[5].lower()
except IndexError:
mimetype = None
self.data = ""
self.binary = False
return
else:
data = raw[6]

if mimetype == "application/json":
try:
self.data = json.loads(raw[6], object_hook=datetime_decoder)
self.binary = False
except ValueError:
raise MessageError("JSON decode failed on '%s ...'" % raw[6][:36])
elif mimetype == "text/ascii":
self.data = str(data)
self.binary = False
elif mimetype == "binary/octet-stream":
self.data = data
self.binary = True
else:
raise MessageError("Unknown mime-type '%s'" % mimetype)

@classmethod
def from_string(cls, rawstr:str|bytes):
"""Create a message from string."""
return cls(rawstr=rawstr)


# -----------------------------------------------------------------------------
Expand All @@ -184,7 +231,7 @@ def __setstate__(self, state):
# -----------------------------------------------------------------------------


def _is_valid_version(version):
def _is_valid_version(version: str) -> bool:
"""Check version."""
return version <= CURRENT_MESSAGE_VERSION

Expand Down Expand Up @@ -322,7 +369,7 @@ def _encode(msg:Message, head:bool=False, binary:bool=False) -> str:
version = render_version(msg.version, msg.data, binary)

rawstr = str(_MAGICK) + u"{0:s} {1:s} {2:s} {3:s} {4:s}".format(
msg.subject, msg.type, msg.sender, msg.time.isoformat(), version)
msg.subject, msg.type, msg.sender, create_datetime_encoder_for_version(version)(msg.time), version)
if not head and msg.data:
mimetype, data = _encode_data(msg.data, binary, version)
return " ".join((rawstr, mimetype, data))
Expand Down
6 changes: 3 additions & 3 deletions posttroll/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def start(self):
self._publisher.start()
return self

def send(self, msg):
def send(self, msg: str):
"""Send the given message."""
return self._publisher.send(msg)

Expand Down Expand Up @@ -181,7 +181,7 @@ def start(self):
self._broadcaster.start()
return self

def send(self, msg):
def send(self, msg: str):
"""Send a *msg*."""
return self._publisher.send(msg)

Expand Down Expand Up @@ -241,7 +241,7 @@ class Publish:

"""

def __init__(self, name, port=0, aliases=None, broadcast_interval=2, nameservers=None,
def __init__(self, name: str, port=0, aliases=None, broadcast_interval=2, nameservers=None,
min_port=None, max_port=None):
"""Initialize the class."""
settings = {"name": name, "port": port, "min_port": min_port, "max_port": max_port,
Expand Down
Loading
Loading