@@ -37,9 +37,15 @@ use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, Table
3737use datafusion:: dataframe:: { DataFrame , DataFrameWriteOptions } ;
3838use datafusion:: error:: DataFusionError ;
3939use datafusion:: execution:: SendableRecordBatchStream ;
40+ use datafusion:: execution:: context:: TaskContext ;
4041use datafusion:: logical_expr:: SortExpr ;
4142use datafusion:: logical_expr:: dml:: InsertOp ;
4243use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
44+ use datafusion:: physical_plan:: {
45+ ExecutionPlan as DFExecutionPlan , collect as df_collect,
46+ collect_partitioned as df_collect_partitioned, execute_stream as df_execute_stream,
47+ execute_stream_partitioned as df_execute_stream_partitioned,
48+ } ;
4349use datafusion:: prelude:: * ;
4450use datafusion_python_util:: { is_ipython_env, spawn_future, wait_for_future} ;
4551use futures:: { StreamExt , TryStreamExt } ;
@@ -308,6 +314,9 @@ pub struct PyDataFrame {
308314
309315 // In IPython environment cache batches between __repr__ and _repr_html_ calls.
310316 batches : SharedCachedBatches ,
317+
318+ // Cache the last physical plan so that metrics are available after execution.
319+ last_plan : Arc < Mutex < Option < Arc < dyn DFExecutionPlan > > > > ,
311320}
312321
313322impl PyDataFrame {
@@ -316,6 +325,7 @@ impl PyDataFrame {
316325 Self {
317326 df : Arc :: new ( df) ,
318327 batches : Arc :: new ( Mutex :: new ( None ) ) ,
328+ last_plan : Arc :: new ( Mutex :: new ( None ) ) ,
319329 }
320330 }
321331
@@ -387,6 +397,20 @@ impl PyDataFrame {
387397 Ok ( html_str)
388398 }
389399
400+ /// Create the physical plan, cache it in `last_plan`, and return the plan together
401+ /// with a task context. Centralises the repeated three-line pattern that appears in
402+ /// `collect`, `collect_partitioned`, `execute_stream`, and `execute_stream_partitioned`.
403+ fn create_and_cache_plan (
404+ & self ,
405+ py : Python ,
406+ ) -> PyDataFusionResult < ( Arc < dyn DFExecutionPlan > , Arc < TaskContext > ) > {
407+ let df = self . df . as_ref ( ) . clone ( ) ;
408+ let new_plan = wait_for_future ( py, df. create_physical_plan ( ) ) ??;
409+ * self . last_plan . lock ( ) = Some ( Arc :: clone ( & new_plan) ) ;
410+ let task_ctx = Arc :: new ( self . df . as_ref ( ) . task_ctx ( ) ) ;
411+ Ok ( ( new_plan, task_ctx) )
412+ }
413+
390414 async fn collect_column_inner ( & self , column : & str ) -> Result < ArrayRef , DataFusionError > {
391415 let batches = self
392416 . df
@@ -646,8 +670,9 @@ impl PyDataFrame {
646670 /// Unless some order is specified in the plan, there is no
647671 /// guarantee of the order of the result.
648672 fn collect < ' py > ( & self , py : Python < ' py > ) -> PyResult < Vec < Bound < ' py , PyAny > > > {
649- let batches = wait_for_future ( py, self . df . as_ref ( ) . clone ( ) . collect ( ) ) ?
650- . map_err ( PyDataFusionError :: from) ?;
673+ let ( plan, task_ctx) = self . create_and_cache_plan ( py) ?;
674+ let batches =
675+ wait_for_future ( py, df_collect ( plan, task_ctx) ) ?. map_err ( PyDataFusionError :: from) ?;
651676 // cannot use PyResult<Vec<RecordBatch>> return type due to
652677 // https://github.com/PyO3/pyo3/issues/1813
653678 batches. into_iter ( ) . map ( |rb| rb. to_pyarrow ( py) ) . collect ( )
@@ -662,7 +687,8 @@ impl PyDataFrame {
662687 /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
663688 /// maintaining the input partitioning.
664689 fn collect_partitioned < ' py > ( & self , py : Python < ' py > ) -> PyResult < Vec < Vec < Bound < ' py , PyAny > > > > {
665- let batches = wait_for_future ( py, self . df . as_ref ( ) . clone ( ) . collect_partitioned ( ) ) ?
690+ let ( plan, task_ctx) = self . create_and_cache_plan ( py) ?;
691+ let batches = wait_for_future ( py, df_collect_partitioned ( plan, task_ctx) ) ?
666692 . map_err ( PyDataFusionError :: from) ?;
667693
668694 batches
@@ -840,7 +866,13 @@ impl PyDataFrame {
840866 }
841867
842868 /// Get the execution plan for this `DataFrame`
869+ ///
870+ /// If the DataFrame has already been executed (e.g. via `collect()`),
871+ /// returns the cached plan which includes populated metrics.
843872 fn execution_plan ( & self , py : Python ) -> PyDataFusionResult < PyExecutionPlan > {
873+ if let Some ( plan) = self . last_plan . lock ( ) . as_ref ( ) {
874+ return Ok ( PyExecutionPlan :: new ( Arc :: clone ( plan) ) ) ;
875+ }
844876 let plan = wait_for_future ( py, self . df . as_ref ( ) . clone ( ) . create_physical_plan ( ) ) ??;
845877 Ok ( plan. into ( ) )
846878 }
@@ -1198,14 +1230,17 @@ impl PyDataFrame {
11981230 }
11991231
12001232 fn execute_stream ( & self , py : Python ) -> PyDataFusionResult < PyRecordBatchStream > {
1201- let df = self . df . as_ref ( ) . clone ( ) ;
1202- let stream = spawn_future ( py, async move { df . execute_stream ( ) . await } ) ?;
1233+ let ( plan , task_ctx ) = self . create_and_cache_plan ( py ) ? ;
1234+ let stream = spawn_future ( py, async move { df_execute_stream ( plan , task_ctx ) } ) ?;
12031235 Ok ( PyRecordBatchStream :: new ( stream) )
12041236 }
12051237
12061238 fn execute_stream_partitioned ( & self , py : Python ) -> PyResult < Vec < PyRecordBatchStream > > {
1207- let df = self . df . as_ref ( ) . clone ( ) ;
1208- let streams = spawn_future ( py, async move { df. execute_stream_partitioned ( ) . await } ) ?;
1239+ let ( plan, task_ctx) = self . create_and_cache_plan ( py) ?;
1240+ let streams = spawn_future (
1241+ py,
1242+ async move { df_execute_stream_partitioned ( plan, task_ctx) } ,
1243+ ) ?;
12091244 Ok ( streams. into_iter ( ) . map ( PyRecordBatchStream :: new) . collect ( ) )
12101245 }
12111246
0 commit comments