Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

intro: Use the pair structure for models, field and indexes #3456

Merged
merged 2 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ pub enum Version {
Prisma2,
}

impl Version {
pub fn is_prisma1(self) -> bool {
matches!(self, Self::Prisma1 | Self::Prisma11)
}
}

#[derive(Debug)]
pub struct IntrospectionResult {
/// Datamodel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
introspection::introspect,
pair::{EnumPair, Pair},
introspection_helpers::{is_new_migration_table, is_old_migration_table, is_prisma_join_table, is_relay_table},
pair::{EnumPair, ModelPair, Pair},
warnings, EnumVariantName, IntrospectedName, ModelName, SqlFamilyTrait, SqlIntrospectionResult,
};
use introspection_connector::{IntrospectionContext, IntrospectionResult, Version, Warning};
Expand Down Expand Up @@ -113,12 +114,12 @@ pub(crate) struct InputContext<'a> {
pub(crate) render_config: bool,
pub(crate) schema: &'a sql::SqlSchema,
pub(crate) sql_family: SqlFamily,
pub(crate) version: Version,
pub(crate) previous_schema: &'a psl::ValidatedSchema,
pub(crate) introspection_map: &'a crate::introspection_map::IntrospectionMap,
}

pub(crate) struct OutputContext<'a> {
pub(crate) version: Version,
pub(crate) rendered_schema: datamodel_renderer::Datamodel<'a>,
pub(crate) target_models: HashMap<sql::TableId, usize>,
pub(crate) warnings: Warnings,
Expand Down Expand Up @@ -147,12 +148,29 @@ impl<'a> InputContext<'a> {
self.config.datasources.first().unwrap().active_connector
}

/// Iterate over the database enums, combined together with a
/// possible existing enum in the PSL.
pub(crate) fn enum_pairs(self) -> impl ExactSizeIterator<Item = EnumPair<'a>> {
self.schema
.enum_walkers()
.map(move |next| Pair::new(self, self.existing_enum(next.id), next))
}

/// Iterate over the database tables, combined together with a
/// possible existing model in the PSL.
pub(crate) fn model_pairs(self) -> impl Iterator<Item = ModelPair<'a>> {
self.schema
.table_walkers()
.filter(|table| !is_old_migration_table(*table))
.filter(|table| !is_new_migration_table(*table))
.filter(|table| !is_prisma_join_table(*table))
.filter(|table| !is_relay_table(*table))
.map(move |next| {
let previous = self.existing_model(next.id);
Pair::new(self, previous, next)
})
}

/// Given a SQL enum from the database, this method returns the enum that matches it (by name)
/// in the Prisma schema.
pub(crate) fn existing_enum(self, id: sql::EnumId) -> Option<walkers::EnumWalker<'a>> {
Expand Down Expand Up @@ -251,7 +269,8 @@ pub fn calculate_datamodel(
) -> SqlIntrospectionResult<IntrospectionResult> {
let introspection_map = crate::introspection_map::IntrospectionMap::new(schema, ctx.previous_schema());

let input = InputContext {
let mut input = InputContext {
version: Version::NonPrisma,
config: ctx.configuration(),
render_config: ctx.render_config,
schema,
Expand All @@ -261,13 +280,12 @@ pub fn calculate_datamodel(
};

let mut output = OutputContext {
version: Version::NonPrisma,
rendered_schema: datamodel_renderer::Datamodel::default(),
target_models: HashMap::default(),
warnings: Warnings::new(),
};

output.version = crate::version_checker::check_prisma_version(&input);
input.version = crate::version_checker::check_prisma_version(&input);

let (schema_string, is_empty) = introspect(input, &mut output)?;
let warnings = output.finalize_warnings();
Expand All @@ -276,7 +294,7 @@ pub fn calculate_datamodel(
let version = if warnings.iter().any(|w| ![5, 6].contains(&w.code)) {
Version::NonPrisma
} else {
output.version
input.version
};

Ok(IntrospectionResult {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,187 +1,96 @@
use crate::calculate_datamodel::{InputContext, OutputContext};
use datamodel_renderer::datamodel as renderer;
use introspection_connector::Version;
use psl::{
builtin_connectors::{MySqlType, PostgresType},
datamodel_connector::constraint_names::ConstraintNames,
dml,
parser_database::walkers,
use crate::{
calculate_datamodel::OutputContext,
pair::{DefaultKind, ScalarFieldPair},
};
use datamodel_renderer::{
datamodel as renderer,
value::{Constant, Function, Text, Value},
};
use sql_schema_describer::{self as sql, postgres::PostgresSchemaExt};

pub(crate) fn render_default<'a>(
column: sql::ColumnWalker<'a>,
existing_field: Option<walkers::ScalarFieldWalker<'a>>,
input: InputContext<'a>,
pub(crate) fn render<'a>(
field: ScalarFieldPair<'a>,
output: &mut OutputContext<'a>,
) -> Option<renderer::DefaultValue<'a>> {
use datamodel_renderer::value::{Constant, Function, Text, Value};
let mut rendered = match field.default().kind() {
Some(kind) => match kind {
DefaultKind::Sequence(sequence) => {
let mut fun = Function::new("sequence");

let mut result = match (column.default().map(|d| d.kind()), column.column_type_family()) {
(Some(sql::DefaultKind::Sequence(name)), _) if input.is_cockroach() => {
let connector_data: &PostgresSchemaExt = input.schema.downcast_connector_data();
if sequence.min_value != 1 {
fun.push_param(("minValue", Constant::from(sequence.min_value)));
}

let sequence_idx = connector_data
.sequences
.binary_search_by_key(&name, |s| &s.name)
.unwrap();
if sequence.max_value != i64::MAX {
fun.push_param(("maxValue", Constant::from(sequence.max_value)));
}

let sequence = &connector_data.sequences[sequence_idx];
if sequence.cache_size != 1 {
fun.push_param(("cache", Constant::from(sequence.cache_size)));
}

let mut fun = Function::new("sequence");
if sequence.increment_by != 1 {
fun.push_param(("increment", Constant::from(sequence.increment_by)));
}

if sequence.min_value != 1 {
fun.push_param(("minValue", Constant::from(sequence.min_value)));
}
if sequence.start_value != 1 {
fun.push_param(("start", Constant::from(sequence.start_value)));
}

if sequence.max_value != i64::MAX {
fun.push_param(("maxValue", Constant::from(sequence.max_value)));
Some(renderer::DefaultValue::function(fun))
}
DefaultKind::DbGenerated(default_string) => {
let mut fun = Function::new("dbgenerated");

if sequence.cache_size != 1 {
fun.push_param(("cache", Constant::from(sequence.cache_size)));
}
if let Some(param) = default_string.filter(|s| !s.trim_matches('\0').is_empty()) {
fun.push_param(Value::from(Text::new(param)));
}

if sequence.increment_by != 1 {
fun.push_param(("increment", Constant::from(sequence.increment_by)));
Some(renderer::DefaultValue::function(fun))
}

if sequence.start_value != 1 {
fun.push_param(("start", Constant::from(sequence.start_value)));
DefaultKind::Autoincrement => Some(renderer::DefaultValue::function(Function::new("autoincrement"))),
DefaultKind::Uuid => Some(renderer::DefaultValue::function(Function::new("uuid"))),
DefaultKind::Cuid => Some(renderer::DefaultValue::function(Function::new("cuid"))),
DefaultKind::Now => Some(renderer::DefaultValue::function(Function::new("now"))),
DefaultKind::String(s) => Some(renderer::DefaultValue::text(s)),
DefaultKind::Constant(c) => Some(renderer::DefaultValue::constant(c)),
DefaultKind::EnumVariant(c) => Some(renderer::DefaultValue::constant(c)),
DefaultKind::Bytes(b) => Some(renderer::DefaultValue::bytes(b)),
DefaultKind::StringList(vals) => {
let vals = vals.into_iter().map(Text::new).collect();
Some(renderer::DefaultValue::array(vals))
}

Some(renderer::DefaultValue::function(fun))
}
(_, sql::ColumnTypeFamily::Int | sql::ColumnTypeFamily::BigInt) if column.is_autoincrement() => {
Some(renderer::DefaultValue::function(Function::new("autoincrement")))
}
(_, sql::ColumnTypeFamily::Int | sql::ColumnTypeFamily::BigInt) if is_sequence(column) => {
Some(renderer::DefaultValue::function(Function::new("autoincrement")))
}
(Some(sql::DefaultKind::Sequence(_)), _) => {
Some(renderer::DefaultValue::function(Function::new("autoincrement")))
}
(Some(sql::DefaultKind::UniqueRowid), _) => {
Some(renderer::DefaultValue::function(Function::new("autoincrement")))
}
(Some(sql::DefaultKind::Now), sql::ColumnTypeFamily::DateTime) => {
Some(renderer::DefaultValue::function(Function::new("now")))
}
(Some(sql::DefaultKind::DbGenerated(default_string)), _) => {
let mut fun = Function::new("dbgenerated");

if let Some(param) = default_string.as_ref().filter(|s| !s.trim_matches('\0').is_empty()) {
fun.push_param(Value::from(Text::new(param)));
DefaultKind::ConstantList(vals) => Some(renderer::DefaultValue::array(vals)),
DefaultKind::BytesList(vals) => {
let vals = vals.into_iter().map(Value::from).collect();
Some(renderer::DefaultValue::array(vals))
}

Some(renderer::DefaultValue::function(fun))
}
(Some(sql::DefaultKind::Value(dml::PrismaValue::Enum(variant))), sql::ColumnTypeFamily::Enum(enum_id)) => {
let variant = input
.schema
.walk(*enum_id)
.variants()
.find(|v| v.name() == variant)
.unwrap();

let variant_name = input.enum_variant_name(variant.id).prisma_name();
Some(renderer::DefaultValue::constant(variant_name))
}
(Some(sql::DefaultKind::Value(dml::PrismaValue::String(val))), _) => Some(renderer::DefaultValue::text(val)),
(Some(sql::DefaultKind::Value(dml::PrismaValue::List(val))), _) => {
let vals = val
.iter()
.map(|val| match val {
dml::PrismaValue::String(v) => Value::from(Text::new(v)),
dml::PrismaValue::Boolean(v) => Value::from(Constant::from(v)),
dml::PrismaValue::Enum(v) => Value::from(Constant::from(v)),
dml::PrismaValue::Int(v) => Value::from(Constant::from(v)),
dml::PrismaValue::Uuid(v) => Value::from(Constant::from(v)),
dml::PrismaValue::List(_) => unreachable!("Lists of lists are not supported in defaults."),
dml::PrismaValue::Json(v) => Value::from(Text::new(v)),
dml::PrismaValue::Xml(v) => Value::from(Text::new(v)),
dml::PrismaValue::Object(_) => unreachable!("Objects are not supported in defaults."),
dml::PrismaValue::Null => Value::from(Constant::from("null")),
dml::PrismaValue::DateTime(v) => Value::from(Constant::from(v)),
dml::PrismaValue::Float(v) => Value::from(Constant::from(v)),
dml::PrismaValue::BigInt(v) => Value::from(Constant::from(v)),
dml::PrismaValue::Bytes(v) => Value::from(v.clone()),
})
.collect();

Some(renderer::DefaultValue::array(vals))
}
(Some(sql::DefaultKind::Value(val)), _) => Some(renderer::DefaultValue::constant(val)),

// Prisma-level defaults.
(None, sql::ColumnTypeFamily::String) => match existing_field.and_then(|f| f.default_value()) {
Some(value) if value.is_cuid() => Some(renderer::DefaultValue::function(Function::new("cuid"))),
Some(value) if value.is_uuid() => Some(renderer::DefaultValue::function(Function::new("uuid"))),
None if matches!(output.version, Version::Prisma1 | Version::Prisma11) => {
maybe_prisma1_default(column, input, output)
DefaultKind::Prisma1Uuid => {
let warn = crate::warnings::ModelAndField {
model: field.model().name().to_string(),
field: field.name().to_string(),
};

output.warnings.prisma_1_uuid_defaults.push(warn);
Some(renderer::DefaultValue::function(Function::new("uuid")))
}
_ => None,
},

_ => None,
};

if let Some(res) = result.as_mut() {
let default_default_value =
ConstraintNames::default_name(column.table().name(), column.name(), input.active_connector());

match column.default().and_then(|def| def.constraint_name()) {
Some(map) if map != default_default_value => {
res.map(map);
DefaultKind::Prisma1Cuid => {
let warn = crate::warnings::ModelAndField {
model: field.model().name().to_string(),
field: field.name().to_string(),
};

output.warnings.prisma_1_cuid_defaults.push(warn);
Some(renderer::DefaultValue::function(Function::new("cuid")))
}
_ => (),
}
}

result
}

fn is_sequence(column: sql::ColumnWalker<'_>) -> bool {
column.is_single_primary_key() && matches!(&column.default(), Some(d) if d.is_sequence())
}

fn maybe_prisma1_default<'a>(
column: sql::ColumnWalker<'a>,
input: InputContext<'a>,
output: &mut OutputContext<'a>,
) -> Option<renderer::DefaultValue<'a>> {
use datamodel_renderer::value::Function;

let model_and_field = || crate::warnings::ModelAndField {
model: input.table_prisma_name(column.table().id).prisma_name().into_owned(),
field: input.column_prisma_name(column.id).prisma_name().into_owned(),
},
None => None,
};

if input.sql_family.is_postgres() {
let native_type: &PostgresType = column.column_type().native_type.as_ref()?.downcast_ref();

if native_type == &PostgresType::VarChar(Some(25)) {
output.warnings.prisma_1_cuid_defaults.push(model_and_field());

return Some(renderer::DefaultValue::function(Function::new("cuid")));
} else if native_type == &PostgresType::VarChar(Some(36)) {
output.warnings.prisma_1_uuid_defaults.push(model_and_field());

return Some(renderer::DefaultValue::function(Function::new("uuid")));
}
} else if input.sql_family.is_mysql() {
let native_type: &MySqlType = column.column_type().native_type.as_ref()?.downcast_ref();

if native_type == &MySqlType::Char(25) {
output.warnings.prisma_1_cuid_defaults.push(model_and_field());

return Some(renderer::DefaultValue::function(Function::new("cuid")));
} else if native_type == &MySqlType::Char(36) {
output.warnings.prisma_1_uuid_defaults.push(model_and_field());

return Some(renderer::DefaultValue::function(Function::new("uuid")));
if let Some(res) = rendered.as_mut() {
if let Some(mapped_name) = field.default().mapped_name() {
res.map(mapped_name);
}
}

None
rendered
}