From 0d609de21e7e5b234e260caa521efd30a5841568 Mon Sep 17 00:00:00 2001 From: Scott Trinh Date: Tue, 5 Mar 2024 10:13:29 -0500 Subject: [PATCH] Return JSON for magic link register (#6974) --- edb/server/protocol/auth_ext/http.py | 49 ++++++++++++++++------------ tests/test_http_ext_auth.py | 5 ++- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 86be3537ee..3604c8f7cb 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -872,6 +872,14 @@ async def handle_magic_link_register(self, request: Any, response: Any): tenant=self.tenant, test_mode=self.test_mode, ) + + request_accepts_json: bool = request.accept == b"application/json" + + if not request_accepts_json and not maybe_redirect_to: + raise errors.InvalidData( + "Request must accept JSON or provide a redirect URL." + ) + try: await magic_link_client.register( email=email, @@ -888,34 +896,33 @@ async def handle_magic_link_register(self, request: Any, response: Any): "email_sent": email, } - if maybe_redirect_to: + if request_accepts_json: + response.status = http.HTTPStatus.OK + response.content_type = b"application/json" + response.body = json.dumps(return_data).encode() + elif maybe_redirect_to: response.status = http.HTTPStatus.FOUND response.custom_headers["Location"] = util.join_url_params( maybe_redirect_to, return_data ) else: - response.status = http.HTTPStatus.OK - response.content_type = b"application/json" - response.body = json.dumps(return_data).encode() + # This should not happen since we check earlier for this case + # but this seems safer than a cast + raise errors.InvalidData( + "Request must accept JSON or provide a redirect URL." + ) except Exception as ex: - redirect_on_failure = data.get( - "redirect_on_failure", maybe_redirect_to - ) - if redirect_on_failure is None: + if request_accepts_json: raise ex - else: - if not self._is_url_allowed(redirect_on_failure): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) - response.status = http.HTTPStatus.FOUND - redirect_params = { - "error": str(ex), - "email": data.get('email', ''), - } - response.custom_headers["Location"] = util.join_url_params( - redirect_on_failure, redirect_params - ) + + response.status = http.HTTPStatus.FOUND + redirect_params = { + "error": str(ex), + "email": data.get('email', ''), + } + response.custom_headers["Location"] = util.join_url_params( + redirect_on_failure, redirect_params + ) async def handle_magic_link_email(self, request: Any, response: Any): data = self._get_data_from_request(request) diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 929e6460f9..6e8fb9beb0 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -3837,7 +3837,10 @@ async def test_http_auth_ext_magic_link_01(self): "redirect_on_failure": redirect_on_failure, } ).encode(), - headers={"Content-Type": "application/json"}, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, ) self.assertEqual(status, 200)