Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ message ListingTableScanNode {
ProjectionColumns projection = 4;
datafusion_common.Schema schema = 5;
repeated LogicalExprNode filters = 6;
repeated string table_partition_cols = 7;
repeated PartitionColumn table_partition_cols = 7;
bool collect_stat = 8;
uint32 target_partitions = 9;
oneof FileFormatType {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

76 changes: 54 additions & 22 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::{
};

use crate::protobuf::{proto_error, ToProtoError};
use arrow::datatypes::{DataType, Schema, SchemaRef};
use arrow::datatypes::{DataType, Schema, SchemaBuilder, SchemaRef};
use datafusion::datasource::cte_worktable::CteWorkTable;
#[cfg(feature = "avro")]
use datafusion::datasource::file_format::avro::AvroFormat;
Expand Down Expand Up @@ -458,23 +458,25 @@ impl AsLogicalPlan for LogicalPlanNode {
.map(ListingTableUrl::parse)
.collect::<Result<Vec<_>, _>>()?;

let partition_columns = scan
.table_partition_cols
.iter()
.map(|col| {
let Some(arrow_type) = col.arrow_type.as_ref() else {
return Err(proto_error(
"Missing Arrow type in partition columns".to_string(),
));
};
let arrow_type = DataType::try_from(arrow_type).map_err(|e| {
proto_error(format!("Received an unknown ArrowType: {}", e))
})?;
Ok((col.name.clone(), arrow_type))
})
.collect::<Result<Vec<_>>>()?;

let options = ListingOptions::new(file_format)
.with_file_extension(&scan.file_extension)
.with_table_partition_cols(
scan.table_partition_cols
.iter()
.map(|col| {
(
col.clone(),
schema
.field_with_name(col)
.unwrap()
.data_type()
.clone(),
)
})
.collect(),
)
.with_table_partition_cols(partition_columns)
.with_collect_stat(scan.collect_stat)
.with_target_partitions(scan.target_partitions as usize)
.with_file_sort_order(all_sort_orders);
Expand Down Expand Up @@ -1046,7 +1048,6 @@ impl AsLogicalPlan for LogicalPlanNode {
})
}
};
let schema: protobuf::Schema = schema.as_ref().try_into()?;

let filters: Vec<protobuf::LogicalExprNode> =
serialize_exprs(filters, extension_codec)?;
Expand Down Expand Up @@ -1099,6 +1100,21 @@ impl AsLogicalPlan for LogicalPlanNode {

let options = listing_table.options();

let mut builder = SchemaBuilder::from(schema.as_ref());
for (idx, field) in schema.fields().iter().enumerate().rev() {
if options
.table_partition_cols
.iter()
.any(|(name, _)| name == field.name())
{
builder.remove(idx);
}
}

let schema = builder.finish();

let schema: protobuf::Schema = (&schema).try_into()?;

let mut exprs_vec: Vec<SortExprNodeCollection> = vec![];
for order in &options.file_sort_order {
let expr_vec = SortExprNodeCollection {
Expand All @@ -1107,18 +1123,32 @@ impl AsLogicalPlan for LogicalPlanNode {
exprs_vec.push(expr_vec);
}

let partition_columns = options
.table_partition_cols
.iter()
.map(|(name, arrow_type)| {
let arrow_type = protobuf::ArrowType::try_from(arrow_type)
.map_err(|e| {
proto_error(format!(
"Received an unknown ArrowType: {}",
e
))
})?;
Ok(protobuf::PartitionColumn {
name: name.clone(),
arrow_type: Some(arrow_type),
})
})
.collect::<Result<Vec<_>>>()?;

Ok(LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::ListingScan(
protobuf::ListingTableScanNode {
file_format_type: Some(file_format_type),
table_name: Some(table_name.clone().into()),
collect_stat: options.collect_stat,
file_extension: options.file_extension.clone(),
table_partition_cols: options
.table_partition_cols
.iter()
.map(|x| x.0.clone())
.collect::<Vec<_>>(),
table_partition_cols: partition_columns,
paths: listing_table
.table_paths()
.iter()
Expand All @@ -1133,6 +1163,7 @@ impl AsLogicalPlan for LogicalPlanNode {
)),
})
} else if let Some(view_table) = source.downcast_ref::<ViewTable>() {
let schema: protobuf::Schema = schema.as_ref().try_into()?;
Ok(LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new(
protobuf::ViewTableScanNode {
Expand Down Expand Up @@ -1167,6 +1198,7 @@ impl AsLogicalPlan for LogicalPlanNode {
)),
})
} else {
let schema: protobuf::Schema = schema.as_ref().try_into()?;
let mut bytes = vec![];
extension_codec
.try_encode_table_provider(table_name, provider, &mut bytes)
Expand Down
35 changes: 34 additions & 1 deletion datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use arrow::datatypes::{
DECIMAL256_MAX_PRECISION,
};
use arrow::util::pretty::pretty_format_batches;
use datafusion::datasource::file_format::json::JsonFormatFactory;
use datafusion::datasource::file_format::json::{JsonFormat, JsonFormatFactory};
use datafusion::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
};
use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion;
use datafusion::optimizer::Optimizer;
use datafusion_common::parsers::CompressionTypeVariant;
Expand Down Expand Up @@ -2559,3 +2562,33 @@ async fn roundtrip_union_query() -> Result<()> {
);
Ok(())
}

#[tokio::test]
async fn roundtrip_custom_listing_tables_schema() -> Result<()> {
let ctx = SessionContext::new();
// Make sure during round-trip, constraint information is preserved
let file_format = JsonFormat::default();
let table_partition_cols = vec![("part".to_owned(), DataType::Int64)];
let data = "../core/tests/data/partitioned_table_json";
let listing_table_url = ListingTableUrl::parse(data)?;
let listing_options = ListingOptions::new(Arc::new(file_format))
.with_table_partition_cols(table_partition_cols);

let config = ListingTableConfig::new(listing_table_url)
.with_listing_options(listing_options)
.infer_schema(&ctx.state())
.await?;

ctx.register_table("hive_style", Arc::new(ListingTable::try_new(config)?))?;

let plan = ctx
.sql("SELECT part, value FROM hive_style LIMIT 1")
.await?
.logical_plan()
.clone();

let bytes = logical_plan_to_bytes(&plan)?;
let new_plan = logical_plan_from_bytes(&bytes, &ctx)?;
assert_eq!(plan, new_plan);
Ok(())
}