Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into fix_order_by_literal
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Dec 4, 2023
2 parents e21558d + 0bcf462 commit a7ac525
Show file tree
Hide file tree
Showing 12 changed files with 348 additions and 85 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ in-memory format. [Python Bindings](https://github.com/apache/arrow-datafusion-p
Here are links to some important information

- [Project Site](https://arrow.apache.org/datafusion)
- [Installation](https://arrow.apache.org/datafusion/user-guide/cli.html#installation)
- [Rust Getting Started](https://arrow.apache.org/datafusion/user-guide/example-usage.html)
- [Rust DataFrame API](https://arrow.apache.org/datafusion/user-guide/dataframe.html)
- [Rust API docs](https://docs.rs/datafusion/latest/datafusion)
Expand Down
2 changes: 2 additions & 0 deletions datafusion/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod dfschema;
mod error;
mod functional_dependencies;
mod join_type;
mod param_value;
#[cfg(feature = "pyarrow")]
mod pyarrow;
mod schema_reference;
Expand Down Expand Up @@ -59,6 +60,7 @@ pub use functional_dependencies::{
Constraints, Dependency, FunctionalDependence, FunctionalDependencies,
};
pub use join_type::{JoinConstraint, JoinSide, JoinType};
pub use param_value::ParamValues;
pub use scalar::{ScalarType, ScalarValue};
pub use schema_reference::{OwnedSchemaReference, SchemaReference};
pub use stats::{ColumnStatistics, Statistics};
Expand Down
149 changes: 149 additions & 0 deletions datafusion/common/src/param_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::error::{_internal_err, _plan_err};
use crate::{DataFusionError, Result, ScalarValue};
use arrow_schema::DataType;
use std::collections::HashMap;

/// The parameter value corresponding to the placeholder
#[derive(Debug, Clone)]
pub enum ParamValues {
/// for positional query parameters, like select * from test where a > $1 and b = $2
LIST(Vec<ScalarValue>),
/// for named query parameters, like select * from test where a > $foo and b = $goo
MAP(HashMap<String, ScalarValue>),
}

impl ParamValues {
/// Verify parameter list length and type
pub fn verify(&self, expect: &Vec<DataType>) -> Result<()> {
match self {
ParamValues::LIST(list) => {
// Verify if the number of params matches the number of values
if expect.len() != list.len() {
return _plan_err!(
"Expected {} parameters, got {}",
expect.len(),
list.len()
);
}

// Verify if the types of the params matches the types of the values
let iter = expect.iter().zip(list.iter());
for (i, (param_type, value)) in iter.enumerate() {
if *param_type != value.data_type() {
return _plan_err!(
"Expected parameter of type {:?}, got {:?} at index {}",
param_type,
value.data_type(),
i
);
}
}
Ok(())
}
ParamValues::MAP(_) => {
// If it is a named query, variables can be reused,
// but the lengths are not necessarily equal
Ok(())
}
}
}

pub fn get_placeholders_with_values(
&self,
id: &String,
data_type: &Option<DataType>,
) -> Result<ScalarValue> {
match self {
ParamValues::LIST(list) => {
if id.is_empty() || id == "$0" {
return _plan_err!("Empty placeholder id");
}
// convert id (in format $1, $2, ..) to idx (0, 1, ..)
let idx = id[1..].parse::<usize>().map_err(|e| {
DataFusionError::Internal(format!(
"Failed to parse placeholder id: {e}"
))
})? - 1;
// value at the idx-th position in param_values should be the value for the placeholder
let value = list.get(idx).ok_or_else(|| {
DataFusionError::Internal(format!(
"No value found for placeholder with id {id}"
))
})?;
// check if the data type of the value matches the data type of the placeholder
if Some(value.data_type()) != *data_type {
return _internal_err!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.data_type()
);
}
Ok(value.clone())
}
ParamValues::MAP(map) => {
// convert name (in format $a, $b, ..) to mapped values (a, b, ..)
let name = &id[1..];
// value at the name position in param_values should be the value for the placeholder
let value = map.get(name).ok_or_else(|| {
DataFusionError::Internal(format!(
"No value found for placeholder with name {id}"
))
})?;
// check if the data type of the value matches the data type of the placeholder
if Some(value.data_type()) != *data_type {
return _internal_err!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.data_type()
);
}
Ok(value.clone())
}
}
}
}

impl From<Vec<ScalarValue>> for ParamValues {
fn from(value: Vec<ScalarValue>) -> Self {
Self::LIST(value)
}
}

impl<K> From<Vec<(K, ScalarValue)>> for ParamValues
where
K: Into<String>,
{
fn from(value: Vec<(K, ScalarValue)>) -> Self {
let value: HashMap<String, ScalarValue> =
value.into_iter().map(|(k, v)| (k.into(), v)).collect();
Self::MAP(value)
}
}

impl<K> From<HashMap<K, ScalarValue>> for ParamValues
where
K: Into<String>,
{
fn from(value: HashMap<K, ScalarValue>) -> Self {
let value: HashMap<String, ScalarValue> =
value.into_iter().map(|(k, v)| (k.into(), v)).collect();
Self::MAP(value)
}
}
30 changes: 26 additions & 4 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions;
use datafusion_common::file_options::json_writer::JsonWriterOptions;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::{
DataFusionError, FileType, FileTypeWriterOptions, SchemaError, UnnestOptions,
DataFusionError, FileType, FileTypeWriterOptions, ParamValues, SchemaError,
UnnestOptions,
};
use datafusion_expr::dml::CopyOptions;

use datafusion_common::{Column, DFSchema, ScalarValue};
use datafusion_common::{Column, DFSchema};
use datafusion_expr::{
avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
Expand Down Expand Up @@ -1227,11 +1228,32 @@ impl DataFrame {
/// ],
/// &results
/// );
/// // Note you can also provide named parameters
/// let results = ctx
/// .sql("SELECT a FROM example WHERE b = $my_param")
/// .await?
/// // replace $my_param with value 2
/// // Note you can also use a HashMap as well
/// .with_param_values(vec![
/// ("my_param", ScalarValue::from(2i64))
/// ])?
/// .collect()
/// .await?;
/// assert_batches_eq!(
/// &[
/// "+---+",
/// "| a |",
/// "+---+",
/// "| 1 |",
/// "+---+",
/// ],
/// &results
/// );
/// # Ok(())
/// # }
/// ```
pub fn with_param_values(self, param_values: Vec<ScalarValue>) -> Result<Self> {
let plan = self.plan.with_param_values(param_values)?;
pub fn with_param_values(self, query_values: impl Into<ParamValues>) -> Result<Self> {
let plan = self.plan.with_param_values(query_values)?;
Ok(Self::new(self.session_state, plan))
}

Expand Down
47 changes: 47 additions & 0 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,53 @@ async fn test_prepare_statement() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_named_query_parameters() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;

// sql to statement then to logical plan with parameters
// c1 defined as UINT32, c2 defined as UInt64
let results = ctx
.sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo")
.await?
.with_param_values(vec![
("foo", ScalarValue::UInt32(Some(3))),
("coo", ScalarValue::UInt32(Some(0))),
])?
.collect()
.await?;
let expected = vec![
"+----+----+",
"| c1 | c2 |",
"+----+----+",
"| 1 | 1 |",
"| 1 | 2 |",
"| 1 | 3 |",
"| 1 | 4 |",
"| 1 | 5 |",
"| 1 | 6 |",
"| 1 | 7 |",
"| 1 | 8 |",
"| 1 | 9 |",
"| 1 | 10 |",
"| 2 | 1 |",
"| 2 | 2 |",
"| 2 | 3 |",
"| 2 | 4 |",
"| 2 | 5 |",
"| 2 | 6 |",
"| 2 | 7 |",
"| 2 | 8 |",
"| 2 | 9 |",
"| 2 | 10 |",
"+----+----+",
];
assert_batches_sorted_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn parallel_query_with_filter() -> Result<()> {
let tmp_dir = TempDir::new()?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ impl InSubquery {
}
}

/// Placeholder, representing bind parameter values such as `$1`.
/// Placeholder, representing bind parameter values such as `$1` or `$name`.
///
/// The type of these parameters is inferred using [`Expr::infer_placeholder_types`]
/// or can be specified directly using `PREPARE` statements.
Expand Down
Loading

0 comments on commit a7ac525

Please sign in to comment.