diff --git a/imdclient/IMDClient.py b/imdclient/IMDClient.py index a0f71dd..1efa567 100644 --- a/imdclient/IMDClient.py +++ b/imdclient/IMDClient.py @@ -54,7 +54,7 @@ class IMDClient: buffer_size : int (optional) :class:`IMDFrameBuffer` will be filled with as many :class:`IMDFrame` fit in `buffer_size` bytes [``10MB``] timeout : int, optional - Timeout for the socket in seconds [``5``] + Timeout for the socket in seconds [``600``] continue_after_disconnect : bool, optional [``None``] If True, the client will attempt to change the simulation engine's waiting behavior to non-blocking after the client disconnects. If False, the client will attempt to change it @@ -73,6 +73,14 @@ def __init__( continue_after_disconnect=None, **kwargs, ): + + # Warn if timeout is overly optimistic + if "timeout" in kwargs and kwargs["timeout"] <= 1: + logger.warning( + f"IMDClient: timeout value of {kwargs['timeout']} second(s) is very low and may lead to " + "premature disconnection by the client. Consider using a higher value (default is 600 seconds)." + ) + self._stopped = False self._conn = self._connect_to_server(host, port, socket_bufsize) self._imdsinfo = self._await_IMD_handshake() @@ -169,21 +177,26 @@ def get_imdframe(self): if self._multithreaded: try: return self._buf.pop_full_imdframe() - except EOFError: + except EOFError as e: # in this case, consumer is already finished # and doesn't need to be notified + logger.debug(f"IMDClient: Multithreaded connection ended") self._disconnect() self._stopped = True - if self._error_queue.qsize(): - raise EOFError(f"{self._error_queue.get()}") - raise EOFError + try: + error = self._error_queue.get_nowait() + except queue.Empty: + raise EOFError from e + else: + raise EOFError(str(error)) from error else: try: return self._producer._get_imdframe() - except EOFError: + except EOFError as e: + logger.debug(f"IMDClient: Single-threaded connection ended") self._disconnect() - raise EOFError + raise EOFError from e def get_imdsessioninfo(self): """ @@ -241,7 +254,9 @@ def _await_IMD_handshake(self) -> IMDSessionInfo: read_into_buf(self._conn, h_buf) except (ConnectionError, TimeoutError, Exception) as e: logger.debug("IMDClient: No handshake packet received: %s", e) - raise ConnectionError("IMDClient: No handshake packet received") + raise ConnectionError( + "IMDClient: No handshake packet received" + ) from e header = IMDHeader(h_buf) @@ -362,7 +377,7 @@ class BaseIMDProducer(threading.Thread): error_queue: queue.Queue Queue to hold errors produced by the producer thread timeout : int, optional - Timeout for the socket in seconds [``5``] + Timeout for the socket in seconds [``600``] """ def __init__( @@ -373,7 +388,7 @@ def __init__( n_atoms, multithreaded, error_queue, - timeout=5, + timeout=600, **kwargs, ): super(BaseIMDProducer, self).__init__(daemon=True) @@ -424,7 +439,8 @@ def _get_imdframe(self): try: self._parse_imdframe() except EOFError as e: - raise EOFError + logger.debug(f"IMDProducer: No more frames to read: {e}") + raise EOFError from e except Exception as e: raise RuntimeError("An unexpected error occurred") from e @@ -468,11 +484,13 @@ def run(self): self._frame, t.elapsed, ) - except EOFError: + except EOFError as e: # simulation ended in a way # that we expected # i.e. consumer stopped or read_into_buf didn't find # full token of data + logger.debug("IMDProducer: %s", e) + self.error_queue.put(e) logger.debug("IMDProducer: Simulation ended normally, cleaning up") except Exception as e: logger.debug("IMDProducer: An unexpected error occurred: %s", e) @@ -513,13 +531,19 @@ def _read(self, buf): """Wraps `read_into_buf` call to give uniform error handling which indicates end of stream""" try: read_into_buf(self._conn, buf) - except (ConnectionError, TimeoutError, BlockingIOError, Exception): # ConnectionError: Server is definitely done sending frames, socket is closed # TimeoutError: Server is *likely* done sending frames. # BlockingIOError: Occurs when timeout is 0 in place of a TimeoutError. Server is *likely* done sending frames # OSError: Occurs when main thread disconnects from the server and closes the socket, but producer thread attempts to read another frame # Exception: Something unexpected happened - raise EOFError + except ConnectionError as e: + raise EOFError("Server is definitely done sending frames") from e + except TimeoutError as e: + raise EOFError("Server is likely done sending frames") from e + except BlockingIOError as e: + raise EOFError("Server is likely done sending frames") from e + except Exception as e: + raise EOFError("Something unexpected happened") from e class IMDProducerV2(BaseIMDProducer): @@ -597,7 +621,9 @@ def _pause(self): self._conn.sendall(pause) except ConnectionResetError as e: # Simulation has already ended by the time we paused - raise EOFError + raise EOFError( + "Simulation has already ended by the time we paused" + ) from e # Edge case: pause occured in the time between server sends its last frame # and closing socket # Simulation is not actually paused but is over, but we still want to read remaining data @@ -612,7 +638,9 @@ def _unpause(self): # Edge case: pause occured in the time between server sends its last frame # and closing socket # Simulation was never actually paused in this case and is now over - raise EOFError + raise EOFError( + "Simulation was never actually paused as pause was sent after the last frame; simulation is now over" + ) from e # Edge case: pause & unpause occured in the time between server sends its last frame and closing socket # in this case, the simulation isn't actually unpaused but over @@ -662,7 +690,9 @@ def _pause(self): self._conn.sendall(pause) except ConnectionResetError as e: # Simulation has already ended by the time we paused - raise EOFError + raise EOFError( + "Simulation has already ended by the time we paused" + ) from e # Edge case: pause occured in the time between server sends its last frame # and closing socket # Simulation is not actually paused but is over, but we still want to read remaining data @@ -677,7 +707,9 @@ def _unpause(self): # Edge case: pause occured in the time between server sends its last frame # and closing socket # Simulation was never actually paused in this case and is now over - raise EOFError + raise EOFError( + "Simulation was never actually paused as pause was sent after the last frame; simulation is now over" + ) from e # Edge case: pause & unpause occured in the time between server sends its last frame and closing socket # in this case, the simulation isn't actually unpaused but over @@ -852,9 +884,10 @@ def wait_for_space(self): if self._consumer_finished: logger.debug("IMDProducer: Noticing consumer finished") - raise EOFError + raise EOFError("Consumer has finished") except Exception as e: logger.debug(f"IMDProducer: Error waiting for space in buffer: {e}") + raise RuntimeError("Error waiting for space in buffer") from e def pop_empty_imdframe(self): logger.debug("IMDProducer: Getting empty frame") @@ -870,7 +903,7 @@ def pop_empty_imdframe(self): if self._consumer_finished: logger.debug("IMDProducer: Noticing consumer finished") - raise EOFError + raise EOFError("Consumer has finished") return self._empty_q.get() @@ -905,7 +938,7 @@ def pop_full_imdframe(self): if self._producer_finished and self._full_q.qsize() == 0: logger.debug("IMDFrameBuffer(Consumer): Producer finished") - raise EOFError + raise EOFError("Producer has finished") imdf = self._full_q.get() diff --git a/imdclient/tests/server.py b/imdclient/tests/server.py index 6c467c3..b515da5 100644 --- a/imdclient/tests/server.py +++ b/imdclient/tests/server.py @@ -37,7 +37,7 @@ def set_imdsessioninfo(self, imdsinfo): @property def port(self): """Get the port the server is bound to. - + Returns: int: The port number, or None if not bound yet. """ @@ -47,7 +47,9 @@ def handshake_sequence(self, host, first_frame=True): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind((host, 0)) # Bind to port 0 to get a free port self._bound_port = s.getsockname()[1] # Store the actual bound port - logger.debug(f"InThreadIMDServer: Listening on {host}:{self._bound_port}") + logger.debug( + f"InThreadIMDServer: Listening on {host}:{self._bound_port}" + ) s.listen(60) self.listen_socket = s diff --git a/imdclient/tests/test_imdclient.py b/imdclient/tests/test_imdclient.py index 656d08c..6a93d8c 100644 --- a/imdclient/tests/test_imdclient.py +++ b/imdclient/tests/test_imdclient.py @@ -1,6 +1,7 @@ """Test for IMDClient functionality""" import logging +import time import pytest from numpy.testing import ( @@ -54,55 +55,62 @@ def imdsinfo(self): return create_default_imdsinfo_v3() @pytest.fixture - def server_client_two_frame_buf(self, universe, imdsinfo): - server = InThreadIMDServer(universe.trajectory) - server.set_imdsessioninfo(imdsinfo) - server.handshake_sequence("localhost", first_frame=False) - client = IMDClient( - f"localhost", - server.port, - universe.trajectory.n_atoms, - buffer_size=imdframe_memsize(universe.trajectory.n_atoms, imdsinfo) - * 2, + def server_client(self, universe, imdsinfo): + created = [] + + def _server_client(endianness=None, **client_kwargs): + server = InThreadIMDServer(universe.trajectory) + if endianness is not None: + imdsinfo.endianness = endianness + server.set_imdsessioninfo(imdsinfo) + + n_atoms = client_kwargs.pop("n_atoms", universe.atoms.n_atoms) + server.handshake_sequence("localhost", first_frame=False) + client = IMDClient( + "localhost", + server.port, + n_atoms, + **client_kwargs, + ) + server.join_accept_thread() + created.append((server, client)) + return server, client + + yield _server_client + + # Teardown: stop clients and cleanup servers + for server, client in created: + try: + client.stop() + except Exception: + pass + try: + server.cleanup() + except Exception: + pass + + @pytest.fixture + def server_client_two_frame_buf(self, server_client, universe, imdsinfo): + # Calculate the buffer size + buffer_size = ( + imdframe_memsize(universe.trajectory.n_atoms, imdsinfo) * 2 ) - server.join_accept_thread() + timeout = 5 # to speed up no disconnect test + server, client = server_client(buffer_size=buffer_size, timeout=timeout) yield server, client - client.stop() - server.cleanup() - @pytest.fixture(params=[">", "<"]) - def server_client(self, universe, imdsinfo, request): - server = InThreadIMDServer(universe.trajectory) - imdsinfo.endianness = request.param - server.set_imdsessioninfo(imdsinfo) - server.handshake_sequence("localhost", first_frame=False) - client = IMDClient( - f"localhost", - server.port, - universe.atoms.n_atoms, - ) - server.join_accept_thread() + @pytest.fixture(params=["<", ">"]) + def server_client_endianness(self, server_client, request): + server, client = server_client(endianness=request.param) yield server, client - client.stop() - server.cleanup() @pytest.fixture - def server_client_incorrect_atoms(self, universe, imdsinfo): - server = InThreadIMDServer(universe.trajectory) - server.set_imdsessioninfo(imdsinfo) - server.handshake_sequence("localhost", first_frame=False) - client = IMDClient( - f"localhost", - server.port, - universe.atoms.n_atoms + 1, - ) - server.join_accept_thread() + def server_client_incorrect_atoms(self, server_client, universe): + server, client = server_client(n_atoms=universe.trajectory.n_atoms + 1) yield server, client - client.stop() - server.cleanup() - def test_traj_unchanged(self, server_client, universe): - server, client = server_client + def test_traj_unchanged(self, server_client_endianness, universe): + server, client = server_client_endianness server.send_frames(0, 5) for i in range(5): imdf = client.get_imdframe() @@ -163,31 +171,71 @@ def test_pause_resume_no_disconnect(self, server_client_two_frame_buf): server.expect_packet(IMDHeaderType.IMD_DISCONNECT) @pytest.mark.parametrize("cont", [True, False]) - def test_continue_after_disconnect(self, universe, imdsinfo, cont): - server = InThreadIMDServer(universe.trajectory) - server.set_imdsessioninfo(imdsinfo) - server.handshake_sequence("localhost", first_frame=False) - client = IMDClient( - f"localhost", - server.port, - universe.trajectory.n_atoms, - continue_after_disconnect=cont, - ) - server.join_accept_thread() + def test_continue_after_disconnect(self, server_client, cont): + server, client = server_client(continue_after_disconnect=cont) server.expect_packet( IMDHeaderType.IMD_WAIT, expected_length=(int)(not cont) ) - def test_incorrect_atom_count(self, server_client_incorrect_atoms, universe): + def test_timeout_warning_low_value(self, server_client, caplog): + """Test that warning is issued for timeout values <= 1 second""" + with caplog.at_level(logging.WARNING): + server, client = server_client(timeout=1) + + # Check that warning was logged + assert any( + "timeout value of 1 second(s) is very low" in record.message + for record in caplog.records + ) + + @pytest.mark.parametrize("timeout_val", [2, 10]) + def test_timeout_within_limit(self, server_client, universe, timeout_val): + """Test that timeout does not trigger when server responds within timeout period""" + server, client = server_client(timeout=timeout_val) + + # Sleep for less than timeout before sending frames + time.sleep(timeout_val - 1) + server.send_frame(0) + + # Should successfully receive frame without timeout + imdf = client.get_imdframe() + assert_allclose(universe.trajectory[0].positions, imdf.positions) + + @pytest.mark.parametrize("timeout_val", [2, 10]) + def test_timeout_when_exceeded(self, server_client, timeout_val): + """Test that timeout triggers EOFError when server doesn't respond within timeout period""" + server, client = server_client(timeout=timeout_val) + + # Sleep for longer than timeout without sending any frames + time.sleep(timeout_val + 1) + + # Client should timeout and raise EOFError when trying to get first frame + with pytest.raises(EOFError) as exc_info: + client.get_imdframe() + + # Verify TimeoutError is somewhere in the exception chain + exception_chain = [] + current = exc_info.value + while current is not None: + exception_chain.append(type(current)) + current = current.__cause__ + + assert TimeoutError in exception_chain + + def test_incorrect_atom_count( + self, server_client_incorrect_atoms, universe + ): server, client = server_client_incorrect_atoms - + server.send_frame(0) - + with pytest.raises(EOFError) as exc_info: client.get_imdframe() - + error_msg = str(exc_info.value) - assert f"Expected n_atoms value {universe.atoms.n_atoms + 1}" in error_msg + assert ( + f"Expected n_atoms value {universe.atoms.n_atoms + 1}" in error_msg + ) assert f"got {universe.atoms.n_atoms}" in error_msg assert "Ensure you are using the correct topology file" in error_msg