diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index f01e79418aa..6e9846c72da 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -69,11 +69,11 @@ process = [ "mio/net", "signal-hook-registry", "winapi/handleapi", + "winapi/minwindef", "winapi/processthreadsapi", "winapi/threadpoollegacyapiset", "winapi/winbase", "winapi/winnt", - "winapi/minwindef", ] # Includes basic task execution capabilities rt = ["once_cell"] diff --git a/tokio/src/io/blocking.rs b/tokio/src/io/blocking.rs index 1d79ee7a279..f6db4500af1 100644 --- a/tokio/src/io/blocking.rs +++ b/tokio/src/io/blocking.rs @@ -34,8 +34,9 @@ enum State { Busy(sys::Blocking<(io::Result, Buf, T)>), } -cfg_io_std! { +cfg_io_blocking! { impl Blocking { + #[cfg_attr(feature = "fs", allow(dead_code))] pub(crate) fn new(inner: T) -> Blocking { Blocking { inner: Some(inner), diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs index df071897d91..25cece62764 100644 --- a/tokio/src/io/poll_evented.rs +++ b/tokio/src/io/poll_evented.rs @@ -136,7 +136,7 @@ impl PollEvented { } feature! { - #![any(feature = "net", feature = "process")] + #![any(feature = "net", all(unix, feature = "process"))] use crate::io::ReadBuf; use std::task::{Context, Poll}; diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 4d9678c5df0..e8ddf16b39e 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -431,7 +431,12 @@ cfg_process! { pub mod process; } -#[cfg(any(feature = "net", feature = "fs", feature = "io-std"))] +#[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + all(windows, feature = "process"), +))] mod blocking; cfg_rt! { diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs index fd3556c302e..f728722132d 100644 --- a/tokio/src/macros/cfg.rs +++ b/tokio/src/macros/cfg.rs @@ -70,7 +70,11 @@ macro_rules! cfg_fs { macro_rules! cfg_io_blocking { ($($item:item)*) => { - $( #[cfg(any(feature = "io-std", feature = "fs"))] $item )* + $( #[cfg(any( + feature = "io-std", + feature = "fs", + all(windows, feature = "process"), + ))] $item )* } } @@ -79,12 +83,12 @@ macro_rules! cfg_io_driver { $( #[cfg(any( feature = "net", - feature = "process", + all(unix, feature = "process"), all(unix, feature = "signal"), ))] #[cfg_attr(docsrs, doc(cfg(any( feature = "net", - feature = "process", + all(unix, feature = "process"), all(unix, feature = "signal"), ))))] $item @@ -97,7 +101,7 @@ macro_rules! cfg_io_driver_impl { $( #[cfg(any( feature = "net", - feature = "process", + all(unix, feature = "process"), all(unix, feature = "signal"), ))] $item @@ -110,7 +114,7 @@ macro_rules! cfg_not_io_driver { $( #[cfg(not(any( feature = "net", - feature = "process", + all(unix, feature = "process"), all(unix, feature = "signal"), )))] $item diff --git a/tokio/src/process/mod.rs b/tokio/src/process/mod.rs index 719fdeee6a1..e5ee5db2ba0 100644 --- a/tokio/src/process/mod.rs +++ b/tokio/src/process/mod.rs @@ -1285,41 +1285,39 @@ impl ChildStderr { impl AsyncWrite for ChildStdin { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.inner.poll_write(cx, buf) + Pin::new(&mut self.inner).poll_write(cx, buf) } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) } - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) } } impl AsyncRead for ChildStdout { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - // Safety: pipes support reading into uninitialized memory - unsafe { self.inner.poll_read(cx, buf) } + Pin::new(&mut self.inner).poll_read(cx, buf) } } impl AsyncRead for ChildStderr { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - // Safety: pipes support reading into uninitialized memory - unsafe { self.inner.poll_read(cx, buf) } + Pin::new(&mut self.inner).poll_read(cx, buf) } } diff --git a/tokio/src/process/unix/mod.rs b/tokio/src/process/unix/mod.rs index 576fe6cb47f..ba34c852b58 100644 --- a/tokio/src/process/unix/mod.rs +++ b/tokio/src/process/unix/mod.rs @@ -29,7 +29,7 @@ use orphan::{OrphanQueue, OrphanQueueImpl, Wait}; mod reap; use reap::Reaper; -use crate::io::PollEvented; +use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf}; use crate::process::kill::Kill; use crate::process::SpawnedChild; use crate::signal::unix::driver::Handle as SignalHandle; @@ -177,8 +177,8 @@ impl AsRawFd for Pipe { } } -pub(crate) fn convert_to_stdio(io: PollEvented) -> io::Result { - let mut fd = io.into_inner()?.fd; +pub(crate) fn convert_to_stdio(io: ChildStdio) -> io::Result { + let mut fd = io.inner.into_inner()?.fd; // Ensure that the fd to be inherited is set to *blocking* mode, as this // is the default that virtually all programs expect to have. Those @@ -213,7 +213,50 @@ impl Source for Pipe { } } -pub(crate) type ChildStdio = PollEvented; +pub(crate) struct ChildStdio { + inner: PollEvented, +} + +impl fmt::Debug for ChildStdio { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(fmt) + } +} + +impl AsRawFd for ChildStdio { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } +} + +impl AsyncWrite for ChildStdio { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.inner.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl AsyncRead for ChildStdio { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // Safety: pipes support reading into uninitialized memory + unsafe { self.inner.poll_read(cx, buf) } + } +} fn set_nonblocking(fd: &mut T, nonblocking: bool) -> io::Result<()> { unsafe { @@ -238,7 +281,7 @@ fn set_nonblocking(fd: &mut T, nonblocking: bool) -> io::Result<()> Ok(()) } -pub(super) fn stdio(io: T) -> io::Result> +pub(super) fn stdio(io: T) -> io::Result where T: IntoRawFd, { @@ -246,5 +289,5 @@ where let mut pipe = Pipe::from(io); set_nonblocking(&mut pipe, true)?; - PollEvented::new(pipe) + PollEvented::new(pipe).map(|inner| ChildStdio { inner }) } diff --git a/tokio/src/process/windows.rs b/tokio/src/process/windows.rs index 136d5b0cab8..651233ba376 100644 --- a/tokio/src/process/windows.rs +++ b/tokio/src/process/windows.rs @@ -15,22 +15,22 @@ //! `RegisterWaitForSingleObject` and then wait on the other end of the oneshot //! from then on out. -use crate::io::PollEvented; +use crate::io::{blocking::Blocking, AsyncRead, AsyncWrite, ReadBuf}; use crate::process::kill::Kill; use crate::process::SpawnedChild; use crate::sync::oneshot; -use mio::windows::NamedPipe; use std::fmt; +use std::fs::File as StdFile; use std::future::Future; use std::io; -use std::os::windows::prelude::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; +use std::os::windows::prelude::{AsRawHandle, IntoRawHandle, RawHandle}; use std::pin::Pin; use std::process::Stdio; use std::process::{Child as StdChild, Command as StdCommand, ExitStatus}; use std::ptr; -use std::task::Context; -use std::task::Poll; +use std::sync::Arc; +use std::task::{Context, Poll}; use winapi::shared::minwindef::{DWORD, FALSE}; use winapi::um::handleapi::{DuplicateHandle, INVALID_HANDLE_VALUE}; use winapi::um::processthreadsapi::GetCurrentProcess; @@ -167,28 +167,97 @@ unsafe extern "system" fn callback(ptr: PVOID, _timer_fired: BOOLEAN) { let _ = complete.take().unwrap().send(()); } -pub(crate) type ChildStdio = PollEvented; +#[derive(Debug)] +struct ArcFile(Arc); -pub(super) fn stdio(io: T) -> io::Result> +impl io::Read for ArcFile { + fn read(&mut self, bytes: &mut [u8]) -> io::Result { + (&*self.0).read(bytes) + } +} + +impl io::Write for ArcFile { + fn write(&mut self, bytes: &[u8]) -> io::Result { + (&*self.0).write(bytes) + } + + fn flush(&mut self) -> io::Result<()> { + (&*self.0).flush() + } +} + +#[derive(Debug)] +pub(crate) struct ChildStdio { + // Used for accessing the raw handle, even if the io version is busy + raw: Arc, + // For doing I/O operations asynchronously + io: Blocking, +} + +impl AsRawHandle for ChildStdio { + fn as_raw_handle(&self) -> RawHandle { + self.raw.as_raw_handle() + } +} + +impl AsyncRead for ChildStdio { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.io).poll_read(cx, buf) + } +} + +impl AsyncWrite for ChildStdio { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.io).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_shutdown(cx) + } +} + +pub(super) fn stdio(io: T) -> io::Result where T: IntoRawHandle, { - let pipe = unsafe { NamedPipe::from_raw_handle(io.into_raw_handle()) }; - PollEvented::new(pipe) + use std::os::windows::prelude::FromRawHandle; + + let raw = Arc::new(unsafe { StdFile::from_raw_handle(io.into_raw_handle()) }); + let io = Blocking::new(ArcFile(raw.clone())); + Ok(ChildStdio { raw, io }) +} + +pub(crate) fn convert_to_stdio(child_stdio: ChildStdio) -> io::Result { + let ChildStdio { raw, io } = child_stdio; + drop(io); // Try to drop the Arc count here + + Arc::try_unwrap(raw) + .or_else(|raw| duplicate_handle(&*raw)) + .map(Stdio::from) } -pub(crate) fn convert_to_stdio(io: PollEvented) -> io::Result { - let named_pipe = io.into_inner()?; +fn duplicate_handle(io: &T) -> io::Result { + use std::os::windows::prelude::FromRawHandle; - // Mio does not implement `IntoRawHandle` for `NamedPipe`, so we'll manually - // duplicate the handle here... unsafe { let mut dup_handle = INVALID_HANDLE_VALUE; let cur_proc = GetCurrentProcess(); let status = DuplicateHandle( cur_proc, - named_pipe.as_raw_handle(), + io.as_raw_handle(), cur_proc, &mut dup_handle, 0 as DWORD, @@ -200,6 +269,6 @@ pub(crate) fn convert_to_stdio(io: PollEvented) -> io::Result return Err(io::Error::last_os_error()); } - Ok(Stdio::from_raw_handle(dup_handle)) + Ok(StdFile::from_raw_handle(dup_handle)) } } diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 8f477a94004..08c0bbd3e32 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -272,7 +272,11 @@ impl Builder { /// .unwrap(); /// ``` pub fn enable_all(&mut self) -> &mut Self { - #[cfg(any(feature = "net", feature = "process", all(unix, feature = "signal")))] + #[cfg(any( + feature = "net", + all(unix, feature = "process"), + all(unix, feature = "signal") + ))] self.enable_io(); #[cfg(feature = "time")] self.enable_time(); diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 961668c0a74..075792a3077 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -23,7 +23,11 @@ pub struct Handle { pub(crate) struct HandleInner { /// Handles to the I/O drivers #[cfg_attr( - not(any(feature = "net", feature = "process", all(unix, feature = "signal"))), + not(any( + feature = "net", + all(unix, feature = "process"), + all(unix, feature = "signal"), + )), allow(dead_code) )] pub(super) io_handle: driver::IoHandle,