5
5
"""
6
6
7
7
from abc import ABC , abstractmethod
8
+ from functools import cached_property
8
9
from typing import (
9
10
TYPE_CHECKING ,
10
11
AbstractSet ,
@@ -362,18 +363,6 @@ def scoped_resources_builder(self) -> ScopedResourcesBuilder:
362
363
def log (self ) -> DagsterLogManager :
363
364
return self ._log_manager
364
365
365
- @property
366
- def partitions_def (self ) -> Optional [PartitionsDefinition ]:
367
- from dagster ._core .definitions .job_definition import JobDefinition
368
-
369
- job_def = self ._execution_data .job_def
370
- if not isinstance (job_def , JobDefinition ):
371
- check .failed (
372
- "Can only call 'partitions_def', when using jobs, not legacy pipelines" ,
373
- )
374
- partitions_def = job_def .partitions_def
375
- return partitions_def
376
-
377
366
@property
378
367
def has_partitions (self ) -> bool :
379
368
tags = self ._plan_data .dagster_run .tags
@@ -386,68 +375,6 @@ def has_partitions(self) -> bool:
386
375
)
387
376
)
388
377
389
- @property
390
- def partition_key (self ) -> str :
391
- from dagster ._core .definitions .multi_dimensional_partitions import (
392
- MultiPartitionsDefinition ,
393
- get_multipartition_key_from_tags ,
394
- )
395
-
396
- if not self .has_partitions :
397
- raise DagsterInvariantViolationError (
398
- "Cannot access partition_key for a non-partitioned run"
399
- )
400
-
401
- tags = self ._plan_data .dagster_run .tags
402
- if any ([tag .startswith (MULTIDIMENSIONAL_PARTITION_PREFIX ) for tag in tags .keys ()]):
403
- return get_multipartition_key_from_tags (tags )
404
- elif PARTITION_NAME_TAG in tags :
405
- return tags [PARTITION_NAME_TAG ]
406
- else :
407
- range_start = tags [ASSET_PARTITION_RANGE_START_TAG ]
408
- range_end = tags [ASSET_PARTITION_RANGE_END_TAG ]
409
-
410
- if range_start != range_end :
411
- raise DagsterInvariantViolationError (
412
- "Cannot access partition_key for a partitioned run with a range of partitions."
413
- " Call partition_key_range instead."
414
- )
415
- else :
416
- if isinstance (self .partitions_def , MultiPartitionsDefinition ):
417
- return self .partitions_def .get_partition_key_from_str (cast (str , range_start ))
418
- return cast (str , range_start )
419
-
420
- @property
421
- def partition_key_range (self ) -> PartitionKeyRange :
422
- from dagster ._core .definitions .multi_dimensional_partitions import (
423
- MultiPartitionsDefinition ,
424
- get_multipartition_key_from_tags ,
425
- )
426
-
427
- if not self .has_partitions :
428
- raise DagsterInvariantViolationError (
429
- "Cannot access partition_key for a non-partitioned run"
430
- )
431
-
432
- tags = self ._plan_data .dagster_run .tags
433
- if any ([tag .startswith (MULTIDIMENSIONAL_PARTITION_PREFIX ) for tag in tags .keys ()]):
434
- multipartition_key = get_multipartition_key_from_tags (tags )
435
- return PartitionKeyRange (multipartition_key , multipartition_key )
436
- elif PARTITION_NAME_TAG in tags :
437
- partition_key = tags [PARTITION_NAME_TAG ]
438
- return PartitionKeyRange (partition_key , partition_key )
439
- else :
440
- partition_key_range_start = tags [ASSET_PARTITION_RANGE_START_TAG ]
441
- if partition_key_range_start is not None :
442
- if isinstance (self .partitions_def , MultiPartitionsDefinition ):
443
- return PartitionKeyRange (
444
- self .partitions_def .get_partition_key_from_str (partition_key_range_start ),
445
- self .partitions_def .get_partition_key_from_str (
446
- tags [ASSET_PARTITION_RANGE_END_TAG ]
447
- ),
448
- )
449
- return PartitionKeyRange (partition_key_range_start , tags [ASSET_PARTITION_RANGE_END_TAG ])
450
-
451
378
@property
452
379
def has_partition_key (self ) -> bool :
453
380
return PARTITION_NAME_TAG in self ._plan_data .dagster_run .tags
@@ -954,6 +881,100 @@ def get_output_asset_keys(self) -> AbstractSet[AssetKey]:
954
881
output_keys .add (asset_key )
955
882
return output_keys
956
883
884
+ @cached_property
885
+ def run_partitions_def (self ) -> Optional [PartitionsDefinition ]:
886
+ job_def_partitions_def = self .job_def .partitions_def
887
+ if job_def_partitions_def is not None :
888
+ return job_def_partitions_def
889
+
890
+ # In the case where a job targets assets with different PartitionsDefinitions,
891
+ # job_def.partitions_def will be None, but the assets targeted in this step might still be
892
+ # partitioned. All assets within a step are expected to either have the same partitions_def
893
+ # or no partitions_def. Get the partitions_def from one of the assets that has one.
894
+ return self .asset_partitions_def
895
+
896
+ @cached_property
897
+ def asset_partitions_def (self ) -> Optional [PartitionsDefinition ]:
898
+ """If the current step is executing a partitioned asset, returns the PartitionsDefinition
899
+ for that asset. If there are one or more partitioned assets executing in the step, they're
900
+ expected to all have the same PartitionsDefinition.
901
+ """
902
+ asset_layer = self .job_def .asset_layer
903
+ assets_def = asset_layer .assets_def_for_node (self .node_handle ) if asset_layer else None
904
+ if assets_def is not None :
905
+ for asset_key in assets_def .keys :
906
+ partitions_def = self .job_def .asset_layer .get (asset_key ).partitions_def
907
+ if partitions_def is not None :
908
+ return partitions_def
909
+
910
+ return None
911
+
912
+ @property
913
+ def partition_key (self ) -> str :
914
+ from dagster ._core .definitions .multi_dimensional_partitions import (
915
+ MultiPartitionsDefinition ,
916
+ get_multipartition_key_from_tags ,
917
+ )
918
+
919
+ if not self .has_partitions :
920
+ raise DagsterInvariantViolationError (
921
+ "Cannot access partition_key for a non-partitioned run"
922
+ )
923
+
924
+ tags = self ._plan_data .dagster_run .tags
925
+ if any ([tag .startswith (MULTIDIMENSIONAL_PARTITION_PREFIX ) for tag in tags .keys ()]):
926
+ return get_multipartition_key_from_tags (tags )
927
+ elif PARTITION_NAME_TAG in tags :
928
+ return tags [PARTITION_NAME_TAG ]
929
+ else :
930
+ range_start = tags [ASSET_PARTITION_RANGE_START_TAG ]
931
+ range_end = tags [ASSET_PARTITION_RANGE_END_TAG ]
932
+
933
+ if range_start != range_end :
934
+ raise DagsterInvariantViolationError (
935
+ "Cannot access partition_key for a partitioned run with a range of partitions."
936
+ " Call partition_key_range instead."
937
+ )
938
+ else :
939
+ if isinstance (self .run_partitions_def , MultiPartitionsDefinition ):
940
+ return self .run_partitions_def .get_partition_key_from_str (
941
+ cast (str , range_start )
942
+ )
943
+ return cast (str , range_start )
944
+
945
+ @property
946
+ def partition_key_range (self ) -> PartitionKeyRange :
947
+ from dagster ._core .definitions .multi_dimensional_partitions import (
948
+ MultiPartitionsDefinition ,
949
+ get_multipartition_key_from_tags ,
950
+ )
951
+
952
+ if not self .has_partitions :
953
+ raise DagsterInvariantViolationError (
954
+ "Cannot access partition_key for a non-partitioned run"
955
+ )
956
+
957
+ tags = self ._plan_data .dagster_run .tags
958
+ if any ([tag .startswith (MULTIDIMENSIONAL_PARTITION_PREFIX ) for tag in tags .keys ()]):
959
+ multipartition_key = get_multipartition_key_from_tags (tags )
960
+ return PartitionKeyRange (multipartition_key , multipartition_key )
961
+ elif PARTITION_NAME_TAG in tags :
962
+ partition_key = tags [PARTITION_NAME_TAG ]
963
+ return PartitionKeyRange (partition_key , partition_key )
964
+ else :
965
+ partition_key_range_start = tags [ASSET_PARTITION_RANGE_START_TAG ]
966
+ if partition_key_range_start is not None :
967
+ if isinstance (self .run_partitions_def , MultiPartitionsDefinition ):
968
+ return PartitionKeyRange (
969
+ self .run_partitions_def .get_partition_key_from_str (
970
+ partition_key_range_start
971
+ ),
972
+ self .run_partitions_def .get_partition_key_from_str (
973
+ tags [ASSET_PARTITION_RANGE_END_TAG ]
974
+ ),
975
+ )
976
+ return PartitionKeyRange (partition_key_range_start , tags [ASSET_PARTITION_RANGE_END_TAG ])
977
+
957
978
def has_asset_partitions_for_input (self , input_name : str ) -> bool :
958
979
asset_layer = self .job_def .asset_layer
959
980
upstream_asset_key = asset_layer .asset_key_for_input (self .node_handle , input_name )
@@ -999,7 +1020,7 @@ def asset_partitions_subset_for_input(
999
1020
upstream_asset_partitions_def = asset_layer .get (upstream_asset_key ).partitions_def
1000
1021
1001
1022
if upstream_asset_partitions_def is not None :
1002
- partitions_def = assets_def . partitions_def if assets_def else None
1023
+ partitions_def = self . asset_partitions_def if assets_def else None
1003
1024
partitions_subset = (
1004
1025
partitions_def .empty_subset ().with_partition_key_range (
1005
1026
partitions_def ,
@@ -1116,12 +1137,7 @@ def asset_partitions_time_window_for_output(self, output_name: str) -> TimeWindo
1116
1137
1117
1138
@property
1118
1139
def partition_time_window (self ) -> TimeWindow :
1119
- asset_layer = self .job_def .asset_layer
1120
- partitions_def = self .job_def .partitions_def
1121
- if asset_layer :
1122
- assets_def = asset_layer .assets_def_for_node (self .node_handle )
1123
- if assets_def :
1124
- partitions_def = assets_def .partitions_def
1140
+ partitions_def = self .run_partitions_def
1125
1141
1126
1142
if partitions_def is None :
1127
1143
raise DagsterInvariantViolationError ("Partitions definition is not defined" )
0 commit comments