diff --git a/charon/config.py b/charon/config.py index 396e4be3..39e9dc80 100644 --- a/charon/config.py +++ b/charon/config.py @@ -42,6 +42,7 @@ def __init__(self, data: Dict): self.__radas_sign_timeout_retry_interval: int = data.get( "radas_sign_timeout_retry_interval", 60 ) + self.__radas_receiver_timeout: int = int(data.get("radas_receiver_timeout", 1800)) def validate(self) -> bool: if not self.__umb_host: @@ -112,6 +113,9 @@ def radas_sign_timeout_retry_count(self) -> int: def radas_sign_timeout_retry_interval(self) -> int: return self.__radas_sign_timeout_retry_interval + def receiver_timeout(self) -> int: + return self.__radas_receiver_timeout + class CharonConfig(object): """CharonConfig is used to store all configurations for charon diff --git a/charon/pkgs/radas_sign.py b/charon/pkgs/radas_sign.py index 355daa77..0ac558b8 100644 --- a/charon/pkgs/radas_sign.py +++ b/charon/pkgs/radas_sign.py @@ -17,13 +17,14 @@ import logging import json import os -import asyncio import sys +import asyncio import uuid +import time from typing import List, Any, Tuple, Callable, Dict, Optional from charon.config import RadasConfig from charon.pkgs.oras_client import OrasClient -from proton import SSLDomain, Message, Event, Sender +from proton import SSLDomain, Message, Event, Sender, Connection from proton.handlers import MessagingHandler from proton.reactor import Container @@ -40,6 +41,8 @@ class RadasReceiver(MessagingHandler): from the cmd flag,should register UmbListener when the client starts request_id (str): Identifier of the request for the signing result + rconf (RadasConfig): + the configurations for the radas messaging system. sign_result_status (str): Result of the signing(success/failed) sign_result_errors (list): @@ -50,10 +53,13 @@ def __init__(self, sign_result_loc: str, request_id: str, rconf: RadasConfig) -> super().__init__() self.sign_result_loc = sign_result_loc self.request_id = request_id - self.conn = None + self.conn: Optional[Connection] = None + self.message_handled = False self.sign_result_status: Optional[str] = None self.sign_result_errors: List[str] = [] self.rconf = rconf + self.start_time = 0.0 + self.timeout_check_delay = 30.0 self.ssl = SSLDomain(SSLDomain.MODE_CLIENT) self.ssl.set_trusted_ca_db(self.rconf.root_ca()) self.ssl.set_peer_authentication(SSLDomain.VERIFY_PEER) @@ -62,27 +68,58 @@ def __init__(self, sign_result_loc: str, request_id: str, rconf: RadasConfig) -> self.rconf.client_key(), self.rconf.client_key_password() ) + self.log = logging.getLogger("charon.pkgs.radas_sign.RadasReceiver") def on_start(self, event: Event) -> None: - self.conn = event.container.connect( - url=self.rconf.umb_target(), - ssl_domain=self.ssl + umb_target = self.rconf.umb_target() + container = event.container + self.conn = container.connect( + url=umb_target, + ssl_domain=self.ssl, + heartbeat=500 ) - event.container.create_receiver( - self.conn, self.rconf.result_queue(), dynamic=True + receiver = container.create_receiver( + context=self.conn, source=self.rconf.result_queue(), ) - logger.info("Listening on %s, queue: %s", - self.rconf.umb_target(), - self.rconf.result_queue()) + self.log.info("Listening on %s, queue: %s", + umb_target, + receiver.source.address) + self.start_time = time.time() + container.schedule(self.timeout_check_delay, self) + + def on_timer_task(self, event: Event) -> None: + current = time.time() + timeout = self.rconf.receiver_timeout() + idle_time = current - self.start_time + self.log.debug("Checking timeout: passed %s seconds, timeout time %s seconds", + idle_time, timeout) + if idle_time > self.rconf.receiver_timeout(): + self.log.error("The receiver did not receive messages for more than %s seconds," + " and needs to stop receiving and quit.", timeout) + self._close(event) + else: + event.container.schedule(self.timeout_check_delay, self) def on_message(self, event: Event) -> None: + self.log.debug("Got message: %s", event.message.body) self._process_message(event.message.body) + if self.message_handled: + self.log.debug("The signing result is handled.") + self._close(event) - def on_connection_error(self, event: Event) -> None: - logger.error("Received an error event:\n%s", event) + def on_error(self, event: Event) -> None: + self.log.error("Received an error event:\n%s", event.message.body) def on_disconnected(self, event: Event) -> None: - logger.error("Disconnected from AMQP broker.") + self.log.info("Disconnected from AMQP broker: %s", + event.connection.connected_address) + + def _close(self, event: Event) -> None: + if event: + if event.connection: + event.connection.close() + if event.container: + event.container.stop() def _process_message(self, msg: Any) -> None: """ @@ -93,32 +130,37 @@ def _process_message(self, msg: Any) -> None: msg_dict = json.loads(msg) msg_request_id = msg_dict.get("request_id") if msg_request_id != self.request_id: - logger.info( + self.log.info( "Message request_id %s does not match the request_id %s from sender, ignoring", msg_request_id, self.request_id, ) return - logger.info( + self.message_handled = True + self.log.info( "Start to process the sign event message, request_id %s is matched", msg_request_id ) self.sign_result_status = msg_dict.get("signing_status") self.sign_result_errors = msg_dict.get("errors", []) - result_reference_url = msg_dict.get("result_reference") - if not result_reference_url: - logger.warning("Not found result_reference in message,ignore.") - return + if self.sign_result_status == "success": + result_reference_url = msg_dict.get("result_reference") + if not result_reference_url: + self.log.warning("Not found result_reference in message,ignore.") + return - logger.info("Using SIGN RESULT LOC: %s", self.sign_result_loc) - sign_result_parent_dir = os.path.dirname(self.sign_result_loc) - os.makedirs(sign_result_parent_dir, exist_ok=True) + self.log.info("Using SIGN RESULT LOC: %s", self.sign_result_loc) + sign_result_parent_dir = os.path.dirname(self.sign_result_loc) + os.makedirs(sign_result_parent_dir, exist_ok=True) - oras_client = OrasClient() - files = oras_client.pull( - result_reference_url=result_reference_url, sign_result_loc=self.sign_result_loc - ) - logger.info("Number of files pulled: %d, path: %s", len(files), files[0]) + oras_client = OrasClient() + files = oras_client.pull( + result_reference_url=result_reference_url, sign_result_loc=self.sign_result_loc + ) + self.log.info("Number of files pulled: %d, path: %s", len(files), files[0]) + else: + self.log.error("The signing result received with failed status. Errors: %s", + self.sign_result_errors) class RadasSender(MessagingHandler): @@ -141,7 +183,6 @@ def __init__(self, payload: Any, rconf: RadasConfig): self.message: Optional[Message] = None self.container: Optional[Container] = None self.sender: Optional[Sender] = None - self.log = logging.getLogger("charon.pkgs.radas_sign.RadasSender") self.ssl = SSLDomain(SSLDomain.MODE_CLIENT) self.ssl.set_trusted_ca_db(self.rconf.root_ca()) self.ssl.set_peer_authentication(SSLDomain.VERIFY_PEER) @@ -150,6 +191,7 @@ def __init__(self, payload: Any, rconf: RadasConfig): self.rconf.client_key(), self.rconf.client_key_password() ) + self.log = logging.getLogger("charon.pkgs.radas_sign.RadasSender") def on_start(self, event): self.container = event.container @@ -329,7 +371,6 @@ def sign_in_radas(repo_url: str, repo_url, requester, sign_key, result_path) request_id = str(uuid.uuid4()) exclude = ignore_patterns if ignore_patterns else [] - payload = { "request_id": request_id, "requested_by": requester, @@ -347,8 +388,12 @@ def sign_in_radas(repo_url: str, logger.error("Something wrong happened in message sending, see logs") sys.exit(1) - listener = RadasReceiver(result_path, request_id, radas_config) - Container(listener).run() + # request_id = "some-request-id-1" # for test purpose + receiver = RadasReceiver(result_path, request_id, radas_config) + Container(receiver).run() - if listener.conn: - listener.conn.close() + status = receiver.sign_result_status + if status != "success": + logger.error("The signing result is processed with errors: %s", + receiver.sign_result_errors) + sys.exit(1) diff --git a/tests/test_radas_sign_receiver.py b/tests/test_radas_sign_receiver.py new file mode 100644 index 00000000..c75a8e64 --- /dev/null +++ b/tests/test_radas_sign_receiver.py @@ -0,0 +1,130 @@ +from unittest import mock +import unittest +import tempfile +import time +import json +from charon.pkgs.radas_sign import RadasReceiver + + +class RadasSignReceiverTest(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def tearDown(self) -> None: + super().tearDown() + + def reset_receiver(self, r_receiver: RadasReceiver) -> None: + r_receiver.message_handled = False + r_receiver.sign_result_errors = [] + r_receiver.sign_result_status = None + + def test_radas_receiver(self): + # Mock configuration + mock_radas_config = mock.MagicMock() + mock_radas_config.validate.return_value = True + mock_radas_config.client_ca.return_value = "test-client-ca" + mock_radas_config.client_key.return_value = "test-client-key" + mock_radas_config.client_key_password.return_value = "test-client-key-pass" + mock_radas_config.root_ca.return_value = "test-root-ca" + mock_radas_config.receiver_timeout.return_value = 60 + + # Mock Container run to avoid real AMQP connection + with mock.patch( + "charon.pkgs.radas_sign.Container") as mock_container, \ + mock.patch("charon.pkgs.radas_sign.SSLDomain") as ssl_domain, \ + mock.patch("charon.pkgs.radas_sign.OrasClient") as oras_client, \ + mock.patch("charon.pkgs.radas_sign.Event") as event: + test_result_path = tempfile.mkdtemp() + test_request_id = "test-request-id" + r_receiver = RadasReceiver(test_result_path, test_request_id, mock_radas_config) + self.assertEqual(ssl_domain.call_count, 1) + self.assertEqual(r_receiver.sign_result_loc, test_result_path) + self.assertEqual(r_receiver.request_id, test_request_id) + + # prepare mock + mock_receiver = mock.MagicMock() + mock_conn = mock.MagicMock() + mock_container.connect.return_value = mock_conn + mock_container.create_receiver.return_value = mock_receiver + event.container = mock_container + event.message = mock.MagicMock() + event.connection = mock.MagicMock() + + # test on_start + r_receiver.on_start(event) + self.assertEqual(mock_container.connect.call_count, 1) + self.assertEqual(mock_container.create_receiver.call_count, 1) + self.assertTrue(r_receiver.start_time > 0.0) + self.assertTrue(r_receiver.start_time < time.time()) + self.assertEqual(mock_container.schedule.call_count, 1) + + # test on_message: unmatched case + test_ummatch_result = { + "request_id": "test-request-id-no-match", + "file_reference": "quay.io/example/test-repo", + "result_reference": "quay.io/example-sign/sign-repo", + "sig_keyname": "testkey", + "signing_status": "success", + "errors": [] + } + event.message.body = json.dumps(test_ummatch_result) + r_receiver.on_message(event) + self.assertEqual(event.connection.close.call_count, 0) + self.assertEqual(mock_container.stop.call_count, 0) + self.assertFalse(r_receiver.message_handled) + self.assertIsNone(r_receiver.sign_result_status) + self.assertEqual(r_receiver.sign_result_errors, []) + self.assertEqual(oras_client.call_count, 0) + + # test on_message: matched case with failed status + self.reset_receiver(r_receiver) + test_failed_result = { + "request_id": "test-request-id", + "file_reference": "quay.io/example/test-repo", + "result_reference": "quay.io/example-sign/sign-repo", + "sig_keyname": "testkey", + "signing_status": "failed", + "errors": ["error1", "error2"] + } + event.message.body = json.dumps(test_failed_result) + r_receiver.on_message(event) + self.assertEqual(event.connection.close.call_count, 1) + self.assertEqual(mock_container.stop.call_count, 1) + self.assertTrue(r_receiver.message_handled) + self.assertEqual(r_receiver.sign_result_status, "failed") + self.assertEqual(r_receiver.sign_result_errors, ["error1", "error2"]) + self.assertEqual(oras_client.call_count, 0) + + # test on_message: matched case with success status + self.reset_receiver(r_receiver) + test_success_result = { + "request_id": "test-request-id", + "file_reference": "quay.io/example/test-repo", + "result_reference": "quay.io/example-sign/sign-repo", + "sig_keyname": "testkey", + "signing_status": "success", + "errors": [] + } + event.message.body = json.dumps(test_success_result) + r_receiver.on_message(event) + self.assertEqual(event.connection.close.call_count, 2) + self.assertEqual(mock_container.stop.call_count, 2) + self.assertTrue(r_receiver.message_handled) + self.assertEqual(r_receiver.sign_result_status, "success") + self.assertEqual(r_receiver.sign_result_errors, []) + self.assertEqual(oras_client.call_count, 1) + oras_client_call = oras_client.return_value + self.assertEqual(oras_client_call.pull.call_count, 1) + + # test on_timer_task: not timeout + r_receiver.on_timer_task(event) + self.assertEqual(event.connection.close.call_count, 2) + self.assertEqual(mock_container.stop.call_count, 2) + self.assertEqual(mock_container.schedule.call_count, 2) + + # test on_timer_task: timeout + mock_radas_config.receiver_timeout.return_value = 0 + r_receiver.on_timer_task(event) + self.assertEqual(event.connection.close.call_count, 3) + self.assertEqual(mock_container.stop.call_count, 3) + self.assertEqual(mock_container.schedule.call_count, 2) diff --git a/tests/test_radas_sign_sender.py b/tests/test_radas_sign_sender.py index 1e75b8fe..602d7de6 100644 --- a/tests/test_radas_sign_sender.py +++ b/tests/test_radas_sign_sender.py @@ -4,7 +4,7 @@ from charon.pkgs.radas_sign import RadasSender -class RadasSignHandlerTest(unittest.TestCase): +class RadasSignSenderTest(unittest.TestCase): def setUp(self) -> None: super().setUp()