From f0bb7208aa420fa43c65b3f2dc530031f2439d19 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 22 Jan 2026 17:38:08 +0100 Subject: [PATCH 1/4] Add more security and allow_lists, and type annotations --- .github/workflows/ci.yaml | 2 +- doc/source/index.rst | 3 + posttroll/address_receiver.py | 4 +- posttroll/backends/zmq/__init__.py | 2 +- posttroll/backends/zmq/address_receiver.py | 16 +- posttroll/backends/zmq/message_broadcaster.py | 8 +- posttroll/backends/zmq/ns.py | 78 ++++++++-- posttroll/backends/zmq/publisher.py | 2 +- posttroll/backends/zmq/socket.py | 145 +++++++++++++----- posttroll/backends/zmq/subscriber.py | 72 ++++++--- posttroll/message.py | 14 +- posttroll/message_broadcaster.py | 10 +- posttroll/ns.py | 19 ++- posttroll/publisher.py | 2 +- posttroll/subscriber.py | 36 ++--- posttroll/tests/test_nameserver.py | 49 ++++-- posttroll/tests/test_pubsub.py | 1 + posttroll/tests/test_secure_zmq_backend.py | 38 ++++- 18 files changed, 354 insertions(+), 147 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e52e844..ee7bb85 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,7 +10,7 @@ jobs: strategy: fail-fast: true matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.12", "3.13", "3.14"] experimental: [false] steps: - name: Checkout source diff --git a/doc/source/index.rst b/doc/source/index.rst index 113678c..59d73aa 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -178,6 +178,9 @@ These settings can also be set using the posttroll config object, for example:: The posttroll configuration uses donfig, for more information, check https://donfig.readthedocs.io/en/latest/. +Nameserver also now use both a secure and unsecure port for communicating. The port for secure connection can be set with +the "secure_zmq_nameserver_port" config item, while the unsecure one use the known "nameserver_port" config item. + Generating the public and secret key pairs ****************************************** diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 5fda077..88615f8 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -56,8 +56,8 @@ def get_local_ips(): class AddressReceiver: """General thread to receive broadcast addresses.""" - def __init__(self, max_age=ten_minutes, port=None, - do_heartbeat=True, multicast_enabled=True, restrict_to_localhost=False): + def __init__(self, max_age: dt.timedelta = ten_minutes, port: int|None =None, + do_heartbeat: bool = True, multicast_enabled: bool = True, restrict_to_localhost: bool = False): """Set up the address receiver.""" self._max_age = max_age self._port = port or get_configured_address_port() diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py index c943737..1e6876f 100644 --- a/posttroll/backends/zmq/__init__.py +++ b/posttroll/backends/zmq/__init__.py @@ -13,7 +13,7 @@ context = {} -def get_context(): +def get_context() -> zmq.Context: """Provide the context to use. This function takes care of creating new contexts in case of forks. diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py index ef58dfa..0c14acb 100644 --- a/posttroll/backends/zmq/address_receiver.py +++ b/posttroll/backends/zmq/address_receiver.py @@ -1,6 +1,6 @@ """ZMQ implementation of the the simple receiver.""" -from zmq import REP +import zmq from posttroll.address_receiver import get_configured_address_port from posttroll.backends.zmq.socket import close_socket, set_up_server_socket @@ -13,7 +13,8 @@ def __init__(self, port=None, timeout=2): """Set up the receiver.""" self._port = port or get_configured_address_port() address = "tcp://*:" + str(port) - self._socket, _, self._authenticator = set_up_server_socket(REP, address) + self._socket, _, self._authenticator = set_up_server_socket(zmq.REP, address) + self._socket.setsockopt(zmq.RCVTIMEO, timeout * 1000) # timeout in milliseconds self._running = True self.timeout = timeout @@ -21,12 +22,11 @@ def __call__(self): """Receive a message.""" while self._running: try: - message = self._socket.recv_string(self.timeout) - except TimeoutError: - continue - else: - self._socket.send_string("ok") - return message, None + message = self._socket.recv_string() + except zmq.Again: + raise TimeoutError("Receive timed out") + self._socket.send_string("ok") + return message, None def close(self): """Close the receiver.""" diff --git a/posttroll/backends/zmq/message_broadcaster.py b/posttroll/backends/zmq/message_broadcaster.py index 238d9eb..b078170 100644 --- a/posttroll/backends/zmq/message_broadcaster.py +++ b/posttroll/backends/zmq/message_broadcaster.py @@ -13,18 +13,18 @@ class ZMQDesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" - def __init__(self, default_port, receivers): + def __init__(self, default_port: int, receivers: list[str]): """Set up the sender.""" self.default_port = default_port - self.receivers = receivers + self.receivers: list[str] = receivers self._shutdown_event = threading.Event() - def __call__(self, data): + def __call__(self, data: str): """Send data.""" for receiver in self.receivers: self._send_to_address(receiver, data) - def _send_to_address(self, address, data, timeout=10): + def _send_to_address(self, address: str, data: str, timeout: int = 10): """Send data to *address* and *port* without verification of response.""" # Socket to talk to server if address.find(":") == -1: diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index 17faf48..f12acdd 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -7,9 +7,21 @@ from zmq import LINGER, REP, REQ -from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket, set_up_server_socket +from posttroll import config +from posttroll.address_receiver import AddressReceiver +from posttroll.backends.zmq.socket import ( + ConfigurationError, + SocketReceiver, + close_socket, + set_up_client_socket, + set_up_server_socket, +) from posttroll.message import Message -from posttroll.ns import get_active_address, get_configured_nameserver_port +from posttroll.ns import ( + get_active_address, + get_configured_secure_zmq_nameserver_port, + get_configured_unsecure_zmq_nameserver_port, +) logger = logging.getLogger("__name__") @@ -21,18 +33,26 @@ def zmq_get_pub_address(name: str, timeout: float | int = 10, nameserver: str = For a given publisher *name* from the nameserver on *nameserver* (localhost by default). """ - nameserver_address = create_nameserver_address(nameserver) + backend = config["backend"] + if backend == "unsecure_zmq": + nameserver_address = create_unsecure_zmq_nameserver_address(nameserver) + elif backend == "secure_zmq": + nameserver_address = create_secure_zmq_nameserver_address(nameserver) + else: + raise NotImplementedError() return _fetch_address_using_socket(nameserver_address, name, timeout) -def create_nameserver_address(nameserver:str): +def create_unsecure_zmq_nameserver_address(nameserver:str): """Create the nameserver address. If `nameserver` is already preformatted and complete, the address is returned without change. """ - url_parts = urlsplit(nameserver) - port = get_configured_nameserver_port() + port = get_configured_unsecure_zmq_nameserver_port() + return _create_nameserver_address(nameserver, port) +def _create_nameserver_address(nameserver:str, port:int): + url_parts = urlsplit(nameserver) if not url_parts.scheme: nameserver_address = "tcp://" + nameserver + ":" + str(port) elif url_parts.scheme == "tcp" and url_parts.port is None: @@ -42,6 +62,15 @@ def create_nameserver_address(nameserver:str): return nameserver_address +def create_secure_zmq_nameserver_address(nameserver:str): + """Create the nameserver address. + + If `nameserver` is already preformatted and complete, the address is returned without change. + """ + port = get_configured_secure_zmq_nameserver_port() + return _create_nameserver_address(nameserver, port) + + def _fetch_address_using_socket(nameserver_address, name, timeout): try: request = Message("/oper/ns", "request", {"service": name}) @@ -83,12 +112,13 @@ class ZMQNameServer: def __init__(self): """Set up the nameserver.""" self.running: bool = True - self.listener: SocketReceiver | None = None + self.unsecure_listener: SocketReceiver | None = None + self.secure_listener: SocketReceiver | None = None self._authenticator = None - def run(self, address_receiver, address:str|None=None): + def run(self, address_receiver: AddressReceiver, address:str|None=None): """Run the listener and answer to requests.""" - port = get_configured_nameserver_port() + unsecure_port = get_configured_unsecure_zmq_nameserver_port() try: # stop was called before we could start running, exit @@ -96,31 +126,45 @@ def run(self, address_receiver, address:str|None=None): return if address is None: address = "*" - address = create_nameserver_address(address) - self.listener, _, self._authenticator = set_up_server_socket(REP, address) - logger.debug(f"Nameserver listening on port {port}") + unsecure_address = create_unsecure_zmq_nameserver_address(address) + self.unsecure_listener, _, self._authenticator = set_up_server_socket(REP, unsecure_address, backend="unsecure_zmq") + socks = [self.unsecure_listener] + ports = [unsecure_port] + try: + secure_port = get_configured_secure_zmq_nameserver_port() + secure_address = create_secure_zmq_nameserver_address(address) + self.secure_listener, _, self._authenticator = set_up_server_socket(REP, secure_address, backend="secure_zmq") + socks.append(self.secure_listener) + ports.append(secure_port) + except ConfigurationError as err: + logger.warning(f"Cannot create secure access to nameserver: {str(err)}") + + logger.debug(f"Nameserver listening on ports {ports}") socket_receiver = SocketReceiver() - socket_receiver.register(self.listener) + for sock in socks: + socket_receiver.register(sock) while self.running: try: - for msg, _ in socket_receiver.receive(self.listener, timeout=1): + for msg, sock in socket_receiver.receive(*socks, timeout=1): logger.debug("Replying to request: " + str(msg)) active_address = get_active_address(msg.data["service"], address_receiver, msg.version) - self.listener.send_unicode(str(active_address)) + sock.send_unicode(str(active_address)) except TimeoutError: continue except KeyboardInterrupt: # Needed to stop the nameserver. pass finally: - socket_receiver.unregister(self.listener) + with suppress(UnboundLocalError): + for sock in socks: + socket_receiver.unregister(sock) self.close_sockets_and_threads() def close_sockets_and_threads(self): """Close all sockets and threads.""" with suppress(AttributeError): - close_socket(self.listener) + close_socket(self.unsecure_listener) with suppress(AttributeError): self._authenticator.stop() diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index 8d2bec5..47cc0ae 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -15,7 +15,7 @@ class ZMQPublisher: """Unsecure ZMQ implementation of the publisher class.""" - def __init__(self, address, name="", min_port=None, max_port=None): + def __init__(self, address:str, name="", min_port=None, max_port=None): """Set up the publisher. Args: diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py index d6bad00..5b9a341 100644 --- a/posttroll/backends/zmq/socket.py +++ b/posttroll/backends/zmq/socket.py @@ -1,7 +1,9 @@ """ZMQ socket handling functions.""" +import logging from contextlib import suppress from functools import cache +from socket import AF_INET, gaierror, getaddrinfo from threading import Lock from urllib.parse import urlsplit, urlunsplit @@ -11,30 +13,35 @@ from posttroll import config from posttroll.backends.zmq import get_context -from posttroll.message import Message +from posttroll.message import Message, MessageError authenticator_lock = Lock() +logger = logging.getLogger(__name__) -def close_socket(sock): +def close_socket(sock: zmq.Socket[int]): """Close a zmq socket.""" with suppress(zmq.ContextTerminated): sock.setsockopt(zmq.LINGER, 1) sock.close() -def set_up_client_socket(socket_type, address, options=None): +def set_up_client_socket(socket_type: int, address: str, + options: dict[int|str, str]|None = None, backend: str|None = None) -> zmq.Socket[int]: """Set up a client (connecting) zmq socket.""" - backend = config["backend"] + options = options or dict() + backend = backend or config["backend"] if backend == "unsecure_zmq": sock = create_unsecure_client_socket(socket_type) elif backend == "secure_zmq": sock = create_secure_client_socket(socket_type) + else: + raise NotImplementedError() add_options(sock, options) sock.connect(address) return sock -def create_unsecure_client_socket(socket_type): +def create_unsecure_client_socket(socket_type: int) -> zmq.Socket[int]: """Create an unsecure client socket.""" return get_context().socket(socket_type) @@ -47,12 +54,21 @@ def add_options(sock, options=None): sock.setsockopt(param, val) -def create_secure_client_socket(socket_type): - """Create a secure client socket.""" - subscriber = get_context().socket(socket_type) +class ConfigurationError(Exception): + """Error when something is not configured correctly.""" - client_secret_key_file = config["client_secret_key_file"] - server_public_key_file = config["server_public_key_file"] +def create_secure_client_socket(socket_type: int) -> zmq.Socket[int]: + """Create a secure client socket.""" + subscriber: zmq.Socket[int] = get_context().socket(socket_type) + + try: + client_secret_key_file = config["client_secret_key_file"] + except KeyError: + raise ConfigurationError("Missing config parameter 'client_secret_key_file'") + try: + server_public_key_file = config["server_public_key_file"] + except KeyError: + raise ConfigurationError("Missing config parameter 'server_public_key_file'") client_public, client_secret = load_certificate(client_secret_key_file) subscriber.curve_secretkey = client_secret subscriber.curve_publickey = client_public @@ -63,16 +79,19 @@ def create_secure_client_socket(socket_type): return subscriber -def set_up_server_socket(socket_type:int, destination, options=None, port_interval=(None, None)): +def set_up_server_socket(socket_type: int, destination: str, options: dict[int, str]|None = None, + port_interval: tuple[int|None, int|None] = (None, None), + backend: str|None = None) -> tuple[zmq.Socket[int], ThreadAuthenticator|None]: """Set up a server (binding) socket.""" if options is None: options = {} - backend:str = config["backend"] - if backend == "unsecure_zmq": - sock = create_unsecure_server_socket(socket_type) - authenticator = None - elif backend == "secure_zmq": + _backend:str = backend or config["backend"] + if _backend == "unsecure_zmq": + sock, authenticator = create_unsecure_server_socket(socket_type) + elif _backend == "secure_zmq": sock, authenticator = create_secure_server_socket(socket_type) + else: + raise NotImplementedError() add_options(sock, options) @@ -80,55 +99,96 @@ def set_up_server_socket(socket_type:int, destination, options=None, port_interv return sock, port, authenticator -def create_unsecure_server_socket(socket_type:int) -> zmq.Socket[int]: +def create_unsecure_server_socket(socket_type: int) -> zmq.Socket[int]: """Create an unsecure server socket.""" - return get_context().socket(socket_type) + ctx = get_context() + sock = ctx.socket(socket_type) + authenticator = get_auth_thread(ctx) + allowed_hosts = config.get("authorized_client_addresses", None) + if allowed_hosts: + ips = resolve_to_ips(allowed_hosts) + if ips: + authenticator.allow(*ips) + sock.setsockopt_string(zmq.ZAP_DOMAIN, "global") -def bind(sock, destination, port_interval): + return sock, authenticator + + +def resolve_to_ips(hosts: list[str]) -> list[str]: + """Resolve hostnames to ips.""" + ips: set[str] = set() + for host in hosts: + try: + results = getaddrinfo(host, None, AF_INET) + for result in results: + ips.add(result[4][0]) + except gaierror: + logger.warning(f"Could not resolve hostname {host}") + return list(ips) + + +def bind(sock, destination: str, port_interval: tuple[int, int]) -> int: """Bind the socket to a destination. If a random port is to be chosen, the port_interval is used. """ # Check for port 0 (random port) min_port, max_port = port_interval - u__ = urlsplit(destination) - port = u__.port + url = urlsplit(destination) + port = url.port if port == 0: - dest = urlunsplit((u__.scheme, u__.hostname, - u__.path, u__.query, u__.fragment)) - port_number = sock.bind_to_random_port(dest, - min_port=min_port, - max_port=max_port) - netloc = u__.hostname + ":" + str(port_number) - destination = urlunsplit((u__.scheme, netloc, u__.path, - u__.query, u__.fragment)) + dest = urlunsplit((url.scheme, url.hostname, + url.path, url.query, url.fragment)) + port_number: int = sock.bind_to_random_port(dest, + min_port=min_port, + max_port=max_port) + netloc = url.hostname + ":" + str(port_number) + destination = urlunsplit((url.scheme, netloc, url.path, + url.query, url.fragment)) else: sock.bind(destination) port_number = port return port_number -@cache + +def enable_auth_curve(ctx): + """Enable curve on the authenticator.""" + try: + clients_public_keys_directory = config["clients_public_keys_directory"] + except KeyError: + raise ConfigurationError("Missing config parameter 'clients_public_keys_directory'") + authorized_sub_addresses = config.get("authorized_client_addresses", []) + + # Start an authenticator for this context. + authenticator_thread = get_auth_thread(ctx) + authenticator_thread.allow(*authorized_sub_addresses) + # Tell authenticator to use the certificate in a directory + authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) + return authenticator_thread + + def get_auth_thread(ctx): + with authenticator_lock: + return _get_auth_thread(ctx) + +@cache +def _get_auth_thread(ctx): """Get the authenticator thread for the context.""" thr = ThreadAuthenticator(ctx) thr.start() return thr -def create_secure_server_socket(socket_type): +def create_secure_server_socket(socket_type: int) -> tuple[zmq.Socket[int], ThreadAuthenticator]: """Create a secure server socket.""" - server_secret_key = config["server_secret_key_file"] - clients_public_keys_directory = config["clients_public_keys_directory"] - authorized_sub_addresses = config.get("authorized_client_addresses", []) + try: + server_secret_key = config["server_secret_key_file"] + except KeyError: + raise ConfigurationError("Missing config parameter 'server_secret_key_file'") ctx = get_context() # Start an authenticator for this context. - with authenticator_lock: - authenticator_thread = get_auth_thread(ctx) - authenticator_thread.allow(*authorized_sub_addresses) - # Tell authenticator to use the certificate in a directory - authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) - + authenticator_thread = enable_auth_curve(ctx) server_socket = ctx.socket(socket_type) server_public, server_secret = load_certificate(server_secret_key) @@ -162,6 +222,9 @@ def receive(self, *sockets, timeout=None): for sock in sockets: if socks.get(sock) == zmq.POLLIN: received = sock.recv_string(zmq.NOBLOCK) - yield Message.decode(received), sock + try: + yield Message.decode(received), sock + except MessageError: + logger.debug(f"Invalid message received, dropping: {received}") else: raise TimeoutError("Did not receive anything on sockets.") diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 4e7c8cf..6a0b86d 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -5,10 +5,12 @@ from time import sleep from urllib.parse import urlsplit -from zmq import PULL, SUB, SUBSCRIBE, ZMQError +from zmq import PULL, SUB, SUBSCRIBE, ContextTerminated, ZMQError +from posttroll import config from posttroll.backends.zmq import get_tcp_keepalive_options from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket +from posttroll.message import MESSAGE_VERSION LOGGER = logging.getLogger(__name__) @@ -16,7 +18,7 @@ class ZMQSubscriber: """A ZMQ subscriber class.""" - def __init__(self, addresses, topics="", message_filter=None, translate=False): + def __init__(self, *addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" self._topics = topics self._filter = message_filter @@ -30,8 +32,8 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False): self._sock_receiver = SocketReceiver() self._lock = Lock() - - self.update(addresses) + dict_addresses = [ensure_address_is_dict(addr) for addr in addresses] + self.update(dict_addresses) self._loop = None @@ -40,23 +42,27 @@ def running(self): """Check if suscriber is running.""" return self._loop - def add(self, address, topics=None): + def add(self, address: dict[str, str], topics=None): """Add *address* to the subscribing list for *topics*. It topics is None we will subscribe to already specified topics. """ with self._lock: - if address in self.addresses: + addr = ensure_address_is_dict(address) + if addr.get("supported_message_version", MESSAGE_VERSION) > MESSAGE_VERSION: + LOGGER.warning(f"Will not connect to {str(addr)}, message version mismatch") + return + if addr["URI"] in self.address_keys: return topics = topics or self._topics LOGGER.info("Subscriber adding address %s with topics %s", str(address), str(topics)) - subscriber = self._add_sub_socket(address, topics) - self.sub_addr[subscriber] = address - self.addr_sub[address] = subscriber + subscriber = self._add_sub_socket(addr, topics) + self.sub_addr[subscriber] = addr + self.addr_sub[addr["URI"]] = subscriber - def remove(self, address): + def remove(self, address: str): """Remove *address* from the subscribing list for *topics*.""" with self._lock: try: @@ -73,17 +79,20 @@ def _remove_sub_socket(self, subscriber): self._sock_receiver.unregister(subscriber) subscriber.close() + @property + def address_keys(self) -> list[str]: + return [addr["URI"] for addr in self.addresses] + def update(self, addresses): """Update with a set of addresses.""" - if isinstance(addresses, str): - addresses = [addresses, ] - current_addresses, new_addresses = set(self.addresses), set(addresses) + uri_dict = uri_keys(addresses) + current_addresses, new_addresses = set(self.address_keys), set(uri_dict.keys()) addresses_to_remove = current_addresses.difference(new_addresses) addresses_to_add = new_addresses.difference(current_addresses) for addr in addresses_to_remove: self.remove(addr) for addr in addresses_to_add: - self.add(addr) + self.add(uri_dict[addr]) return bool(addresses_to_remove or addresses_to_add) def add_hook_sub(self, address, topics, callback): @@ -160,6 +169,8 @@ def _new_messages(self, timeout): self._hooks_cb[sock](m__) except TimeoutError: yield None + except ContextTerminated: + raise except ZMQError as err: if self._loop: LOGGER.exception("Receive failed: %s", str(err)) @@ -189,20 +200,45 @@ def __del__(self): except Exception: # noqa: E722 pass - def _add_sub_socket(self, address, topics): + def _add_sub_socket(self, address: dict[str, str], topics): options = get_tcp_keepalive_options() + try: + backend = address.get("backend", "unsecure_zmq") + uri = address["URI"] + except AttributeError: + backend = config["backend"] + uri = address - subscriber = self._create_socket(SUB, address, options) + subscriber = self._create_socket(SUB, uri, options, backend) add_subscriptions(subscriber, topics) if self._sock_receiver: self._sock_receiver.register(subscriber) return subscriber - def _create_socket(self, socket_type, address, options): - return set_up_client_socket(socket_type, address, options) + def _create_socket(self, socket_type: int, address: str, options, backend: str|None = None): + return set_up_client_socket(socket_type, address, options, backend) + + +def ensure_address_is_dict(addr: dict[str, str]|str) -> dict[str, str]: + """Ensure the passed address is in dict form.""" + if isinstance(addr, dict): + res = addr.copy() + elif isinstance(addr, str): + res = dict(URI=addr) + else: + NotImplementedError(f"Don't know how to handle {type(addr)} addresses") + res.setdefault("backend", config["backend"]) + return res + +def uri_keys(addresses) -> list[str]: + res = dict() + for addr in addresses: + new_addr = ensure_address_is_dict(addr) + res[new_addr["URI"]] = new_addr + return res def add_subscriptions(socket, topics): """Add subscriptions to a socket.""" diff --git a/posttroll/message.py b/posttroll/message.py index 6ab35f2..d0808cc 100644 --- a/posttroll/message.py +++ b/posttroll/message.py @@ -20,8 +20,8 @@ from posttroll import config -_MAGICK = "pytroll:/" -MESSAGE_VERSION = config.get("message_version", "v1.2") +_MAGICK : str = "pytroll:/" +MESSAGE_VERSION : str = config.get("message_version", "v1.2") class MessageError(Exception): @@ -37,7 +37,7 @@ class MessageError(Exception): # ----------------------------------------------------------------------------- -def is_valid_subject(obj): +def is_valid_subject(obj: object): """Check that the message subject is valid. Currently we only check for empty strings. @@ -45,7 +45,7 @@ def is_valid_subject(obj): return isinstance(obj, str) and bool(obj) -def is_valid_type(obj): +def is_valid_type(obj: object): """Check that the message type is valid. Currently we only check for empty strings. @@ -53,7 +53,7 @@ def is_valid_type(obj): return isinstance(obj, str) and bool(obj) -def is_valid_sender(obj): +def is_valid_sender(obj: object): """Check that the sender is valid. Currently we only check for empty strings. @@ -61,7 +61,7 @@ def is_valid_sender(obj): return isinstance(obj, str) and bool(obj) -def is_valid_data(obj, version=MESSAGE_VERSION): +def is_valid_data(obj:object, version:str = MESSAGE_VERSION): """Check if data is JSON serializable.""" if obj: encoder = create_datetime_json_encoder_for_version(version) @@ -264,7 +264,7 @@ def _check_for_element_count(rawstr): return raw -def _check_for_magic_word(rawstr): +def _check_for_magic_word(rawstr: str | bytes): """Check for the magick word.""" try: rawstr = rawstr.decode("utf-8") diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index b7fe3e1..0443024 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -14,9 +14,9 @@ class DesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" - def __init__(self, default_port, receivers): + def __init__(self, default_port: int, receivers: list[str]): """Set settings.""" - backend = config.get("backend", "unsecure_zmq") + backend = config["backend"] if backend == "unsecure_zmq": from posttroll.backends.zmq.message_broadcaster import ZMQDesignatedReceiversSender self._sender = ZMQDesignatedReceiversSender(default_port, receivers) @@ -130,11 +130,13 @@ def __init__(self, name, address, interval, nameservers): class AddressServiceBroadcaster(MessageBroadcaster): """Class to broadcast stuff.""" - def __init__(self, name, address, data_type, interval=2, nameservers=None): + def __init__(self, name, address, data_type: str, interval: int = 2, nameservers: list[str] | None = None): """Initialize broadcaster.""" msg = message.Message("/address/%s" % name, "info", {"URI": address, - "service": data_type}).encode() + "service": data_type, + "supported_message_version": message.MESSAGE_VERSION, + "backend": config["backend"]}).encode() MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval, nameservers) diff --git a/posttroll/ns.py b/posttroll/ns.py index 48e90b6..f27e25f 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -17,22 +17,28 @@ # pylint: enable=E0611 -DEFAULT_NAMESERVER_PORT = 5557 +DEFAULT_UNSECURE_ZMQ_NAMESERVER_PORT = 5557 +DEFAULT_SECURE_ZMQ_NAMESERVER_PORT = 5558 logger = logging.getLogger(__name__) -def get_configured_nameserver_port() -> int: - """Get the configured nameserver port.""" +def get_configured_unsecure_zmq_nameserver_port() -> int: + """Get the configured unsecure zmq nameserver port.""" try: port = int(os.environ["NAMESERVER_PORT"]) warnings.warn("NAMESERVER_PORT is pending deprecation, please use POSTTROLL_NAMESERVER_PORT instead.", PendingDeprecationWarning, stacklevel=2) except KeyError: - port = DEFAULT_NAMESERVER_PORT + port = DEFAULT_UNSECURE_ZMQ_NAMESERVER_PORT return int(config.get("nameserver_port", port)) +def get_configured_secure_zmq_nameserver_port() -> int: + """Get the configured secure zmq nameserver port.""" + return int(config.get("secure_zmq_nameserver_port", DEFAULT_SECURE_ZMQ_NAMESERVER_PORT)) + + # Client functions. @@ -46,6 +52,8 @@ def get_pub_addresses(names:list[str] | None=None, timeout:float=10, nameserver: addrs = [] if names is None: names = ["", ] + elif isinstance(names, str): + names = [names] for name in names: then = dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=timeout) while dt.datetime.now(dt.timezone.utc) < then: @@ -88,7 +96,6 @@ class NameServer: def __init__(self, max_age=None, multicast_enabled=True, restrict_to_localhost=False): """Initialize nameserver.""" self.loop = True - self.listener = None self._max_age = max_age or dt.timedelta(minutes=10) self._multicast_enabled = multicast_enabled self._restrict_to_localhost = restrict_to_localhost @@ -98,7 +105,7 @@ def __init__(self, max_age=None, multicast_enabled=True, restrict_to_localhost=F from posttroll.backends.zmq.ns import ZMQNameServer self._ns = ZMQNameServer() - def run(self, address_receiver=None, nameserver_address=None): + def run(self, address_receiver: AddressReceiver|None =None, nameserver_address: str|None = None): """Run the listener and answer to requests.""" if address_receiver is None: address_receiver = AddressReceiver(max_age=self._max_age, diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 2d0ef48..959944b 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -71,7 +71,7 @@ def __init__(self, address, name="", min_port=None, max_port=None): # Initialize no heartbeat self._heartbeat = None - backend = config.get("backend", "unsecure_zmq") + backend = config["backend"] if backend not in ["unsecure_zmq", "secure_zmq"]: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") from posttroll.backends.zmq.publisher import ZMQPublisher diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 9131b53..1215cac 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -1,13 +1,12 @@ """Simple library to subscribe to messages.""" -import datetime as dt import logging -import time from posttroll import config from posttroll.address_receiver import get_configured_address_port +from posttroll.backends.zmq.subscriber import ensure_address_is_dict from posttroll.message import _MAGICK -from posttroll.ns import get_pub_address +from posttroll.ns import get_pub_addresses LOGGER = logging.getLogger(__name__) @@ -36,15 +35,20 @@ class Subscriber: """ - def __init__(self, addresses, topics="", message_filter=None, translate=False): + def __init__(self, addresses: list[str|dict[str, str]], topics="", message_filter=None, translate=False): """Initialize the subscriber.""" topics = self._magickfy_topics(topics) - backend = config.get("backend", "unsecure_zmq") + backend = config["backend"] if backend not in ["unsecure_zmq", "secure_zmq"]: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") from posttroll.backends.zmq.subscriber import ZMQSubscriber - self._subscriber = ZMQSubscriber(addresses, topics=topics, + addrs = [] + if isinstance(addresses, (str, dict)): + addresses = [addresses, ] + for addr in addresses: + addrs.append(ensure_address_is_dict(addr)) + self._subscriber = ZMQSubscriber(*addresses, topics=topics, message_filter=message_filter, translate=translate) def add(self, address, topics=None): @@ -170,35 +174,24 @@ def __init__(self, services="", topics=_MAGICK, addr_listener=False, def start(self): """Start the subscriber.""" - def _get_addr_loop(service, timeout): - """Try to get the address of *service* until for *timeout* seconds.""" - then = dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=timeout) - while dt.datetime.now(dt.timezone.utc) < then: - addrs = get_pub_address(service, self._timeout, nameserver=self._nameserver) - if addrs: - return [addr["URI"] for addr in addrs] - time.sleep(1) - return [] - # Subscribe to those services and topics. LOGGER.debug("Subscribing to topics %s", str(self._topics)) self._subscriber = Subscriber(self._addresses, self._topics, translate=self._translate) - if self._addr_listener: self._addr_listener = _AddressListener(self._subscriber, self._services, nameserver=self._nameserver) - # Search for addresses corresponding to service. for service in self._services: - addresses = _get_addr_loop(service, self._timeout) + # addresses = _get_addr_loop(service, self._timeout) + addresses = get_pub_addresses(service, self._timeout, self._nameserver) if not addresses: LOGGER.warning("Can't get any address for %s", service) continue else: - LOGGER.debug("Got address for %s: %s", + LOGGER.debug("Got addresses for %s: %s", str(service), str(addresses)) for addr in addresses: self._subscriber.add(addr) @@ -285,6 +278,7 @@ def __init__(self, subscriber, services="", nameserver="localhost"): def handle_msg(self, msg): """Handle the message *msg*.""" + addr_dict = msg.data addr_ = msg.data["URI"] status = msg.data.get("status", True) if status: @@ -293,7 +287,7 @@ def handle_msg(self, msg): if not service or service in msg_services: LOGGER.debug("Adding address %s %s", str(addr_), str(service)) - self.subscriber.add(addr_) + self.subscriber.add(addr_dict) break else: LOGGER.debug("Removing address %s", str(addr_)) diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py index 648deb3..b8d92f8 100644 --- a/posttroll/tests/test_nameserver.py +++ b/posttroll/tests/test_nameserver.py @@ -11,9 +11,14 @@ import pytest from posttroll import config -from posttroll.backends.zmq.ns import create_nameserver_address -from posttroll.message import Message -from posttroll.ns import NameServer, get_configured_nameserver_port, get_pub_address, get_pub_addresses +from posttroll.backends.zmq.ns import create_unsecure_zmq_nameserver_address +from posttroll.message import MESSAGE_VERSION, Message +from posttroll.ns import ( + NameServer, + get_configured_unsecure_zmq_nameserver_port, + get_pub_address, + get_pub_addresses, +) from posttroll.publisher import Publish from posttroll.subscriber import Subscribe from posttroll.tests.test_bbmcast import random_valid_mc_address @@ -57,11 +62,11 @@ def free_port() -> int: def create_nameserver_instance(max_age=3, multicast_enabled=True): """Create a nameserver instance.""" config.set(nameserver_port=free_port()) + config.set(secure_zmq_nameserver_port=free_port()) config.set(address_publish_port=free_port()) ns = NameServer(max_age=dt.timedelta(seconds=max_age), multicast_enabled=multicast_enabled) thr = Thread(target=ns.run) thr.start() - try: yield finally: @@ -115,7 +120,9 @@ def test_pub_addresses(multicast_enabled): assert len(res) == 1 expected = {u"status": True, u"service": [u"data_provider", u"this_data"], - u"name": u"address"} + u"name": u"address", + "backend": "unsecure_zmq", + "supported_message_version": MESSAGE_VERSION} for key, val in expected.items(): assert res[0][key] == val assert "receive_time" in res[0] @@ -124,7 +131,9 @@ def test_pub_addresses(multicast_enabled): assert len(res) == 1 expected = {u"status": True, u"service": [u"data_provider", u"this_data"], - u"name": u"address"} + u"name": u"address", + "backend": "unsecure_zmq", + "supported_message_version": MESSAGE_VERSION} for key, val in expected.items(): assert res[0][key] == val assert "receive_time" in res[0] @@ -165,7 +174,7 @@ def test_pub_sub_ctx(multicast_enabled): [True, False], ids=["mc on", "mc off"] ) -def test_pub_sub_add_rm(multicast_enabled): +def test_pub_sub_add_rm(multicast_enabled: bool): """Test adding and removing publishers.""" if multicast_enabled: if os.getenv("DISABLED_MULTICAST"): @@ -297,16 +306,16 @@ def test_switch_backend_for_nameserver(): def test_create_nameserver_address(tmp_path): """Test creating the nameserver address.""" - port = get_configured_nameserver_port() - res = create_nameserver_address("somehost") + port = get_configured_unsecure_zmq_nameserver_port() + res = create_unsecure_zmq_nameserver_address("somehost") assert res == f"tcp://somehost:{port}" preformatted_address = f"ipc://{str(tmp_path)}" - res = create_nameserver_address(preformatted_address) + res = create_unsecure_zmq_nameserver_address(preformatted_address) assert res == preformatted_address tcp_without_port = "tcp://somehost" - res = create_nameserver_address(tcp_without_port) + res = create_unsecure_zmq_nameserver_address(tcp_without_port) assert res == f"tcp://somehost:{port}" @@ -327,6 +336,24 @@ def test_no_tcp_nameserver(tmp_path): thr.join() +def test_unsecure_tcp_nameserver(tmp_path): + """Test running a nameserver without tcp and multicast.""" + nserver = NameServer() + port = get_configured_unsecure_zmq_nameserver_port() + ns_address = "tcp://localhost" + service_addresses = ["some", "addresses"] + thr = Thread(target=nserver.run, + args=(dict(cool_service=service_addresses), + ns_address)) + thr.start() + try: + addrs = get_pub_address("cool_service", nameserver=ns_address) + assert addrs == service_addresses + finally: + nserver.stop() + thr.join() + + @pytest.mark.parametrize("version", ["v1.01", "v1.2"]) def test_message_version_compatibility(tmp_path, version): """Ensure the message version of nameserver responses.""" diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 559ab61..5e07300 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -81,6 +81,7 @@ def test_pub_suber(self): msg = next(sub.recv(2)) if msg is not None: + assert isinstance(msg, Message) assert str(msg) == str(message) tested = True assert tested diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index d9fc95f..898611b 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -3,6 +3,7 @@ import os import shutil import time +from pathlib import Path from threading import Thread import zmq.auth @@ -10,7 +11,7 @@ from posttroll import config from posttroll.backends.zmq import generate_keys from posttroll.message import Message -from posttroll.ns import get_pub_address +from posttroll.ns import NameServer, get_pub_address from posttroll.publisher import Publisher, create_publisher_from_dict_config from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config from posttroll.tests.test_nameserver import create_nameserver_instance @@ -64,15 +65,15 @@ def test_ipc_pubsub_with_sec(tmp_path): server_secret_key_file=server_secret_key_file): subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) sub = create_subscriber_from_dict_config(subscriber_settings) - pub = Publisher(ipc_address) pub.start() def delayed_send(msg): + to_send = Message(subject="/hi", atype="string", data=msg) time.sleep(.2) - msg = Message(subject="/hi", atype="string", data=msg) - pub.send(str(msg)) + pub.send(str(to_send)) + thr = Thread(target=delayed_send, args=["very sensitive message"]) thr.start() try: @@ -159,6 +160,35 @@ def test_switch_to_secure_backend_for_nameserver(tmp_path): assert res == "" +def test_secure_tcp_nameserver(tmp_path: Path): + """Test nameserver can be queried on as secured port too.""" + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") + with config.set(backend="unsecure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + nserver = NameServer() + ns_address = "tcp://localhost" + service_addresses = ["some", "addresses"] + thr = Thread(target=nserver.run, + args=(dict(cool_unsecure_service=service_addresses, + cool_secure_service=service_addresses), + ns_address)) + thr.start() + try: + with config.set(backend="secure_zmq"): + addrs = get_pub_address("cool_secure_service", nameserver=ns_address) + assert addrs == service_addresses + with config.set(backend="unsecure_zmq"): + addrs = get_pub_address("cool_unsecure_service", nameserver=ns_address) + assert addrs == service_addresses + finally: + nserver.stop() + thr.join() + + def test_create_certificates_cli(tmp_path): """Test the certificate creation cli.""" name = "server" From 0381e80bee59190a86fd6311f1be088ac8261854 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 22 Jan 2026 18:53:09 +0100 Subject: [PATCH 2/4] Allow host names to be passed for white-listing --- posttroll/backends/zmq/socket.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py index 5b9a341..222d096 100644 --- a/posttroll/backends/zmq/socket.py +++ b/posttroll/backends/zmq/socket.py @@ -158,11 +158,14 @@ def enable_auth_curve(ctx): clients_public_keys_directory = config["clients_public_keys_directory"] except KeyError: raise ConfigurationError("Missing config parameter 'clients_public_keys_directory'") - authorized_sub_addresses = config.get("authorized_client_addresses", []) + allowed_hosts = config.get("authorized_client_addresses", []) # Start an authenticator for this context. authenticator_thread = get_auth_thread(ctx) - authenticator_thread.allow(*authorized_sub_addresses) + if allowed_hosts: + ips = resolve_to_ips(allowed_hosts) + if ips: + authenticator.allow(*ips) # Tell authenticator to use the certificate in a directory authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) return authenticator_thread From e3085e7722dd2850eeaee35918364cd6096b86d0 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 23 Jan 2026 10:23:29 +0100 Subject: [PATCH 3/4] Disable authentication for inproc sockets --- posttroll/backends/zmq/address_receiver.py | 6 +++--- posttroll/backends/zmq/ns.py | 4 ++-- posttroll/backends/zmq/socket.py | 15 ++++++++++++--- posttroll/logger.py | 6 +++--- posttroll/message.py | 20 ++++++++++++-------- posttroll/ns.py | 2 +- posttroll/subscriber.py | 10 +++++----- pyproject.toml | 2 +- 8 files changed, 39 insertions(+), 26 deletions(-) diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py index 0c14acb..b3c91b7 100644 --- a/posttroll/backends/zmq/address_receiver.py +++ b/posttroll/backends/zmq/address_receiver.py @@ -1,4 +1,4 @@ -"""ZMQ implementation of the the simple receiver.""" +"""ZMQ implementation of the simple receiver.""" import zmq @@ -6,13 +6,13 @@ from posttroll.backends.zmq.socket import close_socket, set_up_server_socket -class SimpleReceiver(object): +class SimpleReceiver: """Simple listing on port for address messages.""" def __init__(self, port=None, timeout=2): """Set up the receiver.""" self._port = port or get_configured_address_port() - address = "tcp://*:" + str(port) + address = "tcp://*:" + str(self._port) self._socket, _, self._authenticator = set_up_server_socket(zmq.REP, address) self._socket.setsockopt(zmq.RCVTIMEO, timeout * 1000) # timeout in milliseconds self._running = True diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index f12acdd..03dca80 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -1,4 +1,4 @@ -"""ZMQ implexentation of ns.""" +"""ZMQ implementation of ns.""" import logging from contextlib import suppress @@ -23,7 +23,7 @@ get_configured_unsecure_zmq_nameserver_port, ) -logger = logging.getLogger("__name__") +logger = logging.getLogger(__name__) nslock = Lock() diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py index 222d096..eb3417f 100644 --- a/posttroll/backends/zmq/socket.py +++ b/posttroll/backends/zmq/socket.py @@ -30,6 +30,9 @@ def set_up_client_socket(socket_type: int, address: str, """Set up a client (connecting) zmq socket.""" options = options or dict() backend = backend or config["backend"] + # Skip secure setup for inproc (internal thread communication) + if address.startswith("inproc://"): + backend = "unsecure_zmq" if backend == "unsecure_zmq": sock = create_unsecure_client_socket(socket_type) elif backend == "secure_zmq": @@ -81,13 +84,15 @@ def create_secure_client_socket(socket_type: int) -> zmq.Socket[int]: def set_up_server_socket(socket_type: int, destination: str, options: dict[int, str]|None = None, port_interval: tuple[int|None, int|None] = (None, None), - backend: str|None = None) -> tuple[zmq.Socket[int], ThreadAuthenticator|None]: + backend: str|None = None) -> tuple[zmq.Socket[int], int, ThreadAuthenticator|None]: """Set up a server (binding) socket.""" if options is None: options = {} _backend:str = backend or config["backend"] + # Skip ZAP for inproc (internal thread communication) + enable_zap = not destination.startswith("inproc://") if _backend == "unsecure_zmq": - sock, authenticator = create_unsecure_server_socket(socket_type) + sock, authenticator = create_unsecure_server_socket(socket_type, enable_zap=enable_zap) elif _backend == "secure_zmq": sock, authenticator = create_secure_server_socket(socket_type) else: @@ -99,10 +104,14 @@ def set_up_server_socket(socket_type: int, destination: str, options: dict[int, return sock, port, authenticator -def create_unsecure_server_socket(socket_type: int) -> zmq.Socket[int]: +def create_unsecure_server_socket(socket_type: int, enable_zap: bool = True) -> tuple[zmq.Socket[int], ThreadAuthenticator | None]: """Create an unsecure server socket.""" ctx = get_context() sock = ctx.socket(socket_type) + + if not enable_zap: + return sock, None + authenticator = get_auth_thread(ctx) allowed_hosts = config.get("authorized_client_addresses", None) if allowed_hosts: diff --git a/posttroll/logger.py b/posttroll/logger.py index 0253157..40efe0f 100644 --- a/posttroll/logger.py +++ b/posttroll/logger.py @@ -20,7 +20,7 @@ class PytrollFormatter(logging.Formatter): def __init__(self, fmt, datefmt): """Initialize formatter.""" - logging.Formatter.__init__(self, fmt, datefmt) + super().__init__(fmt, datefmt) def format(self, record): """Format the message.""" @@ -35,7 +35,7 @@ class PytrollHandler(logging.Handler): def __init__(self, name, port=0): """Initialize the handler.""" - logging.Handler.__init__(self) + super().__init__() self._publisher = NoisyPublisher(name, port) self._publisher.start() @@ -47,7 +47,7 @@ def emit(self, record): def close(self): """Close the handler.""" self._publisher.stop() - logging.Handler.close(self) + super().close() BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) diff --git a/posttroll/message.py b/posttroll/message.py index d0808cc..13c09d7 100644 --- a/posttroll/message.py +++ b/posttroll/message.py @@ -37,28 +37,33 @@ class MessageError(Exception): # ----------------------------------------------------------------------------- -def is_valid_subject(obj: object): +def _is_valid_nonempty_string(obj: object) -> bool: + """Check that an object is a non-empty string.""" + return isinstance(obj, str) and bool(obj) + + +def is_valid_subject(obj: object) -> bool: """Check that the message subject is valid. Currently we only check for empty strings. """ - return isinstance(obj, str) and bool(obj) + return _is_valid_nonempty_string(obj) -def is_valid_type(obj: object): +def is_valid_type(obj: object) -> bool: """Check that the message type is valid. Currently we only check for empty strings. """ - return isinstance(obj, str) and bool(obj) + return _is_valid_nonempty_string(obj) -def is_valid_sender(obj: object): +def is_valid_sender(obj: object) -> bool: """Check that the sender is valid. Currently we only check for empty strings. """ - return isinstance(obj, str) and bool(obj) + return _is_valid_nonempty_string(obj) def is_valid_data(obj:object, version:str = MESSAGE_VERSION): @@ -258,8 +263,7 @@ def _check_for_version(raw): def _check_for_element_count(rawstr): raw = re.split(r"\s+", rawstr, maxsplit=6) if len(raw) < 5: - raise MessageError("Could node decode raw string: '%s ...'" - % str(rawstr[:36])) + raise MessageError(f"Could not decode raw string: '{rawstr[:36]} ...'") return raw diff --git a/posttroll/ns.py b/posttroll/ns.py index f27e25f..f3c1251 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -175,7 +175,7 @@ def run(ns, logger): ns.run() except KeyboardInterrupt: pass - except: + except Exception: logger.exception("Something wrong happened...") raise finally: diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 1215cac..99b3139 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -317,11 +317,11 @@ def create_subscriber_from_dict_config(settings): def _get_subscriber_instance(settings): - _ = settings.pop("nameserver", None) - _ = settings.pop("port", None) - _ = settings.pop("services", None) - _ = settings.pop("addr_listener", None), - _ = settings.pop("timeout", None) + settings.pop("nameserver", None) + settings.pop("port", None) + settings.pop("services", None) + settings.pop("addr_listener", None) + settings.pop("timeout", None) return Subscriber(**settings) diff --git a/pyproject.toml b/pyproject.toml index 3fd670b..5cf3169 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "donfig", ] readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" license = "Apache-2.0" license-files = ["LICENSE.txt"] classifiers = [ From 0bb8cdb8331e0e575fff7ce4650c4e392d9618b4 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 23 Jan 2026 14:32:09 +0100 Subject: [PATCH 4/4] Fix wrong variable name --- posttroll/backends/zmq/socket.py | 2 +- posttroll/logger.py | 14 +++++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py index eb3417f..f115bb8 100644 --- a/posttroll/backends/zmq/socket.py +++ b/posttroll/backends/zmq/socket.py @@ -174,7 +174,7 @@ def enable_auth_curve(ctx): if allowed_hosts: ips = resolve_to_ips(allowed_hosts) if ips: - authenticator.allow(*ips) + authenticator_thread.allow(*ips) # Tell authenticator to use the certificate in a directory authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) return authenticator_thread diff --git a/posttroll/logger.py b/posttroll/logger.py index 40efe0f..3241734 100644 --- a/posttroll/logger.py +++ b/posttroll/logger.py @@ -1,8 +1,5 @@ """Logger module for Posttroll.""" - -# TODO: remove old hanging subscriptions - import copy import logging import logging.handlers @@ -69,7 +66,7 @@ class ColoredFormatter(logging.Formatter): def __init__(self, msg, use_color=True): """Initialize the colored formatter.""" - logging.Formatter.__init__(self, msg) + super().__init__(msg) self.use_color = use_color def format(self, record): @@ -80,18 +77,17 @@ def format(self, record): + levelname + RESET_SEQ) record2 = copy.copy(record) record2.levelname = levelname_color - return logging.Formatter.format(self, record2) + return super().format(record2) -class Logger(object): +class Logger: """The logging machine. - Contains a thread listening to incomming messages, and a thread logging. + Contains a thread listening to incoming messages, and a thread logging. """ - def __init__(self, nameserver_address="localhost", nameserver_port=16543): + def __init__(self, nameserver_address="localhost", nameserver_port=16543): # noqa: ARG002 """Initialize the logger.""" - del nameserver_address, nameserver_port self.log_thread = Thread(target=self.log) self.loop = True