Skip to content

Commit fbadebb

Browse files
alambNGA-TRAN
andauthored
Fix querying and defining table / view names with period (#4530)
* Add tests for names with period * adjust docstrings * Improve docstrings * Add tests coverage * Update datafusion/common/src/table_reference.rs Co-authored-by: Nga Tran <[email protected]> * Add tests for creating tables with three periods Co-authored-by: Nga Tran <[email protected]>
1 parent 4c1f60c commit fbadebb

File tree

20 files changed

+1118
-222
lines changed

20 files changed

+1118
-222
lines changed

datafusion/common/src/dfschema.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ impl DFSchema {
206206
(Some(qq), None) => {
207207
// the original field may now be aliased with a name that matches the
208208
// original qualified name
209-
let table_ref: TableReference = field.name().as_str().into();
209+
let table_ref = TableReference::parse_str(field.name().as_str());
210210
match table_ref {
211211
TableReference::Partial { schema, table } => {
212212
schema == qq && table == name

datafusion/common/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ pub use error::{field_not_found, DataFusionError, Result, SchemaError};
3636
pub use parsers::parse_interval;
3737
pub use scalar::{ScalarType, ScalarValue};
3838
pub use stats::{ColumnStatistics, Statistics};
39-
pub use table_reference::{ResolvedTableReference, TableReference};
39+
pub use table_reference::{OwnedTableReference, ResolvedTableReference, TableReference};
4040

4141
/// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is
4242
/// not possible. In normal usage of DataFusion the downcast should always succeed.

datafusion/common/src/table_reference.rs

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,73 @@ pub enum TableReference<'a> {
5252
},
5353
}
5454

55+
/// Represents a path to a table that may require further resolution
56+
/// that owns the underlying names
57+
#[derive(Debug, Clone)]
58+
pub enum OwnedTableReference {
59+
/// An unqualified table reference, e.g. "table"
60+
Bare {
61+
/// The table name
62+
table: String,
63+
},
64+
/// A partially resolved table reference, e.g. "schema.table"
65+
Partial {
66+
/// The schema containing the table
67+
schema: String,
68+
/// The table name
69+
table: String,
70+
},
71+
/// A fully resolved table reference, e.g. "catalog.schema.table"
72+
Full {
73+
/// The catalog (aka database) containing the table
74+
catalog: String,
75+
/// The schema containing the table
76+
schema: String,
77+
/// The table name
78+
table: String,
79+
},
80+
}
81+
82+
impl OwnedTableReference {
83+
/// Return a `TableReference` view of this `OwnedTableReference`
84+
pub fn as_table_reference(&self) -> TableReference<'_> {
85+
match self {
86+
Self::Bare { table } => TableReference::Bare { table },
87+
Self::Partial { schema, table } => TableReference::Partial { schema, table },
88+
Self::Full {
89+
catalog,
90+
schema,
91+
table,
92+
} => TableReference::Full {
93+
catalog,
94+
schema,
95+
table,
96+
},
97+
}
98+
}
99+
100+
/// Return a string suitable for display
101+
pub fn display_string(&self) -> String {
102+
match self {
103+
OwnedTableReference::Bare { table } => table.clone(),
104+
OwnedTableReference::Partial { schema, table } => format!("{schema}.{table}"),
105+
OwnedTableReference::Full {
106+
catalog,
107+
schema,
108+
table,
109+
} => format!("{catalog}.{schema}.{table}"),
110+
}
111+
}
112+
}
113+
114+
/// Convert `OwnedTableReference` into a `TableReference`. Somewhat
115+
/// akward to use but 'idiomatic': `(&table_ref).into()`
116+
impl<'a> From<&'a OwnedTableReference> for TableReference<'a> {
117+
fn from(r: &'a OwnedTableReference) -> Self {
118+
r.as_table_reference()
119+
}
120+
}
121+
55122
impl<'a> TableReference<'a> {
56123
/// Retrieve the actual table name, regardless of qualification
57124
pub fn table(&self) -> &str {
@@ -90,10 +157,18 @@ impl<'a> TableReference<'a> {
90157
},
91158
}
92159
}
93-
}
94160

95-
impl<'a> From<&'a str> for TableReference<'a> {
96-
fn from(s: &'a str) -> Self {
161+
/// Forms a [`TableReferece`] by splitting `s` on periods `.`.
162+
///
163+
/// Note that this function does NOT handle periods or name
164+
/// normalization correctly (e.g. `"foo.bar"` will be parsed as
165+
/// `"foo`.`bar"`. and `Foo` will be parsed as `Foo` (not `foo`).
166+
///
167+
/// If you need to handle such identifiers correctly, you should
168+
/// use a SQL parser or form the [`OwnedTableReference`] directly.
169+
///
170+
/// See more detail in <https://github.com/apache/arrow-datafusion/issues/4532>
171+
pub fn parse_str(s: &'a str) -> Self {
97172
let parts: Vec<&str> = s.split('.').collect();
98173

99174
match parts.len() {
@@ -112,6 +187,15 @@ impl<'a> From<&'a str> for TableReference<'a> {
112187
}
113188
}
114189

190+
/// Parse a string into a TableReference, by splittig on `.`
191+
///
192+
/// See caveats on [`TableReference::parse_str`]
193+
impl<'a> From<&'a str> for TableReference<'a> {
194+
fn from(s: &'a str) -> Self {
195+
Self::parse_str(s)
196+
}
197+
}
198+
115199
impl<'a> From<ResolvedTableReference<'a>> for TableReference<'a> {
116200
fn from(resolved: ResolvedTableReference<'a>) -> Self {
117201
Self::Full {

datafusion/core/src/catalog/listing_schema.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::catalog::schema::SchemaProvider;
2020
use crate::datasource::datasource::TableProviderFactory;
2121
use crate::datasource::TableProvider;
2222
use crate::execution::context::SessionState;
23-
use datafusion_common::{DFSchema, DataFusionError};
23+
use datafusion_common::{DFSchema, DataFusionError, OwnedTableReference};
2424
use datafusion_expr::CreateExternalTable;
2525
use futures::TryStreamExt;
2626
use itertools::Itertools;
@@ -115,16 +115,20 @@ impl ListingSchemaProvider {
115115
let table_path = table.to_str().ok_or_else(|| {
116116
DataFusionError::Internal("Cannot parse file name!".to_string())
117117
})?;
118+
118119
if !self.table_exist(table_name) {
119120
let table_url = format!("{}/{}", self.authority, table_path);
120121

122+
let name = OwnedTableReference::Bare {
123+
table: table_name.to_string(),
124+
};
121125
let provider = self
122126
.factory
123127
.create(
124128
state,
125129
&CreateExternalTable {
126130
schema: Arc::new(DFSchema::empty()),
127-
name: table_name.to_string(),
131+
name,
128132
location: table_url,
129133
file_type: self.format.clone(),
130134
has_header: self.has_header,

datafusion/core/src/datasource/view.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ mod tests {
502502
let actual = format!("{}", plan.display_indent());
503503
let expected = "\
504504
Explain\
505-
\n CreateView: \"xyz\"\
505+
\n CreateView: Bare { table: \"xyz\" }\
506506
\n Projection: abc.column1, abc.column2, abc.column3\
507507
\n TableScan: abc projection=[column1, column2, column3]";
508508
assert_eq!(expected, actual);
@@ -516,7 +516,7 @@ mod tests {
516516
let actual = format!("{}", plan.display_indent());
517517
let expected = "\
518518
Explain\
519-
\n CreateView: \"xyz\"\
519+
\n CreateView: Bare { table: \"xyz\" }\
520520
\n Projection: abc.column1, abc.column2, abc.column3\
521521
\n Filter: abc.column2 = Int64(5)\
522522
\n TableScan: abc projection=[column1, column2, column3]";
@@ -531,7 +531,7 @@ mod tests {
531531
let actual = format!("{}", plan.display_indent());
532532
let expected = "\
533533
Explain\
534-
\n CreateView: \"xyz\"\
534+
\n CreateView: Bare { table: \"xyz\" }\
535535
\n Projection: abc.column1, abc.column2\
536536
\n Filter: abc.column2 = Int64(5)\
537537
\n TableScan: abc projection=[column1, column2]";

datafusion/core/src/execution/context.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ impl SessionContext {
225225
batch: RecordBatch,
226226
) -> Result<Option<Arc<dyn TableProvider>>> {
227227
let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
228-
self.register_table(table_name, Arc::new(table))
228+
self.register_table(TableReference::Bare { table: table_name }, Arc::new(table))
229229
}
230230

231231
/// Return the [RuntimeEnv] used to run queries with this [SessionContext]
@@ -265,12 +265,12 @@ impl SessionContext {
265265
if_not_exists,
266266
or_replace,
267267
}) => {
268-
let table = self.table(name.as_str());
268+
let table = self.table(&name);
269269

270270
match (if_not_exists, or_replace, table) {
271271
(true, false, Ok(_)) => self.return_empty_dataframe(),
272272
(false, true, Ok(_)) => {
273-
self.deregister_table(name.as_str())?;
273+
self.deregister_table(&name)?;
274274
let physical =
275275
Arc::new(DataFrame::new(self.state.clone(), &input));
276276

@@ -280,7 +280,7 @@ impl SessionContext {
280280
batches,
281281
)?);
282282

283-
self.register_table(name.as_str(), table)?;
283+
self.register_table(&name, table)?;
284284
self.return_empty_dataframe()
285285
}
286286
(true, true, Ok(_)) => Err(DataFusionError::Internal(
@@ -296,7 +296,7 @@ impl SessionContext {
296296
batches,
297297
)?);
298298

299-
self.register_table(name.as_str(), table)?;
299+
self.register_table(&name, table)?;
300300
self.return_empty_dataframe()
301301
}
302302
(false, false, Ok(_)) => Err(DataFusionError::Execution(format!(
@@ -312,22 +312,22 @@ impl SessionContext {
312312
or_replace,
313313
definition,
314314
}) => {
315-
let view = self.table(name.as_str());
315+
let view = self.table(&name);
316316

317317
match (or_replace, view) {
318318
(true, Ok(_)) => {
319-
self.deregister_table(name.as_str())?;
319+
self.deregister_table(&name)?;
320320
let table =
321321
Arc::new(ViewTable::try_new((*input).clone(), definition)?);
322322

323-
self.register_table(name.as_str(), table)?;
323+
self.register_table(&name, table)?;
324324
self.return_empty_dataframe()
325325
}
326326
(_, Err(_)) => {
327327
let table =
328328
Arc::new(ViewTable::try_new((*input).clone(), definition)?);
329329

330-
self.register_table(name.as_str(), table)?;
330+
self.register_table(&name, table)?;
331331
self.return_empty_dataframe()
332332
}
333333
(false, Ok(_)) => Err(DataFusionError::Execution(format!(
@@ -340,7 +340,7 @@ impl SessionContext {
340340
LogicalPlan::DropTable(DropTable {
341341
name, if_exists, ..
342342
}) => {
343-
let result = self.find_and_deregister(name.as_str(), TableType::Base);
343+
let result = self.find_and_deregister(&name, TableType::Base);
344344
match (result, if_exists) {
345345
(Ok(true), _) => self.return_empty_dataframe(),
346346
(_, true) => self.return_empty_dataframe(),
@@ -354,7 +354,7 @@ impl SessionContext {
354354
LogicalPlan::DropView(DropView {
355355
name, if_exists, ..
356356
}) => {
357-
let result = self.find_and_deregister(name.as_str(), TableType::View);
357+
let result = self.find_and_deregister(&name, TableType::View);
358358
match (result, if_exists) {
359359
(Ok(true), _) => self.return_empty_dataframe(),
360360
(_, true) => self.return_empty_dataframe(),
@@ -497,11 +497,11 @@ impl SessionContext {
497497
let table_provider: Arc<dyn TableProvider> =
498498
self.create_custom_table(cmd).await?;
499499

500-
let table = self.table(cmd.name.as_str());
500+
let table = self.table(&cmd.name);
501501
match (cmd.if_not_exists, table) {
502502
(true, Ok(_)) => self.return_empty_dataframe(),
503503
(_, Err(_)) => {
504-
self.register_table(cmd.name.as_str(), table_provider)?;
504+
self.register_table(&cmd.name, table_provider)?;
505505
self.return_empty_dataframe()
506506
}
507507
(false, Ok(_)) => Err(DataFusionError::Execution(format!(
@@ -765,7 +765,7 @@ impl SessionContext {
765765
.with_listing_options(options)
766766
.with_schema(resolved_schema);
767767
let table = ListingTable::try_new(config)?.with_definition(sql_definition);
768-
self.register_table(name, Arc::new(table))?;
768+
self.register_table(TableReference::Bare { table: name }, Arc::new(table))?;
769769
Ok(())
770770
}
771771

datafusion/core/tests/sql/errors.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,12 @@ async fn invalid_qualified_table_references() -> Result<()> {
132132
"way.too.many.namespaces.as.ident.prefixes.aggregate_test_100",
133133
] {
134134
let sql = format!("SELECT COUNT(*) FROM {}", table_ref);
135-
assert!(matches!(ctx.sql(&sql).await, Err(DataFusionError::Plan(_))));
135+
let result = ctx.sql(&sql).await;
136+
assert!(
137+
matches!(result, Err(DataFusionError::Plan(_))),
138+
"result was: {:?}",
139+
result
140+
);
136141
}
137142
Ok(())
138143
}

datafusion/core/tests/sqllogictests/src/insert/mod.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,21 @@ use datafusion::datasource::MemTable;
2424
use datafusion::prelude::SessionContext;
2525
use datafusion_common::{DFSchema, DataFusionError};
2626
use datafusion_expr::Expr as DFExpr;
27-
use datafusion_sql::planner::{PlannerContext, SqlToRel};
27+
use datafusion_sql::planner::{object_name_to_table_reference, PlannerContext, SqlToRel};
2828
use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement};
2929
use std::sync::Arc;
3030

31-
pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result<String> {
31+
pub async fn insert(ctx: &SessionContext, insert_stmt: SQLStatement) -> Result<String> {
3232
// First, use sqlparser to get table name and insert values
33-
let table_name;
33+
let table_reference;
3434
let insert_values: Vec<Vec<Expr>>;
3535
match insert_stmt {
3636
SQLStatement::Insert {
37-
table_name: name,
38-
source,
39-
..
37+
table_name, source, ..
4038
} => {
39+
table_reference = object_name_to_table_reference(table_name)?;
40+
4141
// Todo: check columns match table schema
42-
table_name = name.to_string();
4342
match &*source.body {
4443
SetExpr::Values(values) => {
4544
insert_values = values.0.clone();
@@ -54,9 +53,9 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result<
5453
}
5554

5655
// Second, get batches in table and destroy the old table
57-
let mut origin_batches = ctx.table(table_name.as_str())?.collect().await?;
58-
let schema = ctx.table_provider(table_name.as_str())?.schema();
59-
ctx.deregister_table(table_name.as_str())?;
56+
let mut origin_batches = ctx.table(&table_reference)?.collect().await?;
57+
let schema = ctx.table_provider(&table_reference)?.schema();
58+
ctx.deregister_table(&table_reference)?;
6059

6160
// Third, transfer insert values to `RecordBatch`
6261
// Attention: schema info can be ignored. (insert values don't contain schema info)
@@ -84,7 +83,7 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result<
8483

8584
// Final, create new memtable with same schema.
8685
let new_provider = MemTable::try_new(schema, vec![origin_batches])?;
87-
ctx.register_table(table_name.as_str(), Arc::new(new_provider))?;
86+
ctx.register_table(&table_reference, Arc::new(new_provider))?;
8887

8988
Ok("".to_string())
9089
}

datafusion/core/tests/sqllogictests/src/main.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,17 @@ fn format_batches(batches: Vec<RecordBatch>) -> Result<String> {
187187
async fn run_query(ctx: &SessionContext, sql: impl Into<String>) -> Result<String> {
188188
let sql = sql.into();
189189
// Check if the sql is `insert`
190-
if let Ok(statements) = DFParser::parse_sql(&sql) {
191-
if let Statement::Statement(statement) = &statements[0] {
192-
if let SQLStatement::Insert { .. } = &**statement {
190+
if let Ok(mut statements) = DFParser::parse_sql(&sql) {
191+
let statement0 = statements.pop_front().expect("at least one SQL statement");
192+
if let Statement::Statement(statement) = statement0 {
193+
let statement = *statement;
194+
if matches!(&statement, SQLStatement::Insert { .. }) {
193195
return insert(ctx, statement).await;
194196
}
195197
}
196198
}
197-
let df = ctx.sql(sql.as_str()).await.unwrap();
198-
let results: Vec<RecordBatch> = df.collect().await.unwrap();
199+
let df = ctx.sql(sql.as_str()).await?;
200+
let results: Vec<RecordBatch> = df.collect().await?;
199201
let formatted_batches = format_batches(results)?;
200202
Ok(formatted_batches)
201203
}

0 commit comments

Comments
 (0)