Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 44 additions & 47 deletions kombu/transport/SQS/SNS.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from kombu.log import get_logger

from .exceptions import (UnableToSubscribeQueueToTopicException,
UnableToUnsubscribeQueueFromTopicException,
UndefinedExchangeException)

# pragma: no branch
Expand Down Expand Up @@ -48,31 +49,17 @@ def __init__(self, channel: Channel):
def initialise_exchange(self, exchange_name: str) -> None:
"""Initialise SNS topic for a fanout exchange.

This method will create the SNS topic if it doesn't exist, and check for any SNS topic subscriptions
that no longer exist.
This method will create the SNS topic if it doesn't exist, and
check for any SNS topic subscriptions that no longer exist.
If there are any SNS topic subscriptions that no longer exist,
then they will be removed.

:param exchange_name: The name of the exchange.
:returns: None
"""
with self._lock:
# If topic has already been initialised, then do nothing
if self._topic_arn_cache.get(exchange_name):
return None

# Clear any old subscriptions
self.subscriptions.cleanup(exchange_name)

# If predefined_exchanges are set, then do not try to create an SNS topic
if self.channel.predefined_exchanges:
logger.debug(
"'predefined_exchanges' has been specified, so SNS topics will"
" not be created."
)
return

# Create the topic and cache the ARN
self._topic_arn_cache[exchange_name] = self._create_sns_topic(exchange_name)
return None
self.subscriptions.cleanup(exchange_name)
self._get_or_create_topic(exchange_name)
return
Comment thread
auvipy marked this conversation as resolved.

def publish(
self,
Expand All @@ -90,7 +77,7 @@ def publish(
:return: None
"""
# Get topic ARN for the given exchange
topic_arn = self._get_topic_arn(exchange_name)
topic_arn = self._get_or_create_topic(exchange_name)

# Build request args for boto
request_args: dict[str, str | dict] = {
Expand All @@ -110,46 +97,50 @@ def publish(
f"Unable to send message to topic '{topic_arn}': status code was {status_code}"
)

def _get_topic_arn(self, exchange_name: str) -> str:
"""Get the SNS topic ARN.
def _get_or_create_topic(self, exchange_name: str) -> str:
"""Create the SNS topic if not found, otherwise get the topic ARN.

This method will get the SNS topic ARN for the given exchange from the cache.
If the topic ARN is not in the cache, it will create the topic and add the ARN
to the cache before returning it.

As specified here, if the topic already exists, AWS will return the existing
topic ARN without creating a new topic:
https://docs.aws.amazon.com/boto3/latest/reference/services/sns/client/create_topic.html

If the topic ARN is not in the cache, then create it
:param exchange_name: The exchange to create the SNS topic for
:return: The SNS topic ARN
:raises UndefinedExchangeException: If predefined_exchanges are used and the
topic is not defined in the predefined_exchanges
"""
# If topic ARN is in the cache, then return it
if topic_arn := self._topic_arn_cache.get(exchange_name):
return topic_arn

# If predefined-exchanges are used, then do not create a new topic and raise an exception
if self.channel.predefined_exchanges:
return self._handle_getting_topic_arn_for_predefined_exchanges(exchange_name)

# If predefined_exchanges are not used, then create a new
# SNS topic/retrieve the ARN from AWS SNS and cache it
with self._lock:
# Re-check the cache after acquiring the lock to avoid redundant topic creation
if topic_arn := self._topic_arn_cache.get(exchange_name):
return topic_arn

arn = self._create_sns_topic(exchange_name)
self._topic_arn_cache[exchange_name] = arn
return arn
# If predefined-exchanges are used, then do not create a new topic and raise an exception
if self.channel.predefined_exchanges:
return self._get_topic_arn_for_predefined_exchange(exchange_name)

# If predefined_exchanges are not used, then create a new
# SNS topic/retrieve the ARN from AWS SNS and cache it
self._topic_arn_cache[exchange_name] = self._create_sns_topic(exchange_name)
return self._topic_arn_cache[exchange_name]

def _handle_getting_topic_arn_for_predefined_exchanges(self, exchange_name: str) -> str:
def _get_topic_arn_for_predefined_exchange(self, exchange_name: str) -> str:
"""Handles getting the topic ARN for predefined exchanges.

:param exchange_name: The exchange name to get the topic ARN for
:return: The SNS topic ARN for the exchange
:raises UndefinedExchangeException: If the exchange is not defined in the predefined_exchanges
"""
with self._lock:
if topic_arn := self._topic_arn_cache.get(exchange_name):
return topic_arn

if pre_defined_exchange_arn := self.channel.predefined_exchanges.get(exchange_name, {}).get("arn"):
self._topic_arn_cache[exchange_name] = pre_defined_exchange_arn
return pre_defined_exchange_arn
if pre_defined_exchange_arn := self.channel.predefined_exchanges.get(exchange_name, {}).get("arn"):
self._topic_arn_cache[exchange_name] = pre_defined_exchange_arn
return pre_defined_exchange_arn

# If pre-defined exchanges do not have the exchange, then raise an exception
raise UndefinedExchangeException(
Expand Down Expand Up @@ -366,7 +357,7 @@ def _handle_create_new_subscription_for_predefined_exchanges(

# Get ARNs for queue and topic
queue_arn = self._get_queue_arn(queue_name)
topic_arn = self.sns._get_topic_arn(exchange_name)
topic_arn = self.sns._get_or_create_topic(exchange_name)

# Subscribe the SQS queue to the SNS topic
subscription_arn = self._subscribe_queue_to_sns_topic(
Expand All @@ -387,7 +378,7 @@ def unsubscribe_queue(self, queue_name: str, exchange_name: str) -> None:
"""Unsubscribes a queue from an AWS SNS topic.

:param queue_name: The queue to unsubscribe
:param exchange_name: The exchange to unsubscribe from the queue, if not provided
:param exchange_name: The exchange to unsubscribe from the queue
:return: None
"""
cache_key = f"{exchange_name}:{queue_name}"
Expand All @@ -401,6 +392,11 @@ def unsubscribe_queue(self, queue_name: str, exchange_name: str) -> None:
logger.info(
f"Unsubscribed subscription '{subscription_arn}' for SQS queue '{queue_name}'"
)
except Exception as e:
Comment thread
rlaunch marked this conversation as resolved.
logger.error(
f"Failed to unsubscribe queue '{queue_name}' from SNS topic '{exchange_name}': {e}"
)
raise
finally:
# Remove the cached subscription ARN so future subscribe calls don't use a stale value
self._subscription_arn_cache.pop(cache_key, None)
Expand All @@ -426,7 +422,7 @@ def cleanup(self, exchange_name: str) -> None:
)

# Get subscriptions to check
topic_arn = self.sns._get_topic_arn(exchange_name)
topic_arn = self.sns._get_or_create_topic(exchange_name)

# Iterate through the subscriptions and remove any that are not associated with SQS queues
for subscription_arn in self._get_invalid_sns_subscriptions(topic_arn):
Expand Down Expand Up @@ -612,11 +608,12 @@ def _unsubscribe_sns_subscription(self, subscription_arn: str) -> None:

:param subscription_arn: The ARN of the subscription to unsubscribe
:return: None
:raises UnableToUnsubscribeQueueFromTopicException: If the SNS unsubscribe API call fails
"""
response = self.sns.get_client().unsubscribe(SubscriptionArn=subscription_arn)
if (status_code := response["ResponseMetadata"]["HTTPStatusCode"]) != 200:
logger.error(
f"Unable to remove subscription '{subscription_arn}': status code was {status_code}"
raise UnableToUnsubscribeQueueFromTopicException(
f"SNS unsubscribe API returned status code '{status_code}'"
)

def _get_invalid_sns_subscriptions(self, sns_topic_arn: str) -> list[str]:
Expand Down
4 changes: 4 additions & 0 deletions kombu/transport/SQS/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ class DoesNotExistQueueException(KombuError):

class UnableToSubscribeQueueToTopicException(KombuError):
"""Raised when unable to subscribe a queue to an SNS topic."""


class UnableToUnsubscribeQueueFromTopicException(KombuError):
"""Raised when unable to unsubscribe a queue from an SNS topic."""
Loading
Loading