Skip to content

Commit

Permalink
Updated RemoteUser middleware to use native async logic
Browse files Browse the repository at this point in the history
  • Loading branch information
bigfootjon committed Apr 6, 2024
1 parent fe6ce51 commit b0dc084
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 12 deletions.
85 changes: 81 additions & 4 deletions django/contrib/auth/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import partial

from asgiref.sync import iscoroutinefunction, markcoroutinefunction

from django.contrib import auth
from django.contrib.auth import load_backend
from django.contrib.auth.backends import RemoteUserBackend
Expand Down Expand Up @@ -34,7 +36,7 @@ def process_request(self, request):
request.auser = partial(auser, request)


class RemoteUserMiddleware(MiddlewareMixin):
class RemoteUserMiddleware:
"""
Middleware for utilizing web-server-provided authentication.
Expand All @@ -48,13 +50,27 @@ class RemoteUserMiddleware(MiddlewareMixin):
different header.
"""

sync_capable = True
async_capable = True

def __init__(self, get_response):
if get_response is None:
raise ValueError("get_response must be provided.")
self.get_response = get_response
self.async_mode = iscoroutinefunction(get_response)
if self.async_mode:
markcoroutinefunction(self)
super().__init__()

# Name of request header to grab username from. This will be the key as
# used in the request.META dictionary, i.e. the normalization of headers to
# all uppercase and the addition of "HTTP_" prefix apply.
header = "REMOTE_USER"
force_logout_if_no_header = True

def process_request(self, request):
def __call__(self, request):
if self.async_mode:
return self.__acall__(request)
# AuthenticationMiddleware is required so that request.user exists.
if not hasattr(request, "user"):
raise ImproperlyConfigured(
Expand All @@ -72,13 +88,13 @@ def process_request(self, request):
# AnonymousUser by the AuthenticationMiddleware).
if self.force_logout_if_no_header and request.user.is_authenticated:
self._remove_invalid_user(request)
return
return self.get_response(request)
# If the user is already authenticated and that user is the user we are
# getting passed in the headers, then the correct user is already
# persisted in the session and we don't need to continue.
if request.user.is_authenticated:
if request.user.get_username() == self.clean_username(username, request):
return
return self.get_response(request)
else:
# An authenticated user is associated with the request, but
# it does not match the authorized user in the header.
Expand All @@ -92,6 +108,51 @@ def process_request(self, request):
# by logging the user in.
request.user = user
auth.login(request, user)
return self.get_response(request)

async def __acall__(self, request):
# AuthenticationMiddleware is required so that request.user exists.
if not hasattr(request, "user"):
raise ImproperlyConfigured(
"The Django remote user auth middleware requires the"
" authentication middleware to be installed. Edit your"
" MIDDLEWARE setting to insert"
" 'django.contrib.auth.middleware.AuthenticationMiddleware'"
" before the RemoteUserMiddleware class."
)
try:
username = request.META["HTTP_" + self.header]
except KeyError:
# If specified header doesn't exist then remove any existing
# authenticated remote-user, or return (leaving request.user set to
# AnonymousUser by the AuthenticationMiddleware).
if self.force_logout_if_no_header:
user = await request.auser()
if user.is_authenticated:
await self._aremove_invalid_user(request)
return await self.get_response(request)
user = await request.auser()
# If the user is already authenticated and that user is the user we are
# getting passed in the headers, then the correct user is already
# persisted in the session and we don't need to continue.
if user.is_authenticated:
if user.get_username() == self.clean_username(username, request):
return await self.get_response(request)
else:
# An authenticated user is associated with the request, but
# it does not match the authorized user in the header.
await self._aremove_invalid_user(request)

# We are seeing this user for the first time in this session, attempt
# to authenticate the user.
user = await auth.aauthenticate(request, remote_user=username)
if user:
# User is valid. Set request.user and persist user in the session
# by logging the user in.
request.user = user
await auth.alogin(request, user)

return await self.get_response(request)

def clean_username(self, username, request):
"""
Expand Down Expand Up @@ -122,6 +183,22 @@ def _remove_invalid_user(self, request):
if isinstance(stored_backend, RemoteUserBackend):
auth.logout(request)

async def _aremove_invalid_user(self, request):
"""
Remove the current authenticated user in the request which is invalid
but only if the user is authenticated via the RemoteUserBackend.
"""
try:
stored_backend = load_backend(
await request.session.aget(auth.BACKEND_SESSION_KEY, "")
)
except ImportError:
# backend failed to load
await auth.alogout(request)
else:
if isinstance(stored_backend, RemoteUserBackend):
await auth.alogout(request)


class PersistentRemoteUserMiddleware(RemoteUserMiddleware):
"""
Expand Down
2 changes: 1 addition & 1 deletion django/test/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def generic(
if headers:
extra.update(HttpHeaders.to_asgi_names(headers))
s["headers"] += [
(key.lower().encode("ascii"), value.encode("latin1"))
(key.lower().encode("ascii"), (value or "").encode("latin1"))
for key, value in extra.items()
]
return self.request(**s)
Expand Down
191 changes: 189 additions & 2 deletions tests/auth_tests/test_remote_user.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from datetime import datetime, timezone

from django.conf import settings
from django.contrib.auth import authenticate
from django.contrib.auth import aauthenticate, authenticate
from django.contrib.auth.backends import RemoteUserBackend
from django.contrib.auth.middleware import RemoteUserMiddleware
from django.contrib.auth.models import User
from django.middleware.csrf import _get_new_csrf_string, _mask_cipher_secret
from django.test import Client, TestCase, modify_settings, override_settings
from django.test import (
AsyncClient,
Client,
TestCase,
modify_settings,
override_settings,
)


@override_settings(ROOT_URLCONF="auth_tests.urls")
Expand Down Expand Up @@ -170,6 +176,187 @@ def test_inactive_user(self):
self.assertTrue(response.context["user"].is_anonymous)


@override_settings(ROOT_URLCONF="auth_tests.urls")
class AsyncRemoteUserTest(TestCase):
middleware = "django.contrib.auth.middleware.RemoteUserMiddleware"
backend = "django.contrib.auth.backends.RemoteUserBackend"
header = "REMOTE_USER"
email_header = "REMOTE_EMAIL"

# Usernames to be passed in REMOTE_USER for the test_known_user test case.
known_user = "knownuser"
known_user2 = "knownuser2"

@classmethod
def setUpClass(cls):
cls.enterClassContext(
modify_settings(
AUTHENTICATION_BACKENDS={"append": cls.backend},
MIDDLEWARE={"append": cls.middleware},
)
)
super().setUpClass()

async def test_no_remote_user(self):
"""Users are not created when remote user is not specified."""
num_users = await User.objects.acount()

response = await self.async_client.get("/remote_user/")
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(await User.objects.acount(), num_users)

response = await self.async_client.get("/remote_user/", **{self.header: None})
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(await User.objects.acount(), num_users)

response = await self.async_client.get("/remote_user/", **{self.header: ""})
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(await User.objects.acount(), num_users)

async def test_csrf_validation_passes_after_process_request_login(self):
"""
CSRF check must access the CSRF token from the session or cookie,
rather than the request, as rotate_token() may have been called by an
authentication middleware during the process_request() phase.
"""
csrf_client = AsyncClient(enforce_csrf_checks=True)
csrf_secret = _get_new_csrf_string()
csrf_token = _mask_cipher_secret(csrf_secret)
csrf_token_form = _mask_cipher_secret(csrf_secret)
headers = {self.header: "fakeuser"}
data = {"csrfmiddlewaretoken": csrf_token_form}

# Verify that CSRF is configured for the view
csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token})
response = await csrf_client.post("/remote_user/", **headers)
self.assertEqual(response.status_code, 403)
self.assertIn(b"CSRF verification failed.", response.content)

# This request will call django.contrib.auth.login() which will call
# django.middleware.csrf.rotate_token() thus changing the value of
# request.META['CSRF_COOKIE'] from the user submitted value set by
# CsrfViewMiddleware.process_request() to the new csrftoken value set
# by rotate_token(). Csrf validation should still pass when the view is
# later processed by CsrfViewMiddleware.process_view()
csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token})
response = await csrf_client.post("/remote_user/", data, **headers)
self.assertEqual(response.status_code, 200)

async def test_unknown_user(self):
"""
Tests the case where the username passed in the header does not exist
as a User.
"""
num_users = await User.objects.acount()
response = await self.async_client.get(
"/remote_user/", **{self.header: "newuser"}
)
self.assertEqual(response.context["user"].username, "newuser")
self.assertEqual(await User.objects.acount(), num_users + 1)
await User.objects.aget(username="newuser")

# Another request with same user should not create any new users.
response = await self.async_client.get(
"/remote_user/", **{self.header: "newuser"}
)
self.assertEqual(await User.objects.acount(), num_users + 1)

async def test_known_user(self):
"""
Tests the case where the username passed in the header is a valid User.
"""
await User.objects.acreate(username="knownuser")
await User.objects.acreate(username="knownuser2")
num_users = await User.objects.acount()
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
self.assertEqual(await User.objects.acount(), num_users)
# A different user passed in the headers causes the new user
# to be logged in.
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user2}
)
self.assertEqual(response.context["user"].username, "knownuser2")
self.assertEqual(await User.objects.acount(), num_users)

async def test_last_login(self):
"""
A user's last_login is set the first time they make a
request but not updated in subsequent requests with the same session.
"""
user = await User.objects.acreate(username="knownuser")
# Set last_login to something so we can determine if it changes.
default_login = datetime(2000, 1, 1)
if settings.USE_TZ:
default_login = default_login.replace(tzinfo=timezone.utc)
user.last_login = default_login
await user.asave()

response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertNotEqual(default_login, response.context["user"].last_login)

user = await User.objects.aget(username="knownuser")
user.last_login = default_login
await user.asave()
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(default_login, response.context["user"].last_login)

async def test_header_disappears(self):
"""
A logged in user is logged out automatically when
the REMOTE_USER header disappears during the same browser session.
"""
await User.objects.acreate(username="knownuser")
# Known user authenticates
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
# During the session, the REMOTE_USER header disappears. Should trigger logout.
response = await self.async_client.get("/remote_user/")
self.assertTrue(response.context["user"].is_anonymous)
# verify the remoteuser middleware will not remove a user
# authenticated via another backend
await User.objects.acreate_user(username="modeluser", password="foo")
await self.async_client.alogin(username="modeluser", password="foo")
await aauthenticate(username="modeluser", password="foo")
response = await self.async_client.get("/remote_user/")
self.assertEqual(response.context["user"].username, "modeluser")

async def test_user_switch_forces_new_login(self):
"""
If the username in the header changes between requests
that the original user is logged out
"""
await User.objects.acreate(username="knownuser")
# Known user authenticates
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
# During the session, the REMOTE_USER changes to a different user.
response = await self.async_client.get(
"/remote_user/", **{self.header: "newnewuser"}
)
# The current user is not the prior remote_user.
# In backends that create a new user, username is "newnewuser"
# In backends that do not create new users, it is '' (anonymous user)
self.assertNotEqual(response.context["user"].username, "knownuser")

async def test_inactive_user(self):
await User.objects.acreate(username="knownuser", is_active=False)
response = await self.async_client.get(
"/remote_user/", **{self.header: "knownuser"}
)
self.assertTrue(response.context["user"].is_anonymous)


class RemoteUserNoCreateBackend(RemoteUserBackend):
"""Backend that doesn't create unknown users."""

Expand Down
6 changes: 1 addition & 5 deletions tests/deprecation/test_middleware_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from asgiref.sync import async_to_sync, iscoroutinefunction

from django.contrib.admindocs.middleware import XViewMiddleware
from django.contrib.auth.middleware import (
AuthenticationMiddleware,
RemoteUserMiddleware,
)
from django.contrib.auth.middleware import AuthenticationMiddleware
from django.contrib.flatpages.middleware import FlatpageFallbackMiddleware
from django.contrib.messages.middleware import MessageMiddleware
from django.contrib.redirects.middleware import RedirectFallbackMiddleware
Expand Down Expand Up @@ -46,7 +43,6 @@ class MiddlewareMixinTests(SimpleTestCase):
LocaleMiddleware,
MessageMiddleware,
RedirectFallbackMiddleware,
RemoteUserMiddleware,
SecurityMiddleware,
SessionMiddleware,
UpdateCacheMiddleware,
Expand Down

0 comments on commit b0dc084

Please sign in to comment.