Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions source/app/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import configparser
from app.logger import logger
import json
import logging
import os
import ssl
Expand Down Expand Up @@ -227,6 +228,144 @@ class CeleryConfig:
worker_pool_restarts = True
broker_connection_retry_on_startup = True

security_key = ''
security_certificate = ''
security_cert_store = ''
security_digest = 'sha256'


def _parse_bool(value):
if isinstance(value, bool):
return value
return value.lower() in ('true', '1', 'yes', 'on')

def _parse_float(value):
return float(value)

_celery_settings = [
("accept_content", json.loads),
("enable_utc", _parse_bool),
("imports", json.loads),
("include", json.loads),
("timezone", str),
("beat_max_loop_interval", int),
("beat_schedule", json.loads),
("beat_scheduler", str),
("beat_schedule_filename", str),
("beat_sync_every", int),
("broker_url", str),
("broker_transport", str),
("broker_transport_options", json.loads),
("broker_connection_timeout", int),
("broker_connection_retry", _parse_bool),
("broker_connection_max_retries", int),
("broker_failover_strategy", str),
("broker_heartbeat", _parse_float),
("broker_login_method", str),
("broker_pool_limit", int),
("broker_use_ssl", _parse_bool),
("cache_backend", str),
("cache_backend_options", json.loads),
("cassandra_table", str),
("cassandra_entry_ttl", int),
("cassandra_keyspace", str),
("cassandra_port", int),
("cassandra_read_consistency", str),
("cassandra_servers", json.loads),
("cassandra_write_consistency", str),
("cassandra_options", json.loads),
("s3_access_key_id", str),
("s3_secret_access_key", str),
("s3_bucket", str),
("s3_base_path", str),
("s3_endpoint_url", str),
("s3_region", str),
("couchbase_backend_settings", json.loads),
("arangodb_backend_settings", json.loads),
("mongodb_backend_settings", json.loads),
("event_queue_expires", _parse_float),
("event_queue_ttl", _parse_float),
("event_queue_prefix", str),
("event_serializer", str),
("redis_db", str),
("redis_host", str),
("redis_max_connections", int),
("redis_username", str),
("redis_password", str),
("redis_port", int),
("redis_backend_use_ssl", json.loads),
("result_backend", str),
("result_cache_max", int),
("result_compression", str),
("result_exchange", str),
("result_exchange_type", str),
("result_expires", int),
("result_persistent", _parse_bool),
("result_serializer", str),
("database_engine_options", json.loads),
("database_short_lived_sessions", _parse_bool),
("database_db_names", json.loads),
("security_certificate", str),
("security_cert_store", str),
("security_key", str),
("task_acks_late", _parse_bool),
("task_acks_on_failure_or_timeout", _parse_bool),
("task_always_eager", _parse_bool),
("task_annotations", json.loads),
("task_compression", str),
("task_create_missing_queues", _parse_bool),
("task_default_delivery_mode", str),
("task_default_exchange", str),
("task_default_exchange_type", str),
("task_default_queue", str),
("task_default_rate_limit", int),
("task_default_routing_key", str),
("task_eager_propagates", _parse_bool),
("task_ignore_result", _parse_bool),
("task_publish_retry", _parse_bool),
("task_publish_retry_policy", json.loads),
("task_queues", json.loads),
("task_routes", json.loads),
("task_send_sent_event", _parse_bool),
("task_serializer", str),
("task_soft_time_limit", int),
("task_track_started", _parse_bool),
("task_reject_on_worker_lost", _parse_bool),
("task_time_limit", int),
("worker_agent", str),
("worker_autoscaler", str),
("worker_concurrency", int),
("worker_consumer", str),
("worker_direct", _parse_bool),
("worker_disable_rate_limits", _parse_bool),
("worker_enable_remote_control", _parse_bool),
("worker_log_color", _parse_bool),
("worker_log_format", str),
("worker_lost_wait", _parse_float),
("worker_max_tasks_per_child", int),
("worker_pool", str),
("worker_pool_putlocks", _parse_bool),
("worker_pool_restarts", _parse_bool),
("worker_prefetch_multiplier", int),
("worker_redirect_stdouts", _parse_bool),
("worker_redirect_stdouts_level", str),
("worker_send_task_events", _parse_bool),
("worker_state_db", str),
("worker_task_log_format", str),
("worker_timer", str),
("worker_timer_precision", _parse_float),
]

for _setting, _parse in _celery_settings:
_env_var = f"CELERY__{_setting}"
if _env_var in os.environ:
_value = os.environ[_env_var]
try:
_parsed_value = _parse(_value)
setattr(CeleryConfig, _setting, _parsed_value)
except (ValueError, SyntaxError) as e:
logger.warning(f"Failed to parse {_env_var}: {e}")


class Config:
# Handled by bumpversion
Expand Down
45 changes: 44 additions & 1 deletion source/app/iris_engine/tasker/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,58 @@
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.

from celery import Celery
from celery.security import setup_security
from kombu.serialization import register
from app.configuration import CeleryConfig

def _patch_celery_cert_datetime():
import datetime
from celery.security.certificate import Certificate

_original_has_expired = Certificate.has_expired

def _patched_has_expired(self):
try:
return _original_has_expired(self)
except TypeError:
not_valid_after = self._cert.not_valid_after_utc
return datetime.datetime.now(datetime.timezone.utc) >= not_valid_after

Certificate.has_expired = _patched_has_expired


def _register_auth_serializer():
import json

def _encode_auth(data):
return json.dumps(data), 'application/auth'

def _decode_auth(data):
return json.loads(data)

register('auth', _encode_auth, _decode_auth, content_type='application/auth')


def make_celery(name):
return Celery(
_register_auth_serializer()

celery_app = Celery(
name,
config_source=CeleryConfig
)

if CeleryConfig.security_key and CeleryConfig.security_certificate:
_patch_celery_cert_datetime()
setup_security(
allowed_serializers=['auth'],
key=CeleryConfig.security_key,
cert=CeleryConfig.security_certificate,
store=CeleryConfig.security_cert_store,
digest=CeleryConfig.security_digest
)

return celery_app


def set_celery_flask_context(celery: Celery, app):
class ContextTask(celery.Task):
Expand Down
Loading