diff --git a/source/app/configuration.py b/source/app/configuration.py index 6d84a7d81..c0d93c597 100644 --- a/source/app/configuration.py +++ b/source/app/configuration.py @@ -18,6 +18,7 @@ import configparser from app.logger import logger +import json import logging import os import ssl @@ -227,6 +228,144 @@ class CeleryConfig: worker_pool_restarts = True broker_connection_retry_on_startup = True + security_key = '' + security_certificate = '' + security_cert_store = '' + security_digest = 'sha256' + + +def _parse_bool(value): + if isinstance(value, bool): + return value + return value.lower() in ('true', '1', 'yes', 'on') + +def _parse_float(value): + return float(value) + +_celery_settings = [ + ("accept_content", json.loads), + ("enable_utc", _parse_bool), + ("imports", json.loads), + ("include", json.loads), + ("timezone", str), + ("beat_max_loop_interval", int), + ("beat_schedule", json.loads), + ("beat_scheduler", str), + ("beat_schedule_filename", str), + ("beat_sync_every", int), + ("broker_url", str), + ("broker_transport", str), + ("broker_transport_options", json.loads), + ("broker_connection_timeout", int), + ("broker_connection_retry", _parse_bool), + ("broker_connection_max_retries", int), + ("broker_failover_strategy", str), + ("broker_heartbeat", _parse_float), + ("broker_login_method", str), + ("broker_pool_limit", int), + ("broker_use_ssl", _parse_bool), + ("cache_backend", str), + ("cache_backend_options", json.loads), + ("cassandra_table", str), + ("cassandra_entry_ttl", int), + ("cassandra_keyspace", str), + ("cassandra_port", int), + ("cassandra_read_consistency", str), + ("cassandra_servers", json.loads), + ("cassandra_write_consistency", str), + ("cassandra_options", json.loads), + ("s3_access_key_id", str), + ("s3_secret_access_key", str), + ("s3_bucket", str), + ("s3_base_path", str), + ("s3_endpoint_url", str), + ("s3_region", str), + ("couchbase_backend_settings", json.loads), + ("arangodb_backend_settings", json.loads), + ("mongodb_backend_settings", json.loads), + ("event_queue_expires", _parse_float), + ("event_queue_ttl", _parse_float), + ("event_queue_prefix", str), + ("event_serializer", str), + ("redis_db", str), + ("redis_host", str), + ("redis_max_connections", int), + ("redis_username", str), + ("redis_password", str), + ("redis_port", int), + ("redis_backend_use_ssl", json.loads), + ("result_backend", str), + ("result_cache_max", int), + ("result_compression", str), + ("result_exchange", str), + ("result_exchange_type", str), + ("result_expires", int), + ("result_persistent", _parse_bool), + ("result_serializer", str), + ("database_engine_options", json.loads), + ("database_short_lived_sessions", _parse_bool), + ("database_db_names", json.loads), + ("security_certificate", str), + ("security_cert_store", str), + ("security_key", str), + ("task_acks_late", _parse_bool), + ("task_acks_on_failure_or_timeout", _parse_bool), + ("task_always_eager", _parse_bool), + ("task_annotations", json.loads), + ("task_compression", str), + ("task_create_missing_queues", _parse_bool), + ("task_default_delivery_mode", str), + ("task_default_exchange", str), + ("task_default_exchange_type", str), + ("task_default_queue", str), + ("task_default_rate_limit", int), + ("task_default_routing_key", str), + ("task_eager_propagates", _parse_bool), + ("task_ignore_result", _parse_bool), + ("task_publish_retry", _parse_bool), + ("task_publish_retry_policy", json.loads), + ("task_queues", json.loads), + ("task_routes", json.loads), + ("task_send_sent_event", _parse_bool), + ("task_serializer", str), + ("task_soft_time_limit", int), + ("task_track_started", _parse_bool), + ("task_reject_on_worker_lost", _parse_bool), + ("task_time_limit", int), + ("worker_agent", str), + ("worker_autoscaler", str), + ("worker_concurrency", int), + ("worker_consumer", str), + ("worker_direct", _parse_bool), + ("worker_disable_rate_limits", _parse_bool), + ("worker_enable_remote_control", _parse_bool), + ("worker_log_color", _parse_bool), + ("worker_log_format", str), + ("worker_lost_wait", _parse_float), + ("worker_max_tasks_per_child", int), + ("worker_pool", str), + ("worker_pool_putlocks", _parse_bool), + ("worker_pool_restarts", _parse_bool), + ("worker_prefetch_multiplier", int), + ("worker_redirect_stdouts", _parse_bool), + ("worker_redirect_stdouts_level", str), + ("worker_send_task_events", _parse_bool), + ("worker_state_db", str), + ("worker_task_log_format", str), + ("worker_timer", str), + ("worker_timer_precision", _parse_float), +] + +for _setting, _parse in _celery_settings: + _env_var = f"CELERY__{_setting}" + if _env_var in os.environ: + _value = os.environ[_env_var] + try: + _parsed_value = _parse(_value) + setattr(CeleryConfig, _setting, _parsed_value) + except (ValueError, SyntaxError) as e: + logger.warning(f"Failed to parse {_env_var}: {e}") + class Config: # Handled by bumpversion diff --git a/source/app/iris_engine/tasker/celery.py b/source/app/iris_engine/tasker/celery.py index 3eb377cb4..bc353f9d8 100644 --- a/source/app/iris_engine/tasker/celery.py +++ b/source/app/iris_engine/tasker/celery.py @@ -17,15 +17,58 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. from celery import Celery +from celery.security import setup_security +from kombu.serialization import register from app.configuration import CeleryConfig +def _patch_celery_cert_datetime(): + import datetime + from celery.security.certificate import Certificate + + _original_has_expired = Certificate.has_expired + + def _patched_has_expired(self): + try: + return _original_has_expired(self) + except TypeError: + not_valid_after = self._cert.not_valid_after_utc + return datetime.datetime.now(datetime.timezone.utc) >= not_valid_after + + Certificate.has_expired = _patched_has_expired + + +def _register_auth_serializer(): + import json + + def _encode_auth(data): + return json.dumps(data), 'application/auth' + + def _decode_auth(data): + return json.loads(data) + + register('auth', _encode_auth, _decode_auth, content_type='application/auth') + def make_celery(name): - return Celery( + _register_auth_serializer() + + celery_app = Celery( name, config_source=CeleryConfig ) + if CeleryConfig.security_key and CeleryConfig.security_certificate: + _patch_celery_cert_datetime() + setup_security( + allowed_serializers=['auth'], + key=CeleryConfig.security_key, + cert=CeleryConfig.security_certificate, + store=CeleryConfig.security_cert_store, + digest=CeleryConfig.security_digest + ) + + return celery_app + def set_celery_flask_context(celery: Celery, app): class ContextTask(celery.Task):