diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index c40dd522a457..0fa6d6cbde06 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1271,11 +1271,12 @@ impl DataFrame { /// ``` pub async fn cache(self) -> Result { let context = SessionContext::new_with_state(self.session_state.clone()); - let mem_table = MemTable::try_new( - SchemaRef::from(self.schema().clone()), - self.collect_partitioned().await?, - )?; - + // The schema is consistent with the output + let plan = self.clone().create_physical_plan().await?; + let schema = plan.schema(); + let task_ctx = Arc::new(self.task_ctx()); + let partitions = collect_partitioned(plan, task_ctx).await?; + let mem_table = MemTable::try_new(schema, partitions)?; context.read_table(Arc::new(mem_table)) } } @@ -2633,6 +2634,17 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_cache_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + let df = ctx + .sql("SELECT CASE WHEN true THEN NULL ELSE 1 END") + .await?; + let cache_df = df.cache().await; + assert!(cache_df.is_ok()); + Ok(()) + } + #[tokio::test] async fn cache_test() -> Result<()> { let df = test_table()