Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support add connection header thought header apis #3221

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
151 changes: 117 additions & 34 deletions actix-http/src/h1/encoder.rs
Expand Up @@ -6,6 +6,7 @@ use std::{
slice::from_raw_parts_mut,
};

use ahash::AHashMap;
use bytes::{BufMut, BytesMut};

use crate::{
Expand Down Expand Up @@ -109,28 +110,21 @@ pub(crate) trait MessageType: Sized {
BodySize::None => dst.put_slice(b"\r\n"),
}

// Connection
match conn_type {
ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"),
ConnectionType::KeepAlive if version < Version::HTTP_11 => {
if camel_case {
dst.put_slice(b"Connection: keep-alive\r\n")
} else {
dst.put_slice(b"connection: keep-alive\r\n")
}
}
ConnectionType::Close if version >= Version::HTTP_11 => {
if camel_case {
dst.put_slice(b"Connection: close\r\n")
} else {
dst.put_slice(b"connection: close\r\n")
}
}
_ => {}
}
let headers = match self.extra_headers() {
Some(extra_headers) => self
.headers()
.inner
.iter()
.filter(|(name, _)| !extra_headers.contains_key(*name))
.chain(extra_headers.inner.iter())
.collect::<AHashMap<_, _>>(),
None => self.headers().inner.iter().collect::<AHashMap<_, _>>(),
};

// write headers
// write connection header
self.write_connection_header(&headers, conn_type, version, dst);

// write headers
let mut has_date = false;

let mut buf = dst.chunk_mut().as_mut_ptr();
Expand All @@ -141,7 +135,7 @@ pub(crate) trait MessageType: Sized {
// container's knowledge, this is used to sync the containers cursor after data is written
let mut pos = 0;

self.write_headers(|key, value| {
self.write_headers(&headers, |key, value| {
match *key {
CONNECTION => return,
TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return,
Expand Down Expand Up @@ -221,22 +215,54 @@ pub(crate) trait MessageType: Sized {
Ok(())
}

fn write_headers<F>(&mut self, mut f: F)
fn write_connection_header<B: BufMut>(
&self,
headers: &AHashMap<&HeaderName, &Value>,
conn_type: ConnectionType,
version: Version,
buf: &mut B,
) {
let camel_case = self.camel_case();

if let Some(header_value) = headers.get(&CONNECTION) {
if camel_case {
buf.put_slice(b"Connection: ");
} else {
buf.put_slice(b"connection: ");
}
for val in header_value.iter() {
buf.put_slice(val.as_ref());
}
buf.put_slice(b"\r\n");
return;
}

// Connection
match conn_type {
ConnectionType::Upgrade => buf.put_slice(b"connection: upgrade\r\n"),
ConnectionType::KeepAlive if version < Version::HTTP_11 => {
if camel_case {
buf.put_slice(b"Connection: keep-alive\r\n")
} else {
buf.put_slice(b"connection: keep-alive\r\n")
}
}
ConnectionType::Close if version >= Version::HTTP_11 => {
if camel_case {
buf.put_slice(b"Connection: close\r\n")
} else {
buf.put_slice(b"connection: close\r\n")
}
}
_ => {}
}
}

fn write_headers<F>(&self, headers: &AHashMap<&HeaderName, &Value>, mut f: F)
where
F: FnMut(&HeaderName, &Value),
{
match self.extra_headers() {
Some(headers) => {
// merging headers from head and extra headers.
self.headers()
.inner
.iter()
.filter(|(name, _)| !headers.contains_key(*name))
.chain(headers.inner.iter())
.for_each(|(k, v)| f(k, v))
}
None => self.headers().inner.iter().for_each(|(k, v)| f(k, v)),
}
headers.iter().for_each(|(key, value)| f(key, value));
}
}

Expand Down Expand Up @@ -668,4 +694,61 @@ mod tests {
assert!(!data.contains("content-length: 0\r\n"));
assert!(!data.contains("transfer-encoding: chunked\r\n"));
}

#[actix_rt::test]
async fn test_close_connection_header_even_keep_alive_was_provided() {
let mut bytes = BytesMut::with_capacity(2048);

let mut res = Response::with_body(StatusCode::OK, ());
res.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("close"));

let _ = res.encode_headers(
&mut bytes,
Version::HTTP_11,
BodySize::Stream,
ConnectionType::KeepAlive,
&ServiceConfig::default(),
);
let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("connection: close\r\n"));
}

#[actix_rt::test]
async fn test_keep_alive_connection_header_when_provided() {
let mut bytes = BytesMut::with_capacity(2048);

let mut res = Response::with_body(StatusCode::OK, ());
res.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));

let _ = res.encode_headers(
&mut bytes,
Version::HTTP_11,
BodySize::Stream,
ConnectionType::KeepAlive,
&ServiceConfig::default(),
);
let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("connection: keep-alive\r\n"));
}

#[actix_rt::test]
async fn test_keep_alive_connection_header_even_close_was_provided() {
let mut bytes = BytesMut::with_capacity(2048);

let mut res = Response::with_body(StatusCode::OK, ());
res.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));

let _ = res.encode_headers(
&mut bytes,
Version::HTTP_11,
BodySize::Stream,
ConnectionType::Close,
&ServiceConfig::default(),
);
let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("connection: keep-alive\r\n"));
}
}