Skip to content

Commit

Permalink
fixup! [substrait] Add support for ExtensionTable
Browse files Browse the repository at this point in the history
  • Loading branch information
ccciudatu committed Jan 8, 2025
1 parent 20ba857 commit a3f89c8
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 97 deletions.
34 changes: 18 additions & 16 deletions datafusion/expr/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub trait SerializerRegistry: Debug + Send + Sync {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
) -> Result<NamedBytes> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
Expand All @@ -143,34 +143,36 @@ pub trait SerializerRegistry: Debug + Send + Sync {
/// bytes.
fn deserialize_logical_plan(
&self,
name: &str,
_bytes: &[u8],
NamedBytes(qualifier, _bytes): &NamedBytes,
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
not_impl_err!(
"Deserializing user defined logical plan node `{name}` is not supported"
"Deserializing user defined logical plan node `{qualifier}` is not supported"
)
}

/// Serialized table definition for UDTFs or manually registered table providers that can't be
/// marshaled by reference. Should return some benign error for regular tables that can be
/// found/restored by name in the destination execution context.
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
not_impl_err!("No custom table support")
/// Serialized table definition for UDTFs or some other table provider implementation that
/// can't be marshaled by reference.
fn serialize_custom_table(
&self,
_table: &dyn TableSource,
) -> Result<Option<NamedBytes>> {
Ok(None)
}

/// Deserialize the custom table with the given name.
/// Note: more often than not, the name can't be used as a discriminator if multiple different
/// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true
/// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`).
/// Deserialize a custom table.
fn deserialize_custom_table(
&self,
name: &str,
_bytes: &[u8],
NamedBytes(qualifier, _bytes): &NamedBytes,
) -> Result<Arc<dyn TableSource>> {
not_impl_err!("Deserializing custom table `{name}` is not supported")
not_impl_err!("Deserializing custom table `{qualifier}` is not supported")
}
}

/// A sequence of bytes with a string qualifier. Meant to encapsulate serialized extensions
/// that need to carry their type, e.g. the `type_url` in `protobuf::Any`.
#[derive(Debug, Clone)]
pub struct NamedBytes(pub String, pub Vec<u8>);

/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
#[derive(Default, Debug)]
pub struct MemoryFunctionRegistry {
Expand Down
91 changes: 49 additions & 42 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};

use datafusion::logical_expr::{
Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableScan, TryCast, Values,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableSource, TryCast, Values,
};
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression as substrait_expression;
Expand Down Expand Up @@ -61,6 +61,7 @@ use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::execution::{FunctionRegistry, SessionState};
use datafusion::logical_expr::builder::project;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::registry::NamedBytes;
use datafusion::logical_expr::{
col, expr, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
Expand Down Expand Up @@ -462,9 +463,7 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
_schema: &DFSchema,
_projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
) -> Result<Arc<dyn TableSource>> {
if let Some(ext_detail) = extension_table.detail.as_ref() {
substrait_err!(
"Missing handler for extension table: {}",
Expand Down Expand Up @@ -548,10 +547,12 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
let Some(ext_detail) = &rel.detail else {
return substrait_err!("Unexpected empty detail in ExtensionLeafRel");
};
let named_bytes =
NamedBytes(ext_detail.type_url.to_owned(), ext_detail.value.to_vec());
let plan = self
.state
.serializer_registry()
.deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
.deserialize_logical_plan(&named_bytes)?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
}

Expand All @@ -562,10 +563,12 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
let Some(ext_detail) = &rel.detail else {
return substrait_err!("Unexpected empty detail in ExtensionSingleRel");
};
let named_bytes =
NamedBytes(ext_detail.type_url.to_owned(), ext_detail.value.to_vec());
let plan = self
.state
.serializer_registry()
.deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
.deserialize_logical_plan(&named_bytes)?;
let Some(input_rel) = &rel.input else {
return substrait_err!(
"ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead"
Expand All @@ -583,10 +586,12 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
let Some(ext_detail) = &rel.detail else {
return substrait_err!("Unexpected empty detail in ExtensionMultiRel");
};
let named_bytes =
NamedBytes(ext_detail.type_url.to_owned(), ext_detail.value.to_vec());
let plan = self
.state
.serializer_registry()
.deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
.deserialize_logical_plan(&named_bytes)?;
let mut inputs = Vec::with_capacity(rel.inputs.len());
for input in &rel.inputs {
let input_plan = self.consume_rel(input).await?;
Expand All @@ -599,24 +604,13 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
schema: &DFSchema,
projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
) -> Result<Arc<dyn TableSource>> {
if let Some(ext_detail) = &extension_table.detail {
let source = self
.state
let named_bytes =
NamedBytes(ext_detail.type_url.to_owned(), ext_detail.value.to_vec());
self.state
.serializer_registry()
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?;
let table_name = ext_detail
.type_url
.rsplit_once('/')
.map(|(_, name)| name)
.unwrap_or(&ext_detail.type_url);
let table_scan = TableScan::try_new(table_name, source, None, vec![], None)?;
let plan = LogicalPlan::TableScan(table_scan);
ensure_schema_compatibility(plan.schema(), schema.clone())?;
let schema = apply_masking(schema.clone(), projection)?;
apply_projection(plan, schema)
.deserialize_custom_table(&named_bytes)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
Expand Down Expand Up @@ -1366,26 +1360,14 @@ pub async fn from_read_rel(
read: &ReadRel,
) -> Result<LogicalPlan> {
async fn read_with_schema(
consumer: &impl SubstraitConsumer,
table_ref: TableReference,
table_source: Arc<dyn TableSource>,
schema: DFSchema,
projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
let schema = schema.replace_qualifier(table_ref.clone());

let plan = {
let provider = match consumer.resolve_table_ref(&table_ref).await? {
Some(ref provider) => Arc::clone(provider),
_ => return plan_err!("No table named '{table_ref}'"),
};

LogicalPlanBuilder::scan(
table_ref,
provider_as_source(Arc::clone(&provider)),
None,
)?
.build()?
};
let plan = { LogicalPlanBuilder::scan(table_ref, table_source, None)?.build()? };

ensure_schema_compatibility(plan.schema(), schema.clone())?;

Expand All @@ -1394,6 +1376,17 @@ pub async fn from_read_rel(
apply_projection(plan, schema)
}

async fn table_source(
consumer: &impl SubstraitConsumer,
table_ref: &TableReference,
) -> Result<Arc<dyn TableSource>> {
if let Some(provider) = consumer.resolve_table_ref(table_ref).await? {
Ok(provider_as_source(provider))
} else {
plan_err!("No table named '{table_ref}'")
}
}

let named_struct = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Read Relation")
})?;
Expand All @@ -1419,10 +1412,10 @@ pub async fn from_read_rel(
table: nt.names[2].clone().into(),
},
};

let table_source = table_source(consumer, &table_reference).await?;
read_with_schema(
consumer,
table_reference,
table_source,
substrait_schema,
&read.projection,
)
Expand Down Expand Up @@ -1501,17 +1494,31 @@ pub async fn from_read_rel(
let name = filename.unwrap();
// directly use unwrap here since we could determine it is a valid one
let table_reference = TableReference::Bare { table: name.into() };
let table_source = table_source(consumer, &table_reference).await?;

read_with_schema(
consumer,
table_reference,
table_source,
substrait_schema,
&read.projection,
)
.await
}
Some(ReadType::ExtensionTable(ext)) => {
consumer.consume_extension_table(ext, &substrait_schema, &read.projection)
let name_hint = read
.common
.as_ref()
.and_then(|rel_common| rel_common.hint.as_ref())
.map(|hint| hint.alias.as_str())
.filter(|alias| !alias.is_empty());
let table_name = name_hint.unwrap_or("tmp_table");
read_with_schema(
TableReference::from(table_name),
consumer.consume_extension_table(ext)?,
substrait_schema,
&read.projection,
)
.await
}
None => {
substrait_err!("Unexpected empty read_type")
Expand Down Expand Up @@ -1917,7 +1924,7 @@ pub async fn from_substrait_sorts(
},
None => not_impl_err!("Sort without sort kind is invalid"),
};
let (asc, nulls_first) = asc_nullfirst.unwrap();
let (asc, nulls_first) = asc_nullfirst?;
sorts.push(Sort {
expr,
asc,
Expand Down
Loading

0 comments on commit a3f89c8

Please sign in to comment.