Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.extern.slf4j.Slf4j;

@Slf4j
Expand Down Expand Up @@ -62,7 +63,7 @@ public void close() {
inputJson);

var response = underlying.execute(bufferedRequest, requestOptions);
return new TeeingStreamHttpResponse(response, span);
return new TeeingStreamHttpResponse(response, span, inputJson);
} catch (Exception e) {
InstrumentationSemConv.tagLLMSpanResponse(span, e);
span.end();
Expand Down Expand Up @@ -90,7 +91,9 @@ public void close() {
return underlying
.executeAsync(bufferedRequest, requestOptions)
.thenApply(
response -> (HttpResponse) new TeeingStreamHttpResponse(response, span))
response ->
(HttpResponse)
new TeeingStreamHttpResponse(response, span, inputJson))
.whenComplete(
(response, t) -> {
if (t != null) {
Expand Down Expand Up @@ -170,14 +173,16 @@ private static String readBodyAsString(HttpRequestBody body) {
private static final class TeeingStreamHttpResponse implements HttpResponse {
private final HttpResponse delegate;
private final Span span;
private final @Nullable String requestBody;
private final long spanStartNanos = System.nanoTime();
private final AtomicLong timeToFirstTokenNanos = new AtomicLong();
private final ByteArrayOutputStream teeBuffer = new ByteArrayOutputStream();
private final InputStream teeStream;

TeeingStreamHttpResponse(HttpResponse delegate, Span span) {
TeeingStreamHttpResponse(HttpResponse delegate, Span span, @Nullable String requestBody) {
this.delegate = delegate;
this.span = span;
this.requestBody = requestBody;
this.teeStream =
new TeeInputStream(
delegate.body(), teeBuffer, this::onFirstByte, this::onStreamClosed);
Expand All @@ -193,7 +198,7 @@ private void onStreamClosed() {
synchronized (teeBuffer) {
bytes = teeBuffer.toByteArray();
}
tagSpanFromBuffer(span, bytes, timeToFirstTokenNanos.get());
tagSpanFromBuffer(span, bytes, timeToFirstTokenNanos.get(), requestBody);
} finally {
span.end();
}
Expand Down Expand Up @@ -287,7 +292,8 @@ private void notifyClosed() {
// Span tagging from buffered bytes
// -------------------------------------------------------------------------

private static void tagSpanFromBuffer(Span span, byte[] bytes, Long timeToFirstTokenNanos) {
private static void tagSpanFromBuffer(
Span span, byte[] bytes, Long timeToFirstTokenNanos, @Nullable String requestBody) {
if (bytes.length == 0) return;
try {
String firstLine = firstNonEmptyLine(bytes);
Expand All @@ -297,13 +303,15 @@ private static void tagSpanFromBuffer(Span span, byte[] bytes, Long timeToFirstT
firstLine != null
&& (firstLine.startsWith("data:") || firstLine.startsWith("event:"));
if (isSse) {
tagSpanFromSseBytes(span, bytes, timeToFirstTokenNanos);
tagSpanFromSseBytes(span, bytes, timeToFirstTokenNanos, requestBody);
} else {
// Non-streaming: plain Message JSON — pass it whole, no time_to_first_token
InstrumentationSemConv.tagLLMSpanResponse(
span,
InstrumentationSemConv.PROVIDER_NAME_ANTHROPIC,
new String(bytes, StandardCharsets.UTF_8));
new String(bytes, StandardCharsets.UTF_8),
null,
requestBody);
}
} catch (Exception e) {
log.error("Could not tag span from Anthropic response buffer", e);
Expand Down Expand Up @@ -338,7 +346,7 @@ private static String firstNonEmptyLine(byte[] bytes) {
* assembled {@link com.anthropic.models.messages.Message} for the span.
*/
private static void tagSpanFromSseBytes(
Span span, byte[] sseBytes, Long timeToFirstTokenNanos) {
Span span, byte[] sseBytes, Long timeToFirstTokenNanos, @Nullable String requestBody) {
try {
var mapper = BraintrustJsonMapper.get();
var reader =
Expand All @@ -362,7 +370,8 @@ private static void tagSpanFromSseBytes(
span,
InstrumentationSemConv.PROVIDER_NAME_ANTHROPIC,
assembledMessageJson,
timeToFirstTokenNanos);
timeToFirstTokenNanos,
requestBody);
} catch (Exception e) {
log.error("Could not parse Anthropic SSE buffer to tag streaming span output", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,22 @@ public static void tagLLMSpanResponse(
@Nonnull String providerName,
@Nonnull String responseBody,
@Nullable Long timeToFirstTokenNanoseconds) {
tagLLMSpanResponse(span, providerName, responseBody, timeToFirstTokenNanoseconds, null);
}

@SneakyThrows
public static void tagLLMSpanResponse(
Span span,
@Nonnull String providerName,
@Nonnull String responseBody,
@Nullable Long timeToFirstTokenNanoseconds,
@Nullable String requestBody) {
switch (providerName) {
case PROVIDER_NAME_OPENAI ->
tagOpenAIResponse(span, responseBody, timeToFirstTokenNanoseconds);
case PROVIDER_NAME_ANTHROPIC ->
tagAnthropicResponse(span, responseBody, timeToFirstTokenNanoseconds);
tagAnthropicResponse(
span, responseBody, timeToFirstTokenNanoseconds, requestBody);
case PROVIDER_NAME_BEDROCK ->
tagBedrockResponse(span, responseBody, timeToFirstTokenNanoseconds);
default -> tagOpenAIResponse(span, responseBody, timeToFirstTokenNanoseconds);
Expand Down Expand Up @@ -237,7 +248,10 @@ private static void tagAnthropicRequest(

@SneakyThrows
private static void tagAnthropicResponse(
Span span, String responseBody, @Nullable Long timeToFirstTokenNanoseconds) {
Span span,
String responseBody,
@Nullable Long timeToFirstTokenNanoseconds,
@Nullable String requestBody) {
JsonNode responseJson = BraintrustJsonMapper.get().readTree(responseBody);

// Anthropic response is the full Message object — output it whole
Expand All @@ -258,13 +272,67 @@ private static void tagAnthropicResponse(
"tokens",
usage.get("input_tokens").asLong() + usage.get("output_tokens").asLong());
}

// Prompt caching metrics
if (usage.has("cache_read_input_tokens")) {
metrics.put("prompt_cached_tokens", usage.get("cache_read_input_tokens"));
}
if (usage.has("cache_creation_input_tokens")) {
long cacheCreationTokens = usage.get("cache_creation_input_tokens").asLong();
metrics.put("prompt_cache_creation_tokens", cacheCreationTokens);

// Per-TTL breakdown: inspect the request to find which TTL tiers were used
if (requestBody != null && cacheCreationTokens > 0) {
addPerTtlCacheMetrics(metrics, requestBody, cacheCreationTokens);
}
}
}

if (!metrics.isEmpty()) {
span.setAttribute("braintrust.metrics", toJson(metrics));
}
}

/**
* Inspect the Anthropic request body for {@code cache_control} blocks and attribute the total
* {@code cache_creation_input_tokens} to per-TTL metrics. If all breakpoints share the same TTL
* (the common case), all creation tokens are attributed to that tier. When multiple TTLs are
* present, we cannot split the total so we attribute it to each tier (the API does not provide
* a per-breakpoint breakdown).
*/
@SneakyThrows
private static void addPerTtlCacheMetrics(
Map<String, Object> metrics, String requestBody, long cacheCreationTokens) {
JsonNode requestJson = BraintrustJsonMapper.get().readTree(requestBody);
java.util.Set<String> ttls = new java.util.LinkedHashSet<>();
collectCacheControlTtls(requestJson, ttls);
// Default TTL for Anthropic is 5m when cache_control is present but no ttl specified
if (ttls.isEmpty()) {
// There were cache_control blocks but none had an explicit ttl — default is 5m
ttls.add("5m");
}
for (String ttl : ttls) {
metrics.put("prompt_cache_creation_" + ttl + "_tokens", cacheCreationTokens);
}
}

/** Recursively collect all distinct {@code cache_control.ttl} values from the request JSON. */
private static void collectCacheControlTtls(JsonNode node, java.util.Set<String> ttls) {
if (node == null) return;
if (node.isObject()) {
if (node.has("cache_control") && node.get("cache_control").has("ttl")) {
ttls.add(node.get("cache_control").get("ttl").asText());
}
for (var it = node.fields(); it.hasNext(); ) {
collectCacheControlTtls(it.next().getValue(), ttls);
}
} else if (node.isArray()) {
for (JsonNode child : node) {
collectCacheControlTtls(child, ttls);
}
}
}

// -------------------------------------------------------------------------
// AWS Bedrock provider implementation
// -------------------------------------------------------------------------
Expand Down
20 changes: 19 additions & 1 deletion btx/src/test/java/dev/braintrust/sdkspecimpl/LlmSpanSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public record LlmSpanSpec(
String provider,
String endpoint,
String client,
Map<String, String> headers,
List<Map<String, Object>> requests,
List<Map<String, Object>> expectedBrainstoreSpans,
String sourcePath) {
Expand Down Expand Up @@ -59,11 +60,28 @@ static LlmSpanSpec fromMap(Map<String, Object> raw, String sourcePath, String cl
String provider = (String) raw.get("provider");
String endpoint = (String) raw.get("endpoint");

Map<String, String> headers = null;
if (raw.containsKey("headers")) {
Map<String, Object> rawHeaders = (Map<String, Object>) raw.get("headers");
headers = new java.util.LinkedHashMap<>();
for (var entry : rawHeaders.entrySet()) {
headers.put(entry.getKey(), String.valueOf(entry.getValue()));
}
}

List<Map<String, Object>> requests = (List<Map<String, Object>>) raw.get("requests");
List<Map<String, Object>> expectedSpans =
(List<Map<String, Object>>) raw.get("expected_brainstore_spans");

return new LlmSpanSpec(
name, type, provider, endpoint, client, requests, expectedSpans, sourcePath);
name,
type,
provider,
endpoint,
client,
headers,
requests,
expectedSpans,
sourcePath);
}
}
18 changes: 18 additions & 0 deletions btx/src/test/java/dev/braintrust/sdkspecimpl/SpanValidator.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,24 @@ private static void assertFnMatcher(Object actual, SpecMatcher.FnMatcher fn, Str
context, v));
}
}
case "is_positive_number" -> {
if (!(actual instanceof Number)) {
fail(
String.format(
"%s: is_positive_number: expected a Number but got %s"
+ " (value: %s)",
context,
actual == null ? "null" : actual.getClass().getSimpleName(),
actual));
}
double v = ((Number) actual).doubleValue();
if (v <= 0) {
fail(
String.format(
"%s: is_positive_number: value %s is not positive",
context, v));
}
}
case "is_non_empty_string" -> {
if (!(actual instanceof String) || ((String) actual).isEmpty()) {
fail(
Expand Down
Loading
Loading