From a5b653048b95d9132ffafd8da6daaecb5f50db86 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 24 Apr 2026 12:20:20 +0300 Subject: [PATCH 1/5] Enhance: integrate ObjectSidCacheRedisClient and ObjectSidAllowedObjectClassesCache into the RID manager. Update AddRequest to conditionally ensure objectSid based on object class names. Remove unused async cache for allowed object classes. --- app/ioc.py | 11 +++ app/ldap_protocol/ldap_requests/add.py | 6 +- app/ldap_protocol/rid_manager/__init__.py | 4 + .../rid_manager/object_sid_use_case.py | 12 ++- .../objectsid_allowed_object_classes_cache.py | 75 +++++++++++++++++++ app/ldap_protocol/rid_manager/types.py | 3 + app/ldap_protocol/utils/async_cache.py | 1 - tests/conftest.py | 32 +++++++- 8 files changed, 134 insertions(+), 10 deletions(-) create mode 100644 app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py diff --git a/app/ioc.py b/app/ioc.py index 71e76603c..b1d24f7a5 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -122,6 +122,7 @@ from ldap_protocol.policies.password.settings import PasswordValidatorSettings from ldap_protocol.policies.password.use_cases import PasswordBanWordUseCases, UserPasswordHistoryUseCases from ldap_protocol.rid_manager import ( + ObjectSidCacheRedisClient, ObjectSIDGateway, ObjectSIDUseCase, RIDManagerGateway, @@ -131,6 +132,7 @@ RIDSetGateway, RIDSetUseCase, ) +from ldap_protocol.rid_manager.objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache from ldap_protocol.rid_manager.types import HostMachineShortName from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -315,6 +317,14 @@ async def get_redis_for_sessions(self, settings: Settings) -> AsyncIterator[Sess yield SessionStorageClient(client) await client.aclose() + @provide(scope=Scope.APP) + def get_objectsid_cache_redis(self, client: SessionStorageClient) -> ObjectSidCacheRedisClient: + """Typed redis client for objectSid-related caches. + + Important: this does NOT create a new connection, just re-types the existing client. + """ + return ObjectSidCacheRedisClient(client) + @provide(scope=Scope.APP) def get_session_storage(self, client: SessionStorageClient, settings: Settings) -> SessionStorage: """Get session storage.""" @@ -464,6 +474,7 @@ def get_object_class_use_case_legacy(self, session: AsyncSession) -> ObjectClass rid_manager_use_case = provide(RIDManagerUseCase, scope=Scope.REQUEST) rid_manager_setup_use_case = provide(RIDManagerSetupUseCase, scope=Scope.REQUEST) object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST) + objectsid_allowed_object_classes_cache = provide(ObjectSidAllowedObjectClassesCache, scope=Scope.APP) object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST) rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST) rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST) diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 9adafa420..c42abeae7 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -168,9 +168,9 @@ async def handle(self, ctx: LDAPAddRequestContext) -> AsyncGenerator[AddResponse ctx.session.add(new_dir) await ctx.session.flush() - - await ctx.object_sid_use_case.ensure_objectsid(directory_id=new_dir.id) - await ctx.session.flush() + if await ctx.object_sid_use_case.is_objectsid_needed(self.object_class_names): + await ctx.object_sid_use_case.ensure_objectsid(directory_id=new_dir.id) + await ctx.session.flush() except IntegrityError: await ctx.session.rollback() yield AddResponse(result_code=LDAPCodes.ENTRY_ALREADY_EXISTS) diff --git a/app/ldap_protocol/rid_manager/__init__.py b/app/ldap_protocol/rid_manager/__init__.py index 204bbef53..5d4fac68f 100644 --- a/app/ldap_protocol/rid_manager/__init__.py +++ b/app/ldap_protocol/rid_manager/__init__.py @@ -6,12 +6,14 @@ from .object_sid_gateway import ObjectSIDGateway from .object_sid_use_case import ObjectSIDUseCase +from .objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache from .rid_manager_gateway import RIDManagerGateway from .rid_manager_use_case import RIDManagerUseCase from .rid_set_gateway import RIDSetGateway from .rid_set_use_case import RIDSetUseCase from .setup_gateway import RIDManagerSetupGateway from .setup_use_case import RIDManagerSetupUseCase +from .types import ObjectSidCacheRedisClient __all__ = [ "ObjectSIDGateway", @@ -22,4 +24,6 @@ "RIDManagerUseCase", "RIDSetGateway", "RIDSetUseCase", + "ObjectSidAllowedObjectClassesCache", + "ObjectSidCacheRedisClient", ] diff --git a/app/ldap_protocol/rid_manager/object_sid_use_case.py b/app/ldap_protocol/rid_manager/object_sid_use_case.py index 561f219c2..775cdfd40 100644 --- a/app/ldap_protocol/rid_manager/object_sid_use_case.py +++ b/app/ldap_protocol/rid_manager/object_sid_use_case.py @@ -10,9 +10,9 @@ from ldap_protocol.ldap_schema.object_class.object_class_dao import ObjectClassDAO from ldap_protocol.rid_manager.exceptions import RIDManagerObjectSIDNotFoundError from ldap_protocol.rid_manager.object_sid_gateway import ObjectSIDGateway +from ldap_protocol.rid_manager.objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase -from ldap_protocol.utils.async_cache import objectsid_allowed_object_classes_cache class ObjectSIDUseCase: @@ -25,6 +25,7 @@ def __init__( session: AsyncSession, rid_manager_use_case: RIDManagerUseCase, object_class_dao: ObjectClassDAO, + objectsid_allowed_object_classes_cache: ObjectSidAllowedObjectClassesCache, ) -> None: """Initialize Object SID use case.""" self._gateway = gateway @@ -32,12 +33,15 @@ def __init__( self._session = session self._rid_manager_use_case = rid_manager_use_case self._object_class_dao = object_class_dao + self._objectsid_allowed_object_classes_cache = objectsid_allowed_object_classes_cache - @objectsid_allowed_object_classes_cache async def get_available_object_classes(self) -> set[str]: """ObjectClasses that allow objectSid (mustContain/mayContain).""" - names = await self._object_class_dao.get_object_class_names_include_attribute_type("objectSid") - return {n.lower() for n in names} + async def compute() -> set[str]: + names = await self._object_class_dao.get_object_class_names_include_attribute_type("objectSid") + return {n.lower() for n in names} + + return await self._objectsid_allowed_object_classes_cache.get_or_compute(compute) async def is_objectsid_needed(self, object_class_names: set[str]) -> bool: """Check if objectSid is needed for objectClasses.""" diff --git a/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py b/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py new file mode 100644 index 000000000..76adf0133 --- /dev/null +++ b/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py @@ -0,0 +1,75 @@ +"""Redis cache for objectSid-related metadata. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import json +from typing import Awaitable, Callable + +from loguru import logger + +from ldap_protocol.rid_manager.types import ObjectSidCacheRedisClient + + +class ObjectSidAllowedObjectClassesCache: + """Cache for ObjectClass names that allow `objectSid`.""" + + _CACHE_KEY = "ldap:objectsid:allowed_object_classes" + _LOCK_KEY = "lock:ldap:objectsid:allowed_object_classes" + _LOCK_BLOCKING_TIMEOUT_SECONDS = 5 + _LOCK_LEASE_TIMEOUT_SECONDS = 30 + + def __init__(self, redis: ObjectSidCacheRedisClient) -> None: + self._redis = redis + + def _decode(self, raw: bytes | str) -> set[str] | None: + try: + if isinstance(raw, bytes | bytearray): + raw = raw.decode("utf-8", errors="replace") + decoded = json.loads(raw) + except (TypeError, ValueError, json.JSONDecodeError): + return None + + if not isinstance(decoded, list): + return None + + return {str(v).lower() for v in decoded} + + async def get(self) -> set[str] | None: + raw = await self._redis.get(self._CACHE_KEY) + if not raw: + return None + return self._decode(raw) + + async def store(self, value: set[str]) -> None: + normalized = sorted({v.lower() for v in value}) + await self._redis.set(self._CACHE_KEY, json.dumps(normalized)) + + async def clear(self) -> None: + await self._redis.delete(self._CACHE_KEY, self._LOCK_KEY) + + async def get_or_compute(self, compute: Callable[[], Awaitable[set[str]]]) -> set[str]: + """Read from redis, or compute once under a lock.""" + if cached := await self.get(): + logger.critical("ObjectSidAllowedObjectClassesCache: read from redis") + logger.critical(cached) + return cached + + lock = self._redis.lock( + name=self._LOCK_KEY, + blocking_timeout=self._LOCK_BLOCKING_TIMEOUT_SECONDS, + timeout=self._LOCK_LEASE_TIMEOUT_SECONDS, + ) + async with lock: + if cached2 := await self.get(): + logger.critical("ObjectSidAllowedObjectClassesCache: read from redis") + logger.critical(cached2) + return cached2 + + value = await compute() + logger.critical("ObjectSidAllowedObjectClassesCache: computed") + logger.critical(value) + normalized = {v.lower() for v in value} + await self.store(normalized) + return normalized diff --git a/app/ldap_protocol/rid_manager/types.py b/app/ldap_protocol/rid_manager/types.py index ab1790f80..f4bbd0874 100644 --- a/app/ldap_protocol/rid_manager/types.py +++ b/app/ldap_protocol/rid_manager/types.py @@ -6,4 +6,7 @@ from typing import NewType +import redis.asyncio as redis + HostMachineShortName = NewType("HostMachineShortName", str) +ObjectSidCacheRedisClient = NewType("ObjectSidCacheRedisClient", redis.Redis) diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py index b2e96d256..7e25ee3c4 100644 --- a/app/ldap_protocol/utils/async_cache.py +++ b/app/ldap_protocol/utils/async_cache.py @@ -41,4 +41,3 @@ async def wrapper(*args: tuple, **kwargs: dict) -> T: base_directories_cache = AsyncTTLCache[list[Directory]]() domain_identifier_cache = AsyncTTLCache[str]() rid_set_id_cache = AsyncTTLCache[int]() -objectsid_allowed_object_classes_cache = AsyncTTLCache[set[str]](ttl=60 * 60 * 24) diff --git a/tests/conftest.py b/tests/conftest.py index ab935b057..3bc201719 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,6 +113,7 @@ from ldap_protocol.policies.password.settings import PasswordValidatorSettings from ldap_protocol.policies.password.use_cases import PasswordBanWordUseCases, UserPasswordHistoryUseCases from ldap_protocol.rid_manager import ( + ObjectSidCacheRedisClient, ObjectSIDGateway, ObjectSIDUseCase, RIDManagerGateway, @@ -122,6 +123,7 @@ RIDSetGateway, RIDSetUseCase, ) +from ldap_protocol.rid_manager.objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache from ldap_protocol.rid_manager.types import HostMachineShortName from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -791,8 +793,16 @@ async def setup_session( rid_set_gateway, entity_type_use_case, session, rid_manager_use_case, role_use_case ) object_sid_gateway = ObjectSIDGateway(session) + redis_client = SessionStorageClient(redis.Redis.from_url(str(settings.SESSION_STORAGE_URL))) + + objectsid_allowed_object_classes_cache = ObjectSidAllowedObjectClassesCache(ObjectSidCacheRedisClient(redis_client)) object_sid_use_case = ObjectSIDUseCase( - object_sid_gateway, rid_set_use_case, session, rid_manager_use_case, object_class_dao + object_sid_gateway, + rid_set_use_case, + session, + rid_manager_use_case, + object_class_dao, + objectsid_allowed_object_classes_cache, ) directory_create_use_case = DirectoryCreateUseCase( session=session, @@ -1414,18 +1424,36 @@ async def object_sid_gateway(container: AsyncContainer) -> AsyncIterator[ObjectS yield ObjectSIDGateway(session) +@pytest_asyncio.fixture(scope="function") +async def objectsid_allowed_object_classes_cache( + container: AsyncContainer, settings: Settings +) -> AsyncIterator[ObjectSidAllowedObjectClassesCache]: + """Provide ObjectSidAllowedObjectClassesCache for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + redis_client = ObjectSidCacheRedisClient(redis.Redis.from_url(str(settings.SESSION_STORAGE_URL))) + yield ObjectSidAllowedObjectClassesCache(ObjectSidCacheRedisClient(redis_client)) + + @pytest_asyncio.fixture(scope="function") async def object_sid_use_case( container: AsyncContainer, rid_manager_use_case: RIDManagerUseCase, rid_set_use_case: RIDSetUseCase, object_sid_gateway: ObjectSIDGateway, + objectsid_allowed_object_classes_cache: ObjectSidAllowedObjectClassesCache, ) -> AsyncIterator[ObjectSIDUseCase]: """Provide RIDManagerUseCase for tests that request it explicitly.""" async with container(scope=Scope.SESSION) as container: session = await container.get(AsyncSession) object_class_dao = ObjectClassDAO(session) - yield ObjectSIDUseCase(object_sid_gateway, rid_set_use_case, session, rid_manager_use_case, object_class_dao) + yield ObjectSIDUseCase( + object_sid_gateway, + rid_set_use_case, + session, + rid_manager_use_case, + object_class_dao, + objectsid_allowed_object_classes_cache, + ) def pytest_configure(config: pytest.Config) -> None: From cdcce7e2aa5680aee958b0eefbb4a03f5cb70c96 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 24 Apr 2026 12:20:47 +0300 Subject: [PATCH 2/5] refactor: remove logging statements from ObjectSidAllowedObjectClassesCache to streamline cache operations --- .../rid_manager/objectsid_allowed_object_classes_cache.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py b/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py index 76adf0133..bc0d2d5b6 100644 --- a/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py +++ b/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py @@ -7,8 +7,6 @@ import json from typing import Awaitable, Callable -from loguru import logger - from ldap_protocol.rid_manager.types import ObjectSidCacheRedisClient @@ -52,8 +50,6 @@ async def clear(self) -> None: async def get_or_compute(self, compute: Callable[[], Awaitable[set[str]]]) -> set[str]: """Read from redis, or compute once under a lock.""" if cached := await self.get(): - logger.critical("ObjectSidAllowedObjectClassesCache: read from redis") - logger.critical(cached) return cached lock = self._redis.lock( @@ -63,13 +59,9 @@ async def get_or_compute(self, compute: Callable[[], Awaitable[set[str]]]) -> se ) async with lock: if cached2 := await self.get(): - logger.critical("ObjectSidAllowedObjectClassesCache: read from redis") - logger.critical(cached2) return cached2 value = await compute() - logger.critical("ObjectSidAllowedObjectClassesCache: computed") - logger.critical(value) normalized = {v.lower() for v in value} await self.store(normalized) return normalized From 029b4283ee017f86303487abda8e916a8b0a19bc Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 24 Apr 2026 12:22:01 +0300 Subject: [PATCH 3/5] refactor: streamline objectSid handling in AddRequest by removing conditional check --- app/ldap_protocol/ldap_requests/add.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index c42abeae7..840b7cf92 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -168,9 +168,8 @@ async def handle(self, ctx: LDAPAddRequestContext) -> AsyncGenerator[AddResponse ctx.session.add(new_dir) await ctx.session.flush() - if await ctx.object_sid_use_case.is_objectsid_needed(self.object_class_names): - await ctx.object_sid_use_case.ensure_objectsid(directory_id=new_dir.id) - await ctx.session.flush() + await ctx.object_sid_use_case.ensure_objectsid(directory_id=new_dir.id) + await ctx.session.flush() except IntegrityError: await ctx.session.rollback() yield AddResponse(result_code=LDAPCodes.ENTRY_ALREADY_EXISTS) From 611479c859bdec3fd6ad61f7fa8815cbfb26fb06 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 24 Apr 2026 12:37:51 +0300 Subject: [PATCH 4/5] refactor: simplify _decode method in ObjectSidAllowedObjectClassesCache by removing redundant checks --- .../objectsid_allowed_object_classes_cache.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py b/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py index bc0d2d5b6..a49401f24 100644 --- a/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py +++ b/app/ldap_protocol/rid_manager/objectsid_allowed_object_classes_cache.py @@ -22,14 +22,11 @@ def __init__(self, redis: ObjectSidCacheRedisClient) -> None: self._redis = redis def _decode(self, raw: bytes | str) -> set[str] | None: + if isinstance(raw, bytes | bytearray): + raw = raw.decode("utf-8", errors="replace") try: - if isinstance(raw, bytes | bytearray): - raw = raw.decode("utf-8", errors="replace") decoded = json.loads(raw) - except (TypeError, ValueError, json.JSONDecodeError): - return None - - if not isinstance(decoded, list): + except json.JSONDecodeError: return None return {str(v).lower() for v in decoded} From bc8a14f2602bcb8bb81b0fe2d4835f51fbcfa199 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 24 Apr 2026 12:50:05 +0300 Subject: [PATCH 5/5] refactor: add compute method for retrieving available object classes in ObjectSIDUseCase and integrate ObjectSidCacheRedisClient into TestProvider --- app/ldap_protocol/rid_manager/object_sid_use_case.py | 1 + tests/conftest.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/app/ldap_protocol/rid_manager/object_sid_use_case.py b/app/ldap_protocol/rid_manager/object_sid_use_case.py index 775cdfd40..5c6c6fbdb 100644 --- a/app/ldap_protocol/rid_manager/object_sid_use_case.py +++ b/app/ldap_protocol/rid_manager/object_sid_use_case.py @@ -37,6 +37,7 @@ def __init__( async def get_available_object_classes(self) -> set[str]: """ObjectClasses that allow objectSid (mustContain/mayContain).""" + async def compute() -> set[str]: names = await self._object_class_dao.get_object_class_names_include_attribute_type("objectSid") return {n.lower() for n in names} diff --git a/tests/conftest.py b/tests/conftest.py index 3bc201719..e88cdc632 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -397,6 +397,10 @@ async def get_redis_for_sessions(self, settings: Settings) -> AsyncIterator[Sess with suppress(RuntimeError): await client.aclose() + @provide(scope=Scope.APP) + def get_objectsid_cache_redis(self, client: SessionStorageClient) -> ObjectSidCacheRedisClient: + return ObjectSidCacheRedisClient(client) + @provide(scope=Scope.REQUEST, provides=MasterGatewayProtocol) async def get_master_gateway(self, session: AsyncSession, settings: Settings) -> PGMasterGateway: return PGMasterGateway(session, settings) @@ -587,6 +591,7 @@ def authorization_provider_protocol(self, identity_provider: IdentityProvider) - rid_manager_setup_gateway = provide(RIDManagerSetupGateway, scope=Scope.REQUEST) rid_manager_setup_use_case = provide(RIDManagerSetupUseCase, scope=Scope.REQUEST) object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST) + objectsid_allowed_object_classes_cache = provide(ObjectSidAllowedObjectClassesCache, scope=Scope.APP) object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST) rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST) rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST)