Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle native types for joined queries #4546

Merged
merged 6 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion psl/builtin-connectors/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ lsp-types = "0.91.1"
once_cell = "1.3"
regex = "1"
chrono = { version = "0.4.6", default_features = false }
Weakky marked this conversation as resolved.
Show resolved Hide resolved
bigdecimal = "0.3"

20 changes: 0 additions & 20 deletions psl/builtin-connectors/src/cockroach_datamodel_connector.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
mod native_types;
mod validations;

use bigdecimal::{BigDecimal, ParseBigDecimalError};
pub use native_types::CockroachType;

use chrono::*;
Expand Down Expand Up @@ -332,25 +331,6 @@ impl Connector for CockroachDatamodelConnector {
),
}
}

fn parse_json_decimal(
&self,
str: &str,
nt: Option<NativeTypeInstance>,
) -> Result<BigDecimal, ParseBigDecimalError> {
let native_type: Option<&CockroachType> = nt.as_ref().map(|nt| nt.downcast_ref());

match native_type {
Some(pt) => match pt {
CockroachType::Decimal(_) => crate::utils::parse_decimal(str),
_ => unreachable!(),
},
None => self.parse_json_decimal(
str,
Some(self.default_native_type_for_scalar_type(&ScalarType::Decimal)),
),
}
}
}

/// An `@default(sequence())` function.
Expand Down
21 changes: 0 additions & 21 deletions psl/builtin-connectors/src/postgres_datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ mod validations;

pub use native_types::PostgresType;

use bigdecimal::{BigDecimal, ParseBigDecimalError};
use chrono::*;
use enumflags2::BitFlags;
use lsp_types::{CompletionItem, CompletionItemKind, CompletionList, InsertTextFormat};
Expand Down Expand Up @@ -592,26 +591,6 @@ impl Connector for PostgresDatamodelConnector {
),
}
}

fn parse_json_decimal(
&self,
str: &str,
nt: Option<NativeTypeInstance>,
) -> Result<BigDecimal, ParseBigDecimalError> {
let native_type: Option<&PostgresType> = nt.as_ref().map(|nt| nt.downcast_ref());

match native_type {
Some(pt) => match pt {
Decimal(_) => crate::utils::parse_decimal(str),
Money => crate::utils::parse_money(str),
_ => unreachable!(),
},
None => self.parse_json_decimal(
str,
Some(self.default_native_type_for_scalar_type(&ScalarType::Decimal)),
),
}
}
}

fn allowed_index_operator_classes(algo: IndexAlgorithm, field: walkers::ScalarFieldWalker<'_>) -> Vec<OperatorClass> {
Expand Down
11 changes: 0 additions & 11 deletions psl/builtin-connectors/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use bigdecimal::{BigDecimal, ParseBigDecimalError};
use chrono::*;
use std::str::FromStr;

pub(crate) fn parse_date(str: &str) -> Result<DateTime<FixedOffset>, chrono::ParseError> {
chrono::NaiveDate::parse_from_str(str, "%Y-%m-%d")
Expand Down Expand Up @@ -37,12 +35,3 @@ pub(crate) fn parse_timetz(str: &str) -> Result<DateTime<FixedOffset>, chrono::P

parse_time(time_without_tz)
}

pub(crate) fn parse_money(str: &str) -> Result<BigDecimal, ParseBigDecimalError> {
// We strip out the currency sign from the string.
BigDecimal::from_str(&str[1..]).map(|bd| bd.normalized())
}

pub(crate) fn parse_decimal(str: &str) -> Result<BigDecimal, ParseBigDecimalError> {
BigDecimal::from_str(str).map(|bd| bd.normalized())
}
9 changes: 0 additions & 9 deletions psl/psl-core/src/datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ pub use self::{
};

use crate::{configuration::DatasourceConnectorData, Configuration, Datasource, PreviewFeature};
use bigdecimal::{BigDecimal, ParseBigDecimalError};
use chrono::{DateTime, FixedOffset};
use diagnostics::{DatamodelError, Diagnostics, NativeTypeErrorFactory, Span};
use enumflags2::BitFlags;
Expand Down Expand Up @@ -369,14 +368,6 @@ pub trait Connector: Send + Sync {
) -> chrono::ParseResult<DateTime<FixedOffset>> {
unreachable!("This method is only implemented on connectors with lateral join support.")
}

fn parse_json_decimal(
&self,
_str: &str,
_nt: Option<NativeTypeInstance>,
) -> Result<BigDecimal, ParseBigDecimalError> {
unreachable!("This method is only implemented on connectors with lateral join support.")
}
}

#[derive(Copy, Clone, Debug, PartialEq)]
Expand Down
9 changes: 8 additions & 1 deletion quaint/src/ast/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Aliasable;
use super::{values::NativeColumnType, Aliasable};
use crate::{
ast::{Expression, ExpressionKind, Table},
Value,
Expand Down Expand Up @@ -32,6 +32,8 @@ pub struct Column<'a> {
pub(crate) alias: Option<Cow<'a, str>>,
pub(crate) default: Option<DefaultValue<'a>>,
pub(crate) type_family: Option<TypeFamily>,
/// The underlying native type of the column.
pub(crate) native_type: Option<NativeColumnType<'a>>,
/// Whether the column is an enum.
pub(crate) is_enum: bool,
/// Whether the column is a (scalar) list.
Expand Down Expand Up @@ -130,6 +132,11 @@ impl<'a> Column<'a> {
.map(|d| d == &DefaultValue::Generated)
.unwrap_or(false)
}

pub fn native_column_type<T: Into<NativeColumnType<'a>>>(mut self, native_type: Option<T>) -> Column<'a> {
self.native_type = native_type.map(|nt| nt.into());
self
}
}

impl<'a> From<Column<'a>> for Expression<'a> {
Expand Down
35 changes: 34 additions & 1 deletion quaint/src/visitor/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ pub struct Postgres<'a> {
parameters: Vec<Value<'a>>,
}

impl<'a> Postgres<'a> {
fn visit_json_build_obj_expr(&mut self, expr: Expression<'a>) -> crate::Result<()> {
dbg!(&expr);
Weakky marked this conversation as resolved.
Show resolved Hide resolved
match expr.kind() {
ExpressionKind::Column(col) => match (col.type_family.as_ref(), col.native_type.as_deref()) {
(Some(TypeFamily::Decimal(_)), Some("MONEY")) => {
Weakky marked this conversation as resolved.
Show resolved Hide resolved
self.visit_expression(expr)?;
self.write("::numeric")?;

Ok(())
}
_ => self.visit_expression(expr),
},
_ => self.visit_expression(expr),
}
}
}

impl<'a> Visitor<'a> for Postgres<'a> {
const C_BACKTICK_OPEN: &'static str = "\"";
const C_BACKTICK_CLOSE: &'static str = "\"";
Expand Down Expand Up @@ -534,7 +552,7 @@ impl<'a> Visitor<'a> for Postgres<'a> {
while let Some((name, expr)) = chunk.next() {
s.visit_raw_value(Value::text(name))?;
s.write(", ")?;
s.visit_expression(expr)?;
s.visit_json_build_obj_expr(expr)?;
if chunk.peek().is_some() {
s.write(", ")?;
}
Expand Down Expand Up @@ -1290,6 +1308,21 @@ mod tests {
);
}

#[test]
fn money() {
Weakky marked this conversation as resolved.
Show resolved Hide resolved
let build_json = json_build_object(vec![(
"money".into(),
Column::from("money")
.native_column_type(Some("money"))
.type_family(TypeFamily::Decimal(None))
.into(),
)]);
let query = Select::default().value(build_json);
let (sql, _) = Postgres::build(query).unwrap();

assert_eq!(sql, "SELECT JSONB_BUILD_OBJECT('money', \"money\"::numeric)");
}

fn build_json_object(num_fields: u32) -> JsonBuildObject<'static> {
let fields = (1..=num_fields)
.map(|i| (format!("f{i}").into(), Expression::from(i as i64)))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::io;

use bigdecimal::{BigDecimal, FromPrimitive};
use bigdecimal::{BigDecimal, FromPrimitive, ParseBigDecimalError};
use itertools::{Either, Itertools};
use query_structure::*;
use std::{io, str::FromStr};

use crate::{query_arguments_ext::QueryArgumentsExt, SqlError};

Expand Down Expand Up @@ -144,7 +143,7 @@ pub(crate) fn coerce_json_scalar_to_pv(value: serde_json::Value, sf: &ScalarFiel
Ok(PrismaValue::DateTime(res))
}
TypeIdentifier::Decimal => {
let res = sf.parse_json_decimal(&s).map_err(|err| {
let res = parse_decimal(&s).map_err(|err| {
build_conversion_error_with_reason(
sf,
&format!("String({s})"),
Expand Down Expand Up @@ -215,3 +214,7 @@ fn build_conversion_error_with_reason(sf: &ScalarField, from: &str, to: &str, re

SqlError::ConversionError(error.into())
}

fn parse_decimal(str: &str) -> std::result::Result<BigDecimal, ParseBigDecimalError> {
BigDecimal::from_str(str).map(|bd| bd.normalized())
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ impl AsColumn for ScalarField {

Column::from((full_table_name, col))
.type_family(self.type_family())
.native_column_type(self.native_type().map(|nt| nt.name()))
.set_is_enum(self.type_identifier().is_enum())
.set_is_list(self.is_list())
.default(quaint::ast::DefaultValue::Generated)
Expand Down
8 changes: 0 additions & 8 deletions query-engine/query-structure/src/field/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::{ast, parent_container::ParentContainer, prelude::*, DefaultKind, NativeTypeInstance, ValueGenerator};
use bigdecimal::{BigDecimal, ParseBigDecimalError};
use chrono::{DateTime, FixedOffset};
use psl::{
parser_database::{walkers, ScalarFieldType, ScalarType},
Expand Down Expand Up @@ -179,13 +178,6 @@ impl ScalarField {
connector.parse_json_datetime(value, nt)
}

pub fn parse_json_decimal(&self, value: &str) -> Result<BigDecimal, ParseBigDecimalError> {
let nt = self.native_type().map(|nt| nt.native_type);
let connector = self.dm.schema.connector;

connector.parse_json_decimal(value, nt)
}

pub fn is_autoincrement(&self) -> bool {
match self.id {
ScalarFieldId::InModel(id) => self.dm.walk(id).is_autoincrement(),
Expand Down