Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.git
.github
.venv
.pytest_cache
.vscode
__pycache__
*.pyc
*.pyo
*.pyd
*.egg-info
build
dist
tests
sample_data
tmp_forag_test
tmp_forag_test_remote
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,7 @@ local/
pylate/*
pylate/

.byaldi/
.byaldi/

.pat

178 changes: 178 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# syntax=docker/dockerfile:1.7
FROM python:3.12-slim AS runtime-base

ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1

WORKDIR /app

# Copy only dependency metadata first for better layer caching
COPY pyproject.toml ./

# Install third-party runtime dependencies from pyproject extras.
# This layer is invalidated only when dependency metadata changes.
RUN --mount=type=cache,target=/root/.cache/pip \
python -m pip install --upgrade pip setuptools wheel && \
python - <<'PY' > /tmp/requirements-runtime.txt
import tomllib
from pathlib import Path

cfg = tomllib.loads(Path("pyproject.toml").read_text())
project = cfg["project"]
deps = list(project.get("dependencies", []))
optional = project.get("optional-dependencies", {})
deps.extend(optional.get("server", []))

# Stable order and dedupe for deterministic cache behavior.
seen = set()
ordered = []
for dep in deps:
if dep not in seen:
seen.add(dep)
ordered.append(dep)
print("\n".join(ordered))
PY
RUN --mount=type=cache,target=/root/.cache/pip \
python -m pip install -r /tmp/requirements-runtime.txt

# Copy package sources late so app code changes do not bust dependency layers
COPY README.md ./
COPY foretrieval ./foretrieval

# Install local package without re-resolving dependencies
RUN --mount=type=cache,target=/root/.cache/pip \
python -m pip install --no-deps .







# Shared CUDA Python runtime base for GPU stages.
# Keep this stage free of app source copies so app-only changes do not bust
# heavyweight builder caches (e.g. flash-attn wheel stage).
FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 AS cuda-python-runtime-base

ENV DEBIAN_FRONTEND=noninteractive \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1 \
CUDA_HOME=/usr/local/cuda

RUN apt-get update && apt-get install -y --no-install-recommends \
software-properties-common \
ca-certificates \
gnupg \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update && apt-get install -y --no-install-recommends \
python3.12 \
python3.12-venv \
&& rm -rf /var/lib/apt/lists/*

RUN ln -sf /usr/bin/python3.12 /usr/local/bin/python
RUN python -m venv /opt/venv
ENV VIRTUAL_ENV=/opt/venv
ENV PATH="/opt/venv/bin:${PATH}"

WORKDIR /app
COPY pyproject.toml ./

# Single dependency/app installation point for GPU pipeline.
RUN --mount=type=cache,target=/root/.cache/pip \
python -m pip install --upgrade pip setuptools wheel && \
python - <<'PY' > /tmp/requirements-runtime.txt
import tomllib
from pathlib import Path

cfg = tomllib.loads(Path("pyproject.toml").read_text())
project = cfg["project"]
deps = list(project.get("dependencies", []))
optional = project.get("optional-dependencies", {})
deps.extend(optional.get("server", []))

seen = set()
ordered = []
for dep in deps:
if dep not in seen:
seen.add(dep)
ordered.append(dep)
print("\n".join(ordered))
PY

RUN --mount=type=cache,target=/root/.cache/pip \
python -m pip install -r /tmp/requirements-runtime.txt && \
python -m pip install "bitsandbytes>=0.43"

# App installation layer for CUDA runtime images.
FROM cuda-python-runtime-base AS cuda-app-base

COPY README.md ./
COPY foretrieval ./foretrieval

RUN --mount=type=cache,target=/root/.cache/pip \
python -m pip install --no-deps .

# Dedicated builder for flash-attn wheel artifacts.
FROM cuda-python-runtime-base AS flashattn-builder

RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
git \
ninja-build \
&& rm -rf /var/lib/apt/lists/*

WORKDIR /build
RUN mkdir -p /wheels && \
python -m pip wheel --no-build-isolation flash-attn -w /wheels

# GPU target: installs flash-attn wheel built in dedicated builder.
FROM cuda-app-base AS gpu

COPY --from=flashattn-builder /wheels /tmp/wheels

RUN --mount=type=cache,target=/root/.cache/pip \
python -m pip install /tmp/wheels/flash_attn-*.whl

# Server runtime defaults (override at `docker run -e ...`)
ENV FOR_EMBED_MODEL=vidore/colqwen2-v1.0 \
FOR_EMBED_DEVICE=cuda \
FOR_EMBED_VERBOSE=1 \
FOR_EMBED_FLASH_ATTENTION=auto \
FOR_SERVER_WORKERS=1 \
FOR_SERVER_MAX_INFLIGHT=1 \
FOR_SERVER_SINGLE_MODEL_CACHE=true \
FOR_EMBED_LOAD_IN_4BIT=false \
FOR_EMBED_LOAD_IN_8BIT=false \
FOR_EMBED_BNB_4BIT_QUANT_TYPE=nf4 \
FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE=float16

EXPOSE 8000

CMD ["sh", "-c", "python -m uvicorn foretrieval.server.server_main:app --host 0.0.0.0 --port 8000 --workers ${FOR_SERVER_WORKERS:-1}"]






# Default target: CPU-friendly runtime without GPU-specific optional deps.
FROM runtime-base AS cpu

# Server runtime defaults (override at `docker run -e ...`)
ENV FOR_EMBED_MODEL=vidore/colqwen2-v1.0 \
FOR_EMBED_DEVICE=cpu \
FOR_EMBED_VERBOSE=1 \
FOR_EMBED_FLASH_ATTENTION=auto \
FOR_SERVER_WORKERS=1 \
FOR_SERVER_MAX_INFLIGHT=1 \
FOR_SERVER_SINGLE_MODEL_CACHE=true \
FOR_EMBED_LOAD_IN_4BIT=false \
FOR_EMBED_LOAD_IN_8BIT=false \
FOR_EMBED_BNB_4BIT_QUANT_TYPE=nf4 \
FOR_EMBED_BNB_4BIT_COMPUTE_DTYPE=float16

EXPOSE 8000

CMD ["sh", "-c", "python -m uvicorn foretrieval.server.server_main:app --host 0.0.0.0 --port 8000 --workers ${FOR_SERVER_WORKERS:-1}"]
69 changes: 69 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
REGISTRY ?= ghcr.io
OWNER ?= random-plm
IMAGE_NAME ?= foretrieval-server
IMAGE_TAG ?= v0.0.1
IMAGE_REPO ?= $(REGISTRY)/$(OWNER)/$(IMAGE_NAME)
IMAGE ?= $(IMAGE_REPO):$(IMAGE_TAG)
IMAGE_CPU ?= $(IMAGE_REPO):$(IMAGE_TAG)-base
GPU_CUDA_VERSION ?= 12.2.2
GPU_CUDNN_MAJOR ?= 8
HOST_ARCH_RAW := $(shell uname -m)
HOST_ARCH ?= $(if $(filter x86_64 amd64,$(HOST_ARCH_RAW)),amd64,$(if $(filter aarch64 arm64,$(HOST_ARCH_RAW)),arm64,$(HOST_ARCH_RAW)))
IMAGE_GPU ?= $(IMAGE_REPO):$(IMAGE_TAG)-cuda$(GPU_CUDA_VERSION)-cudnn$(GPU_CUDNN_MAJOR)-flashattn-bnb-$(HOST_ARCH)

# Backward-compatible aliases.
NAMESPACE ?= $(OWNER)
PROJECT ?= $(IMAGE_NAME)
DOCKERFILE ?= Dockerfile
BUILD_CONTEXT ?= .
PYTEST ?= pytest
TEST_ARGS ?= -m "not slow and not integration"
SRC_PACKAGE ?= foretrieval

.PHONY: check-buildx login-registry build build-cpu build-gpu publish publish-cpu publish-gpu run-server test-fast test-all coverage-fast coverage-all

check-buildx:
@docker buildx version >/dev/null 2>&1 || { \
echo "Error: docker buildx is not available."; \
echo "Install Docker Buildx (or Docker Desktop), then run this command again."; \
echo "Linux plugin package is often named 'docker-buildx-plugin'."; \
exit 1; \
}

login-registry:
@test -n "$$GITHUB_PAT" || { \
echo "Error: GITHUB_PAT is not set. Export a classic PAT with write:packages."; \
exit 1; \
}
@printf '%s' "$$GITHUB_PAT" | docker login ghcr.io -u $(OWNER) --password-stdin

build: build-cpu build-gpu

build-cpu: check-buildx
docker buildx build --load -f $(DOCKERFILE) --target cpu -t $(IMAGE_CPU) $(BUILD_CONTEXT)

build-gpu: check-buildx
docker buildx build --load -f $(DOCKERFILE) --target gpu -t $(IMAGE_GPU) $(BUILD_CONTEXT)

publish: publish-cpu publish-gpu

publish-cpu: check-buildx login-registry
docker buildx build --push -f $(DOCKERFILE) --target cpu -t $(IMAGE_CPU) $(BUILD_CONTEXT)

publish-gpu: check-buildx login-registry
docker buildx build --push -f $(DOCKERFILE) --target gpu -t $(IMAGE_GPU) $(BUILD_CONTEXT)

run-server:
IMAGE=$(IMAGE_GPU) REGISTRY=$(REGISTRY) OWNER=$(OWNER) IMAGE_NAME=$(IMAGE_NAME) IMAGE_TAG=$(IMAGE_TAG) ./scripts/run-docker.sh

test-fast:
$(PYTEST) $(TEST_ARGS)

test-all:
$(PYTEST)

coverage-fast:
$(PYTEST) $(TEST_ARGS) --cov=$(SRC_PACKAGE) --cov-report=term-missing --cov-report=xml

coverage-all:
$(PYTEST) --cov=$(SRC_PACKAGE) --cov-report=term-missing --cov-report=xml
3 changes: 3 additions & 0 deletions foretrieval/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .embedding_backends import LocalEmbeddingBackend, RemoteEmbeddingBackend
from .remote_backend import RemoteEmbeddingClient
from .transport import dumps_payload, loads_payload
62 changes: 62 additions & 0 deletions foretrieval/client/embedding_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from typing import List, Protocol, Union

import torch
from PIL import Image

from .remote_backend import RemoteEmbeddingClient


class EmbeddingBackend(Protocol):
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
...

def embed_queries(self, queries: List[str]) -> torch.Tensor:
...


class LocalEmbeddingBackend:
def __init__(self, model, processor, device):
self.model = model
self.processor = processor
self.device = device

def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
with torch.inference_mode():
batch = self.processor.process_images(images)
batch = {
k: v.to(self.device).to(
self.model.dtype
if v.dtype in [torch.float16, torch.bfloat16, torch.float32]
else v.dtype
)
for k, v in batch.items()
}
embeddings = self.model(**batch)
return embeddings.cpu()

def embed_queries(self, queries: List[str]) -> torch.Tensor:
with torch.inference_mode():
batch = self.processor.process_queries(queries)
batch = {
k: v.to(self.device).to(
self.model.dtype
if v.dtype in [torch.float16, torch.bfloat16, torch.float32]
else v.dtype
)
for k, v in batch.items()
}
embeddings = self.model(**batch)
return embeddings.cpu()


class RemoteEmbeddingBackend:
def __init__(self, client: RemoteEmbeddingClient):
self.client = client

def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
return self.client.encode_images(images).cpu()

def embed_queries(self, queries: List[str]) -> torch.Tensor:
return self.client.encode_queries(queries).cpu()
Loading