diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 7183fdfc6b..ea29d1a18c 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -1988,6 +1988,7 @@ def __init__( self._startup_error: Optional[Exception] = None self._is_shutting_down = False self._setup_lock = None + self._user_credentials = credentials self._credentials = credentials self.client = None self._loop_state_by_loop: dict[asyncio.AbstractEventLoop, _LoopState] = {} @@ -2106,6 +2107,10 @@ def get_credentials(): ) return creds + # Note: this read-then-write is not locked. If two event loops + # race here both will resolve ADC and write back the same creds. + # This is benign — the result is idempotent — so we accept the + # race rather than adding a lock for a one-time init path. if self._credentials is None: self._credentials = await loop.run_in_executor( self._executor, get_credentials @@ -2196,13 +2201,18 @@ async def _lazy_setup(self, **kwargs) -> None: self.offloader = None if self.config.gcs_bucket_name: + # GCSOffloader always creates a storage.Client eagerly + # (line 1329: storage_client or storage.Client(...)). + # Pass credentials so it uses the same auth as the other + # clients; omit when None to let it use ADC. + gcs_kwargs = {"project": self.project_id} + if self._credentials is not None: + gcs_kwargs["credentials"] = self._credentials self.offloader = GCSOffloader( self.project_id, self.config.gcs_bucket_name, self._executor, - storage_client=storage.Client( - project=self.project_id, credentials=self._credentials - ), + storage_client=storage.Client(**gcs_kwargs), ) self.parser = HybridContentParser( @@ -2536,6 +2546,18 @@ def __getstate__(self): state["_startup_error"] = None state["_is_shutting_down"] = False state["_init_pid"] = 0 + # _credentials is always runtime-resolved; clear unconditionally. + state["_credentials"] = None + # Preserve _user_credentials if they are picklable (e.g., + # service-account, AnonymousCredentials). Drop only when + # pickle would fail (e.g., compute_engine.Credentials holding + # a requests.Session). + import pickle as _pickle + + try: + _pickle.dumps(state.get("_user_credentials")) + except Exception: + state["_user_credentials"] = None return state def __setstate__(self, state): @@ -2543,6 +2565,13 @@ def __setstate__(self, state): # Backfill keys that may be absent in pickled state from older # code versions so _ensure_started does not raise AttributeError. state.setdefault("_init_pid", 0) + state.setdefault("_user_credentials", None) + state.setdefault("_credentials", None) + # Restore _credentials from _user_credentials if available so + # _create_loop_state uses the user's identity. When both are + # None (non-picklable credentials were dropped), ADC is used. + if state["_credentials"] is None and state["_user_credentials"] is not None: + state["_credentials"] = state["_user_credentials"] self.__dict__.update(state) def _reset_runtime_state(self) -> None: @@ -2597,6 +2626,11 @@ def _reset_runtime_state(self) -> None: self._startup_error = None self._is_shutting_down = False self._init_pid = os.getpid() + # For ADC-resolved credentials, clear so they are re-resolved + # in the child process. For user-provided credentials, keep + # the original object — we cannot re-create it. The user is + # responsible for providing fork-safe credentials if needed. + self._credentials = self._user_credentials async def __aenter__(self) -> BigQueryAgentAnalyticsPlugin: await self._ensure_started() diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 05b8976da2..3b723b4ceb 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -2093,6 +2093,56 @@ async def test_pickle_safety(self, mock_auth_default, mock_bq_client): finally: await plugin.shutdown() + @pytest.mark.asyncio + async def test_pickle_preserves_picklable_credentials( + self, mock_auth_default, mock_bq_client + ): + """Picklable user credentials survive pickle/unpickle.""" + import pickle + + picklable_creds = FakeCredentials() + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, + DATASET_ID, + table_id=TABLE_ID, + credentials=picklable_creds, + ) + pickled = pickle.dumps(plugin) + unpickled = pickle.loads(pickled) + # User-provided picklable credentials are preserved. + assert unpickled._user_credentials is not None + assert unpickled._credentials is not None + await plugin.shutdown() + + @pytest.mark.asyncio + async def test_pickle_drops_non_picklable_credentials( + self, mock_auth_default, mock_bq_client + ): + """Non-picklable user credentials are dropped gracefully.""" + import pickle + + class NonPicklableCreds(google.auth.credentials.Credentials): + + def refresh(self, request): + pass + + def __getstate__(self): + raise TypeError("cannot pickle") + + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, + DATASET_ID, + table_id=TABLE_ID, + credentials=NonPicklableCreds(), + ) + # Should not raise — non-picklable credentials are dropped. + pickled = pickle.dumps(plugin) + unpickled = pickle.loads(pickled) + # Credentials fall back to None (ADC on next use). + assert unpickled._user_credentials is None + assert unpickled._credentials is None + await plugin.shutdown() + @pytest.mark.asyncio async def test_span_hierarchy_llm_call( self,