Skip to content

Commit

Permalink
feat(psl): add cuid(2) support, fix uuid(7) re-introspection
Browse files Browse the repository at this point in the history
  • Loading branch information
jkomyno committed Nov 15, 2024
1 parent 378004b commit 454112f
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 37 deletions.
41 changes: 31 additions & 10 deletions psl/parser-database/src/attributes/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
ast::{self, WithName},
coerce,
context::Context,
generators::{CUID_SUPPORTED_VERSIONS, UUID_SUPPORTED_VERSIONS},
types::{DefaultAttribute, ScalarFieldType, ScalarType},
DatamodelError, ScalarFieldId, StringId,
};
Expand Down Expand Up @@ -197,10 +198,10 @@ fn validate_model_builtin_scalar_type_default(
validate_empty_function_args(funcname, &funcargs.arguments, accept, ctx)
}
(ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_CUID => {
validate_empty_function_args(funcname, &funcargs.arguments, accept, ctx)
validate_uid_int_args(funcname, &funcargs.arguments, &CUID_SUPPORTED_VERSIONS, accept, ctx)
}
(ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_UUID => {
validate_uuid_args(&funcargs.arguments, accept, ctx)
validate_uid_int_args(funcname, &funcargs.arguments, &UUID_SUPPORTED_VERSIONS, accept, ctx)
}
(ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_NANOID => {
validate_nanoid_args(&funcargs.arguments, accept, ctx)
Expand Down Expand Up @@ -244,10 +245,10 @@ fn validate_composite_builtin_scalar_type_default(
match (scalar_type, value) {
// Functions
(ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_CUID => {
validate_empty_function_args(funcname, &funcargs.arguments, accept, ctx)
validate_uid_int_args(funcname, &funcargs.arguments, &CUID_SUPPORTED_VERSIONS, accept, ctx)
}
(ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_UUID => {
validate_uuid_args(&funcargs.arguments, accept, ctx)
validate_uid_int_args(funcname, &funcargs.arguments, &UUID_SUPPORTED_VERSIONS, accept, ctx)
}
(ScalarType::DateTime, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_NOW => {
validate_empty_function_args(FN_NOW, &funcargs.arguments, accept, ctx)
Expand Down Expand Up @@ -381,18 +382,38 @@ fn validate_dbgenerated_args(args: &[ast::Argument], accept: AcceptFn<'_>, ctx:
}
}

fn validate_uuid_args(args: &[ast::Argument], accept: AcceptFn<'_>, ctx: &mut Context<'_>) {
let mut bail = || ctx.push_attribute_validation_error("`uuid()` takes a single Int argument.");
fn format_valid_values<const N: usize>(valid_values: &[u8; N]) -> String {
match valid_values.len() {
0 => String::new(),
1 => valid_values[0].to_string(),
2 => format!("{} or {}", valid_values[0], valid_values[1]),
_ => {
let (last, rest) = valid_values.split_last().unwrap();
let rest_str = rest.iter().map(|v| v.to_string()).collect::<Vec<_>>().join(", ");
format!("{}, or {}", rest_str, last)
}
}
}

fn validate_uid_int_args<const N: usize>(
fn_name: &str,
args: &[ast::Argument],
valid_values: &[u8; N],
accept: AcceptFn<'_>,
ctx: &mut Context<'_>,
) {
let mut bail = || ctx.push_attribute_validation_error(&format!("`{fn_name}()` takes a single Int argument."));

if args.len() > 1 {
bail()
}

match args.first().map(|arg| &arg.value) {
Some(ast::Expression::NumericValue(val, _)) if ![4u8, 7u8].contains(&val.parse::<u8>().unwrap()) => {
ctx.push_attribute_validation_error(
"`uuid()` takes either no argument, or a single integer argument which is either 4 or 7.",
);
Some(ast::Expression::NumericValue(val, _)) if !valid_values.contains(&val.parse::<u8>().unwrap()) => {
let valid_values_str = format_valid_values(&valid_values);

Check failure on line 413 in psl/parser-database/src/attributes/default.rs

View workflow job for this annotation

GitHub Actions / clippy linting

this expression creates a reference which is immediately dereferenced by the compiler
ctx.push_attribute_validation_error(&format!(
"`{fn_name}()` takes either no argument, or a single integer argument which is either {valid_values_str}.",
));
}
None | Some(ast::Expression::NumericValue(_, _)) => accept(ctx),
_ => bail(),
Expand Down
13 changes: 13 additions & 0 deletions psl/parser-database/src/generators.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//! Convenient access to a ID generator constants, used by Prisma in psl, Query Engine and Schema Engine.
/// Version of the `uuid()` ID generator supported by Prisma.
pub const UUID_SUPPORTED_VERSIONS: [u8; 2] = [4, 7];

/// Version of the `cuid()` ID generator supported by Prisma.
pub const CUID_SUPPORTED_VERSIONS: [u8; 2] = [1, 2];

/// Default version of the `uuid()` ID generator.
pub const DEFAULT_UUID_VERSION: u8 = 4;

/// Default version of the `cuid()` ID generator.
pub const DEFAULT_CUID_VERSION: u8 = 2;
1 change: 1 addition & 0 deletions psl/parser-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod attributes;
mod coerce_expression;
mod context;
mod files;
pub mod generators;
mod ids;
mod interner;
mod names;
Expand Down
2 changes: 1 addition & 1 deletion psl/psl-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub use crate::{
reformat::{reformat, reformat_multiple, reformat_validated_schema_into_single},
};
pub use diagnostics;
pub use parser_database::{self, is_reserved_type_name};
pub use parser_database::{self, generators, is_reserved_type_name};
pub use schema_ast;
pub use set_config_dir::set_config_dir;

Expand Down
1 change: 1 addition & 0 deletions psl/psl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub use psl_core::{
builtin_connectors::{can_have_capability, can_support_relation_load_strategy, has_capability},
datamodel_connector,
diagnostics::{self, Diagnostics},
generators,
is_reserved_type_name,
mcf::config_to_mcf_json_value as get_config,
mcf::{generators_to_json, render_sources_to_json}, // for tests
Expand Down
43 changes: 41 additions & 2 deletions psl/psl/tests/attributes/id_positive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,45 @@ fn should_allow_string_ids_with_cuid() {
model.assert_id_on_fields(&["id"]);
}

#[test]
fn should_allow_string_ids_with_cuid_version_specified() {
let dml = indoc! {r#"
model ModelA {
id String @id @default(cuid(1))
}
model ModelB {
id String @id @default(cuid(2))
}
"#};

let schema = psl::parse_schema(dml).unwrap();

{
let model = schema.assert_has_model("ModelA");

model
.assert_has_scalar_field("id")
.assert_scalar_type(ScalarType::String)
.assert_default_value()
.assert_cuid_version(1);

model.assert_id_on_fields(&["id"]);
}

{
let model = schema.assert_has_model("ModelB");

model
.assert_has_scalar_field("id")
.assert_scalar_type(ScalarType::String)
.assert_default_value()
.assert_cuid_version(2);

model.assert_id_on_fields(&["id"]);
}
}

#[test]
fn should_allow_string_ids_with_uuid() {
let dml = indoc! {r#"
Expand Down Expand Up @@ -91,7 +130,7 @@ fn should_allow_string_ids_with_uuid_version_specified() {
.assert_has_scalar_field("id")
.assert_scalar_type(ScalarType::String)
.assert_default_value()
.assert_uuid();
.assert_uuid_version(4);

model.assert_id_on_fields(&["id"]);
}
Expand All @@ -103,7 +142,7 @@ fn should_allow_string_ids_with_uuid_version_specified() {
.assert_has_scalar_field("id")
.assert_scalar_type(ScalarType::String)
.assert_default_value()
.assert_uuid();
.assert_uuid_version(7);

model.assert_id_on_fields(&["id"]);
}
Expand Down
50 changes: 49 additions & 1 deletion psl/psl/tests/common/asserts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ pub(crate) trait DefaultValueAssert {
fn assert_bytes(&self, val: &[u8]) -> &Self;
fn assert_now(&self) -> &Self;
fn assert_cuid(&self) -> &Self;
fn assert_cuid_version(&self, version: u8) -> &Self;
fn assert_uuid(&self) -> &Self;
fn assert_uuid_version(&self, version: u8) -> &Self;
fn assert_dbgenerated(&self, val: &str) -> &Self;
fn assert_mapped_name(&self, val: &str) -> &Self;
}
Expand Down Expand Up @@ -433,12 +435,24 @@ impl<'a> DefaultValueAssert for walkers::DefaultValueWalker<'a> {
self
}

#[track_caller]
fn assert_cuid_version(&self, version: u8) -> &Self {
self.value().assert_cuid_version(version);
self
}

#[track_caller]
fn assert_uuid(&self) -> &Self {
self.value().assert_uuid();
self
}

#[track_caller]
fn assert_uuid_version(&self, version: u8) -> &Self {
self.value().assert_uuid_version(version);
self
}

#[track_caller]
fn assert_dbgenerated(&self, val: &str) -> &Self {
self.value().assert_dbgenerated(val);
Expand Down Expand Up @@ -623,19 +637,53 @@ impl DefaultValueAssert for ast::Expression {
#[track_caller]
fn assert_cuid(&self) -> &Self {
assert!(
matches!(self, ast::Expression::Function(name, args, _) if name == "cuid" && args.arguments.is_empty())
matches!(self, ast::Expression::Function(name, _, _) if name == "cuid" /* && args.arguments.is_empty() */)
);

self
}

#[track_caller]
fn assert_cuid_version(&self, version: u8) -> &Self {
self.assert_cuid();

if let ast::Expression::Function(_, args, _) = self {
if let ast::Expression::NumericValue(actual, _) = &args.arguments[0].value {
assert_eq!(actual, &format!("{version}"));
} else {
panic!("Expected a numeric value for the version.");
}
} else {
unreachable!();
}

self
}

#[track_caller]
fn assert_uuid(&self) -> &Self {
assert!(matches!(self, ast::Expression::Function(name, _, _) if name == "uuid"));

self
}

#[track_caller]
fn assert_uuid_version(&self, version: u8) -> &Self {
self.assert_uuid();

if let ast::Expression::Function(_, args, _) = self {
if let ast::Expression::NumericValue(actual, _) = &args.arguments[0].value {
assert_eq!(actual, &format!("{version}"));
} else {
panic!("Expected a numeric value for the version.");
}
} else {
unreachable!();
}

self
}

#[track_caller]
fn assert_dbgenerated(&self, val: &str) -> &Self {
match self {
Expand Down
32 changes: 22 additions & 10 deletions query-engine/query-structure/src/default_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ impl ValueGenerator {
ValueGenerator::new("now".to_owned(), vec![]).unwrap()
}

pub fn new_cuid() -> Self {
ValueGenerator::new("cuid".to_owned(), vec![]).unwrap()
pub fn new_cuid(version: u8) -> Self {
ValueGenerator::new(format!("cuid({version})"), vec![]).unwrap()
}

pub fn new_uuid(version: u8) -> Self {
Expand Down Expand Up @@ -239,7 +239,7 @@ impl ValueGenerator {
#[derive(Clone, Copy, PartialEq)]
pub enum ValueGeneratorFn {
Uuid(u8),
Cuid,
Cuid(u8),
Nanoid(Option<u8>),
Now,
Autoincrement,
Expand All @@ -250,7 +250,8 @@ pub enum ValueGeneratorFn {
impl ValueGeneratorFn {
fn new(name: &str) -> std::result::Result<Self, String> {
match name {
"cuid" => Ok(Self::Cuid),
"cuid" | "cuid(2)" => Ok(Self::Cuid(2)),
"cuid(1)" => Ok(Self::Cuid(1)),
"uuid" | "uuid(4)" => Ok(Self::Uuid(4)),
"uuid(7)" => Ok(Self::Uuid(7)),
"now" => Ok(Self::Now),
Expand All @@ -267,7 +268,7 @@ impl ValueGeneratorFn {
fn invoke(&self) -> Option<PrismaValue> {
match self {
Self::Uuid(version) => Some(Self::generate_uuid(*version)),
Self::Cuid => Some(Self::generate_cuid()),
Self::Cuid(version) => Some(Self::generate_cuid(*version)),
Self::Nanoid(length) => Some(Self::generate_nanoid(length)),
Self::Now => Some(Self::generate_now()),
Self::Autoincrement => None,
Expand All @@ -277,9 +278,13 @@ impl ValueGeneratorFn {
}

#[cfg(feature = "default_generators")]
fn generate_cuid() -> PrismaValue {
#[allow(deprecated)]
PrismaValue::String(cuid::cuid().unwrap())
fn generate_cuid(version: u8) -> PrismaValue {
PrismaValue::String(match version {
#[allow(deprecated)]
1 => cuid::cuid().unwrap(),
2 => cuid::cuid2(),
_ => panic!("Unknown `cuid` version: {}", version),
})
}

#[cfg(feature = "default_generators")]
Expand Down Expand Up @@ -358,8 +363,15 @@ mod tests {
}

#[test]
fn default_value_is_cuid() {
let cuid_default = DefaultValue::new_expression(ValueGenerator::new_cuid());
fn default_value_is_cuidv1() {
let cuid_default = DefaultValue::new_expression(ValueGenerator::new_cuid(1));

assert!(cuid_default.is_cuid());
assert!(!cuid_default.is_now());
}

fn default_value_is_cuidv2() {

Check failure on line 373 in query-engine/query-structure/src/default_value.rs

View workflow job for this annotation

GitHub Actions / test

function `default_value_is_cuidv2` is never used
let cuid_default = DefaultValue::new_expression(ValueGenerator::new_cuid(2));

assert!(cuid_default.is_cuid());
assert!(!cuid_default.is_now());
Expand Down
13 changes: 10 additions & 3 deletions query-engine/query-structure/src/field/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{ast, parent_container::ParentContainer, prelude::*, DefaultKind, NativeTypeInstance, ValueGenerator};
use chrono::{DateTime, FixedOffset};
use psl::{
generators::{DEFAULT_CUID_VERSION, DEFAULT_UUID_VERSION},
parser_database::{self as db, walkers, ScalarFieldType, ScalarType},
schema_ast::ast::FieldArity,
};
Expand Down Expand Up @@ -250,12 +251,18 @@ pub fn dml_default_kind(default_value: &ast::Expression, scalar_type: Option<Sca
.first()
.and_then(|arg| arg.value.as_numeric_value())
.map(|(val, _)| val.parse::<u8>().unwrap())
.unwrap_or(4);
.unwrap_or(DEFAULT_UUID_VERSION);

DefaultKind::Expression(ValueGenerator::new_uuid(version))
}
ast::Expression::Function(funcname, _args, _) if funcname == "cuid" => {
DefaultKind::Expression(ValueGenerator::new_cuid())
ast::Expression::Function(funcname, args, _) if funcname == "cuid" => {
let version = args
.arguments
.first()
.and_then(|arg| arg.value.as_numeric_value())
.map(|(val, _)| val.parse::<u8>().unwrap())
.unwrap_or(DEFAULT_CUID_VERSION);
DefaultKind::Expression(ValueGenerator::new_cuid(version))
}
ast::Expression::Function(funcname, args, _) if funcname == "nanoid" => {
DefaultKind::Expression(ValueGenerator::new_nanoid(
Expand Down
Loading

0 comments on commit 454112f

Please sign in to comment.