diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsHttpClient.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsHttpClient.java index 7ef671b84..b5fbc45cb 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsHttpClient.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsHttpClient.java @@ -190,14 +190,21 @@ private static Duration resolveCallTimeout(HttpOptions httpOptions) { public Flowable complete(LlmRequest llmRequest, boolean stream) { return Flowable.defer( () -> { + String effectiveModelName = llmRequest.model().orElse("?"); + logger.trace("Chat Completion Request Contents: {}", llmRequest.contents()); + llmRequest.config().ifPresent(c -> logger.trace("Chat Completion Request Config: {}", c)); + ChatCompletionsRequest dtoRequest = ChatCompletionsRequest.fromLlmRequest(llmRequest, stream); String jsonPayload = objectMapper.writeValueAsString(dtoRequest); - logger.trace( - "Chat Completion Request: model={}, stream={}, messagesCount={}", - dtoRequest.model, - dtoRequest.stream, - dtoRequest.messages != null ? dtoRequest.messages.size() : 0); + logger.trace("Chat Completion Request JSON: {}", jsonPayload); + + if (stream) { + logger.debug( + "Sending streaming chat-completion request to model {}", effectiveModelName); + } else { + logger.debug("Sending chat-completion request to model {}", effectiveModelName); + } Request.Builder requestBuilder = new Request.Builder().url(completionsUrl).post(RequestBody.create(jsonPayload, JSON)); @@ -209,11 +216,7 @@ public Flowable complete(LlmRequest llmRequest, boolean stream) { requestBuilder.header("Content-Type", JSON.toString()); Request request = requestBuilder.build(); - if (stream) { - return createStreamingFlowable(request); - } else { - return createNonStreamingFlowable(request); - } + return stream ? createStreamingFlowable(request) : createNonStreamingFlowable(request); }); } @@ -274,10 +277,14 @@ public void onResponse(Call call, Response response) { // A single malformed chunk must not abort the entire stream. Log a // warning and continue. try { + logger.trace("Raw streaming chat-completion chunk: {}", data); ChatCompletionsResponse.ChatCompletionChunk chunk = objectMapper.readValue( data, ChatCompletionsResponse.ChatCompletionChunk.class); ImmutableList responses = collection.processChunk(chunk); + if (!responses.isEmpty()) { + logger.trace("Responses to emit: {}", responses); + } for (LlmResponse resp : responses) { emitter.onNext(resp); } @@ -341,9 +348,12 @@ public void onResponse(Call call, Response response) { } String jsonResponse = body.string(); + logger.trace("Raw non-streaming chat-completion response: {}", jsonResponse); ChatCompletionsResponse.ChatCompletion completion = objectMapper.readValue(jsonResponse, ChatCompletionsResponse.ChatCompletion.class); - emitter.onNext(completion.toLlmResponse()); + LlmResponse llmResponse = completion.toLlmResponse(); + logger.trace("Response to emit: {}", llmResponse); + emitter.onNext(llmResponse); emitter.onComplete(); } catch (Exception e) { emitter.tryOnError(e); diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java index 523c04a5a..2ad0733b9 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java @@ -26,6 +26,7 @@ import com.google.adk.JsonBaseModel; import com.google.adk.models.LlmRequest; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; @@ -351,41 +352,43 @@ private static List processContent(Content content) { List toolCalls = new ArrayList<>(); List toolResponses = new ArrayList<>(); List refusals = new ArrayList<>(); - - content - .parts() - .ifPresent( - parts -> { - for (Part part : parts) { - if (part.text().isPresent()) { - // Text Parts may carry refusal content prefixed with REFUSAL_PREFIX. - ChatCompletionsCommon.RefusalSplit split = - ChatCompletionsCommon.parseRefusalPrefix(part.text().get()); - if (split.content() != null) { - ContentPart textPart = new ContentPart(); - textPart.type = "text"; - textPart.text = split.content(); - contentParts.add(textPart); - } - if (split.refusal() != null) { - refusals.add(split.refusal()); - } - } else if (part.inlineData().isPresent()) { - contentParts.add(processInlineDataPart(part)); - } else if (part.fileData().isPresent()) { - contentParts.add(processFileDataPart(part)); - } else if (part.functionCall().isPresent()) { - toolCalls.add(processFunctionCallPart(part)); - } else if (part.functionResponse().isPresent()) { - toolResponses.add(processFunctionResponsePart(part)); - } else if (part.executableCode().isPresent()) { - logger.warn("Executable code is not supported in Chat Completion conversion"); - } else if (part.codeExecutionResult().isPresent()) { - logger.warn( - "Code execution result is not supported in Chat Completion conversion"); - } - } - }); + // Capture a message-level thought_signature from the first text Part that carries one. + // This signature must be echoed back on subsequent turns to ensure proper round-tripping. + byte[] textThoughtSignature = null; + + if (content.parts().isPresent()) { + for (Part part : content.parts().get()) { + if (part.text().isPresent()) { + // Text Parts may carry refusal content prefixed with REFUSAL_PREFIX. + ChatCompletionsCommon.RefusalSplit split = + ChatCompletionsCommon.parseRefusalPrefix(part.text().get()); + if (split.content() != null) { + ContentPart textPart = new ContentPart(); + textPart.type = "text"; + textPart.text = split.content(); + contentParts.add(textPart); + } + if (split.refusal() != null) { + refusals.add(split.refusal()); + } + if (textThoughtSignature == null && part.thoughtSignature().isPresent()) { + textThoughtSignature = part.thoughtSignature().get(); + } + } else if (part.inlineData().isPresent()) { + contentParts.add(processInlineDataPart(part)); + } else if (part.fileData().isPresent()) { + contentParts.add(processFileDataPart(part)); + } else if (part.functionCall().isPresent()) { + toolCalls.add(processFunctionCallPart(part)); + } else if (part.functionResponse().isPresent()) { + toolResponses.add(processFunctionResponsePart(part)); + } else if (part.executableCode().isPresent()) { + logger.warn("Executable code is not supported in Chat Completion conversion"); + } else if (part.codeExecutionResult().isPresent()) { + logger.warn("Code execution result is not supported in Chat Completion conversion"); + } + } + } if (!toolResponses.isEmpty()) { return toolResponses; @@ -403,6 +406,14 @@ private static List processContent(Content content) { msg.content = new MessageContent(ImmutableList.copyOf(contentParts)); } } + // Round-trip the message-level thought_signature for assistant text responses. + if (textThoughtSignature != null) { + msg.extraContent = + ImmutableMap.of( + "google", + ImmutableMap.of( + "thought_signature", Base64.getEncoder().encodeToString(textThoughtSignature))); + } List messages = new ArrayList<>(); messages.add(msg); return messages; @@ -446,6 +457,10 @@ private static ContentPart processFileDataPart(Part part) { /** * Processes a function call part and returns a mapped ToolCall. * + *

If the source {@link Part} carries a {@code thoughtSignature}, it is round-tripped back out + * as a base64-encoded string in {@code extra_content.google.thought_signature} to satisfy + * endpoint requirements. + * * @param part The input part containing a requested function call or invocation. * @return The mapped function call tool call. */ @@ -464,6 +479,13 @@ private static ChatCompletionsCommon.ToolCall processFunctionCallPart(Part part) } } toolCall.function = function; + part.thoughtSignature() + .ifPresent( + sigBytes -> { + String sig = Base64.getEncoder().encodeToString(sigBytes); + toolCall.extraContent = + ImmutableMap.of("google", ImmutableMap.of("thought_signature", sig)); + }); return toolCall; } @@ -616,6 +638,13 @@ static class Message { /** See class definition for more details. */ public String refusal; + + /** + * Message-level additional parameters used by some providers. Used for round-tripping data like + * {@code extra_content.google.thought_signature}. + */ + @JsonProperty("extra_content") + public Map extraContent; } /** diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java index 61e7e8358..6cb25f38f 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java @@ -31,6 +31,7 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import java.util.Base64; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -354,6 +355,16 @@ static class Message { /** See class definition for more details. */ public Audio audio; + + /** + * Message-level additional parameters used by some providers. For example, Google Gemini's + * OpenAI-compatible {@code /chat/completions} endpoint emits {@code + * extra_content.google.thought_signature} on the assistant message (separately from any + * tool_call signatures) when the response is plain text; the signature must be echoed back on + * subsequent turns or Gemini may retry or loop. + */ + @JsonProperty("extra_content") + public Map extraContent; } /** @@ -528,6 +539,15 @@ static class ChatCompletionChunkCollection { private Usage usage; private final Map customMetadataMap = new HashMap<>(); + /** + * Base64-encoded thought_signature attached at the message level for the assistant text + * response, captured from any chunk that carries {@code + * delta.extra_content.google.thought_signature}. Gemini's OpenAI-compatible endpoint emits this + * on a dedicated chunk (alongside finish_reason=stop) for plain-text turns; if not + * round-tripped, Gemini may retry or loop on subsequent turns. + */ + private byte[] accumulatedTextThoughtSignature; + private ImmutableList getCustomMetadataList() { ImmutableList.Builder list = ImmutableList.builder(); for (Entry entry : customMetadataMap.entrySet()) { @@ -566,15 +586,52 @@ public ImmutableList processChunk(ChatCompletionChunk chunk) { ImmutableList chunkParts = mapDeltaToParts(choice); - responses.add(buildPartialResponse(chunkParts)); + // Emit a partial response only when this chunk's delta carried actual content. On the + // finish chunk, emit TWO non-partial events to mirror Gemini.processStreamingResponses + // so consumers (Runner, Web UI, evals, plugins) see the same event sequence regardless of + // which model driver produced it: + // (A) An aggregated-text event (only when text was streamed) carrying the full + // accumulated text in a single Part, with NO finishReason. Consumers that present + // streaming output incrementally use this as the "commit" signal for the + // accumulated text bubble. + // (B) A metadata-final event carrying the FinishReason and the accumulated tool_call + // Parts (with args parsed) but NO text. The accumulated text is intentionally + // excluded from the metadata-final so consumers that append on every event do not + // double-render the just-committed text. + if (!chunkParts.isEmpty()) { + responses.add(buildPartialResponse(chunkParts)); + } if (choice.finishReason != null && !choice.finishReason.isEmpty()) { + if (contentParts.length() > 0) { + responses.add(buildAggregatedTextResponse()); + } responses.add(buildFinalResponse(choice)); } return responses.build(); } + /** + * Builds the aggregated-text event for the finish chunk: a non-partial response whose content + * is a single text Part containing the fully-accumulated streamed text, with NO finishReason. + * Mirrors {@code Gemini.processStreamingResponses}'s aggregated-text emit at close-of-stream. + * See {@link #processChunk} for rationale. + */ + private LlmResponse buildAggregatedTextResponse() { + Part textPart = Part.fromText(contentParts.toString()); + if (accumulatedTextThoughtSignature != null) { + textPart = textPart.toBuilder().thoughtSignature(accumulatedTextThoughtSignature).build(); + } + ImmutableList parts = ImmutableList.of(textPart); + return LlmResponse.builder() + .content(Content.builder().role(this.role).parts(parts).build()) + .modelVersion(this.model) + .usageMetadata(mapUsage(this.usage)) + .customMetadata(getCustomMetadataList()) + .build(); + } + /** * Updates the internal state (model, usage, metadata) from the chunk. * @@ -633,13 +690,34 @@ private ImmutableList mapDeltaToParts(ChunkChoice choice) { ImmutableList.Builder chunkParts = ImmutableList.builder(); if (choice.delta != null) { updateRole(choice.delta.role); + captureMessageThoughtSignature(choice.delta.extraContent); appendContent(choice.delta.content, chunkParts); appendRefusal(choice.delta.refusal, chunkParts); - appendToolCalls(choice.delta.toolCalls, chunkParts); + accumulateToolCalls(choice.delta.toolCalls); } return chunkParts.build(); } + /** + * Reads the message-level {@code extra_content.google.thought_signature} (if present) from a + * streaming delta and stores it for later attachment to the accumulated text Part. Gemini's + * OpenAI-compatible endpoint emits this signature on a final chunk that may carry no other + * content; without round-tripping it on the next turn, Gemini may retry the response. + */ + private void captureMessageThoughtSignature(@Nullable Map extraContent) { + if (extraContent == null || !extraContent.containsKey("google")) { + return; + } + Object googleObj = extraContent.get("google"); + if (!(googleObj instanceof Map googleMap)) { + return; + } + Object sigObj = googleMap.get("thought_signature"); + if (sigObj instanceof String sig) { + accumulatedTextThoughtSignature = Base64.getDecoder().decode(sig); + } + } + /** * Updates the accumulated role if the delta contains a valid role. * @@ -684,20 +762,16 @@ private void appendRefusal(@Nullable String refusal, ImmutableList.Builder } /** - * Appends tool calls to the accumulator and adds them to the chunk parts. + * Accumulates streaming tool calls across multiple chunks. To prevent downstream flows from + * dispatching the same tool multiple times, partial tool calls are NOT emitted. The + * fully-accumulated tool call is emitted exactly once via {@link #buildFinalResponse}. * * @param toolCalls the list of tool calls, or {@code null}. - * @param chunkParts the list of parts for this chunk. */ - private void appendToolCalls( - @Nullable List toolCalls, - ImmutableList.Builder chunkParts) { + private void accumulateToolCalls(@Nullable List toolCalls) { if (toolCalls != null) { for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) { - Part p = upsertToolCall(toolCall); - if (p != null) { - chunkParts.add(p); - } + upsertToolCall(toolCall); } } } @@ -719,14 +793,18 @@ private LlmResponse buildPartialResponse(List chunkParts) { } /** - * Builds the final {@link LlmResponse} with all accumulated content. + * Builds the final non-partial {@link LlmResponse} for a streaming turn. Carries the + * FinishReason and the accumulated tool calls, but excludes the accumulated text (which is + * delivered exclusively via per-chunk partial responses). See {@link #processChunk} for the + * rationale. * * @param choice the choice containing the finish reason. * @return the final response. */ private LlmResponse buildFinalResponse(ChunkChoice choice) { + ImmutableList finalParts = getFinalToolCallParts(); return LlmResponse.builder() - .content(Content.builder().role(this.role).parts(getContentParts()).build()) + .content(Content.builder().role(this.role).parts(finalParts).build()) .finishReason(ChatCompletionsResponse.mapFinishReason(choice.finishReason)) .modelVersion(this.model) .usageMetadata(mapUsage(this.usage)) @@ -735,18 +813,61 @@ private LlmResponse buildFinalResponse(ChunkChoice choice) { } /** - * Upserts a tool call from a chunk into the collection and returns the part for this chunk. + * Returns ONLY the accumulated tool_call Parts. Used by {@link #buildFinalResponse}; the + * accumulated text is emitted via per-chunk partial responses. + * + *

If a server emits non-contiguous tool_call indices (e.g. keys 0 and 2 but not 1), the + * present keys are iterated in sorted order and squashed into dense list positions (0 and 1) in + * the returned list. + * + *

Tool-call Parts carry their own per-tool-call thought_signature (attached via {@link + * ChatCompletionsCommon.ToolCall#applyThoughtSignature}). If a Part lacks one, the + * message-level {@code accumulatedTextThoughtSignature} (if any) is backfilled so the assistant + * turn round-trips with a signature. An existing per-tool-call signature is never overwritten. + */ + private ImmutableList getFinalToolCallParts() { + ImmutableList.Builder parts = ImmutableList.builder(); + ImmutableList sortedKeys = ImmutableList.sortedCopyOf(toolCallParts.keySet()); + for (int index : sortedKeys) { + Part part = toolCallParts.get(index); + if (part != null && part.functionCall().isPresent()) { + FunctionCall fc = part.functionCall().get(); + StringBuilder argsSb = toolCallArgsAccumulator.get(index); + if (argsSb != null && argsSb.length() > 0) { + try { + Map args = + objectMapper.readValue( + argsSb.toString(), new TypeReference>() {}); + fc = fc.toBuilder().args(args).build(); + part = part.toBuilder().functionCall(fc).build(); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException( + "Failed to parse final tool call arguments: " + argsSb, e); + } + } + } + if (part != null + && accumulatedTextThoughtSignature != null + && part.thoughtSignature().isEmpty()) { + part = part.toBuilder().thoughtSignature(accumulatedTextThoughtSignature).build(); + } + parts.add(part); + } + return parts.build(); + } + + /** + * Upserts a tool call from a chunk into the accumulated state. Partial tool calls are NOT + * emitted per chunk (see {@link #accumulateToolCalls} for the rationale -- the + * fully-accumulated tool call is emitted exactly once via {@link #buildFinalResponse}). * * @param toolCall the tool call from the chunk. - * @return the {@link Part} to emit for this chunk, or {@code null} if it cannot be converted. */ - private Part upsertToolCall(ChatCompletionsCommon.ToolCall toolCall) { + private void upsertToolCall(ChatCompletionsCommon.ToolCall toolCall) { int index = toolCall.index != null ? toolCall.index : toolCallParts.size(); initializeToolCallState(index); updateAccumulatedToolCall(index, toolCall); - - return buildChunkToolCallPart(toolCall); } /** @@ -798,57 +919,5 @@ private void appendFunctionDetails( toolCallArgsAccumulator.get(index).append(function.arguments); } } - - /** - * Builds the {@link Part} for the current chunk's tool call. - * - * @param toolCall the tool call from the chunk. - * @return the {@link Part} for this chunk. - */ - private Part buildChunkToolCallPart(ChatCompletionsCommon.ToolCall toolCall) { - Part chunkPart = toolCall.toPart(); - if (chunkPart == null) { - FunctionCall.Builder chunkFcBuilder = FunctionCall.builder(); - if (toolCall.id != null) { - chunkFcBuilder.id(toolCall.id); - } - chunkPart = Part.builder().functionCall(chunkFcBuilder.build()).build(); - chunkPart = toolCall.applyThoughtSignature(chunkPart); - } - return chunkPart; - } - - private ImmutableList getContentParts() { - ImmutableList.Builder parts = ImmutableList.builder(); - if (contentParts.length() > 0) { - parts.add(Part.fromText(contentParts.toString())); - } - - // If a server sends keys 0 and 2 but not 1 then squash the indices and - // return parts at indices 0 and 1. - ImmutableList sortedKeys = ImmutableList.sortedCopyOf(toolCallParts.keySet()); - - for (int index : sortedKeys) { - Part part = toolCallParts.get(index); - if (part != null && part.functionCall().isPresent()) { - FunctionCall fc = part.functionCall().get(); - StringBuilder argsSb = toolCallArgsAccumulator.get(index); - if (argsSb != null && argsSb.length() > 0) { - try { - Map args = - objectMapper.readValue( - argsSb.toString(), new TypeReference>() {}); - fc = fc.toBuilder().args(args).build(); - part = part.toBuilder().functionCall(fc).build(); - } catch (JsonProcessingException e) { - throw new IllegalArgumentException( - "Failed to parse final tool call arguments: " + argsSb, e); - } - } - } - parts.add(part); - } - return parts.build(); - } } } diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsHttpClientTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsHttpClientTest.java index b5c7888cf..c799e4700 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsHttpClientTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsHttpClientTest.java @@ -33,12 +33,14 @@ import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionCall; import com.google.genai.types.HttpOptions; import com.google.genai.types.Part; import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.io.IOException; import java.lang.reflect.Field; import java.time.Duration; +import java.util.Base64; import okhttp3.Call; import okhttp3.Callback; import okhttp3.MediaType; @@ -609,4 +611,79 @@ private static OkHttpClient readInternalClient(ChatCompletionsHttpClient target) throw new LinkageError("Failed to read internal client", e); } } + + // -- thought_signature end-to-end through the HTTP layer. ------------------------------ + // + // A single round-trip test that covers the request encoder, the HTTP body writer, the + // response decoder, and the ToolCall.applyThoughtSignature site in one shot. Wider + // request- and response-side coverage lives in the unit tests in + // ChatCompletionsRequestTest and ChatCompletionsResponseTest. + + private static final byte[] httpSigBytes = {0x21, 0x22, 0x23, 0x24}; + private static final String HTTP_SIG_B64 = Base64.getEncoder().encodeToString(httpSigBytes); + + /** + * Round-trip: a {@link Part} with a {@code thoughtSignature} sent on the request must decode + * bytewise-equal on the response when the mock server echoes the same base64 string back. This is + * the strongest single regression guard for the thought_signature pipeline because it covers the + * request encoder, the HTTP body writer, the response decoder, and the {@link + * ChatCompletionsCommon.ToolCall#applyThoughtSignature} site in a single test. + */ + @Test + public void complete_nonStreaming_thoughtSignatureRoundTrip() throws Exception { + // Send a request with a function-call Part carrying httpSigBytes, then mock a response + // whose tool_call carries the same base64 signature, and assert the decoded bytes match. + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().id("call_rt").name("ping").build()) + .thoughtSignature(httpSigBytes) + .build())) + .build())) + .build(); + + String responseBody = + String.format( + """ + { + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [{ + "id": "call_rt", + "type": "function", + "function": { "name": "ping", "arguments": "{}" }, + "extra_content": { + "google": { "thought_signature": "%s" } + } + }] + }, + "finish_reason": "tool_calls" + }] + } + """, + HTTP_SIG_B64); + Response mockResponse = createMockResponse(responseBody, JSON); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Callback.class); + doNothing().when(mockCall).enqueue(callbackCaptor.capture()); + + TestSubscriber testSubscriber = client.complete(llmRequest, false).test(); + callbackCaptor.getValue().onResponse(mockCall, mockResponse); + testSubscriber.await(AWAIT_TIMEOUT.toMillis(), MILLISECONDS); + + testSubscriber.assertNoErrors(); + LlmResponse response = testSubscriber.values().get(0); + Part decodedToolPart = response.content().get().parts().get().get(0); + assertThat(decodedToolPart.functionCall().get().id()).hasValue("call_rt"); + assertThat(decodedToolPart.thoughtSignature().get()).isEqualTo(httpSigBytes); + } } diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java index 1f41189a2..eb00b3770 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java @@ -37,7 +37,9 @@ import com.google.genai.types.Tool; import com.google.genai.types.ToolConfig; import java.util.AbstractMap; +import java.util.Base64; import java.util.List; +import java.util.Map; import java.util.Set; import org.junit.Before; import org.junit.Test; @@ -676,4 +678,173 @@ public void testFromLlmRequest_withConfigResponseMimeTypeJson() throws Exception assertThat(request.responseFormat) .isInstanceOf(ChatCompletionsRequest.ResponseFormatJsonObject.class); } + + // ----- thought_signature round-trip on the request side ---------------------------------- + // + // The four chat source files share a single contract for round-tripping Gemini's + // thought_signature bytes back to the OpenAI-compatible endpoint: + // - Text Parts: Part.thoughtSignature() bytes (first text Part only) --> + // message.extra_content.google.thought_signature (base64 string). + // - functionCall Parts: Part.thoughtSignature() bytes --> + // toolCall.extra_content.google.thought_signature (base64 string). + // - Tool/role=tool turns: extra_content is dropped (the turn becomes a tool message and any + // captured signature is not echoed). + // + // The tests below exercise the encoding pipeline end-to-end via fromLlmRequest, complementing + // the existing DTO-level Jackson serialization test + // (testSerializeChatCompletionRequest_withToolCallsAndExtraContent) which uses a literal + // string and does NOT exercise byte[] handling or the conversion site. + + private static final byte[] signatureBytesText = {0x01, 0x02, 0x03, 0x04}; + private static final byte[] signatureBytesFnCall = {0x10, 0x20, 0x30, 0x40, 0x50}; + private static final byte[] signatureBytesSecondText = {(byte) 0xff, (byte) 0xfe}; + + /** + * Asserts {@code msg.extraContent == {google: {thought_signature: base64(expected)}}} so all + * thought_signature encode tests share a single, precise comparison and never fall into substring + * matching. + */ + private static void assertThoughtSignatureExtraContent( + Map extraContent, byte[] expected) { + assertThat(extraContent).isNotNull(); + assertThat(extraContent).containsKey("google"); + @SuppressWarnings("unchecked") // This code won't run in production and it is a JSON object. + Map google = (Map) extraContent.get("google"); + assertThat(google).containsKey("thought_signature"); + Object sigObj = google.get("thought_signature"); + assertThat(sigObj).isInstanceOf(String.class); + assertThat(Base64.getDecoder().decode((String) sigObj)).isEqualTo(expected); + } + + @Test + public void testFromLlmRequest_textPart_withThoughtSignature_encodesAsMessageExtraContent() + throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.builder() + .text("here is the answer") + .thoughtSignature(signatureBytesText) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + assertThat(msg.role).isEqualTo("assistant"); + assertThat(msg.content.getValue()).isEqualTo("here is the answer"); + assertThoughtSignatureExtraContent(msg.extraContent, signatureBytesText); + } + + @Test + public void testFromLlmRequest_multipleTextParts_firstSignatureWins() throws Exception { + // processContent captures only the FIRST text Part's signature. Verifies that a second + // signature on a later text Part is silently dropped, matching the source contract at + // ChatCompletionsRequest.processContent around line 377. + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.builder() + .text("first") + .thoughtSignature(signatureBytesText) + .build(), + Part.builder() + .text("second") + .thoughtSignature(signatureBytesSecondText) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + assertThoughtSignatureExtraContent(msg.extraContent, signatureBytesText); + } + + @Test + public void + testFromLlmRequest_functionCallPart_withThoughtSignature_encodesAsToolCallExtraContent() + throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_42") + .name("get_weather") + .args(ImmutableMap.of("city", "Tokyo")) + .build()) + .thoughtSignature(signatureBytesFnCall) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + assertThat(msg.toolCalls).hasSize(1); + ChatCompletionsCommon.ToolCall toolCall = msg.toolCalls.get(0); + assertThat(toolCall.id).isEqualTo("call_42"); + assertThat(toolCall.function.name).isEqualTo("get_weather"); + assertThoughtSignatureExtraContent(toolCall.extraContent, signatureBytesFnCall); + // The message-level extraContent must remain null when there is no text Part with a sig. + assertThat(msg.extraContent).isNull(); + } + + @Test + public void testFromLlmRequest_functionResponseTurn_dropsSignature() throws Exception { + // role=tool turns return early in processContent and yield zero or more "tool" Messages + // built from function responses. Any thought_signature on the source Parts -- which would + // not make sense on a tool turn anyway -- must NOT leak into the emitted tool Messages + // via extra_content. + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("tool") + .parts( + ImmutableList.of( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id("call_x") + .response(ImmutableMap.of("ok", true)) + .build()) + .thoughtSignature(signatureBytesText) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message toolMsg = request.messages.get(0); + assertThat(toolMsg.role).isEqualTo("tool"); + assertThat(toolMsg.extraContent).isNull(); + } } diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java index 367545207..352da6d8c 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletion; import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletionChunk; @@ -767,7 +768,11 @@ public void testChunkCollection_accumulatesMultipleToolCalls() throws Exception .modelVersion("") .build(); - LlmResponse finalResponse = responses.get(1); + // Tool call deltas are accumulated across chunks 1-4 without emitting partial responses. + // Chunk 5 (finish_reason=tool_calls) emits a single metadata-final response carrying the fully + // accumulated tool calls and the FinishReason. Size 1. + assertThat(responses).hasSize(1); + LlmResponse finalResponse = responses.get(0); assertThat(finalResponse).isEqualTo(expectedFinalResponse); } @@ -800,21 +805,31 @@ public void testChunkCollection_simpleText() throws Exception { collection.processChunk( objectMapper.readValue(chunk3Json, ChatCompletionsResponse.ChatCompletionChunk.class)); - LlmResponse expectedFinalResponse = + // For a text-only turn, the finish_reason chunk emits TWO non-partial responses: + // (A) an aggregated-text response with the full text but NO finishReason. + // (B) a metadata-final response with FinishReason=STOP and no text parts. + // See ChatCompletionsResponse.processChunk for rationale. Size 2. + LlmResponse expectedAggregatedTextResponse = LlmResponse.builder() .content( Content.builder() .role("") .parts(ImmutableList.of(Part.fromText("Hello World!"))) .build()) + .customMetadata(ImmutableList.of()) + .modelVersion("") + .build(); + LlmResponse expectedFinalResponse = + LlmResponse.builder() + .content(Content.builder().role("").parts(ImmutableList.of()).build()) .finishReason(new FinishReason(Known.STOP.toString())) .customMetadata(ImmutableList.of()) .modelVersion("") .build(); - LlmResponse finalResponse = responses.get(1); - - assertThat(finalResponse).isEqualTo(expectedFinalResponse); + assertThat(responses) + .containsExactly(expectedAggregatedTextResponse, expectedFinalResponse) + .inOrder(); } @Test @@ -838,21 +853,30 @@ public void testChunkCollection_withRefusal() throws Exception { collection.processChunk( objectMapper.readValue(chunk2Json, ChatCompletionsResponse.ChatCompletionChunk.class)); - LlmResponse expectedFinalResponse = + // Similar to testChunkCollection_simpleText: chunk 1 streams the refusal, then chunk 2 + // (finish_reason) emits an aggregated-text response with the full refusal text, followed + // by a metadata-final response with FinishReason and no text parts. Size 2. + LlmResponse expectedAggregatedTextResponse = LlmResponse.builder() .content( Content.builder() .role("") .parts(ImmutableList.of(Part.fromText("I cannot do that."))) .build()) + .customMetadata(ImmutableList.of()) + .modelVersion("") + .build(); + LlmResponse expectedFinalResponse = + LlmResponse.builder() + .content(Content.builder().role("").parts(ImmutableList.of()).build()) .finishReason(new FinishReason(Known.STOP.toString())) .customMetadata(ImmutableList.of()) .modelVersion("") .build(); - LlmResponse finalResponse = responses.get(1); - - assertThat(finalResponse).isEqualTo(expectedFinalResponse); + assertThat(responses) + .containsExactly(expectedAggregatedTextResponse, expectedFinalResponse) + .inOrder(); } @Test @@ -876,4 +900,303 @@ public void testChunkCollection_noChoices() throws Exception { assertThat(response.content()).isPresent(); assertThat(response.content().get().parts()).isEmpty(); } + + // ----- thought_signature decoding on the response side ----------------------------------- + // + // The response code maps wire-level extra_content.google.thought_signature (base64) onto + // Part.thoughtSignature() bytes across four conceptual paths. The tests below cover one + // canonical positive per path plus a single malformed-input tolerance check and one + // request-response byte-equality round-trip: + // 1. Non-streaming text: Message.extra_content is PARSED onto the DTO but is NOT + // attached to the output text Part. Characterized below as + // current behavior (likely a bug; see the test's TODO). + // 2. Non-streaming tool: ToolCall.extra_content --> tool-call Part.thoughtSignature + // (already covered by testToLlmResponse_thoughtSignature). + // 3. Streaming text: Per-chunk delta.extra_content captured into + // ChatCompletionChunkCollection.accumulatedTextThoughtSignature, + // attached to the aggregated text Part on the finish chunk. + // 4. Streaming tool: Per-chunk delta.tool_calls[i].extra_content applied to the + // accumulated tool-call Part; message-level signature backfills + // tool-call Parts that lack their own. + + private static final byte[] STREAMING_TEXT_SIGNATURE = {0x0a, 0x0b, 0x0c}; + private static final byte[] STREAMING_TOOL_SIGNATURE = {0x11, 0x12, 0x13, 0x14}; + private static final String STREAMING_TEXT_SIGNATURE_B64 = + Base64.getEncoder().encodeToString(STREAMING_TEXT_SIGNATURE); + private static final String STREAMING_TOOL_SIGNATURE_B64 = + Base64.getEncoder().encodeToString(STREAMING_TOOL_SIGNATURE); + + @Test + public void testToLlmResponse_nonStreamingText_messageLevelSignatureStaysOnDtoButNotOnOutputPart() + throws Exception { + // Characterizes a known asymmetry between the streaming and non-streaming paths: + // - Streaming: ChatCompletionChunkCollection.captureMessageThoughtSignature decodes + // extra_content.google.thought_signature from any delta and attaches + // it to the aggregated text Part (see buildAggregatedTextResponse). + // - Non-streaming: mapMessageToParts does NOT decode Message.extraContent at all; the + // signature parses onto the Message DTO but never lands on any Part. + // + // Gemini's OpenAI-compatible endpoint emits a message-level thought_signature on + // assistant text responses; if not round-tripped on the next turn, Gemini may retry or + // loop. Today the non-streaming branch silently drops the signature, which is likely a + // bug. This test pins the CURRENT behavior so it is visible to future readers and so any + // future fix (propagating the signature to the text Part, mirroring streaming) flips + // this test from passing to failing -- forcing an intentional, documented update. + // + // TODO(b/...): consider attaching message.extraContent.google.thought_signature to the + // output text Part to match the streaming-path contract. If/when that fix lands, this + // test should be updated to assert that textPart.thoughtSignature() has the decoded + // bytes (compare with + // testChunkCollection_streamingText_messageLevelSignatureAttachesToAggregatedTextPart). + String json = + String.format( + """ + { + "choices": [{ + "message": { + "role": "assistant", + "content": "Hello world", + "extra_content": { + "google": { + "thought_signature": "%s" + } + } + }, + "finish_reason": "stop" + }] + } + """, + STREAMING_TEXT_SIGNATURE_B64); + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); + + // The DTO field IS populated (Jackson parses the JSON) ... + @SuppressWarnings("unchecked") + Map google = + (Map) completion.choices.get(0).message.extraContent.get("google"); + assertThat(google).containsEntry("thought_signature", STREAMING_TEXT_SIGNATURE_B64); + + // ... but mapMessageToParts does NOT propagate it to the output text Part. + LlmResponse response = completion.toLlmResponse(); + Part textPart = response.content().get().parts().get().get(0); + assertThat(textPart.text()).hasValue("Hello world"); + assertThat(textPart.thoughtSignature()).isEmpty(); + } + + @Test + public void testToLlmResponse_nonStreamingText_malformedExtraContent_doesNotCrash() + throws Exception { + // Defensive: a non-string thought_signature (e.g. the number 42) on a non-streaming + // text Message must not throw during toLlmResponse(). The Message DTO parses the field + // as Map, so a numeric value lands as Integer/Long and any future + // decoder needs to tolerate it. Today's code does nothing with it; this test guards + // both today's no-op behavior and any future decode site from a NullPointer/ClassCast. + String json = + """ + { + "choices": [{ + "message": { + "role": "assistant", + "content": "hi", + "extra_content": { + "google": { + "thought_signature": 42 + } + } + }, + "finish_reason": "stop" + }] + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + Part textPart = response.content().get().parts().get().get(0); + assertThat(textPart.text()).hasValue("hi"); + assertThat(textPart.thoughtSignature()).isEmpty(); + } + + // ----- streaming thought_signature paths ------------------------------------------------- + + /** + * Pushes the given JSON chunks (one per varargs entry) through a fresh {@link + * ChatCompletionsResponse.ChatCompletionChunkCollection} and returns the concatenated list of all + * {@link LlmResponse} values emitted. Centralizes the four-line decode-and-process boiler so + * streaming tests stay focused on assertions. + */ + private ImmutableList runStream(String... chunkJson) throws Exception { + ChatCompletionsResponse.ChatCompletionChunkCollection collection = + new ChatCompletionsResponse.ChatCompletionChunkCollection(); + ImmutableList.Builder all = ImmutableList.builder(); + for (String json : chunkJson) { + all.addAll( + collection.processChunk( + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletionChunk.class))); + } + return all.build(); + } + + @Test + public void testChunkCollection_streamingText_messageLevelSignatureAttachesToAggregatedTextPart() + throws Exception { + // The canonical Gemini streaming-text pattern: text chunks first, then a finish chunk that + // carries the message-level thought_signature in delta.extra_content. The aggregated-text + // response emitted on the finish chunk MUST carry the signature on its single Part. + String chunk1 = "{\"choices\":[{\"delta\":{\"content\":\"Hello \"}}]}"; + String chunk2 = "{\"choices\":[{\"delta\":{\"content\":\"world!\"}}]}"; + String chunk3 = + String.format( + "{\"choices\":[{\"delta\":{\"extra_content\":{\"google\":{\"thought_signature\":\"%s\"}}},\"finish_reason\":\"stop\"}]}", + STREAMING_TEXT_SIGNATURE_B64); + + ImmutableList all = runStream(chunk1, chunk2, chunk3); + + // chunk1 and chunk2 each emit a partial text response (no signature); chunk3 emits + // (A) aggregated-text with signature, then (B) the metadata-final response. + assertThat(all).hasSize(4); + LlmResponse aggregated = all.get(2); + Part aggregatedTextPart = aggregated.content().get().parts().get().get(0); + assertThat(aggregatedTextPart.text()).hasValue("Hello world!"); + assertThat(aggregatedTextPart.thoughtSignature()).hasValue(STREAMING_TEXT_SIGNATURE); + } + + @Test + public void testChunkCollection_streamingText_malformedExtraContent_doesNotCrash() + throws Exception { + String chunk1 = "{\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}"; + String chunk2 = + "{\"choices\":[{\"delta\":{\"extra_content\":{\"google\":42}},\"finish_reason\":\"stop\"}]}"; + + ImmutableList all = runStream(chunk1, chunk2); + + assertThat(all).hasSize(3); + LlmResponse aggregated = all.get(1); + Part aggregatedTextPart = aggregated.content().get().parts().get().get(0); + assertThat(aggregatedTextPart.text()).hasValue("hi"); + assertThat(aggregatedTextPart.thoughtSignature()).isEmpty(); + } + + @Test + public void testChunkCollection_streamingToolCall_perToolCallSignatureAttachesToFinalPart() + throws Exception { + // Per-tool-call streaming: a tool_call delta with extra_content.google.thought_signature + // must land on the accumulated tool-call Part by the time finish_reason=tool_calls fires. + String chunk1 = + String.format( + "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_a\",\"type\":\"function\",\"function\":{\"name\":\"do_thing\",\"arguments\":\"{}\"},\"extra_content\":{\"google\":{\"thought_signature\":\"%s\"}}}]}}]}", + STREAMING_TOOL_SIGNATURE_B64); + String chunk2 = "{\"choices\":[{\"finish_reason\":\"tool_calls\"}]}"; + + ImmutableList all = runStream(chunk1, chunk2); + + // Tool-call chunks are accumulated silently: per the doc-comment on accumulateToolCalls + // ("To prevent downstream flows from dispatching the same tool multiple times, partial + // tool calls are NOT emitted."), chunk1 emits zero events. chunk2's finish chunk emits + // a single metadata-final response carrying the fully-accumulated tool-call Part with the + // per-tool-call signature applied by updateAccumulatedToolCall. + assertThat(all).hasSize(1); + + LlmResponse finalResponse = all.get(0); + assertThat(finalResponse.finishReason().get().knownEnum()).isEqualTo(Known.STOP); + Part finalToolPart = finalResponse.content().get().parts().get().get(0); + assertThat(finalToolPart.functionCall().get().name()).hasValue("do_thing"); + assertThat(finalToolPart.thoughtSignature()).hasValue(STREAMING_TOOL_SIGNATURE); + } + + @Test + public void testChunkCollection_streamingToolCall_backfillsMessageLevelSignatureWhenAbsent() + throws Exception { + // When a tool-call Part lacks its own per-call signature but the stream carries a + // message-level signature (typical Gemini pattern: message-level signature on the final + // chunk), getFinalToolCallParts backfills it onto the tool-call Part so the assistant + // turn round-trips with a signature. + String chunk1 = + "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_b\",\"type\":\"function\",\"function\":{\"name\":\"do_thing\",\"arguments\":\"{}\"}}]}}]}"; + String chunk2 = + String.format( + "{\"choices\":[{\"delta\":{\"extra_content\":{\"google\":{\"thought_signature\":\"%s\"}}},\"finish_reason\":\"tool_calls\"}]}", + STREAMING_TEXT_SIGNATURE_B64); + + ImmutableList all = runStream(chunk1, chunk2); + + // chunk1: silently accumulated (no partial emitted for tool calls). chunk2: single + // metadata-final response (no aggregated-text event because contentParts is empty). + assertThat(all).hasSize(1); + + LlmResponse finalResponse = all.get(0); + Part finalToolPart = finalResponse.content().get().parts().get().get(0); + // Backfilled signature. + assertThat(finalToolPart.thoughtSignature()).hasValue(STREAMING_TEXT_SIGNATURE); + } + + // ----- Round-trip: Part(sig) --> request --> response --> Part(sig) bytewise equal ------- + + @Test + public void testRoundTrip_functionCallSignature_bytesPreservedThroughRequestAndResponse() + throws Exception { + // Bytewise round-trip from a Part with a signature through the request encoder, then + // back through the response decoder. Guards against any encoding-decoding asymmetry + // (e.g. URL-safe vs standard base64) that DTO-only tests cannot catch. + byte[] originalSig = {0x00, 0x7f, (byte) 0x80, (byte) 0xff}; + + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().id("call_rt").name("ping").build()) + .thoughtSignature(originalSig) + .build())) + .build())) + .build(); + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + // Sanity-check the outbound DTO carries the same encoded signature value... + @SuppressWarnings("unchecked") + Map outboundGoogle = + (Map) request.messages.get(0).toolCalls.get(0).extraContent.get("google"); + String encodedSig = (String) outboundGoogle.get("thought_signature"); + + // ...then synthesize a wire-shaped response carrying the same encoded sig and decode it + // back through toLlmResponse. + String responseJson = + String.format( + """ + { + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [{ + "id": "call_rt", + "type": "function", + "function": { "name": "ping", "arguments": "{}" }, + "extra_content": { + "google": { + "thought_signature": "%s" + } + } + }] + } + }] + } + """, + encodedSig); + + ChatCompletion roundTrippedCompletion = + objectMapper.readValue(responseJson, ChatCompletion.class); + LlmResponse roundTripped = roundTrippedCompletion.toLlmResponse(); + + Part decodedPart = roundTripped.content().get().parts().get().get(0); + assertThat(decodedPart.thoughtSignature()).hasValue(originalSig); + } }