Skip to content
This repository has been archived by the owner on Apr 25, 2023. It is now read-only.

Enabling statement cache for PostgreSQL #143

Merged
merged 2 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions .envrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
export TEST_MYSQL=mysql://root:prisma@localhost:3306/prisma
export TEST_PSQL=postgres://postgres:prisma@localhost:5432/postgres
export TEST_MYSQL="mysql://root:prisma@localhost:3306/prisma"
export TEST_PSQL="postgres://postgres:prisma@localhost:5432/postgres"
export TEST_MSSQL="sqlserver://localhost:1433;database=master;user=SA;password=<YourStrong@Passw0rd>;trustServerCertificate=true"
14 changes: 13 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,21 @@ single-mysql = ["mysql", "json-1", "uuid-0_8", "chrono-0_4"]
single-sqlite = ["sqlite", "json-1", "uuid-0_8", "chrono-0_4"]
single-mssql = ["mssql"]

postgresql = [
"rust_decimal/tokio-pg",
"native-tls",
"tokio-postgres",
"postgres-native-tls",
"array",
"bytes",
"tokio",
"bit-vec",
"lru-cache"
]

pooled = ["mobc"]
sqlite = ["rusqlite", "libsqlite3-sys", "tokio/sync"]
json-1 = ["serde_json", "base64"]
postgresql = ["rust_decimal/tokio-pg", "native-tls", "tokio-postgres", "postgres-native-tls", "array", "bytes", "tokio", "bit-vec"]
uuid-0_8 = ["uuid"]
chrono-0_4 = ["chrono"]
mysql = ["mysql_async", "tokio"]
Expand All @@ -64,6 +75,7 @@ uuid = { version = "0.8", optional = true }
chrono = { version = "0.4", optional = true }
serde_json = { version = "1.0.48", optional = true }
base64 = { version = "0.11.0", optional = true }
lru-cache = { version = "0.1", optional = true }

rusqlite = { version = "0.21", features = ["chrono", "bundled"], optional = true }
libsqlite3-sys = { version = "0.17", default-features = false, features = ["bundled"], optional = true }
Expand Down
71 changes: 60 additions & 11 deletions src/connector/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
};
use async_trait::async_trait;
use futures::{future::FutureExt, lock::Mutex};
use lru_cache::LruCache;
use native_tls::{Certificate, Identity, TlsConnector};
use percent_encoding::percent_decode;
use postgres_native_tls::MakeTlsConnector;
Expand All @@ -19,7 +20,7 @@ use std::{
time::Duration,
};
use tokio::time::timeout;
use tokio_postgres::{config::SslMode, Client, Config};
use tokio_postgres::{config::SslMode, Client, Config, Statement};
use url::Url;

pub(crate) const DEFAULT_SCHEMA: &str = "public";
Expand All @@ -33,7 +34,7 @@ impl<T> std::fmt::Debug for Hidden<T> {
}
}

struct PostgresClient(Mutex<Client>);
struct PostgresClient(Client);

impl std::fmt::Debug for PostgresClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand All @@ -47,6 +48,7 @@ pub struct PostgreSql {
client: PostgresClient,
pg_bouncer: bool,
socket_timeout: Option<Duration>,
statement_cache: Mutex<LruCache<String, Statement>>,
}

#[derive(Debug, Clone, Copy, PartialEq)]
Expand Down Expand Up @@ -214,6 +216,14 @@ impl PostgresUrl {
self.query_params.connect_timeout
}

pub(crate) fn cache(&self) -> LruCache<String, Statement> {
if self.query_params.pg_bouncer == true {
LruCache::new(0)
} else {
LruCache::new(self.query_params.statement_cache_size)
}
}

fn parse_query_params(url: &Url) -> Result<PostgresUrlQueryParams, Error> {
let mut connection_limit = None;
let mut schema = String::from(DEFAULT_SCHEMA);
Expand All @@ -226,6 +236,7 @@ impl PostgresUrl {
let mut socket_timeout = None;
let mut connect_timeout = None;
let mut pg_bouncer = false;
let mut statement_cache_size = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll probably figure things out as we start using it, but for other quaint users, if it works we may want to default to a non-zero size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to enable this to everybody by default. Should be opt-in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no other way to reuse prepared statements currently, so this should at least to be well documented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Making the default here to 500. Let's see...


for (k, v) in url.query_pairs() {
match k.as_ref() {
Expand Down Expand Up @@ -256,6 +267,11 @@ impl PostgresUrl {
"sslpassword" => {
identity_password = Some(v.to_string());
}
"statement_cache_size" => {
statement_cache_size = v
.parse()
.map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
}
"sslaccept" => {
match v.as_ref() {
"strict" => {
Expand Down Expand Up @@ -324,6 +340,7 @@ impl PostgresUrl {
connect_timeout,
socket_timeout,
pg_bouncer,
statement_cache_size,
})
}

Expand Down Expand Up @@ -365,6 +382,7 @@ pub(crate) struct PostgresUrlQueryParams {
host: Option<String>,
socket_timeout: Option<Duration>,
connect_timeout: Option<Duration>,
statement_cache_size: usize,
}

impl PostgreSql {
Expand Down Expand Up @@ -423,9 +441,10 @@ impl PostgreSql {
client.simple_query(session_variables.as_str()).await?;

Ok(Self {
client: PostgresClient(Mutex::new(client)),
client: PostgresClient(client),
socket_timeout: url.query_params.socket_timeout,
pg_bouncer: url.query_params.pg_bouncer,
statement_cache: Mutex::new(url.cache()),
})
}

Expand All @@ -446,6 +465,39 @@ impl PostgreSql {
},
}
}

async fn fetch_cached(&self, sql: &str) -> crate::Result<Statement> {
let mut cache = self.statement_cache.lock().await;

match cache.get_mut(sql) {
Some(stmt) => {
#[cfg(not(feature = "tracing-log"))]
{
trace!("CACHE HIT: \"{}\"", sql);
}
#[cfg(feature = "tracing-log")]
{
tracing::trace!("CACHE HIT: \"{}\"", sql);
}

Ok(stmt.clone()) // arc'd
}
None => {
#[cfg(not(feature = "tracing-log"))]
{
trace!("CACHE MISS: \"{}\"", sql);
}
#[cfg(feature = "tracing-log")]
{
tracing::trace!("CACHE MISS: \"{}\"", sql);
}

let stmt = self.timeout(self.client.0.prepare(sql)).await?;
cache.insert(sql.to_string(), stmt.clone());
Ok(stmt)
}
}
}
}

impl TransactionCapable for PostgreSql {}
Expand All @@ -464,11 +516,10 @@ impl Queryable for PostgreSql {

async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
metrics::query("postgres.query_raw", sql, params, move || async move {
let client = self.client.0.lock().await;
let stmt = self.timeout(client.prepare(sql)).await?;
let stmt = self.fetch_cached(sql).await?;

let rows = self
.timeout(client.query(&stmt, conversion::conv_params(params).as_slice()))
.timeout(self.client.0.query(&stmt, conversion::conv_params(params).as_slice()))
.await?;

let mut result = ResultSet::new(stmt.to_column_names(), Vec::new());
Expand All @@ -484,11 +535,10 @@ impl Queryable for PostgreSql {

async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
metrics::query("postgres.execute_raw", sql, params, move || async move {
let client = self.client.0.lock().await;
let stmt = self.timeout(client.prepare(sql)).await?;
let stmt = self.fetch_cached(sql).await?;

let changes = self
.timeout(client.execute(&stmt, conversion::conv_params(params).as_slice()))
.timeout(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice()))
.await?;

Ok(changes)
Expand All @@ -498,8 +548,7 @@ impl Queryable for PostgreSql {

async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
metrics::query("postgres.raw_cmd", cmd, &[], move || async move {
let client = self.client.0.lock().await;
self.timeout(client.simple_query(cmd)).await?;
self.timeout(self.client.0.simple_query(cmd)).await?;

Ok(())
})
Expand Down
3 changes: 3 additions & 0 deletions src/pooled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
//! transaction, a deallocation query `DEALLOCATE ALL` is executed right after
//! `BEGIN` to avoid possible collisions with statements created in other
//! sessions.
//! - `statement_cache_size`, number of prepared statements kept cached.
//! Defaults to 0, which means caching is off. If `pgbouncer` mode is enabled,
//! caching is always off.
//!
//! ## MySQL
//!
Expand Down
9 changes: 9 additions & 0 deletions src/single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ impl Quaint {
/// - `connect_timeout` defined in seconds (default: 5). Connecting to a
/// database will return a `ConnectTimeout` error if taking more than the
/// defined value.
/// - `pgbouncer` either `true` or `false`. If set, allows usage with the
/// pgBouncer connection pool in transaction mode. Additionally a transaction
/// is required for every query for the mode to work. When starting a new
/// transaction, a deallocation query `DEALLOCATE ALL` is executed right after
/// `BEGIN` to avoid possible collisions with statements created in other
/// sessions.
/// - `statement_cache_size`, number of prepared statements kept cached.
/// Defaults to 0, which means caching is off. If `pgbouncer` mode is enabled,
/// caching is always off.
///
/// MySQL:
///
Expand Down