Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.reactivex.rxjava3.core.Completable;
Expand All @@ -32,10 +33,13 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* An in-memory implementation of {@link BaseSessionService} assuming {@link Session} objects are
Expand All @@ -49,6 +53,23 @@
* during retrieval operations ({@code getSession}, {@code createSession}).
*/
public final class InMemorySessionService implements BaseSessionService {

private static final Logger log = LoggerFactory.getLogger(InMemorySessionService.class);

/**
* Reserved session-state keys that are managed internally by the ADK framework. Callers are not
* permitted to set or override these keys through the public API (initial session state or
* per-run stateDelta). Allowing external writes to these keys would let an untrusted caller steer
* internal framework behaviour, such as hijacking the code-execution session identifier used by
* {@code VertexAiCodeExecutor}.
*/
private static final Set<String> RESERVED_STATE_KEYS =
ImmutableSet.of(
"_code_execution_context",
"_code_executor_input_files",
"_code_executor_error_counts",
"_code_execution_results");

// Structure: appName -> userId -> sessionId -> Session
private final ConcurrentMap<String, ConcurrentMap<String, ConcurrentMap<String, Session>>>
sessions;
Expand All @@ -65,6 +86,31 @@ public InMemorySessionService() {
this.appState = new ConcurrentHashMap<>();
}

/**
* Removes reserved internal keys from a caller-supplied state map before it is persisted.
* Logs a warning for each key that is dropped.
*
* @param state The caller-supplied state map (may be null).
* @return A new {@link ConcurrentHashMap} containing only the non-reserved entries.
*/
private static ConcurrentMap<String, Object> sanitizeCallerState(
@Nullable Map<String, Object> state) {
if (state == null) {
return new ConcurrentHashMap<>();
}
ConcurrentMap<String, Object> sanitized = new ConcurrentHashMap<>();
state.forEach(
(key, value) -> {
if (RESERVED_STATE_KEYS.contains(key)) {
log.warn(
"Caller attempted to set reserved internal state key '{}'; ignoring.", key);
} else {
sanitized.put(key, value);
}
});
return sanitized;
}

@Override
public Single<Session> createSession(
String appName,
Expand All @@ -89,9 +135,8 @@ public Single<Session> createSession(
.filter(s -> !s.isEmpty())
.orElseGet(() -> UUID.randomUUID().toString());

// Ensure state map and events list are mutable for the new session
ConcurrentMap<String, Object> initialState =
(state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state);
// Sanitize caller-supplied state: strip reserved internal keys before persisting.
ConcurrentMap<String, Object> initialState = sanitizeCallerState(state);

// Assuming Session constructor or setters allow setting these mutable collections
Session newSession =
Expand Down Expand Up @@ -268,6 +313,12 @@ public Single<Event> appendEvent(Session session, Event event) {
.put(userStateKey, value);
}
} else {
// Reject writes to reserved internal keys from any external stateDelta.
if (RESERVED_STATE_KEYS.contains(key)) {
log.warn(
"stateDelta contains reserved internal key '{}'; ignoring write.", key);
return;
}
if (value == State.REMOVED) {
session.state().remove(key);
} else {
Expand Down