Skip to content

Commit

Permalink
AssetSpec.partitions_def
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza authored and OwenKephart committed Dec 19, 2024
1 parent 37f6bef commit f490efa
Show file tree
Hide file tree
Showing 14 changed files with 350 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def owners(self) -> Sequence[str]:

@property
def is_partitioned(self) -> bool:
return self.assets_def.partitions_def is not None
return self.partitions_def is not None

@property
def partitions_def(self) -> Optional[PartitionsDefinition]:
return self.assets_def.partitions_def
return self.assets_def.specs_by_key[self.key].partitions_def

@property
def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.input import NoValueSentinel
from dagster._core.definitions.output import Out
from dagster._core.definitions.partition import PartitionsDefinition
from dagster._core.definitions.utils import resolve_automation_condition
from dagster._core.errors import DagsterInvalidDefinitionError
from dagster._core.types.dagster_type import DagsterType
Expand Down Expand Up @@ -217,12 +218,17 @@ def to_out(self) -> Out:
)

def to_spec(
self, key: AssetKey, deps: Sequence[AssetDep], additional_tags: Mapping[str, str] = {}
self,
key: AssetKey,
deps: Sequence[AssetDep],
additional_tags: Mapping[str, str] = {},
partitions_def: Optional[PartitionsDefinition] = ...,
) -> AssetSpec:
return self._spec.replace_attributes(
key=key,
tags={**additional_tags, **self.tags} if self.tags else additional_tags,
deps=[*self._spec.deps, *deps],
partitions_def=partitions_def,
)

@public
Expand Down
169 changes: 62 additions & 107 deletions python_modules/dagster/dagster/_core/definitions/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
AssetSpec,
)
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.backfill_policy import BackfillPolicy, BackfillPolicyType
from dagster._core.definitions.backfill_policy import BackfillPolicy
from dagster._core.definitions.declarative_automation.automation_condition import (
AutomationCondition,
)
Expand Down Expand Up @@ -107,9 +107,9 @@ class AssetsDefinition(ResourceAddable, IHasInternalInit):
"descriptions_by_key",
"asset_deps",
"owners_by_key",
"partitions_def",
}

_partitions_def: Optional[PartitionsDefinition]
# partition mappings are also tracked inside the AssetSpecs, but this enables faster access by
# upstream asset key
_partition_mappings: Mapping[AssetKey, PartitionMapping]
Expand Down Expand Up @@ -229,24 +229,10 @@ def __init__(
execution_type=execution_type or AssetExecutionType.MATERIALIZATION,
)

self._partitions_def = _resolve_partitions_def(specs, partitions_def)

self._resource_defs = wrap_resources_for_execution(
check.opt_mapping_param(resource_defs, "resource_defs")
)

if self._partitions_def is None:
# check if backfill policy is BackfillPolicyType.SINGLE_RUN if asset is not partitioned
check.param_invariant(
(
backfill_policy.policy_type is BackfillPolicyType.SINGLE_RUN
if backfill_policy
else True
),
"backfill_policy",
"Non partitioned asset can only have single run backfill policy",
)

if specs is not None:
check.invariant(group_names_by_key is None)
check.invariant(metadata_by_key is None)
Expand All @@ -258,6 +244,7 @@ def __init__(
check.invariant(owners_by_key is None)
check.invariant(partition_mappings is None)
check.invariant(asset_deps is None)
check.invariant(partitions_def is None)
resolved_specs = specs

else:
Expand Down Expand Up @@ -297,6 +284,7 @@ def __init__(
metadata_by_key=metadata_by_key,
descriptions_by_key=descriptions_by_key,
code_versions_by_key=None,
partitions_def=partitions_def,
)

normalized_specs: List[AssetSpec] = []
Expand Down Expand Up @@ -333,11 +321,11 @@ def __init__(
check.invariant(
not (
spec.freshness_policy
and self._partitions_def is not None
and not isinstance(self._partitions_def, TimeWindowPartitionsDefinition)
and spec.partitions_def is not None
and not isinstance(spec.partitions_def, TimeWindowPartitionsDefinition)
),
"FreshnessPolicies are currently unsupported for assets with partitions of type"
f" {type(self._partitions_def)}.",
f" {spec.partitions_def}.",
)

normalized_specs.append(
Expand All @@ -347,10 +335,19 @@ def __init__(
metadata=metadata,
description=description,
skippable=skippable,
partitions_def=self._partitions_def,
)
)

unique_partitions_defs = {
spec.partitions_def for spec in normalized_specs if spec.partitions_def is not None
}
if len(unique_partitions_defs) > 1 and not can_subset:
raise DagsterInvalidDefinitionError(
"If different AssetSpecs have different partitions_defs, can_subset must be True"
)

_validate_self_deps(normalized_specs)

self._specs_by_key = {spec.key: spec for spec in normalized_specs}

self._partition_mappings = get_partition_mappings_from_deps(
Expand All @@ -363,27 +360,11 @@ def __init__(
spec.key: spec for spec in self._check_specs_by_output_name.values()
}

if self._computation:
_validate_self_deps(
input_keys=[
key
# filter out the special inputs which are used for cases when a multi-asset is
# subsetted, as these are not the same as self-dependencies and are never loaded
# in the same step that their corresponding output is produced
for input_name, key in self._computation.keys_by_input_name.items()
if not input_name.startswith(ASSET_SUBSET_INPUT_PREFIX)
],
output_keys=self._computation.selected_asset_keys,
partition_mappings=self._partition_mappings,
partitions_def=self._partitions_def,
)

def dagster_internal_init(
*,
keys_by_input_name: Mapping[str, AssetKey],
keys_by_output_name: Mapping[str, AssetKey],
node_def: NodeDefinition,
partitions_def: Optional[PartitionsDefinition],
selected_asset_keys: Optional[AbstractSet[AssetKey]],
can_subset: bool,
resource_defs: Optional[Mapping[str, object]],
Expand All @@ -400,7 +381,6 @@ def dagster_internal_init(
keys_by_input_name=keys_by_input_name,
keys_by_output_name=keys_by_output_name,
node_def=node_def,
partitions_def=partitions_def,
selected_asset_keys=selected_asset_keys,
can_subset=can_subset,
resource_defs=resource_defs,
Expand Down Expand Up @@ -771,17 +751,13 @@ def _output_dict_to_asset_dict(
metadata_by_key=_output_dict_to_asset_dict(metadata_by_output_name),
descriptions_by_key=_output_dict_to_asset_dict(descriptions_by_output_name),
code_versions_by_key=_output_dict_to_asset_dict(code_versions_by_output_name),
partitions_def=partitions_def,
)

return AssetsDefinition.dagster_internal_init(
keys_by_input_name=keys_by_input_name,
keys_by_output_name=keys_by_output_name_with_prefix,
node_def=node_def,
partitions_def=check.opt_inst_param(
partitions_def,
"partitions_def",
PartitionsDefinition,
),
resource_defs=resource_defs,
backfill_policy=check.opt_inst_param(
backfill_policy, "backfill_policy", BackfillPolicy
Expand Down Expand Up @@ -1044,10 +1020,20 @@ def backfill_policy(self) -> Optional[BackfillPolicy]:
return self._computation.backfill_policy if self._computation else None

@public
@property
@cached_property
def partitions_def(self) -> Optional[PartitionsDefinition]:
"""Optional[PartitionsDefinition]: The PartitionsDefinition for this AssetsDefinition (if any)."""
return self._partitions_def
partitions_defs = {
spec.partitions_def for spec in self.specs if spec.partitions_def is not None
}
if len(partitions_defs) == 1:
return next(iter(partitions_defs))
elif len(partitions_defs) == 0:
return None
else:
check.failed(
"Different assets within this AssetsDefinition have different PartitionsDefinitions"
)

@property
def metadata_by_key(self) -> Mapping[AssetKey, ArbitraryMetadataMapping]:
Expand Down Expand Up @@ -1138,12 +1124,17 @@ def get_partition_mapping_for_dep(self, dep_key: AssetKey) -> Optional[Partition
return self._partition_mappings.get(dep_key)

def infer_partition_mapping(
self, upstream_asset_key: AssetKey, upstream_partitions_def: Optional[PartitionsDefinition]
self,
asset_key: AssetKey,
upstream_asset_key: AssetKey,
upstream_partitions_def: Optional[PartitionsDefinition],
) -> PartitionMapping:
with disable_dagster_warnings():
partition_mapping = self._partition_mappings.get(upstream_asset_key)
return infer_partition_mapping(
partition_mapping, self._partitions_def, upstream_partitions_def
partition_mapping,
self.specs_by_key[asset_key].partitions_def,
upstream_partitions_def,
)

def has_output_for_asset_key(self, key: AssetKey) -> bool:
Expand Down Expand Up @@ -1398,7 +1389,7 @@ def _output_to_source_asset(self, output_name: str) -> SourceAsset:
io_manager_key=output_def.io_manager_key,
description=spec.description,
resource_defs=self.resource_defs,
partitions_def=self.partitions_def,
partitions_def=spec.partitions_def,
group_name=spec.group_name,
tags=spec.tags,
io_manager_def=None,
Expand Down Expand Up @@ -1504,7 +1495,6 @@ def get_attributes_dict(self) -> Dict[str, Any]:
keys_by_input_name=self.node_keys_by_input_name,
keys_by_output_name=self.node_keys_by_output_name,
node_def=self._computation.node_def if self._computation else None,
partitions_def=self._partitions_def,
selected_asset_keys=self.keys,
can_subset=self.can_subset,
resource_defs=self._resource_defs,
Expand Down Expand Up @@ -1700,6 +1690,7 @@ def _asset_specs_from_attr_key_params(
code_versions_by_key: Optional[Mapping[AssetKey, str]],
descriptions_by_key: Optional[Mapping[AssetKey, str]],
owners_by_key: Optional[Mapping[AssetKey, Sequence[str]]],
partitions_def: Optional[PartitionsDefinition],
) -> Sequence[AssetSpec]:
validated_group_names_by_key = check.opt_mapping_param(
group_names_by_key, "group_names_by_key", key_type=AssetKey, value_type=str
Expand Down Expand Up @@ -1772,41 +1763,37 @@ def _asset_specs_from_attr_key_params(
# NodeDefinition
skippable=False,
auto_materialize_policy=None,
partitions_def=None,
kinds=None,
partitions_def=check.opt_inst_param(
partitions_def, "partitions_def", PartitionsDefinition
),
)
)

return result


def _validate_self_deps(
input_keys: Iterable[AssetKey],
output_keys: Iterable[AssetKey],
partition_mappings: Mapping[AssetKey, PartitionMapping],
partitions_def: Optional[PartitionsDefinition],
) -> None:
output_keys_set = set(output_keys)
for input_key in input_keys:
if input_key in output_keys_set:
if input_key in partition_mappings:
partition_mapping = partition_mappings[input_key]
time_window_partition_mapping = get_self_dep_time_window_partition_mapping(
partition_mapping, partitions_def
)
if (
time_window_partition_mapping is not None
and (time_window_partition_mapping.start_offset or 0) < 0
and (time_window_partition_mapping.end_offset or 0) < 0
):
continue
def _validate_self_deps(specs: Iterable[AssetSpec]) -> None:
for spec in specs:
for dep in spec.deps:
if dep.asset_key == spec.key:
if dep.partition_mapping:
time_window_partition_mapping = get_self_dep_time_window_partition_mapping(
dep.partition_mapping, spec.partitions_def
)
if (
time_window_partition_mapping is not None
and (time_window_partition_mapping.start_offset or 0) < 0
and (time_window_partition_mapping.end_offset or 0) < 0
):
continue

raise DagsterInvalidDefinitionError(
f'Asset "{input_key.to_user_string()}" depends on itself. Assets can only depend'
" on themselves if they are:\n(a) time-partitioned and each partition depends on"
" earlier partitions\n(b) multipartitioned, with one time dimension that depends"
" on earlier time partitions"
)
raise DagsterInvalidDefinitionError(
f'Asset "{spec.key.to_user_string()}" depends on itself. Assets can only depend'
" on themselves if they are:\n(a) time-partitioned and each partition depends on"
" earlier partitions\n(b) multipartitioned, with one time dimension that depends"
" on earlier time partitions"
)


def get_self_dep_time_window_partition_mapping(
Expand Down Expand Up @@ -1834,38 +1821,6 @@ def get_self_dep_time_window_partition_mapping(
return None


def _resolve_partitions_def(
specs: Optional[Sequence[AssetSpec]], partitions_def: Optional[PartitionsDefinition]
) -> Optional[PartitionsDefinition]:
if specs:
asset_keys_by_partitions_def = defaultdict(set)
for spec in specs:
asset_keys_by_partitions_def[spec.partitions_def].add(spec.key)
if len(asset_keys_by_partitions_def) > 1:
partition_1_asset_keys, partition_2_asset_keys, *_ = (
asset_keys_by_partitions_def.values()
)
check.failed(
f"All AssetSpecs must have the same partitions_def, but "
f"{next(iter(partition_1_asset_keys)).to_user_string()} and "
f"{next(iter(partition_2_asset_keys)).to_user_string()} have different "
"partitions_defs."
)
common_partitions_def = next(iter(asset_keys_by_partitions_def.keys()))
if (
common_partitions_def is not None
and partitions_def is not None
and common_partitions_def != partitions_def
):
check.failed(
f"AssetSpec for {next(iter(specs)).key.to_user_string()} has partitions_def which is different "
"than the partitions_def provided to AssetsDefinition.",
)
return partitions_def or common_partitions_def
else:
return partitions_def


def get_partition_mappings_from_deps(
partition_mappings: Dict[AssetKey, PartitionMapping], deps: Iterable[AssetDep], asset_name: str
) -> Mapping[AssetKey, PartitionMapping]:
Expand Down
Loading

0 comments on commit f490efa

Please sign in to comment.