diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index b6e546e336..3be5fe6779 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -118,6 +118,8 @@ class OP: HTTP_SERVER = "http.server" MIDDLEWARE_DJANGO = "middleware.django" MIDDLEWARE_STARLETTE = "middleware.starlette" + MIDDLEWARE_STARLETTE_RECEIVE = "middleware.starlette.receive" + MIDDLEWARE_STARLETTE_SEND = "middleware.starlette.send" QUEUE_SUBMIT_CELERY = "queue.submit.celery" QUEUE_TASK_CELERY = "queue.task.celery" QUEUE_TASK_RQ = "queue.task.rq" diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index dffba5afd5..aaf7fb3dc4 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -85,21 +85,49 @@ def _enable_span_for_middleware(middleware_class): # type: (Any) -> type old_call = middleware_class.__call__ - async def _create_span_call(*args, **kwargs): - # type: (Any, Any) -> None + async def _create_span_call(app, scope, receive, send, **kwargs): + # type: (Any, Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]], Any) -> None hub = Hub.current integration = hub.get_integration(StarletteIntegration) if integration is not None: - middleware_name = args[0].__class__.__name__ + middleware_name = app.__class__.__name__ + with hub.start_span( op=OP.MIDDLEWARE_STARLETTE, description=middleware_name ) as middleware_span: middleware_span.set_tag("starlette.middleware_name", middleware_name) - await old_call(*args, **kwargs) + # Creating spans for the "receive" callback + async def _sentry_receive(*args, **kwargs): + # type: (*Any, **Any) -> Any + hub = Hub.current + with hub.start_span( + op=OP.MIDDLEWARE_STARLETTE_RECEIVE, + description=receive.__qualname__, + ) as span: + span.set_tag("starlette.middleware_name", middleware_name) + await receive(*args, **kwargs) + + receive_patched = receive.__name__ == "_sentry_receive" + new_receive = _sentry_receive if not receive_patched else receive + + # Creating spans for the "send" callback + async def _sentry_send(*args, **kwargs): + # type: (*Any, **Any) -> Any + hub = Hub.current + with hub.start_span( + op=OP.MIDDLEWARE_STARLETTE_SEND, description=send.__qualname__ + ) as span: + span.set_tag("starlette.middleware_name", middleware_name) + await send(*args, **kwargs) + + send_patched = send.__name__ == "_sentry_send" + new_send = _sentry_send if not send_patched else send + + await old_call(app, scope, new_receive, new_send, **kwargs) else: - await old_call(*args, **kwargs) + await old_call(app, scope, receive, send, **kwargs) not_yet_patched = old_call.__name__ not in [ "_create_span_call", diff --git a/tests/integrations/starlette/test_starlette.py b/tests/integrations/starlette/test_starlette.py index 24254b69ef..29e5916adb 100644 --- a/tests/integrations/starlette/test_starlette.py +++ b/tests/integrations/starlette/test_starlette.py @@ -31,6 +31,8 @@ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.testclient import TestClient +STARLETTE_VERSION = tuple([int(x) for x in starlette.__version__.split(".")]) + PICTURE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "photo.jpg") BODY_JSON = {"some": "json", "for": "testing", "nested": {"numbers": 123}} @@ -152,6 +154,26 @@ async def __anext__(self): raise StopAsyncIteration +class SampleMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + # only handle http requests + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + async def do_stuff(message): + if message["type"] == "http.response.start": + # do something here. + pass + + await send(message) + + await self.app(scope, receive, do_stuff) + + @pytest.mark.asyncio async def test_starlettrequestextractor_content_length(sentry_init): with mock.patch( @@ -546,6 +568,82 @@ def test_middleware_spans(sentry_init, capture_events): idx += 1 +def test_middleware_callback_spans(sentry_init, capture_events): + sentry_init( + traces_sample_rate=1.0, + integrations=[StarletteIntegration()], + ) + starlette_app = starlette_app_factory(middleware=[Middleware(SampleMiddleware)]) + events = capture_events() + + client = TestClient(starlette_app, raise_server_exceptions=False) + try: + client.get("/message", auth=("Gabriela", "hello123")) + except Exception: + pass + + (_, transaction_event) = events + + expected = [ + { + "op": "middleware.starlette", + "description": "ServerErrorMiddleware", + "tags": {"starlette.middleware_name": "ServerErrorMiddleware"}, + }, + { + "op": "middleware.starlette", + "description": "SampleMiddleware", + "tags": {"starlette.middleware_name": "SampleMiddleware"}, + }, + { + "op": "middleware.starlette", + "description": "ExceptionMiddleware", + "tags": {"starlette.middleware_name": "ExceptionMiddleware"}, + }, + { + "op": "middleware.starlette.send", + "description": "SampleMiddleware.__call__..do_stuff", + "tags": {"starlette.middleware_name": "ExceptionMiddleware"}, + }, + { + "op": "middleware.starlette.send", + "description": "ServerErrorMiddleware.__call__.._send", + "tags": {"starlette.middleware_name": "SampleMiddleware"}, + }, + { + "op": "middleware.starlette.send", + "description": "_ASGIAdapter.send..send" + if STARLETTE_VERSION < (0, 21) + else "_TestClientTransport.handle_request..send", + "tags": {"starlette.middleware_name": "ServerErrorMiddleware"}, + }, + { + "op": "middleware.starlette.send", + "description": "SampleMiddleware.__call__..do_stuff", + "tags": {"starlette.middleware_name": "ExceptionMiddleware"}, + }, + { + "op": "middleware.starlette.send", + "description": "ServerErrorMiddleware.__call__.._send", + "tags": {"starlette.middleware_name": "SampleMiddleware"}, + }, + { + "op": "middleware.starlette.send", + "description": "_ASGIAdapter.send..send" + if STARLETTE_VERSION < (0, 21) + else "_TestClientTransport.handle_request..send", + "tags": {"starlette.middleware_name": "ServerErrorMiddleware"}, + }, + ] + + idx = 0 + for span in transaction_event["spans"]: + assert span["op"] == expected[idx]["op"] + assert span["description"] == expected[idx]["description"] + assert span["tags"] == expected[idx]["tags"] + idx += 1 + + def test_last_event_id(sentry_init, capture_events): sentry_init( integrations=[StarletteIntegration()], diff --git a/tox.ini b/tox.ini index d2bf7fa2b1..8b19296671 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ envlist = {py3.7,py3.8,py3.9,py3.10}-asgi - {py3.7,py3.8,py3.9,py3.10}-starlette-{0.19.1,0.20} + {py3.7,py3.8,py3.9,py3.10}-starlette-{0.19.1,0.20,0.21} {py3.7,py3.8,py3.9,py3.10}-fastapi @@ -152,8 +152,10 @@ deps = starlette: pytest-asyncio starlette: python-multipart starlette: requests + starlette-0.21: httpx starlette-0.19.1: starlette==0.19.1 starlette-0.20: starlette>=0.20.0,<0.21.0 + starlette-0.21: starlette>=0.21.0,<0.22.0 fastapi: fastapi fastapi: pytest-asyncio