diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index dac9887e2f..036b797685 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -1,5 +1,6 @@ import asyncio import functools +import json import warnings from collections.abc import Set from copy import deepcopy @@ -20,10 +21,12 @@ ) from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.traces import NoOpStreamedSpan, StreamedSpan from sentry_sdk.tracing import ( SOURCE_FOR_STYLE, TransactionSource, ) +from sentry_sdk.tracing_utils import has_span_streaming_enabled from sentry_sdk.utils import ( AnnotatedValue, capture_internal_exceptions, @@ -147,7 +150,8 @@ async def _create_span_call( send: "Callable[[Dict[str, Any]], Awaitable[None]]", **kwargs: "Any", ) -> None: - integration = sentry_sdk.get_client().get_integration(StarletteIntegration) + client = sentry_sdk.get_client() + integration = client.get_integration(StarletteIntegration) if integration is None: return await old_call(app, scope, receive, send, **kwargs) @@ -164,22 +168,38 @@ async def _create_span_call( return await old_call(app, scope, receive, send, **kwargs) middleware_name = app.__class__.__name__ + is_span_streaming_enabled = has_span_streaming_enabled(client.options) + + def _start_middleware_span(op: str, name: str) -> "Any": + if is_span_streaming_enabled: + return sentry_sdk.traces.start_span( + name=name, + attributes={ + "sentry.op": op, + "sentry.origin": StarletteIntegration.origin, + "middleware.name": middleware_name, + }, + ) + return sentry_sdk.start_span( + op=op, + name=name, + origin=StarletteIntegration.origin, + ) - with sentry_sdk.start_span( - op=OP.MIDDLEWARE_STARLETTE, - name=middleware_name, - origin=StarletteIntegration.origin, + with _start_middleware_span( + op=OP.MIDDLEWARE_STARLETTE, name=middleware_name ) as middleware_span: - middleware_span.set_tag("starlette.middleware_name", middleware_name) + if not is_span_streaming_enabled: + middleware_span.set_tag("starlette.middleware_name", middleware_name) # Creating spans for the "receive" callback async def _sentry_receive(*args: "Any", **kwargs: "Any") -> "Any": - with sentry_sdk.start_span( + with _start_middleware_span( op=OP.MIDDLEWARE_STARLETTE_RECEIVE, name=getattr(receive, "__qualname__", str(receive)), - origin=StarletteIntegration.origin, ) as span: - span.set_tag("starlette.middleware_name", middleware_name) + if not is_span_streaming_enabled: + span.set_tag("starlette.middleware_name", middleware_name) return await receive(*args, **kwargs) receive_name = getattr(receive, "__name__", str(receive)) @@ -188,12 +208,12 @@ async def _sentry_receive(*args: "Any", **kwargs: "Any") -> "Any": # Creating spans for the "send" callback async def _sentry_send(*args: "Any", **kwargs: "Any") -> "Any": - with sentry_sdk.start_span( + with _start_middleware_span( op=OP.MIDDLEWARE_STARLETTE_SEND, name=getattr(send, "__qualname__", str(send)), - origin=StarletteIntegration.origin, ) as span: - span.set_tag("starlette.middleware_name", middleware_name) + if not is_span_streaming_enabled: + span.set_tag("starlette.middleware_name", middleware_name) return await send(*args, **kwargs) send_name = getattr(send, "__name__", str(send)) @@ -214,6 +234,16 @@ async def _sentry_send(*args: "Any", **kwargs: "Any") -> "Any": return middleware_class +def _serialize_body_data(data: "Any") -> str: + # data may be a JSON-serializable value, an AnnotatedValue, or a dict with AnnotatedValue values + def _default(value: "Any") -> "Any": + if isinstance(value, AnnotatedValue): + return {"value": value.value, "metadata": value.metadata} + return str(value) + + return json.dumps(data, default=_default) + + @ensure_integration_enabled(StarletteIntegration) def _capture_exception(exception: BaseException, handled: "Any" = False) -> None: event, hint = event_from_exception( @@ -439,9 +469,8 @@ def _sentry_request_response(func: "Callable[[Any], Any]") -> "ASGIApp": if is_coroutine: async def _sentry_async_func(*args: "Any", **kwargs: "Any") -> "Any": - integration = sentry_sdk.get_client().get_integration( - StarletteIntegration - ) + client = sentry_sdk.get_client() + integration = client.get_integration(StarletteIntegration) if integration is None: return await old_func(*args, **kwargs) @@ -481,6 +510,22 @@ def event_processor( _make_request_event_processor(request, integration) ) + is_span_streaming_enabled = has_span_streaming_enabled(client.options) + if is_span_streaming_enabled: + current_span = sentry_sdk.get_current_span() + + if ( + info + and "data" in info + and isinstance(current_span, StreamedSpan) + and not isinstance(current_span, NoOpStreamedSpan) + ): + data = info["data"] + current_span._segment.set_attribute( + "http.request.body.data", + _serialize_body_data(data), + ) + return await old_func(*args, **kwargs) func = _sentry_async_func @@ -496,7 +541,13 @@ def _sentry_sync_func(*args: "Any", **kwargs: "Any") -> "Any": return old_func(*args, **kwargs) current_scope = sentry_sdk.get_current_scope() - if current_scope.transaction is not None: + current_span = current_scope.span + + if isinstance(current_span, StreamedSpan) and not isinstance( + current_span, NoOpStreamedSpan + ): + current_span._segment._update_active_thread() + elif current_scope.transaction is not None: current_scope.transaction.update_active_thread() sentry_scope = sentry_sdk.get_isolation_scope() diff --git a/tests/integrations/starlette/test_starlette.py b/tests/integrations/starlette/test_starlette.py index 801cd53bf4..a673b785fb 100644 --- a/tests/integrations/starlette/test_starlette.py +++ b/tests/integrations/starlette/test_starlette.py @@ -11,6 +11,7 @@ import pytest +import sentry_sdk from sentry_sdk import capture_message, get_baggage, get_traceparent from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from sentry_sdk.integrations.starlette import ( @@ -648,15 +649,23 @@ def test_user_information_transaction_no_pii(sentry_init, capture_events): assert "user" not in transaction_event -def test_middleware_spans(sentry_init, capture_events): +@pytest.mark.parametrize("span_streaming", [True, False]) +def test_middleware_spans(sentry_init, capture_events, capture_items, span_streaming): sentry_init( traces_sample_rate=1.0, integrations=[StarletteIntegration(middleware_spans=True)], + _experiments={ + "trace_lifecycle": "stream" if span_streaming else "static", + }, ) starlette_app = starlette_app_factory( middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())] ) - events = capture_events() + + if span_streaming: + items = capture_items("span") + else: + events = capture_events() client = TestClient(starlette_app, raise_server_exceptions=False) try: @@ -664,8 +673,6 @@ def test_middleware_spans(sentry_init, capture_events): except Exception: pass - (_, transaction_event) = events - expected_middleware_spans = [ "ServerErrorMiddleware", "AuthenticationMiddleware", @@ -676,27 +683,60 @@ def test_middleware_spans(sentry_init, capture_events): "ServerErrorMiddleware", # 'op': 'middleware.starlette.send' ] - assert len(transaction_event["spans"]) == len(expected_middleware_spans) + if span_streaming: + sentry_sdk.flush() + + middleware_spans = sorted( + [ + item.payload + for item in items + if item.payload.get("attributes", {}) + .get("sentry.op", "") + .startswith("middleware.starlette") + ], + key=lambda s: s["start_timestamp"], + ) - idx = 0 - for span in transaction_event["spans"]: - if span["op"].startswith("middleware.starlette"): + assert len(middleware_spans) == len(expected_middleware_spans) + + for idx, span in enumerate(middleware_spans): assert ( - span["tags"]["starlette.middleware_name"] - == expected_middleware_spans[idx] + span["attributes"]["middleware.name"] == expected_middleware_spans[idx] ) - idx += 1 + else: + (_, transaction_event) = events + assert len(transaction_event["spans"]) == len(expected_middleware_spans) -def test_middleware_spans_disabled(sentry_init, capture_events): + idx = 0 + for span in transaction_event["spans"]: + if span["op"].startswith("middleware.starlette"): + assert ( + span["tags"]["starlette.middleware_name"] + == expected_middleware_spans[idx] + ) + idx += 1 + + +@pytest.mark.parametrize("span_streaming", [True, False]) +def test_middleware_spans_disabled( + sentry_init, capture_events, capture_items, span_streaming +): sentry_init( traces_sample_rate=1.0, integrations=[StarletteIntegration(middleware_spans=False)], + _experiments={ + "trace_lifecycle": "stream" if span_streaming else "static", + }, ) starlette_app = starlette_app_factory( middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())] ) - events = capture_events() + + if span_streaming: + items = capture_items("span") + else: + events = capture_events() client = TestClient(starlette_app, raise_server_exceptions=False) try: @@ -704,18 +744,39 @@ def test_middleware_spans_disabled(sentry_init, capture_events): except Exception: pass - (_, transaction_event) = events - - assert len(transaction_event["spans"]) == 0 + if span_streaming: + sentry_sdk.flush() + + middleware_spans = [ + item.payload + for item in items + if item.payload.get("attributes", {}) + .get("sentry.op", "") + .startswith("middleware.starlette") + ] + assert len(middleware_spans) == 0 + else: + (_, transaction_event) = events + assert len(transaction_event["spans"]) == 0 -def test_middleware_callback_spans(sentry_init, capture_events): +@pytest.mark.parametrize("span_streaming", [True, False]) +def test_middleware_callback_spans( + sentry_init, capture_events, capture_items, span_streaming +): sentry_init( traces_sample_rate=1.0, - integrations=[StarletteIntegration()], + integrations=[StarletteIntegration(middleware_spans=True)], + _experiments={ + "trace_lifecycle": "stream" if span_streaming else "static", + }, ) starlette_app = starlette_app_factory(middleware=[Middleware(SampleMiddleware)]) - events = capture_events() + + if span_streaming: + items = capture_items("span") + else: + events = capture_events() client = TestClient(starlette_app, raise_server_exceptions=False) try: @@ -723,8 +784,6 @@ def test_middleware_callback_spans(sentry_init, capture_events): except Exception: pass - (_, transaction_event) = events - expected = [ { "op": "middleware.starlette", @@ -773,12 +832,37 @@ def test_middleware_callback_spans(sentry_init, capture_events): }, ] - idx = 0 - for span in transaction_event["spans"]: - assert span["op"] == expected[idx]["op"] - assert span["description"] == expected[idx]["description"] - assert span["tags"] == expected[idx]["tags"] - idx += 1 + if span_streaming: + sentry_sdk.flush() + + middleware_spans = sorted( + [ + item.payload + for item in items + if item.payload.get("attributes", {}) + .get("sentry.op", "") + .startswith("middleware.starlette") + ], + key=lambda s: s["start_timestamp"], + ) + + assert len(middleware_spans) == len(expected) + for span, exp in zip(middleware_spans, expected): + assert span["attributes"]["sentry.op"] == exp["op"] + assert span["name"] == exp["description"] + assert ( + span["attributes"]["middleware.name"] + == exp["tags"]["starlette.middleware_name"] + ) + else: + (_, transaction_event) = events + + idx = 0 + for span in transaction_event["spans"]: + assert span["op"] == expected[idx]["op"] + assert span["description"] == expected[idx]["description"] + assert span["tags"] == expected[idx]["tags"] + idx += 1 def test_middleware_receive_send(sentry_init, capture_events): @@ -946,6 +1030,158 @@ def test_active_thread_id(sentry_init, capture_envelopes, teardown_profiling, en assert str(data["active"]) == trace_context["data"]["thread.id"] +@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"]) +def test_active_thread_id_span_streaming(sentry_init, capture_items, endpoint): + sentry_init( + auto_enabling_integrations=False, # avoid legacy spans from auto-enabled integrations leaking into streaming mode + integrations=[StarletteIntegration()], + traces_sample_rate=1.0, + _experiments={"trace_lifecycle": "stream"}, + ) + app = starlette_app_factory() + + items = capture_items("span") + + client = TestClient(app) + response = client.get(endpoint) + assert response.status_code == 200 + + data = json.loads(response.content) + + sentry_sdk.flush() + + segments = [item.payload for item in items if item.payload.get("is_segment")] + assert len(segments) == 1 + assert str(data["active"]) == segments[0]["attributes"]["thread.id"] + + +def _post_body_app(handler_awaitable): + async def _handler(request): + await handler_awaitable(request) + return starlette.responses.JSONResponse({"ok": True}) + + return starlette.applications.Starlette( + routes=[starlette.routing.Route("/body", _handler, methods=["POST"])], + ) + + +@pytest.mark.parametrize("middleware_spans", [False, True]) +def test_request_body_data_does_not_scrub_pii_span_streaming( + sentry_init, capture_items, middleware_spans +): + sentry_init( + auto_enabling_integrations=False, + integrations=[StarletteIntegration(middleware_spans=middleware_spans)], + traces_sample_rate=1.0, + _experiments={"trace_lifecycle": "stream"}, + ) + + async def _read_json(request): + await request.json() + + items = capture_items("span") + + client = TestClient(_post_body_app(_read_json)) + response = client.post( + "/body", + json={ + "password": "ohno", + "authorization": "Bearer token", + "message": "hello", + }, + ) + assert response.status_code == 200 + + sentry_sdk.flush() + + segments = [item.payload for item in items if item.payload.get("is_segment")] + assert len(segments) == 1 + attr = segments[0]["attributes"]["http.request.body.data"] + + # Going forward, the sanitization of data will need to happen within the `before_send_span` hooks + # See https://sentry.slack.com/archives/C09RR0KD2N7/p1776951331206129?thread_ts=1776951227.440659&cid=C09RR0KD2N7 + assert "ohno" in attr + assert "Bearer token" in attr + assert "hello" in attr + + +@pytest.mark.skipif( + STARLETTE_VERSION < (0, 21), + reason="Requires Starlette >= 0.21, because earlier versions use a requests-based TestClient which does not support the 'content' kwarg", +) +@pytest.mark.parametrize("middleware_spans", [False, True]) +def test_request_body_data_annotated_value_top_level_span_streaming( + sentry_init, capture_items, middleware_spans +): + sentry_init( + auto_enabling_integrations=False, + integrations=[StarletteIntegration(middleware_spans=middleware_spans)], + traces_sample_rate=1.0, + _experiments={"trace_lifecycle": "stream"}, + ) + + async def _read_body(request): + await request.body() + + items = capture_items("span") + + client = TestClient(_post_body_app(_read_body)) + response = client.post( + "/body", + content=b"not json and not form", + headers={"content-type": "application/octet-stream"}, + ) + assert response.status_code == 200 + + sentry_sdk.flush() + + segments = [item.payload for item in items if item.payload.get("is_segment")] + assert len(segments) == 1 + attr = segments[0]["attributes"]["http.request.body.data"] + + assert isinstance(attr, str) + assert "!raw" in attr + + +@pytest.mark.parametrize("middleware_spans", [False, True]) +def test_request_body_data_annotated_value_nested_span_streaming( + sentry_init, capture_items, middleware_spans +): + pytest.importorskip("multipart") + + sentry_init( + auto_enabling_integrations=False, + integrations=[StarletteIntegration(middleware_spans=middleware_spans)], + traces_sample_rate=1.0, + _experiments={"trace_lifecycle": "stream"}, + ) + + async def _read_form(request): + await request.form() + + items = capture_items("span") + + client = TestClient(_post_body_app(_read_form)) + response = client.post( + "/body", + data={"name": "erica"}, + files={"avatar": ("photo.jpg", b"fake-bytes", "image/jpeg")}, + ) + assert response.status_code == 200 + + sentry_sdk.flush() + + segments = [item.payload for item in items if item.payload.get("is_segment")] + assert len(segments) == 1 + attr = segments[0]["attributes"]["http.request.body.data"] + + assert isinstance(attr, str) + parsed = json.loads(attr) + assert parsed["name"] == "erica" + assert parsed["avatar"]["metadata"]["rem"] == [["!raw", "x"]] + assert "fake-bytes" not in attr + + def test_original_request_not_scrubbed(sentry_init, capture_events): sentry_init(integrations=[StarletteIntegration()]) @@ -1167,15 +1403,24 @@ def test_transaction_name_in_middleware( ) -def test_span_origin(sentry_init, capture_events): +@pytest.mark.parametrize("span_streaming", [True, False]) +def test_span_origin(sentry_init, capture_events, capture_items, span_streaming): sentry_init( - integrations=[StarletteIntegration()], + auto_enabling_integrations=False, # avoid httpx auto-instrumentation leaking spans + integrations=[StarletteIntegration(middleware_spans=True)], traces_sample_rate=1.0, + _experiments={ + "trace_lifecycle": "stream" if span_streaming else "static", + }, ) starlette_app = starlette_app_factory( middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())] ) - events = capture_events() + + if span_streaming: + items = capture_items("span") + else: + events = capture_events() client = TestClient(starlette_app, raise_server_exceptions=False) try: @@ -1183,11 +1428,18 @@ def test_span_origin(sentry_init, capture_events): except Exception: pass - (_, event) = events + if span_streaming: + sentry_sdk.flush() + + assert len(items) > 0 + for item in items: + assert item.payload["attributes"]["sentry.origin"] == "auto.http.starlette" + else: + (_, event) = events - assert event["contexts"]["trace"]["origin"] == "auto.http.starlette" - for span in event["spans"]: - assert span["origin"] == "auto.http.starlette" + assert event["contexts"]["trace"]["origin"] == "auto.http.starlette" + for span in event["spans"]: + assert span["origin"] == "auto.http.starlette" class NonIterableContainer: