Skip to content
This repository has been archived by the owner on Apr 25, 2023. It is now read-only.

Enabling statement cache for PostgreSQL #143

Merged
merged 2 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions .envrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
export TEST_MYSQL=mysql://root:prisma@localhost:3306/prisma
export TEST_PSQL=postgres://postgres:prisma@localhost:5432/postgres
export TEST_MYSQL="mysql://root:prisma@localhost:3306/prisma"
export TEST_PSQL="postgres://postgres:prisma@localhost:5432/postgres"
export TEST_MSSQL="sqlserver://localhost:1433;database=master;user=SA;password=<YourStrong@Passw0rd>;trustServerCertificate=true"
14 changes: 13 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,21 @@ single-mysql = ["mysql", "json-1", "uuid-0_8", "chrono-0_4"]
single-sqlite = ["sqlite", "json-1", "uuid-0_8", "chrono-0_4"]
single-mssql = ["mssql"]

postgresql = [
"rust_decimal/tokio-pg",
"native-tls",
"tokio-postgres",
"postgres-native-tls",
"array",
"bytes",
"tokio",
"bit-vec",
"lru-cache"
]

pooled = ["mobc"]
sqlite = ["rusqlite", "libsqlite3-sys", "tokio/sync"]
json-1 = ["serde_json", "base64"]
postgresql = ["rust_decimal/tokio-pg", "native-tls", "tokio-postgres", "postgres-native-tls", "array", "bytes", "tokio", "bit-vec"]
uuid-0_8 = ["uuid"]
chrono-0_4 = ["chrono"]
mysql = ["mysql_async", "tokio"]
Expand All @@ -64,6 +75,7 @@ uuid = { version = "0.8", optional = true }
chrono = { version = "0.4", optional = true }
serde_json = { version = "1.0.48", optional = true }
base64 = { version = "0.11.0", optional = true }
lru-cache = { version = "0.1", optional = true }

rusqlite = { version = "0.21", features = ["chrono", "bundled"], optional = true }
libsqlite3-sys = { version = "0.17", default-features = false, features = ["bundled"], optional = true }
Expand Down
90 changes: 79 additions & 11 deletions src/connector/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
};
use async_trait::async_trait;
use futures::{future::FutureExt, lock::Mutex};
use lru_cache::LruCache;
use native_tls::{Certificate, Identity, TlsConnector};
use percent_encoding::percent_decode;
use postgres_native_tls::MakeTlsConnector;
Expand All @@ -19,7 +20,7 @@ use std::{
time::Duration,
};
use tokio::time::timeout;
use tokio_postgres::{config::SslMode, Client, Config};
use tokio_postgres::{config::SslMode, Client, Config, Statement};
use url::Url;

pub(crate) const DEFAULT_SCHEMA: &str = "public";
Expand All @@ -33,7 +34,7 @@ impl<T> std::fmt::Debug for Hidden<T> {
}
}

struct PostgresClient(Mutex<Client>);
struct PostgresClient(Client);

impl std::fmt::Debug for PostgresClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand All @@ -47,6 +48,7 @@ pub struct PostgreSql {
client: PostgresClient,
pg_bouncer: bool,
socket_timeout: Option<Duration>,
statement_cache: Mutex<LruCache<String, Statement>>,
}

#[derive(Debug, Clone, Copy, PartialEq)]
Expand Down Expand Up @@ -214,6 +216,14 @@ impl PostgresUrl {
self.query_params.connect_timeout
}

pub(crate) fn cache(&self) -> LruCache<String, Statement> {
if self.query_params.pg_bouncer == true {
LruCache::new(0)
} else {
LruCache::new(self.query_params.statement_cache_size)
}
}

fn parse_query_params(url: &Url) -> Result<PostgresUrlQueryParams, Error> {
let mut connection_limit = None;
let mut schema = String::from(DEFAULT_SCHEMA);
Expand All @@ -226,6 +236,7 @@ impl PostgresUrl {
let mut socket_timeout = None;
let mut connect_timeout = None;
let mut pg_bouncer = false;
let mut statement_cache_size = 500;

for (k, v) in url.query_pairs() {
match k.as_ref() {
Expand Down Expand Up @@ -256,6 +267,11 @@ impl PostgresUrl {
"sslpassword" => {
identity_password = Some(v.to_string());
}
"statement_cache_size" => {
statement_cache_size = v
.parse()
.map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
}
"sslaccept" => {
match v.as_ref() {
"strict" => {
Expand Down Expand Up @@ -324,6 +340,7 @@ impl PostgresUrl {
connect_timeout,
socket_timeout,
pg_bouncer,
statement_cache_size,
})
}

Expand Down Expand Up @@ -365,6 +382,7 @@ pub(crate) struct PostgresUrlQueryParams {
host: Option<String>,
socket_timeout: Option<Duration>,
connect_timeout: Option<Duration>,
statement_cache_size: usize,
}

impl PostgreSql {
Expand Down Expand Up @@ -423,9 +441,10 @@ impl PostgreSql {
client.simple_query(session_variables.as_str()).await?;

Ok(Self {
client: PostgresClient(Mutex::new(client)),
client: PostgresClient(client),
socket_timeout: url.query_params.socket_timeout,
pg_bouncer: url.query_params.pg_bouncer,
statement_cache: Mutex::new(url.cache()),
})
}

Expand All @@ -446,6 +465,39 @@ impl PostgreSql {
},
}
}

async fn fetch_cached(&self, sql: &str) -> crate::Result<Statement> {
let mut cache = self.statement_cache.lock().await;

match cache.get_mut(sql) {
Some(stmt) => {
#[cfg(not(feature = "tracing-log"))]
{
trace!("CACHE HIT: \"{}\"", sql);
}
#[cfg(feature = "tracing-log")]
{
tracing::trace!("CACHE HIT: \"{}\"", sql);
}

Ok(stmt.clone()) // arc'd
}
None => {
#[cfg(not(feature = "tracing-log"))]
{
trace!("CACHE MISS: \"{}\"", sql);
}
#[cfg(feature = "tracing-log")]
{
tracing::trace!("CACHE MISS: \"{}\"", sql);
}

let stmt = self.timeout(self.client.0.prepare(sql)).await?;
cache.insert(sql.to_string(), stmt.clone());
Ok(stmt)
}
}
}
}

impl TransactionCapable for PostgreSql {}
Expand All @@ -464,11 +516,10 @@ impl Queryable for PostgreSql {

async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
metrics::query("postgres.query_raw", sql, params, move || async move {
let client = self.client.0.lock().await;
let stmt = self.timeout(client.prepare(sql)).await?;
let stmt = self.fetch_cached(sql).await?;

let rows = self
.timeout(client.query(&stmt, conversion::conv_params(params).as_slice()))
.timeout(self.client.0.query(&stmt, conversion::conv_params(params).as_slice()))
.await?;

let mut result = ResultSet::new(stmt.to_column_names(), Vec::new());
Expand All @@ -484,11 +535,10 @@ impl Queryable for PostgreSql {

async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
metrics::query("postgres.execute_raw", sql, params, move || async move {
let client = self.client.0.lock().await;
let stmt = self.timeout(client.prepare(sql)).await?;
let stmt = self.fetch_cached(sql).await?;

let changes = self
.timeout(client.execute(&stmt, conversion::conv_params(params).as_slice()))
.timeout(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice()))
.await?;

Ok(changes)
Expand All @@ -498,8 +548,7 @@ impl Queryable for PostgreSql {

async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
metrics::query("postgres.raw_cmd", cmd, &[], move || async move {
let client = self.client.0.lock().await;
self.timeout(client.simple_query(cmd)).await?;
self.timeout(self.client.0.simple_query(cmd)).await?;

Ok(())
})
Expand Down Expand Up @@ -557,6 +606,25 @@ mod tests {
assert_eq!("/var/run/postgresql", url.host());
}

#[test]
fn should_allow_changing_of_cache_size() {
let url =
PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap();
assert_eq!(420, url.cache().capacity());
}

#[test]
fn should_have_default_cache_size() {
let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap();
assert_eq!(500, url.cache().capacity());
}

#[test]
fn should_not_enable_caching_with_pgbouncer() {
let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap();
assert_eq!(0, url.cache().capacity());
}

#[test]
fn should_parse_default_host() {
let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap();
Expand Down
3 changes: 3 additions & 0 deletions src/pooled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
//! transaction, a deallocation query `DEALLOCATE ALL` is executed right after
//! `BEGIN` to avoid possible collisions with statements created in other
//! sessions.
//! - `statement_cache_size`, number of prepared statements kept cached.
//! Defaults to 500, which means caching is off. If `pgbouncer` mode is enabled,
//! caching is always off.
//!
//! ## MySQL
//!
Expand Down
9 changes: 9 additions & 0 deletions src/single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ impl Quaint {
/// - `connect_timeout` defined in seconds (default: 5). Connecting to a
/// database will return a `ConnectTimeout` error if taking more than the
/// defined value.
/// - `pgbouncer` either `true` or `false`. If set, allows usage with the
/// pgBouncer connection pool in transaction mode. Additionally a transaction
/// is required for every query for the mode to work. When starting a new
/// transaction, a deallocation query `DEALLOCATE ALL` is executed right after
/// `BEGIN` to avoid possible collisions with statements created in other
/// sessions.
/// - `statement_cache_size`, number of prepared statements kept cached.
/// Defaults to 500, which means caching is off. If `pgbouncer` mode is enabled,
/// caching is always off.
///
/// MySQL:
///
Expand Down