Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net: implement UnixSocket #6290

Merged
merged 12 commits into from
Jan 22, 2024
1 change: 1 addition & 0 deletions tokio/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cfg_net_unix! {
pub use unix::datagram::socket::UnixDatagram;
pub use unix::listener::UnixListener;
pub use unix::stream::UnixStream;
pub use unix::socket::UnixSocket;
}

cfg_net_windows! {
Expand Down
10 changes: 10 additions & 0 deletions tokio/src/net/unix/datagram/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ cfg_net_unix! {
}

impl UnixDatagram {
pub(crate) fn from_mio(sys: mio::net::UnixDatagram) -> io::Result<UnixDatagram> {
let datagram = UnixDatagram::new(sys)?;

if let Some(e) = datagram.io.take_error()? {
return Err(e);
}

Ok(datagram)
}

/// Waits for any of the requested ready states.
///
/// This function is usually paired with `try_recv()` or `try_send()`. It
Expand Down
5 changes: 5 additions & 0 deletions tokio/src/net/unix/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ cfg_net_unix! {
}

impl UnixListener {
pub(crate) fn new(listener: mio::net::UnixListener) -> io::Result<UnixListener> {
let io = PollEvented::new(listener)?;
Ok(UnixListener { io })
}

/// Creates a new `UnixListener` bound to the specified path.
///
/// # Panics
Expand Down
2 changes: 2 additions & 0 deletions tokio/src/net/unix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pub mod datagram;

pub(crate) mod listener;

pub(crate) mod socket;

mod split;
pub use split::{ReadHalf, WriteHalf};

Expand Down
271 changes: 271 additions & 0 deletions tokio/src/net/unix/socket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
use std::io;
use std::path::Path;

use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};

use crate::net::{UnixDatagram, UnixListener, UnixStream};

cfg_net_unix! {
/// A Unix socket that has not yet been converted to a [`UnixStream`], [`UnixDatagram`], or
/// [`UnixListener`].
///
/// `UnixSocket` wraps an operating system socket and enables the caller to
/// configure the socket before establishing a connection or accepting
/// inbound connections. The caller is able to set socket option and explicitly
/// bind the socket with a socket address.
///
/// The underlying socket is closed when the `UnixSocket` value is dropped.
///
/// `UnixSocket` should only be used directly if the default configuration used
/// by [`UnixStream::connect`], [`UnixDatagram::bind`], and [`UnixListener::bind`]
/// does not meet the required use case.
///
/// Calling `UnixStream::connect(path)` effectively performs the same function as:
///
/// ```no_run
/// use tokio::net::UnixSocket;
/// use std::error::Error;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// let dir = tempfile::tempdir().unwrap();
/// let path = dir.path().join("bind_path");
/// let socket = UnixSocket::new_stream()?;
///
/// let stream = socket.connect(path).await?;
///
/// Ok(())
/// }
/// ```
///
/// Calling `UnixDatagram::bind(path)` effectively performs the same function as:
///
/// ```no_run
/// use tokio::net::UnixSocket;
/// use std::error::Error;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// let dir = tempfile::tempdir().unwrap();
/// let path = dir.path().join("bind_path");
/// let socket = UnixSocket::new_datagram()?;
/// socket.bind(path)?;
///
/// let datagram = socket.datagram()?;
///
/// Ok(())
/// }
/// ```
///
/// Calling `UnixListener::bind(path)` effectively performs the same function as:
///
/// ```no_run
/// use tokio::net::UnixSocket;
/// use std::error::Error;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// let dir = tempfile::tempdir().unwrap();
/// let path = dir.path().join("bind_path");
/// let socket = UnixSocket::new_stream()?;
/// socket.bind(path)?;
///
/// let listener = socket.listen(1024)?;
///
/// Ok(())
/// }
/// ```
///
/// Setting socket options not explicitly provided by `UnixSocket` may be done by
/// accessing the [`RawFd`]/[`RawSocket`] using [`AsRawFd`]/[`AsRawSocket`] and
/// setting the option with a crate like [`socket2`].
///
/// [`RawFd`]: std::os::fd::RawFd
/// [`RawSocket`]: https://doc.rust-lang.org/std/os/windows/io/type.RawSocket.html
/// [`AsRawFd`]: std::os::fd::AsRawFd
/// [`AsRawSocket`]: https://doc.rust-lang.org/std/os/windows/io/trait.AsRawSocket.html
/// [`socket2`]: https://docs.rs/socket2/
#[derive(Debug)]
pub struct UnixSocket {
inner: socket2::Socket,
}
}

impl UnixSocket {
fn ty(&self) -> socket2::Type {
self.inner.r#type().unwrap()
}

/// Creates a new Unix datagram socket.
///
/// Calls `socket(2)` with `AF_UNIX` and `SOCK_DGRAM`.
///
/// # Returns
///
/// On success, the newly created [`UnixSocket`] is returned. If an error is
/// encountered, it is returned instead.
pub fn new_datagram() -> io::Result<UnixSocket> {
UnixSocket::new(socket2::Type::DGRAM)
}

/// Creates a new Unix stream socket.
///
/// Calls `socket(2)` with `AF_UNIX` and `SOCK_STREAM`.
///
/// # Returns
///
/// On success, the newly created [`UnixSocket`] is returned. If an error is
/// encountered, it is returned instead.
pub fn new_stream() -> io::Result<UnixSocket> {
UnixSocket::new(socket2::Type::STREAM)
}

fn new(ty: socket2::Type) -> io::Result<UnixSocket> {
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "netbsd",
target_os = "openbsd"
))]
let ty = ty.nonblocking();
let inner = socket2::Socket::new(socket2::Domain::UNIX, ty, None)?;
#[cfg(not(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "netbsd",
target_os = "openbsd"
)))]
inner.set_nonblocking(true)?;
Ok(UnixSocket { inner })
}

/// Binds the socket to the given address.
///
/// This calls the `bind(2)` operating-system function.
pub fn bind(&self, path: impl AsRef<Path>) -> io::Result<()> {
let addr = socket2::SockAddr::unix(path)?;
self.inner.bind(&addr)
}

/// Converts the socket into a `UnixListener`.
///
/// `backlog` defines the maximum number of pending connections are queued
/// by the operating system at any given time. Connection are removed from
/// the queue with [`UnixListener::accept`]. When the queue is full, the
/// operating-system will start rejecting connections.
///
/// Calling this function on a socket created by [`new_datagram`] will return an error.
///
/// This calls the `listen(2)` operating-system function, marking the socket
/// as a passive socket.
///
/// [`new_datagram`]: `UnixSocket::new_datagram`
pub fn listen(self, backlog: u32) -> io::Result<UnixListener> {
if self.ty() == socket2::Type::DGRAM {
return Err(io::Error::new(
io::ErrorKind::Other,
"listen cannot be called on a datagram socket",
));
}

self.inner.listen(backlog as i32)?;
let mio = {
use std::os::unix::io::{FromRawFd, IntoRawFd};

let raw_fd = self.inner.into_raw_fd();
unsafe { mio::net::UnixListener::from_raw_fd(raw_fd) }
};

UnixListener::new(mio)
}

/// Establishes a Unix connection with a peer at the specified socket address.
///
/// The `UnixSocket` is consumed. Once the connection is established, a
/// connected [`UnixStream`] is returned. If the connection fails, the
/// encountered error is returned.
///
/// Calling this function on a socket created by [`new_datagram`] will return an error.
///
/// This calls the `connect(2)` operating-system function.
///
/// [`new_datagram`]: `UnixSocket::new_datagram`
pub async fn connect(self, path: impl AsRef<Path>) -> io::Result<UnixStream> {
if self.ty() == socket2::Type::DGRAM {
return Err(io::Error::new(
io::ErrorKind::Other,
"connect cannot be called on a datagram socket",
));
}

let addr = socket2::SockAddr::unix(path)?;
if let Err(err) = self.inner.connect(&addr) {
if err.raw_os_error() != Some(libc::EINPROGRESS) {
return Err(err);
}
}
let mio = {
use std::os::unix::io::{FromRawFd, IntoRawFd};

let raw_fd = self.inner.into_raw_fd();
unsafe { mio::net::UnixStream::from_raw_fd(raw_fd) }
};

UnixStream::connect_mio(mio).await
}

/// Converts the socket into a [`UnixDatagram`].
///
/// Calling this function on a socket created by [`new_stream`] will return an error.
///
/// [`new_stream`]: `UnixSocket::new_stream`
pub fn datagram(self) -> io::Result<UnixDatagram> {
if self.ty() == socket2::Type::STREAM {
return Err(io::Error::new(
io::ErrorKind::Other,
"datagram cannot be called on a stream socket",
));
}
let mio = {
use std::os::unix::io::{FromRawFd, IntoRawFd};

let raw_fd = self.inner.into_raw_fd();
unsafe { mio::net::UnixDatagram::from_raw_fd(raw_fd) }
};

UnixDatagram::from_mio(mio)
}
}

impl AsRawFd for UnixSocket {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}

impl AsFd for UnixSocket {
fn as_fd(&self) -> BorrowedFd<'_> {
unsafe { BorrowedFd::borrow_raw(self.as_raw_fd()) }
}
}

impl FromRawFd for UnixSocket {
unsafe fn from_raw_fd(fd: RawFd) -> UnixSocket {
let inner = socket2::Socket::from_raw_fd(fd);
UnixSocket { inner }
}
}

impl IntoRawFd for UnixSocket {
fn into_raw_fd(self) -> RawFd {
self.inner.into_raw_fd()
}
}
18 changes: 18 additions & 0 deletions tokio/src/net/unix/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ cfg_net_unix! {
}

impl UnixStream {
pub(crate) async fn connect_mio(sys: mio::net::UnixStream) -> io::Result<UnixStream> {
let stream = UnixStream::new(sys)?;

// Once we've connected, wait for the stream to be writable as
// that's when the actual connection has been initiated. Once we're
// writable we check for `take_socket_error` to see if the connect
// actually hit an error or not.
//
// If all that succeeded then we ship everything on up.
poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?;

if let Some(e) = stream.io.take_error()? {
return Err(e);
}

Ok(stream)
}

/// Connects to the socket named by `path`.
///
/// This function will create a new Unix socket and connect to the path
Expand Down