@@ -92,11 +92,8 @@ def execute_plan(self, plan: PhysicalPlan, num_samples: int | float = float("inf
92
92
93
93
# get handle to DataSource and pre-compute its size
94
94
source_operator = plan .operators [0 ]
95
- datasource = (
96
- source_operator .get_datasource ()
97
- if isinstance (source_operator , MarshalAndScanDataOp )
98
- else self .datadir .get_cached_result (source_operator .dataset_id )
99
- )
95
+ assert isinstance (source_operator , DataSourcePhysicalOp ), "First operator in physical plan must be a DataSourcePhysicalOp"
96
+ datasource = source_operator .get_datasource ()
100
97
datasource_len = len (datasource )
101
98
102
99
# Calculate total work units - each record needs to go through each operator
@@ -272,11 +269,8 @@ def execute_plan(self, plan: PhysicalPlan, num_samples: int | float = float("inf
272
269
273
270
# get handle to DataSource and pre-compute its size
274
271
source_operator = plan .operators [0 ]
275
- datasource = (
276
- source_operator .get_datasource ()
277
- if isinstance (source_operator , MarshalAndScanDataOp )
278
- else self .datadir .get_cached_result (source_operator .dataset_id )
279
- )
272
+ assert isinstance (source_operator , DataSourcePhysicalOp ), "First operator in physical plan must be a DataSourcePhysicalOp"
273
+ datasource = source_operator .get_datasource ()
280
274
datasource_len = len (datasource )
281
275
282
276
# Calculate total work units - each record needs to go through each operator
@@ -468,11 +462,8 @@ def __init__(self, *args, **kwargs):
468
462
469
463
# # get handle to DataSource and pre-compute its size
470
464
# source_operator = plan.operators[0]
471
- # datasource = (
472
- # source_operator.get_datasource()
473
- # if isinstance(source_operator, MarshalAndScanDataOp)
474
- # else self.datadir.get_cached_result(source_operator.dataset_id)
475
- # )
465
+ # assert isinstance(source_operator, DataSourcePhysicalOp), "First operator in physical plan must be a DataSourcePhysicalOp"
466
+ # datasource = source_operator.get_datasource()
476
467
# datasource_len = len(datasource)
477
468
478
469
# # Calculate total work units - each record needs to go through each operator
0 commit comments