Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PreCommit_Python_ML.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"revision": 3
"revision": 6
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __getitem__(self, key):


@pytest.mark.require_docker_in_docker
@pytest.mark.no_xdist
@unittest.skipUnless(
platform.system() == "Linux",
"Test runs only on Linux due to lack of support, as yet, for nested "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def drop_collection(client: MilvusClient, collection_name: str):


@pytest.mark.require_docker_in_docker
@pytest.mark.no_xdist
@unittest.skipIf(not PYMILVUS_AVAILABLE, 'pymilvus is not installed.')
@unittest.skipUnless(
platform.system() == "Linux",
Expand Down
75 changes: 23 additions & 52 deletions sdks/python/apache_beam/ml/rag/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import contextlib
import logging
import os
import socket
import tempfile
import unittest
from dataclasses import dataclass
Expand Down Expand Up @@ -47,6 +46,10 @@

_LOGGER = logging.getLogger(__name__)

# Milvus standalone defaults (match testcontainers MilvusContainer).
_MILVUS_SERVICE_PORT = 19530
_MILVUS_METRICS_PORT = 9091


@dataclass
class VectorDBContainerInfo:
Expand All @@ -68,58 +71,26 @@ def uri(self) -> str:
return f"http://{self.host}:{self.port}"


class TestHelpers:
@staticmethod
def find_free_port():
"""Find a free port on the local machine."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# Bind to port 0, which asks OS to assign a free port.
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# Return the port number assigned by OS.
return s.getsockname()[1]


class CustomMilvusContainer(MilvusContainer):
"""Custom Milvus container with configurable ports and environment setup.

Extends MilvusContainer to provide custom port binding and environment
configuration for testing with standalone Milvus instances.
"""
"""Milvus container with user.yaml volume for integration test configuration."""

def __init__( # pylint: disable=bad-super-call
self,
image: str,
service_container_port,
healthcheck_container_port,
**kwargs,
) -> None:
# Skip the parent class's constructor and go straight to
# GenericContainer.
super(
MilvusContainer,
self,
).__init__(
image=image, **kwargs)
self.port = service_container_port
self.healthcheck_port = healthcheck_container_port
self.with_exposed_ports(service_container_port, healthcheck_container_port)

# Get free host ports.
service_host_port = TestHelpers.find_free_port()
healthcheck_host_port = TestHelpers.find_free_port()

# Bind container and host ports.
self.with_bind_ports(service_container_port, service_host_port)
self.with_bind_ports(healthcheck_container_port, healthcheck_host_port)
super(MilvusContainer, self).__init__(image=image, **kwargs)
self.port = _MILVUS_SERVICE_PORT
self.healthcheck_port = _MILVUS_METRICS_PORT
self.with_exposed_ports(self.port, self.healthcheck_port)
self.cmd = "milvus run standalone"
self.with_command(self.cmd)

# Set environment variables needed for Milvus.
envs = {
"ETCD_USE_EMBED": "true",
"ETCD_DATA_DIR": "/var/lib/milvus/etcd",
"COMMON_STORAGETYPE": "local",
"METRICS_PORT": str(healthcheck_container_port)
"METRICS_PORT": str(_MILVUS_METRICS_PORT),
}
for env, value in envs.items():
self.with_env(env, value)
Expand All @@ -139,9 +110,11 @@ class MilvusTestHelpers:
# Example: Milvus v2.6.0 requires pymilvus==2.6.0 (exact match required).
@staticmethod
def _wait_for_milvus_grpc(uri: str) -> None:
"""Wait until Milvus accepts RPCs.
"""Wait until Milvus gRPC proxy accepts connections.

Docker may report started before gRPC is ready.
MilvusContainer.start() only health-checks the metrics HTTP port; the gRPC
proxy can become ready later. Use the same bounded retry budget as other
Milvus client setup in this module (well under the pytest 600s limit).
"""
def list_collections_probe():
client = MilvusClient(uri=uri)
Expand All @@ -152,9 +125,8 @@ def list_collections_probe():

retry_with_backoff(
list_collections_probe,
max_retries=25,
max_retries=5,
retry_delay=2.0,
retry_backoff_factor=1.2,
operation_name="Milvus client connection after container start",
exception_types=(MilvusException, ))

Expand All @@ -164,29 +136,24 @@ def start_db_container(
max_vec_fields=5,
vector_client_max_retries=3,
tc_max_retries=None) -> Optional[VectorDBContainerInfo]:
service_container_port = TestHelpers.find_free_port()
healthcheck_container_port = TestHelpers.find_free_port()
user_yaml_creator = MilvusTestHelpers.create_user_yaml
with user_yaml_creator(service_container_port, max_vec_fields) as cfg:
with user_yaml_creator(_MILVUS_SERVICE_PORT, max_vec_fields) as cfg:
info = None
original_tc_max_tries = testcontainers_config.max_tries
if tc_max_retries is not None:
testcontainers_config.max_tries = tc_max_retries
for i in range(vector_client_max_retries):
vector_db_container: Optional[CustomMilvusContainer] = None
try:
vector_db_container = CustomMilvusContainer(
image=image,
service_container_port=service_container_port,
healthcheck_container_port=healthcheck_container_port)
vector_db_container = CustomMilvusContainer(image=image)
mapped_container = vector_db_container.with_volume_mapping(
cfg, "/milvus/configs/user.yaml")
assert mapped_container is not None
running_container: CustomMilvusContainer = mapped_container
vector_db_container = running_container
running_container.start()
host = running_container.get_container_host_ip()
port = running_container.get_exposed_port(service_container_port)
port = running_container.get_exposed_port(_MILVUS_SERVICE_PORT)
info = VectorDBContainerInfo(running_container, host, port)
MilvusTestHelpers._wait_for_milvus_grpc(info.uri)
_LOGGER.info(
Expand All @@ -198,6 +165,10 @@ def start_db_container(
raw_out, raw_err = vector_db_container.get_logs()
stdout_logs = raw_out.decode("utf-8")
stderr_logs = raw_err.decode("utf-8")
try:
vector_db_container.stop()
except Exception: # pylint: disable=broad-except
pass
_LOGGER.warning(
"Retry %d/%d: Failed to start Milvus DB container. Reason: %s. "
"STDOUT logs:\n%s\nSTDERR logs:\n%s",
Expand Down
Loading