Skip to content

Commit 5765906

Browse files
committed
move step context partitions methods
branch-name: mv-step-context-partitions-def-methods
1 parent 70db15f commit 5765906

File tree

1 file changed

+98
-82
lines changed
  • python_modules/dagster/dagster/_core/execution/context

1 file changed

+98
-82
lines changed

python_modules/dagster/dagster/_core/execution/context/system.py

Lines changed: 98 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from abc import ABC, abstractmethod
8+
from functools import cached_property
89
from typing import (
910
TYPE_CHECKING,
1011
AbstractSet,
@@ -362,18 +363,6 @@ def scoped_resources_builder(self) -> ScopedResourcesBuilder:
362363
def log(self) -> DagsterLogManager:
363364
return self._log_manager
364365

365-
@property
366-
def partitions_def(self) -> Optional[PartitionsDefinition]:
367-
from dagster._core.definitions.job_definition import JobDefinition
368-
369-
job_def = self._execution_data.job_def
370-
if not isinstance(job_def, JobDefinition):
371-
check.failed(
372-
"Can only call 'partitions_def', when using jobs, not legacy pipelines",
373-
)
374-
partitions_def = job_def.partitions_def
375-
return partitions_def
376-
377366
@property
378367
def has_partitions(self) -> bool:
379368
tags = self._plan_data.dagster_run.tags
@@ -386,68 +375,6 @@ def has_partitions(self) -> bool:
386375
)
387376
)
388377

389-
@property
390-
def partition_key(self) -> str:
391-
from dagster._core.definitions.multi_dimensional_partitions import (
392-
MultiPartitionsDefinition,
393-
get_multipartition_key_from_tags,
394-
)
395-
396-
if not self.has_partitions:
397-
raise DagsterInvariantViolationError(
398-
"Cannot access partition_key for a non-partitioned run"
399-
)
400-
401-
tags = self._plan_data.dagster_run.tags
402-
if any([tag.startswith(MULTIDIMENSIONAL_PARTITION_PREFIX) for tag in tags.keys()]):
403-
return get_multipartition_key_from_tags(tags)
404-
elif PARTITION_NAME_TAG in tags:
405-
return tags[PARTITION_NAME_TAG]
406-
else:
407-
range_start = tags[ASSET_PARTITION_RANGE_START_TAG]
408-
range_end = tags[ASSET_PARTITION_RANGE_END_TAG]
409-
410-
if range_start != range_end:
411-
raise DagsterInvariantViolationError(
412-
"Cannot access partition_key for a partitioned run with a range of partitions."
413-
" Call partition_key_range instead."
414-
)
415-
else:
416-
if isinstance(self.partitions_def, MultiPartitionsDefinition):
417-
return self.partitions_def.get_partition_key_from_str(cast(str, range_start))
418-
return cast(str, range_start)
419-
420-
@property
421-
def partition_key_range(self) -> PartitionKeyRange:
422-
from dagster._core.definitions.multi_dimensional_partitions import (
423-
MultiPartitionsDefinition,
424-
get_multipartition_key_from_tags,
425-
)
426-
427-
if not self.has_partitions:
428-
raise DagsterInvariantViolationError(
429-
"Cannot access partition_key for a non-partitioned run"
430-
)
431-
432-
tags = self._plan_data.dagster_run.tags
433-
if any([tag.startswith(MULTIDIMENSIONAL_PARTITION_PREFIX) for tag in tags.keys()]):
434-
multipartition_key = get_multipartition_key_from_tags(tags)
435-
return PartitionKeyRange(multipartition_key, multipartition_key)
436-
elif PARTITION_NAME_TAG in tags:
437-
partition_key = tags[PARTITION_NAME_TAG]
438-
return PartitionKeyRange(partition_key, partition_key)
439-
else:
440-
partition_key_range_start = tags[ASSET_PARTITION_RANGE_START_TAG]
441-
if partition_key_range_start is not None:
442-
if isinstance(self.partitions_def, MultiPartitionsDefinition):
443-
return PartitionKeyRange(
444-
self.partitions_def.get_partition_key_from_str(partition_key_range_start),
445-
self.partitions_def.get_partition_key_from_str(
446-
tags[ASSET_PARTITION_RANGE_END_TAG]
447-
),
448-
)
449-
return PartitionKeyRange(partition_key_range_start, tags[ASSET_PARTITION_RANGE_END_TAG])
450-
451378
@property
452379
def has_partition_key(self) -> bool:
453380
return PARTITION_NAME_TAG in self._plan_data.dagster_run.tags
@@ -954,6 +881,100 @@ def get_output_asset_keys(self) -> AbstractSet[AssetKey]:
954881
output_keys.add(asset_key)
955882
return output_keys
956883

884+
@cached_property
885+
def run_partitions_def(self) -> Optional[PartitionsDefinition]:
886+
job_def_partitions_def = self.job_def.partitions_def
887+
if job_def_partitions_def is not None:
888+
return job_def_partitions_def
889+
890+
# In the case where a job targets assets with different PartitionsDefinitions,
891+
# job_def.partitions_def will be None, but the assets targeted in this step might still be
892+
# partitioned. All assets within a step are expected to either have the same partitions_def
893+
# or no partitions_def. Get the partitions_def from one of the assets that has one.
894+
return self.asset_partitions_def
895+
896+
@cached_property
897+
def asset_partitions_def(self) -> Optional[PartitionsDefinition]:
898+
"""If the current step is executing a partitioned asset, returns the PartitionsDefinition
899+
for that asset. If there are one or more partitioned assets executing in the step, they're
900+
expected to all have the same PartitionsDefinition.
901+
"""
902+
asset_layer = self.job_def.asset_layer
903+
assets_def = asset_layer.assets_def_for_node(self.node_handle) if asset_layer else None
904+
if assets_def is not None:
905+
for asset_key in assets_def.keys:
906+
partitions_def = self.job_def.asset_layer.get(asset_key).partitions_def
907+
if partitions_def is not None:
908+
return partitions_def
909+
910+
return None
911+
912+
@property
913+
def partition_key(self) -> str:
914+
from dagster._core.definitions.multi_dimensional_partitions import (
915+
MultiPartitionsDefinition,
916+
get_multipartition_key_from_tags,
917+
)
918+
919+
if not self.has_partitions:
920+
raise DagsterInvariantViolationError(
921+
"Cannot access partition_key for a non-partitioned run"
922+
)
923+
924+
tags = self._plan_data.dagster_run.tags
925+
if any([tag.startswith(MULTIDIMENSIONAL_PARTITION_PREFIX) for tag in tags.keys()]):
926+
return get_multipartition_key_from_tags(tags)
927+
elif PARTITION_NAME_TAG in tags:
928+
return tags[PARTITION_NAME_TAG]
929+
else:
930+
range_start = tags[ASSET_PARTITION_RANGE_START_TAG]
931+
range_end = tags[ASSET_PARTITION_RANGE_END_TAG]
932+
933+
if range_start != range_end:
934+
raise DagsterInvariantViolationError(
935+
"Cannot access partition_key for a partitioned run with a range of partitions."
936+
" Call partition_key_range instead."
937+
)
938+
else:
939+
if isinstance(self.run_partitions_def, MultiPartitionsDefinition):
940+
return self.run_partitions_def.get_partition_key_from_str(
941+
cast(str, range_start)
942+
)
943+
return cast(str, range_start)
944+
945+
@property
946+
def partition_key_range(self) -> PartitionKeyRange:
947+
from dagster._core.definitions.multi_dimensional_partitions import (
948+
MultiPartitionsDefinition,
949+
get_multipartition_key_from_tags,
950+
)
951+
952+
if not self.has_partitions:
953+
raise DagsterInvariantViolationError(
954+
"Cannot access partition_key for a non-partitioned run"
955+
)
956+
957+
tags = self._plan_data.dagster_run.tags
958+
if any([tag.startswith(MULTIDIMENSIONAL_PARTITION_PREFIX) for tag in tags.keys()]):
959+
multipartition_key = get_multipartition_key_from_tags(tags)
960+
return PartitionKeyRange(multipartition_key, multipartition_key)
961+
elif PARTITION_NAME_TAG in tags:
962+
partition_key = tags[PARTITION_NAME_TAG]
963+
return PartitionKeyRange(partition_key, partition_key)
964+
else:
965+
partition_key_range_start = tags[ASSET_PARTITION_RANGE_START_TAG]
966+
if partition_key_range_start is not None:
967+
if isinstance(self.run_partitions_def, MultiPartitionsDefinition):
968+
return PartitionKeyRange(
969+
self.run_partitions_def.get_partition_key_from_str(
970+
partition_key_range_start
971+
),
972+
self.run_partitions_def.get_partition_key_from_str(
973+
tags[ASSET_PARTITION_RANGE_END_TAG]
974+
),
975+
)
976+
return PartitionKeyRange(partition_key_range_start, tags[ASSET_PARTITION_RANGE_END_TAG])
977+
957978
def has_asset_partitions_for_input(self, input_name: str) -> bool:
958979
asset_layer = self.job_def.asset_layer
959980
upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name)
@@ -999,11 +1020,11 @@ def asset_partitions_subset_for_input(
9991020
upstream_asset_partitions_def = asset_layer.get(upstream_asset_key).partitions_def
10001021

10011022
if upstream_asset_partitions_def is not None:
1002-
partitions_def = assets_def.partitions_def if assets_def else None
1023+
partitions_def = self.asset_partitions_def if assets_def else None
10031024
partitions_subset = (
10041025
partitions_def.empty_subset().with_partition_key_range(
10051026
partitions_def,
1006-
self.asset_partition_key_range,
1027+
self.partition_key_range,
10071028
dynamic_partitions_store=self.instance,
10081029
)
10091030
if partitions_def
@@ -1116,12 +1137,7 @@ def asset_partitions_time_window_for_output(self, output_name: str) -> TimeWindo
11161137

11171138
@property
11181139
def partition_time_window(self) -> TimeWindow:
1119-
asset_layer = self.job_def.asset_layer
1120-
partitions_def = self.job_def.partitions_def
1121-
if asset_layer:
1122-
assets_def = asset_layer.assets_def_for_node(self.node_handle)
1123-
if assets_def:
1124-
partitions_def = assets_def.partitions_def
1140+
partitions_def = self.run_partitions_def
11251141

11261142
if partitions_def is None:
11271143
raise DagsterInvariantViolationError("Partitions definition is not defined")

0 commit comments

Comments
 (0)