Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,7 @@ def __init__(
self._startup_error: Optional[Exception] = None
self._is_shutting_down = False
self._setup_lock = None
self._user_credentials = credentials
self._credentials = credentials
self.client = None
self._loop_state_by_loop: dict[asyncio.AbstractEventLoop, _LoopState] = {}
Expand Down Expand Up @@ -2106,6 +2107,10 @@ def get_credentials():
)
return creds

# Note: this read-then-write is not locked. If two event loops
# race here both will resolve ADC and write back the same creds.
# This is benign — the result is idempotent — so we accept the
# race rather than adding a lock for a one-time init path.
if self._credentials is None:
self._credentials = await loop.run_in_executor(
self._executor, get_credentials
Expand Down Expand Up @@ -2196,13 +2201,18 @@ async def _lazy_setup(self, **kwargs) -> None:

self.offloader = None
if self.config.gcs_bucket_name:
# GCSOffloader always creates a storage.Client eagerly
# (line 1329: storage_client or storage.Client(...)).
# Pass credentials so it uses the same auth as the other
# clients; omit when None to let it use ADC.
gcs_kwargs = {"project": self.project_id}
if self._credentials is not None:
gcs_kwargs["credentials"] = self._credentials
self.offloader = GCSOffloader(
self.project_id,
self.config.gcs_bucket_name,
self._executor,
storage_client=storage.Client(
project=self.project_id, credentials=self._credentials
),
storage_client=storage.Client(**gcs_kwargs),
)

self.parser = HybridContentParser(
Expand Down Expand Up @@ -2536,13 +2546,32 @@ def __getstate__(self):
state["_startup_error"] = None
state["_is_shutting_down"] = False
state["_init_pid"] = 0
# _credentials is always runtime-resolved; clear unconditionally.
state["_credentials"] = None
# Preserve _user_credentials if they are picklable (e.g.,
# service-account, AnonymousCredentials). Drop only when
# pickle would fail (e.g., compute_engine.Credentials holding
# a requests.Session).
import pickle as _pickle

try:
_pickle.dumps(state.get("_user_credentials"))
except Exception:
state["_user_credentials"] = None
return state

def __setstate__(self, state):
"""Custom unpickling to restore state."""
# Backfill keys that may be absent in pickled state from older
# code versions so _ensure_started does not raise AttributeError.
state.setdefault("_init_pid", 0)
state.setdefault("_user_credentials", None)
state.setdefault("_credentials", None)
# Restore _credentials from _user_credentials if available so
# _create_loop_state uses the user's identity. When both are
# None (non-picklable credentials were dropped), ADC is used.
if state["_credentials"] is None and state["_user_credentials"] is not None:
state["_credentials"] = state["_user_credentials"]
self.__dict__.update(state)

def _reset_runtime_state(self) -> None:
Expand Down Expand Up @@ -2597,6 +2626,11 @@ def _reset_runtime_state(self) -> None:
self._startup_error = None
self._is_shutting_down = False
self._init_pid = os.getpid()
# For ADC-resolved credentials, clear so they are re-resolved
# in the child process. For user-provided credentials, keep
# the original object — we cannot re-create it. The user is
# responsible for providing fork-safe credentials if needed.
self._credentials = self._user_credentials

async def __aenter__(self) -> BigQueryAgentAnalyticsPlugin:
await self._ensure_started()
Expand Down
50 changes: 50 additions & 0 deletions tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,6 +2093,56 @@ async def test_pickle_safety(self, mock_auth_default, mock_bq_client):
finally:
await plugin.shutdown()

@pytest.mark.asyncio
async def test_pickle_preserves_picklable_credentials(
self, mock_auth_default, mock_bq_client
):
"""Picklable user credentials survive pickle/unpickle."""
import pickle

picklable_creds = FakeCredentials()
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID,
DATASET_ID,
table_id=TABLE_ID,
credentials=picklable_creds,
)
pickled = pickle.dumps(plugin)
unpickled = pickle.loads(pickled)
# User-provided picklable credentials are preserved.
assert unpickled._user_credentials is not None
assert unpickled._credentials is not None
await plugin.shutdown()

@pytest.mark.asyncio
async def test_pickle_drops_non_picklable_credentials(
self, mock_auth_default, mock_bq_client
):
"""Non-picklable user credentials are dropped gracefully."""
import pickle

class NonPicklableCreds(google.auth.credentials.Credentials):

def refresh(self, request):
pass

def __getstate__(self):
raise TypeError("cannot pickle")

plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID,
DATASET_ID,
table_id=TABLE_ID,
credentials=NonPicklableCreds(),
)
# Should not raise — non-picklable credentials are dropped.
pickled = pickle.dumps(plugin)
unpickled = pickle.loads(pickled)
# Credentials fall back to None (ADC on next use).
assert unpickled._user_credentials is None
assert unpickled._credentials is None
await plugin.shutdown()

@pytest.mark.asyncio
async def test_span_hierarchy_llm_call(
self,
Expand Down