Skip to content

Commit 952e7c3

Browse files
authored
improve file path validation when reading parquet (#8267)
* improve file path validation * fix cli * update test * update test
1 parent 54a0247 commit 952e7c3

File tree

2 files changed

+71
-6
lines changed

2 files changed

+71
-6
lines changed

datafusion/core/src/execution/context/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -858,10 +858,12 @@ impl SessionContext {
858858

859859
// check if the file extension matches the expected extension
860860
for path in &table_paths {
861-
let file_name = path.prefix().filename().unwrap_or_default();
862-
if !path.as_str().ends_with(&option_extension) && file_name.contains('.') {
861+
let file_path = path.as_str();
862+
if !file_path.ends_with(option_extension.clone().as_str())
863+
&& !path.is_collection()
864+
{
863865
return exec_err!(
864-
"File '{file_name}' does not match the expected extension '{option_extension}'"
866+
"File path '{file_path}' does not match the expected extension '{option_extension}'"
865867
);
866868
}
867869
}

datafusion/core/src/execution/context/parquet.rs

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ mod tests {
138138
Ok(())
139139
}
140140

141-
#[cfg(not(target_family = "windows"))]
142141
#[tokio::test]
143142
async fn read_from_different_file_extension() -> Result<()> {
144143
let ctx = SessionContext::new();
144+
let sep = std::path::MAIN_SEPARATOR.to_string();
145145

146146
// Make up a new dataframe.
147147
let write_df = ctx.read_batch(RecordBatch::try_new(
@@ -175,6 +175,25 @@ mod tests {
175175
.unwrap()
176176
.to_string();
177177

178+
let path4 = temp_dir_path
179+
.join("output4.parquet".to_owned() + &sep)
180+
.to_str()
181+
.unwrap()
182+
.to_string();
183+
184+
let path5 = temp_dir_path
185+
.join("bbb..bbb")
186+
.join("filename.parquet")
187+
.to_str()
188+
.unwrap()
189+
.to_string();
190+
let dir = temp_dir_path
191+
.join("bbb..bbb".to_owned() + &sep)
192+
.to_str()
193+
.unwrap()
194+
.to_string();
195+
std::fs::create_dir(dir).expect("create dir failed");
196+
178197
// Write the dataframe to a parquet file named 'output1.parquet'
179198
write_df
180199
.clone()
@@ -205,6 +224,7 @@ mod tests {
205224

206225
// Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet'
207226
write_df
227+
.clone()
208228
.write_parquet(
209229
&path3,
210230
DataFrameWriteOptions::new().with_single_file_output(true),
@@ -216,6 +236,19 @@ mod tests {
216236
)
217237
.await?;
218238

239+
// Write the dataframe to a parquet file named 'bbb..bbb/filename.parquet'
240+
write_df
241+
.write_parquet(
242+
&path5,
243+
DataFrameWriteOptions::new().with_single_file_output(true),
244+
Some(
245+
WriterProperties::builder()
246+
.set_compression(Compression::SNAPPY)
247+
.build(),
248+
),
249+
)
250+
.await?;
251+
219252
// Read the dataframe from 'output1.parquet' with the default file extension.
220253
let read_df = ctx
221254
.read_parquet(
@@ -253,10 +286,11 @@ mod tests {
253286
},
254287
)
255288
.await;
256-
289+
let binding = DataFilePaths::to_urls(&path2).unwrap();
290+
let expexted_path = binding[0].as_str();
257291
assert_eq!(
258292
read_df.unwrap_err().strip_backtrace(),
259-
"Execution error: File 'output2.parquet.snappy' does not match the expected extension '.parquet'"
293+
format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expexted_path)
260294
);
261295

262296
// Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension.
@@ -269,6 +303,35 @@ mod tests {
269303
)
270304
.await?;
271305

306+
let results = read_df.collect().await?;
307+
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
308+
assert_eq!(total_rows, 5);
309+
310+
// Read the dataframe from 'output4/'
311+
std::fs::create_dir(&path4)?;
312+
let read_df = ctx
313+
.read_parquet(
314+
&path4,
315+
ParquetReadOptions {
316+
..Default::default()
317+
},
318+
)
319+
.await?;
320+
321+
let results = read_df.collect().await?;
322+
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
323+
assert_eq!(total_rows, 0);
324+
325+
// Read the datafram from doule dot folder;
326+
let read_df = ctx
327+
.read_parquet(
328+
&path5,
329+
ParquetReadOptions {
330+
..Default::default()
331+
},
332+
)
333+
.await?;
334+
272335
let results = read_df.collect().await?;
273336
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
274337
assert_eq!(total_rows, 5);

0 commit comments

Comments
 (0)