diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 91fb568cd3..ae94c141e2 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -291,9 +291,13 @@ async def run_async( if ctx.end_invocation: return - async with Aclosing(self._run_async_impl(ctx)) as agen: - async for event in agen: - yield event + try: + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + yield event + except Exception as e: + await self._handle_agent_error_callback(ctx, e) + raise if ctx.end_invocation: return @@ -323,9 +327,13 @@ async def run_live( if ctx.end_invocation: return - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: - yield event + try: + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + yield event + except Exception as e: + await self._handle_agent_error_callback(ctx, e) + raise if event := await self._handle_after_agent_callback(ctx): yield event @@ -545,6 +553,27 @@ async def _handle_after_agent_callback( ) return None + async def _handle_agent_error_callback( + self, + invocation_context: InvocationContext, + error: Exception, + ) -> None: + """Runs the on_agent_error_callback for all plugins. + + This is notification-only: the exception is always re-raised by + the caller after this method returns. + + Args: + invocation_context: The invocation context for this agent. + error: The exception that escaped agent execution. + """ + callback_context = CallbackContext(invocation_context) + await invocation_context.plugin_manager.run_on_agent_error_callback( + agent=self, + callback_context=callback_context, + error=error, + ) + @override def model_post_init(self, __context: Any) -> None: self.__set_parent_agent_for_sub_agents() diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 54bfab2ed2..50ce806623 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -370,3 +370,41 @@ async def on_tool_error_callback( allows the original error to be raised. """ pass + + async def on_agent_error_callback( + self, + *, + agent: BaseAgent, + callback_context: CallbackContext, + error: Exception, + ) -> None: + """Callback executed when an unhandled exception escapes agent execution. + + This is a notification-only callback. The exception is always re-raised + after all registered plugins have been notified. Plugins should NOT + suppress the exception. + + Args: + agent: The agent instance that encountered the error. + callback_context: The callback context for the agent invocation. + error: The exception that was raised during agent execution. + """ + pass + + async def on_run_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> None: + """Callback executed when an unhandled exception escapes runner execution. + + This is a notification-only callback. The exception is always re-raised + after all registered plugins have been notified. Plugins should NOT + suppress the exception. + + Args: + invocation_context: The context for the entire invocation. + error: The exception that was raised during runner execution. + """ + pass diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 50eb72ffdb..0172f30994 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -28,6 +28,7 @@ import logging import mimetypes import os +import traceback as traceback_module # Enable gRPC fork support so child processes created via os.fork() # can safely create new gRPC channels. Must be set before grpc's @@ -1870,8 +1871,15 @@ def _get_events_schema() -> list[bigquery.SchemaField]: "AGENT_COMPLETED": [ "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", ], + "AGENT_ERROR": [ + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + "JSON_VALUE(content, '$.error_traceback') AS error_traceback", + ], "INVOCATION_STARTING": [], "INVOCATION_COMPLETED": [], + "INVOCATION_ERROR": [ + "JSON_VALUE(content, '$.error_traceback') AS error_traceback", + ], "STATE_DELTA": [ "JSON_QUERY(attributes, '$.state_delta') AS state_delta", ], @@ -3505,3 +3513,98 @@ async def on_tool_error_callback( parent_span_id_override=parent_span_id, ), ) + + @_safe_callback + async def on_agent_error_callback( + self, + *, + agent: Any, + callback_context: CallbackContext, + error: Exception, + ) -> None: + """Callback when an agent execution fails with an unhandled exception. + + Emits an AGENT_ERROR event and pops the agent span from + TraceManager. + + Args: + agent: The agent instance that failed. + callback_context: The callback context. + error: The exception that escaped agent execution. + """ + span_id, duration = TraceManager.pop_span() + parent_span_id, _ = TraceManager.get_current_span_and_parent() + + error_tb = "".join( + traceback_module.format_exception( + type(error), error, error.__traceback__ + ) + ) + max_len = self.config.max_content_length + if max_len > 0 and len(error_tb) > max_len: + error_tb = error_tb[:max_len] + "... [truncated]" + + await self._log_event( + "AGENT_ERROR", + callback_context, + event_data=EventData( + status="ERROR", + error_message=str(error), + latency_ms=duration, + span_id_override=span_id, + parent_span_id_override=parent_span_id, + ), + raw_content={"error_traceback": error_tb}, + ) + + @_safe_callback + async def on_run_error_callback( + self, + *, + invocation_context: "InvocationContext", + error: Exception, + ) -> None: + """Callback when a runner execution fails with an unhandled exception. + + Emits an INVOCATION_ERROR event and performs the cleanup that + after_run_callback would normally do. + + Args: + invocation_context: The context of the current invocation. + error: The exception that escaped runner execution. + """ + try: + callback_ctx = CallbackContext(invocation_context) + trace_id = TraceManager.get_trace_id(callback_ctx) + + span_id, duration = TraceManager.pop_span() + parent_span_id = TraceManager.get_current_span_id() + + error_tb = "".join( + traceback_module.format_exception( + type(error), error, error.__traceback__ + ) + ) + max_len = self.config.max_content_length + if max_len > 0 and len(error_tb) > max_len: + error_tb = error_tb[:max_len] + "... [truncated]" + + await self._log_event( + "INVOCATION_ERROR", + callback_ctx, + event_data=EventData( + trace_id_override=trace_id, + status="ERROR", + error_message=str(error), + latency_ms=duration, + span_id_override=span_id, + parent_span_id_override=parent_span_id, + ), + raw_content={"error_traceback": error_tb}, + ) + finally: + # Cleanup must run even if _log_event raises. + TraceManager.clear_stack() + _active_invocation_id_ctx.set(None) + _root_agent_name_ctx.set(None) + await self.flush() diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index c781e8fa4e..1f6c203a5b 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -52,6 +52,8 @@ "after_model_callback", "on_tool_error_callback", "on_model_error_callback", + "on_agent_error_callback", + "on_run_error_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -306,6 +308,61 @@ async def _run_callbacks( return None + async def run_on_agent_error_callback( + self, + *, + agent: BaseAgent, + callback_context: CallbackContext, + error: Exception, + ) -> None: + """Runs the `on_agent_error_callback` for all plugins.""" + await self._run_notification_callbacks( + "on_agent_error_callback", + agent=agent, + callback_context=callback_context, + error=error, + ) + + async def run_on_run_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> None: + """Runs the `on_run_error_callback` for all plugins.""" + await self._run_notification_callbacks( + "on_run_error_callback", + invocation_context=invocation_context, + error=error, + ) + + async def _run_notification_callbacks( + self, callback_name: PluginCallbackName, **kwargs: Any + ) -> None: + """Executes a notification-only callback for all registered plugins. + + Unlike ``_run_callbacks``, this method is best-effort: it always + iterates all plugins regardless of return values or exceptions. + If a plugin's callback raises, the error is logged and iteration + continues so that every plugin gets notified. + + Args: + callback_name: The name of the callback method to execute. + **kwargs: Keyword arguments to be passed to the callback method. + """ + for plugin in self.plugins: + callback_method = getattr(plugin, callback_name) + try: + await callback_method(**kwargs) + except Exception as e: + logger.error( + "Error in plugin '%s' during '%s' callback: %s", + plugin.name, + callback_name, + e, + exc_info=True, + ) + async def close(self) -> None: """Calls the close method on all registered plugins concurrently. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 0f36d6389d..8ec71837e7 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -851,115 +851,108 @@ async def _exec_with_plugin( plugin_manager = invocation_context.plugin_manager - # Step 1: Run the before_run callbacks to see if we should early exit. - early_exit_result = await plugin_manager.run_before_run_callback( - invocation_context=invocation_context - ) - if isinstance(early_exit_result, types.Content): - early_exit_event = Event( - invocation_id=invocation_context.invocation_id, - author='model', - content=early_exit_result, - ) - _apply_run_config_custom_metadata( - early_exit_event, invocation_context.run_config + try: + # Step 1: Run the before_run callbacks to see if we should + # early exit. + early_exit_result = await plugin_manager.run_before_run_callback( + invocation_context=invocation_context ) - if self._should_append_event(early_exit_event, is_live_call): - await self.session_service.append_event( - session=session, - event=early_exit_event, + if isinstance(early_exit_result, types.Content): + early_exit_event = Event( + invocation_id=invocation_context.invocation_id, + author='model', + content=early_exit_result, ) - yield early_exit_event - else: - # Step 2: Otherwise continue with normal execution - # Note for live/bidi: - # the transcription may arrive later than the action(function call - # event and thus function response event). In this case, the order of - # transcription and function call event will be wrong if we just - # append as it arrives. To address this, we should check if there is - # transcription going on. If there is transcription going on, we - # should hold on appending the function call event until the - # transcription is finished. The transcription in progress can be - # identified by checking if the transcription event is partial. When - # the next transcription event is not partial, it means the previous - # transcription is finished. Then if there is any buffered function - # call event, we should append them after this finished(non-partial) - # transcription event. - buffered_events: list[Event] = [] - is_transcribing: bool = False - - async with Aclosing(execute_fn(invocation_context)) as agen: - async for event in agen: - _apply_run_config_custom_metadata( - event, invocation_context.run_config - ) - # Step 3: Run the on_event callbacks before persisting so callback - # changes are stored in the session and match the streamed event. - modified_event = await plugin_manager.run_on_event_callback( - invocation_context=invocation_context, event=event - ) - output_event = self._get_output_event( - original_event=event, - modified_event=modified_event, - run_config=invocation_context.run_config, + _apply_run_config_custom_metadata( + early_exit_event, invocation_context.run_config + ) + if self._should_append_event(early_exit_event, is_live_call): + await self.session_service.append_event( + session=session, + event=early_exit_event, ) + yield early_exit_event + else: + # Step 2: Otherwise continue with normal execution + buffered_events: list[Event] = [] + is_transcribing: bool = False - if is_live_call: - if event.partial and _is_transcription(event): - is_transcribing = True - if is_transcribing and _is_tool_call_or_response(event): - # only buffer function call and function response event which is - # non-partial - buffered_events.append(output_event) - continue - # Note for live/bidi: for audio response, it's considered as - # non-partial event(event.partial=None) - # event.partial=False and event.partial=None are considered as - # non-partial event; event.partial=True is considered as partial - # event. - if event.partial is not True: - if _is_transcription(event) and ( - _has_non_empty_transcription_text(event.input_transcription) - or _has_non_empty_transcription_text( - event.output_transcription - ) - ): - # transcription end signal, append buffered events - is_transcribing = False - logger.debug( - 'Appending transcription finished event: %s', event - ) - if self._should_append_event(event, is_live_call): - await self.session_service.append_event( - session=session, event=output_event - ) + async with Aclosing(execute_fn(invocation_context)) as agen: + async for event in agen: + _apply_run_config_custom_metadata( + event, invocation_context.run_config + ) + # Step 3: Run the on_event callbacks before persisting + # so callback changes are stored in the session and + # match the streamed event. + modified_event = await plugin_manager.run_on_event_callback( + invocation_context=invocation_context, event=event + ) + output_event = self._get_output_event( + original_event=event, + modified_event=modified_event, + run_config=invocation_context.run_config, + ) - for buffered_event in buffered_events: - logger.debug('Appending buffered event: %s', buffered_event) - await self.session_service.append_event( - session=session, event=buffered_event + if is_live_call: + if event.partial and _is_transcription(event): + is_transcribing = True + if is_transcribing and _is_tool_call_or_response(event): + buffered_events.append(output_event) + continue + if event.partial is not True: + if _is_transcription(event) and ( + _has_non_empty_transcription_text(event.input_transcription) + or _has_non_empty_transcription_text( + event.output_transcription + ) + ): + is_transcribing = False + logger.debug( + 'Appending transcription finished event: %s', + event, ) - yield buffered_event # yield buffered events to caller - buffered_events = [] - else: - # non-transcription event or empty transcription event, for - # example, event that stores blob reference, should be appended. - if self._should_append_event(event, is_live_call): - logger.debug('Appending non-buffered event: %s', event) - await self.session_service.append_event( - session=session, event=output_event - ) - else: - if event.partial is not True: - await self.session_service.append_event( - session=session, event=output_event - ) + if self._should_append_event(event, is_live_call): + await self.session_service.append_event( + session=session, event=output_event + ) + + for buffered_event in buffered_events: + logger.debug( + 'Appending buffered event: %s', + buffered_event, + ) + await self.session_service.append_event( + session=session, event=buffered_event + ) + yield buffered_event + buffered_events = [] + else: + if self._should_append_event(event, is_live_call): + logger.debug('Appending non-buffered event: %s', event) + await self.session_service.append_event( + session=session, event=output_event + ) + else: + if event.partial is not True: + await self.session_service.append_event( + session=session, event=output_event + ) - yield output_event + yield output_event + except Exception as e: + # Notify plugins of the unhandled execution error. Covers + # failures in before_run_callback, early-exit, and the main + # execution loop. Notification-only; always re-raised. + await plugin_manager.run_on_run_error_callback( + invocation_context=invocation_context, + error=e, + ) + raise - # Step 4: Run the after_run callbacks to perform global cleanup tasks or - # finalizing logs and metrics data. - # This does NOT emit any event. + # Step 4: Run the after_run callbacks to perform global cleanup + # tasks or finalizing logs and metrics data. + # This does NOT emit any event. Only runs on success. await plugin_manager.run_after_run_callback( invocation_context=invocation_context ) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 8a05392bec..7239e4adb2 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -1807,6 +1807,144 @@ async def test_on_tool_error_callback_logs_correctly( assert log_entry["error_message"] == "Tool timed out" assert log_entry["status"] == "ERROR" + @pytest.mark.asyncio + async def test_on_agent_error_callback_logs_correctly( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """on_agent_error_callback emits AGENT_ERROR with traceback.""" + error = RuntimeError("Agent crashed") + try: + raise error + except RuntimeError: + pass # populate __traceback__ + bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context) + await bq_plugin_inst.on_agent_error_callback( + agent=mock_agent, + callback_context=callback_context, + error=error, + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + log_entry = next(r for r in rows if r["event_type"] == "AGENT_ERROR") + assert log_entry["error_message"] == "Agent crashed" + assert log_entry["status"] == "ERROR" + content = json.loads(log_entry["content"]) + assert "error_traceback" in content + assert "RuntimeError: Agent crashed" in content["error_traceback"] + + @pytest.mark.asyncio + async def test_on_run_error_callback_logs_correctly( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """on_run_error_callback emits INVOCATION_ERROR with traceback.""" + error = ValueError("Invocation failed") + try: + raise error + except ValueError: + pass + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_run_error_callback( + invocation_context=invocation_context, + error=error, + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + log_entry = next(r for r in rows if r["event_type"] == "INVOCATION_ERROR") + assert log_entry["error_message"] == "Invocation failed" + assert log_entry["status"] == "ERROR" + content = json.loads(log_entry["content"]) + assert "error_traceback" in content + assert "ValueError: Invocation failed" in content["error_traceback"] + + @pytest.mark.asyncio + async def test_on_run_error_callback_cleanup_runs_on_log_failure( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + ): + """on_run_error_callback cleans up even when _log_event raises.""" + # Push spans and set context vars to simulate active invocation + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set("test-inv") + bigquery_agent_analytics_plugin._root_agent_name_ctx.set("test-agent") + + # Make _log_event raise + with mock.patch.object( + bq_plugin_inst, "_log_event", side_effect=RuntimeError("boom") + ): + # @_safe_callback swallows the exception + await bq_plugin_inst.on_run_error_callback( + invocation_context=invocation_context, + error=ValueError("app error"), + ) + + # finally block must have cleaned up + assert ( + bigquery_agent_analytics_plugin._active_invocation_id_ctx.get(None) + is None + ) + assert ( + bigquery_agent_analytics_plugin._root_agent_name_ctx.get(None) is None + ) + + @pytest.mark.asyncio + async def test_traceback_not_truncated_with_negative_max_len( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + invocation_context, + mock_agent, + dummy_arrow_schema, + ): + """Traceback is not truncated when max_content_length is -1.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + max_content_length=-1, + create_views=False, + ) + async with managed_plugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) as plugin: + await plugin._ensure_started() + + error = RuntimeError("x" * 2000) + try: + raise error + except RuntimeError: + pass + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await plugin.on_agent_error_callback( + agent=mock_agent, + callback_context=bigquery_agent_analytics_plugin.CallbackContext( + invocation_context + ), + error=error, + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + log_entry = next(r for r in rows if r["event_type"] == "AGENT_ERROR") + content = json.loads(log_entry["content"]) + # Should NOT be truncated + assert "[truncated]" not in content["error_traceback"] + assert "x" * 2000 in content["error_traceback"] + @pytest.mark.asyncio async def test_table_creation_options( self, @@ -5733,6 +5871,27 @@ def test_view_sql_contains_correct_event_filter(self): view_name = "v_" + event_type.lower() assert view_name in all_sql, f"View {view_name} not found in SQL" + def test_error_views_contain_traceback_column(self): + """AGENT_ERROR and INVOCATION_ERROR views include error_traceback.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + calls = plugin.client.query.call_args_list + all_sqls = {c[0][0] for c in calls} + + agent_error_sqls = [s for s in all_sqls if "v_agent_error" in s] + assert len(agent_error_sqls) == 1 + assert "error_traceback" in agent_error_sqls[0] + assert "total_ms" in agent_error_sqls[0] + + inv_error_sqls = [s for s in all_sqls if "v_invocation_error" in s] + assert len(inv_error_sqls) == 1 + assert "error_traceback" in inv_error_sqls[0] + def test_config_create_views_default_true(self): """Config create_views defaults to True.""" config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() diff --git a/tests/unittests/plugins/test_error_callbacks.py b/tests/unittests/plugins/test_error_callbacks.py new file mode 100644 index 0000000000..17110abe8a --- /dev/null +++ b/tests/unittests/plugins/test_error_callbacks.py @@ -0,0 +1,612 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for on_agent_error_callback and on_run_error_callback. + +Validates RFC #5044: agent-level and runner-level error callbacks. +""" + +import asyncio +from typing import AsyncGenerator +from typing import Optional +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.plugins.plugin_manager import PluginManager +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +import pytest +from typing_extensions import override + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _CrashingAgent(BaseAgent): + """Agent whose _run_async_impl always raises.""" + + crash_error: Exception = RuntimeError("agent crashed") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise self.crash_error + yield # make it an async generator + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise self.crash_error + yield + + +class _SuccessAgent(BaseAgent): + """Agent that completes successfully.""" + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + branch=ctx.branch, + invocation_id=ctx.invocation_id, + content=types.Content(parts=[types.Part(text="ok")]), + ) + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + branch=ctx.branch, + invocation_id=ctx.invocation_id, + content=types.Content(parts=[types.Part(text="ok live")]), + ) + + +class _ErrorTrackingPlugin(BasePlugin): + """Plugin that records which error callbacks were called.""" + + __test__ = False + + def __init__(self, name: str = "error_tracker"): + super().__init__(name) + self.agent_errors: list[tuple[str, Exception]] = [] + self.run_errors: list[Exception] = [] + self.after_agent_called = False + self.after_run_called = False + + async def on_agent_error_callback( + self, + *, + agent: BaseAgent, + callback_context: CallbackContext, + error: Exception, + ) -> None: + self.agent_errors.append((agent.name, error)) + + async def on_run_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> None: + self.run_errors.append(error) + + async def after_agent_callback( + self, + *, + agent: BaseAgent, + callback_context: CallbackContext, + ) -> Optional[types.Content]: + self.after_agent_called = True + return None + + async def after_run_callback( + self, + *, + invocation_context: InvocationContext, + ) -> None: + self.after_run_called = True + + +async def _create_ctx( + agent: BaseAgent, + plugins: list[BasePlugin] | None = None, +) -> InvocationContext: + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name="test_app", user_id="test_user" + ) + return InvocationContext( + invocation_id="test_invocation", + agent=agent, + session=session, + session_service=session_service, + plugin_manager=PluginManager(plugins=plugins or []), + ) + + +# --------------------------------------------------------------------------- +# Agent-level error callback tests +# --------------------------------------------------------------------------- + + +class TestAgentErrorCallback: + """Tests for on_agent_error_callback in base_agent.py.""" + + @pytest.mark.asyncio + async def test_agent_error_callback_fires_on_crash(self): + """Error callback fires when _run_async_impl raises.""" + plugin = _ErrorTrackingPlugin() + agent = _CrashingAgent(name="crash_agent") + ctx = await _create_ctx(agent, plugins=[plugin]) + + with pytest.raises(RuntimeError, match="agent crashed"): + _ = [e async for e in agent.run_async(ctx)] + + assert len(plugin.agent_errors) == 1 + assert plugin.agent_errors[0][0] == "crash_agent" + assert str(plugin.agent_errors[0][1]) == "agent crashed" + + @pytest.mark.asyncio + async def test_agent_error_callback_fires_on_live_crash(self): + """Error callback fires when _run_live_impl raises.""" + plugin = _ErrorTrackingPlugin() + agent = _CrashingAgent(name="crash_agent") + ctx = await _create_ctx(agent, plugins=[plugin]) + + with pytest.raises(RuntimeError, match="agent crashed"): + _ = [e async for e in agent.run_live(ctx)] + + assert len(plugin.agent_errors) == 1 + assert plugin.agent_errors[0][0] == "crash_agent" + + @pytest.mark.asyncio + async def test_after_agent_not_called_on_crash(self): + """after_agent_callback (success-only) is NOT called on failure.""" + plugin = _ErrorTrackingPlugin() + agent = _CrashingAgent(name="crash_agent") + ctx = await _create_ctx(agent, plugins=[plugin]) + + with pytest.raises(RuntimeError): + _ = [e async for e in agent.run_async(ctx)] + + assert not plugin.after_agent_called + + @pytest.mark.asyncio + async def test_exception_is_reraised_after_agent_error_callback(self): + """The original exception propagates after the error callback.""" + plugin = _ErrorTrackingPlugin() + err = ValueError("specific error") + agent = _CrashingAgent(name="crash_agent", crash_error=err) + ctx = await _create_ctx(agent, plugins=[plugin]) + + with pytest.raises(ValueError, match="specific error"): + _ = [e async for e in agent.run_async(ctx)] + + @pytest.mark.asyncio + async def test_agent_error_callback_not_fired_on_success(self): + """Error callback does NOT fire when agent succeeds.""" + plugin = _ErrorTrackingPlugin() + agent = _SuccessAgent(name="good_agent") + ctx = await _create_ctx(agent, plugins=[plugin]) + + events = [e async for e in agent.run_async(ctx)] + + assert len(events) > 0 + assert len(plugin.agent_errors) == 0 + # after_agent_callback should still fire on success + assert plugin.after_agent_called + + @pytest.mark.asyncio + async def test_cancelled_error_does_not_trigger_agent_error_callback( + self, + ): + """asyncio.CancelledError (BaseException) does NOT trigger error callback.""" + + class _CancellingAgent(BaseAgent): + + @override + async def _run_async_impl(self, ctx): + raise asyncio.CancelledError() + yield + + @override + async def _run_live_impl(self, ctx): + raise asyncio.CancelledError() + yield + + plugin = _ErrorTrackingPlugin() + agent = _CancellingAgent(name="cancel_agent") + ctx = await _create_ctx(agent, plugins=[plugin]) + + with pytest.raises(asyncio.CancelledError): + _ = [e async for e in agent.run_async(ctx)] + + assert len(plugin.agent_errors) == 0 + + +# --------------------------------------------------------------------------- +# Runner-level error callback tests +# --------------------------------------------------------------------------- + + +class TestRunErrorCallback: + """Tests for on_run_error_callback in runners.py.""" + + @pytest.mark.asyncio + async def test_run_error_callback_fires_on_crash(self): + """on_run_error_callback fires when execute_fn raises.""" + from google.adk.runners import Runner + + plugin = _ErrorTrackingPlugin() + agent = _CrashingAgent(name="crash_agent") + runner = Runner( + agent=agent, + app_name="test_app", + session_service=InMemorySessionService(), + plugins=[plugin], + ) + session = await runner.session_service.create_session( + app_name="test_app", user_id="test_user" + ) + + with pytest.raises(RuntimeError, match="agent crashed"): + _ = [ + e + async for e in runner.run_async( + user_id="test_user", + session_id=session.id, + new_message=types.Content(parts=[types.Part(text="hello")]), + ) + ] + + assert len(plugin.run_errors) == 1 + assert str(plugin.run_errors[0]) == "agent crashed" + + @pytest.mark.asyncio + async def test_after_run_not_called_on_crash(self): + """after_run_callback (success-only) is NOT called on failure.""" + from google.adk.runners import Runner + + plugin = _ErrorTrackingPlugin() + agent = _CrashingAgent(name="crash_agent") + runner = Runner( + agent=agent, + app_name="test_app", + session_service=InMemorySessionService(), + plugins=[plugin], + ) + session = await runner.session_service.create_session( + app_name="test_app", user_id="test_user" + ) + + with pytest.raises(RuntimeError): + _ = [ + e + async for e in runner.run_async( + user_id="test_user", + session_id=session.id, + new_message=types.Content(parts=[types.Part(text="hello")]), + ) + ] + + assert not plugin.after_run_called + + @pytest.mark.asyncio + async def test_run_error_callback_not_fired_on_success(self): + """on_run_error_callback does NOT fire on success.""" + from google.adk.runners import Runner + + plugin = _ErrorTrackingPlugin() + agent = _SuccessAgent(name="good_agent") + runner = Runner( + agent=agent, + app_name="test_app", + session_service=InMemorySessionService(), + plugins=[plugin], + ) + session = await runner.session_service.create_session( + app_name="test_app", user_id="test_user" + ) + + events = [ + e + async for e in runner.run_async( + user_id="test_user", + session_id=session.id, + new_message=types.Content(parts=[types.Part(text="hello")]), + ) + ] + + assert len(events) > 0 + assert len(plugin.run_errors) == 0 + assert plugin.after_run_called + + +# --------------------------------------------------------------------------- +# Exactly-once-per-layer tests +# --------------------------------------------------------------------------- + + +class TestExactlyOncePerLayer: + """Verify each error callback fires exactly once at its own layer.""" + + @pytest.mark.asyncio + async def test_agent_crash_fires_both_callbacks_once_each(self): + """A crashing agent fires on_agent_error_callback once AND + on_run_error_callback once (the re-raised exception propagates).""" + from google.adk.runners import Runner + + plugin = _ErrorTrackingPlugin() + agent = _CrashingAgent(name="crash_agent") + runner = Runner( + agent=agent, + app_name="test_app", + session_service=InMemorySessionService(), + plugins=[plugin], + ) + session = await runner.session_service.create_session( + app_name="test_app", user_id="test_user" + ) + + with pytest.raises(RuntimeError, match="agent crashed"): + _ = [ + e + async for e in runner.run_async( + user_id="test_user", + session_id=session.id, + new_message=types.Content(parts=[types.Part(text="hello")]), + ) + ] + + # Agent error callback: exactly 1 call + assert len(plugin.agent_errors) == 1 + assert plugin.agent_errors[0][0] == "crash_agent" + + # Run error callback: exactly 1 call (same exception bubbled up) + assert len(plugin.run_errors) == 1 + assert plugin.run_errors[0] is plugin.agent_errors[0][1] + + # Neither after callback should fire + assert not plugin.after_agent_called + assert not plugin.after_run_called + + +# --------------------------------------------------------------------------- +# PluginManager dispatch tests +# --------------------------------------------------------------------------- + + +class TestPluginManagerErrorCallbackDispatch: + """Test PluginManager correctly dispatches the new error callbacks.""" + + @pytest.mark.asyncio + async def test_run_on_agent_error_callback_dispatches(self): + """run_on_agent_error_callback calls all plugins.""" + plugin1 = _ErrorTrackingPlugin(name="p1") + plugin2 = _ErrorTrackingPlugin(name="p2") + pm = PluginManager(plugins=[plugin1, plugin2]) + + mock_agent = Mock(spec=BaseAgent) + mock_agent.name = "test_agent" + mock_ctx = Mock(spec=CallbackContext) + err = RuntimeError("boom") + + await pm.run_on_agent_error_callback( + agent=mock_agent, + callback_context=mock_ctx, + error=err, + ) + + assert len(plugin1.agent_errors) == 1 + assert len(plugin2.agent_errors) == 1 + + @pytest.mark.asyncio + async def test_run_on_run_error_callback_dispatches(self): + """run_on_run_error_callback calls all plugins.""" + plugin1 = _ErrorTrackingPlugin(name="p1") + plugin2 = _ErrorTrackingPlugin(name="p2") + pm = PluginManager(plugins=[plugin1, plugin2]) + + mock_ctx = Mock(spec=InvocationContext) + err = RuntimeError("boom") + + await pm.run_on_run_error_callback( + invocation_context=mock_ctx, + error=err, + ) + + assert len(plugin1.run_errors) == 1 + assert len(plugin2.run_errors) == 1 + + @pytest.mark.asyncio + async def test_agent_error_callback_does_not_short_circuit(self): + """on_agent_error_callback is notification-only: a non-None return + from one plugin does NOT skip subsequent plugins.""" + + class _ReturningPlugin(BasePlugin): + __test__ = False + + def __init__(self, name): + super().__init__(name) + self.agent_error_called = False + + async def on_agent_error_callback(self, **kwargs): + self.agent_error_called = True + return "should be ignored" + + p1 = _ReturningPlugin(name="p1") + p2 = _ReturningPlugin(name="p2") + pm = PluginManager(plugins=[p1, p2]) + + await pm.run_on_agent_error_callback( + agent=Mock(spec=BaseAgent), + callback_context=Mock(spec=CallbackContext), + error=RuntimeError("x"), + ) + + # Both plugins must be called even though p1 returns non-None. + assert p1.agent_error_called + assert p2.agent_error_called + + @pytest.mark.asyncio + async def test_run_error_callback_does_not_short_circuit(self): + """on_run_error_callback is notification-only: a non-None return + from one plugin does NOT skip subsequent plugins.""" + + class _ReturningPlugin(BasePlugin): + __test__ = False + + def __init__(self, name): + super().__init__(name) + self.run_error_called = False + + async def on_run_error_callback(self, **kwargs): + self.run_error_called = True + return "should be ignored" + + p1 = _ReturningPlugin(name="p1") + p2 = _ReturningPlugin(name="p2") + pm = PluginManager(plugins=[p1, p2]) + + await pm.run_on_run_error_callback( + invocation_context=Mock(spec=InvocationContext), + error=RuntimeError("x"), + ) + + # Both plugins must be called even though p1 returns non-None. + assert p1.run_error_called + assert p2.run_error_called + + @pytest.mark.asyncio + async def test_plugin_callback_failure_does_not_mask_app_error(self): + """When a plugin's error callback raises, iteration continues + and the original application exception is what the caller sees.""" + + class _FailingPlugin(BasePlugin): + __test__ = False + + def __init__(self, name): + super().__init__(name) + self.agent_error_called = False + self.run_error_called = False + + async def on_agent_error_callback(self, **kwargs): + self.agent_error_called = True + raise ValueError("plugin boom") + + async def on_run_error_callback(self, **kwargs): + self.run_error_called = True + raise ValueError("plugin boom") + + p1 = _FailingPlugin(name="p1") + p2 = _ErrorTrackingPlugin(name="p2") + pm = PluginManager(plugins=[p1, p2]) + + # Agent error callback: p1 raises, p2 must still be notified. + mock_agent = Mock(spec=BaseAgent) + mock_agent.name = "test_agent" + await pm.run_on_agent_error_callback( + agent=mock_agent, + callback_context=Mock(spec=CallbackContext), + error=RuntimeError("app crash"), + ) + assert p1.agent_error_called + assert len(p2.agent_errors) == 1 + + # Run error callback: same behavior. + await pm.run_on_run_error_callback( + invocation_context=Mock(spec=InvocationContext), + error=RuntimeError("app crash"), + ) + assert p1.run_error_called + assert len(p2.run_errors) == 1 + + @pytest.mark.asyncio + async def test_original_exception_propagates_despite_plugin_failure( + self, + ): + """End-to-end: a crashing plugin error callback does not mask + the original agent exception seen by the caller.""" + + class _FailingPlugin(BasePlugin): + __test__ = False + + def __init__(self, name): + super().__init__(name) + + async def on_agent_error_callback(self, **kwargs): + raise ValueError("plugin internal error") + + plugin = _FailingPlugin(name="bad_plugin") + agent = _CrashingAgent(name="crash_agent") + ctx = await _create_ctx(agent, plugins=[plugin]) + + # The caller must see the original RuntimeError("agent crashed"), + # NOT the plugin's ValueError. + with pytest.raises(RuntimeError, match="agent crashed"): + _ = [e async for e in agent.run_async(ctx)] + + @pytest.mark.asyncio + async def test_original_exception_propagates_despite_run_plugin_failure( + self, + ): + """End-to-end: a crashing plugin on_run_error_callback does not mask + the original agent exception seen by the runner caller.""" + from google.adk.runners import Runner + + class _FailingRunPlugin(BasePlugin): + __test__ = False + + def __init__(self, name): + super().__init__(name) + + async def on_run_error_callback(self, **kwargs): + raise ValueError("plugin internal error") + + plugin = _FailingRunPlugin(name="bad_plugin") + agent = _CrashingAgent(name="crash_agent") + runner = Runner( + agent=agent, + app_name="test_app", + session_service=InMemorySessionService(), + plugins=[plugin], + ) + session = await runner.session_service.create_session( + app_name="test_app", user_id="test_user" + ) + + # The caller must see the original RuntimeError("agent crashed"), + # NOT the plugin's ValueError. + with pytest.raises(RuntimeError, match="agent crashed"): + _ = [ + e + async for e in runner.run_async( + user_id="test_user", + session_id=session.id, + new_message=types.Content(parts=[types.Part(text="hello")]), + ) + ]