Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test and fix mapping of mysql and postgres native types #538

Merged
merged 15 commits into from Mar 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -463,7 +463,7 @@ fn parse_bool(value: &str) -> Option<bool> {

static RE_FLOAT: Lazy<Regex> = Lazy::new(|| Regex::new(r"^'?([^']+)'?$").expect("compile regex"));

fn parse_float(value: &str) -> Option<f32> {
fn parse_float(value: &str) -> Option<f64> {
debug!("Parsing float '{}'", value);
let rslt = RE_FLOAT.captures(value);
if rslt.is_none() {
Expand All @@ -473,7 +473,7 @@ fn parse_float(value: &str) -> Option<f32> {

let captures = rslt.expect("get captures");
let num_str = captures.get(1).expect("get capture").as_str();
let num_rslt = num_str.parse::<f32>();
let num_rslt = num_str.parse::<f64>();
match num_rslt {
Ok(num) => Some(num),
Err(_) => {
Expand Down
7 changes: 7 additions & 0 deletions introspection-engine/core/src/lib.rs
@@ -0,0 +1,7 @@
mod command_error;
mod error;
mod error_rendering;
mod rpc;

pub use error::Error;
pub use rpc::RpcImpl;
12 changes: 6 additions & 6 deletions introspection-engine/core/src/rpc.rs
Expand Up @@ -26,7 +26,7 @@ pub trait Rpc {
fn introspect(&self, input: IntrospectionInput) -> RpcFutureResult<String>;
}

pub(crate) struct RpcImpl;
pub struct RpcImpl;

impl Rpc for RpcImpl {
fn list_databases(&self, input: IntrospectionInput) -> RpcFutureResult<Vec<String>> {
Expand All @@ -47,7 +47,7 @@ impl Rpc for RpcImpl {
}

impl RpcImpl {
pub(crate) fn new() -> Self {
pub fn new() -> Self {
RpcImpl
}

Expand All @@ -63,7 +63,7 @@ impl RpcImpl {
Ok(Box::new(SqlIntrospectionConnector::new(&url).await?))
}

pub(crate) async fn introspect_internal(schema: String) -> RpcResult<String> {
pub async fn introspect_internal(schema: String) -> RpcResult<String> {
let config = datamodel::parse_configuration(&schema).map_err(Error::from)?;
let url = config
.datasources
Expand All @@ -85,17 +85,17 @@ impl RpcImpl {
}
}

pub(crate) async fn list_databases_internal(schema: String) -> RpcResult<Vec<String>> {
pub async fn list_databases_internal(schema: String) -> RpcResult<Vec<String>> {
let connector = RpcImpl::load_connector(&schema).await?;
Ok(connector.list_databases().await.map_err(Error::from)?)
}

pub(crate) async fn get_database_description(schema: String) -> RpcResult<String> {
pub async fn get_database_description(schema: String) -> RpcResult<String> {
let connector = RpcImpl::load_connector(&schema).await?;
Ok(connector.get_database_description().await.map_err(Error::from)?)
}

pub(crate) async fn get_database_metadata_internal(schema: String) -> RpcResult<DatabaseMetadata> {
pub async fn get_database_metadata_internal(schema: String) -> RpcResult<DatabaseMetadata> {
let connector = RpcImpl::load_connector(&schema).await?;
Ok(connector.get_metadata().await.map_err(Error::from)?)
}
Expand Down
4 changes: 2 additions & 2 deletions libs/datamodel/connectors/datamodel-connector/src/scalars.rs
Expand Up @@ -43,8 +43,8 @@ impl ToString for ScalarType {
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum ScalarValue {
Int(i32),
Float(f32),
Decimal(f32),
Float(f64),
Decimal(f64),
Boolean(bool),
String(String),
DateTime(DateTime<Utc>),
Expand Down
12 changes: 6 additions & 6 deletions libs/datamodel/core/src/common/value_validator.rs
Expand Up @@ -115,20 +115,20 @@ impl ValueValidator {
}

/// Tries to convert the wrapped value to a Prisma Float.
pub fn as_float(&self) -> Result<f32, DatamodelError> {
pub fn as_float(&self) -> Result<f64, DatamodelError> {
match &self.value {
ast::Expression::NumericValue(value, _) => self.wrap_error_from_result(value.parse::<f32>(), "numeric"),
ast::Expression::Any(value, _) => self.wrap_error_from_result(value.parse::<f32>(), "numeric"),
ast::Expression::NumericValue(value, _) => self.wrap_error_from_result(value.parse::<f64>(), "numeric"),
ast::Expression::Any(value, _) => self.wrap_error_from_result(value.parse::<f64>(), "numeric"),
_ => Err(self.construct_type_mismatch_error("numeric")),
}
}

// TODO: Ask which decimal type to take.
/// Tries to convert the wrapped value to a Prisma Decimal.
pub fn as_decimal(&self) -> Result<f32, DatamodelError> {
pub fn as_decimal(&self) -> Result<f64, DatamodelError> {
match &self.value {
ast::Expression::NumericValue(value, _) => self.wrap_error_from_result(value.parse::<f32>(), "numeric"),
ast::Expression::Any(value, _) => self.wrap_error_from_result(value.parse::<f32>(), "numeric"),
ast::Expression::NumericValue(value, _) => self.wrap_error_from_result(value.parse::<f64>(), "numeric"),
ast::Expression::Any(value, _) => self.wrap_error_from_result(value.parse::<f64>(), "numeric"),
_ => Err(self.construct_type_mismatch_error("numeric")),
}
}
Expand Down
10 changes: 0 additions & 10 deletions libs/prisma-value/src/lib.rs
Expand Up @@ -140,16 +140,6 @@ impl TryFrom<f64> for PrismaValue {

fn try_from(f: f64) -> PrismaValueResult<PrismaValue> {
Decimal::from_f64(f)
.map(|d| PrismaValue::Float(d))
.ok_or(ConversionFailure::new("f32", "Decimal"))
}
}

impl TryFrom<f32> for PrismaValue {
type Error = ConversionFailure;

fn try_from(f: f32) -> PrismaValueResult<PrismaValue> {
Decimal::from_f32(f)
.map(|d| PrismaValue::Float(d))
.ok_or(ConversionFailure::new("f64", "Decimal"))
}
Expand Down
3 changes: 3 additions & 0 deletions libs/prisma-value/src/sql_ext.rs
Expand Up @@ -15,6 +15,9 @@ impl<'a> From<ParameterizedValue<'a>> for PrismaValue {
ParameterizedValue::Uuid(uuid) => PrismaValue::Uuid(uuid),
ParameterizedValue::DateTime(dt) => PrismaValue::DateTime(dt),
ParameterizedValue::Char(c) => PrismaValue::String(c.to_string()),
ParameterizedValue::Bytes(bytes) => PrismaValue::String(
String::from_utf8(bytes.into_owned()).expect("PrismaValue::String from ParameterizedValue::Bytes"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next PR will test this (and things like non-utf8 encodings) to make sure we return good errors.

),
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion libs/sql-schema-describer/src/mysql.rs
Expand Up @@ -456,11 +456,12 @@ fn get_column_type_and_enum(
("numeric", _) => ColumnTypeFamily::Float,
("float", _) => ColumnTypeFamily::Float,
("double", _) => ColumnTypeFamily::Float,
("bit", _) => ColumnTypeFamily::Int,
("date", _) => ColumnTypeFamily::DateTime,
("time", _) => ColumnTypeFamily::DateTime,
("datetime", _) => ColumnTypeFamily::DateTime,
("timestamp", _) => ColumnTypeFamily::DateTime,
("year", _) => ColumnTypeFamily::DateTime,
("year", _) => ColumnTypeFamily::Int,
("char", _) => ColumnTypeFamily::String,
("varchar", _) => ColumnTypeFamily::String,
("text", _) => ColumnTypeFamily::String,
Expand Down
4 changes: 3 additions & 1 deletion libs/sql-schema-describer/src/postgres.rs
Expand Up @@ -573,6 +573,7 @@ fn get_column_type<'a>(data_type: &str, full_data_type: &'a str, arity: ColumnAr
"int2" | "_int2" => Int,
"int4" | "_int4" => Int,
"int8" | "_int8" => Int,
"oid" | "_oid" => Int,
"float4" | "_float4" => Float,
"float8" | "_float8" => Float,
"bool" | "_bool" => Boolean,
Expand All @@ -592,8 +593,9 @@ fn get_column_type<'a>(data_type: &str, full_data_type: &'a str, arity: ColumnAr
"path" | "_path" => Geometric,
"polygon" | "_polygon" => Geometric,
"bpchar" | "_bpchar" => String,
"interval" | "_interval" => DateTime,
"interval" | "_interval" => String,
"numeric" | "_numeric" => Float,
"money" | "_money" => Float,
"pg_lsn" | "_pg_lsn" => LogSequenceNumber,
"time" | "_time" => DateTime,
"timetz" | "_timetz" => DateTime,
Expand Down
Expand Up @@ -228,7 +228,7 @@ async fn all_mysql_column_types_must_work() {
name: "year_col".to_string(),
tpe: ColumnType {
raw: "year".to_string(),
family: ColumnTypeFamily::DateTime,
family: ColumnTypeFamily::Int,
arity: ColumnArity::Required,
},

Expand Down
Expand Up @@ -327,7 +327,7 @@ async fn all_postgres_column_types_must_work() {
name: "interval_col".into(),
tpe: ColumnType {
raw: "interval".into(),
family: ColumnTypeFamily::DateTime,
family: ColumnTypeFamily::String,
arity: ColumnArity::Required,
},

Expand Down
8 changes: 7 additions & 1 deletion libs/test-setup/src/lib.rs
Expand Up @@ -335,11 +335,17 @@ pub async fn create_postgres_database(original_url: &Url) -> Result<Quaint, AnyE

let db_name = fetch_db_name(&original_url, "postgres");

let drop_stmt = format!("DROP DATABASE IF EXISTS \"{}\"", db_name);
let create_stmt = format!("CREATE DATABASE \"{}\"", db_name);
let create_schema_stmt = format!("CREATE SCHEMA \"{}\"", SCHEMA_NAME);

let conn = Quaint::new(url.as_str()).await.unwrap();

conn.query_raw(&drop_stmt, &[]).await.ok();
conn.query_raw(&create_stmt, &[]).await.ok();

Ok(Quaint::new(original_url.as_str()).await?)
let conn = Quaint::new(original_url.as_str()).await?;
conn.query_raw(&create_schema_stmt, &[]).await.ok();

Ok(conn)
}
115 changes: 111 additions & 4 deletions query-engine/connectors/sql-query-connector/src/row.rs
Expand Up @@ -7,7 +7,7 @@ use quaint::{
connector::ResultRow,
};
use rust_decimal::{prelude::FromPrimitive, Decimal};
use std::{borrow::Borrow, io};
use std::{borrow::Borrow, io, str::FromStr};
use uuid::Uuid;

/// An allocated representation of a `Row` returned from the database.
Expand All @@ -33,6 +33,7 @@ impl ToSqlRow for ResultRow {
fn to_sql_row<'b>(self, idents: &[(TypeIdentifier, FieldArity)]) -> crate::Result<SqlRow> {
let mut row = SqlRow::default();
let row_width = idents.len();
row.values.reserve(row_width);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. What is the difference between this and then implementing with_capacity to SqlRow and using it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's the same.

for (i, p_value) in self.into_iter().enumerate().take(row_width) {
let pv = match &idents[i] {
(type_identifier, FieldArity::List) => match p_value {
Expand Down Expand Up @@ -118,7 +119,10 @@ pub fn row_value_to_prisma_value(
ParameterizedValue::Text(dt_string) => {
let dt = DateTime::parse_from_rfc3339(dt_string.borrow())
.or_else(|_| DateTime::parse_from_rfc2822(dt_string.borrow()))
.expect(&format!("Could not parse stored DateTime string: {}", dt_string));
.map_err(|err| {
failure::format_err!("Could not parse stored DateTime string: {} ({})", dt_string, err)
})
.unwrap();

PrismaValue::DateTime(dt.with_timezone(&Utc))
}
Expand All @@ -136,7 +140,13 @@ pub fn row_value_to_prisma_value(
ParameterizedValue::Integer(i) => {
PrismaValue::Float(Decimal::from_f64(i as f64).expect("f64 was not a Decimal."))
}
ParameterizedValue::Text(s) => PrismaValue::Float(s.parse().unwrap()),
ParameterizedValue::Text(_) | ParameterizedValue::Bytes(_) => PrismaValue::Float(
p_value
.as_str()
.expect("text/bytes as str")
.parse()
.map_err(|err: rust_decimal::Error| SqlError::ColumnReadFailure(err.into()))?,
),
_ => {
let error = io::Error::new(
io::ErrorKind::InvalidData,
Expand All @@ -145,7 +155,23 @@ pub fn row_value_to_prisma_value(
return Err(SqlError::ConversionError(error.into()));
}
},
_ => PrismaValue::from(p_value),
TypeIdentifier::Int => match p_value {
ParameterizedValue::Integer(i) => PrismaValue::Int(i),
ParameterizedValue::Bytes(bytes) => PrismaValue::Int(interpret_bytes_as_i64(&bytes)),
ParameterizedValue::Text(txt) => PrismaValue::Int(
i64::from_str(dbg!(txt.trim_start_matches('\0')))
.map_err(|err| SqlError::ConversionError(err.into()))?,
),
other => PrismaValue::from(other),
},
TypeIdentifier::String => match p_value {
ParameterizedValue::Uuid(uuid) => PrismaValue::String(uuid.to_string()),
ParameterizedValue::Json(json_value) => {
PrismaValue::String(serde_json::to_string(&json_value).expect("JSON value to string"))
}
ParameterizedValue::Null => PrismaValue::Null,
other => PrismaValue::from(other),
},
})
}

Expand All @@ -171,3 +197,84 @@ impl From<&SqlId> for DatabaseValue<'static> {
id.clone().into()
}
}

// We assume the bytes are stored as a big endian signed integer, because that is what
// mysql does if you enter a numeric value for a bits column.
fn interpret_bytes_as_i64(bytes: &[u8]) -> i64 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked already in slack. This makes me a bit worried due to this not being obvious in the type level that we're only talking about MySQL. I'm worried that somebody can without knowing write code that returns bytes from Quaint that are not from MySQL and we run this code and do some crazy wrong stuff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm not proposing or asking any changes, just writing down here that I'm a bit uncomfortable with this :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be nice to have a way to specialize things depending on the specific database we're working with in the sql query connector. That way we could do this, but only on mysql. This hack may go away when we expand the type system.

match bytes.len() {
8 => i64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]),
len if len < 8 => {
let sign_bit_mask: u8 = 0b10000000;
// The first byte will only contain the sign bit.
let most_significant_bit_byte = bytes[0] & sign_bit_mask;
let padding = if most_significant_bit_byte == 0 { 0 } else { 0b11111111 };
let mut i64_bytes = [padding; 8];

for (target_byte, source_byte) in i64_bytes.iter_mut().rev().zip(bytes.iter().rev()) {
*target_byte = *source_byte;
}

i64::from_be_bytes(i64_bytes)
}
0 => 0,
_ => panic!("Attempted to interpret more than 8 bytes as an integer."),
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn quaint_bytes_to_integer_conversion_works() {
// Negative i64
{
let i: i64 = -123456789123;
let bytes = i.to_be_bytes();
let roundtripped = interpret_bytes_as_i64(&bytes);
assert_eq!(roundtripped, i);
}

// Positive i64
{
let i: i64 = 123456789123;
let bytes = i.to_be_bytes();
let roundtripped = interpret_bytes_as_i64(&bytes);
assert_eq!(roundtripped, i);
}

// Positive i32
{
let i: i32 = 123456789;
let bytes = i.to_be_bytes();
let roundtripped = interpret_bytes_as_i64(&bytes);
assert_eq!(roundtripped, i as i64);
}

// Negative i32
{
let i: i32 = -123456789;
let bytes = i.to_be_bytes();
let roundtripped = interpret_bytes_as_i64(&bytes);
assert_eq!(roundtripped, i as i64);
}

// Positive i16
{
let i: i16 = 12345;
let bytes = i.to_be_bytes();
let roundtripped = interpret_bytes_as_i64(&bytes);
assert_eq!(roundtripped, i as i64);
}

// Negative i16
{
let i: i16 = -12345;
let bytes = i.to_be_bytes();
let roundtripped = interpret_bytes_as_i64(&bytes);
assert_eq!(roundtripped, i as i64);
}
}
}