diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..72455a4 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,16 @@ +.git +.github +.venv +.pytest_cache +.vscode +__pycache__ +*.pyc +*.pyo +*.pyd +*.egg-info +build +dist +tests +sample_data +tmp_forag_test +tmp_forag_test_remote diff --git a/.gitignore b/.gitignore index 10ef763..0156186 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,7 @@ local/ pylate/* pylate/ -.byaldi/ \ No newline at end of file +.byaldi/ + +.pat + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..73bcd79 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,178 @@ +# syntax=docker/dockerfile:1.7 +FROM python:3.12-slim AS runtime-base + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +WORKDIR /app + +# Copy only dependency metadata first for better layer caching +COPY pyproject.toml ./ + +# Install third-party runtime dependencies from pyproject extras. +# This layer is invalidated only when dependency metadata changes. +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --upgrade pip setuptools wheel && \ + python - <<'PY' > /tmp/requirements-runtime.txt +import tomllib +from pathlib import Path + +cfg = tomllib.loads(Path("pyproject.toml").read_text()) +project = cfg["project"] +deps = list(project.get("dependencies", [])) +optional = project.get("optional-dependencies", {}) +deps.extend(optional.get("server", [])) + +# Stable order and dedupe for deterministic cache behavior. +seen = set() +ordered = [] +for dep in deps: + if dep not in seen: + seen.add(dep) + ordered.append(dep) +print("\n".join(ordered)) +PY +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install -r /tmp/requirements-runtime.txt + +# Copy package sources late so app code changes do not bust dependency layers +COPY README.md ./ +COPY foretrieval ./foretrieval + +# Install local package without re-resolving dependencies +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --no-deps . + + + + + + + +# Shared CUDA Python runtime base for GPU stages. +# Keep this stage free of app source copies so app-only changes do not bust +# heavyweight builder caches (e.g. flash-attn wheel stage). +FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 AS cuda-python-runtime-base + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + CUDA_HOME=/usr/local/cuda + +RUN apt-get update && apt-get install -y --no-install-recommends \ + software-properties-common \ + ca-certificates \ + gnupg \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update && apt-get install -y --no-install-recommends \ + python3.12 \ + python3.12-venv \ + && rm -rf /var/lib/apt/lists/* + +RUN ln -sf /usr/bin/python3.12 /usr/local/bin/python +RUN python -m venv /opt/venv +ENV VIRTUAL_ENV=/opt/venv +ENV PATH="/opt/venv/bin:${PATH}" + +WORKDIR /app +COPY pyproject.toml ./ + +# Single dependency/app installation point for GPU pipeline. +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --upgrade pip setuptools wheel && \ + python - <<'PY' > /tmp/requirements-runtime.txt +import tomllib +from pathlib import Path + +cfg = tomllib.loads(Path("pyproject.toml").read_text()) +project = cfg["project"] +deps = list(project.get("dependencies", [])) +optional = project.get("optional-dependencies", {}) +deps.extend(optional.get("server", [])) + +seen = set() +ordered = [] +for dep in deps: + if dep not in seen: + seen.add(dep) + ordered.append(dep) +print("\n".join(ordered)) +PY + +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install -r /tmp/requirements-runtime.txt && \ + python -m pip install "bitsandbytes>=0.43" + +# App installation layer for CUDA runtime images. +FROM cuda-python-runtime-base AS cuda-app-base + +COPY README.md ./ +COPY foretrieval ./foretrieval + +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --no-deps . + +# Dedicated builder for flash-attn wheel artifacts. +FROM cuda-python-runtime-base AS flashattn-builder + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + git \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /build +RUN mkdir -p /wheels && \ + python -m pip wheel --no-build-isolation flash-attn -w /wheels + +# GPU target: installs flash-attn wheel built in dedicated builder. +FROM cuda-app-base AS gpu + +COPY --from=flashattn-builder /wheels /tmp/wheels + +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install /tmp/wheels/flash_attn-*.whl + +# Server runtime defaults (override at `docker run -e ...`) +ENV FOR_EMBED_MODEL=vidore/colqwen2-v1.0 \ + FOR_EMBED_DEVICE=cuda \ + FOR_EMBED_VERBOSE=1 \ + FOR_EMBED_FLASH_ATTENTION=auto \ + FOR_SERVER_WORKERS=1 \ + FOR_SERVER_MAX_INFLIGHT=1 \ + FOR_SERVER_SINGLE_MODEL_CACHE=true \ + FOR_EMBED_LOAD_IN_4BIT=false \ + FOR_EMBED_LOAD_IN_8BIT=false \ + FOR_EMBED_BNB_4BIT_QUANT_TYPE=nf4 \ + FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE=float16 + +EXPOSE 8000 + +CMD ["sh", "-c", "python -m uvicorn foretrieval.server.server_main:app --host 0.0.0.0 --port 8000 --workers ${FOR_SERVER_WORKERS:-1}"] + + + + + + +# Default target: CPU-friendly runtime without GPU-specific optional deps. +FROM runtime-base AS cpu + +# Server runtime defaults (override at `docker run -e ...`) +ENV FOR_EMBED_MODEL=vidore/colqwen2-v1.0 \ + FOR_EMBED_DEVICE=cpu \ + FOR_EMBED_VERBOSE=1 \ + FOR_EMBED_FLASH_ATTENTION=auto \ + FOR_SERVER_WORKERS=1 \ + FOR_SERVER_MAX_INFLIGHT=1 \ + FOR_SERVER_SINGLE_MODEL_CACHE=true \ + FOR_EMBED_LOAD_IN_4BIT=false \ + FOR_EMBED_LOAD_IN_8BIT=false \ + FOR_EMBED_BNB_4BIT_QUANT_TYPE=nf4 \ + FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE=float16 + +EXPOSE 8000 + +CMD ["sh", "-c", "python -m uvicorn foretrieval.server.server_main:app --host 0.0.0.0 --port 8000 --workers ${FOR_SERVER_WORKERS:-1}"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0507a37 --- /dev/null +++ b/Makefile @@ -0,0 +1,69 @@ +REGISTRY ?= ghcr.io +OWNER ?= random-plm +IMAGE_NAME ?= foretrieval-server +IMAGE_TAG ?= v0.0.1 +IMAGE_REPO ?= $(REGISTRY)/$(OWNER)/$(IMAGE_NAME) +IMAGE ?= $(IMAGE_REPO):$(IMAGE_TAG) +IMAGE_CPU ?= $(IMAGE_REPO):$(IMAGE_TAG)-base +GPU_CUDA_VERSION ?= 12.2.2 +GPU_CUDNN_MAJOR ?= 8 +HOST_ARCH_RAW := $(shell uname -m) +HOST_ARCH ?= $(if $(filter x86_64 amd64,$(HOST_ARCH_RAW)),amd64,$(if $(filter aarch64 arm64,$(HOST_ARCH_RAW)),arm64,$(HOST_ARCH_RAW))) +IMAGE_GPU ?= $(IMAGE_REPO):$(IMAGE_TAG)-cuda$(GPU_CUDA_VERSION)-cudnn$(GPU_CUDNN_MAJOR)-flashattn-bnb-$(HOST_ARCH) + +# Backward-compatible aliases. +NAMESPACE ?= $(OWNER) +PROJECT ?= $(IMAGE_NAME) +DOCKERFILE ?= Dockerfile +BUILD_CONTEXT ?= . +PYTEST ?= pytest +TEST_ARGS ?= -m "not slow and not integration" +SRC_PACKAGE ?= foretrieval + +.PHONY: check-buildx login-registry build build-cpu build-gpu publish publish-cpu publish-gpu run-server test-fast test-all coverage-fast coverage-all + +check-buildx: + @docker buildx version >/dev/null 2>&1 || { \ + echo "Error: docker buildx is not available."; \ + echo "Install Docker Buildx (or Docker Desktop), then run this command again."; \ + echo "Linux plugin package is often named 'docker-buildx-plugin'."; \ + exit 1; \ + } + +login-registry: + @test -n "$$GITHUB_PAT" || { \ + echo "Error: GITHUB_PAT is not set. Export a classic PAT with write:packages."; \ + exit 1; \ + } + @printf '%s' "$$GITHUB_PAT" | docker login ghcr.io -u $(OWNER) --password-stdin + +build: build-cpu build-gpu + +build-cpu: check-buildx + docker buildx build --load -f $(DOCKERFILE) --target cpu -t $(IMAGE_CPU) $(BUILD_CONTEXT) + +build-gpu: check-buildx + docker buildx build --load -f $(DOCKERFILE) --target gpu -t $(IMAGE_GPU) $(BUILD_CONTEXT) + +publish: publish-cpu publish-gpu + +publish-cpu: check-buildx login-registry + docker buildx build --push -f $(DOCKERFILE) --target cpu -t $(IMAGE_CPU) $(BUILD_CONTEXT) + +publish-gpu: check-buildx login-registry + docker buildx build --push -f $(DOCKERFILE) --target gpu -t $(IMAGE_GPU) $(BUILD_CONTEXT) + +run-server: + IMAGE=$(IMAGE_GPU) REGISTRY=$(REGISTRY) OWNER=$(OWNER) IMAGE_NAME=$(IMAGE_NAME) IMAGE_TAG=$(IMAGE_TAG) ./scripts/run-docker.sh + +test-fast: + $(PYTEST) $(TEST_ARGS) + +test-all: + $(PYTEST) + +coverage-fast: + $(PYTEST) $(TEST_ARGS) --cov=$(SRC_PACKAGE) --cov-report=term-missing --cov-report=xml + +coverage-all: + $(PYTEST) --cov=$(SRC_PACKAGE) --cov-report=term-missing --cov-report=xml diff --git a/foretrieval/client/__init__.py b/foretrieval/client/__init__.py new file mode 100644 index 0000000..f6ffba7 --- /dev/null +++ b/foretrieval/client/__init__.py @@ -0,0 +1,3 @@ +from .embedding_backends import LocalEmbeddingBackend, RemoteEmbeddingBackend +from .remote_backend import RemoteEmbeddingClient +from .transport import dumps_payload, loads_payload diff --git a/foretrieval/client/embedding_backends.py b/foretrieval/client/embedding_backends.py new file mode 100644 index 0000000..e492987 --- /dev/null +++ b/foretrieval/client/embedding_backends.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import List, Protocol, Union + +import torch +from PIL import Image + +from .remote_backend import RemoteEmbeddingClient + + +class EmbeddingBackend(Protocol): + def embed_images(self, images: List[Image.Image]) -> torch.Tensor: + ... + + def embed_queries(self, queries: List[str]) -> torch.Tensor: + ... + + +class LocalEmbeddingBackend: + def __init__(self, model, processor, device): + self.model = model + self.processor = processor + self.device = device + + def embed_images(self, images: List[Image.Image]) -> torch.Tensor: + with torch.inference_mode(): + batch = self.processor.process_images(images) + batch = { + k: v.to(self.device).to( + self.model.dtype + if v.dtype in [torch.float16, torch.bfloat16, torch.float32] + else v.dtype + ) + for k, v in batch.items() + } + embeddings = self.model(**batch) + return embeddings.cpu() + + def embed_queries(self, queries: List[str]) -> torch.Tensor: + with torch.inference_mode(): + batch = self.processor.process_queries(queries) + batch = { + k: v.to(self.device).to( + self.model.dtype + if v.dtype in [torch.float16, torch.bfloat16, torch.float32] + else v.dtype + ) + for k, v in batch.items() + } + embeddings = self.model(**batch) + return embeddings.cpu() + + +class RemoteEmbeddingBackend: + def __init__(self, client: RemoteEmbeddingClient): + self.client = client + + def embed_images(self, images: List[Image.Image]) -> torch.Tensor: + return self.client.encode_images(images).cpu() + + def embed_queries(self, queries: List[str]) -> torch.Tensor: + return self.client.encode_queries(queries).cpu() diff --git a/foretrieval/client/remote_backend.py b/foretrieval/client/remote_backend.py new file mode 100644 index 0000000..54a6b4c --- /dev/null +++ b/foretrieval/client/remote_backend.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import io +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional + +import torch +from PIL import Image + +from .transport import dumps_payload, loads_payload + + +class RemoteEmbeddingClient: + def __init__( + self, + server_url: str, + model_name: str, + token: Optional[str] = None, + timeout: float = 30.0, + verify_ssl: bool = True, + concurrency: int = 1, + request_batch_size: Optional[int] = None, + ): + try: + import httpx + except ImportError as exc: + raise ImportError( + "Remote embedding mode requires `httpx`. Install with `pip install httpx`." + ) from exc + + self._httpx = httpx + self.server_url = server_url.rstrip("/") + self.model_name = model_name + self.concurrency = max(1, int(concurrency)) + self.request_batch_size = int(request_batch_size) if request_batch_size else None + if self.request_batch_size is not None and self.request_batch_size < 1: + raise ValueError("request_batch_size must be >= 1 when provided.") + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + self.client = httpx.Client(timeout=timeout, headers=headers, verify=verify_ssl) + + @staticmethod + def _image_to_bytes(image: Image.Image) -> bytes: + buf = io.BytesIO() + image.save(buf, format="PNG") + return buf.getvalue() + + @staticmethod + def _chunked(items: list, chunk_size: int) -> List[list]: + return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)] + + def _post_embeddings(self, endpoint: str, payload: dict) -> torch.Tensor: + resp = self.client.post( + f"{self.server_url}{endpoint}", + content=dumps_payload(payload), + headers={"Content-Type": "application/octet-stream"}, + ) + resp.raise_for_status() + out = loads_payload(resp.content) + return out["embeddings"].cpu() + + def _encode_batched(self, endpoint: str, field_name: str, values: list) -> torch.Tensor: + if not values: + return torch.empty((0,)) + + batch_size = self.request_batch_size or len(values) + chunks = self._chunked(values, batch_size) + + # avoids thread pool overhead if only one element + if self.concurrency == 1 or len(chunks) == 1: + tensors = [] + for chunk in chunks: + payload = {"model": self.model_name, field_name: chunk} + tensors.append(self._post_embeddings(endpoint, payload)) + return torch.cat(tensors, dim=0) if len(tensors) > 1 else tensors[0] + + with ThreadPoolExecutor(max_workers=self.concurrency) as pool: + futures = [] + for idx, chunk in enumerate(chunks): + payload = {"model": self.model_name, field_name: chunk} + futures.append((idx, pool.submit(self._post_embeddings, endpoint, payload))) + + ordered = [None] * len(chunks) + for idx, future in futures: + ordered[idx] = future.result() + + return torch.cat(ordered, dim=0) if len(ordered) > 1 else ordered[0] + + def encode_images(self, images: List[Image.Image]) -> torch.Tensor: + encoded = [self._image_to_bytes(image.convert("RGB")) for image in images] + return self._encode_batched("/v1/embed/images", "images", encoded) + + def encode_queries(self, queries: List[str]) -> torch.Tensor: + return self._encode_batched("/v1/embed/queries", "queries", queries) diff --git a/foretrieval/client/transport.py b/foretrieval/client/transport.py new file mode 100644 index 0000000..9db18ac --- /dev/null +++ b/foretrieval/client/transport.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import io +from typing import Any + +import torch + + +def dumps_payload(payload: Any) -> bytes: + buffer = io.BytesIO() + torch.save(payload, buffer) + return buffer.getvalue() + + +def loads_payload(data: bytes) -> Any: + buffer = io.BytesIO(data) + return torch.load(buffer, map_location="cpu") diff --git a/foretrieval/colpali.py b/foretrieval/colpali.py index 1530737..589047b 100644 --- a/foretrieval/colpali.py +++ b/foretrieval/colpali.py @@ -1,4 +1,5 @@ import base64 +import importlib.util import io import logging import os @@ -13,6 +14,7 @@ from pdf2image import convert_from_path from PIL import Image import torch +from transformers import BitsAndBytesConfig try: from qdrant_client import QdrantClient @@ -41,6 +43,8 @@ from .objects import Result from .plot_utils import draw_circle_on_max_patch, pil_from_base64, pil_to_base64_png, compute_patch_heatmap, majority_token_id, build_heatmap_overlays_base64 from .utils import _value_match +from .client.embedding_backends import LocalEmbeddingBackend, RemoteEmbeddingBackend +from .client.remote_backend import RemoteEmbeddingClient VERSION = "0.0.1" @@ -121,11 +125,64 @@ def __init__( self.enable_circle = False self.SOURCE_EXTS = {".doc", ".docx", ".rtf", ".odt", ".ppt", ".pptx", ".odp", ".xls", ".xlsx", ".ods", ".txt", ".md", ".csv", ".json", ".yaml", ".yml", ".epub", ".html"} self.IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".gif"} + self.embedding_mode = str(kwargs.pop("embedding_mode", "local")) + self.embedding_server_url = kwargs.pop("embedding_server_url", None) + self.embedding_server_token = kwargs.pop("embedding_server_token", None) + self.embedding_request_timeout = float(kwargs.pop("embedding_request_timeout", 30.0)) + self.embedding_verify_ssl = bool(kwargs.pop("embedding_verify_ssl", True)) + self.embedding_concurrency = int(kwargs.pop("embedding_concurrency", 1)) + self.embedding_request_batch_size = kwargs.pop("embedding_request_batch_size", None) + if self.embedding_concurrency < 1: + raise ValueError("embedding_concurrency must be >= 1.") + if self.embedding_request_batch_size is not None: + self.embedding_request_batch_size = int(self.embedding_request_batch_size) + if self.embedding_request_batch_size < 1: + raise ValueError("embedding_request_batch_size must be >= 1 when provided.") + self.flash_attention_mode = str(kwargs.pop("flash_attention_mode", "auto")).strip().lower() + if self.flash_attention_mode not in {"auto", "on", "off"}: + raise ValueError("flash_attention_mode must be one of: auto, on, off.") self.qdrant_client = None self.qdrant_collection = index_name self.qdrant_path = None + load_in_4bit = bool(kwargs.pop("load_in_4bit", False)) + load_in_8bit = bool(kwargs.pop("load_in_8bit", False)) + bnb_4bit_quant_type = str(kwargs.pop("bnb_4bit_quant_type", "nf4")) + bnb_4bit_compute_dtype = str(kwargs.pop("bnb_4bit_compute_dtype", "float16")) + + if load_in_4bit and load_in_8bit: + raise ValueError("Only one quantization mode can be enabled: 4-bit or 8-bit.") + + quantization_config = None + if load_in_4bit or load_in_8bit: + if importlib.util.find_spec("bitsandbytes") is None: + raise ImportError( + "Quantization requested but `bitsandbytes` is not installed. " + "Install it with `pip install bitsandbytes`." + ) + if load_in_4bit: + compute_dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + if bnb_4bit_compute_dtype not in compute_dtype_map: + raise ValueError( + "Invalid bnb_4bit_compute_dtype. Expected one of: " + "'float16', 'bfloat16', 'float32'." + ) + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=compute_dtype_map[bnb_4bit_compute_dtype], + ) + else: + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + self._load_in_4bit = load_in_4bit + self._load_in_8bit = load_in_8bit + self._quantization_config = quantization_config + if self.storage_qdrant and self.index_name is not None: self.qdrant_path = Path(self.index_root) / self.index_name / "qdrant" self.qdrant_client = QdrantClient(path=str(self.qdrant_path)) @@ -135,7 +192,28 @@ def __init__( self.docling_dir = Path(index_root) / self.index_name / "docling_chunks" self.docling_dir.mkdir(parents=True, exist_ok=True) - self._load_model_and_processor() + self.embedding_backend = None + if self.embedding_mode == "remote": + if not self.embedding_server_url: + raise ValueError("embedding_server_url is required when embedding_mode='remote'.") + self.remote_client = RemoteEmbeddingClient( + self.embedding_server_url, + model_name=self.pretrained_model_name_or_path, + token=self.embedding_server_token, + timeout=self.embedding_request_timeout, + verify_ssl=self.embedding_verify_ssl, + concurrency=self.embedding_concurrency, + request_batch_size=self.embedding_request_batch_size, + ) + self.embedding_backend = RemoteEmbeddingBackend(self.remote_client) + self._load_remote_processor_only() + self.model = None + elif self.embedding_mode == "local": + self.remote_client = None + self._load_model_and_processor() + self.embedding_backend = LocalEmbeddingBackend(self.model, self.processor, self.device) + else: + raise ValueError("embedding_mode must be either 'local' or 'remote'.") if not load_from_index: self.full_document_collection = False @@ -150,6 +228,7 @@ def _load_model_and_processor(self): token = self.kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN") is_cuda = self.device == "cuda" or (isinstance(self.device, torch.device) and self.device.type == "cuda") device_map = "cuda" if is_cuda else None + flash_attn_available = importlib.util.find_spec("flash_attn") is not None if "colpali" in self.pretrained_model_name_or_path.lower(): model_cls = ColPali @@ -161,11 +240,33 @@ def _load_model_and_processor(self): model_cls = ColQwen2 processor_cls = ColQwen2Processor + model_kwargs = { + "torch_dtype": torch.bfloat16, + "device_map": device_map, + "quantization_config": self._quantization_config, + "token": token, + } + if self.flash_attention_mode == "on": + if not is_cuda: + raise RuntimeError("flash_attention_mode='on' requires CUDA device.") + if not flash_attn_available: + raise RuntimeError("flash_attention_mode='on' but flash_attn is not installed.") + model_kwargs["attn_implementation"] = "flash_attention_2" + logger.info("Using Flash-Attention 2 for model loading (flash_attention_mode=on).") + elif self.flash_attention_mode == "off": + logger.info("Using default attention implementation for model loading (flash_attention_mode=off).") + elif is_cuda and flash_attn_available: + model_kwargs["attn_implementation"] = "flash_attention_2" + logger.info("Using Flash-Attention 2 for model loading (flash_attention_mode=auto).") + else: + logger.info( + "Using default attention implementation for model loading " + f"(flash_attention_mode=auto, is_cuda={is_cuda}, flash_attn_available={flash_attn_available})." + ) + self.model = model_cls.from_pretrained( self.pretrained_model_name_or_path, - torch_dtype=torch.bfloat16, - device_map=device_map, - token=token, + **model_kwargs, ) self.processor = processor_cls.from_pretrained( self.pretrained_model_name_or_path, @@ -173,9 +274,28 @@ def _load_model_and_processor(self): ) self.model = self.model.eval() - if device_map is None: + if device_map is None and not (self._load_in_4bit or self._load_in_8bit): self.model = self.model.to(self.device) + def _load_remote_processor_only(self): + token = self.kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN") + model_name = self.pretrained_model_name_or_path.lower() + if "colpali" in model_name: + self.processor = ColPaliProcessor.from_pretrained( + self.pretrained_model_name_or_path, + token=token, + ) + elif "colqwen2.5" in model_name: + self.processor = ColQwen2_5_Processor.from_pretrained( + self.pretrained_model_name_or_path, + token=token, + ) + else: + self.processor = ColQwen2Processor.from_pretrained( + self.pretrained_model_name_or_path, + token=token, + ) + def _load_index_state(self): if self.index_name is None: raise ValueError("No index name specified. Cannot load from index.") @@ -859,17 +979,7 @@ def _add_to_index( # Optionnel mais utile : taille originale de l'image orig_sizes = [img.size for img in images] # (W,H) PIL - # Generate embeddings - with torch.inference_mode(): - processed_images = { - k: v.to(self.device).to( - self.model.dtype - if v.dtype in [torch.float16, torch.bfloat16, torch.float32] - else v.dtype - ) - for k, v in processed_images.items() - } - embeddings = self.model(**processed_images) + embeddings = self._embed_images(images) # 1. Compute embeddings # 2. Ensure backend storage and check duplicates @@ -1089,19 +1199,10 @@ def remove_from_index(self): # ============================================================ def _encode_search_query(self, query: str): - with torch.inference_mode(): - batch_query = self.processor.process_queries([query]) - batch_query = { - kk: vv.to(self.device).to( - self.model.dtype - if vv.dtype in [torch.float16, torch.bfloat16, torch.float32] - else vv.dtype - ) - for kk, vv in batch_query.items() - } - embeddings_query = self.model(**batch_query) - qs = list(torch.unbind(embeddings_query.to("cpu"))) + embeddings_query = self._embed_queries([query]) + qs = list(torch.unbind(embeddings_query.to("cpu"))) + batch_query = self.processor.process_queries([query]) input_ids = batch_query["input_ids"][0].detach().cpu().tolist() tokens = self.processor.tokenizer.convert_ids_to_tokens(input_ids) valid_idxs = [i for i, tok in enumerate(tokens) if tok not in {"<|endoftext|>", "Query", ":"}] @@ -1499,6 +1600,57 @@ def _post_process_image(self, image: Image.Image) -> str: # File helpers # ============================================================ + def encode_image( + self, input_data: Union[str, Image.Image, List[Union[str, Image.Image]]] + ) -> torch.Tensor: + if not isinstance(input_data, list): + input_data = [input_data] + + images = [] + for item in input_data: + if isinstance(item, Image.Image): + images.append(item) + elif isinstance(item, str): + if os.path.isdir(item): + for file in os.listdir(item): + if file.lower().endswith( + (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif") + ): + images.append(Image.open(os.path.join(item, file))) + elif item.lower().endswith(".pdf"): + with tempfile.TemporaryDirectory() as path: + pdf_images = convert_from_path( + item, thread_count=os.cpu_count() - 1, output_folder=path + ) + images.extend(pdf_images) + elif item.lower().endswith( + (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif") + ): + images.append(Image.open(item)) + else: + raise ValueError(f"Unsupported file type: {item}") + else: + raise ValueError(f"Unsupported input type: {type(item)}") + + return self._embed_images(images).cpu() + + def encode_query(self, query: Union[str, List[str]]) -> torch.Tensor: + if isinstance(query, str): + query = [query] + + return self._embed_queries(query).cpu() + + def _embed_images(self, images: List[Image.Image]) -> torch.Tensor: + if self.embedding_backend is None: + raise RuntimeError("Embedding backend is not initialized.") + return self.embedding_backend.embed_images(images) + + def _embed_queries(self, query: Union[str, List[str]]) -> torch.Tensor: + queries = [query] if isinstance(query, str) else query + if self.embedding_backend is None: + raise RuntimeError("Embedding backend is not initialized.") + return self.embedding_backend.embed_queries(queries) + def _looks_like_pdf(self, path: Path) -> bool: try: if not path.exists() or path.stat().st_size < 5: diff --git a/foretrieval/retriever.py b/foretrieval/retriever.py index 0b76b77..a00dd17 100644 --- a/foretrieval/retriever.py +++ b/foretrieval/retriever.py @@ -44,8 +44,17 @@ def from_pretrained( pretrained_model_name_or_path: Union[str, Path], index_root: str = ".rag_index", ingestion: Dict[str, Any] = {"backend": "default"}, + storage_qdrant: bool = False, device: str = "cuda", verbose: int = 1, + hf_token: Optional[str] = None, + embedding_mode: str = "local", + remote: Optional[Dict[str, Any]] = None, + load_in_4bit: bool = False, + load_in_8bit: bool = False, + bnb_4bit_quant_type: str = "nf4", + bnb_4bit_compute_dtype: str = "float16", + flash_attention_mode: str = "auto", ): """Load a ColPali model from a pre-trained checkpoint. @@ -59,13 +68,28 @@ def from_pretrained( Returns: cls (RAGMultiModalModel): The current instance of RAGMultiModalModel, with the model initialised. """ + remote_cfg: Dict[str, Any] = remote or {} instance = cls() instance.model = ColPaliModel.from_pretrained( pretrained_model_name_or_path, index_root=index_root, ingestion=ingestion, + storage_qdrant=storage_qdrant, device=device, verbose=verbose, + hf_token=hf_token, + embedding_mode=embedding_mode, + embedding_server_url=remote_cfg.get("url"), + embedding_server_token=remote_cfg.get("token"), + embedding_request_timeout=remote_cfg.get("request_timeout", 30.0), + embedding_verify_ssl=remote_cfg.get("verify_ssl", True), + embedding_concurrency=remote_cfg.get("concurrency", 1), + embedding_request_batch_size=remote_cfg.get("request_batch_size"), + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + flash_attention_mode=flash_attention_mode, ) return instance diff --git a/foretrieval/server/__init__.py b/foretrieval/server/__init__.py new file mode 100644 index 0000000..28e776c --- /dev/null +++ b/foretrieval/server/__init__.py @@ -0,0 +1 @@ +from .embedding_server import create_app diff --git a/foretrieval/server/embedding_server.py b/foretrieval/server/embedding_server.py new file mode 100644 index 0000000..6ab0d5f --- /dev/null +++ b/foretrieval/server/embedding_server.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import asyncio +import gc +import io +from typing import Dict, Optional, Tuple + +from fastapi import FastAPI, HTTPException, Request, Response +from PIL import Image + +from ..client.transport import dumps_payload, loads_payload +from ..retriever import MultiModalRetrieverModel + + +def create_app( + model_name: str, + device: str = "cuda", + hf_token: Optional[str] = None, + verbose: int = 1, + load_in_4bit: bool = False, + load_in_8bit: bool = False, + bnb_4bit_quant_type: str = "nf4", + bnb_4bit_compute_dtype: str = "float16", + flash_attention_mode: str = "auto", + max_inflight: int = 1, + single_model_cache: bool = False, +) -> FastAPI: + app = FastAPI(title="FORetrieval Embedding Server") + inflight_semaphore = asyncio.Semaphore(max_inflight) + cache_lock = asyncio.Lock() + + retriever_cache: Dict[str, MultiModalRetrieverModel] = {} + retriever_in_use: Dict[int, int] = {} + pending_releases: Dict[int, MultiModalRetrieverModel] = {} + + def _release_retriever(retriever: MultiModalRetrieverModel) -> None: + # Best effort release: drop strong refs, force GC, then flush CUDA allocator cache. + del retriever + gc.collect() + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if hasattr(torch.cuda, "ipc_collect"): + torch.cuda.ipc_collect() + except Exception: + # Cache eviction should not fail request processing. + pass + + def _schedule_or_release_locked(retriever: MultiModalRetrieverModel) -> None: + rid = id(retriever) + if retriever_in_use.get(rid, 0) > 0: + pending_releases[rid] = retriever + return + _release_retriever(retriever) + + async def _acquire_retriever(requested_model: Optional[str]) -> Tuple[MultiModalRetrieverModel, int]: + target_model = requested_model or model_name + async with cache_lock: + cached = retriever_cache.get(target_model) + if cached is None: + if single_model_cache and retriever_cache: + old_retrievers = list(retriever_cache.values()) + retriever_cache.clear() + for old in old_retrievers: + _schedule_or_release_locked(old) + + cached = MultiModalRetrieverModel.from_pretrained( + target_model, + device=device, + verbose=verbose, + hf_token=hf_token, + embedding_mode="local", + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + flash_attention_mode=flash_attention_mode, + ) + retriever_cache[target_model] = cached + + rid = id(cached) + retriever_in_use[rid] = retriever_in_use.get(rid, 0) + 1 + return cached, rid + + async def _release_retriever_use(rid: int) -> None: + to_release = None + async with cache_lock: + current = retriever_in_use.get(rid, 0) + if current <= 1: + retriever_in_use.pop(rid, None) + to_release = pending_releases.pop(rid, None) + else: + retriever_in_use[rid] = current - 1 + if to_release is not None: + _release_retriever(to_release) + + @app.get("/health") + def health(): + return { + "status": "ok", + "default_model": model_name, + "loaded_models": list(retriever_cache.keys()), + "single_model_cache": single_model_cache, + "in_use_retrievers": len(retriever_in_use), + "pending_evictions": len(pending_releases), + "device": str(device), + "max_inflight": max_inflight, + } + + @app.post("/v1/embed/images") + async def embed_images(request: Request): + try: + payload = loads_payload(await request.body()) + async with inflight_semaphore: + retriever, rid = await _acquire_retriever(payload.get("model")) + images = [Image.open(io.BytesIO(b)).convert("RGB") for b in payload["images"]] + try: + embeddings = retriever.model.encode_image(images) + finally: + await _release_retriever_use(rid) + return Response( + content=dumps_payload({"embeddings": embeddings.cpu()}), + media_type="application/octet-stream", + ) + except Exception as exc: + raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") from exc + + @app.post("/v1/embed/queries") + async def embed_queries(request: Request): + try: + payload = loads_payload(await request.body()) + async with inflight_semaphore: + retriever, rid = await _acquire_retriever(payload.get("model")) + queries = payload["queries"] + try: + embeddings = retriever.model.encode_query(queries) + finally: + await _release_retriever_use(rid) + return Response( + content=dumps_payload({"embeddings": embeddings.cpu()}), + media_type="application/octet-stream", + ) + except Exception as exc: + raise HTTPException(status_code=400, detail=f"Invalid query payload: {exc}") from exc + + return app diff --git a/foretrieval/server/server_main.py b/foretrieval/server/server_main.py new file mode 100644 index 0000000..e9edbdb --- /dev/null +++ b/foretrieval/server/server_main.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import os + +from .embedding_server import create_app + + +def _as_bool(name: str, default: bool = False) -> bool: + val = os.environ.get(name) + if val is None: + return default + return val.strip().lower() in {"1", "true", "yes", "on"} + + +def _as_positive_int(name: str, default: int) -> int: + val = os.environ.get(name) + if val is None: + return default + try: + parsed = int(val) + except ValueError as exc: + raise ValueError(f"{name} must be an integer, got {val!r}.") from exc + if parsed < 1: + raise ValueError(f"{name} must be >= 1, got {parsed}.") + return parsed + + +def _as_optional_str(name: str) -> str | None: + val = os.environ.get(name) + if val is None: + return None + stripped = val.strip() + return stripped if stripped else None + + +def _as_flash_attention_mode(name: str, default: str = "auto") -> str: + val = os.environ.get(name) + if val is None: + return default + mode = val.strip().lower() + aliases = { + "true": "on", + "1": "on", + "yes": "on", + "false": "off", + "0": "off", + "no": "off", + } + mode = aliases.get(mode, mode) + if mode not in {"auto", "on", "off"}: + raise ValueError(f"{name} must be one of auto/on/off (or boolean aliases), got {val!r}.") + return mode + + +MODEL_NAME = os.environ.get("FOR_EMBED_MODEL", "vidore/colqwen2-v1.0") +DEVICE = os.environ.get("FOR_EMBED_DEVICE", "cuda") +VERBOSE = int(os.environ.get("FOR_EMBED_VERBOSE", "1")) +HF_TOKEN = _as_optional_str("HF_TOKEN") +LOAD_IN_4BIT = _as_bool("FOR_EMBED_LOAD_IN_4BIT", False) +LOAD_IN_8BIT = _as_bool("FOR_EMBED_LOAD_IN_8BIT", False) +BNB_4BIT_QUANT_TYPE = os.environ.get("FOR_EMBED_BNB_4BIT_QUANT_TYPE", "nf4") +BNB_4BIT_COMPUTE_DTYPE = os.environ.get("FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE", "float16") +FLASH_ATTENTION_MODE = _as_flash_attention_mode("FOR_EMBED_FLASH_ATTENTION", "auto") +MAX_INFLIGHT = _as_positive_int("FOR_SERVER_MAX_INFLIGHT", 1) +SINGLE_MODEL_CACHE = _as_bool("FOR_SERVER_SINGLE_MODEL_CACHE", False) + +app = create_app( + model_name=MODEL_NAME, + device=DEVICE, + verbose=VERBOSE, + hf_token=HF_TOKEN, + load_in_4bit=LOAD_IN_4BIT, + load_in_8bit=LOAD_IN_8BIT, + bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE, + bnb_4bit_compute_dtype=BNB_4BIT_COMPUTE_DTYPE, + flash_attention_mode=FLASH_ATTENTION_MODE, + max_inflight=MAX_INFLIGHT, + single_model_cache=SINGLE_MODEL_CACHE, +) diff --git a/pyproject.toml b/pyproject.toml index c640dfd..3ab76cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,13 +27,14 @@ dependencies = [ "colpali-engine>=0.3.4,<0.4.0", "docx2pdf>=0.1.8", "langdetect>=1.0.9", + "matplotlib>=3.8.0", "ml-dtypes", "pydantic-ai>=1.4.0", "pydantic-ai-slim[openai]>=1.8.0", "pypdf>=6.1.3", "srsly", "torch>=2.7.1", - "transformers>=4.42.0", + "transformers>=4.42.0, <5", "pdf2image>=1.17.0", ] @@ -42,7 +43,11 @@ dependencies = [ qdrant = ["qdrant-client>=1.17.1"] # Docling-based PDF chunking ingestion pipeline docling = ["docling>=2.76.0"] -dev = ["pytest>=7.4.0", "ruff>=0.1.9"] +dev = ["pytest>=7.4.0", "pytest-cov>=5.0.0", "ruff>=0.1.9"] +server = ["uvicorn", "fastapi"] +client = ["httpx>=0.27.0"] +full = ["httpx>=0.27.0", "uvicorn", "fastapi"] +quant = ["bitsandbytes>=0.43"] langchain = ["langchain-core"] extra_converters = [ "docx2pdf>=0.1.8; sys_platform == \"win32\"", diff --git a/scripts/run-docker.sh b/scripts/run-docker.sh new file mode 100755 index 0000000..a42fed8 --- /dev/null +++ b/scripts/run-docker.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Image selection (GHCR-friendly defaults) +REGISTRY="${REGISTRY:-ghcr.io}" +OWNER="${OWNER:-random-plm}" +IMAGE_NAME="${IMAGE_NAME:-foretrieval-server}" +IMAGE_TAG="${IMAGE_TAG:-v0.0.1}" +IMAGE="${IMAGE:-${REGISTRY}/${OWNER}/${IMAGE_NAME}:${IMAGE_TAG}}" +GITHUB_PAT="${GITHUB_PAT:-}" + +# Container/runtime settings +CONTAINER_NAME="${CONTAINER_NAME:-foretrieval-server}" +HOST_PORT="${HOST_PORT:-8000}" +CONTAINER_PORT="${CONTAINER_PORT:-8000}" +HF_CACHE_DIR="${HF_CACHE_DIR:-/models/foretrieval}" + +# GPU pinning +CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" + +# Server env vars +FOR_EMBED_MODEL="${FOR_EMBED_MODEL:-vidore/colqwen2-v1.0}" +FOR_EMBED_DEVICE="${FOR_EMBED_DEVICE:-cuda}" +FOR_EMBED_VERBOSE="${FOR_EMBED_VERBOSE:-1}" +FOR_SERVER_WORKERS="${FOR_SERVER_WORKERS:-1}" +FOR_SERVER_MAX_INFLIGHT="${FOR_SERVER_MAX_INFLIGHT:-1}" +FOR_SERVER_SINGLE_MODEL_CACHE="${FOR_SERVER_SINGLE_MODEL_CACHE:-true}" +FOR_EMBED_LOAD_IN_4BIT="${FOR_EMBED_LOAD_IN_4BIT:-false}" +FOR_EMBED_LOAD_IN_8BIT="${FOR_EMBED_LOAD_IN_8BIT:-false}" +FOR_EMBED_BNB_4BIT_QUANT_TYPE="${FOR_EMBED_BNB_4BIT_QUANT_TYPE:-nf4}" +FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE="${FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE:-float16}" +HF_TOKEN="${HF_TOKEN:-}" + +echo "Starting container '${CONTAINER_NAME}' from image '${IMAGE}'..." +echo "GPU pinning: CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" +echo "HF cache mount: ${HF_CACHE_DIR} -> /root/.cache/huggingface" + +if [[ -z "${GITHUB_PAT}" ]]; then + echo "Error: GITHUB_PAT is not set." + echo "Set it before running, e.g.:" + echo " export GITHUB_PAT=ghp_xxx" + exit 1 +fi + +printf '%s' "${GITHUB_PAT}" | docker login ghcr.io -u "${OWNER}" --password-stdin >/dev/null +echo "Authenticated to ghcr.io as ${OWNER}." + +mkdir -p "${HF_CACHE_DIR}" + +env_args=( + -e CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" + -e FOR_EMBED_MODEL="${FOR_EMBED_MODEL}" + -e FOR_EMBED_DEVICE="${FOR_EMBED_DEVICE}" + -e FOR_EMBED_VERBOSE="${FOR_EMBED_VERBOSE}" + -e FOR_SERVER_WORKERS="${FOR_SERVER_WORKERS}" + -e FOR_SERVER_MAX_INFLIGHT="${FOR_SERVER_MAX_INFLIGHT}" + -e FOR_SERVER_SINGLE_MODEL_CACHE="${FOR_SERVER_SINGLE_MODEL_CACHE}" + -e FOR_EMBED_LOAD_IN_4BIT="${FOR_EMBED_LOAD_IN_4BIT}" + -e FOR_EMBED_LOAD_IN_8BIT="${FOR_EMBED_LOAD_IN_8BIT}" + -e FOR_EMBED_BNB_4BIT_QUANT_TYPE="${FOR_EMBED_BNB_4BIT_QUANT_TYPE}" + -e FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE="${FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE}" +) + +if [[ -n "${HF_TOKEN}" ]]; then + env_args+=(-e HF_TOKEN="${HF_TOKEN}") +else + echo "HF_TOKEN is empty; not passing it to the container." +fi + +docker run --rm -it \ + --name "${CONTAINER_NAME}" \ + --gpus "device=${CUDA_VISIBLE_DEVICES}" \ + -p "${HOST_PORT}:${CONTAINER_PORT}" \ + -v "${HF_CACHE_DIR}:/root/.cache/huggingface" \ + "${env_args[@]}" \ + "${IMAGE}" diff --git a/tests/test_colpali.py b/tests/test_colpali.py index 04da413..5e86d7e 100644 --- a/tests/test_colpali.py +++ b/tests/test_colpali.py @@ -12,7 +12,13 @@ def colpali_rag_model() -> Generator[MultiModalRetrieverModel, None, None]: device = get_torch_device("auto") print(f"Using device: {device}") - yield MultiModalRetrieverModel.from_pretrained("vidore/colpali-v1.3", device=device) + yield MultiModalRetrieverModel.from_pretrained( + "vidore/colpali-v1.3", + device=device, + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="float16", + ) tear_down_torch() diff --git a/tests/test_colqwen.py b/tests/test_colqwen.py index 6b20283..449e53f 100644 --- a/tests/test_colqwen.py +++ b/tests/test_colqwen.py @@ -13,7 +13,11 @@ def colqwen_rag_model() -> Generator[MultiModalRetrieverModel, None, None]: device = get_torch_device("auto") print(f"Using device: {device}") yield MultiModalRetrieverModel.from_pretrained( - "vidore/colqwen2.5-v0.2", device=device + "vidore/colqwen2.5-v0.2", + device=device, + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="float16", ) tear_down_torch() diff --git a/tests/test_quantization.py b/tests/test_quantization.py new file mode 100644 index 0000000..d246b43 --- /dev/null +++ b/tests/test_quantization.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import pytest +import torch +from transformers import BitsAndBytesConfig + +from foretrieval.colpali import ColPaliModel +from foretrieval.retriever import MultiModalRetrieverModel + + +class _DummyModel: + def eval(self): + return self + + def to(self, _device): + return self + + +class _DummyProcessor: + pass + + +def test_quant_modes_mutually_exclusive(): + with pytest.raises(ValueError, match="Only one quantization mode"): + ColPaliModel.from_pretrained( + "vidore/colpali-v1.3", + device="cpu", + load_in_4bit=True, + load_in_8bit=True, + ) + + +def test_invalid_4bit_compute_dtype(monkeypatch): + import foretrieval.colpali as colpali_module + + monkeypatch.setattr(colpali_module.importlib.util, "find_spec", lambda _name: object()) + + with pytest.raises(ValueError, match="Invalid bnb_4bit_compute_dtype"): + ColPaliModel.from_pretrained( + "vidore/colpali-v1.3", + device="cpu", + load_in_4bit=True, + bnb_4bit_compute_dtype="float64", + ) + + +def test_wrapper_forwards_quant_args(monkeypatch): + import foretrieval.retriever as retriever_module + + captured = {} + + def _fake_from_pretrained(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + return object() + + monkeypatch.setattr(retriever_module.ColPaliModel, "from_pretrained", _fake_from_pretrained) + + _ = MultiModalRetrieverModel.from_pretrained( + "vidore/colpali-v1.3", + device="cpu", + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="float16", + ) + + assert captured["args"][0] == "vidore/colpali-v1.3" + assert captured["kwargs"]["load_in_4bit"] is True + assert captured["kwargs"]["load_in_8bit"] is False + assert captured["kwargs"]["bnb_4bit_quant_type"] == "nf4" + assert captured["kwargs"]["bnb_4bit_compute_dtype"] == "float16" + + +@pytest.mark.parametrize( + ("loader_attr", "quant_kwargs", "expected"), + [ + ( + "ColPali", + {"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": "float16"}, + {"load_in_4bit": True, "load_in_8bit": False, "dtype": torch.float16}, + ), + ( + "ColPali", + {"load_in_8bit": True}, + {"load_in_4bit": False, "load_in_8bit": True, "dtype": None}, + ), + ], +) +def test_quantization_config_is_passed(monkeypatch, loader_attr, quant_kwargs, expected): + import foretrieval.colpali as colpali_module + + monkeypatch.setattr(colpali_module.importlib.util, "find_spec", lambda _name: object()) + + captured = {} + + def _fake_model_from_pretrained(*_args, **kwargs): + captured["model_kwargs"] = kwargs + return _DummyModel() + + def _fake_processor_from_pretrained(*_args, **_kwargs): + return _DummyProcessor() + + monkeypatch.setattr(getattr(colpali_module, loader_attr), "from_pretrained", _fake_model_from_pretrained) + monkeypatch.setattr(colpali_module.ColPaliProcessor, "from_pretrained", _fake_processor_from_pretrained) + + _ = ColPaliModel.from_pretrained("vidore/colpali-v1.3", device="cpu", **quant_kwargs) + + qcfg = captured["model_kwargs"]["quantization_config"] + assert isinstance(qcfg, BitsAndBytesConfig) + assert qcfg.load_in_4bit is expected["load_in_4bit"] + assert qcfg.load_in_8bit is expected["load_in_8bit"] + if expected["dtype"] is not None: + assert qcfg.bnb_4bit_compute_dtype == expected["dtype"] + + +def test_quantization_requires_bitsandbytes(monkeypatch): + import foretrieval.colpali as colpali_module + + monkeypatch.setattr(colpali_module.importlib.util, "find_spec", lambda _name: None) + + with pytest.raises(ImportError, match="bitsandbytes"): + ColPaliModel.from_pretrained( + "vidore/colpali-v1.3", + device="cpu", + load_in_4bit=True, + ) diff --git a/tests/test_remote_mode.py b/tests/test_remote_mode.py new file mode 100644 index 0000000..c720363 --- /dev/null +++ b/tests/test_remote_mode.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import torch + +from foretrieval.colpali import ColPaliModel +from foretrieval.retriever import MultiModalRetrieverModel + + +def test_wrapper_forwards_remote_mode_args(monkeypatch): + import foretrieval.retriever as retriever_module + + captured = {} + + def _fake_from_pretrained(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + return object() + + monkeypatch.setattr( + retriever_module.ColPaliModel, "from_pretrained", _fake_from_pretrained + ) + + _ = MultiModalRetrieverModel.from_pretrained( + "vidore/colqwen2-v1.0", + embedding_mode="remote", + remote={ + "url": "http://localhost:8000", + "token": "abc", + "request_timeout": 12.0, + "verify_ssl": False, + }, + ) + + assert captured["kwargs"]["embedding_mode"] == "remote" + assert captured["kwargs"]["embedding_server_url"] == "http://localhost:8000" + assert captured["kwargs"]["embedding_server_token"] == "abc" + assert captured["kwargs"]["embedding_request_timeout"] == 12.0 + assert captured["kwargs"]["embedding_verify_ssl"] is False + + +def test_remote_mode_requires_server_url(monkeypatch): + import foretrieval.colpali as colpali_module + + with monkeypatch.context() as m: + m.setattr(colpali_module.importlib.util, "find_spec", lambda _name: None) + try: + ColPaliModel.from_pretrained("vidore/colqwen2-v1.0", embedding_mode="remote") + assert False, "Expected ValueError" + except ValueError as exc: + assert "embedding_server_url" in str(exc) + + +def test_embed_queries_uses_remote_client(monkeypatch): + import foretrieval.colpali as colpali_module + + class _DummyProcessor: + pass + + class _FakeRemoteClient: + def __init__(self, *_args, **_kwargs): + pass + + def encode_queries(self, queries): + assert queries == ["hello"] + return torch.randn(1, 4, 8) + + def encode_images(self, _images): + return torch.randn(1, 4, 8) + + monkeypatch.setattr(colpali_module, "RemoteEmbeddingClient", _FakeRemoteClient) + monkeypatch.setattr( + colpali_module.ColQwen2Processor, + "from_pretrained", + lambda *_args, **_kwargs: _DummyProcessor(), + ) + + model = ColPaliModel.from_pretrained( + "vidore/colqwen2-v1.0", + embedding_mode="remote", + embedding_server_url="http://localhost:8000", + device="cpu", + ) + out = model.encode_query("hello") + assert isinstance(out, torch.Tensor) + assert out.shape[0] == 1 + + +def test_embedding_verify_ssl_disables_remote_ssl_verification(monkeypatch): + import foretrieval.colpali as colpali_module + + class _DummyProcessor: + pass + + captured = {} + + class _FakeRemoteClient: + def __init__(self, *_args, **kwargs): + captured["verify_ssl"] = kwargs.get("verify_ssl") + + def encode_queries(self, _queries): + return torch.randn(1, 4, 8) + + def encode_images(self, _images): + return torch.randn(1, 4, 8) + + monkeypatch.setattr(colpali_module, "RemoteEmbeddingClient", _FakeRemoteClient) + monkeypatch.setattr( + colpali_module.ColQwen2Processor, + "from_pretrained", + lambda *_args, **_kwargs: _DummyProcessor(), + ) + + ColPaliModel.from_pretrained( + "vidore/colqwen2-v1.0", + embedding_mode="remote", + embedding_server_url="https://localhost:8000", + embedding_verify_ssl=False, + device="cpu", + ) + + assert captured["verify_ssl"] is False diff --git a/uv.lock b/uv.lock index af0739e..33399ae 100644 --- a/uv.lock +++ b/uv.lock @@ -1068,9 +1068,11 @@ name = "foretrieval" version = "0.1" source = { editable = "." } dependencies = [ + { name = "bitsandbytes", marker = "extra == 'quant'" }, { name = "colpali-engine" }, { name = "docx2pdf" }, { name = "langdetect" }, + { name = "matplotlib" }, { name = "ml-dtypes" }, { name = "pdf2image" }, { name = "pydantic-ai" }, @@ -1082,6 +1084,9 @@ dependencies = [ ] [package.optional-dependencies] +client = [ + { name = "httpx" }, +] dev = [ { name = "pytest" }, { name = "ruff" }, @@ -1095,12 +1100,20 @@ extra-converters = [ { name = "python-pptx" }, { name = "reportlab" }, ] +full = [ + { name = "fastapi" }, + { name = "httpx" }, + { name = "uvicorn" }, +] langchain = [ { name = "langchain-core" }, ] qdrant = [ { name = "qdrant-client" }, ] +quant = [ + { name = "bitsandbytes" }, +] server = [ { name = "fastapi" }, { name = "uvicorn" }, @@ -1113,13 +1126,18 @@ dev = [ [package.metadata] requires-dist = [ + { name = "bitsandbytes", marker = "extra == 'quant'", specifier = ">=0.43" }, { name = "colpali-engine", specifier = ">=0.3.4,<0.4.0" }, { name = "docling", marker = "extra == 'docling'", specifier = ">=2.76.0" }, { name = "docx2pdf", specifier = ">=0.1.8" }, { name = "docx2pdf", marker = "sys_platform == 'win32' and extra == 'extra-converters'", specifier = ">=0.1.8" }, + { name = "fastapi", marker = "extra == 'full'" }, { name = "fastapi", marker = "extra == 'server'" }, + { name = "httpx", marker = "extra == 'client'", specifier = ">=0.27.0" }, + { name = "httpx", marker = "extra == 'full'", specifier = ">=0.27.0" }, { name = "langchain-core", marker = "extra == 'langchain'" }, { name = "langdetect", specifier = ">=1.0.9" }, + { name = "matplotlib", specifier = ">=3.8.0" }, { name = "ml-dtypes" }, { name = "pdf2image", specifier = ">=1.17.0" }, { name = "pydantic-ai", specifier = ">=1.4.0" }, @@ -1134,8 +1152,10 @@ requires-dist = [ { name = "srsly" }, { name = "torch", specifier = ">=2.7.1" }, { name = "transformers", specifier = ">=4.42.0" }, + { name = "uvicorn", marker = "extra == 'full'" }, { name = "uvicorn", marker = "extra == 'server'" }, ] +provides-extras = ["dev", "server", "client", "full", "quant", "langchain", "extra-converters"] [package.metadata.requires-dev] dev = [{ name = "matplotlib", specifier = ">=3.10.8" }]