Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions app/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
from ldap_protocol.policies.password.settings import PasswordValidatorSettings
from ldap_protocol.policies.password.use_cases import PasswordBanWordUseCases, UserPasswordHistoryUseCases
from ldap_protocol.rid_manager import (
ObjectSidCacheRedisClient,
ObjectSIDGateway,
ObjectSIDUseCase,
RIDManagerGateway,
Expand All @@ -131,6 +132,7 @@
RIDSetGateway,
RIDSetUseCase,
)
from ldap_protocol.rid_manager.objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache
from ldap_protocol.rid_manager.types import HostMachineShortName
from ldap_protocol.roles.access_manager import AccessManager
from ldap_protocol.roles.ace_dao import AccessControlEntryDAO
Expand Down Expand Up @@ -315,6 +317,14 @@ async def get_redis_for_sessions(self, settings: Settings) -> AsyncIterator[Sess
yield SessionStorageClient(client)
await client.aclose()

@provide(scope=Scope.APP)
def get_objectsid_cache_redis(self, client: SessionStorageClient) -> ObjectSidCacheRedisClient:
"""Typed redis client for objectSid-related caches.

Important: this does NOT create a new connection, just re-types the existing client.
"""
return ObjectSidCacheRedisClient(client)

@provide(scope=Scope.APP)
def get_session_storage(self, client: SessionStorageClient, settings: Settings) -> SessionStorage:
"""Get session storage."""
Expand Down Expand Up @@ -464,6 +474,7 @@ def get_object_class_use_case_legacy(self, session: AsyncSession) -> ObjectClass
rid_manager_use_case = provide(RIDManagerUseCase, scope=Scope.REQUEST)
rid_manager_setup_use_case = provide(RIDManagerSetupUseCase, scope=Scope.REQUEST)
object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST)
objectsid_allowed_object_classes_cache = provide(ObjectSidAllowedObjectClassesCache, scope=Scope.APP)
object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST)
rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST)
rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST)
Expand Down
1 change: 0 additions & 1 deletion app/ldap_protocol/ldap_requests/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ async def handle(self, ctx: LDAPAddRequestContext) -> AsyncGenerator[AddResponse
ctx.session.add(new_dir)

await ctx.session.flush()

await ctx.object_sid_use_case.ensure_objectsid(directory_id=new_dir.id)
await ctx.session.flush()
except IntegrityError:
Expand Down
4 changes: 4 additions & 0 deletions app/ldap_protocol/rid_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

from .object_sid_gateway import ObjectSIDGateway
from .object_sid_use_case import ObjectSIDUseCase
from .objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache
from .rid_manager_gateway import RIDManagerGateway
from .rid_manager_use_case import RIDManagerUseCase
from .rid_set_gateway import RIDSetGateway
from .rid_set_use_case import RIDSetUseCase
from .setup_gateway import RIDManagerSetupGateway
from .setup_use_case import RIDManagerSetupUseCase
from .types import ObjectSidCacheRedisClient

__all__ = [
"ObjectSIDGateway",
Expand All @@ -22,4 +24,6 @@
"RIDManagerUseCase",
"RIDSetGateway",
"RIDSetUseCase",
"ObjectSidAllowedObjectClassesCache",
"ObjectSidCacheRedisClient",
]
13 changes: 9 additions & 4 deletions app/ldap_protocol/rid_manager/object_sid_use_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from ldap_protocol.ldap_schema.object_class.object_class_dao import ObjectClassDAO
from ldap_protocol.rid_manager.exceptions import RIDManagerObjectSIDNotFoundError
from ldap_protocol.rid_manager.object_sid_gateway import ObjectSIDGateway
from ldap_protocol.rid_manager.objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache
from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase
from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase
from ldap_protocol.utils.async_cache import objectsid_allowed_object_classes_cache


class ObjectSIDUseCase:
Expand All @@ -25,19 +25,24 @@ def __init__(
session: AsyncSession,
rid_manager_use_case: RIDManagerUseCase,
object_class_dao: ObjectClassDAO,
objectsid_allowed_object_classes_cache: ObjectSidAllowedObjectClassesCache,
) -> None:
"""Initialize Object SID use case."""
self._gateway = gateway
self._rid_set_use_case = rid_set_use_case
self._session = session
self._rid_manager_use_case = rid_manager_use_case
self._object_class_dao = object_class_dao
self._objectsid_allowed_object_classes_cache = objectsid_allowed_object_classes_cache

@objectsid_allowed_object_classes_cache
async def get_available_object_classes(self) -> set[str]:
"""ObjectClasses that allow objectSid (mustContain/mayContain)."""
names = await self._object_class_dao.get_object_class_names_include_attribute_type("objectSid")
return {n.lower() for n in names}

async def compute() -> set[str]:
names = await self._object_class_dao.get_object_class_names_include_attribute_type("objectSid")
return {n.lower() for n in names}

return await self._objectsid_allowed_object_classes_cache.get_or_compute(compute)
Comment on lines +41 to +45
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

давай async def compute унесем внутрь ObjectSidAllowedObjectClassesCache
думаю тут нет смысла держать эту функцию

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

обсудили голосом


async def is_objectsid_needed(self, object_class_names: set[str]) -> bool:
"""Check if objectSid is needed for objectClasses."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Redis cache for objectSid-related metadata.

Copyright (c) 2026 MultiFactor
License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE
"""

import json
from typing import Awaitable, Callable

from ldap_protocol.rid_manager.types import ObjectSidCacheRedisClient


class ObjectSidAllowedObjectClassesCache:
"""Cache for ObjectClass names that allow `objectSid`."""

_CACHE_KEY = "ldap:objectsid:allowed_object_classes"
_LOCK_KEY = "lock:ldap:objectsid:allowed_object_classes"
_LOCK_BLOCKING_TIMEOUT_SECONDS = 5
_LOCK_LEASE_TIMEOUT_SECONDS = 30

def __init__(self, redis: ObjectSidCacheRedisClient) -> None:
self._redis = redis

def _decode(self, raw: bytes | str) -> set[str] | None:
if isinstance(raw, bytes | bytearray):
raw = raw.decode("utf-8", errors="replace")
try:
decoded = json.loads(raw)
except json.JSONDecodeError:
return None

return {str(v).lower() for v in decoded}

async def get(self) -> set[str] | None:
raw = await self._redis.get(self._CACHE_KEY)
if not raw:
return None
return self._decode(raw)

async def store(self, value: set[str]) -> None:
normalized = sorted({v.lower() for v in value})
await self._redis.set(self._CACHE_KEY, json.dumps(normalized))

async def clear(self) -> None:
await self._redis.delete(self._CACHE_KEY, self._LOCK_KEY)

async def get_or_compute(self, compute: Callable[[], Awaitable[set[str]]]) -> set[str]:
"""Read from redis, or compute once under a lock."""
if cached := await self.get():
return cached

lock = self._redis.lock(
name=self._LOCK_KEY,
blocking_timeout=self._LOCK_BLOCKING_TIMEOUT_SECONDS,
timeout=self._LOCK_LEASE_TIMEOUT_SECONDS,
)
async with lock:
if cached2 := await self.get():
return cached2

value = await compute()
normalized = {v.lower() for v in value}
await self.store(normalized)
return normalized
3 changes: 3 additions & 0 deletions app/ldap_protocol/rid_manager/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@

from typing import NewType

import redis.asyncio as redis

HostMachineShortName = NewType("HostMachineShortName", str)
ObjectSidCacheRedisClient = NewType("ObjectSidCacheRedisClient", redis.Redis)
1 change: 0 additions & 1 deletion app/ldap_protocol/utils/async_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,3 @@ async def wrapper(*args: tuple, **kwargs: dict) -> T:
base_directories_cache = AsyncTTLCache[list[Directory]]()
domain_identifier_cache = AsyncTTLCache[str]()
rid_set_id_cache = AsyncTTLCache[int]()
objectsid_allowed_object_classes_cache = AsyncTTLCache[set[str]](ttl=60 * 60 * 24)
37 changes: 35 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
from ldap_protocol.policies.password.settings import PasswordValidatorSettings
from ldap_protocol.policies.password.use_cases import PasswordBanWordUseCases, UserPasswordHistoryUseCases
from ldap_protocol.rid_manager import (
ObjectSidCacheRedisClient,
ObjectSIDGateway,
ObjectSIDUseCase,
RIDManagerGateway,
Expand All @@ -122,6 +123,7 @@
RIDSetGateway,
RIDSetUseCase,
)
from ldap_protocol.rid_manager.objectsid_allowed_object_classes_cache import ObjectSidAllowedObjectClassesCache
from ldap_protocol.rid_manager.types import HostMachineShortName
from ldap_protocol.roles.access_manager import AccessManager
from ldap_protocol.roles.ace_dao import AccessControlEntryDAO
Expand Down Expand Up @@ -395,6 +397,10 @@ async def get_redis_for_sessions(self, settings: Settings) -> AsyncIterator[Sess
with suppress(RuntimeError):
await client.aclose()

@provide(scope=Scope.APP)
def get_objectsid_cache_redis(self, client: SessionStorageClient) -> ObjectSidCacheRedisClient:
return ObjectSidCacheRedisClient(client)

@provide(scope=Scope.REQUEST, provides=MasterGatewayProtocol)
async def get_master_gateway(self, session: AsyncSession, settings: Settings) -> PGMasterGateway:
return PGMasterGateway(session, settings)
Expand Down Expand Up @@ -585,6 +591,7 @@ def authorization_provider_protocol(self, identity_provider: IdentityProvider) -
rid_manager_setup_gateway = provide(RIDManagerSetupGateway, scope=Scope.REQUEST)
rid_manager_setup_use_case = provide(RIDManagerSetupUseCase, scope=Scope.REQUEST)
object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST)
objectsid_allowed_object_classes_cache = provide(ObjectSidAllowedObjectClassesCache, scope=Scope.APP)
object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST)
rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST)
rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST)
Expand Down Expand Up @@ -791,8 +798,16 @@ async def setup_session(
rid_set_gateway, entity_type_use_case, session, rid_manager_use_case, role_use_case
)
object_sid_gateway = ObjectSIDGateway(session)
redis_client = SessionStorageClient(redis.Redis.from_url(str(settings.SESSION_STORAGE_URL)))

objectsid_allowed_object_classes_cache = ObjectSidAllowedObjectClassesCache(ObjectSidCacheRedisClient(redis_client))
object_sid_use_case = ObjectSIDUseCase(
object_sid_gateway, rid_set_use_case, session, rid_manager_use_case, object_class_dao
object_sid_gateway,
rid_set_use_case,
session,
rid_manager_use_case,
object_class_dao,
objectsid_allowed_object_classes_cache,
)
directory_create_use_case = DirectoryCreateUseCase(
session=session,
Expand Down Expand Up @@ -1414,18 +1429,36 @@ async def object_sid_gateway(container: AsyncContainer) -> AsyncIterator[ObjectS
yield ObjectSIDGateway(session)


@pytest_asyncio.fixture(scope="function")
async def objectsid_allowed_object_classes_cache(
container: AsyncContainer, settings: Settings
) -> AsyncIterator[ObjectSidAllowedObjectClassesCache]:
"""Provide ObjectSidAllowedObjectClassesCache for tests that request it explicitly."""
async with container(scope=Scope.SESSION) as container:
redis_client = ObjectSidCacheRedisClient(redis.Redis.from_url(str(settings.SESSION_STORAGE_URL)))
yield ObjectSidAllowedObjectClassesCache(ObjectSidCacheRedisClient(redis_client))


@pytest_asyncio.fixture(scope="function")
async def object_sid_use_case(
container: AsyncContainer,
rid_manager_use_case: RIDManagerUseCase,
rid_set_use_case: RIDSetUseCase,
object_sid_gateway: ObjectSIDGateway,
objectsid_allowed_object_classes_cache: ObjectSidAllowedObjectClassesCache,
) -> AsyncIterator[ObjectSIDUseCase]:
"""Provide RIDManagerUseCase for tests that request it explicitly."""
async with container(scope=Scope.SESSION) as container:
session = await container.get(AsyncSession)
object_class_dao = ObjectClassDAO(session)
yield ObjectSIDUseCase(object_sid_gateway, rid_set_use_case, session, rid_manager_use_case, object_class_dao)
yield ObjectSIDUseCase(
object_sid_gateway,
rid_set_use_case,
session,
rid_manager_use_case,
object_class_dao,
objectsid_allowed_object_classes_cache,
)


def pytest_configure(config: pytest.Config) -> None:
Expand Down
Loading