Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def agent_run_wrapper( # type: ignore[no-untyped-def]
contents=contents_arg,
user_simulator_config=user_simulator_config_arg,
agent=agent_arg,
api_client=api_client_arg,
)

future = executor.submit(
Expand Down Expand Up @@ -955,6 +956,7 @@ async def _run_adk_user_simulation(
row: pd.Series,
agent: "LlmAgent", # type: ignore # noqa: F821
config: Optional[types.evals.UserSimulatorConfig] = None,
api_client: Optional[BaseApiClient] = None,
) -> list[dict[str, Any]]:
"""Runs a multi-turn user simulation using ADK's EvaluationGenerator."""
# Lazy-import ADK dependencies to avoid top-level import failures when
Expand Down Expand Up @@ -996,6 +998,23 @@ async def _run_adk_user_simulation(
if config.max_turn is not None:
user_simulator_kwargs["max_allowed_invocations"] = config.max_turn

# When using a Vertex AI client, convert the user simulator model name to
# a full resource path so that ADK's Gemini class automatically uses the
# Vertex AI backend. This removes the need for users to manually set the
# GOOGLE_GENAI_USE_VERTEXAI environment variable.
if (
api_client
and getattr(api_client, "project", None)
and getattr(api_client, "location", None)
):
model_name = user_simulator_kwargs.get("model", "gemini-2.5-flash")
if not model_name.startswith("projects/"):
user_simulator_kwargs["model"] = (
f"projects/{api_client.project}"
f"/locations/{api_client.location}"
f"/publishers/google/models/{model_name}"
)

user_simulator_config = LlmBackedUserSimulatorConfig(**user_simulator_kwargs)
user_simulator = LlmBackedUserSimulator(
conversation_scenario=scenario, config=user_simulator_config
Expand Down Expand Up @@ -2095,11 +2114,12 @@ def _execute_local_agent_run_with_retry(
agent: "LlmAgent", # type: ignore # noqa: F821
max_retries: int = 3,
user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None,
api_client: Optional[BaseApiClient] = None,
) -> Union[list[dict[str, Any]], dict[str, Any]]:
"""Executes agent run locally for a single prompt synchronously."""
return asyncio.run(
_execute_local_agent_run_with_retry_async(
row, contents, agent, max_retries, user_simulator_config
row, contents, agent, max_retries, user_simulator_config, api_client
)
)

Expand All @@ -2110,6 +2130,7 @@ async def _execute_local_agent_run_with_retry_async(
agent: "LlmAgent", # type: ignore # noqa: F821
max_retries: int = 3,
user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None,
api_client: Optional[BaseApiClient] = None,
) -> Union[list[dict[str, Any]], dict[str, Any]]:
"""Executes agent run locally for a single prompt asynchronously."""
# Lazy-import ADK dependencies to avoid top-level import failures when
Expand All @@ -2120,7 +2141,9 @@ async def _execute_local_agent_run_with_retry_async(
# Multi-turn agent scraping with user simulation.
if user_simulator_config or "conversation_plan" in row:
try:
return await _run_adk_user_simulation(row, agent, user_simulator_config)
return await _run_adk_user_simulation(
row, agent, user_simulator_config, api_client
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Multi-turn agent run with user simulation failed: %s", e)
return {"error": f"Multi-turn agent run with user simulation failed: {e}"}
Expand Down
Loading