From c46c4835af2667de19f3a26474f49ba17d67a837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Le=20M=C3=A9tayer?= Date: Tue, 14 Apr 2026 15:41:08 +0200 Subject: [PATCH 1/7] Add 4-bit and 8-bit model quantization support --- .gitignore | 3 +- foretrieval/colpali.py | 42 ++++++++++++- foretrieval/retriever.py | 8 +++ pyproject.toml | 3 + tests/test_quantization.py | 126 +++++++++++++++++++++++++++++++++++++ uv.lock | 8 +++ 6 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 tests/test_quantization.py diff --git a/.gitignore b/.gitignore index 10ef763..c0a3523 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,5 @@ local/ pylate/* pylate/ -.byaldi/ \ No newline at end of file +.byaldi/ + diff --git a/foretrieval/colpali.py b/foretrieval/colpali.py index 1530737..69c47d9 100644 --- a/foretrieval/colpali.py +++ b/foretrieval/colpali.py @@ -1,4 +1,5 @@ import base64 +import importlib.util import io import logging import os @@ -13,6 +14,7 @@ from pdf2image import convert_from_path from PIL import Image import torch +from transformers import BitsAndBytesConfig try: from qdrant_client import QdrantClient @@ -126,6 +128,43 @@ def __init__( self.qdrant_collection = index_name self.qdrant_path = None + load_in_4bit = bool(kwargs.pop("load_in_4bit", False)) + load_in_8bit = bool(kwargs.pop("load_in_8bit", False)) + bnb_4bit_quant_type = str(kwargs.pop("bnb_4bit_quant_type", "nf4")) + bnb_4bit_compute_dtype = str(kwargs.pop("bnb_4bit_compute_dtype", "float16")) + + if load_in_4bit and load_in_8bit: + raise ValueError("Only one quantization mode can be enabled: 4-bit or 8-bit.") + + quantization_config = None + if load_in_4bit or load_in_8bit: + if importlib.util.find_spec("bitsandbytes") is None: + raise ImportError( + "Quantization requested but `bitsandbytes` is not installed. " + "Install it with `pip install bitsandbytes`." + ) + if load_in_4bit: + compute_dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + if bnb_4bit_compute_dtype not in compute_dtype_map: + raise ValueError( + "Invalid bnb_4bit_compute_dtype. Expected one of: " + "'float16', 'bfloat16', 'float32'." + ) + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=compute_dtype_map[bnb_4bit_compute_dtype], + ) + else: + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + self._load_in_4bit = load_in_4bit + self._load_in_8bit = load_in_8bit + self._quantization_config = quantization_config + if self.storage_qdrant and self.index_name is not None: self.qdrant_path = Path(self.index_root) / self.index_name / "qdrant" self.qdrant_client = QdrantClient(path=str(self.qdrant_path)) @@ -165,6 +204,7 @@ def _load_model_and_processor(self): self.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, device_map=device_map, + quantization_config=self._quantization_config, token=token, ) self.processor = processor_cls.from_pretrained( @@ -173,7 +213,7 @@ def _load_model_and_processor(self): ) self.model = self.model.eval() - if device_map is None: + if device_map is None and not (self._load_in_4bit or self._load_in_8bit): self.model = self.model.to(self.device) def _load_index_state(self): diff --git a/foretrieval/retriever.py b/foretrieval/retriever.py index 0b76b77..f8cb581 100644 --- a/foretrieval/retriever.py +++ b/foretrieval/retriever.py @@ -46,6 +46,10 @@ def from_pretrained( ingestion: Dict[str, Any] = {"backend": "default"}, device: str = "cuda", verbose: int = 1, + load_in_4bit: bool = False, + load_in_8bit: bool = False, + bnb_4bit_quant_type: str = "nf4", + bnb_4bit_compute_dtype: str = "float16", ): """Load a ColPali model from a pre-trained checkpoint. @@ -66,6 +70,10 @@ def from_pretrained( ingestion=ingestion, device=device, verbose=verbose, + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, ) return instance diff --git a/pyproject.toml b/pyproject.toml index c640dfd..a284878 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "colpali-engine>=0.3.4,<0.4.0", "docx2pdf>=0.1.8", "langdetect>=1.0.9", + "matplotlib>=3.8.0", "ml-dtypes", "pydantic-ai>=1.4.0", "pydantic-ai-slim[openai]>=1.8.0", @@ -43,6 +44,8 @@ qdrant = ["qdrant-client>=1.17.1"] # Docling-based PDF chunking ingestion pipeline docling = ["docling>=2.76.0"] dev = ["pytest>=7.4.0", "ruff>=0.1.9"] +server = ["uvicorn", "fastapi"] +quant = ["bitsandbytes>=0.43"] langchain = ["langchain-core"] extra_converters = [ "docx2pdf>=0.1.8; sys_platform == \"win32\"", diff --git a/tests/test_quantization.py b/tests/test_quantization.py new file mode 100644 index 0000000..d246b43 --- /dev/null +++ b/tests/test_quantization.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import pytest +import torch +from transformers import BitsAndBytesConfig + +from foretrieval.colpali import ColPaliModel +from foretrieval.retriever import MultiModalRetrieverModel + + +class _DummyModel: + def eval(self): + return self + + def to(self, _device): + return self + + +class _DummyProcessor: + pass + + +def test_quant_modes_mutually_exclusive(): + with pytest.raises(ValueError, match="Only one quantization mode"): + ColPaliModel.from_pretrained( + "vidore/colpali-v1.3", + device="cpu", + load_in_4bit=True, + load_in_8bit=True, + ) + + +def test_invalid_4bit_compute_dtype(monkeypatch): + import foretrieval.colpali as colpali_module + + monkeypatch.setattr(colpali_module.importlib.util, "find_spec", lambda _name: object()) + + with pytest.raises(ValueError, match="Invalid bnb_4bit_compute_dtype"): + ColPaliModel.from_pretrained( + "vidore/colpali-v1.3", + device="cpu", + load_in_4bit=True, + bnb_4bit_compute_dtype="float64", + ) + + +def test_wrapper_forwards_quant_args(monkeypatch): + import foretrieval.retriever as retriever_module + + captured = {} + + def _fake_from_pretrained(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + return object() + + monkeypatch.setattr(retriever_module.ColPaliModel, "from_pretrained", _fake_from_pretrained) + + _ = MultiModalRetrieverModel.from_pretrained( + "vidore/colpali-v1.3", + device="cpu", + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="float16", + ) + + assert captured["args"][0] == "vidore/colpali-v1.3" + assert captured["kwargs"]["load_in_4bit"] is True + assert captured["kwargs"]["load_in_8bit"] is False + assert captured["kwargs"]["bnb_4bit_quant_type"] == "nf4" + assert captured["kwargs"]["bnb_4bit_compute_dtype"] == "float16" + + +@pytest.mark.parametrize( + ("loader_attr", "quant_kwargs", "expected"), + [ + ( + "ColPali", + {"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": "float16"}, + {"load_in_4bit": True, "load_in_8bit": False, "dtype": torch.float16}, + ), + ( + "ColPali", + {"load_in_8bit": True}, + {"load_in_4bit": False, "load_in_8bit": True, "dtype": None}, + ), + ], +) +def test_quantization_config_is_passed(monkeypatch, loader_attr, quant_kwargs, expected): + import foretrieval.colpali as colpali_module + + monkeypatch.setattr(colpali_module.importlib.util, "find_spec", lambda _name: object()) + + captured = {} + + def _fake_model_from_pretrained(*_args, **kwargs): + captured["model_kwargs"] = kwargs + return _DummyModel() + + def _fake_processor_from_pretrained(*_args, **_kwargs): + return _DummyProcessor() + + monkeypatch.setattr(getattr(colpali_module, loader_attr), "from_pretrained", _fake_model_from_pretrained) + monkeypatch.setattr(colpali_module.ColPaliProcessor, "from_pretrained", _fake_processor_from_pretrained) + + _ = ColPaliModel.from_pretrained("vidore/colpali-v1.3", device="cpu", **quant_kwargs) + + qcfg = captured["model_kwargs"]["quantization_config"] + assert isinstance(qcfg, BitsAndBytesConfig) + assert qcfg.load_in_4bit is expected["load_in_4bit"] + assert qcfg.load_in_8bit is expected["load_in_8bit"] + if expected["dtype"] is not None: + assert qcfg.bnb_4bit_compute_dtype == expected["dtype"] + + +def test_quantization_requires_bitsandbytes(monkeypatch): + import foretrieval.colpali as colpali_module + + monkeypatch.setattr(colpali_module.importlib.util, "find_spec", lambda _name: None) + + with pytest.raises(ImportError, match="bitsandbytes"): + ColPaliModel.from_pretrained( + "vidore/colpali-v1.3", + device="cpu", + load_in_4bit=True, + ) diff --git a/uv.lock b/uv.lock index af0739e..328c7b1 100644 --- a/uv.lock +++ b/uv.lock @@ -1068,9 +1068,11 @@ name = "foretrieval" version = "0.1" source = { editable = "." } dependencies = [ + { name = "bitsandbytes", marker = "extra == 'quant'" }, { name = "colpali-engine" }, { name = "docx2pdf" }, { name = "langdetect" }, + { name = "matplotlib" }, { name = "ml-dtypes" }, { name = "pdf2image" }, { name = "pydantic-ai" }, @@ -1101,6 +1103,9 @@ langchain = [ qdrant = [ { name = "qdrant-client" }, ] +quant = [ + { name = "bitsandbytes" }, +] server = [ { name = "fastapi" }, { name = "uvicorn" }, @@ -1113,6 +1118,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "bitsandbytes", marker = "extra == 'quant'", specifier = ">=0.43" }, { name = "colpali-engine", specifier = ">=0.3.4,<0.4.0" }, { name = "docling", marker = "extra == 'docling'", specifier = ">=2.76.0" }, { name = "docx2pdf", specifier = ">=0.1.8" }, @@ -1120,6 +1126,7 @@ requires-dist = [ { name = "fastapi", marker = "extra == 'server'" }, { name = "langchain-core", marker = "extra == 'langchain'" }, { name = "langdetect", specifier = ">=1.0.9" }, + { name = "matplotlib", specifier = ">=3.8.0" }, { name = "ml-dtypes" }, { name = "pdf2image", specifier = ">=1.17.0" }, { name = "pydantic-ai", specifier = ">=1.4.0" }, @@ -1136,6 +1143,7 @@ requires-dist = [ { name = "transformers", specifier = ">=4.42.0" }, { name = "uvicorn", marker = "extra == 'server'" }, ] +provides-extras = ["dev", "server", "quant", "langchain", "extra-converters"] [package.metadata.requires-dev] dev = [{ name = "matplotlib", specifier = ">=3.10.8" }] From b7149c343d9c6fed1c7ad3c896e847350c18642c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Le=20M=C3=A9tayer?= Date: Thu, 16 Apr 2026 10:34:55 +0200 Subject: [PATCH 2/7] feat(architecture): introduce client-server split to offload inference to remote server --- .dockerignore | 16 +++ .gitignore | 2 + Dockerfile | 61 +++++++++++ Makefile | 53 +++++++++ foretrieval/client/__init__.py | 3 + foretrieval/client/embedding_backends.py | 62 +++++++++++ foretrieval/client/remote_backend.py | 64 +++++++++++ foretrieval/client/transport.py | 17 +++ foretrieval/colpali.py | 133 ++++++++++++++++++----- foretrieval/retriever.py | 10 ++ foretrieval/server/__init__.py | 1 + foretrieval/server/embedding_server.py | 86 +++++++++++++++ foretrieval/server/server_main.py | 56 ++++++++++ pyproject.toml | 4 +- scripts/run-docker.sh | 74 +++++++++++++ tests/test_colpali.py | 8 +- tests/test_colqwen.py | 6 +- tests/test_remote_mode.py | 81 ++++++++++++++ uv.lock | 14 ++- 19 files changed, 719 insertions(+), 32 deletions(-) create mode 100644 .dockerignore create mode 100644 Dockerfile create mode 100644 Makefile create mode 100644 foretrieval/client/__init__.py create mode 100644 foretrieval/client/embedding_backends.py create mode 100644 foretrieval/client/remote_backend.py create mode 100644 foretrieval/client/transport.py create mode 100644 foretrieval/server/__init__.py create mode 100644 foretrieval/server/embedding_server.py create mode 100644 foretrieval/server/server_main.py create mode 100755 scripts/run-docker.sh create mode 100644 tests/test_remote_mode.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..72455a4 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,16 @@ +.git +.github +.venv +.pytest_cache +.vscode +__pycache__ +*.pyc +*.pyo +*.pyd +*.egg-info +build +dist +tests +sample_data +tmp_forag_test +tmp_forag_test_remote diff --git a/.gitignore b/.gitignore index c0a3523..0156186 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,5 @@ pylate/ .byaldi/ +.pat + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..be1f41e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,61 @@ +# syntax=docker/dockerfile:1.7 +FROM python:3.12-slim + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +WORKDIR /app + +# Copy only dependency metadata first for better layer caching +COPY pyproject.toml ./ + +# Install third-party runtime dependencies from pyproject extras. +# This layer is invalidated only when dependency metadata changes. +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --upgrade pip setuptools wheel && \ + python - <<'PY' > /tmp/requirements-runtime.txt +import tomllib +from pathlib import Path + +cfg = tomllib.loads(Path("pyproject.toml").read_text()) +project = cfg["project"] +deps = list(project.get("dependencies", [])) +optional = project.get("optional-dependencies", {}) +deps.extend(optional.get("server", [])) +deps.extend(optional.get("quant", [])) + +# Stable order and dedupe for deterministic cache behavior. +seen = set() +ordered = [] +for dep in deps: + if dep not in seen: + seen.add(dep) + ordered.append(dep) +print("\n".join(ordered)) +PY +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install -r /tmp/requirements-runtime.txt + +# Copy package sources late so app code changes do not bust dependency layers +COPY README.md ./ +COPY foretrieval ./foretrieval + +# Install local package without re-resolving dependencies +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --no-deps . + +# Server runtime defaults (override at `docker run -e ...`) +ENV FOR_EMBED_MODEL=vidore/colqwen2-v1.0 \ + FOR_EMBED_DEVICE=cpu \ + FOR_EMBED_VERBOSE=1 \ + FOR_SERVER_WORKERS=1 \ + FOR_SERVER_MAX_INFLIGHT=1 \ + FOR_EMBED_LOAD_IN_4BIT=false \ + FOR_EMBED_LOAD_IN_8BIT=false \ + FOR_EMBED_BNB_4BIT_QUANT_TYPE=nf4 \ + FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE=float16 + +EXPOSE 8000 + +CMD ["sh", "-c", "python -m uvicorn foretrieval.server.server_main:app --host 0.0.0.0 --port 8000 --workers ${FOR_SERVER_WORKERS:-1}"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c643780 --- /dev/null +++ b/Makefile @@ -0,0 +1,53 @@ +REGISTRY ?= ghcr.io +OWNER ?= random-plm +IMAGE_NAME ?= foretrieval-server +IMAGE_TAG ?= v0.0.1 +IMAGE_REPO ?= $(REGISTRY)/$(OWNER)/$(IMAGE_NAME) +IMAGE ?= $(IMAGE_REPO):$(IMAGE_TAG) + +# Backward-compatible aliases. +NAMESPACE ?= $(OWNER) +PROJECT ?= $(IMAGE_NAME) +DOCKERFILE ?= Dockerfile +BUILD_CONTEXT ?= . +PYTEST ?= pytest +TEST_ARGS ?= -m "not slow and not integration" +SRC_PACKAGE ?= foretrieval + +.PHONY: check-buildx login-registry build publish run-server test-fast test-all coverage-fast coverage-all + +check-buildx: + @docker buildx version >/dev/null 2>&1 || { \ + echo "Error: docker buildx is not available."; \ + echo "Install Docker Buildx (or Docker Desktop), then run this command again."; \ + echo "Linux plugin package is often named 'docker-buildx-plugin'."; \ + exit 1; \ + } + +login-registry: + @test -n "$$GITHUB_PAT" || { \ + echo "Error: GITHUB_PAT is not set. Export a classic PAT with write:packages."; \ + exit 1; \ + } + @printf '%s' "$$GITHUB_PAT" | docker login ghcr.io -u $(OWNER) --password-stdin + +build: check-buildx + docker buildx build --load -f $(DOCKERFILE) -t $(IMAGE) $(BUILD_CONTEXT) + +publish: check-buildx login-registry + docker buildx build --push -f $(DOCKERFILE) -t $(IMAGE) $(BUILD_CONTEXT) + +run-server: + REGISTRY=$(REGISTRY) OWNER=$(OWNER) IMAGE_NAME=$(IMAGE_NAME) IMAGE_TAG=$(IMAGE_TAG) ./scripts/run-docker.sh + +test-fast: + $(PYTEST) $(TEST_ARGS) + +test-all: + $(PYTEST) + +coverage-fast: + $(PYTEST) $(TEST_ARGS) --cov=$(SRC_PACKAGE) --cov-report=term-missing --cov-report=xml + +coverage-all: + $(PYTEST) --cov=$(SRC_PACKAGE) --cov-report=term-missing --cov-report=xml diff --git a/foretrieval/client/__init__.py b/foretrieval/client/__init__.py new file mode 100644 index 0000000..f6ffba7 --- /dev/null +++ b/foretrieval/client/__init__.py @@ -0,0 +1,3 @@ +from .embedding_backends import LocalEmbeddingBackend, RemoteEmbeddingBackend +from .remote_backend import RemoteEmbeddingClient +from .transport import dumps_payload, loads_payload diff --git a/foretrieval/client/embedding_backends.py b/foretrieval/client/embedding_backends.py new file mode 100644 index 0000000..e492987 --- /dev/null +++ b/foretrieval/client/embedding_backends.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import List, Protocol, Union + +import torch +from PIL import Image + +from .remote_backend import RemoteEmbeddingClient + + +class EmbeddingBackend(Protocol): + def embed_images(self, images: List[Image.Image]) -> torch.Tensor: + ... + + def embed_queries(self, queries: List[str]) -> torch.Tensor: + ... + + +class LocalEmbeddingBackend: + def __init__(self, model, processor, device): + self.model = model + self.processor = processor + self.device = device + + def embed_images(self, images: List[Image.Image]) -> torch.Tensor: + with torch.inference_mode(): + batch = self.processor.process_images(images) + batch = { + k: v.to(self.device).to( + self.model.dtype + if v.dtype in [torch.float16, torch.bfloat16, torch.float32] + else v.dtype + ) + for k, v in batch.items() + } + embeddings = self.model(**batch) + return embeddings.cpu() + + def embed_queries(self, queries: List[str]) -> torch.Tensor: + with torch.inference_mode(): + batch = self.processor.process_queries(queries) + batch = { + k: v.to(self.device).to( + self.model.dtype + if v.dtype in [torch.float16, torch.bfloat16, torch.float32] + else v.dtype + ) + for k, v in batch.items() + } + embeddings = self.model(**batch) + return embeddings.cpu() + + +class RemoteEmbeddingBackend: + def __init__(self, client: RemoteEmbeddingClient): + self.client = client + + def embed_images(self, images: List[Image.Image]) -> torch.Tensor: + return self.client.encode_images(images).cpu() + + def embed_queries(self, queries: List[str]) -> torch.Tensor: + return self.client.encode_queries(queries).cpu() diff --git a/foretrieval/client/remote_backend.py b/foretrieval/client/remote_backend.py new file mode 100644 index 0000000..7c38726 --- /dev/null +++ b/foretrieval/client/remote_backend.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import io +from typing import List, Optional + +import torch +from PIL import Image + +from .transport import dumps_payload, loads_payload + + +class RemoteEmbeddingClient: + def __init__( + self, + server_url: str, + model_name: str, + token: Optional[str] = None, + timeout: float = 30.0, + ): + try: + import httpx + except ImportError as exc: + raise ImportError( + "Remote embedding mode requires `httpx`. Install with `pip install httpx`." + ) from exc + + self._httpx = httpx + self.server_url = server_url.rstrip("/") + self.model_name = model_name + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + self.client = httpx.Client(timeout=timeout, headers=headers) + + @staticmethod + def _image_to_bytes(image: Image.Image) -> bytes: + buf = io.BytesIO() + image.save(buf, format="PNG") + return buf.getvalue() + + def encode_images(self, images: List[Image.Image]) -> torch.Tensor: + payload = { + "model": self.model_name, + "images": [self._image_to_bytes(image.convert("RGB")) for image in images], + } + resp = self.client.post( + f"{self.server_url}/v1/embed/images", + content=dumps_payload(payload), + headers={"Content-Type": "application/octet-stream"}, + ) + resp.raise_for_status() + out = loads_payload(resp.content) + return out["embeddings"].cpu() + + def encode_queries(self, queries: List[str]) -> torch.Tensor: + payload = {"model": self.model_name, "queries": queries} + resp = self.client.post( + f"{self.server_url}/v1/embed/queries", + content=dumps_payload(payload), + headers={"Content-Type": "application/octet-stream"}, + ) + resp.raise_for_status() + out = loads_payload(resp.content) + return out["embeddings"].cpu() diff --git a/foretrieval/client/transport.py b/foretrieval/client/transport.py new file mode 100644 index 0000000..9db18ac --- /dev/null +++ b/foretrieval/client/transport.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import io +from typing import Any + +import torch + + +def dumps_payload(payload: Any) -> bytes: + buffer = io.BytesIO() + torch.save(payload, buffer) + return buffer.getvalue() + + +def loads_payload(data: bytes) -> Any: + buffer = io.BytesIO(data) + return torch.load(buffer, map_location="cpu") diff --git a/foretrieval/colpali.py b/foretrieval/colpali.py index 69c47d9..b6aa44d 100644 --- a/foretrieval/colpali.py +++ b/foretrieval/colpali.py @@ -43,6 +43,8 @@ from .objects import Result from .plot_utils import draw_circle_on_max_patch, pil_from_base64, pil_to_base64_png, compute_patch_heatmap, majority_token_id, build_heatmap_overlays_base64 from .utils import _value_match +from .client.embedding_backends import LocalEmbeddingBackend, RemoteEmbeddingBackend +from .client.remote_backend import RemoteEmbeddingClient VERSION = "0.0.1" @@ -123,6 +125,10 @@ def __init__( self.enable_circle = False self.SOURCE_EXTS = {".doc", ".docx", ".rtf", ".odt", ".ppt", ".pptx", ".odp", ".xls", ".xlsx", ".ods", ".txt", ".md", ".csv", ".json", ".yaml", ".yml", ".epub", ".html"} self.IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".gif"} + self.embedding_mode = str(kwargs.pop("embedding_mode", "local")) + self.embedding_server_url = kwargs.pop("embedding_server_url", None) + self.embedding_server_token = kwargs.pop("embedding_server_token", None) + self.embedding_request_timeout = float(kwargs.pop("embedding_request_timeout", 30.0)) self.qdrant_client = None self.qdrant_collection = index_name @@ -174,7 +180,25 @@ def __init__( self.docling_dir = Path(index_root) / self.index_name / "docling_chunks" self.docling_dir.mkdir(parents=True, exist_ok=True) - self._load_model_and_processor() + self.embedding_backend = None + if self.embedding_mode == "remote": + if not self.embedding_server_url: + raise ValueError("embedding_server_url is required when embedding_mode='remote'.") + self.remote_client = RemoteEmbeddingClient( + self.embedding_server_url, + model_name=self.pretrained_model_name_or_path, + token=self.embedding_server_token, + timeout=self.embedding_request_timeout, + ) + self.embedding_backend = RemoteEmbeddingBackend(self.remote_client) + self._load_remote_processor_only() + self.model = None + elif self.embedding_mode == "local": + self.remote_client = None + self._load_model_and_processor() + self.embedding_backend = LocalEmbeddingBackend(self.model, self.processor, self.device) + else: + raise ValueError("embedding_mode must be either 'local' or 'remote'.") if not load_from_index: self.full_document_collection = False @@ -216,6 +240,25 @@ def _load_model_and_processor(self): if device_map is None and not (self._load_in_4bit or self._load_in_8bit): self.model = self.model.to(self.device) + def _load_remote_processor_only(self): + token = self.kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN") + model_name = self.pretrained_model_name_or_path.lower() + if "colpali" in model_name: + self.processor = ColPaliProcessor.from_pretrained( + self.pretrained_model_name_or_path, + token=token, + ) + elif "colqwen2.5" in model_name: + self.processor = ColQwen2_5_Processor.from_pretrained( + self.pretrained_model_name_or_path, + token=token, + ) + else: + self.processor = ColQwen2Processor.from_pretrained( + self.pretrained_model_name_or_path, + token=token, + ) + def _load_index_state(self): if self.index_name is None: raise ValueError("No index name specified. Cannot load from index.") @@ -899,17 +942,7 @@ def _add_to_index( # Optionnel mais utile : taille originale de l'image orig_sizes = [img.size for img in images] # (W,H) PIL - # Generate embeddings - with torch.inference_mode(): - processed_images = { - k: v.to(self.device).to( - self.model.dtype - if v.dtype in [torch.float16, torch.bfloat16, torch.float32] - else v.dtype - ) - for k, v in processed_images.items() - } - embeddings = self.model(**processed_images) + embeddings = self._embed_images(images) # 1. Compute embeddings # 2. Ensure backend storage and check duplicates @@ -1129,23 +1162,16 @@ def remove_from_index(self): # ============================================================ def _encode_search_query(self, query: str): - with torch.inference_mode(): - batch_query = self.processor.process_queries([query]) - batch_query = { - kk: vv.to(self.device).to( - self.model.dtype - if vv.dtype in [torch.float16, torch.bfloat16, torch.float32] - else vv.dtype - ) - for kk, vv in batch_query.items() - } - embeddings_query = self.model(**batch_query) - qs = list(torch.unbind(embeddings_query.to("cpu"))) + embeddings_query = self._embed_queries([query]) + qs = list(torch.unbind(embeddings_query.to("cpu"))) - input_ids = batch_query["input_ids"][0].detach().cpu().tolist() - tokens = self.processor.tokenizer.convert_ids_to_tokens(input_ids) - valid_idxs = [i for i, tok in enumerate(tokens) if tok not in {"<|endoftext|>", "Query", ":"}] - return [qs[0][valid_idxs]] + if self.embedding_mode == "local": + batch_query = self.processor.process_queries([query]) + input_ids = batch_query["input_ids"][0].detach().cpu().tolist() + tokens = self.processor.tokenizer.convert_ids_to_tokens(input_ids) + valid_idxs = [i for i, tok in enumerate(tokens) if tok not in {"<|endoftext|>", "Query", ":"}] + return [qs[0][valid_idxs]] + return qs def _search_local( self, @@ -1539,6 +1565,57 @@ def _post_process_image(self, image: Image.Image) -> str: # File helpers # ============================================================ + def encode_image( + self, input_data: Union[str, Image.Image, List[Union[str, Image.Image]]] + ) -> torch.Tensor: + if not isinstance(input_data, list): + input_data = [input_data] + + images = [] + for item in input_data: + if isinstance(item, Image.Image): + images.append(item) + elif isinstance(item, str): + if os.path.isdir(item): + for file in os.listdir(item): + if file.lower().endswith( + (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif") + ): + images.append(Image.open(os.path.join(item, file))) + elif item.lower().endswith(".pdf"): + with tempfile.TemporaryDirectory() as path: + pdf_images = convert_from_path( + item, thread_count=os.cpu_count() - 1, output_folder=path + ) + images.extend(pdf_images) + elif item.lower().endswith( + (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif") + ): + images.append(Image.open(item)) + else: + raise ValueError(f"Unsupported file type: {item}") + else: + raise ValueError(f"Unsupported input type: {type(item)}") + + return self._embed_images(images).cpu() + + def encode_query(self, query: Union[str, List[str]]) -> torch.Tensor: + if isinstance(query, str): + query = [query] + + return self._embed_queries(query).cpu() + + def _embed_images(self, images: List[Image.Image]) -> torch.Tensor: + if self.embedding_backend is None: + raise RuntimeError("Embedding backend is not initialized.") + return self.embedding_backend.embed_images(images) + + def _embed_queries(self, query: Union[str, List[str]]) -> torch.Tensor: + queries = [query] if isinstance(query, str) else query + if self.embedding_backend is None: + raise RuntimeError("Embedding backend is not initialized.") + return self.embedding_backend.embed_queries(queries) + def _looks_like_pdf(self, path: Path) -> bool: try: if not path.exists() or path.stat().st_size < 5: diff --git a/foretrieval/retriever.py b/foretrieval/retriever.py index f8cb581..03c992e 100644 --- a/foretrieval/retriever.py +++ b/foretrieval/retriever.py @@ -46,6 +46,11 @@ def from_pretrained( ingestion: Dict[str, Any] = {"backend": "default"}, device: str = "cuda", verbose: int = 1, + hf_token: Optional[str] = None, + embedding_mode: str = "local", + embedding_server_url: Optional[str] = None, + embedding_server_token: Optional[str] = None, + embedding_request_timeout: float = 30.0, load_in_4bit: bool = False, load_in_8bit: bool = False, bnb_4bit_quant_type: str = "nf4", @@ -70,6 +75,11 @@ def from_pretrained( ingestion=ingestion, device=device, verbose=verbose, + hf_token=hf_token, + embedding_mode=embedding_mode, + embedding_server_url=embedding_server_url, + embedding_server_token=embedding_server_token, + embedding_request_timeout=embedding_request_timeout, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit, bnb_4bit_quant_type=bnb_4bit_quant_type, diff --git a/foretrieval/server/__init__.py b/foretrieval/server/__init__.py new file mode 100644 index 0000000..28e776c --- /dev/null +++ b/foretrieval/server/__init__.py @@ -0,0 +1 @@ +from .embedding_server import create_app diff --git a/foretrieval/server/embedding_server.py b/foretrieval/server/embedding_server.py new file mode 100644 index 0000000..8b20eed --- /dev/null +++ b/foretrieval/server/embedding_server.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import asyncio +import io +from typing import Dict, Optional + +from fastapi import FastAPI, HTTPException, Request, Response +from PIL import Image + +from ..client.transport import dumps_payload, loads_payload +from ..retriever import MultiModalRetrieverModel + + +def create_app( + model_name: str, + device: str = "cuda", + hf_token: Optional[str] = None, + verbose: int = 1, + load_in_4bit: bool = False, + load_in_8bit: bool = False, + bnb_4bit_quant_type: str = "nf4", + bnb_4bit_compute_dtype: str = "float16", + max_inflight: int = 1, +) -> FastAPI: + app = FastAPI(title="FORetrieval Embedding Server") + inflight_semaphore = asyncio.Semaphore(max_inflight) + + retriever_cache: Dict[str, MultiModalRetrieverModel] = {} + + def get_retriever(requested_model: Optional[str]) -> MultiModalRetrieverModel: + target_model = requested_model or model_name + if target_model not in retriever_cache: + retriever_cache[target_model] = MultiModalRetrieverModel.from_pretrained( + target_model, + device=device, + verbose=verbose, + hf_token=hf_token, + embedding_mode="local", + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + ) + return retriever_cache[target_model] + + @app.get("/health") + def health(): + return { + "status": "ok", + "default_model": model_name, + "loaded_models": list(retriever_cache.keys()), + "device": str(device), + "max_inflight": max_inflight, + } + + @app.post("/v1/embed/images") + async def embed_images(request: Request): + try: + payload = loads_payload(await request.body()) + async with inflight_semaphore: + retriever = get_retriever(payload.get("model")) + images = [Image.open(io.BytesIO(b)).convert("RGB") for b in payload["images"]] + embeddings = retriever.model.encode_image(images) + return Response( + content=dumps_payload({"embeddings": embeddings.cpu()}), + media_type="application/octet-stream", + ) + except Exception as exc: + raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") from exc + + @app.post("/v1/embed/queries") + async def embed_queries(request: Request): + try: + payload = loads_payload(await request.body()) + async with inflight_semaphore: + retriever = get_retriever(payload.get("model")) + queries = payload["queries"] + embeddings = retriever.model.encode_query(queries) + return Response( + content=dumps_payload({"embeddings": embeddings.cpu()}), + media_type="application/octet-stream", + ) + except Exception as exc: + raise HTTPException(status_code=400, detail=f"Invalid query payload: {exc}") from exc + + return app diff --git a/foretrieval/server/server_main.py b/foretrieval/server/server_main.py new file mode 100644 index 0000000..2fb5fa7 --- /dev/null +++ b/foretrieval/server/server_main.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import os + +from .embedding_server import create_app + + +def _as_bool(name: str, default: bool = False) -> bool: + val = os.environ.get(name) + if val is None: + return default + return val.strip().lower() in {"1", "true", "yes", "on"} + + +def _as_positive_int(name: str, default: int) -> int: + val = os.environ.get(name) + if val is None: + return default + try: + parsed = int(val) + except ValueError as exc: + raise ValueError(f"{name} must be an integer, got {val!r}.") from exc + if parsed < 1: + raise ValueError(f"{name} must be >= 1, got {parsed}.") + return parsed + + +def _as_optional_str(name: str) -> str | None: + val = os.environ.get(name) + if val is None: + return None + stripped = val.strip() + return stripped if stripped else None + + +MODEL_NAME = os.environ.get("FOR_EMBED_MODEL", "vidore/colqwen2-v1.0") +DEVICE = os.environ.get("FOR_EMBED_DEVICE", "cuda") +VERBOSE = int(os.environ.get("FOR_EMBED_VERBOSE", "1")) +HF_TOKEN = _as_optional_str("HF_TOKEN") +LOAD_IN_4BIT = _as_bool("FOR_EMBED_LOAD_IN_4BIT", False) +LOAD_IN_8BIT = _as_bool("FOR_EMBED_LOAD_IN_8BIT", False) +BNB_4BIT_QUANT_TYPE = os.environ.get("FOR_EMBED_BNB_4BIT_QUANT_TYPE", "nf4") +BNB_4BIT_COMPUTE_DTYPE = os.environ.get("FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE", "float16") +MAX_INFLIGHT = _as_positive_int("FOR_SERVER_MAX_INFLIGHT", 1) + +app = create_app( + model_name=MODEL_NAME, + device=DEVICE, + verbose=VERBOSE, + hf_token=HF_TOKEN, + load_in_4bit=LOAD_IN_4BIT, + load_in_8bit=LOAD_IN_8BIT, + bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE, + bnb_4bit_compute_dtype=BNB_4BIT_COMPUTE_DTYPE, + max_inflight=MAX_INFLIGHT, +) diff --git a/pyproject.toml b/pyproject.toml index a284878..f82cd0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,8 +43,10 @@ dependencies = [ qdrant = ["qdrant-client>=1.17.1"] # Docling-based PDF chunking ingestion pipeline docling = ["docling>=2.76.0"] -dev = ["pytest>=7.4.0", "ruff>=0.1.9"] +dev = ["pytest>=7.4.0", "pytest-cov>=5.0.0", "ruff>=0.1.9"] server = ["uvicorn", "fastapi"] +client = ["httpx>=0.27.0"] +full = ["httpx>=0.27.0", "uvicorn", "fastapi"] quant = ["bitsandbytes>=0.43"] langchain = ["langchain-core"] extra_converters = [ diff --git a/scripts/run-docker.sh b/scripts/run-docker.sh new file mode 100755 index 0000000..ea9bcba --- /dev/null +++ b/scripts/run-docker.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Image selection (GHCR-friendly defaults) +REGISTRY="${REGISTRY:-ghcr.io}" +OWNER="${OWNER:-random-plm}" +IMAGE_NAME="${IMAGE_NAME:-foretrieval-server}" +IMAGE_TAG="${IMAGE_TAG:-v0.0.1}" +IMAGE="${IMAGE:-${REGISTRY}/${OWNER}/${IMAGE_NAME}:${IMAGE_TAG}}" +GITHUB_PAT="${GITHUB_PAT:-}" + +# Container/runtime settings +CONTAINER_NAME="${CONTAINER_NAME:-foretrieval-server}" +HOST_PORT="${HOST_PORT:-8000}" +CONTAINER_PORT="${CONTAINER_PORT:-8000}" +HF_CACHE_DIR="${HF_CACHE_DIR:-/models/foretrieval}" + +# GPU pinning +CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" + +# Server env vars +FOR_EMBED_MODEL="${FOR_EMBED_MODEL:-vidore/colqwen2-v1.0}" +FOR_EMBED_DEVICE="${FOR_EMBED_DEVICE:-cuda}" +FOR_EMBED_VERBOSE="${FOR_EMBED_VERBOSE:-1}" +FOR_SERVER_WORKERS="${FOR_SERVER_WORKERS:-1}" +FOR_SERVER_MAX_INFLIGHT="${FOR_SERVER_MAX_INFLIGHT:-1}" +FOR_EMBED_LOAD_IN_4BIT="${FOR_EMBED_LOAD_IN_4BIT:-false}" +FOR_EMBED_LOAD_IN_8BIT="${FOR_EMBED_LOAD_IN_8BIT:-false}" +FOR_EMBED_BNB_4BIT_QUANT_TYPE="${FOR_EMBED_BNB_4BIT_QUANT_TYPE:-nf4}" +FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE="${FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE:-float16}" +HF_TOKEN="${HF_TOKEN:-}" + +echo "Starting container '${CONTAINER_NAME}' from image '${IMAGE}'..." +echo "GPU pinning: CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" +echo "HF cache mount: ${HF_CACHE_DIR} -> /root/.cache/huggingface" + +if [[ -z "${GITHUB_PAT}" ]]; then + echo "Error: GITHUB_PAT is not set." + echo "Set it before running, e.g.:" + echo " export GITHUB_PAT=ghp_xxx" + exit 1 +fi + +printf '%s' "${GITHUB_PAT}" | docker login ghcr.io -u "${OWNER}" --password-stdin >/dev/null +echo "Authenticated to ghcr.io as ${OWNER}." + +mkdir -p "${HF_CACHE_DIR}" + +env_args=( + -e CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" + -e FOR_EMBED_MODEL="${FOR_EMBED_MODEL}" + -e FOR_EMBED_DEVICE="${FOR_EMBED_DEVICE}" + -e FOR_EMBED_VERBOSE="${FOR_EMBED_VERBOSE}" + -e FOR_SERVER_WORKERS="${FOR_SERVER_WORKERS}" + -e FOR_SERVER_MAX_INFLIGHT="${FOR_SERVER_MAX_INFLIGHT}" + -e FOR_EMBED_LOAD_IN_4BIT="${FOR_EMBED_LOAD_IN_4BIT}" + -e FOR_EMBED_LOAD_IN_8BIT="${FOR_EMBED_LOAD_IN_8BIT}" + -e FOR_EMBED_BNB_4BIT_QUANT_TYPE="${FOR_EMBED_BNB_4BIT_QUANT_TYPE}" + -e FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE="${FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE}" +) + +if [[ -n "${HF_TOKEN}" ]]; then + env_args+=(-e HF_TOKEN="${HF_TOKEN}") +else + echo "HF_TOKEN is empty; not passing it to the container." +fi + +docker run --rm -it \ + --name "${CONTAINER_NAME}" \ + --gpus "device=${CUDA_VISIBLE_DEVICES}" \ + -p "${HOST_PORT}:${CONTAINER_PORT}" \ + -v "${HF_CACHE_DIR}:/root/.cache/huggingface" \ + "${env_args[@]}" \ + "${IMAGE}" diff --git a/tests/test_colpali.py b/tests/test_colpali.py index 04da413..5e86d7e 100644 --- a/tests/test_colpali.py +++ b/tests/test_colpali.py @@ -12,7 +12,13 @@ def colpali_rag_model() -> Generator[MultiModalRetrieverModel, None, None]: device = get_torch_device("auto") print(f"Using device: {device}") - yield MultiModalRetrieverModel.from_pretrained("vidore/colpali-v1.3", device=device) + yield MultiModalRetrieverModel.from_pretrained( + "vidore/colpali-v1.3", + device=device, + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="float16", + ) tear_down_torch() diff --git a/tests/test_colqwen.py b/tests/test_colqwen.py index 6b20283..449e53f 100644 --- a/tests/test_colqwen.py +++ b/tests/test_colqwen.py @@ -13,7 +13,11 @@ def colqwen_rag_model() -> Generator[MultiModalRetrieverModel, None, None]: device = get_torch_device("auto") print(f"Using device: {device}") yield MultiModalRetrieverModel.from_pretrained( - "vidore/colqwen2.5-v0.2", device=device + "vidore/colqwen2.5-v0.2", + device=device, + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="float16", ) tear_down_torch() diff --git a/tests/test_remote_mode.py b/tests/test_remote_mode.py new file mode 100644 index 0000000..c8b0851 --- /dev/null +++ b/tests/test_remote_mode.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import torch + +from foretrieval.colpali import ColPaliModel +from foretrieval.retriever import MultiModalRetrieverModel + + +def test_wrapper_forwards_remote_mode_args(monkeypatch): + import foretrieval.retriever as retriever_module + + captured = {} + + def _fake_from_pretrained(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + return object() + + monkeypatch.setattr( + retriever_module.ColPaliModel, "from_pretrained", _fake_from_pretrained + ) + + _ = MultiModalRetrieverModel.from_pretrained( + "vidore/colqwen2-v1.0", + embedding_mode="remote", + embedding_server_url="http://localhost:8000", + embedding_server_token="abc", + embedding_request_timeout=12.0, + ) + + assert captured["kwargs"]["embedding_mode"] == "remote" + assert captured["kwargs"]["embedding_server_url"] == "http://localhost:8000" + assert captured["kwargs"]["embedding_server_token"] == "abc" + assert captured["kwargs"]["embedding_request_timeout"] == 12.0 + + +def test_remote_mode_requires_server_url(monkeypatch): + import foretrieval.colpali as colpali_module + + with monkeypatch.context() as m: + m.setattr(colpali_module.importlib.util, "find_spec", lambda _name: None) + try: + ColPaliModel.from_pretrained("vidore/colqwen2-v1.0", embedding_mode="remote") + assert False, "Expected ValueError" + except ValueError as exc: + assert "embedding_server_url" in str(exc) + + +def test_embed_queries_uses_remote_client(monkeypatch): + import foretrieval.colpali as colpali_module + + class _DummyProcessor: + pass + + class _FakeRemoteClient: + def __init__(self, *_args, **_kwargs): + pass + + def encode_queries(self, queries): + assert queries == ["hello"] + return torch.randn(1, 4, 8) + + def encode_images(self, _images): + return torch.randn(1, 4, 8) + + monkeypatch.setattr(colpali_module, "RemoteEmbeddingClient", _FakeRemoteClient) + monkeypatch.setattr( + colpali_module.ColQwen2Processor, + "from_pretrained", + lambda *_args, **_kwargs: _DummyProcessor(), + ) + + model = ColPaliModel.from_pretrained( + "vidore/colqwen2-v1.0", + embedding_mode="remote", + embedding_server_url="http://localhost:8000", + device="cpu", + ) + out = model.encode_query("hello") + assert isinstance(out, torch.Tensor) + assert out.shape[0] == 1 diff --git a/uv.lock b/uv.lock index 328c7b1..33399ae 100644 --- a/uv.lock +++ b/uv.lock @@ -1084,6 +1084,9 @@ dependencies = [ ] [package.optional-dependencies] +client = [ + { name = "httpx" }, +] dev = [ { name = "pytest" }, { name = "ruff" }, @@ -1097,6 +1100,11 @@ extra-converters = [ { name = "python-pptx" }, { name = "reportlab" }, ] +full = [ + { name = "fastapi" }, + { name = "httpx" }, + { name = "uvicorn" }, +] langchain = [ { name = "langchain-core" }, ] @@ -1123,7 +1131,10 @@ requires-dist = [ { name = "docling", marker = "extra == 'docling'", specifier = ">=2.76.0" }, { name = "docx2pdf", specifier = ">=0.1.8" }, { name = "docx2pdf", marker = "sys_platform == 'win32' and extra == 'extra-converters'", specifier = ">=0.1.8" }, + { name = "fastapi", marker = "extra == 'full'" }, { name = "fastapi", marker = "extra == 'server'" }, + { name = "httpx", marker = "extra == 'client'", specifier = ">=0.27.0" }, + { name = "httpx", marker = "extra == 'full'", specifier = ">=0.27.0" }, { name = "langchain-core", marker = "extra == 'langchain'" }, { name = "langdetect", specifier = ">=1.0.9" }, { name = "matplotlib", specifier = ">=3.8.0" }, @@ -1141,9 +1152,10 @@ requires-dist = [ { name = "srsly" }, { name = "torch", specifier = ">=2.7.1" }, { name = "transformers", specifier = ">=4.42.0" }, + { name = "uvicorn", marker = "extra == 'full'" }, { name = "uvicorn", marker = "extra == 'server'" }, ] -provides-extras = ["dev", "server", "quant", "langchain", "extra-converters"] +provides-extras = ["dev", "server", "client", "full", "quant", "langchain", "extra-converters"] [package.metadata.requires-dev] dev = [{ name = "matplotlib", specifier = ">=3.10.8" }] From a3dc2874012c9f80904e44692279b98d64cb1b85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Le=20M=C3=A9tayer?= Date: Thu, 16 Apr 2026 16:06:46 +0200 Subject: [PATCH 3/7] Adding storage qdrant option to MultiModalRetrieverModel from_pretrained signature; defaulting to false as it's an extra dependency --- foretrieval/retriever.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/foretrieval/retriever.py b/foretrieval/retriever.py index 03c992e..07e6edf 100644 --- a/foretrieval/retriever.py +++ b/foretrieval/retriever.py @@ -44,6 +44,7 @@ def from_pretrained( pretrained_model_name_or_path: Union[str, Path], index_root: str = ".rag_index", ingestion: Dict[str, Any] = {"backend": "default"}, + storage_qdrant: bool = False, device: str = "cuda", verbose: int = 1, hf_token: Optional[str] = None, @@ -73,6 +74,7 @@ def from_pretrained( pretrained_model_name_or_path, index_root=index_root, ingestion=ingestion, + storage_qdrant=storage_qdrant, device=device, verbose=verbose, hf_token=hf_token, From 1ab28ef07b672dde88991465953dfd4bd4abcc80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Le=20M=C3=A9tayer?= Date: Fri, 17 Apr 2026 10:08:57 +0200 Subject: [PATCH 4/7] Apply query token filtering in both local and remote modes --- foretrieval/colpali.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/foretrieval/colpali.py b/foretrieval/colpali.py index b6aa44d..8ee264f 100644 --- a/foretrieval/colpali.py +++ b/foretrieval/colpali.py @@ -1165,13 +1165,11 @@ def _encode_search_query(self, query: str): embeddings_query = self._embed_queries([query]) qs = list(torch.unbind(embeddings_query.to("cpu"))) - if self.embedding_mode == "local": - batch_query = self.processor.process_queries([query]) - input_ids = batch_query["input_ids"][0].detach().cpu().tolist() - tokens = self.processor.tokenizer.convert_ids_to_tokens(input_ids) - valid_idxs = [i for i, tok in enumerate(tokens) if tok not in {"<|endoftext|>", "Query", ":"}] - return [qs[0][valid_idxs]] - return qs + batch_query = self.processor.process_queries([query]) + input_ids = batch_query["input_ids"][0].detach().cpu().tolist() + tokens = self.processor.tokenizer.convert_ids_to_tokens(input_ids) + valid_idxs = [i for i, tok in enumerate(tokens) if tok not in {"<|endoftext|>", "Query", ":"}] + return [qs[0][valid_idxs]] def _search_local( self, From 75cb82de95022b1d803b962269aef4267e133cad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Le=20M=C3=A9tayer?= Date: Sun, 19 Apr 2026 16:14:00 +0200 Subject: [PATCH 5/7] cap transformers lib version to <5; no support for qwen2.5 when using >= 5 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f82cd0b..3ab76cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "pypdf>=6.1.3", "srsly", "torch>=2.7.1", - "transformers>=4.42.0", + "transformers>=4.42.0, <5", "pdf2image>=1.17.0", ] From 8810c14c956b47c6e5d72b03f0f76494a1ded5ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Le=20M=C3=A9tayer?= Date: Sun, 19 Apr 2026 16:15:20 +0200 Subject: [PATCH 6/7] Adding verify ssl option support --- foretrieval/client/remote_backend.py | 3 ++- foretrieval/colpali.py | 2 ++ foretrieval/retriever.py | 2 ++ tests/test_remote_mode.py | 38 ++++++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 1 deletion(-) diff --git a/foretrieval/client/remote_backend.py b/foretrieval/client/remote_backend.py index 7c38726..0d87638 100644 --- a/foretrieval/client/remote_backend.py +++ b/foretrieval/client/remote_backend.py @@ -16,6 +16,7 @@ def __init__( model_name: str, token: Optional[str] = None, timeout: float = 30.0, + verify_ssl: bool = True, ): try: import httpx @@ -30,7 +31,7 @@ def __init__( headers = {} if token: headers["Authorization"] = f"Bearer {token}" - self.client = httpx.Client(timeout=timeout, headers=headers) + self.client = httpx.Client(timeout=timeout, headers=headers, verify=verify_ssl) @staticmethod def _image_to_bytes(image: Image.Image) -> bytes: diff --git a/foretrieval/colpali.py b/foretrieval/colpali.py index 8ee264f..f9fedea 100644 --- a/foretrieval/colpali.py +++ b/foretrieval/colpali.py @@ -129,6 +129,7 @@ def __init__( self.embedding_server_url = kwargs.pop("embedding_server_url", None) self.embedding_server_token = kwargs.pop("embedding_server_token", None) self.embedding_request_timeout = float(kwargs.pop("embedding_request_timeout", 30.0)) + self.embedding_verify_ssl = bool(kwargs.pop("embedding_verify_ssl", True)) self.qdrant_client = None self.qdrant_collection = index_name @@ -189,6 +190,7 @@ def __init__( model_name=self.pretrained_model_name_or_path, token=self.embedding_server_token, timeout=self.embedding_request_timeout, + verify_ssl=self.embedding_verify_ssl, ) self.embedding_backend = RemoteEmbeddingBackend(self.remote_client) self._load_remote_processor_only() diff --git a/foretrieval/retriever.py b/foretrieval/retriever.py index 07e6edf..2f14593 100644 --- a/foretrieval/retriever.py +++ b/foretrieval/retriever.py @@ -52,6 +52,7 @@ def from_pretrained( embedding_server_url: Optional[str] = None, embedding_server_token: Optional[str] = None, embedding_request_timeout: float = 30.0, + embedding_verify_ssl: bool = True, load_in_4bit: bool = False, load_in_8bit: bool = False, bnb_4bit_quant_type: str = "nf4", @@ -82,6 +83,7 @@ def from_pretrained( embedding_server_url=embedding_server_url, embedding_server_token=embedding_server_token, embedding_request_timeout=embedding_request_timeout, + embedding_verify_ssl=embedding_verify_ssl, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit, bnb_4bit_quant_type=bnb_4bit_quant_type, diff --git a/tests/test_remote_mode.py b/tests/test_remote_mode.py index c8b0851..040fea9 100644 --- a/tests/test_remote_mode.py +++ b/tests/test_remote_mode.py @@ -26,12 +26,14 @@ def _fake_from_pretrained(*args, **kwargs): embedding_server_url="http://localhost:8000", embedding_server_token="abc", embedding_request_timeout=12.0, + embedding_verify_ssl=False, ) assert captured["kwargs"]["embedding_mode"] == "remote" assert captured["kwargs"]["embedding_server_url"] == "http://localhost:8000" assert captured["kwargs"]["embedding_server_token"] == "abc" assert captured["kwargs"]["embedding_request_timeout"] == 12.0 + assert captured["kwargs"]["embedding_verify_ssl"] is False def test_remote_mode_requires_server_url(monkeypatch): @@ -79,3 +81,39 @@ def encode_images(self, _images): out = model.encode_query("hello") assert isinstance(out, torch.Tensor) assert out.shape[0] == 1 + + +def test_embedding_verify_ssl_disables_remote_ssl_verification(monkeypatch): + import foretrieval.colpali as colpali_module + + class _DummyProcessor: + pass + + captured = {} + + class _FakeRemoteClient: + def __init__(self, *_args, **kwargs): + captured["verify_ssl"] = kwargs.get("verify_ssl") + + def encode_queries(self, _queries): + return torch.randn(1, 4, 8) + + def encode_images(self, _images): + return torch.randn(1, 4, 8) + + monkeypatch.setattr(colpali_module, "RemoteEmbeddingClient", _FakeRemoteClient) + monkeypatch.setattr( + colpali_module.ColQwen2Processor, + "from_pretrained", + lambda *_args, **_kwargs: _DummyProcessor(), + ) + + ColPaliModel.from_pretrained( + "vidore/colqwen2-v1.0", + embedding_mode="remote", + embedding_server_url="https://localhost:8000", + embedding_verify_ssl=False, + device="cpu", + ) + + assert captured["verify_ssl"] is False From 7b274b84893de4464d4955deda4c32516366b7de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Le=20M=C3=A9tayer?= Date: Mon, 20 Apr 2026 15:58:29 +0200 Subject: [PATCH 7/7] Multiple features: - split server image into base image and cuda fa2 enabled image - add concurrency option to parallelize embeddings computation - implement auto fa2 activation + force option - implement single cache model (worker doesn't retain previously loaded model on memory to avois OOM/memory exhausion) --- Dockerfile | 121 ++++++++++++++++++++++++- Makefile | 28 ++++-- foretrieval/client/remote_backend.py | 60 +++++++++--- foretrieval/colpali.py | 43 ++++++++- foretrieval/retriever.py | 18 ++-- foretrieval/server/embedding_server.py | 100 ++++++++++++++++---- foretrieval/server/server_main.py | 23 +++++ scripts/run-docker.sh | 2 + tests/test_remote_mode.py | 10 +- 9 files changed, 347 insertions(+), 58 deletions(-) diff --git a/Dockerfile b/Dockerfile index be1f41e..73bcd79 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1.7 -FROM python:3.12-slim +FROM python:3.12-slim AS runtime-base ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ @@ -23,7 +23,6 @@ project = cfg["project"] deps = list(project.get("dependencies", [])) optional = project.get("optional-dependencies", {}) deps.extend(optional.get("server", [])) -deps.extend(optional.get("quant", [])) # Stable order and dedupe for deterministic cache behavior. seen = set() @@ -45,12 +44,130 @@ COPY foretrieval ./foretrieval RUN --mount=type=cache,target=/root/.cache/pip \ python -m pip install --no-deps . + + + + + + +# Shared CUDA Python runtime base for GPU stages. +# Keep this stage free of app source copies so app-only changes do not bust +# heavyweight builder caches (e.g. flash-attn wheel stage). +FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 AS cuda-python-runtime-base + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + CUDA_HOME=/usr/local/cuda + +RUN apt-get update && apt-get install -y --no-install-recommends \ + software-properties-common \ + ca-certificates \ + gnupg \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update && apt-get install -y --no-install-recommends \ + python3.12 \ + python3.12-venv \ + && rm -rf /var/lib/apt/lists/* + +RUN ln -sf /usr/bin/python3.12 /usr/local/bin/python +RUN python -m venv /opt/venv +ENV VIRTUAL_ENV=/opt/venv +ENV PATH="/opt/venv/bin:${PATH}" + +WORKDIR /app +COPY pyproject.toml ./ + +# Single dependency/app installation point for GPU pipeline. +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --upgrade pip setuptools wheel && \ + python - <<'PY' > /tmp/requirements-runtime.txt +import tomllib +from pathlib import Path + +cfg = tomllib.loads(Path("pyproject.toml").read_text()) +project = cfg["project"] +deps = list(project.get("dependencies", [])) +optional = project.get("optional-dependencies", {}) +deps.extend(optional.get("server", [])) + +seen = set() +ordered = [] +for dep in deps: + if dep not in seen: + seen.add(dep) + ordered.append(dep) +print("\n".join(ordered)) +PY + +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install -r /tmp/requirements-runtime.txt && \ + python -m pip install "bitsandbytes>=0.43" + +# App installation layer for CUDA runtime images. +FROM cuda-python-runtime-base AS cuda-app-base + +COPY README.md ./ +COPY foretrieval ./foretrieval + +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --no-deps . + +# Dedicated builder for flash-attn wheel artifacts. +FROM cuda-python-runtime-base AS flashattn-builder + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + git \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /build +RUN mkdir -p /wheels && \ + python -m pip wheel --no-build-isolation flash-attn -w /wheels + +# GPU target: installs flash-attn wheel built in dedicated builder. +FROM cuda-app-base AS gpu + +COPY --from=flashattn-builder /wheels /tmp/wheels + +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install /tmp/wheels/flash_attn-*.whl + +# Server runtime defaults (override at `docker run -e ...`) +ENV FOR_EMBED_MODEL=vidore/colqwen2-v1.0 \ + FOR_EMBED_DEVICE=cuda \ + FOR_EMBED_VERBOSE=1 \ + FOR_EMBED_FLASH_ATTENTION=auto \ + FOR_SERVER_WORKERS=1 \ + FOR_SERVER_MAX_INFLIGHT=1 \ + FOR_SERVER_SINGLE_MODEL_CACHE=true \ + FOR_EMBED_LOAD_IN_4BIT=false \ + FOR_EMBED_LOAD_IN_8BIT=false \ + FOR_EMBED_BNB_4BIT_QUANT_TYPE=nf4 \ + FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE=float16 + +EXPOSE 8000 + +CMD ["sh", "-c", "python -m uvicorn foretrieval.server.server_main:app --host 0.0.0.0 --port 8000 --workers ${FOR_SERVER_WORKERS:-1}"] + + + + + + +# Default target: CPU-friendly runtime without GPU-specific optional deps. +FROM runtime-base AS cpu + # Server runtime defaults (override at `docker run -e ...`) ENV FOR_EMBED_MODEL=vidore/colqwen2-v1.0 \ FOR_EMBED_DEVICE=cpu \ FOR_EMBED_VERBOSE=1 \ + FOR_EMBED_FLASH_ATTENTION=auto \ FOR_SERVER_WORKERS=1 \ FOR_SERVER_MAX_INFLIGHT=1 \ + FOR_SERVER_SINGLE_MODEL_CACHE=true \ FOR_EMBED_LOAD_IN_4BIT=false \ FOR_EMBED_LOAD_IN_8BIT=false \ FOR_EMBED_BNB_4BIT_QUANT_TYPE=nf4 \ diff --git a/Makefile b/Makefile index c643780..0507a37 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,12 @@ IMAGE_NAME ?= foretrieval-server IMAGE_TAG ?= v0.0.1 IMAGE_REPO ?= $(REGISTRY)/$(OWNER)/$(IMAGE_NAME) IMAGE ?= $(IMAGE_REPO):$(IMAGE_TAG) +IMAGE_CPU ?= $(IMAGE_REPO):$(IMAGE_TAG)-base +GPU_CUDA_VERSION ?= 12.2.2 +GPU_CUDNN_MAJOR ?= 8 +HOST_ARCH_RAW := $(shell uname -m) +HOST_ARCH ?= $(if $(filter x86_64 amd64,$(HOST_ARCH_RAW)),amd64,$(if $(filter aarch64 arm64,$(HOST_ARCH_RAW)),arm64,$(HOST_ARCH_RAW))) +IMAGE_GPU ?= $(IMAGE_REPO):$(IMAGE_TAG)-cuda$(GPU_CUDA_VERSION)-cudnn$(GPU_CUDNN_MAJOR)-flashattn-bnb-$(HOST_ARCH) # Backward-compatible aliases. NAMESPACE ?= $(OWNER) @@ -14,7 +20,7 @@ PYTEST ?= pytest TEST_ARGS ?= -m "not slow and not integration" SRC_PACKAGE ?= foretrieval -.PHONY: check-buildx login-registry build publish run-server test-fast test-all coverage-fast coverage-all +.PHONY: check-buildx login-registry build build-cpu build-gpu publish publish-cpu publish-gpu run-server test-fast test-all coverage-fast coverage-all check-buildx: @docker buildx version >/dev/null 2>&1 || { \ @@ -31,14 +37,24 @@ login-registry: } @printf '%s' "$$GITHUB_PAT" | docker login ghcr.io -u $(OWNER) --password-stdin -build: check-buildx - docker buildx build --load -f $(DOCKERFILE) -t $(IMAGE) $(BUILD_CONTEXT) +build: build-cpu build-gpu -publish: check-buildx login-registry - docker buildx build --push -f $(DOCKERFILE) -t $(IMAGE) $(BUILD_CONTEXT) +build-cpu: check-buildx + docker buildx build --load -f $(DOCKERFILE) --target cpu -t $(IMAGE_CPU) $(BUILD_CONTEXT) + +build-gpu: check-buildx + docker buildx build --load -f $(DOCKERFILE) --target gpu -t $(IMAGE_GPU) $(BUILD_CONTEXT) + +publish: publish-cpu publish-gpu + +publish-cpu: check-buildx login-registry + docker buildx build --push -f $(DOCKERFILE) --target cpu -t $(IMAGE_CPU) $(BUILD_CONTEXT) + +publish-gpu: check-buildx login-registry + docker buildx build --push -f $(DOCKERFILE) --target gpu -t $(IMAGE_GPU) $(BUILD_CONTEXT) run-server: - REGISTRY=$(REGISTRY) OWNER=$(OWNER) IMAGE_NAME=$(IMAGE_NAME) IMAGE_TAG=$(IMAGE_TAG) ./scripts/run-docker.sh + IMAGE=$(IMAGE_GPU) REGISTRY=$(REGISTRY) OWNER=$(OWNER) IMAGE_NAME=$(IMAGE_NAME) IMAGE_TAG=$(IMAGE_TAG) ./scripts/run-docker.sh test-fast: $(PYTEST) $(TEST_ARGS) diff --git a/foretrieval/client/remote_backend.py b/foretrieval/client/remote_backend.py index 0d87638..54a6b4c 100644 --- a/foretrieval/client/remote_backend.py +++ b/foretrieval/client/remote_backend.py @@ -1,6 +1,7 @@ from __future__ import annotations import io +from concurrent.futures import ThreadPoolExecutor from typing import List, Optional import torch @@ -17,6 +18,8 @@ def __init__( token: Optional[str] = None, timeout: float = 30.0, verify_ssl: bool = True, + concurrency: int = 1, + request_batch_size: Optional[int] = None, ): try: import httpx @@ -28,6 +31,10 @@ def __init__( self._httpx = httpx self.server_url = server_url.rstrip("/") self.model_name = model_name + self.concurrency = max(1, int(concurrency)) + self.request_batch_size = int(request_batch_size) if request_batch_size else None + if self.request_batch_size is not None and self.request_batch_size < 1: + raise ValueError("request_batch_size must be >= 1 when provided.") headers = {} if token: headers["Authorization"] = f"Bearer {token}" @@ -39,13 +46,13 @@ def _image_to_bytes(image: Image.Image) -> bytes: image.save(buf, format="PNG") return buf.getvalue() - def encode_images(self, images: List[Image.Image]) -> torch.Tensor: - payload = { - "model": self.model_name, - "images": [self._image_to_bytes(image.convert("RGB")) for image in images], - } + @staticmethod + def _chunked(items: list, chunk_size: int) -> List[list]: + return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)] + + def _post_embeddings(self, endpoint: str, payload: dict) -> torch.Tensor: resp = self.client.post( - f"{self.server_url}/v1/embed/images", + f"{self.server_url}{endpoint}", content=dumps_payload(payload), headers={"Content-Type": "application/octet-stream"}, ) @@ -53,13 +60,36 @@ def encode_images(self, images: List[Image.Image]) -> torch.Tensor: out = loads_payload(resp.content) return out["embeddings"].cpu() + def _encode_batched(self, endpoint: str, field_name: str, values: list) -> torch.Tensor: + if not values: + return torch.empty((0,)) + + batch_size = self.request_batch_size or len(values) + chunks = self._chunked(values, batch_size) + + # avoids thread pool overhead if only one element + if self.concurrency == 1 or len(chunks) == 1: + tensors = [] + for chunk in chunks: + payload = {"model": self.model_name, field_name: chunk} + tensors.append(self._post_embeddings(endpoint, payload)) + return torch.cat(tensors, dim=0) if len(tensors) > 1 else tensors[0] + + with ThreadPoolExecutor(max_workers=self.concurrency) as pool: + futures = [] + for idx, chunk in enumerate(chunks): + payload = {"model": self.model_name, field_name: chunk} + futures.append((idx, pool.submit(self._post_embeddings, endpoint, payload))) + + ordered = [None] * len(chunks) + for idx, future in futures: + ordered[idx] = future.result() + + return torch.cat(ordered, dim=0) if len(ordered) > 1 else ordered[0] + + def encode_images(self, images: List[Image.Image]) -> torch.Tensor: + encoded = [self._image_to_bytes(image.convert("RGB")) for image in images] + return self._encode_batched("/v1/embed/images", "images", encoded) + def encode_queries(self, queries: List[str]) -> torch.Tensor: - payload = {"model": self.model_name, "queries": queries} - resp = self.client.post( - f"{self.server_url}/v1/embed/queries", - content=dumps_payload(payload), - headers={"Content-Type": "application/octet-stream"}, - ) - resp.raise_for_status() - out = loads_payload(resp.content) - return out["embeddings"].cpu() + return self._encode_batched("/v1/embed/queries", "queries", queries) diff --git a/foretrieval/colpali.py b/foretrieval/colpali.py index f9fedea..589047b 100644 --- a/foretrieval/colpali.py +++ b/foretrieval/colpali.py @@ -130,6 +130,17 @@ def __init__( self.embedding_server_token = kwargs.pop("embedding_server_token", None) self.embedding_request_timeout = float(kwargs.pop("embedding_request_timeout", 30.0)) self.embedding_verify_ssl = bool(kwargs.pop("embedding_verify_ssl", True)) + self.embedding_concurrency = int(kwargs.pop("embedding_concurrency", 1)) + self.embedding_request_batch_size = kwargs.pop("embedding_request_batch_size", None) + if self.embedding_concurrency < 1: + raise ValueError("embedding_concurrency must be >= 1.") + if self.embedding_request_batch_size is not None: + self.embedding_request_batch_size = int(self.embedding_request_batch_size) + if self.embedding_request_batch_size < 1: + raise ValueError("embedding_request_batch_size must be >= 1 when provided.") + self.flash_attention_mode = str(kwargs.pop("flash_attention_mode", "auto")).strip().lower() + if self.flash_attention_mode not in {"auto", "on", "off"}: + raise ValueError("flash_attention_mode must be one of: auto, on, off.") self.qdrant_client = None self.qdrant_collection = index_name @@ -191,6 +202,8 @@ def __init__( token=self.embedding_server_token, timeout=self.embedding_request_timeout, verify_ssl=self.embedding_verify_ssl, + concurrency=self.embedding_concurrency, + request_batch_size=self.embedding_request_batch_size, ) self.embedding_backend = RemoteEmbeddingBackend(self.remote_client) self._load_remote_processor_only() @@ -215,6 +228,7 @@ def _load_model_and_processor(self): token = self.kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN") is_cuda = self.device == "cuda" or (isinstance(self.device, torch.device) and self.device.type == "cuda") device_map = "cuda" if is_cuda else None + flash_attn_available = importlib.util.find_spec("flash_attn") is not None if "colpali" in self.pretrained_model_name_or_path.lower(): model_cls = ColPali @@ -226,12 +240,33 @@ def _load_model_and_processor(self): model_cls = ColQwen2 processor_cls = ColQwen2Processor + model_kwargs = { + "torch_dtype": torch.bfloat16, + "device_map": device_map, + "quantization_config": self._quantization_config, + "token": token, + } + if self.flash_attention_mode == "on": + if not is_cuda: + raise RuntimeError("flash_attention_mode='on' requires CUDA device.") + if not flash_attn_available: + raise RuntimeError("flash_attention_mode='on' but flash_attn is not installed.") + model_kwargs["attn_implementation"] = "flash_attention_2" + logger.info("Using Flash-Attention 2 for model loading (flash_attention_mode=on).") + elif self.flash_attention_mode == "off": + logger.info("Using default attention implementation for model loading (flash_attention_mode=off).") + elif is_cuda and flash_attn_available: + model_kwargs["attn_implementation"] = "flash_attention_2" + logger.info("Using Flash-Attention 2 for model loading (flash_attention_mode=auto).") + else: + logger.info( + "Using default attention implementation for model loading " + f"(flash_attention_mode=auto, is_cuda={is_cuda}, flash_attn_available={flash_attn_available})." + ) + self.model = model_cls.from_pretrained( self.pretrained_model_name_or_path, - torch_dtype=torch.bfloat16, - device_map=device_map, - quantization_config=self._quantization_config, - token=token, + **model_kwargs, ) self.processor = processor_cls.from_pretrained( self.pretrained_model_name_or_path, diff --git a/foretrieval/retriever.py b/foretrieval/retriever.py index 2f14593..a00dd17 100644 --- a/foretrieval/retriever.py +++ b/foretrieval/retriever.py @@ -49,14 +49,12 @@ def from_pretrained( verbose: int = 1, hf_token: Optional[str] = None, embedding_mode: str = "local", - embedding_server_url: Optional[str] = None, - embedding_server_token: Optional[str] = None, - embedding_request_timeout: float = 30.0, - embedding_verify_ssl: bool = True, + remote: Optional[Dict[str, Any]] = None, load_in_4bit: bool = False, load_in_8bit: bool = False, bnb_4bit_quant_type: str = "nf4", bnb_4bit_compute_dtype: str = "float16", + flash_attention_mode: str = "auto", ): """Load a ColPali model from a pre-trained checkpoint. @@ -70,6 +68,7 @@ def from_pretrained( Returns: cls (RAGMultiModalModel): The current instance of RAGMultiModalModel, with the model initialised. """ + remote_cfg: Dict[str, Any] = remote or {} instance = cls() instance.model = ColPaliModel.from_pretrained( pretrained_model_name_or_path, @@ -80,14 +79,17 @@ def from_pretrained( verbose=verbose, hf_token=hf_token, embedding_mode=embedding_mode, - embedding_server_url=embedding_server_url, - embedding_server_token=embedding_server_token, - embedding_request_timeout=embedding_request_timeout, - embedding_verify_ssl=embedding_verify_ssl, + embedding_server_url=remote_cfg.get("url"), + embedding_server_token=remote_cfg.get("token"), + embedding_request_timeout=remote_cfg.get("request_timeout", 30.0), + embedding_verify_ssl=remote_cfg.get("verify_ssl", True), + embedding_concurrency=remote_cfg.get("concurrency", 1), + embedding_request_batch_size=remote_cfg.get("request_batch_size"), load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit, bnb_4bit_quant_type=bnb_4bit_quant_type, bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + flash_attention_mode=flash_attention_mode, ) return instance diff --git a/foretrieval/server/embedding_server.py b/foretrieval/server/embedding_server.py index 8b20eed..6ab0d5f 100644 --- a/foretrieval/server/embedding_server.py +++ b/foretrieval/server/embedding_server.py @@ -1,8 +1,9 @@ from __future__ import annotations import asyncio +import gc import io -from typing import Dict, Optional +from typing import Dict, Optional, Tuple from fastapi import FastAPI, HTTPException, Request, Response from PIL import Image @@ -20,28 +21,80 @@ def create_app( load_in_8bit: bool = False, bnb_4bit_quant_type: str = "nf4", bnb_4bit_compute_dtype: str = "float16", + flash_attention_mode: str = "auto", max_inflight: int = 1, + single_model_cache: bool = False, ) -> FastAPI: app = FastAPI(title="FORetrieval Embedding Server") inflight_semaphore = asyncio.Semaphore(max_inflight) + cache_lock = asyncio.Lock() retriever_cache: Dict[str, MultiModalRetrieverModel] = {} + retriever_in_use: Dict[int, int] = {} + pending_releases: Dict[int, MultiModalRetrieverModel] = {} - def get_retriever(requested_model: Optional[str]) -> MultiModalRetrieverModel: + def _release_retriever(retriever: MultiModalRetrieverModel) -> None: + # Best effort release: drop strong refs, force GC, then flush CUDA allocator cache. + del retriever + gc.collect() + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if hasattr(torch.cuda, "ipc_collect"): + torch.cuda.ipc_collect() + except Exception: + # Cache eviction should not fail request processing. + pass + + def _schedule_or_release_locked(retriever: MultiModalRetrieverModel) -> None: + rid = id(retriever) + if retriever_in_use.get(rid, 0) > 0: + pending_releases[rid] = retriever + return + _release_retriever(retriever) + + async def _acquire_retriever(requested_model: Optional[str]) -> Tuple[MultiModalRetrieverModel, int]: target_model = requested_model or model_name - if target_model not in retriever_cache: - retriever_cache[target_model] = MultiModalRetrieverModel.from_pretrained( - target_model, - device=device, - verbose=verbose, - hf_token=hf_token, - embedding_mode="local", - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, - bnb_4bit_quant_type=bnb_4bit_quant_type, - bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, - ) - return retriever_cache[target_model] + async with cache_lock: + cached = retriever_cache.get(target_model) + if cached is None: + if single_model_cache and retriever_cache: + old_retrievers = list(retriever_cache.values()) + retriever_cache.clear() + for old in old_retrievers: + _schedule_or_release_locked(old) + + cached = MultiModalRetrieverModel.from_pretrained( + target_model, + device=device, + verbose=verbose, + hf_token=hf_token, + embedding_mode="local", + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + flash_attention_mode=flash_attention_mode, + ) + retriever_cache[target_model] = cached + + rid = id(cached) + retriever_in_use[rid] = retriever_in_use.get(rid, 0) + 1 + return cached, rid + + async def _release_retriever_use(rid: int) -> None: + to_release = None + async with cache_lock: + current = retriever_in_use.get(rid, 0) + if current <= 1: + retriever_in_use.pop(rid, None) + to_release = pending_releases.pop(rid, None) + else: + retriever_in_use[rid] = current - 1 + if to_release is not None: + _release_retriever(to_release) @app.get("/health") def health(): @@ -49,6 +102,9 @@ def health(): "status": "ok", "default_model": model_name, "loaded_models": list(retriever_cache.keys()), + "single_model_cache": single_model_cache, + "in_use_retrievers": len(retriever_in_use), + "pending_evictions": len(pending_releases), "device": str(device), "max_inflight": max_inflight, } @@ -58,9 +114,12 @@ async def embed_images(request: Request): try: payload = loads_payload(await request.body()) async with inflight_semaphore: - retriever = get_retriever(payload.get("model")) + retriever, rid = await _acquire_retriever(payload.get("model")) images = [Image.open(io.BytesIO(b)).convert("RGB") for b in payload["images"]] - embeddings = retriever.model.encode_image(images) + try: + embeddings = retriever.model.encode_image(images) + finally: + await _release_retriever_use(rid) return Response( content=dumps_payload({"embeddings": embeddings.cpu()}), media_type="application/octet-stream", @@ -73,9 +132,12 @@ async def embed_queries(request: Request): try: payload = loads_payload(await request.body()) async with inflight_semaphore: - retriever = get_retriever(payload.get("model")) + retriever, rid = await _acquire_retriever(payload.get("model")) queries = payload["queries"] - embeddings = retriever.model.encode_query(queries) + try: + embeddings = retriever.model.encode_query(queries) + finally: + await _release_retriever_use(rid) return Response( content=dumps_payload({"embeddings": embeddings.cpu()}), media_type="application/octet-stream", diff --git a/foretrieval/server/server_main.py b/foretrieval/server/server_main.py index 2fb5fa7..e9edbdb 100644 --- a/foretrieval/server/server_main.py +++ b/foretrieval/server/server_main.py @@ -33,6 +33,25 @@ def _as_optional_str(name: str) -> str | None: return stripped if stripped else None +def _as_flash_attention_mode(name: str, default: str = "auto") -> str: + val = os.environ.get(name) + if val is None: + return default + mode = val.strip().lower() + aliases = { + "true": "on", + "1": "on", + "yes": "on", + "false": "off", + "0": "off", + "no": "off", + } + mode = aliases.get(mode, mode) + if mode not in {"auto", "on", "off"}: + raise ValueError(f"{name} must be one of auto/on/off (or boolean aliases), got {val!r}.") + return mode + + MODEL_NAME = os.environ.get("FOR_EMBED_MODEL", "vidore/colqwen2-v1.0") DEVICE = os.environ.get("FOR_EMBED_DEVICE", "cuda") VERBOSE = int(os.environ.get("FOR_EMBED_VERBOSE", "1")) @@ -41,7 +60,9 @@ def _as_optional_str(name: str) -> str | None: LOAD_IN_8BIT = _as_bool("FOR_EMBED_LOAD_IN_8BIT", False) BNB_4BIT_QUANT_TYPE = os.environ.get("FOR_EMBED_BNB_4BIT_QUANT_TYPE", "nf4") BNB_4BIT_COMPUTE_DTYPE = os.environ.get("FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE", "float16") +FLASH_ATTENTION_MODE = _as_flash_attention_mode("FOR_EMBED_FLASH_ATTENTION", "auto") MAX_INFLIGHT = _as_positive_int("FOR_SERVER_MAX_INFLIGHT", 1) +SINGLE_MODEL_CACHE = _as_bool("FOR_SERVER_SINGLE_MODEL_CACHE", False) app = create_app( model_name=MODEL_NAME, @@ -52,5 +73,7 @@ def _as_optional_str(name: str) -> str | None: load_in_8bit=LOAD_IN_8BIT, bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE, bnb_4bit_compute_dtype=BNB_4BIT_COMPUTE_DTYPE, + flash_attention_mode=FLASH_ATTENTION_MODE, max_inflight=MAX_INFLIGHT, + single_model_cache=SINGLE_MODEL_CACHE, ) diff --git a/scripts/run-docker.sh b/scripts/run-docker.sh index ea9bcba..a42fed8 100755 --- a/scripts/run-docker.sh +++ b/scripts/run-docker.sh @@ -24,6 +24,7 @@ FOR_EMBED_DEVICE="${FOR_EMBED_DEVICE:-cuda}" FOR_EMBED_VERBOSE="${FOR_EMBED_VERBOSE:-1}" FOR_SERVER_WORKERS="${FOR_SERVER_WORKERS:-1}" FOR_SERVER_MAX_INFLIGHT="${FOR_SERVER_MAX_INFLIGHT:-1}" +FOR_SERVER_SINGLE_MODEL_CACHE="${FOR_SERVER_SINGLE_MODEL_CACHE:-true}" FOR_EMBED_LOAD_IN_4BIT="${FOR_EMBED_LOAD_IN_4BIT:-false}" FOR_EMBED_LOAD_IN_8BIT="${FOR_EMBED_LOAD_IN_8BIT:-false}" FOR_EMBED_BNB_4BIT_QUANT_TYPE="${FOR_EMBED_BNB_4BIT_QUANT_TYPE:-nf4}" @@ -53,6 +54,7 @@ env_args=( -e FOR_EMBED_VERBOSE="${FOR_EMBED_VERBOSE}" -e FOR_SERVER_WORKERS="${FOR_SERVER_WORKERS}" -e FOR_SERVER_MAX_INFLIGHT="${FOR_SERVER_MAX_INFLIGHT}" + -e FOR_SERVER_SINGLE_MODEL_CACHE="${FOR_SERVER_SINGLE_MODEL_CACHE}" -e FOR_EMBED_LOAD_IN_4BIT="${FOR_EMBED_LOAD_IN_4BIT}" -e FOR_EMBED_LOAD_IN_8BIT="${FOR_EMBED_LOAD_IN_8BIT}" -e FOR_EMBED_BNB_4BIT_QUANT_TYPE="${FOR_EMBED_BNB_4BIT_QUANT_TYPE}" diff --git a/tests/test_remote_mode.py b/tests/test_remote_mode.py index 040fea9..c720363 100644 --- a/tests/test_remote_mode.py +++ b/tests/test_remote_mode.py @@ -23,10 +23,12 @@ def _fake_from_pretrained(*args, **kwargs): _ = MultiModalRetrieverModel.from_pretrained( "vidore/colqwen2-v1.0", embedding_mode="remote", - embedding_server_url="http://localhost:8000", - embedding_server_token="abc", - embedding_request_timeout=12.0, - embedding_verify_ssl=False, + remote={ + "url": "http://localhost:8000", + "token": "abc", + "request_timeout": 12.0, + "verify_ssl": False, + }, ) assert captured["kwargs"]["embedding_mode"] == "remote"