Skip to content

Commit

Permalink
PyGrain performance and debugging tool
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725628257
  • Loading branch information
Grain Team authored and copybara-github committed Feb 11, 2025
1 parent 6bd8ecf commit c19d9dd
Show file tree
Hide file tree
Showing 12 changed files with 531 additions and 18 deletions.
1 change: 1 addition & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ bazel_dep(name = "rules_python", version = "0.34.0")
bazel_dep(name = "pybind11_bazel", version = "2.13.6")
bazel_dep(name = "abseil-py", version = "2.1.0")
bazel_dep(name = "abseil-cpp", version = "20240722.0")
bazel_dep(name = "protobuf", version = "29.0", repo_name = "com_google_protobuf")

python = use_extension("@rules_python//python/extensions:python.bzl", "python")

Expand Down
2 changes: 2 additions & 0 deletions grain/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ py_library(
"python/experimental.py",
"python/fast_proto.py",
],
data = ["//grain/_src/python/experimental/index_shuffle/python:index_shuffle_module.so"],
srcs_version = "PY3",
# Implicit build flag
visibility = ["//visibility:public"],
Expand Down Expand Up @@ -44,5 +45,6 @@ py_library(
"//grain/_src/python/dataset/transformations:zip",
"//grain/_src/python/experimental/example_packing:packing",
"//grain/_src/python/testing:experimental",
"//grain/proto:execution_summary_py_pb2",
],
)
1 change: 1 addition & 0 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ py_library(
":options",
":record",
":shared_memory_array",
"//grain/_src/core:config",
"//grain/_src/core:parallel",
"//grain/_src/core:tree_lib",
"@abseil-py//absl/logging",
Expand Down
4 changes: 4 additions & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ py_library(
"//grain/_src/python:grain_pool",
"//grain/_src/python:options",
"//grain/_src/python:shared_memory_array",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/logging",
"@pypi//cloudpickle:pkg",
"@pypi//numpy:pkg",
Expand All @@ -62,6 +63,7 @@ py_test(
"//grain/_src/core:transforms",
"//grain/_src/python:options",
"//grain/_src/python/testing:experimental",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:flagsaver",
"@abseil-py//absl/testing:parameterized",
Expand Down Expand Up @@ -105,6 +107,7 @@ py_library(
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"//grain/_src/core:tree_lib",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/logging",
],
)
Expand All @@ -117,6 +120,7 @@ py_test(
":dataset",
":stats",
"//grain/_src/core:transforms",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:flagsaver",
"@pypi//cloudpickle:pkg",
Expand Down
16 changes: 16 additions & 0 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from grain._src.python import options as grain_options
from grain._src.python.dataset import base
from grain._src.python.dataset import stats as dataset_stats
from grain.proto import execution_summary_pb2
import numpy as np

from grain._src.core import monitoring
Expand Down Expand Up @@ -1287,3 +1288,18 @@ def apply_transformations(
f"Transformation type: {transformation} is not supported."
)
return ds


def get_execution_summary(
ds: DatasetIterator,
) -> execution_summary_pb2.ExecutionSummary:
"""Returns the execution summary for the dataset."""
# pylint: disable=protected-access
execution_stats = ds._stats
if not isinstance(execution_stats, dataset_stats._ExecutionStats):
raise ValueError(
"Set `grain_py_debug_mode` or set `execution_tracking_mode` in grain"
" options to `STAGE_TIMING` to enable execution statistics collection."
)
return execution_stats._get_execution_summary()
# pylint: enable=protected-access
50 changes: 50 additions & 0 deletions grain/_src/python/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
import grain._src.python.testing.experimental as test_util
from grain.proto import execution_summary_pb2
import numpy as np
from typing_extensions import override

Expand Down Expand Up @@ -907,5 +908,54 @@ def test_conflicting_options(self):
)


class GetExecutionSummaryTest(parameterized.TestCase):

def test_get_execution_summary_without_collection(self):
ds = dataset.MapDataset.range(10).shuffle(42)
ds = ds.to_iter_dataset()
it = ds.__iter__()
with self.assertRaisesRegex(
ValueError,
"Set `grain_py_debug_mode` or set `execution_tracking_mode` in grain"
" options to `STAGE_TIMING` to enable execution statistics collection.",
):
dataset.get_execution_summary(it)

@mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05)
@mock.patch.object(dataset_stats, "_LOG_EXECUTION_SUMMARY_PERIOD_SEC", 0.06)
@flagsaver.flagsaver(grain_py_debug_mode=True)
def test_execution_summary_with_logging(self):
with self.assertLogs(level="INFO") as logs:
ds = dataset.MapDataset.range(10).shuffle(42)
ds = ds.map(MapTransformAddingOne())
ds = ds.to_iter_dataset()
it = ds.__iter__()
# Get execution summary after iterating through the dataset.
_ = list(it)
# reporting stats after 0.05 seconds.
time.sleep(0.1)
log_value = "Grain Dataset Execution Summary"
self.assertRegex("".join(logs.output), log_value)

@mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05)
@mock.patch.object(dataset_stats, "_LOG_EXECUTION_SUMMARY_PERIOD_SEC", 0.06)
def test_execution_summary_with_no_logging(self):
with self.assertNoLogs(level="INFO"):
ds = dataset.MapDataset.range(10).shuffle(42)
ds = ds.map(MapTransformAddingOne())
ds = ds.to_iter_dataset()
ds = dataset.WithOptionsIterDataset(
ds,
base.DatasetOptions(
execution_tracking_mode=base.ExecutionTrackingMode.STAGE_TIMING
),
)
it = ds.__iter__()
# Get execution summary after iterating through the dataset.
_ = list(it)
# reporting stats after 0.05 seconds.
time.sleep(0.1)


if __name__ == "__main__":
absltest.main()
17 changes: 17 additions & 0 deletions grain/_src/python/dataset/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from grain._src.core import monitoring as grain_monitoring
from grain._src.core import tree_lib
from grain._src.python.dataset import base
from grain.proto import execution_summary_pb2

from grain._src.core import monitoring

Expand Down Expand Up @@ -712,4 +713,20 @@ def make_stats(
),
) -> Stats:
"""Produces statistics instance according to the current execution mode."""
vis_output_dir = grain_config.config.py_dataset_visualization_output_dir
# Only None and "" are supported.
if vis_output_dir:
raise NotImplementedError(
"Saving the dataset graph to a file is not supported yet. Set"
" `grain_py_dataset_visualization_output_dir` to empty string to"
" produce visualization in the logs."
)
if grain_config.config.py_debug_mode:
# In debug mode, we always log the execution summary.
config = dataclasses.replace(config, log_summary=True)
return _ExecutionStats(config, parents=parents)
if execution_tracking_mode == base.ExecutionTrackingMode.STAGE_TIMING:
return _ExecutionStats(config, parents=parents)
if vis_output_dir is not None:
return _VisualizationStats(config, parents=parents)
return _NoopStats(config, parents=parents)
Loading

0 comments on commit c19d9dd

Please sign in to comment.