Skip to content

Commit

Permalink
Rework PgCopyBoth interface
Browse files Browse the repository at this point in the history
Use something similar to PgListener
  • Loading branch information
grim7reaper committed Mar 11, 2024
1 parent 0bdb8d9 commit 0d1bcf6
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 205 deletions.
197 changes: 2 additions & 195 deletions sqlx-postgres/src/copy.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};

use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::future::Either;
use futures_util::stream::Stream;
use sqlx_core::bytes::{BufMut, Bytes};

use crate::connection::PgConnection;
use crate::error::{Error, Result};
use crate::ext::async_stream::TryAsyncStream;
use crate::io::{AsyncRead, AsyncReadExt, Encode};
use crate::io::{AsyncRead, AsyncReadExt};
use crate::message::{
CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
};
use crate::pool::{Pool, PoolConnection};
use crate::Postgres;
use sqlx_core::bytes::{BufMut, Bytes};

impl PgConnection {
/// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
Expand Down Expand Up @@ -362,192 +358,3 @@ async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(

Ok(Box::pin(stream))
}

enum PgCopyBothCommand {
Begin(Vec<u8>),
CopyData(Vec<u8>),
CopyDone { from_client: bool },
}

pub struct PgCopyBothSender(flume::Sender<PgCopyBothCommand>);
pub struct PgCopyBothReceiver(flume::r#async::RecvStream<'static, Result<Bytes>>);

// Open a duplex connection allowing high-speed bulk data transfer to and from the server.
//
// # Example
//
// ```rust,no_run
// # use sqlx::postgres::{
// # pg_copy_both, PgConnectOptions, PgPoolOptions, PgReplicationMode,
// # };
// // Connection must be configured with a replication mode!
// let options = PgConnectOptions::new()
// .host("0.0.0.0")
// .replication_mode(PgReplicationMode::Logical);
//
// let pool = PgPoolOptions::new()
// .connect_with(options)
// .await
// .expect("failed to connect to postgres");
//
// let query = format!(
// r#"START_REPLICATION SLOT "{}" LOGICAL {} ("proto_version" '1', "publication_names" '{}')"#,
// "test_slot", "0/1573178", "test_publication",
// );
// let (copy_tx, copy_recv) = pg_copy_both(&pool, query.as_str())
// .await
// .expect("start replication");
// // Read data from the server.
// while let Some(data) = copy_recv.next().await {
// println!("data: {:?}", data);
// // And send some back (e.g. keepalive).
// copy_tx.send(Vec::new()).await?;
// }
// // Connection closed.
// ```
pub async fn pg_copy_both(
pool: &Pool<Postgres>,
statement: &str,
) -> Result<(PgCopyBothSender, PgCopyBothReceiver), Error> {
// Setup upstream/downstream channels.
let (recv_tx, recv_rx) = flume::bounded(1);
let (send_tx, send_rx) = flume::bounded(1);

crate::rt::spawn({
let pool = pool.clone();
async move {
if let Err(err) = copy_both_handler(pool, recv_tx.clone(), send_rx).await {
let _ignored = recv_tx.send_async(Err(err)).await;
}
}
});

// Execute the given statement to switch into CopyBoth mode.
let mut buf = Vec::new();
Query(statement).encode(&mut buf);
send_tx
.send_async(PgCopyBothCommand::Begin(buf))
.await
.map_err(|_err| Error::WorkerCrashed)?;

Ok((
PgCopyBothSender(send_tx),
PgCopyBothReceiver(recv_rx.into_stream()),
))
}

impl PgCopyBothSender {
/// Send a chunk of `COPY` data.
pub async fn send(&self, data: impl Into<Vec<u8>>) -> Result<()> {
self.0
.send_async(PgCopyBothCommand::CopyData(data.into()))
.await
.map_err(|_err| Error::WorkerCrashed)?;

Ok(())
}

/// Signal that the CopyBoth mode is complete.
pub async fn finish(self) -> Result<()> {
self.0
.send_async(PgCopyBothCommand::CopyDone { from_client: true })
.await
.map_err(|_err| Error::WorkerCrashed)?;

Ok(())
}
}

impl Stream for PgCopyBothReceiver {
type Item = Result<Bytes, Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.0).poll_next(cx)
}
}

async fn copy_both_handler(
pool: Pool<Postgres>,
recv_tx: flume::Sender<Result<Bytes>>,
send_rx: flume::Receiver<PgCopyBothCommand>,
) -> Result<()> {
let mut has_started = false;
let mut conn = pool.acquire().await?;
conn.wait_until_ready().await?;

loop {
// Wait for either incoming data or a message to send.
let command = match futures_util::future::select(
std::pin::pin!(conn.stream.recv()),
std::pin::pin!(send_rx.recv_async()),
)
.await
{
Either::Left((data, _)) => {
let message = data?;
match message.format {
MessageFormat::CopyData => {
recv_tx
.send_async(message.decode::<CopyData<Bytes>>().map(|x| x.0))
.await
.map_err(|_err| Error::WorkerCrashed)?;
None
}
// Server is done sending data, close our side.
MessageFormat::CopyDone => {
let _ = message.decode::<CopyDone>()?;
Some(PgCopyBothCommand::CopyDone { from_client: false })
}
_ => {
return Err(err_protocol!(
"unexpected message format during copy out: {:?}",
message.format
))
}
}
}
// This only errors if the consumer has been dropped.
// There is no reason to continue.
Either::Right((command, _)) => Some(command.map_err(|_| Error::WorkerCrashed)?),
};

if let Some(command) = command {
match command {
// Start the stream.
PgCopyBothCommand::Begin(buf) => {
if has_started {
return Err(err_protocol!("Copy-Both mode already initiated"));
}
conn.stream.send(buf.as_slice()).await?;
// Consume the server response.
conn.stream
.recv_expect::<CopyResponse>(MessageFormat::CopyBothResponse)
.await?;
has_started = true;
}
// Send data to the server.
PgCopyBothCommand::CopyData(data) => {
if !has_started {
return Err(err_protocol!("connection hasn't been started"));
}
conn.stream.send(CopyData(data)).await?;
}

// Grafeceful shutdown of the stream.
PgCopyBothCommand::CopyDone { from_client } => {
if !has_started {
return Err(err_protocol!("connection hasn't been started"));
}
conn.stream.send(CopyDone).await?;
// If we are the first to send CopyDone, wait for the server to send his own.
if from_client {
conn.stream.recv_expect(MessageFormat::CopyDone).await?;
}
break;
}
}
}
}

Ok(())
}
4 changes: 3 additions & 1 deletion sqlx-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod listener;
mod message;
mod options;
mod query_result;
mod replication;
mod row;
mod statement;
mod transaction;
Expand All @@ -39,13 +40,14 @@ pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
pub use arguments::{PgArgumentBuffer, PgArguments};
pub use column::PgColumn;
pub use connection::PgConnection;
pub use copy::{pg_copy_both, PgCopyBothReceiver, PgCopyBothSender, PgCopyIn};
pub use copy::PgCopyIn;
pub use database::Postgres;
pub use error::{PgDatabaseError, PgErrorPosition};
pub use listener::{PgListener, PgNotification};
pub use message::PgSeverity;
pub use options::{PgConnectOptions, PgReplicationMode, PgSslMode};
pub use query_result::PgQueryResult;
pub use replication::{PgCopyBothReceiver, PgCopyBothSender, PgReplication, PgReplicationPool};
pub use row::PgRow;
pub use statement::PgStatement;
pub use transaction::PgTransactionManager;
Expand Down
10 changes: 1 addition & 9 deletions sqlx-postgres/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,15 +514,7 @@ impl PgConnectOptions {
/// can be used.
///
/// The default behavior is to disable the replication mode.
///
/// # Example
///
/// ```rust
/// # use sqlx::postgres::{PgConnectOptions, PgReplicationMode};
/// let options = PgConnectOptions::new()
/// .replication_mode(PgReplicationMode::Logical);
/// ```
pub fn replication_mode(mut self, replication_mode: PgReplicationMode) -> Self {
pub(crate) fn replication_mode(mut self, replication_mode: PgReplicationMode) -> Self {
self.replication_mode = Some(replication_mode);
self
}
Expand Down

0 comments on commit 0d1bcf6

Please sign in to comment.