Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 63 additions & 26 deletions sentry_sdk/integrations/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from typing import TYPE_CHECKING

import sentry_sdk
from sentry_sdk.ai.utils import get_start_span_function
from sentry_sdk.ai.utils import _set_span_data_attribute, get_start_span_function
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import Integration, DidNotEnable
from sentry_sdk.traces import StreamedSpan
from sentry_sdk.tracing_utils import has_span_streaming_enabled
from sentry_sdk.utils import safe_serialize
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.integrations._wsgi_common import nullcontext
Expand All @@ -33,8 +35,10 @@


if TYPE_CHECKING:
from typing import Any, Callable, Optional, Tuple, ContextManager
from typing import Any, Callable, Optional, Tuple, Union, ContextManager

from sentry_sdk.tracing import Span
from sentry_sdk.traces import StreamedSpan
from starlette.types import Receive, Scope, Send # type: ignore[import-not-found]


Expand Down Expand Up @@ -156,7 +160,7 @@ def _get_span_config(


def _set_span_input_data(
span: "Any",
span: "Union[StreamedSpan, Span]",
handler_name: str,
span_data_key: str,
mcp_method_name: str,
Expand All @@ -168,26 +172,28 @@ def _set_span_input_data(
"""Set input span data for MCP handlers."""

# Set handler identifier
span.set_data(span_data_key, handler_name)
span.set_data(SPANDATA.MCP_METHOD_NAME, mcp_method_name)
_set_span_data_attribute(span, span_data_key, handler_name)
_set_span_data_attribute(span, SPANDATA.MCP_METHOD_NAME, mcp_method_name)

# Set transport/MCP transport type
span.set_data(
SPANDATA.NETWORK_TRANSPORT, "pipe" if mcp_transport == "stdio" else "tcp"
_set_span_data_attribute(
span,
SPANDATA.NETWORK_TRANSPORT,
"pipe" if mcp_transport == "stdio" else "tcp",
)
span.set_data(SPANDATA.MCP_TRANSPORT, mcp_transport)
_set_span_data_attribute(span, SPANDATA.MCP_TRANSPORT, mcp_transport)

# Set request_id if provided
if request_id:
span.set_data(SPANDATA.MCP_REQUEST_ID, request_id)
_set_span_data_attribute(span, SPANDATA.MCP_REQUEST_ID, request_id)

# Set session_id if provided
if session_id:
span.set_data(SPANDATA.MCP_SESSION_ID, session_id)
_set_span_data_attribute(span, SPANDATA.MCP_SESSION_ID, session_id)

# Set request arguments (excluding common request context objects)
for k, v in arguments.items():
span.set_data(f"mcp.request.argument.{k}", safe_serialize(v))
_set_span_data_attribute(span, f"mcp.request.argument.{k}", safe_serialize(v))


def _extract_tool_result_content(result: "Any") -> "Any":
Expand Down Expand Up @@ -231,7 +237,10 @@ def _extract_tool_result_content(result: "Any") -> "Any":


def _set_span_output_data(
span: "Any", result: "Any", result_data_key: "Optional[str]", handler_type: str
span: "Union[StreamedSpan, Span]",
result: "Any",
result_data_key: "Optional[str]",
handler_type: str,
) -> None:
"""Set output span data for MCP handlers."""
if result is None:
Expand All @@ -248,11 +257,17 @@ def _set_span_output_data(
# For tools, extract the meaningful content
if handler_type == "tool":
extracted = _extract_tool_result_content(result)
if extracted is not None and should_include_data:
span.set_data(result_data_key, safe_serialize(extracted))
if (
extracted is not None
and should_include_data
and result_data_key is not None
):
_set_span_data_attribute(span, result_data_key, safe_serialize(extracted))
# Set content count if result is a dict
if isinstance(extracted, dict):
span.set_data(SPANDATA.MCP_TOOL_RESULT_CONTENT_COUNT, len(extracted))
_set_span_data_attribute(
span, SPANDATA.MCP_TOOL_RESULT_CONTENT_COUNT, len(extracted)
)
elif handler_type == "prompt":
# For prompts, count messages and set role/content only for single-message prompts
try:
Expand All @@ -270,7 +285,9 @@ def _set_span_output_data(

# Always set message count if we found messages
if message_count > 0:
span.set_data(SPANDATA.MCP_PROMPT_RESULT_MESSAGE_COUNT, message_count)
_set_span_data_attribute(
span, SPANDATA.MCP_PROMPT_RESULT_MESSAGE_COUNT, message_count
)

# Only set role and content for single-message prompts if PII is allowed
if message_count == 1 and should_include_data and messages:
Expand All @@ -283,7 +300,9 @@ def _set_span_output_data(
role = first_message["role"]

if role:
span.set_data(SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE, role)
_set_span_data_attribute(
span, SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE, role
)

# Extract content text
content_text = None
Expand All @@ -303,8 +322,8 @@ def _set_span_output_data(
elif isinstance(msg_content, str):
content_text = msg_content

if content_text:
span.set_data(result_data_key, content_text)
if content_text and result_data_key is not None:
_set_span_data_attribute(span, result_data_key, content_text)
except Exception:
# Silently ignore if we can't extract message info
pass
Expand Down Expand Up @@ -434,14 +453,28 @@ async def _handler_wrapper(
# Get request ID, session ID, and transport from context
request_id, session_id, mcp_transport = _get_request_context_data()

span_streaming = has_span_streaming_enabled(sentry_sdk.get_client().options)

# Start span and execute
with isolation_scope_context:
with current_scope_context:
with get_start_span_function()(
op=OP.MCP_SERVER,
name=span_name,
origin=MCPIntegration.origin,
) as span:
span_mgr: "Union[Span, StreamedSpan]"
if span_streaming:
span_mgr = sentry_sdk.traces.start_span(
name=span_name,
attributes={
"sentry.op": OP.MCP_SERVER,
"sentry.origin": MCPIntegration.origin,
},
)
else:
span_mgr = get_start_span_function()(
op=OP.MCP_SERVER,
name=span_name,
origin=MCPIntegration.origin,
)

with span_mgr as span:
# Set input span data
_set_span_input_data(
span,
Expand All @@ -467,7 +500,9 @@ async def _handler_wrapper(
elif handler_name and "://" in handler_name:
protocol = handler_name.split("://")[0]
if protocol:
span.set_data(SPANDATA.MCP_RESOURCE_PROTOCOL, protocol)
_set_span_data_attribute(
span, SPANDATA.MCP_RESOURCE_PROTOCOL, protocol
)

try:
# Execute the async handler
Expand All @@ -481,7 +516,9 @@ async def _handler_wrapper(
except Exception as e:
# Set error flag for tools
if handler_type == "tool":
span.set_data(SPANDATA.MCP_TOOL_RESULT_IS_ERROR, True)
_set_span_data_attribute(
span, SPANDATA.MCP_TOOL_RESULT_IS_ERROR, True
)
sentry_sdk.capture_exception(e)
raise

Expand Down
Loading
Loading