Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/autointent/generation/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import logging
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
Expand Down Expand Up @@ -43,6 +44,17 @@ def _get_structured_output_cache_path(dirname: str) -> Path:
return Path(user_cache_dir("autointent")) / "structured_outputs" / dirname


def _remove_cache_entry(path: Path) -> None:
"""Remove a single on-disk cache entry.

Each entry is a *directory* (``PydanticModelDumper.dump`` writes
``class_info.json`` + ``model_dump.json`` inside it), so eviction must use
``rmtree`` rather than ``unlink``. ``ignore_errors`` keeps a missing or
partially removed entry from raising during cleanup.
"""
shutil.rmtree(path, ignore_errors=True)


class StructuredOutputCache:
"""Cache for structured output results."""

Expand Down Expand Up @@ -70,8 +82,10 @@ def _load_existing_cache(self) -> None:
if not cache_dir.exists():
return

# Get all cache files to process
cache_files = [f for f in cache_dir.iterdir() if f.is_file()]
# Each cache entry is a directory written by PydanticModelDumper, so
# collect directories (filtering on is_file() here matched nothing and
# silently disabled eager loading entirely).
cache_files = [f for f in cache_dir.iterdir() if f.is_dir()]

if not cache_files:
return
Expand Down Expand Up @@ -118,7 +132,7 @@ def _load_single_cache_file(self, cache_file: Path) -> tuple[str, BaseModel] | N
cached_data = PydanticModelDumper.load(cache_file)
except (ValidationError, ImportError) as e:
logger.warning("Failed to load cached item %s: %s", cache_file.name, e)
cache_file.unlink(missing_ok=True)
_remove_cache_entry(cache_file)
else:
return cache_file.name, cached_data

Expand Down Expand Up @@ -184,10 +198,10 @@ def _load_from_disk(self, cache_key: str, output_model: type[T]) -> T | None:
return cached_data

logger.warning("Cached data type mismatch on disk, removing invalid cache")
cache_path.unlink()
_remove_cache_entry(cache_path)
except (ValidationError, ImportError) as e:
logger.warning("Failed to load cached structured output from disk: %s", e)
cache_path.unlink(missing_ok=True)
_remove_cache_entry(cache_path)

return None

Expand Down Expand Up @@ -271,10 +285,10 @@ async def _load_from_disk_async(self, cache_key: str, output_model: type[T]) ->
return cached_data

logger.warning("Cached data type mismatch on disk, removing invalid cache")
cache_path.unlink()
_remove_cache_entry(cache_path)
except (ValidationError, ImportError) as e:
logger.warning("Failed to load cached structured output from disk: %s", e)
cache_path.unlink(missing_ok=True)
_remove_cache_entry(cache_path)

return None

Expand Down
59 changes: 58 additions & 1 deletion tests/generation/structured_output/test_cache_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

import pytest
from pydantic import BaseModel

from autointent.generation._cache import StructuredOutputCache
from autointent.generation._cache import StructuredOutputCache, _get_structured_output_cache_path
from autointent.generation.chat_templates import Role

if TYPE_CHECKING:
Expand Down Expand Up @@ -98,3 +99,59 @@ async def test_async_disabled_cache_is_noop() -> None:
cache = StructuredOutputCache(use_cache=False)
await cache.set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1))
assert await cache.get_async(MESSAGES, CacheModel, PARAMS) is None


# --- Regression tests for the on-disk-cache bugs (#326 eager load, #327 eviction) ---
# Disk entries are directories (PydanticModelDumper writes class_info.json +
# model_dump.json), so eager load must collect directories and eviction must
# rmtree rather than unlink.


def test_eager_load_populates_memory_from_disk() -> None:
"""A fresh instance eagerly batch-loads existing on-disk entries into memory (#326)."""
StructuredOutputCache(use_cache=True).set(MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9))

fresh = StructuredOutputCache(use_cache=True)
key = fresh._get_cache_key(MESSAGES, CacheModel, PARAMS)

# populated at construction by the eager load, before any get() call
assert key in fresh._memory_cache
assert isinstance(fresh._memory_cache[key], CacheModel)


def test_eager_load_removes_corrupted_entry() -> None:
"""A cache directory whose payload fails to load is skipped and cleaned up, not raised."""
entry = _get_structured_output_cache_path("corrupted-entry")
entry.mkdir(parents=True)
(entry / "class_info.json").write_text(json.dumps({"name": CacheModel.__name__, "module": CacheModel.__module__}))
# missing the required "value" field -> ValidationError on load
(entry / "model_dump.json").write_text(json.dumps({"name": "x"}))

cache = StructuredOutputCache(use_cache=True) # eager load must not raise

assert not cache._memory_cache
assert not entry.exists()


def test_disk_type_mismatch_evicts_entry() -> None:
"""A type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327)."""
cache = StructuredOutputCache(use_cache=True)
# plant a CacheModel at the key the cache derives for OtherModel inputs
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS)
cache._save_to_disk(key, CacheModel(name="x", value=1))
cache._memory_cache.clear()

assert cache._load_from_disk(key, OtherModel) is None
assert not _get_structured_output_cache_path(key).exists()


@pytest.mark.asyncio
async def test_async_disk_type_mismatch_evicts_entry() -> None:
"""Async type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327)."""
cache = StructuredOutputCache(use_cache=True)
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS)
await cache._save_to_disk_async(key, CacheModel(name="x", value=1))
cache._memory_cache.clear()

assert await cache._load_from_disk_async(key, OtherModel) is None
assert not _get_structured_output_cache_path(key).exists()