Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support overriding request headers in middlewares #41380

Merged
merged 15 commits into from Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 31 additions & 0 deletions packages/next/server/next-server.ts
Expand Up @@ -1882,6 +1882,37 @@ export default class NextNodeServer extends BaseServer {
result.response.headers.set('x-middleware-rewrite', rel)
}

if (result.response.headers.has('x-middleware-override-headers')) {
const overriddenHeaders: Set<string> = new Set()
for (const key of result.response.headers
.get('x-middleware-override-headers')!
.split(',')) {
overriddenHeaders.add(key.trim())
}

result.response.headers.delete('x-middleware-override-headers')

// Delete headers.
for (const key of Object.keys(req.headers)) {
if (!overriddenHeaders.has(key)) {
delete req.headers[key]
}
}

// Update or add headers.
for (const key of overriddenHeaders.keys()) {
const valueKey = 'x-middleware-request-' + key
const newValue = result.response.headers.get(valueKey)
const oldValue = req.headers[key]

if (oldValue !== newValue) {
req.headers[key] = newValue === null ? undefined : newValue
}

result.response.headers.delete(valueKey)
}
}
nuta marked this conversation as resolved.
Show resolved Hide resolved

if (result.response.headers.has('Location')) {
const value = result.response.headers.get('Location')!
const rel = relativizeURL(value, initUrl)
Expand Down
44 changes: 42 additions & 2 deletions packages/next/server/web/spec-extension/response.ts
Expand Up @@ -7,6 +7,25 @@ import { NextCookies } from './cookies'
const INTERNALS = Symbol('internal response')
const REDIRECTS = new Set([301, 302, 303, 307, 308])

function handleMiddlewareField(
init: MiddlewareResponseInit | undefined,
headers: Headers
) {
if (init?.request?.headers) {
if (!(init.request.headers instanceof Headers)) {
throw new Error('request.headers must be an instance of Headers')
}
nuta marked this conversation as resolved.
Show resolved Hide resolved

const keys = []
for (const [key, value] of init.request.headers) {
headers.set('x-middleware-request-' + key, value)
keys.push(key)
}

headers.set('x-middleware-override-headers', keys.join(','))
}
}

export class NextResponse extends Response {
[INTERNALS]: {
cookies: NextCookies
Expand Down Expand Up @@ -71,15 +90,22 @@ export class NextResponse extends Response {
})
}

static rewrite(destination: string | NextURL | URL, init?: ResponseInit) {
static rewrite(
destination: string | NextURL | URL,
init?: MiddlewareResponseInit
) {
const headers = new Headers(init?.headers)
headers.set('x-middleware-rewrite', validateURL(destination))

handleMiddlewareField(init, headers)
ijjk marked this conversation as resolved.
Show resolved Hide resolved
return new NextResponse(null, { ...init, headers })
}

static next(init?: ResponseInit) {
static next(init?: MiddlewareResponseInit) {
const headers = new Headers(init?.headers)
headers.set('x-middleware-next', '1')

handleMiddlewareField(init, headers)
return new NextResponse(null, { ...init, headers })
}
}
Expand All @@ -92,3 +118,17 @@ interface ResponseInit extends globalThis.ResponseInit {
}
url?: string
}

interface ModifiedRequest {
/**
* If this is set, the request headers will be overridden with this value.
*/
headers?: Headers
}

interface MiddlewareResponseInit extends globalThis.ResponseInit {
/**
* These fields will override the request from clients.
*/
request?: ModifiedRequest
}
129 changes: 129 additions & 0 deletions test/e2e/app-dir/app-middleware.test.ts
@@ -0,0 +1,129 @@
/* eslint-env jest */

import { NextInstance } from 'test/lib/next-modes/base'
import { fetchViaHTTP } from 'next-test-utils'
import { createNext, FileRef } from 'e2e-utils'
import cheerio from 'cheerio'
import path from 'path'

describe('app-dir with middleware', () => {
if ((global as any).isNextDeploy) {
it('should skip next deploy for now', () => {})
return
}

if (process.env.NEXT_TEST_REACT_VERSION === '^17') {
it('should skip for react v17', () => {})
return
}

let next: NextInstance

afterAll(() => next.destroy())
beforeAll(async () => {
next = await createNext({
files: new FileRef(path.join(__dirname, 'app-middleware')),
dependencies: {
react: 'experimental',
'react-dom': 'experimental',
},
})
})

describe.each([
{
title: 'Serverless Functions',
path: '/api/dump-headers-serverless',
toJson: (res: Response) => res.json(),
},
{
title: 'Edge Functions',
path: '/api/dump-headers-edge',
toJson: (res: Response) => res.json(),
},
{
title: 'next/headers',
path: '/headers',
toJson: async (res: Response) => {
const $ = cheerio.load(await res.text())
return JSON.parse($('#headers').text())
},
},
])('Mutate request headers for $title', ({ path, toJson }) => {
it(`Adds new headers`, async () => {
const res = await fetchViaHTTP(next.url, path, null, {
headers: {
'x-from-client': 'hello-from-client',
},
})
expect(await toJson(res)).toMatchObject({
'x-from-client': 'hello-from-client',
'x-from-middleware': 'hello-from-middleware',
})
})

it(`Deletes headers`, async () => {
const res = await fetchViaHTTP(
next.url,
path,
{
'remove-headers': 'x-from-client1,x-from-client2',
},
{
headers: {
'x-from-client1': 'hello-from-client',
'X-From-Client2': 'hello-from-client',
},
}
)

const json = await toJson(res)
expect(json).not.toHaveProperty('x-from-client1')
expect(json).not.toHaveProperty('X-From-Client2')
expect(json).toMatchObject({
'x-from-middleware': 'hello-from-middleware',
})

// Should not be included in response headers.
expect(res.headers.get('x-middleware-override-headers')).toBeNull()
expect(
res.headers.get('x-middleware-request-x-from-middleware')
).toBeNull()
expect(res.headers.get('x-middleware-request-x-from-client1')).toBeNull()
expect(res.headers.get('x-middleware-request-x-from-client2')).toBeNull()
})

it(`Updates headers`, async () => {
const res = await fetchViaHTTP(
next.url,
path,
{
'update-headers':
'x-from-client1=new-value1,x-from-client2=new-value2',
},
{
headers: {
'x-from-client1': 'old-value1',
'X-From-Client2': 'old-value2',
'x-from-client3': 'old-value3',
},
}
)
expect(await toJson(res)).toMatchObject({
'x-from-client1': 'new-value1',
'x-from-client2': 'new-value2',
'x-from-client3': 'old-value3',
'x-from-middleware': 'hello-from-middleware',
})

// Should not be included in response headers.
expect(res.headers.get('x-middleware-override-headers')).toBeNull()
expect(
res.headers.get('x-middleware-request-x-from-middleware')
).toBeNull()
expect(res.headers.get('x-middleware-request-x-from-client1')).toBeNull()
expect(res.headers.get('x-middleware-request-x-from-client2')).toBeNull()
expect(res.headers.get('x-middleware-request-x-from-client3')).toBeNull()
})
})
})
10 changes: 10 additions & 0 deletions test/e2e/app-dir/app-middleware/app/headers/page.js
@@ -0,0 +1,10 @@
import { headers } from 'next/headers'

export default function SSRPage() {
const headersObj = Object.fromEntries(headers())
return (
<>
<p id="headers">{JSON.stringify(headersObj)}</p>
</>
)
}
10 changes: 10 additions & 0 deletions test/e2e/app-dir/app-middleware/app/layout.js
@@ -0,0 +1,10 @@
export default function Layout({ children }) {
return (
<html lang="en">
<head>
<title>app-middleware</title>
</head>
<body>{children}</body>
</html>
)
}
30 changes: 30 additions & 0 deletions test/e2e/app-dir/app-middleware/middleware.js
@@ -0,0 +1,30 @@
import { NextResponse } from 'next/server'

/**
* @param {import('next/server').NextRequest} request
*/
export async function middleware(request) {
const headers = new Headers(request.headers)
headers.set('x-from-middleware', 'hello-from-middleware')

const removeHeaders = request.nextUrl.searchParams.get('remove-headers')
if (removeHeaders) {
for (const key of removeHeaders.split(',')) {
headers.delete(key)
}
}

const updateHeader = request.nextUrl.searchParams.get('update-headers')
if (updateHeader) {
for (const kv of updateHeader.split(',')) {
const [key, value] = kv.split('=')
headers.set(key, value)
}
}

return NextResponse.next({
request: {
headers,
},
})
}
7 changes: 7 additions & 0 deletions test/e2e/app-dir/app-middleware/next.config.js
@@ -0,0 +1,7 @@
module.exports = {
experimental: {
appDir: true,
legacyBrowsers: false,
browsersListForSwc: true,
},
}
11 changes: 11 additions & 0 deletions test/e2e/app-dir/app-middleware/pages/api/dump-headers-edge.js
@@ -0,0 +1,11 @@
export const config = {
runtime: 'experimental-edge',
}

export default (req) => {
return Response.json(Object.fromEntries(req.headers.entries()), {
headers: {
'headers-from-edge-function': '1',
},
})
}
@@ -0,0 +1,6 @@
export default (req, res) => {
return res
.status(200)
.setHeader('headers-from-serverless', '1')
.json(req.headers)
}
@@ -0,0 +1 @@
.vercel
30 changes: 30 additions & 0 deletions test/e2e/middleware-request-header-overrides/app/middleware.js
@@ -0,0 +1,30 @@
import { NextResponse } from 'next/server'

/**
* @param {import('next/server').NextRequest} request
*/
export async function middleware(request) {
const headers = new Headers(request.headers)
headers.set('x-from-middleware', 'hello-from-middleware')

const removeHeaders = request.nextUrl.searchParams.get('remove-headers')
if (removeHeaders) {
for (const key of removeHeaders.split(',')) {
headers.delete(key)
}
}

const updateHeader = request.nextUrl.searchParams.get('update-headers')
if (updateHeader) {
for (const kv of updateHeader.split(',')) {
const [key, value] = kv.split('=')
headers.set(key, value)
}
}

return NextResponse.next({
request: {
headers,
},
})
}
@@ -0,0 +1 @@
module.exports = {}
@@ -0,0 +1,11 @@
export const config = {
runtime: 'experimental-edge',
}

export default (req) => {
return Response.json(Object.fromEntries(req.headers.entries()), {
headers: {
'headers-from-edge-function': '1',
},
})
}
@@ -0,0 +1,6 @@
export default (req, res) => {
return res
.status(200)
.setHeader('headers-from-serverless', '1')
.json(req.headers)
}
@@ -0,0 +1,15 @@
export default function SSRPage({ headers }) {
return (
<>
<p id="headers">{JSON.stringify(headers)}</p>
</>
)
}

export const getServerSideProps = (ctx) => {
return {
props: {
headers: ctx.req.headers,
},
}
}