diff --git a/Cargo.lock b/Cargo.lock index 96876295449d..4ae6b067cb73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1950,11 +1950,13 @@ dependencies = [ "hyper", "indexmap", "indoc", + "introspection-core", "itertools", "log", "migration-connector", "migration-core", "once_cell", + "pretty_assertions", "prisma-inflector", "prisma-models", "quaint", @@ -1974,6 +1976,7 @@ dependencies = [ "tokio", "tracing", "tracing-attributes", + "tracing-futures", "tracing-log", "tracing-subscriber", "url 2.1.1", @@ -2107,9 +2110,10 @@ dependencies = [ [[package]] name = "quaint" version = "0.2.0-alpha.9" -source = "git+https://github.com/prisma/quaint#f3c27482fd49bde532339716cce514d764fa9548" +source = "git+https://github.com/prisma/quaint#5c85aed15976ebd8c9004251302424a587c8840b" dependencies = [ "async-trait", + "base64 0.11.0", "bytes", "chrono", "futures 0.3.4", diff --git a/introspection-engine/connectors/sql-introspection-connector/src/misc_helpers.rs b/introspection-engine/connectors/sql-introspection-connector/src/misc_helpers.rs index 2a137cb9efa1..7dd71724c187 100644 --- a/introspection-engine/connectors/sql-introspection-connector/src/misc_helpers.rs +++ b/introspection-engine/connectors/sql-introspection-connector/src/misc_helpers.rs @@ -463,7 +463,7 @@ fn parse_bool(value: &str) -> Option { static RE_FLOAT: Lazy = Lazy::new(|| Regex::new(r"^'?([^']+)'?$").expect("compile regex")); -fn parse_float(value: &str) -> Option { +fn parse_float(value: &str) -> Option { debug!("Parsing float '{}'", value); let rslt = RE_FLOAT.captures(value); if rslt.is_none() { @@ -473,7 +473,7 @@ fn parse_float(value: &str) -> Option { let captures = rslt.expect("get captures"); let num_str = captures.get(1).expect("get capture").as_str(); - let num_rslt = num_str.parse::(); + let num_rslt = num_str.parse::(); match num_rslt { Ok(num) => Some(num), Err(_) => { diff --git a/introspection-engine/core/src/lib.rs b/introspection-engine/core/src/lib.rs new file mode 100644 index 000000000000..4b1a69169d4b --- /dev/null +++ b/introspection-engine/core/src/lib.rs @@ -0,0 +1,7 @@ +mod command_error; +mod error; +mod error_rendering; +mod rpc; + +pub use error::Error; +pub use rpc::RpcImpl; diff --git a/introspection-engine/core/src/rpc.rs b/introspection-engine/core/src/rpc.rs index f5189c1b0fc0..390aea6c3d98 100644 --- a/introspection-engine/core/src/rpc.rs +++ b/introspection-engine/core/src/rpc.rs @@ -26,7 +26,7 @@ pub trait Rpc { fn introspect(&self, input: IntrospectionInput) -> RpcFutureResult; } -pub(crate) struct RpcImpl; +pub struct RpcImpl; impl Rpc for RpcImpl { fn list_databases(&self, input: IntrospectionInput) -> RpcFutureResult> { @@ -47,7 +47,7 @@ impl Rpc for RpcImpl { } impl RpcImpl { - pub(crate) fn new() -> Self { + pub fn new() -> Self { RpcImpl } @@ -63,7 +63,7 @@ impl RpcImpl { Ok(Box::new(SqlIntrospectionConnector::new(&url).await?)) } - pub(crate) async fn introspect_internal(schema: String) -> RpcResult { + pub async fn introspect_internal(schema: String) -> RpcResult { let config = datamodel::parse_configuration(&schema).map_err(Error::from)?; let url = config .datasources @@ -85,17 +85,17 @@ impl RpcImpl { } } - pub(crate) async fn list_databases_internal(schema: String) -> RpcResult> { + pub async fn list_databases_internal(schema: String) -> RpcResult> { let connector = RpcImpl::load_connector(&schema).await?; Ok(connector.list_databases().await.map_err(Error::from)?) } - pub(crate) async fn get_database_description(schema: String) -> RpcResult { + pub async fn get_database_description(schema: String) -> RpcResult { let connector = RpcImpl::load_connector(&schema).await?; Ok(connector.get_database_description().await.map_err(Error::from)?) } - pub(crate) async fn get_database_metadata_internal(schema: String) -> RpcResult { + pub async fn get_database_metadata_internal(schema: String) -> RpcResult { let connector = RpcImpl::load_connector(&schema).await?; Ok(connector.get_metadata().await.map_err(Error::from)?) } diff --git a/libs/datamodel/connectors/datamodel-connector/src/scalars.rs b/libs/datamodel/connectors/datamodel-connector/src/scalars.rs index 6b556a236fd7..124e540a845d 100644 --- a/libs/datamodel/connectors/datamodel-connector/src/scalars.rs +++ b/libs/datamodel/connectors/datamodel-connector/src/scalars.rs @@ -43,8 +43,8 @@ impl ToString for ScalarType { #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub enum ScalarValue { Int(i32), - Float(f32), - Decimal(f32), + Float(f64), + Decimal(f64), Boolean(bool), String(String), DateTime(DateTime), diff --git a/libs/datamodel/core/src/common/value_validator.rs b/libs/datamodel/core/src/common/value_validator.rs index f96e0f77341e..cec388dd8ef8 100644 --- a/libs/datamodel/core/src/common/value_validator.rs +++ b/libs/datamodel/core/src/common/value_validator.rs @@ -115,20 +115,20 @@ impl ValueValidator { } /// Tries to convert the wrapped value to a Prisma Float. - pub fn as_float(&self) -> Result { + pub fn as_float(&self) -> Result { match &self.value { - ast::Expression::NumericValue(value, _) => self.wrap_error_from_result(value.parse::(), "numeric"), - ast::Expression::Any(value, _) => self.wrap_error_from_result(value.parse::(), "numeric"), + ast::Expression::NumericValue(value, _) => self.wrap_error_from_result(value.parse::(), "numeric"), + ast::Expression::Any(value, _) => self.wrap_error_from_result(value.parse::(), "numeric"), _ => Err(self.construct_type_mismatch_error("numeric")), } } // TODO: Ask which decimal type to take. /// Tries to convert the wrapped value to a Prisma Decimal. - pub fn as_decimal(&self) -> Result { + pub fn as_decimal(&self) -> Result { match &self.value { - ast::Expression::NumericValue(value, _) => self.wrap_error_from_result(value.parse::(), "numeric"), - ast::Expression::Any(value, _) => self.wrap_error_from_result(value.parse::(), "numeric"), + ast::Expression::NumericValue(value, _) => self.wrap_error_from_result(value.parse::(), "numeric"), + ast::Expression::Any(value, _) => self.wrap_error_from_result(value.parse::(), "numeric"), _ => Err(self.construct_type_mismatch_error("numeric")), } } diff --git a/libs/prisma-value/src/lib.rs b/libs/prisma-value/src/lib.rs index e759153bfd82..ace486df9a0b 100644 --- a/libs/prisma-value/src/lib.rs +++ b/libs/prisma-value/src/lib.rs @@ -140,16 +140,6 @@ impl TryFrom for PrismaValue { fn try_from(f: f64) -> PrismaValueResult { Decimal::from_f64(f) - .map(|d| PrismaValue::Float(d)) - .ok_or(ConversionFailure::new("f32", "Decimal")) - } -} - -impl TryFrom for PrismaValue { - type Error = ConversionFailure; - - fn try_from(f: f32) -> PrismaValueResult { - Decimal::from_f32(f) .map(|d| PrismaValue::Float(d)) .ok_or(ConversionFailure::new("f64", "Decimal")) } diff --git a/libs/prisma-value/src/sql_ext.rs b/libs/prisma-value/src/sql_ext.rs index ce9291f6a7d7..692d13778d75 100644 --- a/libs/prisma-value/src/sql_ext.rs +++ b/libs/prisma-value/src/sql_ext.rs @@ -15,6 +15,9 @@ impl<'a> From> for PrismaValue { ParameterizedValue::Uuid(uuid) => PrismaValue::Uuid(uuid), ParameterizedValue::DateTime(dt) => PrismaValue::DateTime(dt), ParameterizedValue::Char(c) => PrismaValue::String(c.to_string()), + ParameterizedValue::Bytes(bytes) => PrismaValue::String( + String::from_utf8(bytes.into_owned()).expect("PrismaValue::String from ParameterizedValue::Bytes"), + ), } } } diff --git a/libs/sql-schema-describer/src/mysql.rs b/libs/sql-schema-describer/src/mysql.rs index fd30c7c581cc..565402f00dba 100644 --- a/libs/sql-schema-describer/src/mysql.rs +++ b/libs/sql-schema-describer/src/mysql.rs @@ -456,11 +456,12 @@ fn get_column_type_and_enum( ("numeric", _) => ColumnTypeFamily::Float, ("float", _) => ColumnTypeFamily::Float, ("double", _) => ColumnTypeFamily::Float, + ("bit", _) => ColumnTypeFamily::Int, ("date", _) => ColumnTypeFamily::DateTime, ("time", _) => ColumnTypeFamily::DateTime, ("datetime", _) => ColumnTypeFamily::DateTime, ("timestamp", _) => ColumnTypeFamily::DateTime, - ("year", _) => ColumnTypeFamily::DateTime, + ("year", _) => ColumnTypeFamily::Int, ("char", _) => ColumnTypeFamily::String, ("varchar", _) => ColumnTypeFamily::String, ("text", _) => ColumnTypeFamily::String, diff --git a/libs/sql-schema-describer/src/postgres.rs b/libs/sql-schema-describer/src/postgres.rs index 883954aca725..6cd2be06bbfe 100644 --- a/libs/sql-schema-describer/src/postgres.rs +++ b/libs/sql-schema-describer/src/postgres.rs @@ -573,6 +573,7 @@ fn get_column_type<'a>(data_type: &str, full_data_type: &'a str, arity: ColumnAr "int2" | "_int2" => Int, "int4" | "_int4" => Int, "int8" | "_int8" => Int, + "oid" | "_oid" => Int, "float4" | "_float4" => Float, "float8" | "_float8" => Float, "bool" | "_bool" => Boolean, @@ -592,8 +593,9 @@ fn get_column_type<'a>(data_type: &str, full_data_type: &'a str, arity: ColumnAr "path" | "_path" => Geometric, "polygon" | "_polygon" => Geometric, "bpchar" | "_bpchar" => String, - "interval" | "_interval" => DateTime, + "interval" | "_interval" => String, "numeric" | "_numeric" => Float, + "money" | "_money" => Float, "pg_lsn" | "_pg_lsn" => LogSequenceNumber, "time" | "_time" => DateTime, "timetz" | "_timetz" => DateTime, diff --git a/libs/sql-schema-describer/tests/mysql_introspection_tests.rs b/libs/sql-schema-describer/tests/mysql_introspection_tests.rs index 7e8e6e6d0224..bd4ff4889873 100644 --- a/libs/sql-schema-describer/tests/mysql_introspection_tests.rs +++ b/libs/sql-schema-describer/tests/mysql_introspection_tests.rs @@ -228,7 +228,7 @@ async fn all_mysql_column_types_must_work() { name: "year_col".to_string(), tpe: ColumnType { raw: "year".to_string(), - family: ColumnTypeFamily::DateTime, + family: ColumnTypeFamily::Int, arity: ColumnArity::Required, }, diff --git a/libs/sql-schema-describer/tests/postgres_introspection_tests.rs b/libs/sql-schema-describer/tests/postgres_introspection_tests.rs index 75ba4a4b461c..fc72e053443a 100644 --- a/libs/sql-schema-describer/tests/postgres_introspection_tests.rs +++ b/libs/sql-schema-describer/tests/postgres_introspection_tests.rs @@ -327,7 +327,7 @@ async fn all_postgres_column_types_must_work() { name: "interval_col".into(), tpe: ColumnType { raw: "interval".into(), - family: ColumnTypeFamily::DateTime, + family: ColumnTypeFamily::String, arity: ColumnArity::Required, }, diff --git a/libs/test-setup/src/lib.rs b/libs/test-setup/src/lib.rs index 1683536bd8f4..ac18d64ee572 100644 --- a/libs/test-setup/src/lib.rs +++ b/libs/test-setup/src/lib.rs @@ -335,11 +335,17 @@ pub async fn create_postgres_database(original_url: &Url) -> Result(self, idents: &[(TypeIdentifier, FieldArity)]) -> crate::Result { let mut row = SqlRow::default(); let row_width = idents.len(); + row.values.reserve(row_width); for (i, p_value) in self.into_iter().enumerate().take(row_width) { let pv = match &idents[i] { (type_identifier, FieldArity::List) => match p_value { @@ -118,7 +119,10 @@ pub fn row_value_to_prisma_value( ParameterizedValue::Text(dt_string) => { let dt = DateTime::parse_from_rfc3339(dt_string.borrow()) .or_else(|_| DateTime::parse_from_rfc2822(dt_string.borrow())) - .expect(&format!("Could not parse stored DateTime string: {}", dt_string)); + .map_err(|err| { + failure::format_err!("Could not parse stored DateTime string: {} ({})", dt_string, err) + }) + .unwrap(); PrismaValue::DateTime(dt.with_timezone(&Utc)) } @@ -136,7 +140,13 @@ pub fn row_value_to_prisma_value( ParameterizedValue::Integer(i) => { PrismaValue::Float(Decimal::from_f64(i as f64).expect("f64 was not a Decimal.")) } - ParameterizedValue::Text(s) => PrismaValue::Float(s.parse().unwrap()), + ParameterizedValue::Text(_) | ParameterizedValue::Bytes(_) => PrismaValue::Float( + p_value + .as_str() + .expect("text/bytes as str") + .parse() + .map_err(|err: rust_decimal::Error| SqlError::ColumnReadFailure(err.into()))?, + ), _ => { let error = io::Error::new( io::ErrorKind::InvalidData, @@ -145,7 +155,23 @@ pub fn row_value_to_prisma_value( return Err(SqlError::ConversionError(error.into())); } }, - _ => PrismaValue::from(p_value), + TypeIdentifier::Int => match p_value { + ParameterizedValue::Integer(i) => PrismaValue::Int(i), + ParameterizedValue::Bytes(bytes) => PrismaValue::Int(interpret_bytes_as_i64(&bytes)), + ParameterizedValue::Text(txt) => PrismaValue::Int( + i64::from_str(dbg!(txt.trim_start_matches('\0'))) + .map_err(|err| SqlError::ConversionError(err.into()))?, + ), + other => PrismaValue::from(other), + }, + TypeIdentifier::String => match p_value { + ParameterizedValue::Uuid(uuid) => PrismaValue::String(uuid.to_string()), + ParameterizedValue::Json(json_value) => { + PrismaValue::String(serde_json::to_string(&json_value).expect("JSON value to string")) + } + ParameterizedValue::Null => PrismaValue::Null, + other => PrismaValue::from(other), + }, }) } @@ -171,3 +197,84 @@ impl From<&SqlId> for DatabaseValue<'static> { id.clone().into() } } + +// We assume the bytes are stored as a big endian signed integer, because that is what +// mysql does if you enter a numeric value for a bits column. +fn interpret_bytes_as_i64(bytes: &[u8]) -> i64 { + match bytes.len() { + 8 => i64::from_be_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]), + len if len < 8 => { + let sign_bit_mask: u8 = 0b10000000; + // The first byte will only contain the sign bit. + let most_significant_bit_byte = bytes[0] & sign_bit_mask; + let padding = if most_significant_bit_byte == 0 { 0 } else { 0b11111111 }; + let mut i64_bytes = [padding; 8]; + + for (target_byte, source_byte) in i64_bytes.iter_mut().rev().zip(bytes.iter().rev()) { + *target_byte = *source_byte; + } + + i64::from_be_bytes(i64_bytes) + } + 0 => 0, + _ => panic!("Attempted to interpret more than 8 bytes as an integer."), + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn quaint_bytes_to_integer_conversion_works() { + // Negative i64 + { + let i: i64 = -123456789123; + let bytes = i.to_be_bytes(); + let roundtripped = interpret_bytes_as_i64(&bytes); + assert_eq!(roundtripped, i); + } + + // Positive i64 + { + let i: i64 = 123456789123; + let bytes = i.to_be_bytes(); + let roundtripped = interpret_bytes_as_i64(&bytes); + assert_eq!(roundtripped, i); + } + + // Positive i32 + { + let i: i32 = 123456789; + let bytes = i.to_be_bytes(); + let roundtripped = interpret_bytes_as_i64(&bytes); + assert_eq!(roundtripped, i as i64); + } + + // Negative i32 + { + let i: i32 = -123456789; + let bytes = i.to_be_bytes(); + let roundtripped = interpret_bytes_as_i64(&bytes); + assert_eq!(roundtripped, i as i64); + } + + // Positive i16 + { + let i: i16 = 12345; + let bytes = i.to_be_bytes(); + let roundtripped = interpret_bytes_as_i64(&bytes); + assert_eq!(roundtripped, i as i64); + } + + // Negative i16 + { + let i: i16 = -12345; + let bytes = i.to_be_bytes(); + let roundtripped = interpret_bytes_as_i64(&bytes); + assert_eq!(roundtripped, i as i64); + } + } +} diff --git a/query-engine/prisma/Cargo.toml b/query-engine/prisma/Cargo.toml index 93e823e7bd36..7577a2552598 100644 --- a/query-engine/prisma/Cargo.toml +++ b/query-engine/prisma/Cargo.toml @@ -42,6 +42,8 @@ tracing-attributes = "0.1" log = "0.4" user-facing-errors = { path = "../../libs/user-facing-errors" } +pretty_assertions = "0.6.1" +tracing-futures = "0.2.3" [build-dependencies] rustc_version = "0.2.3" @@ -52,6 +54,7 @@ test-setup = { path = "../../libs/test-setup" } quaint = { git = "https://github.com/prisma/quaint", features = ["full"] } migration-connector = { path = "../../migration-engine/connectors/migration-connector" } migration-core = { path = "../../migration-engine/core" } +introspection-core = { path = "../../introspection-engine/core" } sql-migration-connector = { path = "../../migration-engine/connectors/sql-migration-connector" } indoc = "0.3" anyhow = "1" diff --git a/query-engine/prisma/src/request_handlers/graphql/protocol_adapter.rs b/query-engine/prisma/src/request_handlers/graphql/protocol_adapter.rs index 1c69770bbcc4..739a227c8341 100644 --- a/query-engine/prisma/src/request_handlers/graphql/protocol_adapter.rs +++ b/query-engine/prisma/src/request_handlers/graphql/protocol_adapter.rs @@ -3,8 +3,8 @@ use graphql_parser::query::{ Definition, Document, OperationDefinition, Selection as GqlSelection, SelectionSet, Value, }; use query_core::query_document::*; -use rust_decimal::{prelude::FromPrimitive, Decimal}; -use std::collections::BTreeMap; +use rust_decimal::Decimal; +use std::{collections::BTreeMap, str::FromStr}; /// Protocol adapter for GraphQL -> Query Document. /// @@ -145,7 +145,9 @@ impl GraphQLProtocolAdapter { i ))), }, - Value::Float(f) => match Decimal::from_f64(f) { + // We can't use Decimal::from_f64 here due to a bug in rust_decimal. + // Issue: https://github.com/paupino/rust-decimal/issues/228 + Value::Float(f) => match Decimal::from_str(&f.to_string()).ok() { Some(dec) => Ok(QueryValue::Float(dec)), None => Err(PrismaError::QueryConversionError(format!( "invalid 64-bit float: {:?}", diff --git a/query-engine/prisma/src/tests.rs b/query-engine/prisma/src/tests.rs index c1d78226bb79..474cf570ede3 100644 --- a/query-engine/prisma/src/tests.rs +++ b/query-engine/prisma/src/tests.rs @@ -1,3 +1,4 @@ mod dmmf; mod execute_raw; mod test_api; +mod type_mappings; diff --git a/query-engine/prisma/src/tests/test_api.rs b/query-engine/prisma/src/tests/test_api.rs index b37e8a66af20..466db70985a8 100644 --- a/query-engine/prisma/src/tests/test_api.rs +++ b/query-engine/prisma/src/tests/test_api.rs @@ -22,6 +22,10 @@ pub struct QueryEngine { } impl QueryEngine { + pub fn new(ctx: PrismaContext) -> Self { + QueryEngine { context: Arc::new(ctx) } + } + pub async fn request(&self, body: impl Into) -> serde_json::Value { let request = PrismaRequest { body: GraphQlBody::Single(body.into()), diff --git a/query-engine/prisma/src/tests/type_mappings.rs b/query-engine/prisma/src/tests/type_mappings.rs new file mode 100644 index 000000000000..58281837a938 --- /dev/null +++ b/query-engine/prisma/src/tests/type_mappings.rs @@ -0,0 +1,3 @@ +mod mysql_types; +mod postgres_types; +mod test_api; diff --git a/query-engine/prisma/src/tests/type_mappings/mysql_types.rs b/query-engine/prisma/src/tests/type_mappings/mysql_types.rs new file mode 100644 index 000000000000..802fd9b7ef59 --- /dev/null +++ b/query-engine/prisma/src/tests/type_mappings/mysql_types.rs @@ -0,0 +1,274 @@ +use super::test_api::*; +use datamodel::dml::ScalarType; +use indoc::indoc; +use pretty_assertions::assert_eq; +use serde_json::json; +use test_macros::*; + +const CREATE_TYPES_TABLE: &str = indoc! { + r##" + CREATE TABLE `types` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `numeric_integer_tinyint` tinyint(4), + `numeric_integer_smallint` smallint(6), + `numeric_integer_int` int(11), + `numeric_integer_bigint` bigint(20), + `numeric_floating_decimal` decimal(10,2), + `numeric_floating_float` float, + `numeric_fixed_double` double, + `numeric_fixed_real` double, + `numeric_bit` bit(64), + `numeric_boolean` tinyint(1), + `date_date` date, + `date_datetime` datetime, + `date_timestamp` timestamp null DEFAULT null, + `date_time` time, + `date_year` year(4), + `string_char` char(255), + `string_varchar` varchar(255), + `string_text_tinytext` tinytext, + `string_text_text` text, + `string_text_mediumtext` mediumtext, + `string_text_longtext` longtext, + `string_binary_binary` binary(20), + `string_binary_varbinary` varbinary(255), + `string_blob_tinyblob` tinyblob, + `string_blob_mediumblob` mediumblob, + `string_blob_blob` blob, + `string_blob_longblob` longblob, + `string_enum` enum('pollicle_dogs','jellicle_cats'), + `string_set` set('a','b','c'), + `spatial_geometry` geometry, + `spatial_point` point, + `spatial_linestring` linestring, + `spatial_polygon` polygon, + `spatial_multipoint` multipoint, + `spatial_multilinestring` multilinestring, + `spatial_multipolygon` multipolygon, + `spatial_geometrycollection` geometrycollection, + `json` json, + + PRIMARY KEY (`id`) + ) ENGINE=InnoDB DEFAULT CHARSET=latin1; + "## +}; + +#[test_each_connector(tags("mysql"))] +async fn mysql_types_roundtrip(api: &TestApi) -> TestResult { + api.execute_sql(CREATE_TYPES_TABLE).await?; + + let (datamodel, engine) = api.introspect_and_start_query_engine().await?; + + datamodel.assert_model("types", |model| { + model + .assert_field_type("numeric_integer_tinyint", ScalarType::Int)? + .assert_field_type("numeric_integer_smallint", ScalarType::Int)? + .assert_field_type("numeric_integer_int", ScalarType::Int)? + .assert_field_type("numeric_integer_bigint", ScalarType::Int)? + .assert_field_type("numeric_floating_decimal", ScalarType::Float)? + .assert_field_type("numeric_floating_float", ScalarType::Float)? + .assert_field_type("numeric_fixed_double", ScalarType::Float)? + .assert_field_type("numeric_fixed_real", ScalarType::Float)? + .assert_field_type("numeric_bit", ScalarType::Int)? + .assert_field_type("numeric_boolean", ScalarType::Boolean)? + .assert_field_type("date_date", ScalarType::DateTime)? + .assert_field_type("date_datetime", ScalarType::DateTime)? + .assert_field_type("date_timestamp", ScalarType::DateTime)? + .assert_field_type("date_time", ScalarType::DateTime)? + .assert_field_type("date_year", ScalarType::Int)? + .assert_field_type("string_char", ScalarType::String)? + .assert_field_type("string_varchar", ScalarType::String)? + .assert_field_type("string_text_tinytext", ScalarType::String)? + .assert_field_type("string_text_text", ScalarType::String)? + .assert_field_type("string_text_mediumtext", ScalarType::String)? + .assert_field_type("string_text_longtext", ScalarType::String)? + .assert_field_type("string_binary_binary", ScalarType::String)? + .assert_field_type("string_blob_tinyblob", ScalarType::String)? + .assert_field_type("string_blob_mediumblob", ScalarType::String)? + .assert_field_type("string_blob_blob", ScalarType::String)? + .assert_field_type("string_blob_longblob", ScalarType::String)? + .assert_field_enum_type("string_enum", "types_string_enum")? + .assert_field_type("string_set", ScalarType::String)? + .assert_field_type("spatial_geometry", ScalarType::String)? + .assert_field_type("spatial_point", ScalarType::String)? + .assert_field_type("spatial_linestring", ScalarType::String)? + .assert_field_type("spatial_polygon", ScalarType::String)? + .assert_field_type("spatial_multipoint", ScalarType::String)? + .assert_field_type("spatial_multilinestring", ScalarType::String)? + .assert_field_type("spatial_multipolygon", ScalarType::String)? + .assert_field_type("spatial_geometrycollection", ScalarType::String)? + .assert_field_type("json", ScalarType::String) + })?; + + // Write the values. + { + let write = indoc! { + " + mutation { + createOnetypes( + data: { + numeric_integer_tinyint: 12, + numeric_integer_smallint: 350, + numeric_integer_int: 9002, + numeric_integer_bigint: 30000, + numeric_floating_decimal: 3.14 + numeric_floating_float: -32.0 + numeric_fixed_double: 0.14 + numeric_fixed_real: 12.12 + numeric_bit: 4 + numeric_boolean: true + date_date: \"2020-02-27T00:00:00Z\" + date_datetime: \"2020-02-27T19:10:22Z\" + date_timestamp: \"2020-02-27T19:11:22Z\" + # date_time: \"2020-02-20T12:50:01Z\" + date_year: 2012 + string_char: \"make dolphins easy\" + string_varchar: \"dolphins of varying characters\" + string_text_tinytext: \"tiny dolphins\" + string_text_text: \"dolphins\" + string_text_mediumtext: \"medium dolphins\" + string_text_longtext: \"long dolphins\" + string_binary_binary: \"hello 2020\" + string_blob_tinyblob: \"smol blob\" + string_blob_mediumblob: \"average blob\" + string_blob_blob: \"very average blob\" + string_blob_longblob: \"loong looooong bloooooooob\" + string_enum: \"jellicle_cats\" + json: \"{\\\"name\\\": null}\" + } + ) { id } + } + " + }; + + let write_response = engine.request(write).await; + + let expected_write_response = json!({ + "data": { + "createOnetypes": { + "id": 1, + } + } + }); + + assert_eq!(write_response, expected_write_response); + } + + // Read the values back. + { + let read = indoc! { + " + query { + findManytypes { + numeric_integer_tinyint + numeric_integer_smallint + numeric_integer_int + numeric_integer_bigint + numeric_floating_decimal + numeric_floating_float + numeric_fixed_double + numeric_fixed_real + numeric_bit + numeric_boolean + date_date + date_datetime + date_timestamp + # date_time + date_year + string_char + string_varchar + string_text_tinytext + string_text_text + string_text_mediumtext + string_text_longtext + string_binary_binary + string_blob_tinyblob + string_blob_mediumblob + string_blob_blob + string_blob_longblob + string_enum + # omitting spatial/geometry types + json + } + } + " + }; + + let read_response = engine.request(read).await; + + let expected_read_response = json!({ + "data": { + "findManytypes": [ + { + "numeric_integer_tinyint": 12, + "numeric_integer_smallint": 350, + "numeric_integer_int": 9002, + "numeric_integer_bigint": 30000, + "numeric_floating_decimal": 3.14, + "numeric_floating_float": -32.0, + "numeric_fixed_double": 0.14, + "numeric_fixed_real": 12.12, + "numeric_bit": 4, + "numeric_boolean": true, + "date_date": "2020-02-27T00:00:00.000Z", + "date_datetime": "2020-02-27T19:10:22.000Z", + "date_timestamp": "2020-02-27T19:11:22.000Z", + // "date_time": "2020-02-27T19:11:22.000Z", + "date_year": 2012, + "string_char": "make dolphins easy", + "string_varchar": "dolphins of varying characters", + "string_text_tinytext": "tiny dolphins", + "string_text_text": "dolphins", + "string_text_mediumtext": "medium dolphins", + "string_text_longtext": "long dolphins", + "string_binary_binary": "hello 2020\u{0}\u{0}\u{0}\u{0}\u{0}\u{0}\u{0}\u{0}\u{0}\u{0}", + "string_blob_tinyblob": "smol blob", + "string_blob_mediumblob": "average blob", + "string_blob_blob": "very average blob", + "string_blob_longblob": "loong looooong bloooooooob", + "string_enum": "jellicle_cats", + "json": "{\"name\": null}", + }, + ] + }, + }); + + assert_eq!(read_response, expected_read_response); + } + + Ok(()) +} + +#[test_each_connector(tags("mysql"))] +async fn mysql_bit_columns_are_properly_mapped_to_signed_integers(api: &TestApi) -> TestResult { + api.execute_sql(CREATE_TYPES_TABLE).await?; + + let (_datamodel, engine) = api.introspect_and_start_query_engine().await?; + + let write = indoc! { + " + mutation { + createOnetypes( + data: { + numeric_bit: -12 + } + ) { id numeric_bit } + } + " + }; + + let write_response = engine.request(write).await; + + let expected_write_response = json!({ + "data": { + "createOnetypes": { + "id": 1, + "numeric_bit": -12, + } + } + }); + + assert_eq!(write_response, expected_write_response); + + Ok(()) +} diff --git a/query-engine/prisma/src/tests/type_mappings/postgres_types.rs b/query-engine/prisma/src/tests/type_mappings/postgres_types.rs new file mode 100644 index 000000000000..fff12ef73890 --- /dev/null +++ b/query-engine/prisma/src/tests/type_mappings/postgres_types.rs @@ -0,0 +1,458 @@ +use super::test_api::*; +use datamodel::ScalarType; +use indoc::indoc; +use pretty_assertions::assert_eq; +use serde_json::json; +use test_macros::test_each_connector; + +const CREATE_TYPES_TABLE: &str = indoc! { + r##" + CREATE TABLE "prisma-tests"."types" ( + id SERIAL PRIMARY KEY, + numeric_int2 int2, + numeric_int4 int4, + numeric_int8 int8, + + numeric_decimal decimal(8, 4), + numeric_float4 float4, + numeric_float8 float8, + + numeric_serial2 serial2, + numeric_serial4 serial4, + numeric_serial8 serial8, + + numeric_money money, + numeric_oid oid, + + string_char char(8), + string_varchar varchar(20), + string_text text, + + binary_bytea bytea, + binary_bits bit(80), + binary_bits_varying bit varying(80), + binary_uuid uuid, + + time_timestamp timestamp, + time_timestamptz timestamptz, + time_date date, + time_time time, + time_timetz timetz, + time_interval interval, + + boolean_boolean boolean, + + network_cidr cidr, + network_inet inet, + network_mac macaddr, + + search_tsvector tsvector, + search_tsquery tsquery, + + json_json json, + json_jsonb jsonb, + + range_int4range int4range, + range_int8range int8range, + range_numrange numrange, + range_tsrange tsrange, + range_tstzrange tstzrange, + range_daterange daterange + ); + "## +}; + +#[test_each_connector(tags("postgres"), log = "debug")] +async fn postgres_types_roundtrip(api: &TestApi) -> TestResult { + api.execute_sql(CREATE_TYPES_TABLE).await?; + + let (datamodel, engine) = api.introspect_and_start_query_engine().await?; + + datamodel.assert_model("types", |model| { + model + .assert_field_type("numeric_int2", ScalarType::Int)? + .assert_field_type("numeric_int4", ScalarType::Int)? + .assert_field_type("numeric_int8", ScalarType::Int)? + .assert_field_type("numeric_decimal", ScalarType::Float)? + .assert_field_type("numeric_float4", ScalarType::Float)? + .assert_field_type("numeric_float8", ScalarType::Float)? + .assert_field_type("numeric_serial2", ScalarType::Int)? + .assert_field_type("numeric_serial4", ScalarType::Int)? + .assert_field_type("numeric_serial8", ScalarType::Int)? + .assert_field_type("numeric_money", ScalarType::Float)? + .assert_field_type("numeric_oid", ScalarType::Int)? + .assert_field_type("string_char", ScalarType::String)? + .assert_field_type("string_varchar", ScalarType::String)? + .assert_field_type("string_text", ScalarType::String)? + .assert_field_type("binary_bytea", ScalarType::String)? + .assert_field_type("binary_bits", ScalarType::String)? + .assert_field_type("binary_bits_varying", ScalarType::String)? + .assert_field_type("binary_uuid", ScalarType::String)? + .assert_field_type("time_timestamp", ScalarType::DateTime)? + .assert_field_type("time_timestamptz", ScalarType::DateTime)? + .assert_field_type("time_date", ScalarType::DateTime)? + .assert_field_type("time_time", ScalarType::DateTime)? + .assert_field_type("time_timetz", ScalarType::DateTime)? + .assert_field_type("time_interval", ScalarType::String)? + .assert_field_type("boolean_boolean", ScalarType::Boolean)? + .assert_field_type("network_cidr", ScalarType::String)? + .assert_field_type("network_inet", ScalarType::String)? + .assert_field_type("network_mac", ScalarType::String)? + .assert_field_type("search_tsvector", ScalarType::String)? + .assert_field_type("search_tsquery", ScalarType::String)? + .assert_field_type("json_json", ScalarType::String)? + .assert_field_type("json_jsonb", ScalarType::String)? + .assert_field_type("range_int4range", ScalarType::String)? + .assert_field_type("range_int8range", ScalarType::String)? + .assert_field_type("range_numrange", ScalarType::String)? + .assert_field_type("range_tsrange", ScalarType::String)? + .assert_field_type("range_tstzrange", ScalarType::String)? + .assert_field_type("range_daterange", ScalarType::String) + })?; + + let query = indoc! { + r##" + mutation { + createOnetypes( + data: { + numeric_int2: 12 + numeric_int4: 9002 + numeric_int8: 100000000 + numeric_decimal: 49.3444 + numeric_float4: 12.12 + numeric_float8: 3.139428 + numeric_serial2: 8, + numeric_serial4: 80, + numeric_serial8: 80000, + numeric_money: 3.50 + numeric_oid: 2000 + string_char: "yeet" + string_varchar: "yeet variable" + string_text: "to yeet or not to yeet" + # binary_bytea: "test" + binary_uuid: "111142ec-880b-4062-913d-8eac479ab957" + time_timestamp: "2020-03-02T08:00:00.000" + time_timestamptz: "2020-03-02T08:00:00.000" + time_date: "2020-03-05T00:00:00.000" + time_time: "2020-03-05T08:00:00.000" + time_timetz: "2020-03-05T08:00:00.000" + # time_interval: "3 hours" + boolean_boolean: true + # network_cidr: "192.168.100.14/24" + network_inet: "192.168.100.14" + # network_mac: "12:33:ed:44:49:36" + # search_tsvector: "''a'' ''dump'' ''dumps'' ''fox'' ''in'' ''the''" + # search_tsquery: "''foxy cat''" + json_json: "{ \"isJson\": true }" + json_jsonb: "{ \"isJSONB\": true }" + # range_int4range: "[-4, 8)" + # range_int8range: "[4000, 9000)" + # range_numrange: "[11.1, 22.2)" + # range_tsrange: "[2010-01-01 14:30, 2010-01-01 15:30)" + # range_tstzrange: "[2010-01-01 14:30, 2010-01-01 15:30)" + # range_daterange: "[2020-03-02, 2020-03-22)" + } + ) { + numeric_int2 + numeric_int4 + numeric_int8 + numeric_decimal + numeric_float4 + numeric_float8 + numeric_serial2 + numeric_serial4 + numeric_serial8 + numeric_money + numeric_oid + string_char + string_varchar + string_text + # binary_bytea + binary_uuid + time_timestamp + time_timestamptz + time_date + time_time + time_timetz + # time_interval + boolean_boolean + # network_cidr + network_inet + # network_mac + # search_tsvector + # search_tsquery + json_json + json_jsonb + # range_int4range + # range_int8range + # range_numrange + # range_tsrange + # range_tstzrange + # range_daterange + } + } + "## + }; + + let response = engine.request(query).await; + + let expected_response = json!({ + "data": { + "createOnetypes": { + "numeric_int2": 12, + "numeric_int4": 9002, + "numeric_int8": 100000000, + "numeric_serial2": 8, + "numeric_serial4": 80, + "numeric_serial8": 80000, + "numeric_decimal": 49.3444, + "numeric_float4": 12.12, + "numeric_float8": 3.139428, + "numeric_money": 3.5, + "numeric_oid": 2000, + "string_char": "yeet ", + "string_varchar": "yeet variable", + "string_text": "to yeet or not to yeet", + "binary_uuid": "111142ec-880b-4062-913d-8eac479ab957", + "time_timestamp": "2020-03-02T08:00:00.000Z", + "time_timestamptz": "2020-03-02T08:00:00.000Z", + "time_date": "2020-03-05T00:00:00.000Z", + "time_time": "1970-01-01T08:00:00.000Z", + "time_timetz": "1970-01-01T08:00:00.000Z", + "boolean_boolean": true, + "network_inet": "192.168.100.14", + "json_json": "{\"isJson\":true}", + "json_jsonb": "{\"isJSONB\":true}", + } + } + }); + + assert_eq!(response, expected_response); + + Ok(()) +} + +#[test_each_connector(tags("postgres"), log = "debug")] +async fn small_float_values_must_work(api: &TestApi) -> TestResult { + let schema = indoc! { + r#" + CREATE TABLE floatilla ( + id SERIAL PRIMARY KEY, + f32 float4, + f64 float8, + decimal_column decimal + ); + "# + }; + + api.execute_sql(schema).await?; + + let (datamodel, engine) = api.introspect_and_start_query_engine().await?; + + datamodel.assert_model("floatilla", |model| { + model + .assert_field_type("f32", ScalarType::Float)? + .assert_field_type("f64", ScalarType::Float)? + .assert_field_type("decimal_column", ScalarType::Float) + })?; + + let query = indoc! { + r##" + mutation { + createOnefloatilla( + data: { + f32: 0.00006927, + f64: 0.00006927, + decimal_column: 0.00006927 + } + ) { + id + f32 + f64 + decimal_column + } + } + "## + }; + + let response = engine.request(query).await; + + let expected_response = json!({ + "data": { + "createOnefloatilla": { + "id": 1, + "f32": 0.00006927, + "f64": 0.00006927, + "decimal_column": 0.00006927 + } + } + }); + + assert_eq!(response, expected_response); + + Ok(()) +} + +const CREATE_ARRAY_TYPES_TABLE: &str = indoc! { + r##" + CREATE TABLE "prisma-tests"."arraytypes" ( + id SERIAL PRIMARY KEY, + numeric_int2 int2[], + numeric_int4 int4[], + numeric_int8 int8[], + + numeric_decimal decimal(8, 4)[], + numeric_float4 float4[], + numeric_float8 float8[], + + numeric_money money[], + numeric_oid oid[], + + string_char char(8)[], + string_varchar varchar(20)[], + string_text text[], + + binary_bytea bytea[], + binary_bits bit(80)[], + binary_bits_varying bit varying(80)[], + binary_uuid uuid[], + + time_timestamp timestamp[], + time_timestamptz timestamptz[], + time_date date[], + time_time time[], + time_timetz timetz[], + + boolean_boolean boolean[], + + network_cidr cidr[], + network_inet inet[], + + json_json json[], + json_jsonb jsonb[] + ); + "## +}; + +#[test_each_connector(tags("postgres"), log = "debug")] +async fn postgres_array_types_roundtrip(api: &TestApi) -> TestResult { + api.execute_sql(CREATE_ARRAY_TYPES_TABLE).await?; + + let (datamodel, engine) = api.introspect_and_start_query_engine().await?; + + datamodel.assert_model("arraytypes", |model| { + model + .assert_field_type("numeric_int2", ScalarType::Int)? + .assert_field_type("numeric_int4", ScalarType::Int)? + .assert_field_type("numeric_int8", ScalarType::Int)? + .assert_field_type("numeric_decimal", ScalarType::Float)? + .assert_field_type("numeric_float4", ScalarType::Float)? + .assert_field_type("numeric_float8", ScalarType::Float)? + .assert_field_type("numeric_money", ScalarType::Float)? + .assert_field_type("numeric_oid", ScalarType::Int)? + .assert_field_type("string_char", ScalarType::String)? + .assert_field_type("string_varchar", ScalarType::String)? + .assert_field_type("string_text", ScalarType::String)? + .assert_field_type("binary_bytea", ScalarType::String)? + .assert_field_type("binary_bits", ScalarType::String)? + .assert_field_type("binary_bits_varying", ScalarType::String)? + .assert_field_type("binary_uuid", ScalarType::String)? + .assert_field_type("time_timestamp", ScalarType::DateTime)? + .assert_field_type("time_timestamptz", ScalarType::DateTime)? + .assert_field_type("time_date", ScalarType::DateTime)? + .assert_field_type("time_time", ScalarType::DateTime)? + .assert_field_type("time_timetz", ScalarType::DateTime)? + .assert_field_type("boolean_boolean", ScalarType::Boolean)? + .assert_field_type("network_inet", ScalarType::String)? + .assert_field_type("json_json", ScalarType::String)? + .assert_field_type("json_jsonb", ScalarType::String) + })?; + + let query = indoc! { + r##" + mutation { + createOnearraytypes( + data: { + numeric_int2: { set: [12] } + numeric_int4: { set: [9002] } + numeric_int8: { set: [100000000] } + numeric_decimal: { set: [49.3444] } + numeric_float4: { set: [12.12] } + numeric_float8: { set: [3.139428] } + numeric_money: { set: [3.50] } + numeric_oid: { set: [2000] } + string_char: { set: ["yeet"] } + string_varchar: { set: ["yeet variable"] } + string_text: { set: ["to yeet or not to yeet"] } + binary_uuid: { set: ["111142ec-880b-4062-913d-8eac479ab957"] } + time_timestamp: { set: ["2020-03-02T08:00:00.000"] } + time_timestamptz: { set: ["2020-03-02T08:00:00.000"] } + time_date: { set: ["2020-03-05T00:00:00.000"] } + time_time: { set: ["2020-03-05T08:00:00.000"] } + time_timetz: { set: ["2020-03-05T08:00:00.000"] } + boolean_boolean: { set: [true, true, false, true] } + network_inet: { set: ["192.168.100.14"] } + json_json: { set: ["{ \"isJson\": true }"] } + json_jsonb: { set: ["{ \"isJSONB\": true }"] } + } + ) { + numeric_int2 + numeric_int4 + numeric_int8 + numeric_decimal + numeric_float4 + numeric_float8 + numeric_money + numeric_oid + string_char + string_varchar + string_text + binary_uuid + time_timestamp + time_timestamptz + time_date + time_time + time_timetz + boolean_boolean + network_inet + json_json + json_jsonb + } + } + "## + }; + + let response = engine.request(query).await; + + let expected_response = json!({ + "data": { + "createOnearraytypes": { + "numeric_int2": [12], + "numeric_int4": [9002], + "numeric_int8": [100000000], + "numeric_decimal": [49.3444], + "numeric_float4": [12.12], + "numeric_float8": [3.139428], + "numeric_money": [3.5], + "numeric_oid": [2000], + "string_char": ["yeet "], + "string_varchar": ["yeet variable"], + "string_text": ["to yeet or not to yeet"], + "binary_uuid": ["111142ec-880b-4062-913d-8eac479ab957"], + "time_timestamp": ["2020-03-02T08:00:00.000Z"], + "time_timestamptz": ["2020-03-02T08:00:00.000Z"], + "time_date": ["2020-03-05T00:00:00.000Z"], + "time_time": ["1970-01-01T08:00:00.000Z"], + "time_timetz": ["1970-01-01T08:00:00.000Z"], + "boolean_boolean": [true, true, false, true], + "network_inet": ["192.168.100.14"], + "json_json": ["{\"isJson\":true}"], + "json_jsonb": ["{\"isJSONB\":true}"], + } + } + }); + + assert_eq!(response, expected_response); + + Ok(()) +} diff --git a/query-engine/prisma/src/tests/type_mappings/test_api.rs b/query-engine/prisma/src/tests/type_mappings/test_api.rs new file mode 100644 index 000000000000..836c816ebbdf --- /dev/null +++ b/query-engine/prisma/src/tests/type_mappings/test_api.rs @@ -0,0 +1,209 @@ +use super::super::test_api::QueryEngine; +use crate::context::PrismaContext; +use quaint::prelude::Queryable; + +pub type TestResult = anyhow::Result<()>; + +pub struct TestApi { + provider: &'static str, + database_string: String, + is_pgbouncer: bool, +} + +impl TestApi { + fn datasource(&self) -> String { + format!( + r#" + datasource my_db {{ + provider = "{provider}" + url = "{url}" + }} + "#, + provider = self.provider, + url = self.database_string, + ) + } + + pub async fn execute_sql(&self, sql: &str) -> anyhow::Result<()> { + let conn = quaint::single::Quaint::new(&self.database_string).await?; + + conn.execute_raw(sql, &[]).await?; + + Ok(()) + } + + pub async fn introspect_and_start_query_engine(&self) -> anyhow::Result<(DatamodelAssertions, QueryEngine)> { + let datasource = self.datasource(); + + let schema = introspection_core::RpcImpl::introspect_internal(datasource) + .await + .map_err(|err| anyhow::anyhow!("{:?}", err.data))?; + + let context = PrismaContext::builder() + .enable_raw_queries(true) + .datamodel(schema.clone()) + .force_transactions(self.is_pgbouncer) + .build() + .await + .unwrap(); + + eprintln!("{}", schema); + let schema = datamodel::parse_datamodel(&schema).unwrap(); + + Ok((DatamodelAssertions(schema), QueryEngine::new(context))) + } +} + +pub struct DatamodelAssertions(datamodel::Datamodel); + +impl DatamodelAssertions { + pub fn assert_model(self, name: &str, assert_fn: F) -> anyhow::Result + where + F: for<'a> FnOnce(ModelAssertions<'a>) -> anyhow::Result>, + { + let model = self + .0 + .find_model(name) + .ok_or_else(|| anyhow::anyhow!("Assertion error: could not find model {}", name))?; + + assert_fn(ModelAssertions(model))?; + + Ok(self) + } +} + +pub struct ModelAssertions<'a>(&'a datamodel::dml::Model); + +impl<'a> ModelAssertions<'a> { + pub fn assert_field_type(self, name: &str, r#type: datamodel::dml::ScalarType) -> anyhow::Result { + let field = self + .0 + .find_field(name) + .ok_or_else(|| anyhow::anyhow!("Assertion error: could not find field {}", name))?; + + anyhow::ensure!( + field.field_type == datamodel::dml::FieldType::Base(r#type), + "Assertion error: expected the field {} to have type {:?}, but found {:?}", + field.name, + r#type, + &field.field_type, + ); + + Ok(self) + } + + pub fn assert_field_enum_type(self, name: &str, enum_name: &str) -> anyhow::Result { + let field = self + .0 + .find_field(name) + .ok_or_else(|| anyhow::anyhow!("Assertion error: could not find field {}", name))?; + + anyhow::ensure!( + field.field_type == datamodel::dml::FieldType::Enum(enum_name.into()), + "Assertion error: expected the field {} to have enum type {:?}, but found {:?}", + field.name, + enum_name, + &field.field_type, + ); + + Ok(self) + } +} + +pub async fn mysql_8_test_api(db_name: &str) -> TestApi { + let mysql_url = test_setup::mysql_8_url(db_name); + + test_setup::create_mysql_database(&mysql_url.parse().unwrap()) + .await + .unwrap(); + + TestApi { + database_string: mysql_url, + provider: "mysql", + is_pgbouncer: false, + } +} + +pub async fn mysql_test_api(db_name: &str) -> TestApi { + let mysql_url = test_setup::mysql_url(db_name); + + test_setup::create_mysql_database(&mysql_url.parse().unwrap()) + .await + .unwrap(); + + TestApi { + database_string: mysql_url, + provider: "mysql", + is_pgbouncer: false, + } +} + +pub async fn mysql_mariadb_test_api(db_name: &str) -> TestApi { + let mysql_url = test_setup::mariadb_url(db_name); + + test_setup::create_mysql_database(&mysql_url.parse().unwrap()) + .await + .unwrap(); + + TestApi { + database_string: mysql_url, + provider: "mysql", + is_pgbouncer: false, + } +} + +pub async fn postgres_test_api(db_name: &str) -> TestApi { + let postgres_url = test_setup::postgres_10_url(db_name); + + test_setup::create_postgres_database(&postgres_url.parse().unwrap()) + .await + .unwrap(); + + TestApi { + database_string: postgres_url, + provider: "postgres", + is_pgbouncer: false, + } +} + +pub async fn postgres9_test_api(db_name: &str) -> TestApi { + let postgres_url = test_setup::postgres_9_url(db_name); + + test_setup::create_postgres_database(&postgres_url.parse().unwrap()) + .await + .unwrap(); + + TestApi { + database_string: postgres_url, + provider: "postgres", + is_pgbouncer: false, + } +} + +pub async fn postgres11_test_api(db_name: &str) -> TestApi { + let postgres_url = test_setup::postgres_11_url(db_name); + + test_setup::create_postgres_database(&postgres_url.parse().unwrap()) + .await + .unwrap(); + + TestApi { + database_string: postgres_url, + provider: "postgres", + is_pgbouncer: false, + } +} + +pub async fn postgres12_test_api(db_name: &str) -> TestApi { + let postgres_url = test_setup::postgres_12_url(db_name); + + test_setup::create_postgres_database(&postgres_url.parse().unwrap()) + .await + .unwrap(); + + TestApi { + database_string: postgres_url, + provider: "postgres", + is_pgbouncer: false, + } +}