Skip to content

Commit

Permalink
refactor(datamodel): move all relation reformatting logic to reformat.rs
Browse files Browse the repository at this point in the history
This makes it explicitly a private concern of the reformatter that the
rest of the code does not need to worry about. As a side effect, we do
not allocate anymore for potentially inferred fields in validations, and
some validations as well as lifting gets simpler.

closes prisma/prisma#13742 as a side effect
  • Loading branch information
tomhoule committed Jun 28, 2022
1 parent da41d2b commit 1d7fa24
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 238 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,10 @@ impl<'db> InlineRelationWalkerExt<'db> for InlineRelationWalker<'db> {
fn constraint_name(self, connector: &dyn Connector) -> Cow<'db, str> {
self.mapped_name().map(Cow::Borrowed).unwrap_or_else(|| {
let model_database_name = self.referencing_model().database_name();
let field_names: Vec<&str> = match self.referencing_fields() {
ReferencingFields::Concrete(fields) => fields.map(|f| f.database_name()).collect(),
ReferencingFields::Inferred(fields) => {
let field_names: Vec<_> = fields.iter().map(|f| f.name.as_str()).collect();
return ConstraintNames::foreign_key_constraint_name(model_database_name, &field_names, connector)
.into();
}
ReferencingFields::NA => Vec::new(),
};
let field_names: Vec<&str> = self
.referencing_fields()
.map(|fields| fields.map(|f| f.database_name()).collect())
.unwrap_or_default();
ConstraintNames::foreign_key_constraint_name(model_database_name, &field_names, connector).into()
})
}
Expand Down
145 changes: 113 additions & 32 deletions libs/datamodel/core/src/reformat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -791,9 +791,9 @@ fn push_inline_relation_missing_arguments(
if let Some(forward) = inline_relation.forward_relation_field() {
// the `fields: [...]` argument.
match inline_relation.referencing_fields() {
walkers::ReferencingFields::Concrete(_) => (),
walkers::ReferencingFields::NA => (), // error somewhere else
walkers::ReferencingFields::Inferred(fields) => {
Some(_) => (),
None => {
let fields: Vec<InferredScalarField<'_>> = infer_missing_referencing_scalar_fields(inline_relation);
let missing_arg = MissingRelationAttributeArg {
model: forward.model().name().to_owned(),
field: forward.ast_field().name.name.to_owned(),
Expand Down Expand Up @@ -857,21 +857,22 @@ fn push_missing_relation_attribute(

// the `fields: [...]` argument.
let fields: Option<ast::Argument> = match inline_relation.referencing_fields() {
walkers::ReferencingFields::Concrete(_) => None,
walkers::ReferencingFields::NA => None, // error somewhere else
walkers::ReferencingFields::Inferred(fields) => Some(ast::Argument {
name: Some(ast::Identifier::new("fields")),
value: ast::Expression::Array(
fields
.into_iter()
.map(|f| ast::Expression::ConstantValue(f.name, Span::empty()))
.collect(),
Span::empty(),
),
span: Span::empty(),
}),
Some(_) => None,
None => {
let fields = infer_missing_referencing_scalar_fields(inline_relation);
Some(ast::Argument {
name: Some(ast::Identifier::new("fields")),
value: ast::Expression::Array(
fields
.into_iter()
.map(|f| ast::Expression::ConstantValue(f.name, Span::empty()))
.collect(),
Span::empty(),
),
span: Span::empty(),
})
}
};

// the `references: [...]` argument
let references: Option<ast::Argument> = if forward.referenced_fields().is_none() {
Some(ast::Argument {
Expand Down Expand Up @@ -957,24 +958,18 @@ fn push_missing_relation_fields(inline: walkers::InlineRelationWalker<'_>, missi
field: ast::Field {
field_type: ast::FieldType::Supported(ast::Identifier::new(inline.referenced_model().name())),
name: ast::Identifier::new(inline.referenced_model().name()),
arity: inline.forward_relation_field_arity(),
arity: forward_relation_field_arity(inline),
attributes: vec![ast::Attribute {
name: ast::Identifier::new("relation"),
arguments: ast::ArgumentsList {
arguments: vec![
ast::Argument {
name: Some(ast::Identifier::new("fields")),
value: ast::Expression::Array(
match inline.referencing_fields() {
walkers::ReferencingFields::Concrete(fields) => fields
.map(|f| ast::Expression::ConstantValue(f.name().to_owned(), Span::empty()))
.collect(),
walkers::ReferencingFields::Inferred(fields) => fields
.into_iter()
.map(|f| ast::Expression::ConstantValue(f.name, Span::empty()))
.collect(),
walkers::ReferencingFields::NA => Vec::new(),
},
infer_missing_referencing_scalar_fields(inline)
.into_iter()
.map(|f| ast::Expression::ConstantValue(f.name, Span::empty()))
.collect(),
Span::empty(),
),
span: Span::empty(),
Expand Down Expand Up @@ -1005,9 +1000,9 @@ fn push_missing_relation_fields(inline: walkers::InlineRelationWalker<'_>, missi
}

fn push_missing_scalar_fields(inline: walkers::InlineRelationWalker<'_>, missing_fields: &mut Vec<MissingField>) {
let missing_scalar_fields = match inline.referencing_fields() {
walkers::ReferencingFields::Inferred(inferred_fields) => inferred_fields,
_ => return,
let missing_scalar_fields: Vec<InferredScalarField<'_>> = match inline.referencing_fields() {
Some(_) => return,
None => infer_missing_referencing_scalar_fields(inline),
};

// Filter out duplicate fields
Expand Down Expand Up @@ -1045,7 +1040,7 @@ fn push_missing_scalar_fields(inline: walkers::InlineRelationWalker<'_>, missing
field: ast::Field {
field_type: ast::FieldType::Supported(ast::Identifier::new(field_type.as_str())),
name: ast::Identifier::new(&field.name),
arity: inline.forward_relation_field_arity(),
arity: field.arity,
attributes,
documentation: None,
span: Span::empty(),
Expand All @@ -1054,3 +1049,89 @@ fn push_missing_scalar_fields(inline: walkers::InlineRelationWalker<'_>, missing
})
}
}

/// A scalar inferred by magic reformatting.
#[derive(Debug)]
struct InferredScalarField<'db> {
name: String,
arity: ast::FieldArity,
tpe: parser_database::ScalarFieldType,
blueprint: walkers::ScalarFieldWalker<'db>,
}

fn infer_missing_referencing_scalar_fields(inline: walkers::InlineRelationWalker<'_>) -> Vec<InferredScalarField<'_>> {
match inline.referenced_model().unique_criterias().next() {
Some(first_unique_criteria) => {
first_unique_criteria
.fields()
.map(|field| {
let name = format!(
"{}{}",
camel_case(inline.referenced_model().name()),
pascal_case(field.name())
);

// we cannot have composite fields in a relation for now.
let field = field.as_scalar_field().unwrap();

if let Some(existing_field) =
inline.referencing_model().scalar_fields().find(|sf| sf.name() == name)
{
InferredScalarField {
name,
arity: existing_field.ast_field().arity,
tpe: existing_field.scalar_field_type(),
blueprint: field,
}
} else {
InferredScalarField {
name,
arity: inline
.forward_relation_field()
.map(|f| f.ast_field().arity)
.unwrap_or(ast::FieldArity::Optional),
tpe: field.scalar_field_type(),
blueprint: field,
}
}
})
.collect()
}
None => Vec::new(),
}
}

fn pascal_case(input: &str) -> String {
let mut c = input.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
}

fn camel_case(input: &str) -> String {
let mut c = input.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_lowercase().collect::<String>() + c.as_str(),
}
}

/// The arity of the forward relation field. Works even without forward relation field.
fn forward_relation_field_arity(inline: walkers::InlineRelationWalker<'_>) -> ast::FieldArity {
inline
// First use the relation field itself if it exists.
.forward_relation_field()
.map(|rf| rf.ast_field().arity)
// Otherwise, if we have fields that look right on the model, use these.
.unwrap_or_else(|| {
if infer_missing_referencing_scalar_fields(inline)
.into_iter()
.any(|f| f.arity.is_optional())
{
ast::FieldArity::Optional
} else {
ast::FieldArity::Required
}
})
}
105 changes: 33 additions & 72 deletions libs/datamodel/core/src/transform/ast_to_dml/lift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,39 +125,25 @@ impl<'a> LiftAstToDml<'a> {
}

let relation_info = dml::RelationInfo::new(relation.referenced_model().name());
let model = schema.find_model_mut(relation.referencing_model().name());

// reformatted/virtual/inferred extra scalar fields for reformatted relations.
let mut inferred_scalar_fields = Vec::new();

let mut relation_field = if let Some(relation_field) = relation.forward_relation_field() {
// Construct a relation field in the DML for an existing relation field in the source.
let arity = self.lift_field_arity(&relation_field.ast_field().arity);
let referential_arity = self.lift_field_arity(&relation_field.referential_arity());
let mut field =
dml::RelationField::new(relation_field.name(), arity, referential_arity, relation_info);

field.relation_info.fk_name = Some(relation.constraint_name(active_connector).into_owned());
common_dml_fields(&mut field, relation_field);
field_ids_for_sorting.insert(
(relation_field.model().name(), relation_field.name()),
relation_field.field_id(),
);
let forward_field_walker = relation.forward_relation_field().unwrap();
// Construct a relation field in the DML for an existing relation field in the source.
let arity = self.lift_arity(&forward_field_walker.ast_field().arity);
let referential_arity = self.lift_arity(&forward_field_walker.referential_arity());
let mut relation_field = dml::RelationField::new(
forward_field_walker.name(),
arity,
referential_arity,
relation_info,
);

field
} else {
// Construct a relation field in the DML without corresponding relation field in the source.
//
// This is part of magic reformatting.
let arity = self.lift_field_arity(&relation.forward_relation_field_arity());
let referential_arity = arity;
dml::RelationField::new(
relation.referenced_model().name(),
arity,
referential_arity,
relation_info,
)
};
relation_field.relation_info.fk_name =
Some(relation.constraint_name(active_connector).into_owned());
common_dml_fields(&mut relation_field, forward_field_walker);
field_ids_for_sorting.insert(
(forward_field_walker.model().name(), forward_field_walker.name()),
forward_field_walker.field_id(),
);

relation_field.relation_info.name = relation.relation_name().to_string();

Expand All @@ -166,39 +152,14 @@ impl<'a> LiftAstToDml<'a> {
.map(|field| field.name().to_owned())
.collect();

relation_field.relation_info.fields = match relation.referencing_fields() {
ReferencingFields::Concrete(fields) => fields.map(|f| f.name().to_owned()).collect(),
ReferencingFields::Inferred(fields) => {
// In this branch, we are creating the underlying scalar fields
// from thin air. This is part of reformatting.
let mut field_names = Vec::with_capacity(fields.len());

for field in fields {
let field_type = self.lift_scalar_field_type(
field.blueprint.ast_field(),
&field.tpe,
field.blueprint,
);
let mut scalar_field = dml::ScalarField::new_generated(&field.name, field_type);
scalar_field.arity = if relation_field.arity.is_required() {
dml::FieldArity::Required
} else {
self.lift_field_arity(&field.arity)
};
inferred_scalar_fields.push(dml::Field::ScalarField(scalar_field));

field_names.push(field.name);
}

field_names
}
ReferencingFields::NA => Vec::new(),
};
model.add_field(dml::Field::RelationField(relation_field));
relation_field.relation_info.fields = relation
.referencing_fields()
.unwrap()
.map(|f| f.name().to_owned())
.collect();

for field in inferred_scalar_fields {
model.add_field(field)
}
let model = schema.find_model_mut(relation.referencing_model().name());
model.add_field(dml::Field::RelationField(relation_field));
};

// Back field
Expand All @@ -208,8 +169,8 @@ impl<'a> LiftAstToDml<'a> {

let mut field = if let Some(relation_field) = relation.back_relation_field() {
let ast_field = relation_field.ast_field();
let arity = self.lift_field_arity(&ast_field.arity);
let referential_arity = self.lift_field_arity(&relation_field.referential_arity());
let arity = self.lift_arity(&ast_field.arity);
let referential_arity = self.lift_arity(&relation_field.referential_arity());
let mut field =
dml::RelationField::new(relation_field.name(), arity, referential_arity, relation_info);

Expand Down Expand Up @@ -242,9 +203,9 @@ impl<'a> LiftAstToDml<'a> {
RefinedRelationWalker::ImplicitManyToMany(relation) => {
for relation_field in [relation.field_a(), relation.field_b()] {
let ast_field = relation_field.ast_field();
let arity = self.lift_field_arity(&ast_field.arity);
let arity = self.lift_arity(&ast_field.arity);
let relation_info = dml::RelationInfo::new(relation_field.related_model().name());
let referential_arity = self.lift_field_arity(&relation_field.referential_arity());
let referential_arity = self.lift_arity(&relation_field.referential_arity());
let mut field =
dml::RelationField::new(relation_field.name(), arity, referential_arity, relation_info);

Expand Down Expand Up @@ -272,9 +233,9 @@ impl<'a> LiftAstToDml<'a> {
RefinedRelationWalker::TwoWayEmbeddedManyToMany(relation) => {
for relation_field in [relation.field_a(), relation.field_b()] {
let ast_field = relation_field.ast_field();
let arity = self.lift_field_arity(&ast_field.arity);
let arity = self.lift_arity(&ast_field.arity);
let relation_info = dml::RelationInfo::new(relation_field.related_model().name());
let referential_arity = self.lift_field_arity(&relation_field.referential_arity());
let referential_arity = self.lift_arity(&relation_field.referential_arity());

let mut field =
dml::RelationField::new(relation_field.name(), arity, referential_arity, relation_info);
Expand Down Expand Up @@ -314,7 +275,7 @@ impl<'a> LiftAstToDml<'a> {
let field = CompositeTypeField {
name: field.name().to_owned(),
r#type: self.lift_composite_type_field_type(field, field.r#type()),
arity: self.lift_field_arity(&field.arity()),
arity: self.lift_arity(&field.arity()),
database_name: field.mapped_name().map(String::from),
documentation: field.documentation().map(ToString::to_string),
default_value: field.default_value().map(|value| dml::DefaultValue {
Expand Down Expand Up @@ -423,7 +384,7 @@ impl<'a> LiftAstToDml<'a> {
for scalar_field in walker.scalar_fields() {
let field_id = scalar_field.field_id();
let ast_field = &ast_model[field_id];
let arity = self.lift_field_arity(&ast_field.arity);
let arity = self.lift_arity(&ast_field.arity);

field_ids_for_sorting.insert((&ast_model.name.name, &ast_field.name.name), field_id);

Expand Down Expand Up @@ -483,7 +444,7 @@ impl<'a> LiftAstToDml<'a> {
}

/// Internal: Lift a field's arity.
fn lift_field_arity(&self, field_arity: &ast::FieldArity) -> dml::FieldArity {
fn lift_arity(&self, field_arity: &ast::FieldArity) -> dml::FieldArity {
match field_arity {
ast::FieldArity::Required => dml::FieldArity::Required,
ast::FieldArity::Optional => dml::FieldArity::Optional,
Expand Down

0 comments on commit 1d7fa24

Please sign in to comment.