From c70a8a23eced302943e223002af9d813c23e33c7 Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Fri, 16 Jan 2026 12:01:30 +0300 Subject: [PATCH] added cache func and cached get_base_directories --- app/ldap_protocol/utils/helpers.py | 76 +++++++++++++++++++++++++++++- app/ldap_protocol/utils/queries.py | 2 + 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 296ff5978..73b061050 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -130,6 +130,7 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +import asyncio import functools import hashlib import random @@ -138,19 +139,23 @@ import time from calendar import timegm from datetime import datetime +from functools import wraps from hashlib import blake2b from operator import attrgetter -from typing import Callable +from typing import Any, Callable, Iterable from zoneinfo import ZoneInfo from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.compiler import compiles +from sqlalchemy.orm.attributes import instance_state from sqlalchemy.sql.compiler import DDLCompiler from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable from entities import Directory +DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes + def validate_entry(entry: str) -> bool: """Validate entry str. @@ -402,3 +407,72 @@ async def explain_query( for row in await session.execute(explain(query, analyze=True)) ), ) + + +def has_expired_sqla_objs(obj: Any, max_depth: int = 3) -> bool: + def _check(value: Any) -> bool: + try: + state = instance_state(value) + return bool(state.expired_attributes) + except AttributeError: + return False + + def _walk(value: Any, depth: int = 0) -> bool: + if depth > max_depth: + return False + + if _check(value): + return True + + if isinstance(value, str | bytes | bytearray): + return False + + if isinstance(value, dict): + return any(_walk(v, depth + 1) for v in value.values()) + + if isinstance(value, Iterable): + return any(_walk(v, depth + 1) for v in value) + + return False + + return _walk(obj) + + +def async_lru_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable: + cache: dict = {} + locks: dict = {} + + def _is_value_expired( + value: Any, + now: float, + expires_at: float | None, + ) -> bool: + return bool( + expires_at and expires_at < now or has_expired_sqla_objs(value), + ) + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args: tuple, **kwargs: dict) -> Any: + key = (args, tuple(sorted(kwargs.items()))) + now = time.monotonic() + if key not in locks: + locks[key] = asyncio.Lock() + + async with locks[key]: + if key in cache: + value, expires_at = cache[key] + if not _is_value_expired(value, now, expires_at): + return value + else: + del cache[key] + + result = await func(*args, **kwargs) + expires_at = now + ttl if ttl else None + cache[key] = (result, expires_at) + del locks[key] + return result + + return wrapper + + return decorator diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index a1f9243de..4e40514aa 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -27,6 +27,7 @@ from .const import EMAIL_RE, GRANT_DN_STRING from .helpers import ( + async_lru_cache, create_integer_hash, create_object_sid, dn_is_base_directory, @@ -35,6 +36,7 @@ ) +@async_lru_cache() async def get_base_directories(session: AsyncSession) -> list[Directory]: """Get base domain directories.""" result = await session.execute(