From 454112fd466075270cefb55d4659bbd30345e604 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 15 Nov 2024 19:41:18 +0100 Subject: [PATCH] feat(psl): add cuid(2) support, fix uuid(7) re-introspection --- psl/parser-database/src/attributes/default.rs | 41 +++++++++++---- psl/parser-database/src/generators.rs | 13 +++++ psl/parser-database/src/lib.rs | 1 + psl/psl-core/src/lib.rs | 2 +- psl/psl/src/lib.rs | 1 + psl/psl/tests/attributes/id_positive.rs | 43 ++++++++++++++- psl/psl/tests/common/asserts.rs | 50 +++++++++++++++++- .../query-structure/src/default_value.rs | 32 ++++++++---- .../query-structure/src/field/scalar.rs | 13 +++-- .../tests/datamodel_converter_tests.rs | 52 +++++++++++++++++++ .../introspection_pair/default.rs | 24 +++++++-- .../src/introspection/rendering/defaults.rs | 20 ++++++- .../tests/re_introspection/mod.rs | 17 ++++-- 13 files changed, 272 insertions(+), 37 deletions(-) create mode 100644 psl/parser-database/src/generators.rs diff --git a/psl/parser-database/src/attributes/default.rs b/psl/parser-database/src/attributes/default.rs index d1f6b887fa22..fd5181b284e3 100644 --- a/psl/parser-database/src/attributes/default.rs +++ b/psl/parser-database/src/attributes/default.rs @@ -2,6 +2,7 @@ use crate::{ ast::{self, WithName}, coerce, context::Context, + generators::{CUID_SUPPORTED_VERSIONS, UUID_SUPPORTED_VERSIONS}, types::{DefaultAttribute, ScalarFieldType, ScalarType}, DatamodelError, ScalarFieldId, StringId, }; @@ -197,10 +198,10 @@ fn validate_model_builtin_scalar_type_default( validate_empty_function_args(funcname, &funcargs.arguments, accept, ctx) } (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_CUID => { - validate_empty_function_args(funcname, &funcargs.arguments, accept, ctx) + validate_uid_int_args(funcname, &funcargs.arguments, &CUID_SUPPORTED_VERSIONS, accept, ctx) } (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_UUID => { - validate_uuid_args(&funcargs.arguments, accept, ctx) + validate_uid_int_args(funcname, &funcargs.arguments, &UUID_SUPPORTED_VERSIONS, accept, ctx) } (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_NANOID => { validate_nanoid_args(&funcargs.arguments, accept, ctx) @@ -244,10 +245,10 @@ fn validate_composite_builtin_scalar_type_default( match (scalar_type, value) { // Functions (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_CUID => { - validate_empty_function_args(funcname, &funcargs.arguments, accept, ctx) + validate_uid_int_args(funcname, &funcargs.arguments, &CUID_SUPPORTED_VERSIONS, accept, ctx) } (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_UUID => { - validate_uuid_args(&funcargs.arguments, accept, ctx) + validate_uid_int_args(funcname, &funcargs.arguments, &UUID_SUPPORTED_VERSIONS, accept, ctx) } (ScalarType::DateTime, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_NOW => { validate_empty_function_args(FN_NOW, &funcargs.arguments, accept, ctx) @@ -381,18 +382,38 @@ fn validate_dbgenerated_args(args: &[ast::Argument], accept: AcceptFn<'_>, ctx: } } -fn validate_uuid_args(args: &[ast::Argument], accept: AcceptFn<'_>, ctx: &mut Context<'_>) { - let mut bail = || ctx.push_attribute_validation_error("`uuid()` takes a single Int argument."); +fn format_valid_values(valid_values: &[u8; N]) -> String { + match valid_values.len() { + 0 => String::new(), + 1 => valid_values[0].to_string(), + 2 => format!("{} or {}", valid_values[0], valid_values[1]), + _ => { + let (last, rest) = valid_values.split_last().unwrap(); + let rest_str = rest.iter().map(|v| v.to_string()).collect::>().join(", "); + format!("{}, or {}", rest_str, last) + } + } +} + +fn validate_uid_int_args( + fn_name: &str, + args: &[ast::Argument], + valid_values: &[u8; N], + accept: AcceptFn<'_>, + ctx: &mut Context<'_>, +) { + let mut bail = || ctx.push_attribute_validation_error(&format!("`{fn_name}()` takes a single Int argument.")); if args.len() > 1 { bail() } match args.first().map(|arg| &arg.value) { - Some(ast::Expression::NumericValue(val, _)) if ![4u8, 7u8].contains(&val.parse::().unwrap()) => { - ctx.push_attribute_validation_error( - "`uuid()` takes either no argument, or a single integer argument which is either 4 or 7.", - ); + Some(ast::Expression::NumericValue(val, _)) if !valid_values.contains(&val.parse::().unwrap()) => { + let valid_values_str = format_valid_values(&valid_values); + ctx.push_attribute_validation_error(&format!( + "`{fn_name}()` takes either no argument, or a single integer argument which is either {valid_values_str}.", + )); } None | Some(ast::Expression::NumericValue(_, _)) => accept(ctx), _ => bail(), diff --git a/psl/parser-database/src/generators.rs b/psl/parser-database/src/generators.rs new file mode 100644 index 000000000000..5e7cfe16f9bc --- /dev/null +++ b/psl/parser-database/src/generators.rs @@ -0,0 +1,13 @@ +//! Convenient access to a ID generator constants, used by Prisma in psl, Query Engine and Schema Engine. + +/// Version of the `uuid()` ID generator supported by Prisma. +pub const UUID_SUPPORTED_VERSIONS: [u8; 2] = [4, 7]; + +/// Version of the `cuid()` ID generator supported by Prisma. +pub const CUID_SUPPORTED_VERSIONS: [u8; 2] = [1, 2]; + +/// Default version of the `uuid()` ID generator. +pub const DEFAULT_UUID_VERSION: u8 = 4; + +/// Default version of the `cuid()` ID generator. +pub const DEFAULT_CUID_VERSION: u8 = 2; diff --git a/psl/parser-database/src/lib.rs b/psl/parser-database/src/lib.rs index 5ada8cebb961..087abc42c970 100644 --- a/psl/parser-database/src/lib.rs +++ b/psl/parser-database/src/lib.rs @@ -32,6 +32,7 @@ mod attributes; mod coerce_expression; mod context; mod files; +pub mod generators; mod ids; mod interner; mod names; diff --git a/psl/psl-core/src/lib.rs b/psl/psl-core/src/lib.rs index 304ef3cc5cfe..e4b36dd15e92 100644 --- a/psl/psl-core/src/lib.rs +++ b/psl/psl-core/src/lib.rs @@ -25,7 +25,7 @@ pub use crate::{ reformat::{reformat, reformat_multiple, reformat_validated_schema_into_single}, }; pub use diagnostics; -pub use parser_database::{self, is_reserved_type_name}; +pub use parser_database::{self, generators, is_reserved_type_name}; pub use schema_ast; pub use set_config_dir::set_config_dir; diff --git a/psl/psl/src/lib.rs b/psl/psl/src/lib.rs index a0085d2b790d..b58b84d95546 100644 --- a/psl/psl/src/lib.rs +++ b/psl/psl/src/lib.rs @@ -7,6 +7,7 @@ pub use psl_core::{ builtin_connectors::{can_have_capability, can_support_relation_load_strategy, has_capability}, datamodel_connector, diagnostics::{self, Diagnostics}, + generators, is_reserved_type_name, mcf::config_to_mcf_json_value as get_config, mcf::{generators_to_json, render_sources_to_json}, // for tests diff --git a/psl/psl/tests/attributes/id_positive.rs b/psl/psl/tests/attributes/id_positive.rs index e81c1305c274..e6e7ec010295 100644 --- a/psl/psl/tests/attributes/id_positive.rs +++ b/psl/psl/tests/attributes/id_positive.rs @@ -50,6 +50,45 @@ fn should_allow_string_ids_with_cuid() { model.assert_id_on_fields(&["id"]); } +#[test] +fn should_allow_string_ids_with_cuid_version_specified() { + let dml = indoc! {r#" + model ModelA { + id String @id @default(cuid(1)) + } + + model ModelB { + id String @id @default(cuid(2)) + } + "#}; + + let schema = psl::parse_schema(dml).unwrap(); + + { + let model = schema.assert_has_model("ModelA"); + + model + .assert_has_scalar_field("id") + .assert_scalar_type(ScalarType::String) + .assert_default_value() + .assert_cuid_version(1); + + model.assert_id_on_fields(&["id"]); + } + + { + let model = schema.assert_has_model("ModelB"); + + model + .assert_has_scalar_field("id") + .assert_scalar_type(ScalarType::String) + .assert_default_value() + .assert_cuid_version(2); + + model.assert_id_on_fields(&["id"]); + } +} + #[test] fn should_allow_string_ids_with_uuid() { let dml = indoc! {r#" @@ -91,7 +130,7 @@ fn should_allow_string_ids_with_uuid_version_specified() { .assert_has_scalar_field("id") .assert_scalar_type(ScalarType::String) .assert_default_value() - .assert_uuid(); + .assert_uuid_version(4); model.assert_id_on_fields(&["id"]); } @@ -103,7 +142,7 @@ fn should_allow_string_ids_with_uuid_version_specified() { .assert_has_scalar_field("id") .assert_scalar_type(ScalarType::String) .assert_default_value() - .assert_uuid(); + .assert_uuid_version(7); model.assert_id_on_fields(&["id"]); } diff --git a/psl/psl/tests/common/asserts.rs b/psl/psl/tests/common/asserts.rs index e75df7c2cd7b..accb524650a6 100644 --- a/psl/psl/tests/common/asserts.rs +++ b/psl/psl/tests/common/asserts.rs @@ -83,7 +83,9 @@ pub(crate) trait DefaultValueAssert { fn assert_bytes(&self, val: &[u8]) -> &Self; fn assert_now(&self) -> &Self; fn assert_cuid(&self) -> &Self; + fn assert_cuid_version(&self, version: u8) -> &Self; fn assert_uuid(&self) -> &Self; + fn assert_uuid_version(&self, version: u8) -> &Self; fn assert_dbgenerated(&self, val: &str) -> &Self; fn assert_mapped_name(&self, val: &str) -> &Self; } @@ -433,12 +435,24 @@ impl<'a> DefaultValueAssert for walkers::DefaultValueWalker<'a> { self } + #[track_caller] + fn assert_cuid_version(&self, version: u8) -> &Self { + self.value().assert_cuid_version(version); + self + } + #[track_caller] fn assert_uuid(&self) -> &Self { self.value().assert_uuid(); self } + #[track_caller] + fn assert_uuid_version(&self, version: u8) -> &Self { + self.value().assert_uuid_version(version); + self + } + #[track_caller] fn assert_dbgenerated(&self, val: &str) -> &Self { self.value().assert_dbgenerated(val); @@ -623,12 +637,29 @@ impl DefaultValueAssert for ast::Expression { #[track_caller] fn assert_cuid(&self) -> &Self { assert!( - matches!(self, ast::Expression::Function(name, args, _) if name == "cuid" && args.arguments.is_empty()) + matches!(self, ast::Expression::Function(name, _, _) if name == "cuid" /* && args.arguments.is_empty() */) ); self } + #[track_caller] + fn assert_cuid_version(&self, version: u8) -> &Self { + self.assert_cuid(); + + if let ast::Expression::Function(_, args, _) = self { + if let ast::Expression::NumericValue(actual, _) = &args.arguments[0].value { + assert_eq!(actual, &format!("{version}")); + } else { + panic!("Expected a numeric value for the version."); + } + } else { + unreachable!(); + } + + self + } + #[track_caller] fn assert_uuid(&self) -> &Self { assert!(matches!(self, ast::Expression::Function(name, _, _) if name == "uuid")); @@ -636,6 +667,23 @@ impl DefaultValueAssert for ast::Expression { self } + #[track_caller] + fn assert_uuid_version(&self, version: u8) -> &Self { + self.assert_uuid(); + + if let ast::Expression::Function(_, args, _) = self { + if let ast::Expression::NumericValue(actual, _) = &args.arguments[0].value { + assert_eq!(actual, &format!("{version}")); + } else { + panic!("Expected a numeric value for the version."); + } + } else { + unreachable!(); + } + + self + } + #[track_caller] fn assert_dbgenerated(&self, val: &str) -> &Self { match self { diff --git a/query-engine/query-structure/src/default_value.rs b/query-engine/query-structure/src/default_value.rs index 605224909e31..5afe2a5b7181 100644 --- a/query-engine/query-structure/src/default_value.rs +++ b/query-engine/query-structure/src/default_value.rs @@ -182,8 +182,8 @@ impl ValueGenerator { ValueGenerator::new("now".to_owned(), vec![]).unwrap() } - pub fn new_cuid() -> Self { - ValueGenerator::new("cuid".to_owned(), vec![]).unwrap() + pub fn new_cuid(version: u8) -> Self { + ValueGenerator::new(format!("cuid({version})"), vec![]).unwrap() } pub fn new_uuid(version: u8) -> Self { @@ -239,7 +239,7 @@ impl ValueGenerator { #[derive(Clone, Copy, PartialEq)] pub enum ValueGeneratorFn { Uuid(u8), - Cuid, + Cuid(u8), Nanoid(Option), Now, Autoincrement, @@ -250,7 +250,8 @@ pub enum ValueGeneratorFn { impl ValueGeneratorFn { fn new(name: &str) -> std::result::Result { match name { - "cuid" => Ok(Self::Cuid), + "cuid" | "cuid(2)" => Ok(Self::Cuid(2)), + "cuid(1)" => Ok(Self::Cuid(1)), "uuid" | "uuid(4)" => Ok(Self::Uuid(4)), "uuid(7)" => Ok(Self::Uuid(7)), "now" => Ok(Self::Now), @@ -267,7 +268,7 @@ impl ValueGeneratorFn { fn invoke(&self) -> Option { match self { Self::Uuid(version) => Some(Self::generate_uuid(*version)), - Self::Cuid => Some(Self::generate_cuid()), + Self::Cuid(version) => Some(Self::generate_cuid(*version)), Self::Nanoid(length) => Some(Self::generate_nanoid(length)), Self::Now => Some(Self::generate_now()), Self::Autoincrement => None, @@ -277,9 +278,13 @@ impl ValueGeneratorFn { } #[cfg(feature = "default_generators")] - fn generate_cuid() -> PrismaValue { - #[allow(deprecated)] - PrismaValue::String(cuid::cuid().unwrap()) + fn generate_cuid(version: u8) -> PrismaValue { + PrismaValue::String(match version { + #[allow(deprecated)] + 1 => cuid::cuid().unwrap(), + 2 => cuid::cuid2(), + _ => panic!("Unknown `cuid` version: {}", version), + }) } #[cfg(feature = "default_generators")] @@ -358,8 +363,15 @@ mod tests { } #[test] - fn default_value_is_cuid() { - let cuid_default = DefaultValue::new_expression(ValueGenerator::new_cuid()); + fn default_value_is_cuidv1() { + let cuid_default = DefaultValue::new_expression(ValueGenerator::new_cuid(1)); + + assert!(cuid_default.is_cuid()); + assert!(!cuid_default.is_now()); + } + + fn default_value_is_cuidv2() { + let cuid_default = DefaultValue::new_expression(ValueGenerator::new_cuid(2)); assert!(cuid_default.is_cuid()); assert!(!cuid_default.is_now()); diff --git a/query-engine/query-structure/src/field/scalar.rs b/query-engine/query-structure/src/field/scalar.rs index 5fc10acddd13..8027715a3f9c 100644 --- a/query-engine/query-structure/src/field/scalar.rs +++ b/query-engine/query-structure/src/field/scalar.rs @@ -1,6 +1,7 @@ use crate::{ast, parent_container::ParentContainer, prelude::*, DefaultKind, NativeTypeInstance, ValueGenerator}; use chrono::{DateTime, FixedOffset}; use psl::{ + generators::{DEFAULT_CUID_VERSION, DEFAULT_UUID_VERSION}, parser_database::{self as db, walkers, ScalarFieldType, ScalarType}, schema_ast::ast::FieldArity, }; @@ -250,12 +251,18 @@ pub fn dml_default_kind(default_value: &ast::Expression, scalar_type: Option().unwrap()) - .unwrap_or(4); + .unwrap_or(DEFAULT_UUID_VERSION); DefaultKind::Expression(ValueGenerator::new_uuid(version)) } - ast::Expression::Function(funcname, _args, _) if funcname == "cuid" => { - DefaultKind::Expression(ValueGenerator::new_cuid()) + ast::Expression::Function(funcname, args, _) if funcname == "cuid" => { + let version = args + .arguments + .first() + .and_then(|arg| arg.value.as_numeric_value()) + .map(|(val, _)| val.parse::().unwrap()) + .unwrap_or(DEFAULT_CUID_VERSION); + DefaultKind::Expression(ValueGenerator::new_cuid(version)) } ast::Expression::Function(funcname, args, _) if funcname == "nanoid" => { DefaultKind::Expression(ValueGenerator::new_nanoid( diff --git a/query-engine/query-structure/tests/datamodel_converter_tests.rs b/query-engine/query-structure/tests/datamodel_converter_tests.rs index 31f00976378e..8a84b74b2622 100644 --- a/query-engine/query-structure/tests/datamodel_converter_tests.rs +++ b/query-engine/query-structure/tests/datamodel_converter_tests.rs @@ -333,6 +333,32 @@ fn uuid_fields_must_work() { "#, ); + let model = datamodel.assert_model("Test"); + model + .assert_scalar_field("id") + .assert_type_identifier(TypeIdentifier::String); + + let datamodel = convert( + r#" + model Test { + id String @id @default(uuid(4)) + } + "#, + ); + + let model = datamodel.assert_model("Test"); + model + .assert_scalar_field("id") + .assert_type_identifier(TypeIdentifier::String); + + let datamodel = convert( + r#" + model Test { + id String @id @default(uuid(7)) + } + "#, + ); + let model = datamodel.assert_model("Test"); model .assert_scalar_field("id") @@ -349,6 +375,32 @@ fn cuid_fields_must_work() { "#, ); + let model = datamodel.assert_model("Test"); + model + .assert_scalar_field("id") + .assert_type_identifier(TypeIdentifier::String); + + let datamodel = convert( + r#" + model Test { + id String @id @default(cuid(1)) + } + "#, + ); + + let model = datamodel.assert_model("Test"); + model + .assert_scalar_field("id") + .assert_type_identifier(TypeIdentifier::String); + + let datamodel = convert( + r#" + model Test { + id String @id @default(cuid(2)) + } + "#, + ); + let model = datamodel.assert_model("Test"); model .assert_scalar_field("id") diff --git a/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/default.rs b/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/default.rs index dcb04371439b..e19a946ad5a4 100644 --- a/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/default.rs +++ b/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/default.rs @@ -14,8 +14,8 @@ pub(crate) enum DefaultKind<'a> { Sequence(&'a sql::postgres::Sequence), DbGenerated(Option<&'a str>), Autoincrement, - Uuid, - Cuid, + Uuid(Option), + Cuid(Option), Nanoid(Option), Now, String(&'a str), @@ -116,8 +116,24 @@ impl<'a> DefaultValuePair<'a> { }, (None, sql::ColumnTypeFamily::String | sql::ColumnTypeFamily::Uuid) => match self.previous { - Some(previous) if previous.is_cuid() => Some(DefaultKind::Cuid), - Some(previous) if previous.is_uuid() => Some(DefaultKind::Uuid), + Some(previous) if previous.is_cuid() => { + let version = previous.value().as_function().and_then(|(_, args, _)| { + args.arguments + .first() + .map(|arg| arg.value.as_numeric_value().unwrap().0.parse::().unwrap()) + }); + + Some(DefaultKind::Cuid(version)) + } + Some(previous) if previous.is_uuid() => { + let version = previous.value().as_function().and_then(|(_, args, _)| { + args.arguments + .first() + .map(|arg| arg.value.as_numeric_value().unwrap().0.parse::().unwrap()) + }); + + Some(DefaultKind::Uuid(version)) + } Some(previous) if previous.is_nanoid() => { let length = previous.value().as_function().and_then(|(_, args, _)| { args.arguments diff --git a/schema-engine/connectors/sql-schema-connector/src/introspection/rendering/defaults.rs b/schema-engine/connectors/sql-schema-connector/src/introspection/rendering/defaults.rs index ad8ef45e192d..748cf77c6883 100644 --- a/schema-engine/connectors/sql-schema-connector/src/introspection/rendering/defaults.rs +++ b/schema-engine/connectors/sql-schema-connector/src/introspection/rendering/defaults.rs @@ -45,8 +45,24 @@ pub(crate) fn render(default: DefaultValuePair<'_>) -> Option Some(renderer::DefaultValue::function(Function::new("autoincrement"))), - DefaultKind::Uuid => Some(renderer::DefaultValue::function(Function::new("uuid"))), - DefaultKind::Cuid => Some(renderer::DefaultValue::function(Function::new("cuid"))), + DefaultKind::Uuid(version) => { + let mut fun = Function::new("uuid"); + + if let Some(version) = version { + fun.push_param(Value::from(Constant::from(version))); + } + + Some(renderer::DefaultValue::function(fun)) + } + DefaultKind::Cuid(version) => { + let mut fun = Function::new("cuid"); + + if let Some(version) = version { + fun.push_param(Value::from(Constant::from(version))); + } + + Some(renderer::DefaultValue::function(fun)) + } DefaultKind::Nanoid(length) => { let mut fun = Function::new("nanoid"); diff --git a/schema-engine/sql-introspection-tests/tests/re_introspection/mod.rs b/schema-engine/sql-introspection-tests/tests/re_introspection/mod.rs index a7c6e1897b95..f4b23e31d8d0 100644 --- a/schema-engine/sql-introspection-tests/tests/re_introspection/mod.rs +++ b/schema-engine/sql-introspection-tests/tests/re_introspection/mod.rs @@ -969,16 +969,19 @@ async fn multiple_changed_relation_names_due_to_mapped_models(api: &mut TestApi) } #[test_connector(tags(Postgres), exclude(CockroachDb))] -async fn virtual_cuid_default(api: &mut TestApi) { +async fn virtual_uid_default(api: &mut TestApi) { api.barrel() .execute(|migration| { migration.create_table("User", |t| { t.add_column("id", types::varchar(30).primary(true)); - t.add_column("non_id", types::varchar(30)); + t.add_column("non_id_1", types::varchar(30)); + t.add_column("non_id_2", types::varchar(30)); }); migration.create_table("User2", |t| { t.add_column("id", types::varchar(36).primary(true)); + t.add_column("non_id_1", types::varchar(36)); + t.add_column("non_id_2", types::varchar(36)); }); migration.create_table("User3", |t| { @@ -995,11 +998,14 @@ async fn virtual_cuid_default(api: &mut TestApi) { let input_dm = r#" model User { id String @id @default(cuid()) @db.VarChar(30) - non_id String @default(cuid()) @db.VarChar(30) + non_id_1 String @default(cuid(1)) @db.VarChar(30) + non_id_2 String @default(cuid(2)) @db.VarChar(30) } model User2 { id String @id @default(uuid()) @db.VarChar(36) + non_id_1 String @default(uuid(4)) @db.VarChar(36) + non_id_2 String @default(uuid(7)) @db.VarChar(36) } model User3 { @@ -1010,11 +1016,14 @@ async fn virtual_cuid_default(api: &mut TestApi) { let final_dm = indoc! {r#" model User { id String @id @default(cuid()) @db.VarChar(30) - non_id String @default(cuid()) @db.VarChar(30) + non_id_1 String @default(cuid(1)) @db.VarChar(30) + non_id_2 String @default(cuid(2)) @db.VarChar(30) } model User2 { id String @id @default(uuid()) @db.VarChar(36) + non_id_1 String @default(uuid(4)) @db.VarChar(36) + non_id_2 String @default(uuid(7)) @db.VarChar(36) } model User3 {