diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 818b38a..50ff524 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -50,7 +50,6 @@ def __init__( logger (Optional[Logger], optional): Custom logger to use. Defaults to None. """ self.update_callback = None - self.parent_conn = None self.host = host self.port = port self.user = user @@ -83,7 +82,7 @@ def _create_subscription_process( self._cleanup_connections_and_processes() self.parent_conn, self.child_conn = Pipe() - self.subscription_proces = Process( + self.subscription_process = Process( target=casbin_channel_subscription, args=( self.child_conn, @@ -109,9 +108,12 @@ def start( self, timeout=20, # seconds ): - if not self.subscription_proces.is_alive(): + if self.subscription_process is None: + self._create_subscription_process(start_listening=False) + + if not self.subscription_process.is_alive(): # Start listening to messages - self.subscription_proces.start() + self.subscription_process.start() # And wait for the Process to be ready to listen for updates # from PostgreSQL timeout_time = time() + timeout @@ -124,6 +126,9 @@ def start( raise PostgresqlWatcherChannelSubscriptionTimeoutError(timeout) sleep(1 / 1000) # wait for 1 ms + def stop(self): + self._cleanup_connections_and_processes() + def _cleanup_connections_and_processes(self) -> None: # Clean up potentially existing Connections and Processes if self.parent_conn is not None: @@ -132,8 +137,9 @@ def _cleanup_connections_and_processes(self) -> None: if self.child_conn is not None: self.child_conn.close() self.child_conn = None - if self.subscription_process is not None: + if self.subscription_process is not None and self.subscription_process.pid is not None: self.subscription_process.terminate() + self.subscription_process.join() self.subscription_process = None def set_update_callback(self, update_handler: Optional[Callable[[None], None]]): diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index d3f9d70..5311ccb 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -50,7 +50,7 @@ def test_pg_watcher_init(self): assert isinstance(pg_watcher.parent_conn, connection.PipeConnection) else: assert isinstance(pg_watcher.parent_conn, connection.Connection) - assert isinstance(pg_watcher.subscription_proces, context.Process) + assert isinstance(pg_watcher.subscription_process, context.Process) def test_update_single_pg_watcher(self): pg_watcher = get_watcher("test_update_single_pg_watcher") @@ -115,6 +115,28 @@ def test_update_handler_not_called(self): self.assertFalse(main_watcher.should_reload()) self.assertTrue(handler.call_count == 0) + def test_stop_and_restart(self): + channel_name = "test_stop_and_restart" + pg_watcher = get_watcher(channel_name) + + # Verify initially started + self.assertTrue(pg_watcher.subscription_process.is_alive()) + + # Stop the watcher + pg_watcher.stop() + self.assertIsNone(pg_watcher.subscription_process) + + # Restart the watcher + pg_watcher.start() + + # Verify resources are recreated and process is alive + self.assertTrue(pg_watcher.subscription_process.is_alive()) + + # Verify it still works after restart + pg_watcher.update() + sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) + self.assertTrue(pg_watcher.should_reload()) + if __name__ == "__main__": main()