diff --git a/app/ioc.py b/app/ioc.py index 71e76603c..b1d24f7a5 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -122,6 +122,7 @@ from ldap_protocol.policies.password.settings import PasswordValidatorSettings from ldap_protocol.policies.password.use_cases import PasswordBanWordUseCases, UserPasswordHistoryUseCases from ldap_protocol.rid_manager import ( + ObjectSidCacheRedisClient, ObjectSIDGateway, ObjectSIDUseCase, RIDManagerGateway, @@ -131,6 +132,7 @@ RIDSetGateway, RIDSetUseCase, ) +from ldap_protocol.rid_manager.objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache from ldap_protocol.rid_manager.types import HostMachineShortName from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -315,6 +317,14 @@ async def get_redis_for_sessions(self, settings: Settings) -> AsyncIterator[Sess yield SessionStorageClient(client) await client.aclose() + @provide(scope=Scope.APP) + def get_objectsid_cache_redis(self, client: SessionStorageClient) -> ObjectSidCacheRedisClient: + """Typed redis client for objectSid-related caches. + + Important: this does NOT create a new connection, just re-types the existing client. + """ + return ObjectSidCacheRedisClient(client) + @provide(scope=Scope.APP) def get_session_storage(self, client: SessionStorageClient, settings: Settings) -> SessionStorage: """Get session storage.""" @@ -464,6 +474,7 @@ def get_object_class_use_case_legacy(self, session: AsyncSession) -> ObjectClass rid_manager_use_case = provide(RIDManagerUseCase, scope=Scope.REQUEST) rid_manager_setup_use_case = provide(RIDManagerSetupUseCase, scope=Scope.REQUEST) object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST) + objectsid_allowed_object_classes_cache = provide(ObjectSidAllowedObjectClassesCache, scope=Scope.APP) object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST) rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST) rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST) diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 9adafa420..840b7cf92 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -168,7 +168,6 @@ async def handle(self, ctx: LDAPAddRequestContext) -> AsyncGenerator[AddResponse ctx.session.add(new_dir) await ctx.session.flush() - await ctx.object_sid_use_case.ensure_objectsid(directory_id=new_dir.id) await ctx.session.flush() except IntegrityError: diff --git a/app/ldap_protocol/rid_manager/__init__.py b/app/ldap_protocol/rid_manager/__init__.py index 204bbef53..5d4fac68f 100644 --- a/app/ldap_protocol/rid_manager/__init__.py +++ b/app/ldap_protocol/rid_manager/__init__.py @@ -6,12 +6,14 @@ from .object_sid_gateway import ObjectSIDGateway from .object_sid_use_case import ObjectSIDUseCase +from .objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache from .rid_manager_gateway import RIDManagerGateway from .rid_manager_use_case import RIDManagerUseCase from .rid_set_gateway import RIDSetGateway from .rid_set_use_case import RIDSetUseCase from .setup_gateway import RIDManagerSetupGateway from .setup_use_case import RIDManagerSetupUseCase +from .types import ObjectSidCacheRedisClient __all__ = [ "ObjectSIDGateway", @@ -22,4 +24,6 @@ "RIDManagerUseCase", "RIDSetGateway", "RIDSetUseCase", + "ObjectSidAllowedObjectClassesCache", + "ObjectSidCacheRedisClient", ] diff --git a/app/ldap_protocol/rid_manager/object_sid_use_case.py b/app/ldap_protocol/rid_manager/object_sid_use_case.py index 561f219c2..5c6c6fbdb 100644 --- a/app/ldap_protocol/rid_manager/object_sid_use_case.py +++ b/app/ldap_protocol/rid_manager/object_sid_use_case.py @@ -10,9 +10,9 @@ from ldap_protocol.ldap_schema.object_class.object_class_dao import ObjectClassDAO from ldap_protocol.rid_manager.exceptions import RIDManagerObjectSIDNotFoundError from ldap_protocol.rid_manager.object_sid_gateway import ObjectSIDGateway +from ldap_protocol.rid_manager.objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase -from ldap_protocol.utils.async_cache import objectsid_allowed_object_classes_cache class ObjectSIDUseCase: @@ -25,6 +25,7 @@ def __init__( session: AsyncSession, rid_manager_use_case: RIDManagerUseCase, object_class_dao: ObjectClassDAO, + objectsid_allowed_object_classes_cache: ObjectSidAllowedObjectClassesCache, ) -> None: """Initialize Object SID use case.""" self._gateway = gateway @@ -32,12 +33,16 @@ def __init__( self._session = session self._rid_manager_use_case = rid_manager_use_case self._object_class_dao = object_class_dao + self._objectsid_allowed_object_classes_cache = objectsid_allowed_object_classes_cache - @objectsid_allowed_object_classes_cache async def get_available_object_classes(self) -> set[str]: """ObjectClasses that allow objectSid (mustContain/mayContain).""" - names = await self._object_class_dao.get_object_class_names_include_attribute_type("objectSid") - return {n.lower() for n in names} + + async def compute() -> set[str]: + names = await self._object_class_dao.get_object_class_names_include_attribute_type("objectSid") + return {n.lower() for n in names} + + return await self._objectsid_allowed_object_classes_cache.get_or_compute(compute) async def is_objectsid_needed(self, object_class_names: set[str]) -> bool: """Check if objectSid is needed for objectClasses.""" diff --git a/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py b/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py new file mode 100644 index 000000000..a49401f24 --- /dev/null +++ b/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py @@ -0,0 +1,64 @@ +"""Redis cache for objectSid-related metadata. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import json +from typing import Awaitable, Callable + +from ldap_protocol.rid_manager.types import ObjectSidCacheRedisClient + + +class ObjectSidAllowedObjectClassesCache: + """Cache for ObjectClass names that allow `objectSid`.""" + + _CACHE_KEY = "ldap:objectsid:allowed_object_classes" + _LOCK_KEY = "lock:ldap:objectsid:allowed_object_classes" + _LOCK_BLOCKING_TIMEOUT_SECONDS = 5 + _LOCK_LEASE_TIMEOUT_SECONDS = 30 + + def __init__(self, redis: ObjectSidCacheRedisClient) -> None: + self._redis = redis + + def _decode(self, raw: bytes | str) -> set[str] | None: + if isinstance(raw, bytes | bytearray): + raw = raw.decode("utf-8", errors="replace") + try: + decoded = json.loads(raw) + except json.JSONDecodeError: + return None + + return {str(v).lower() for v in decoded} + + async def get(self) -> set[str] | None: + raw = await self._redis.get(self._CACHE_KEY) + if not raw: + return None + return self._decode(raw) + + async def store(self, value: set[str]) -> None: + normalized = sorted({v.lower() for v in value}) + await self._redis.set(self._CACHE_KEY, json.dumps(normalized)) + + async def clear(self) -> None: + await self._redis.delete(self._CACHE_KEY, self._LOCK_KEY) + + async def get_or_compute(self, compute: Callable[[], Awaitable[set[str]]]) -> set[str]: + """Read from redis, or compute once under a lock.""" + if cached := await self.get(): + return cached + + lock = self._redis.lock( + name=self._LOCK_KEY, + blocking_timeout=self._LOCK_BLOCKING_TIMEOUT_SECONDS, + timeout=self._LOCK_LEASE_TIMEOUT_SECONDS, + ) + async with lock: + if cached2 := await self.get(): + return cached2 + + value = await compute() + normalized = {v.lower() for v in value} + await self.store(normalized) + return normalized diff --git a/app/ldap_protocol/rid_manager/types.py b/app/ldap_protocol/rid_manager/types.py index ab1790f80..f4bbd0874 100644 --- a/app/ldap_protocol/rid_manager/types.py +++ b/app/ldap_protocol/rid_manager/types.py @@ -6,4 +6,7 @@ from typing import NewType +import redis.asyncio as redis + HostMachineShortName = NewType("HostMachineShortName", str) +ObjectSidCacheRedisClient = NewType("ObjectSidCacheRedisClient", redis.Redis) diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py index b2e96d256..7e25ee3c4 100644 --- a/app/ldap_protocol/utils/async_cache.py +++ b/app/ldap_protocol/utils/async_cache.py @@ -41,4 +41,3 @@ async def wrapper(*args: tuple, **kwargs: dict) -> T: base_directories_cache = AsyncTTLCache[list[Directory]]() domain_identifier_cache = AsyncTTLCache[str]() rid_set_id_cache = AsyncTTLCache[int]() -objectsid_allowed_object_classes_cache = AsyncTTLCache[set[str]](ttl=60 * 60 * 24) diff --git a/tests/conftest.py b/tests/conftest.py index ab935b057..e88cdc632 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,6 +113,7 @@ from ldap_protocol.policies.password.settings import PasswordValidatorSettings from ldap_protocol.policies.password.use_cases import PasswordBanWordUseCases, UserPasswordHistoryUseCases from ldap_protocol.rid_manager import ( + ObjectSidCacheRedisClient, ObjectSIDGateway, ObjectSIDUseCase, RIDManagerGateway, @@ -122,6 +123,7 @@ RIDSetGateway, RIDSetUseCase, ) +from ldap_protocol.rid_manager.objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache from ldap_protocol.rid_manager.types import HostMachineShortName from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -395,6 +397,10 @@ async def get_redis_for_sessions(self, settings: Settings) -> AsyncIterator[Sess with suppress(RuntimeError): await client.aclose() + @provide(scope=Scope.APP) + def get_objectsid_cache_redis(self, client: SessionStorageClient) -> ObjectSidCacheRedisClient: + return ObjectSidCacheRedisClient(client) + @provide(scope=Scope.REQUEST, provides=MasterGatewayProtocol) async def get_master_gateway(self, session: AsyncSession, settings: Settings) -> PGMasterGateway: return PGMasterGateway(session, settings) @@ -585,6 +591,7 @@ def authorization_provider_protocol(self, identity_provider: IdentityProvider) - rid_manager_setup_gateway = provide(RIDManagerSetupGateway, scope=Scope.REQUEST) rid_manager_setup_use_case = provide(RIDManagerSetupUseCase, scope=Scope.REQUEST) object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST) + objectsid_allowed_object_classes_cache = provide(ObjectSidAllowedObjectClassesCache, scope=Scope.APP) object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST) rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST) rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST) @@ -791,8 +798,16 @@ async def setup_session( rid_set_gateway, entity_type_use_case, session, rid_manager_use_case, role_use_case ) object_sid_gateway = ObjectSIDGateway(session) + redis_client = SessionStorageClient(redis.Redis.from_url(str(settings.SESSION_STORAGE_URL))) + + objectsid_allowed_object_classes_cache = ObjectSidAllowedObjectClassesCache(ObjectSidCacheRedisClient(redis_client)) object_sid_use_case = ObjectSIDUseCase( - object_sid_gateway, rid_set_use_case, session, rid_manager_use_case, object_class_dao + object_sid_gateway, + rid_set_use_case, + session, + rid_manager_use_case, + object_class_dao, + objectsid_allowed_object_classes_cache, ) directory_create_use_case = DirectoryCreateUseCase( session=session, @@ -1414,18 +1429,36 @@ async def object_sid_gateway(container: AsyncContainer) -> AsyncIterator[ObjectS yield ObjectSIDGateway(session) +@pytest_asyncio.fixture(scope="function") +async def objectsid_allowed_object_classes_cache( + container: AsyncContainer, settings: Settings +) -> AsyncIterator[ObjectSidAllowedObjectClassesCache]: + """Provide ObjectSidAllowedObjectClassesCache for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + redis_client = ObjectSidCacheRedisClient(redis.Redis.from_url(str(settings.SESSION_STORAGE_URL))) + yield ObjectSidAllowedObjectClassesCache(ObjectSidCacheRedisClient(redis_client)) + + @pytest_asyncio.fixture(scope="function") async def object_sid_use_case( container: AsyncContainer, rid_manager_use_case: RIDManagerUseCase, rid_set_use_case: RIDSetUseCase, object_sid_gateway: ObjectSIDGateway, + objectsid_allowed_object_classes_cache: ObjectSidAllowedObjectClassesCache, ) -> AsyncIterator[ObjectSIDUseCase]: """Provide RIDManagerUseCase for tests that request it explicitly.""" async with container(scope=Scope.SESSION) as container: session = await container.get(AsyncSession) object_class_dao = ObjectClassDAO(session) - yield ObjectSIDUseCase(object_sid_gateway, rid_set_use_case, session, rid_manager_use_case, object_class_dao) + yield ObjectSIDUseCase( + object_sid_gateway, + rid_set_use_case, + session, + rid_manager_use_case, + object_class_dao, + objectsid_allowed_object_classes_cache, + ) def pytest_configure(config: pytest.Config) -> None: