diff --git a/braintrust-sdk/instrumentation/anthropic_2_2_0/src/main/java/dev/braintrust/instrumentation/anthropic/v2_2_0/TracingHttpClient.java b/braintrust-sdk/instrumentation/anthropic_2_2_0/src/main/java/dev/braintrust/instrumentation/anthropic/v2_2_0/TracingHttpClient.java index 8dc0824c..da295f25 100644 --- a/braintrust-sdk/instrumentation/anthropic_2_2_0/src/main/java/dev/braintrust/instrumentation/anthropic/v2_2_0/TracingHttpClient.java +++ b/braintrust-sdk/instrumentation/anthropic_2_2_0/src/main/java/dev/braintrust/instrumentation/anthropic/v2_2_0/TracingHttpClient.java @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import lombok.extern.slf4j.Slf4j; @Slf4j @@ -62,7 +63,7 @@ public void close() { inputJson); var response = underlying.execute(bufferedRequest, requestOptions); - return new TeeingStreamHttpResponse(response, span); + return new TeeingStreamHttpResponse(response, span, inputJson); } catch (Exception e) { InstrumentationSemConv.tagLLMSpanResponse(span, e); span.end(); @@ -90,7 +91,9 @@ public void close() { return underlying .executeAsync(bufferedRequest, requestOptions) .thenApply( - response -> (HttpResponse) new TeeingStreamHttpResponse(response, span)) + response -> + (HttpResponse) + new TeeingStreamHttpResponse(response, span, inputJson)) .whenComplete( (response, t) -> { if (t != null) { @@ -170,14 +173,16 @@ private static String readBodyAsString(HttpRequestBody body) { private static final class TeeingStreamHttpResponse implements HttpResponse { private final HttpResponse delegate; private final Span span; + private final @Nullable String requestBody; private final long spanStartNanos = System.nanoTime(); private final AtomicLong timeToFirstTokenNanos = new AtomicLong(); private final ByteArrayOutputStream teeBuffer = new ByteArrayOutputStream(); private final InputStream teeStream; - TeeingStreamHttpResponse(HttpResponse delegate, Span span) { + TeeingStreamHttpResponse(HttpResponse delegate, Span span, @Nullable String requestBody) { this.delegate = delegate; this.span = span; + this.requestBody = requestBody; this.teeStream = new TeeInputStream( delegate.body(), teeBuffer, this::onFirstByte, this::onStreamClosed); @@ -193,7 +198,7 @@ private void onStreamClosed() { synchronized (teeBuffer) { bytes = teeBuffer.toByteArray(); } - tagSpanFromBuffer(span, bytes, timeToFirstTokenNanos.get()); + tagSpanFromBuffer(span, bytes, timeToFirstTokenNanos.get(), requestBody); } finally { span.end(); } @@ -287,7 +292,8 @@ private void notifyClosed() { // Span tagging from buffered bytes // ------------------------------------------------------------------------- - private static void tagSpanFromBuffer(Span span, byte[] bytes, Long timeToFirstTokenNanos) { + private static void tagSpanFromBuffer( + Span span, byte[] bytes, Long timeToFirstTokenNanos, @Nullable String requestBody) { if (bytes.length == 0) return; try { String firstLine = firstNonEmptyLine(bytes); @@ -297,13 +303,15 @@ private static void tagSpanFromBuffer(Span span, byte[] bytes, Long timeToFirstT firstLine != null && (firstLine.startsWith("data:") || firstLine.startsWith("event:")); if (isSse) { - tagSpanFromSseBytes(span, bytes, timeToFirstTokenNanos); + tagSpanFromSseBytes(span, bytes, timeToFirstTokenNanos, requestBody); } else { // Non-streaming: plain Message JSON — pass it whole, no time_to_first_token InstrumentationSemConv.tagLLMSpanResponse( span, InstrumentationSemConv.PROVIDER_NAME_ANTHROPIC, - new String(bytes, StandardCharsets.UTF_8)); + new String(bytes, StandardCharsets.UTF_8), + null, + requestBody); } } catch (Exception e) { log.error("Could not tag span from Anthropic response buffer", e); @@ -338,7 +346,7 @@ private static String firstNonEmptyLine(byte[] bytes) { * assembled {@link com.anthropic.models.messages.Message} for the span. */ private static void tagSpanFromSseBytes( - Span span, byte[] sseBytes, Long timeToFirstTokenNanos) { + Span span, byte[] sseBytes, Long timeToFirstTokenNanos, @Nullable String requestBody) { try { var mapper = BraintrustJsonMapper.get(); var reader = @@ -362,7 +370,8 @@ private static void tagSpanFromSseBytes( span, InstrumentationSemConv.PROVIDER_NAME_ANTHROPIC, assembledMessageJson, - timeToFirstTokenNanos); + timeToFirstTokenNanos, + requestBody); } catch (Exception e) { log.error("Could not parse Anthropic SSE buffer to tag streaming span output", e); } diff --git a/braintrust-sdk/src/main/java/dev/braintrust/instrumentation/InstrumentationSemConv.java b/braintrust-sdk/src/main/java/dev/braintrust/instrumentation/InstrumentationSemConv.java index 05c04afd..d15f7db2 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/instrumentation/InstrumentationSemConv.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/instrumentation/InstrumentationSemConv.java @@ -85,11 +85,22 @@ public static void tagLLMSpanResponse( @Nonnull String providerName, @Nonnull String responseBody, @Nullable Long timeToFirstTokenNanoseconds) { + tagLLMSpanResponse(span, providerName, responseBody, timeToFirstTokenNanoseconds, null); + } + + @SneakyThrows + public static void tagLLMSpanResponse( + Span span, + @Nonnull String providerName, + @Nonnull String responseBody, + @Nullable Long timeToFirstTokenNanoseconds, + @Nullable String requestBody) { switch (providerName) { case PROVIDER_NAME_OPENAI -> tagOpenAIResponse(span, responseBody, timeToFirstTokenNanoseconds); case PROVIDER_NAME_ANTHROPIC -> - tagAnthropicResponse(span, responseBody, timeToFirstTokenNanoseconds); + tagAnthropicResponse( + span, responseBody, timeToFirstTokenNanoseconds, requestBody); case PROVIDER_NAME_BEDROCK -> tagBedrockResponse(span, responseBody, timeToFirstTokenNanoseconds); default -> tagOpenAIResponse(span, responseBody, timeToFirstTokenNanoseconds); @@ -237,7 +248,10 @@ private static void tagAnthropicRequest( @SneakyThrows private static void tagAnthropicResponse( - Span span, String responseBody, @Nullable Long timeToFirstTokenNanoseconds) { + Span span, + String responseBody, + @Nullable Long timeToFirstTokenNanoseconds, + @Nullable String requestBody) { JsonNode responseJson = BraintrustJsonMapper.get().readTree(responseBody); // Anthropic response is the full Message object — output it whole @@ -258,6 +272,20 @@ private static void tagAnthropicResponse( "tokens", usage.get("input_tokens").asLong() + usage.get("output_tokens").asLong()); } + + // Prompt caching metrics + if (usage.has("cache_read_input_tokens")) { + metrics.put("prompt_cached_tokens", usage.get("cache_read_input_tokens")); + } + if (usage.has("cache_creation_input_tokens")) { + long cacheCreationTokens = usage.get("cache_creation_input_tokens").asLong(); + metrics.put("prompt_cache_creation_tokens", cacheCreationTokens); + + // Per-TTL breakdown: inspect the request to find which TTL tiers were used + if (requestBody != null && cacheCreationTokens > 0) { + addPerTtlCacheMetrics(metrics, requestBody, cacheCreationTokens); + } + } } if (!metrics.isEmpty()) { @@ -265,6 +293,46 @@ private static void tagAnthropicResponse( } } + /** + * Inspect the Anthropic request body for {@code cache_control} blocks and attribute the total + * {@code cache_creation_input_tokens} to per-TTL metrics. If all breakpoints share the same TTL + * (the common case), all creation tokens are attributed to that tier. When multiple TTLs are + * present, we cannot split the total so we attribute it to each tier (the API does not provide + * a per-breakpoint breakdown). + */ + @SneakyThrows + private static void addPerTtlCacheMetrics( + Map metrics, String requestBody, long cacheCreationTokens) { + JsonNode requestJson = BraintrustJsonMapper.get().readTree(requestBody); + java.util.Set ttls = new java.util.LinkedHashSet<>(); + collectCacheControlTtls(requestJson, ttls); + // Default TTL for Anthropic is 5m when cache_control is present but no ttl specified + if (ttls.isEmpty()) { + // There were cache_control blocks but none had an explicit ttl — default is 5m + ttls.add("5m"); + } + for (String ttl : ttls) { + metrics.put("prompt_cache_creation_" + ttl + "_tokens", cacheCreationTokens); + } + } + + /** Recursively collect all distinct {@code cache_control.ttl} values from the request JSON. */ + private static void collectCacheControlTtls(JsonNode node, java.util.Set ttls) { + if (node == null) return; + if (node.isObject()) { + if (node.has("cache_control") && node.get("cache_control").has("ttl")) { + ttls.add(node.get("cache_control").get("ttl").asText()); + } + for (var it = node.fields(); it.hasNext(); ) { + collectCacheControlTtls(it.next().getValue(), ttls); + } + } else if (node.isArray()) { + for (JsonNode child : node) { + collectCacheControlTtls(child, ttls); + } + } + } + // ------------------------------------------------------------------------- // AWS Bedrock provider implementation // ------------------------------------------------------------------------- diff --git a/btx/src/test/java/dev/braintrust/sdkspecimpl/LlmSpanSpec.java b/btx/src/test/java/dev/braintrust/sdkspecimpl/LlmSpanSpec.java index b8b96fbe..a3ba3fe1 100644 --- a/btx/src/test/java/dev/braintrust/sdkspecimpl/LlmSpanSpec.java +++ b/btx/src/test/java/dev/braintrust/sdkspecimpl/LlmSpanSpec.java @@ -30,6 +30,7 @@ public record LlmSpanSpec( String provider, String endpoint, String client, + Map headers, List> requests, List> expectedBrainstoreSpans, String sourcePath) { @@ -59,11 +60,28 @@ static LlmSpanSpec fromMap(Map raw, String sourcePath, String cl String provider = (String) raw.get("provider"); String endpoint = (String) raw.get("endpoint"); + Map headers = null; + if (raw.containsKey("headers")) { + Map rawHeaders = (Map) raw.get("headers"); + headers = new java.util.LinkedHashMap<>(); + for (var entry : rawHeaders.entrySet()) { + headers.put(entry.getKey(), String.valueOf(entry.getValue())); + } + } + List> requests = (List>) raw.get("requests"); List> expectedSpans = (List>) raw.get("expected_brainstore_spans"); return new LlmSpanSpec( - name, type, provider, endpoint, client, requests, expectedSpans, sourcePath); + name, + type, + provider, + endpoint, + client, + headers, + requests, + expectedSpans, + sourcePath); } } diff --git a/btx/src/test/java/dev/braintrust/sdkspecimpl/SpanValidator.java b/btx/src/test/java/dev/braintrust/sdkspecimpl/SpanValidator.java index 637c9a81..f75ed295 100644 --- a/btx/src/test/java/dev/braintrust/sdkspecimpl/SpanValidator.java +++ b/btx/src/test/java/dev/braintrust/sdkspecimpl/SpanValidator.java @@ -187,6 +187,24 @@ private static void assertFnMatcher(Object actual, SpecMatcher.FnMatcher fn, Str context, v)); } } + case "is_positive_number" -> { + if (!(actual instanceof Number)) { + fail( + String.format( + "%s: is_positive_number: expected a Number but got %s" + + " (value: %s)", + context, + actual == null ? "null" : actual.getClass().getSimpleName(), + actual)); + } + double v = ((Number) actual).doubleValue(); + if (v <= 0) { + fail( + String.format( + "%s: is_positive_number: value %s is not positive", + context, v)); + } + } case "is_non_empty_string" -> { if (!(actual instanceof String) || ((String) actual).isEmpty()) { fail( diff --git a/btx/src/test/java/dev/braintrust/sdkspecimpl/SpecExecutor.java b/btx/src/test/java/dev/braintrust/sdkspecimpl/SpecExecutor.java index 40dda005..a47a9a16 100644 --- a/btx/src/test/java/dev/braintrust/sdkspecimpl/SpecExecutor.java +++ b/btx/src/test/java/dev/braintrust/sdkspecimpl/SpecExecutor.java @@ -25,17 +25,6 @@ import dev.braintrust.instrumentation.langchain.BraintrustLangchain; import dev.braintrust.instrumentation.openai.BraintrustOpenAI; import dev.braintrust.instrumentation.springai.v1_0_0.BraintrustSpringAI; -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.data.message.ChatMessage; -import dev.langchain4j.data.message.Content; -import dev.langchain4j.data.message.ImageContent; -import dev.langchain4j.data.message.SystemMessage; -import dev.langchain4j.data.message.TextContent; -import dev.langchain4j.data.message.UserMessage; -import dev.langchain4j.model.chat.ChatModel; -import dev.langchain4j.model.chat.request.ChatRequest; -import dev.langchain4j.model.chat.request.json.JsonObjectSchema; -import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiStreamingChatModel; import io.opentelemetry.api.trace.Span; @@ -49,22 +38,11 @@ import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.function.FunctionToolCallback; -import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; -import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; -import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock; -import software.amazon.awssdk.services.bedrockruntime.model.ImageFormat; -import software.amazon.awssdk.services.bedrockruntime.model.ImageSource; -import software.amazon.awssdk.services.bedrockruntime.model.Message; /** * Executes LLM spec tests in-process using the Braintrust Java SDK instrumentation. @@ -137,8 +115,7 @@ public String execute(LlmSpanSpec spec) throws Exception { List responsesHistory = new ArrayList<>(); for (Map request : spec.requests()) { - dispatchRequest( - spec.provider(), spec.endpoint(), spec.client(), request, responsesHistory); + dispatchRequest(spec, request, responsesHistory); } } finally { rootSpan.end(); @@ -147,12 +124,11 @@ public String execute(LlmSpanSpec spec) throws Exception { } private void dispatchRequest( - String provider, - String endpoint, - String client, - Map request, - List responsesHistory) + LlmSpanSpec spec, Map request, List responsesHistory) throws Exception { + String provider = spec.provider(); + String endpoint = spec.endpoint(); + String client = spec.client(); if ("openai".equals(provider) && "/v1/chat/completions".equals(endpoint)) { if ("langchain-openai".equals(client)) { executeLangChainChatCompletion(request); @@ -165,9 +141,9 @@ private void dispatchRequest( executeResponses(request, responsesHistory); } else if ("anthropic".equals(provider) && "/v1/messages".equals(endpoint)) { if ("springai-anthropic".equals(client)) { - executeSpringAiAnthropicMessages(request); + executeSpringAiAnthropicMessages(spec, request); } else { - executeAnthropicMessages(request); + executeAnthropicMessages(spec, request); } } else if ("bedrock".equals(provider) && endpoint.contains("/converse-stream")) { executeBedrockConverseStream(request); @@ -191,51 +167,15 @@ private void dispatchRequest( private void executeChatCompletion(Map request) throws Exception { boolean streaming = Boolean.TRUE.equals(request.get("stream")); - // Serialize the whole request map to JSON, then let the SDK's mapper deserialize each - // field into the correct SDK type — no manual field extraction needed. - String json = MAPPER.writeValueAsString(request); - com.fasterxml.jackson.databind.JsonNode node = ObjectMappers.jsonMapper().readTree(json); + // Ensure "stream" is always present in the body — the OpenAI API expects it + // and VCR cassettes were recorded with it. + Map bodyMap = new java.util.LinkedHashMap<>(request); + bodyMap.putIfAbsent("stream", false); + String json = MAPPER.writeValueAsString(bodyMap); + ChatCompletionCreateParams.Body body = + ObjectMappers.jsonMapper().readValue(json, ChatCompletionCreateParams.Body.class); + var params = ChatCompletionCreateParams.builder().body(body).build(); - var builder = ChatCompletionCreateParams.builder(); - if (node.has("model")) - builder.model(com.openai.models.ChatModel.of(node.get("model").asText())); - if (node.has("messages")) { - List msgs = - ObjectMappers.jsonMapper() - .convertValue( - node.get("messages"), - ObjectMappers.jsonMapper() - .getTypeFactory() - .constructCollectionType( - List.class, - com.openai.models.chat.completions - .ChatCompletionMessageParam.class)); - builder.messages(msgs); - } - if (node.has("tools")) { - List tools = - ObjectMappers.jsonMapper() - .convertValue( - node.get("tools"), - ObjectMappers.jsonMapper() - .getTypeFactory() - .constructCollectionType( - List.class, - com.openai.models.chat.completions - .ChatCompletionTool.class)); - builder.tools(tools); - } - if (node.has("temperature")) builder.temperature(node.get("temperature").asDouble()); - if (node.has("max_tokens")) builder.maxCompletionTokens(node.get("max_tokens").asLong()); - if (node.has("stream_options")) - builder.streamOptions( - ObjectMappers.jsonMapper() - .convertValue( - node.get("stream_options"), - com.openai.models.chat.completions.ChatCompletionStreamOptions - .class)); - - var params = builder.build(); if (streaming) { // Hold a reference to prevent GC-driven PhantomReachable cleanup before the stream // is fully consumed, which would close the SSE stream early. @@ -249,155 +189,144 @@ private void executeChatCompletion(Map request) throws Exception // ---- LangChain4j OpenAI chat/completions ------------------------------------ - @SuppressWarnings("unchecked") - private void executeLangChainChatCompletion(Map request) throws Exception { - var node = MAPPER.valueToTree(request); - boolean streaming = node.has("stream") && node.get("stream").asBoolean(); - - List messages = new ArrayList<>(); - for (Map msg : (List>) request.get("messages")) { - String role = (String) msg.get("role"); - Object rawContent = msg.get("content"); - switch (role) { - case "system" -> - messages.add( - SystemMessage.from( - rawContent != null ? rawContent.toString() : "")); - case "user" -> { - if (rawContent instanceof List) { - messages.add( - new UserMessage( - buildLangChainContents( - (List>) rawContent))); - } else { - messages.add( - UserMessage.from(rawContent != null ? rawContent.toString() : "")); + /** + * Jackson ObjectMapper for deserializing spec JSON into LangChain4j's internal {@link + * dev.langchain4j.model.openai.internal.chat.ChatCompletionRequest}. + * + *

LangChain4j's {@code Message} interface has no {@code @JsonTypeInfo}, so we register a + * custom deserializer that dispatches on the {@code role} field. + */ + private static final ObjectMapper LANGCHAIN_MAPPER = createLangChainMapper(); + + private static ObjectMapper createLangChainMapper() { + var module = new com.fasterxml.jackson.databind.module.SimpleModule(); + module.addDeserializer( + dev.langchain4j.model.openai.internal.chat.Message.class, + new com.fasterxml.jackson.databind.JsonDeserializer< + dev.langchain4j.model.openai.internal.chat.Message>() { + @Override + public dev.langchain4j.model.openai.internal.chat.Message deserialize( + com.fasterxml.jackson.core.JsonParser p, + com.fasterxml.jackson.databind.DeserializationContext ctx) + throws java.io.IOException { + com.fasterxml.jackson.databind.JsonNode node = p.getCodec().readTree(p); + String role = node.has("role") ? node.get("role").asText() : ""; + com.fasterxml.jackson.databind.ObjectMapper codec = + (com.fasterxml.jackson.databind.ObjectMapper) p.getCodec(); + return switch (role) { + case "system" -> + codec.treeToValue( + node, + dev.langchain4j.model.openai.internal.chat.SystemMessage + .class); + case "user" -> deserializeUserMessage(codec, node); + case "assistant" -> + codec.treeToValue( + node, + dev.langchain4j.model.openai.internal.chat + .AssistantMessage.class); + case "tool" -> + codec.treeToValue( + node, + dev.langchain4j.model.openai.internal.chat.ToolMessage + .class); + default -> + throw new java.io.IOException( + "Unsupported langchain message role: " + role); + }; } - } - default -> - throw new UnsupportedOperationException( - "langchain-openai: unsupported role: " + role); + }); + return new ObjectMapper() + .disable( + com.fasterxml.jackson.databind.DeserializationFeature + .FAIL_ON_IGNORED_PROPERTIES) + .disable( + com.fasterxml.jackson.databind.DeserializationFeature + .FAIL_ON_UNKNOWN_PROPERTIES) + .registerModule(module); + } + + /** + * Deserialize a LangChain4j UserMessage from a JSON node, handling the polymorphic {@code + * content} field (string vs array of Content blocks) that the Builder can't dispatch + * automatically. + */ + private static dev.langchain4j.model.openai.internal.chat.UserMessage deserializeUserMessage( + ObjectMapper mapper, com.fasterxml.jackson.databind.JsonNode node) + throws com.fasterxml.jackson.core.JsonProcessingException { + var builder = dev.langchain4j.model.openai.internal.chat.UserMessage.builder(); + if (node.has("content")) { + var content = node.get("content"); + if (content.isTextual()) { + builder.content(content.asText()); + } else if (content.isArray()) { + List list = + mapper.convertValue( + content, + mapper.getTypeFactory() + .constructCollectionType( + List.class, + dev.langchain4j.model.openai.internal.chat.Content + .class)); + builder.content(list); } } + if (node.has("name")) { + builder.name(node.get("name").asText()); + } + return builder.build(); + } + + private void executeLangChainChatCompletion(Map request) throws Exception { + boolean streaming = Boolean.TRUE.equals(request.get("stream")); + // Build a model just to get an instrumented client via BraintrustLangchain.wrap(). + dev.langchain4j.model.openai.internal.OpenAiClient langchainClient; if (streaming) { var modelBuilder = OpenAiStreamingChatModel.builder().baseUrl(openAiBaseUrl).apiKey(openAiApiKey); - if (node.has("model")) modelBuilder.modelName(node.get("model").asText()); - if (node.has("temperature")) - modelBuilder.temperature(node.get("temperature").asDouble()); - if (node.has("max_tokens")) modelBuilder.maxTokens(node.get("max_tokens").asInt()); var model = BraintrustLangchain.wrap(otel, modelBuilder); - var done = new CompletableFuture(); - model.chat( - messages, - new StreamingChatResponseHandler() { - @Override - public void onPartialResponse(String s) {} - - @Override - public void onCompleteResponse( - dev.langchain4j.model.chat.response.ChatResponse r) { - done.complete(null); - } - - @Override - public void onError(Throwable t) { - done.completeExceptionally(t); - } - }); - done.get(); + langchainClient = getPrivateField(model, "client"); } else { var modelBuilder = OpenAiChatModel.builder().baseUrl(openAiBaseUrl).apiKey(openAiApiKey); - if (node.has("model")) modelBuilder.modelName(node.get("model").asText()); - if (node.has("temperature")) - modelBuilder.temperature(node.get("temperature").asDouble()); - if (node.has("max_tokens")) modelBuilder.maxTokens(node.get("max_tokens").asInt()); - ChatModel model = BraintrustLangchain.wrap(otel, modelBuilder); - var reqBuilder = ChatRequest.builder().messages(messages); - if (node.has("tools")) { - reqBuilder.toolSpecifications(buildLangChainToolSpecs(node.get("tools"))); - } - model.chat(reqBuilder.build()); + OpenAiChatModel model = BraintrustLangchain.wrap(otel, modelBuilder); + langchainClient = getPrivateField(model, "client"); } - } - /** Build LangChain4j {@link Content} list from a multi-part YAML content array. */ - @SuppressWarnings("unchecked") - private static List buildLangChainContents(List> parts) { - List contents = new ArrayList<>(); - for (Map part : parts) { - String type = (String) part.get("type"); - if ("text".equals(type)) { - contents.add(new TextContent((String) part.get("text"))); - } else if ("image_url".equals(type)) { - Map imageUrl = (Map) part.get("image_url"); - String url = (String) imageUrl.get("url"); - if (url != null && url.startsWith("data:")) { - // data:;base64, - int semi = url.indexOf(';'); - int comma = url.indexOf(','); - String mimeType = semi > 0 ? url.substring(5, semi) : "image/png"; - String base64 = comma > 0 ? url.substring(comma + 1) : ""; - contents.add(new ImageContent(base64, mimeType)); - } else { - contents.add(new ImageContent(url)); - } - } + // Deserialize the spec JSON directly into LangChain4j's ChatCompletionRequest. + // The LANGCHAIN_MAPPER has custom deserializers for Message (role-based dispatch) + // and UserMessage (polymorphic string/array content handling). + String json = MAPPER.writeValueAsString(request); + var chatRequest = + LANGCHAIN_MAPPER.readValue( + json, + dev.langchain4j.model.openai.internal.chat.ChatCompletionRequest.class); + + if (streaming) { + var done = new CompletableFuture(); + langchainClient + .chatCompletion(chatRequest) + .onPartialResponse(response -> {}) + .onComplete(() -> done.complete(null)) + .onError(done::completeExceptionally) + .execute(); + done.get(); + } else { + langchainClient.chatCompletion(chatRequest).execute(); } - return contents; } - /** Build LangChain4j {@link ToolSpecification}s from the YAML {@code tools} array. */ - private static List buildLangChainToolSpecs( - com.fasterxml.jackson.databind.JsonNode toolsNode) { - List specs = new ArrayList<>(); - for (com.fasterxml.jackson.databind.JsonNode toolNode : toolsNode) { - com.fasterxml.jackson.databind.JsonNode fn = toolNode.get("function"); - if (fn == null) continue; - var schemaBuilder = JsonObjectSchema.builder(); - com.fasterxml.jackson.databind.JsonNode params = fn.get("parameters"); - if (params != null && params.has("properties")) { - List required = new ArrayList<>(); - if (params.has("required")) { - params.get("required").forEach(r -> required.add(r.asText())); - } - params.get("properties") - .fields() - .forEachRemaining( - entry -> { - var prop = entry.getValue(); - String name = entry.getKey(); - String desc = - prop.has("description") - ? prop.get("description").asText() - : null; - if (prop.has("enum")) { - List vals = new ArrayList<>(); - prop.get("enum").forEach(e -> vals.add(e.asText())); - schemaBuilder.addEnumProperty(name, vals); - } else { - schemaBuilder.addStringProperty(name, desc); - } - }); - schemaBuilder.required(required); - } - specs.add( - ToolSpecification.builder() - .name(fn.get("name").asText()) - .description( - fn.has("description") ? fn.get("description").asText() : null) - .parameters(schemaBuilder.build()) - .build()); - } - return specs; + @SuppressWarnings("unchecked") + private static T getPrivateField(Object obj, String fieldName) throws Exception { + var field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return (T) field.get(obj); } // ---- Spring AI OpenAI chat/completions -------------------------------------- private void executeSpringAiOpenAiChatCompletion(Map request) throws Exception { - var node = MAPPER.valueToTree(request); // Pass the full base URL (including /v1) and override completionsPath so Spring AI // appends just "/chat/completions" rather than the default "/v1/chat/completions". var api = @@ -406,162 +335,107 @@ private void executeSpringAiOpenAiChatCompletion(Map request) th .completionsPath("/chat/completions") .apiKey(openAiApiKey) .build(); - var optionsBuilder = OpenAiChatOptions.builder(); - if (node.has("model")) optionsBuilder.model(node.get("model").asText()); - if (node.has("temperature")) optionsBuilder.temperature(node.get("temperature").asDouble()); - if (node.has("max_tokens")) optionsBuilder.maxTokens(node.get("max_tokens").asInt()); - if (node.has("stream") && node.get("stream").asBoolean()) { - optionsBuilder.streamUsage(true); - } - if (node.has("tools")) { - optionsBuilder.toolCallbacks(buildSpringAiToolCallbacks(node.get("tools"))); - // Disable internal execution so tool_calls surface in the response output - optionsBuilder.internalToolExecutionEnabled(false); - } + + // We need to wrap the api's HTTP clients for instrumentation. The easiest way + // is to go through OpenAiChatModel.builder() + BraintrustSpringAI.wrap(), + // which instruments the RestClient/WebClient inside the api object in-place. var modelBuilder = org.springframework.ai.openai.OpenAiChatModel.builder() .openAiApi(api) - .defaultOptions(optionsBuilder.build()); + .defaultOptions(OpenAiChatOptions.builder().build()); BraintrustSpringAI.wrap(otel, modelBuilder); - var model = modelBuilder.build(); - var prompt = buildSpringAiPrompt(request); - if (node.has("stream") && node.get("stream").asBoolean()) { - model.stream(prompt).blockLast(); - } else { - model.call(prompt); - } - } - /** Build Spring AI {@link ToolCallback}s from the YAML {@code tools} array. */ - private static List buildSpringAiToolCallbacks( - com.fasterxml.jackson.databind.JsonNode toolsNode) { - List callbacks = new ArrayList<>(); - for (com.fasterxml.jackson.databind.JsonNode toolNode : toolsNode) { - com.fasterxml.jackson.databind.JsonNode fn = toolNode.get("function"); - if (fn == null) continue; - String paramsJson = fn.has("parameters") ? fn.get("parameters").toString() : "{}"; - callbacks.add( - FunctionToolCallback.builder( - fn.get("name").asText(), (String input) -> "not implemented") - .description( - fn.has("description") ? fn.get("description").asText() : "") - .inputSchema(paramsJson) - .inputType(String.class) - .build()); + // Deserialize the spec JSON directly into Spring AI's ChatCompletionRequest. + // Default "stream" to false since Spring AI's OpenAiApi unboxes it. + var node = MAPPER.valueToTree(request); + if (!node.has("stream")) { + ((com.fasterxml.jackson.databind.node.ObjectNode) node).put("stream", false); + } + boolean stream = node.get("stream").asBoolean(); + // Add stream_options for streaming so usage stats are returned. + if (stream && !node.has("stream_options")) { + var streamOpts = MAPPER.createObjectNode(); + streamOpts.put("include_usage", true); + ((com.fasterxml.jackson.databind.node.ObjectNode) node) + .set("stream_options", streamOpts); + } + var chatRequest = MAPPER.treeToValue(node, OpenAiApi.ChatCompletionRequest.class); + if (stream) { + api.chatCompletionStream(chatRequest).blockLast(); + } else { + api.chatCompletionEntity(chatRequest); } - return callbacks; } // ---- Spring AI Anthropic messages ------------------------------------------- - private void executeSpringAiAnthropicMessages(Map request) throws Exception { - var node = MAPPER.valueToTree(request); - var api = AnthropicApi.builder().baseUrl(anthropicBaseUrl).apiKey(anthropicApiKey).build(); - var optionsBuilder = AnthropicChatOptions.builder(); - if (node.has("model")) optionsBuilder.model(node.get("model").asText()); - if (node.has("temperature")) optionsBuilder.temperature(node.get("temperature").asDouble()); - if (node.has("max_tokens")) optionsBuilder.maxTokens(node.get("max_tokens").asInt()); + private void executeSpringAiAnthropicMessages(LlmSpanSpec spec, Map request) + throws Exception { + var apiBuilder = AnthropicApi.builder().baseUrl(anthropicBaseUrl).apiKey(anthropicApiKey); + if (spec.headers() != null && spec.headers().containsKey("anthropic-beta")) { + apiBuilder.anthropicBetaFeatures(spec.headers().get("anthropic-beta")); + } + var api = apiBuilder.build(); + + // We need to wrap the api's HTTP clients for instrumentation. The easiest way + // is to go through AnthropicChatModel.builder() + BraintrustSpringAI.wrap(), + // which instruments the RestClient/WebClient inside the api object in-place. var modelBuilder = AnthropicChatModel.builder() .anthropicApi(api) - .defaultOptions(optionsBuilder.build()); + .defaultOptions(AnthropicChatOptions.builder().build()); BraintrustSpringAI.wrap(otel, modelBuilder); - var model = modelBuilder.build(); - var prompt = buildSpringAiPrompt(request); - if (node.has("stream") && node.get("stream").asBoolean()) { - model.stream(prompt).blockLast(); - } else { - model.call(prompt); - } - } - /** - * Build a Spring AI {@link Prompt} from the YAML request's {@code messages} list. - * - *

Also handles top-level {@code system:} fields (used by Anthropic-style YAML) by prepending - * a {@link org.springframework.ai.chat.messages.SystemMessage}. - */ - @SuppressWarnings("unchecked") - private static Prompt buildSpringAiPrompt(Map request) throws Exception { - List messages = new ArrayList<>(); - for (Map msg : (List>) request.get("messages")) { - String role = (String) msg.get("role"); - Object rawContent = msg.get("content"); - if ("user".equals(role) && rawContent instanceof List) { - messages.add(buildSpringAiUserMessage((List>) rawContent)); - } else { - String content = rawContent != null ? rawContent.toString() : ""; - messages.add( - switch (role) { - case "system" -> - new org.springframework.ai.chat.messages.SystemMessage(content); - case "user" -> - new org.springframework.ai.chat.messages.UserMessage(content); - case "assistant" -> new AssistantMessage(content); - default -> - throw new UnsupportedOperationException( - "unsupported role: " + role); - }); - } + // Normalize the spec JSON so it deserializes into Spring AI's + // ChatCompletionRequest: message "content" strings must become + // [{type:"text", text:"..."}] lists since AnthropicMessage expects + // List, and "stream" must be explicitly present since + // AnthropicApi unboxes the Boolean without a null check. + var node = MAPPER.valueToTree(request); + normalizeAnthropicMessages(node); + if (!node.has("stream")) { + ((com.fasterxml.jackson.databind.node.ObjectNode) node).put("stream", false); } - // Append a system message for top-level "system" field (Anthropic-style YAML) - if (request.containsKey("system")) { - messages.add( - new org.springframework.ai.chat.messages.SystemMessage( - request.get("system").toString())); + + boolean stream = node.get("stream").asBoolean(); + var chatRequest = MAPPER.treeToValue(node, AnthropicApi.ChatCompletionRequest.class); + if (stream) { + api.chatCompletionStream(chatRequest).blockLast(); + } else { + api.chatCompletionEntity(chatRequest); } - return new Prompt(messages); } /** - * Build a Spring AI {@link org.springframework.ai.chat.messages.UserMessage} with text and - * optional media parts. + * Normalize Anthropic message content for Spring AI deserialization. The Anthropic API accepts + * both {@code "content": "text"} and {@code "content": [{...}]}, but Spring AI's {@link + * AnthropicApi.AnthropicMessage} only models the list form. This converts any string content + * into {@code [{type:"text", text:"..."}]}. */ - @SuppressWarnings("unchecked") - private static org.springframework.ai.chat.messages.UserMessage buildSpringAiUserMessage( - List> parts) throws Exception { - String text = ""; - List mediaList = new ArrayList<>(); - for (Map part : parts) { - String type = (String) part.get("type"); - if ("text".equals(type)) { - text = (String) part.getOrDefault("text", ""); - } else if ("image_url".equals(type)) { - // OpenAI format: {type: image_url, image_url: {url: data:mime;base64,...}} - Map imageUrl = (Map) part.get("image_url"); - String url = (String) imageUrl.get("url"); - if (url != null && url.startsWith("data:")) { - int semi = url.indexOf(';'), comma = url.indexOf(','); - String mimeType = semi > 0 ? url.substring(5, semi) : "image/png"; - byte[] bytes = java.util.Base64.getDecoder().decode(url.substring(comma + 1)); - mediaList.add( - new org.springframework.ai.content.Media( - org.springframework.util.MimeTypeUtils.parseMimeType(mimeType), - new org.springframework.core.io.ByteArrayResource(bytes))); - } - } else if ("image".equals(type)) { - // Anthropic format: {type: image, source: {type: base64, media_type, data}} - Map source = (Map) part.get("source"); - if ("base64".equals(source.get("type"))) { - String mimeType = (String) source.getOrDefault("media_type", "image/png"); - byte[] bytes = - java.util.Base64.getDecoder().decode((String) source.get("data")); - mediaList.add( - new org.springframework.ai.content.Media( - org.springframework.util.MimeTypeUtils.parseMimeType(mimeType), - new org.springframework.core.io.ByteArrayResource(bytes))); - } + private static void normalizeAnthropicMessages(com.fasterxml.jackson.databind.JsonNode root) { + var messages = root.get("messages"); + if (messages == null || !messages.isArray()) return; + for (var msg : messages) { + var content = msg.get("content"); + if (content != null && content.isTextual()) { + var arr = MAPPER.createArrayNode(); + var block = MAPPER.createObjectNode(); + block.put("type", "text"); + block.put("text", content.asText()); + arr.add(block); + ((com.fasterxml.jackson.databind.node.ObjectNode) msg).set("content", arr); } } - var builder = org.springframework.ai.chat.messages.UserMessage.builder().text(text); - if (!mediaList.isEmpty()) builder.media(mediaList); - return builder.build(); } // ---- OpenAI responses ------------------------------------------------------- private void executeResponses(Map request, List history) throws Exception { + // The responses API has multi-turn history: each turn's input items are + // prepended with outputs from prior turns. We deserialize the "input" field + // separately to accumulate history, then deserialize the rest of the body + // generically. String json = MAPPER.writeValueAsString(request); com.fasterxml.jackson.databind.JsonNode node = ObjectMappers.jsonMapper().readTree(json); @@ -579,15 +453,12 @@ private void executeResponses(Map request, List fullInput = new ArrayList<>(history); fullInput.addAll(thisInput); - var builder = ResponseCreateParams.builder().inputOfResponse(fullInput); - if (node.has("model")) builder.model(node.get("model").asText()); - if (node.has("reasoning")) - builder.reasoning( - ObjectMappers.jsonMapper() - .convertValue( - node.get("reasoning"), com.openai.models.Reasoning.class)); + // Deserialize the full body, then override input with the accumulated history. + ResponseCreateParams.Body body = + ObjectMappers.jsonMapper().readValue(json, ResponseCreateParams.Body.class); + var params = ResponseCreateParams.builder().body(body).inputOfResponse(fullInput).build(); - Response response = openAIClient.responses().create(builder.build()); + Response response = openAIClient.responses().create(params); // Accumulate this turn's input + output into history for the next turn history.addAll(thisInput); @@ -599,34 +470,28 @@ private void executeResponses(Map request, List request) throws Exception { - String json = MAPPER.writeValueAsString(request); - com.fasterxml.jackson.databind.JsonNode node = - com.anthropic.core.ObjectMappers.jsonMapper().readTree(json); - - var builder = MessageCreateParams.builder(); - if (node.has("model")) builder.model(node.get("model").asText()); - if (node.has("max_tokens")) builder.maxTokens(node.get("max_tokens").asLong()); - if (node.has("temperature")) builder.temperature(node.get("temperature").asDouble()); - if (node.has("system")) builder.system(node.get("system").asText()); - if (node.has("messages")) { - List msgs = - com.anthropic.core.ObjectMappers.jsonMapper() - .convertValue( - node.get("messages"), - com.anthropic.core.ObjectMappers.jsonMapper() - .getTypeFactory() - .constructCollectionType( - List.class, - com.anthropic.models.messages.MessageParam - .class)); - builder.messages(msgs); + private void executeAnthropicMessages(LlmSpanSpec spec, Map request) + throws Exception { + // Strip the "stream" key before deserializing — it's not part of + // MessageCreateParams.Body; we handle it ourselves. + boolean stream = Boolean.TRUE.equals(request.get("stream")); + Map bodyMap = new java.util.LinkedHashMap<>(request); + bodyMap.remove("stream"); + + String json = MAPPER.writeValueAsString(bodyMap); + MessageCreateParams.Body body = + com.anthropic.core.ObjectMappers.jsonMapper() + .readValue(json, MessageCreateParams.Body.class); + + var builder = MessageCreateParams.builder().body(body); + if (spec.headers() != null) { + spec.headers().forEach(builder::putAdditionalHeader); } - var params = builder.build(); - if (node.has("stream") && node.get("stream").asBoolean()) { - try (var stream = anthropicClient.messages().createStreaming(params)) { - stream.stream().forEach(event -> {}); + + if (stream) { + try (var s = anthropicClient.messages().createStreaming(params)) { + s.stream().forEach(event -> {}); } } else { anthropicClient.messages().create(params); @@ -635,63 +500,98 @@ private void executeAnthropicMessages(Map request) throws Except // ---- AWS Bedrock ------------------------------------------------------------ - @SuppressWarnings("unchecked") - private void executeBedrockConverse(Map request) { - String modelId = (String) request.get("modelId"); - - // Build messages from the spec YAML format: [{role, content: [{text: ...} | {image: ...}]}] - List messages = new ArrayList<>(); - for (Map msg : (List>) request.get("messages")) { - String role = (String) msg.get("role"); - List contentBlocks = new ArrayList<>(); - for (Map part : (List>) msg.get("content")) { - if (part.containsKey("text")) { - contentBlocks.add(ContentBlock.fromText((String) part.get("text"))); - } else if (part.containsKey("image")) { - contentBlocks.add( - buildBedrockImageBlock((Map) part.get("image"))); - } - } - messages.add( - Message.builder() - .role(ConversationRole.fromValue(role)) - .content(contentBlocks) - .build()); + /** + * Unmarshaller that uses the AWS SDK's internal {@link + * software.amazon.awssdk.protocols.json.internal.unmarshall.JsonProtocolUnmarshaller} (via + * reflection) to deserialize JSON into SDK model objects (SdkPojo). This is the same machinery + * the SDK uses to parse API responses. + */ + private static final Object BEDROCK_UNMARSHALLER; + + private static final software.amazon.awssdk.protocols.jsoncore.JsonNodeParser + BEDROCK_JSON_PARSER = software.amazon.awssdk.protocols.jsoncore.JsonNodeParser.create(); + + static { + try { + // JsonProtocolUnmarshaller is @SdkInternalApi, so we construct it reflectively. + Class unmarshallerClass = + Class.forName( + "software.amazon.awssdk.protocols.json.internal.unmarshall.JsonProtocolUnmarshaller"); + var builderMethod = unmarshallerClass.getMethod("builder"); + var builderObj = builderMethod.invoke(null); + var builderClass = builderObj.getClass(); + + // Set the parser + builderClass + .getMethod( + "parser", + software.amazon.awssdk.protocols.jsoncore.JsonNodeParser.class) + .invoke(builderObj, BEDROCK_JSON_PARSER); + + // Use default protocol unmarshall dependencies + var depsMethod = unmarshallerClass.getMethod("defaultProtocolUnmarshallDependencies"); + var deps = depsMethod.invoke(null); + builderClass + .getMethod( + "protocolUnmarshallDependencies", + Class.forName( + "software.amazon.awssdk.protocols.json.internal.unmarshall.ProtocolUnmarshallDependencies")) + .invoke(builderObj, deps); + + BEDROCK_UNMARSHALLER = builderClass.getMethod("build").invoke(builderObj); + } catch (Exception e) { + throw new RuntimeException("Failed to create Bedrock JSON unmarshaller", e); } + } + + /** + * Deserialize a JSON string into an AWS SDK model object using the SDK's internal unmarshaller. + * The object must implement {@link software.amazon.awssdk.core.SdkPojo}. + */ + @SuppressWarnings("unchecked") + private static T bedrockFromJson( + String json, software.amazon.awssdk.core.SdkPojo builderInstance) throws Exception { + software.amazon.awssdk.protocols.jsoncore.JsonNode jsonNode = + BEDROCK_JSON_PARSER.parse( + new java.io.ByteArrayInputStream( + json.getBytes(java.nio.charset.StandardCharsets.UTF_8))); + + // Build a minimal SdkHttpFullResponse — the unmarshaller only uses it for + // explicit payload members (SdkBytes/String), which normal Converse fields don't have. + var response = + software.amazon.awssdk.http.SdkHttpFullResponse.builder().statusCode(200).build(); + + // Call unmarshall(SdkPojo, SdkHttpFullResponse, JsonNode) reflectively. + var method = + BEDROCK_UNMARSHALLER + .getClass() + .getMethod( + "unmarshall", + software.amazon.awssdk.core.SdkPojo.class, + software.amazon.awssdk.http.SdkHttpFullResponse.class, + software.amazon.awssdk.protocols.jsoncore.JsonNode.class); + return (T) method.invoke(BEDROCK_UNMARSHALLER, builderInstance, response, jsonNode); + } + + private void executeBedrockConverse(Map request) throws Exception { + String json = MAPPER.writeValueAsString(request); + ConverseRequest converseRequest = bedrockFromJson(json, ConverseRequest.builder()); var builder = BraintrustAWSBedrock.wrap(otel, bedrockUtils.syncClientBuilder()); try (var client = builder.build()) { - client.converse(ConverseRequest.builder().modelId(modelId).messages(messages).build()); + client.converse(converseRequest); } } - @SuppressWarnings("unchecked") private void executeBedrockConverseStream(Map request) throws Exception { - String modelId = (String) request.get("modelId"); - - List messages = new ArrayList<>(); - for (Map msg : (List>) request.get("messages")) { - String role = (String) msg.get("role"); - List contentBlocks = new ArrayList<>(); - for (Map part : (List>) msg.get("content")) { - if (part.containsKey("text")) { - contentBlocks.add(ContentBlock.fromText((String) part.get("text"))); - } - } - messages.add( - Message.builder() - .role(ConversationRole.fromValue(role)) - .content(contentBlocks) - .build()); - } + String json = MAPPER.writeValueAsString(request); + ConverseStreamRequest converseStreamRequest = + bedrockFromJson(json, ConverseStreamRequest.builder()); var asyncBuilder = BraintrustAWSBedrock.wrap(otel, bedrockUtils.asyncClientBuilder()); try (var client = asyncBuilder.build()) { client.converseStream( - ConverseStreamRequest.builder() - .modelId(modelId) - .messages(messages) - .build(), + converseStreamRequest, ConverseStreamResponseHandler.builder() .subscriber( ConverseStreamResponseHandler.Visitor.builder().build()) @@ -700,20 +600,6 @@ private void executeBedrockConverseStream(Map request) throws Ex } } - /** Builds a Bedrock {@link ContentBlock} image from the YAML {@code image:} map. */ - @SuppressWarnings("unchecked") - private static ContentBlock buildBedrockImageBlock(Map imageMap) { - String format = (String) imageMap.getOrDefault("format", "png"); - Map sourceMap = (Map) imageMap.get("source"); - String base64 = (String) sourceMap.get("bytes"); - byte[] imageBytes = java.util.Base64.getDecoder().decode(base64); - return ContentBlock.fromImage( - ImageBlock.builder() - .format(ImageFormat.fromValue(format)) - .source(ImageSource.fromBytes(SdkBytes.fromByteArray(imageBytes))) - .build()); - } - // ---- Google Gemini ---------------------------------------------------------- @SuppressWarnings("unchecked")