Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions charts/model-engine/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ env:
- name: LAUNCH_SERVICE_TEMPLATE_FOLDER
value: "/workspace/model-engine/model_engine_server/infra/gateways/resources/templates"
{{- $model_cache := default dict .Values.modelCache }}
{{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") }}
- name: MODEL_CACHE_ENABLED
value: {{ get $model_cache "enabled" | default false | quote }}
- name: MODEL_CACHE_MOUNT_PATH
Expand Down Expand Up @@ -404,6 +405,14 @@ env:
- name: SERVICEBUS_NAMESPACE
value: {{ .Values.azure.servicebus_namespace }}
{{- end }}
{{- if $gcp_cloud_provider }}
- name: GCP_PROJECT_ID
value: {{ (.Values.gcp).project_id | default "" | quote }}
- name: PUBSUB_TOPIC_PREFIX
value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }}
- name: PUBSUB_SUBSCRIPTION_PREFIX
value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }}
{{- end }}
{{- if eq .Values.context "circleci" }}
- name: CIRCLECI
value: "true"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
{{- $tag := .Values.tag }}
{{- $message_broker := .Values.celeryBrokerType }}
{{- $num_shards := .Values.celery_autoscaler.num_shards }}
{{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") -}}
{{- $broker_name := "redis-elasticache-message-broker-master" }}
{{- if eq $message_broker "sqs" }}
{{ $broker_name = "sqs-message-broker-master" }}
{{- else if eq $message_broker "servicebus" }}
{{ $broker_name = "servicebus-message-broker-master" }}
{{- else if and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") }}
{{- else if $gcp_cloud_provider }}
{{ $broker_name = "redis-gcp-memorystore-message-broker-master" }}
{{- end }}
apiVersion: apps/v1
Expand Down Expand Up @@ -86,6 +87,14 @@ spec:
- name: SERVICEBUS_NAMESPACE
value: {{ .Values.azure.servicebus_namespace }}
{{- end }}
{{- if $gcp_cloud_provider }}
- name: GCP_PROJECT_ID
value: {{ (.Values.gcp).project_id | default "" | quote }}
- name: PUBSUB_TOPIC_PREFIX
value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }}
- name: PUBSUB_SUBSCRIPTION_PREFIX
value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }}
{{- end }}
image: "{{ .Values.image.gatewayRepository }}:{{ $tag }}"
imagePullPolicy: Always
name: main
Expand Down
6 changes: 6 additions & 0 deletions charts/model-engine/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,9 @@ utilityImages:

# Additional GPU tolerations for endpoint pods
gpuTolerations: []

# GCP configuration for GCP-based deployments
gcp:
project_id: ""
pubsub_topic_prefix: "launch-endpoint-id-"
pubsub_subscription_prefix: "launch-endpoint-id-"
6 changes: 6 additions & 0 deletions charts/model-engine/values_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,9 @@ recommendedHardware:
gpu_type: nvidia-hopper-h100
nodes_per_worker: 1
#serviceBuilderQueue:

# GCP configuration for GCP-based deployments
gcp:
project_id: "your-gcp-project"
pubsub_topic_prefix: "launch-endpoint-id-"
pubsub_subscription_prefix: "launch-endpoint-id-"
3 changes: 3 additions & 0 deletions clients/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@
python_requires=">=3.8",
version="0.0.0.beta45",
packages=find_packages(),
# types-setuptools 82.0.0+ tightened package_data to _DictLike; the literal dict
# still works at runtime, only the new stub disagrees. Suppress at the call site
# rather than down-pinning the stub (which would mask real future tightenings).
package_data={"llmengine": ["py.typed"]}, # type: ignore[arg-type]
)
19 changes: 14 additions & 5 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@
from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import (
FakeQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
)
Expand All @@ -104,9 +107,6 @@
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.redis_queue_endpoint_resource_delegate import (
RedisQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import (
SQSQueueEndpointResourceDelegate,
)
Expand Down Expand Up @@ -248,8 +248,17 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
# GCP uses Redis (Memorystore) for Celery, so use Redis-based queue delegate
queue_delegate = RedisQueueEndpointResourceDelegate(redis_client=redis_client)
# Mirror the SQS_PROFILE env-first pattern: the Helm chart injects GCP_PROJECT_ID as a
# pod env var (from .Values.gcp.project_id), which is a different source than the YAML-
# rendered infra_service_config. Read the env first so the chart value reaches the delegate;
# the infra_config.gcp_project_id field handles setups that wire it via the config YAML.
gcp_project_id = os.getenv("GCP_PROJECT_ID") or infra_config().gcp_project_id
if not gcp_project_id:
raise ValueError(
"cloud_provider=gcp requires GCP_PROJECT_ID env var "
"(via .Values.gcp.project_id) or infra.gcp_project_id in the service config."
)
queue_delegate = GcpPubSubQueueEndpointResourceDelegate(project_id=gcp_project_id)
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand Down
1 change: 1 addition & 0 deletions model-engine/model_engine_server/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class _InfraConfig:
celery_enable_sha256: Optional[bool] = None
docker_registry_type: Optional[str] = None
debug_mode: Optional[bool] = None
gcp_project_id: Optional[str] = None


@dataclass
Expand Down
13 changes: 13 additions & 0 deletions model-engine/model_engine_server/entrypoints/k8s_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import (
FakeQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.image_cache_gateway import ImageCacheGateway
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
Expand Down Expand Up @@ -119,6 +122,16 @@ async def main(args: Any):
queue_delegate = OnPremQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
# See dependencies.py for rationale: Helm injects GCP_PROJECT_ID as a pod env var;
# the infra_service_config YAML is a different source. Read the env first.
gcp_project_id = os.getenv("GCP_PROJECT_ID") or infra_config().gcp_project_id
if not gcp_project_id:
raise ValueError(
"cloud_provider=gcp requires GCP_PROJECT_ID env var "
"(via .Values.gcp.project_id) or infra.gcp_project_id in the service config."
)
queue_delegate = GcpPubSubQueueEndpointResourceDelegate(project_id=gcp_project_id)
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import (
FakeQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
)
Expand Down Expand Up @@ -90,6 +93,16 @@ async def run_batch_job(
queue_delegate = OnPremQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
# See dependencies.py for rationale: Helm injects GCP_PROJECT_ID as a pod env var;
# the infra_service_config YAML is a different source. Read the env first.
gcp_project_id = os.getenv("GCP_PROJECT_ID") or infra_config().gcp_project_id
if not gcp_project_id:
raise ValueError(
"cloud_provider=gcp requires GCP_PROJECT_ID env var "
"(via .Values.gcp.project_id) or infra.gcp_project_id in the service config."
)
queue_delegate = GcpPubSubQueueEndpointResourceDelegate(project_id=gcp_project_id)
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand All @@ -110,6 +123,9 @@ async def run_batch_job(
if infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "gcp":
inference_task_queue_gateway = redis_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
elif infra_config().cloud_provider == "onprem" or infra_config().celery_broker_type_redis:
# On-prem uses Redis-based task queues
inference_task_queue_gateway = redis_task_queue_gateway
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Any, Dict, Optional

from google.api_core import exceptions as gcp_exceptions
from google.cloud import pubsub_v1
from google.protobuf import field_mask_pb2
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import EndpointResourceInfraException
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
QueueInfo,
)

logger = make_logger(logger_name())

GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS = 600 # Pub/Sub hard limit


class GcpPubSubQueueEndpointResourceDelegate(QueueEndpointResourceDelegate):
"""
Using GCP Pub/Sub (topic + subscription per endpoint).

topic_prefix and subscription_prefix control the GCP resource name prefix.
The logical queue_name returned to callers always uses the canonical
QueueEndpointResourceDelegate.endpoint_id_to_queue_name format, independent
of these prefixes.
"""

def __init__(
self,
project_id: str,
topic_prefix: str = "launch-endpoint-id-",
subscription_prefix: str = "launch-endpoint-id-",
) -> None:
if not project_id:
raise ValueError(
"GcpPubSubQueueEndpointResourceDelegate requires a non-empty project_id; "
"set infra.gcp_project_id in the service config."
)
self.project_id = project_id
self.topic_prefix = topic_prefix
self.subscription_prefix = subscription_prefix
# Lazily-initialized gRPC clients. Construction calls Google ADC which is
# unavailable in unit-test environments, so defer until first real use.
# The clients are then cached for the lifetime of the delegate.
self._publisher_client: Optional[pubsub_v1.PublisherClient] = None
self._subscriber_client: Optional[pubsub_v1.SubscriberClient] = None

@property
def _publisher(self) -> pubsub_v1.PublisherClient:
if self._publisher_client is None:
self._publisher_client = pubsub_v1.PublisherClient()
return self._publisher_client

@property
def _subscriber(self) -> pubsub_v1.SubscriberClient:
if self._subscriber_client is None:
self._subscriber_client = pubsub_v1.SubscriberClient()
return self._subscriber_client

def _topic_id(self, endpoint_id: str) -> str:
return f"{self.topic_prefix}{endpoint_id}"

def _subscription_id(self, endpoint_id: str) -> str:
return f"{self.subscription_prefix}{endpoint_id}"

async def create_queue_if_not_exists(
self,
endpoint_id: str,
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_seconds: Optional[int] = None,
) -> QueueInfo:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}"
subscription_path = (
f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}"
)
ack_deadline = min(queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Pub/Sub enforces a 10–600 second range for ack_deadline_seconds. The current expression only clamps to the 600-second ceiling; values of 1–9 (a valid user-supplied queue_message_timeout_seconds) will be forwarded to the API and rejected with INVALID_ARGUMENT. Add a lower bound of 10 to mirror the ceiling clamp.

Suggested change
ack_deadline = min(queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS)
GCP_PUBSUB_MIN_ACK_DEADLINE_SECONDS = 10 # Pub/Sub hard minimum
ack_deadline = max(
GCP_PUBSUB_MIN_ACK_DEADLINE_SECONDS,
min(queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS),
)
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/infra/gateways/resources/gcp_pubsub_queue_endpoint_resource_delegate.py
Line: 64

Comment:
Pub/Sub enforces a 10–600 second range for `ack_deadline_seconds`. The current expression only clamps to the 600-second ceiling; values of 1–9 (a valid user-supplied `queue_message_timeout_seconds`) will be forwarded to the API and rejected with `INVALID_ARGUMENT`. Add a lower bound of 10 to mirror the ceiling clamp.

```suggestion
        GCP_PUBSUB_MIN_ACK_DEADLINE_SECONDS = 10  # Pub/Sub hard minimum
        ack_deadline = max(
            GCP_PUBSUB_MIN_ACK_DEADLINE_SECONDS,
            min(queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS),
        )
```

How can I resolve this? If you propose a fix, please make it concise.

Fix in Cursor Fix in Claude Code Fix in Codex


try:
self._publisher.create_topic(name=topic_path)
except gcp_exceptions.AlreadyExists:
pass

try:
self._subscriber.create_subscription(
name=subscription_path,
topic=topic_path,
ack_deadline_seconds=ack_deadline,
)
except gcp_exceptions.AlreadyExists:
try:
self._subscriber.update_subscription(
subscription=pubsub_v1.types.Subscription(
name=subscription_path,
ack_deadline_seconds=ack_deadline,
),
update_mask=field_mask_pb2.FieldMask(paths=["ack_deadline_seconds"]),
)
except gcp_exceptions.GoogleAPIError as e:
logger.warning(
f"Failed to update ack_deadline for Pub/Sub subscription {subscription_path}: {e}"
)

# Pub/Sub has no URL concept analogous to SQS queue URLs
return QueueInfo(queue_name, queue_url=None)

async def delete_queue(self, endpoint_id: str) -> None:
subscription_path = (
f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}"
)
topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}"

# Always attempt BOTH deletions so a failure on one doesn't leave the other resource
# orphaned (Greptile P1). NotFound is silent. Other GoogleAPIErrors are collected and
# surfaced together at the end so callers see every cleanup failure, not just the first.
errors: list[tuple[str, str, gcp_exceptions.GoogleAPIError]] = []

try:
self._subscriber.delete_subscription(subscription=subscription_path)
except gcp_exceptions.NotFound:
logger.info(
f"Could not find Pub/Sub subscription {subscription_path} for endpoint {endpoint_id}"
)
except gcp_exceptions.GoogleAPIError as e:
errors.append(("subscription", subscription_path, e))

try:
self._publisher.delete_topic(topic=topic_path)
except gcp_exceptions.NotFound:
logger.info(f"Could not find Pub/Sub topic {topic_path} for endpoint {endpoint_id}")
except gcp_exceptions.GoogleAPIError as e:
errors.append(("topic", topic_path, e))

if errors:
details = "; ".join(
f"Failed to delete Pub/Sub {kind} {path}: {err}" for kind, path, err in errors
)
raise EndpointResourceInfraException(
f"Cleanup errors for endpoint {endpoint_id}: {details}"
) from errors[0][2]

async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
return {
"name": queue_name,
# Pub/Sub does not expose a synchronous undelivered message count;
# real observability requires the Cloud Monitoring API as a separate concern.
"num_undelivered_messages": -1,
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ async def get_resources(
)
elif "active_message_count" in sqs_attributes: # from ASBQueueEndpointResourceDelegate
resources.num_queued_items = int(sqs_attributes["active_message_count"])
elif (
"num_undelivered_messages" in sqs_attributes
): # from GcpPubSubQueueEndpointResourceDelegate
# Pub/Sub returns -1 when num_undelivered_messages is not yet wired to Cloud Monitoring.
# Treat -1 as "unknown" and skip; downstream autoscaling expects non-negative counts.
gcp_count = int(sqs_attributes["num_undelivered_messages"])
if gcp_count >= 0:
resources.num_queued_items = gcp_count

return resources

Expand Down
6 changes: 6 additions & 0 deletions model-engine/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ azure-storage-blob~=12.19.0
# GCP dependencies
gcloud-aio-storage~=9.6
google-auth~=2.25.0
google-cloud-pubsub>=2.18
# google-cloud-pubsub transitively pulls opentelemetry-sdk, which flips
# common/startup_tracing/correlation.py's OTEL_AVAILABLE to True. Once that's
# True, tracer.py imports from opentelemetry-exporter-otlp-proto-grpc, which
# isn't otherwise a dependency. Pin it explicitly so the import resolves.
opentelemetry-exporter-otlp-proto-grpc
google-cloud-artifact-registry~=1.21.0
google-cloud-secret-manager>=2.24.0
google-cloud-storage~=2.14.0
Expand Down
Loading