Skip to content
This repository has been archived by the owner on Mar 11, 2024. It is now read-only.

Commit

Permalink
🐛 log the total number of rows in deltalake append mode
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Jun 30, 2023
1 parent a8050cc commit e10917b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 21 deletions.
70 changes: 51 additions & 19 deletions dagster_polars/io_managers/delta.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pprint import pformat
from typing import Union
from typing import Dict, Union

import polars as pl
from dagster import InputContext, OutputContext
from dagster import InputContext, MetadataValue, OutputContext
from deltalake import DeltaTable
from upath import UPath

Expand All @@ -19,23 +19,6 @@ class PolarsDeltaIOManager(BasePolarsUPathIOManager):
All read/write arguments can be passed via corresponding metadata values."""
)

def get_path_for_partition(self, context: Union[InputContext, OutputContext], path: UPath, partition: str) -> UPath:
if isinstance(context, InputContext):
if (
context.upstream_output is not None
and context.upstream_output.metadata is not None
and context.upstream_output.metadata.get("partition_by") is not None
):
# upstream asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself
return path

if isinstance(context, OutputContext):
if context.metadata is not None and context.metadata.get("partition_by") is not None:
# this asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself
return path

return path / partition # partitioning is handled by the IOManager

def dump_df_to_path(self, context: OutputContext, df: pl.DataFrame, path: UPath):
assert context.metadata is not None

Expand Down Expand Up @@ -73,3 +56,52 @@ def scan_df_from_path(self, path: UPath, context: InputContext) -> pl.LazyFrame:
pyarrow_options=context.metadata.get("pyarrow_options"),
storage_options=self.get_storage_options(path),
)

def get_path_for_partition(self, context: Union[InputContext, OutputContext], path: UPath, partition: str) -> UPath:
if isinstance(context, InputContext):
if (
context.upstream_output is not None
and context.upstream_output.metadata is not None
and context.upstream_output.metadata.get("partition_by") is not None
):
# upstream asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself
return path

if isinstance(context, OutputContext):
if context.metadata is not None and context.metadata.get("partition_by") is not None:
# this asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself
return path

return path / partition # partitioning is handled by the IOManager

def get_metadata(self, context: OutputContext, obj: pl.DataFrame) -> Dict[str, MetadataValue]:
assert context.metadata is not None

metadata = super().get_metadata(context, obj)

if context.has_asset_partitions:
partition_by = context.metadata.get("partition_by")
if partition_by is not None:
metadata["partition_by"] = partition_by

if context.metadata.get("mode") == "append":
# FIXME: what to do if we are appending to a partitioned table?
# we should not be using the full table length,
# but it's unclear how to get the length of the partition we are appending to

if context.has_asset_partitions:
paths = self._get_paths_for_partitions(context)
assert len(paths) == 1
path = list(paths.values())[0]
else:
path = self._get_path(context)

if not context.has_asset_partitions:
# we need to get num_rows from the full table
metadata["num_rows"] = MetadataValue.int(
DeltaTable(str(path), storage_options=self.get_storage_options(path))
.to_pyarrow_dataset()
.count_rows()
)

return metadata
7 changes: 5 additions & 2 deletions tests/test_polars_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,15 @@ def append_asset() -> pl.DataFrame:
)

handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("append_asset")))
saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore[index,union-attr]
saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore
assert handled_output_events[0].event_specific_data.metadata["num_rows"].value == 3 # type: ignore
assert isinstance(saved_path, str)

materialize(
result = materialize(
[append_asset],
)
handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("append_asset")))
assert handled_output_events[0].event_specific_data.metadata["num_rows"].value == 6 # type: ignore

pl_testing.assert_frame_equal(pl.concat([df, df]), pl.read_delta(saved_path))

Expand Down

0 comments on commit e10917b

Please sign in to comment.