Skip to content
This repository has been archived by the owner on Apr 25, 2023. It is now read-only.

Commit

Permalink
Enabling statement cache for PostgreSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
Julius de Bruijn committed Jun 17, 2020
1 parent 01ee7d7 commit 58631ce
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 14 deletions.
4 changes: 2 additions & 2 deletions .envrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
export TEST_MYSQL=mysql://root:prisma@localhost:3306/prisma
export TEST_PSQL=postgres://postgres:prisma@localhost:5432/postgres
export TEST_MYSQL="mysql://root:prisma@localhost:3306/prisma"
export TEST_PSQL="postgres://postgres:prisma@localhost:5432/postgres"
export TEST_MSSQL="sqlserver://localhost:1433;database=master;user=SA;password=<YourStrong@Passw0rd>;trustServerCertificate=true"
14 changes: 13 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,21 @@ single-mysql = ["mysql", "json-1", "uuid-0_8", "chrono-0_4"]
single-sqlite = ["sqlite", "json-1", "uuid-0_8", "chrono-0_4"]
single-mssql = ["mssql"]

postgresql = [
"rust_decimal/tokio-pg",
"native-tls",
"tokio-postgres",
"postgres-native-tls",
"array",
"bytes",
"tokio",
"bit-vec",
"lru-cache"
]

pooled = ["mobc"]
sqlite = ["rusqlite", "libsqlite3-sys", "tokio/sync"]
json-1 = ["serde_json", "base64"]
postgresql = ["rust_decimal/tokio-pg", "native-tls", "tokio-postgres", "postgres-native-tls", "array", "bytes", "tokio", "bit-vec"]
uuid-0_8 = ["uuid"]
chrono-0_4 = ["chrono"]
mysql = ["mysql_async", "tokio"]
Expand All @@ -64,6 +75,7 @@ uuid = { version = "0.8", optional = true }
chrono = { version = "0.4", optional = true }
serde_json = { version = "1.0.48", optional = true }
base64 = { version = "0.11.0", optional = true }
lru-cache = { version = "0.1", optional = true }

rusqlite = { version = "0.21", features = ["chrono", "bundled"], optional = true }
libsqlite3-sys = { version = "0.17", default-features = false, features = ["bundled"], optional = true }
Expand Down
71 changes: 60 additions & 11 deletions src/connector/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
};
use async_trait::async_trait;
use futures::{future::FutureExt, lock::Mutex};
use lru_cache::LruCache;
use native_tls::{Certificate, Identity, TlsConnector};
use percent_encoding::percent_decode;
use postgres_native_tls::MakeTlsConnector;
Expand All @@ -19,7 +20,7 @@ use std::{
time::Duration,
};
use tokio::time::timeout;
use tokio_postgres::{config::SslMode, Client, Config};
use tokio_postgres::{config::SslMode, Client, Config, Statement};
use url::Url;

pub(crate) const DEFAULT_SCHEMA: &str = "public";
Expand All @@ -33,7 +34,7 @@ impl<T> std::fmt::Debug for Hidden<T> {
}
}

struct PostgresClient(Mutex<Client>);
struct PostgresClient(Client);

impl std::fmt::Debug for PostgresClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand All @@ -47,6 +48,7 @@ pub struct PostgreSql {
client: PostgresClient,
pg_bouncer: bool,
socket_timeout: Option<Duration>,
statement_cache: Mutex<LruCache<String, Statement>>,
}

#[derive(Debug, Clone, Copy, PartialEq)]
Expand Down Expand Up @@ -214,6 +216,14 @@ impl PostgresUrl {
self.query_params.connect_timeout
}

pub(crate) fn cache(&self) -> LruCache<String, Statement> {
if self.query_params.pg_bouncer == true {
LruCache::new(0)
} else {
LruCache::new(self.query_params.statement_cache_size)
}
}

fn parse_query_params(url: &Url) -> Result<PostgresUrlQueryParams, Error> {
let mut connection_limit = None;
let mut schema = String::from(DEFAULT_SCHEMA);
Expand All @@ -226,6 +236,7 @@ impl PostgresUrl {
let mut socket_timeout = None;
let mut connect_timeout = None;
let mut pg_bouncer = false;
let mut statement_cache_size = 0;

for (k, v) in url.query_pairs() {
match k.as_ref() {
Expand Down Expand Up @@ -256,6 +267,11 @@ impl PostgresUrl {
"sslpassword" => {
identity_password = Some(v.to_string());
}
"statement_cache_size" => {
statement_cache_size = v
.parse()
.map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
}
"sslaccept" => {
match v.as_ref() {
"strict" => {
Expand Down Expand Up @@ -324,6 +340,7 @@ impl PostgresUrl {
connect_timeout,
socket_timeout,
pg_bouncer,
statement_cache_size,
})
}

Expand Down Expand Up @@ -365,6 +382,7 @@ pub(crate) struct PostgresUrlQueryParams {
host: Option<String>,
socket_timeout: Option<Duration>,
connect_timeout: Option<Duration>,
statement_cache_size: usize,
}

impl PostgreSql {
Expand Down Expand Up @@ -423,9 +441,10 @@ impl PostgreSql {
client.simple_query(session_variables.as_str()).await?;

Ok(Self {
client: PostgresClient(Mutex::new(client)),
client: PostgresClient(client),
socket_timeout: url.query_params.socket_timeout,
pg_bouncer: url.query_params.pg_bouncer,
statement_cache: Mutex::new(url.cache()),
})
}

Expand All @@ -446,6 +465,39 @@ impl PostgreSql {
},
}
}

async fn fetch_cached(&self, sql: &str) -> crate::Result<Statement> {
let mut cache = self.statement_cache.lock().await;

match cache.get_mut(sql) {
Some(stmt) => {
#[cfg(not(feature = "tracing-log"))]
{
trace!("CACHE HIT: \"{}\"", sql);
}
#[cfg(feature = "tracing-log")]
{
tracing::trace!("CACHE HIT: \"{}\"", sql);
}

Ok(stmt.clone()) // arc'd
}
None => {
#[cfg(not(feature = "tracing-log"))]
{
trace!("CACHE MISS: \"{}\"", sql);
}
#[cfg(feature = "tracing-log")]
{
tracing::trace!("CACHE MISS: \"{}\"", sql);
}

let stmt = self.timeout(self.client.0.prepare(sql)).await?;
cache.insert(sql.to_string(), stmt.clone());
Ok(stmt)
}
}
}
}

impl TransactionCapable for PostgreSql {}
Expand All @@ -464,11 +516,10 @@ impl Queryable for PostgreSql {

async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
metrics::query("postgres.query_raw", sql, params, move || async move {
let client = self.client.0.lock().await;
let stmt = self.timeout(client.prepare(sql)).await?;
let stmt = self.fetch_cached(sql).await?;

let rows = self
.timeout(client.query(&stmt, conversion::conv_params(params).as_slice()))
.timeout(self.client.0.query(&stmt, conversion::conv_params(params).as_slice()))
.await?;

let mut result = ResultSet::new(stmt.to_column_names(), Vec::new());
Expand All @@ -484,11 +535,10 @@ impl Queryable for PostgreSql {

async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
metrics::query("postgres.execute_raw", sql, params, move || async move {
let client = self.client.0.lock().await;
let stmt = self.timeout(client.prepare(sql)).await?;
let stmt = self.fetch_cached(sql).await?;

let changes = self
.timeout(client.execute(&stmt, conversion::conv_params(params).as_slice()))
.timeout(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice()))
.await?;

Ok(changes)
Expand All @@ -498,8 +548,7 @@ impl Queryable for PostgreSql {

async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
metrics::query("postgres.raw_cmd", cmd, &[], move || async move {
let client = self.client.0.lock().await;
self.timeout(client.simple_query(cmd)).await?;
self.timeout(self.client.0.simple_query(cmd)).await?;

Ok(())
})
Expand Down
3 changes: 3 additions & 0 deletions src/pooled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
//! transaction, a deallocation query `DEALLOCATE ALL` is executed right after
//! `BEGIN` to avoid possible collisions with statements created in other
//! sessions.
//! - `statement_cache_size`, number of prepared statements kept cached.
//! Defaults to 0, which means caching is off. If `pgbouncer` mode is enabled,
//! caching is always off.
//!
//! ## MySQL
//!
Expand Down
9 changes: 9 additions & 0 deletions src/single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ impl Quaint {
/// - `connect_timeout` defined in seconds (default: 5). Connecting to a
/// database will return a `ConnectTimeout` error if taking more than the
/// defined value.
/// - `pgbouncer` either `true` or `false`. If set, allows usage with the
/// pgBouncer connection pool in transaction mode. Additionally a transaction
/// is required for every query for the mode to work. When starting a new
/// transaction, a deallocation query `DEALLOCATE ALL` is executed right after
/// `BEGIN` to avoid possible collisions with statements created in other
/// sessions.
/// - `statement_cache_size`, number of prepared statements kept cached.
/// Defaults to 0, which means caching is off. If `pgbouncer` mode is enabled,
/// caching is always off.
///
/// MySQL:
///
Expand Down

0 comments on commit 58631ce

Please sign in to comment.