diff --git a/MODULE.bazel b/MODULE.bazel index fc290873..77722fbd 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -23,6 +23,7 @@ bazel_dep(name = "rules_python", version = "0.34.0") bazel_dep(name = "pybind11_bazel", version = "2.13.6") bazel_dep(name = "abseil-py", version = "2.1.0") bazel_dep(name = "abseil-cpp", version = "20240722.0") +bazel_dep(name = "protobuf", version = "29.0", repo_name = "com_google_protobuf") python = use_extension("@rules_python//python/extensions:python.bzl", "python") diff --git a/grain/BUILD b/grain/BUILD index 504c25dc..1d533744 100644 --- a/grain/BUILD +++ b/grain/BUILD @@ -16,6 +16,7 @@ py_library( "python/experimental.py", "python/fast_proto.py", ], + data = ["//grain/_src/python/experimental/index_shuffle/python:index_shuffle_module.so"], srcs_version = "PY3", # Implicit build flag visibility = ["//visibility:public"], @@ -44,5 +45,6 @@ py_library( "//grain/_src/python/dataset/transformations:zip", "//grain/_src/python/experimental/example_packing:packing", "//grain/_src/python/testing:experimental", + "//grain/proto:execution_summary_py_pb2", ], ) diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 94cffe02..7d847e73 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -203,6 +203,7 @@ py_library( ":options", ":record", ":shared_memory_array", + "//grain/_src/core:config", "//grain/_src/core:parallel", "//grain/_src/core:tree_lib", "@abseil-py//absl/logging", diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 8ee63fa3..a960d686 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -43,6 +43,7 @@ py_library( "//grain/_src/python:grain_pool", "//grain/_src/python:options", "//grain/_src/python:shared_memory_array", + "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/logging", "@pypi//cloudpickle:pkg", "@pypi//numpy:pkg", @@ -62,6 +63,7 @@ py_test( "//grain/_src/core:transforms", "//grain/_src/python:options", "//grain/_src/python/testing:experimental", + "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:flagsaver", "@abseil-py//absl/testing:parameterized", @@ -105,6 +107,7 @@ py_library( "//grain/_src/core:config", "//grain/_src/core:monitoring", "//grain/_src/core:tree_lib", + "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/logging", ], ) @@ -117,6 +120,7 @@ py_test( ":dataset", ":stats", "//grain/_src/core:transforms", + "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:flagsaver", "@pypi//cloudpickle:pkg", diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index 6a24a83a..b6bef8d7 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -58,6 +58,7 @@ from grain._src.python import options as grain_options from grain._src.python.dataset import base from grain._src.python.dataset import stats as dataset_stats +from grain.proto import execution_summary_pb2 import numpy as np from grain._src.core import monitoring @@ -1287,3 +1288,18 @@ def apply_transformations( f"Transformation type: {transformation} is not supported." ) return ds + + +def get_execution_summary( + ds: DatasetIterator, +) -> execution_summary_pb2.ExecutionSummary: + """Returns the execution summary for the dataset.""" + # pylint: disable=protected-access + execution_stats = ds._stats + if not isinstance(execution_stats, dataset_stats._ExecutionStats): + raise ValueError( + "Set `grain_py_debug_mode` or set `execution_tracking_mode` in grain" + " options to `STAGE_TIMING` to enable execution statistics collection." + ) + return execution_stats._get_execution_summary() + # pylint: enable=protected-access diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index 035b6090..d0ee4662 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -31,6 +31,7 @@ from grain._src.python.dataset import dataset from grain._src.python.dataset import stats as dataset_stats import grain._src.python.testing.experimental as test_util +from grain.proto import execution_summary_pb2 import numpy as np from typing_extensions import override @@ -907,5 +908,54 @@ def test_conflicting_options(self): ) +class GetExecutionSummaryTest(parameterized.TestCase): + + def test_get_execution_summary_without_collection(self): + ds = dataset.MapDataset.range(10).shuffle(42) + ds = ds.to_iter_dataset() + it = ds.__iter__() + with self.assertRaisesRegex( + ValueError, + "Set `grain_py_debug_mode` or set `execution_tracking_mode` in grain" + " options to `STAGE_TIMING` to enable execution statistics collection.", + ): + dataset.get_execution_summary(it) + + @mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05) + @mock.patch.object(dataset_stats, "_LOG_EXECUTION_SUMMARY_PERIOD_SEC", 0.06) + @flagsaver.flagsaver(grain_py_debug_mode=True) + def test_execution_summary_with_logging(self): + with self.assertLogs(level="INFO") as logs: + ds = dataset.MapDataset.range(10).shuffle(42) + ds = ds.map(MapTransformAddingOne()) + ds = ds.to_iter_dataset() + it = ds.__iter__() + # Get execution summary after iterating through the dataset. + _ = list(it) + # reporting stats after 0.05 seconds. + time.sleep(0.1) + log_value = "Grain Dataset Execution Summary" + self.assertRegex("".join(logs.output), log_value) + + @mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05) + @mock.patch.object(dataset_stats, "_LOG_EXECUTION_SUMMARY_PERIOD_SEC", 0.06) + def test_execution_summary_with_no_logging(self): + with self.assertNoLogs(level="INFO"): + ds = dataset.MapDataset.range(10).shuffle(42) + ds = ds.map(MapTransformAddingOne()) + ds = ds.to_iter_dataset() + ds = dataset.WithOptionsIterDataset( + ds, + base.DatasetOptions( + execution_tracking_mode=base.ExecutionTrackingMode.STAGE_TIMING + ), + ) + it = ds.__iter__() + # Get execution summary after iterating through the dataset. + _ = list(it) + # reporting stats after 0.05 seconds. + time.sleep(0.1) + + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index c4a03d95..5ff329aa 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -31,6 +31,7 @@ from grain._src.core import monitoring as grain_monitoring from grain._src.core import tree_lib from grain._src.python.dataset import base +from grain.proto import execution_summary_pb2 from grain._src.core import monitoring @@ -712,4 +713,20 @@ def make_stats( ), ) -> Stats: """Produces statistics instance according to the current execution mode.""" + vis_output_dir = grain_config.config.py_dataset_visualization_output_dir + # Only None and "" are supported. + if vis_output_dir: + raise NotImplementedError( + "Saving the dataset graph to a file is not supported yet. Set" + " `grain_py_dataset_visualization_output_dir` to empty string to" + " produce visualization in the logs." + ) + if grain_config.config.py_debug_mode: + # In debug mode, we always log the execution summary. + config = dataclasses.replace(config, log_summary=True) + return _ExecutionStats(config, parents=parents) + if execution_tracking_mode == base.ExecutionTrackingMode.STAGE_TIMING: + return _ExecutionStats(config, parents=parents) + if vis_output_dir is not None: + return _VisualizationStats(config, parents=parents) return _NoopStats(config, parents=parents) diff --git a/grain/_src/python/dataset/stats_test.py b/grain/_src/python/dataset/stats_test.py index 803e0bd2..bd46d853 100644 --- a/grain/_src/python/dataset/stats_test.py +++ b/grain/_src/python/dataset/stats_test.py @@ -23,6 +23,7 @@ from grain._src.core import transforms from grain._src.python.dataset import dataset from grain._src.python.dataset import stats +from grain.proto import execution_summary_pb2 from absl.testing import absltest @@ -100,7 +101,7 @@ "[]" ││ - ││ MapDatasetIterator(transform= @ .../python/dataset/stats_test.py:525) + ││ MapDatasetIterator(transform= @ .../python/dataset/stats_test.py:524) ││ ╲╱ {'data': "[]", @@ -232,5 +233,347 @@ def test_report(self): s = s._parents[0] s.report() + +class DebugModeStatsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.enter_context(flagsaver.flagsaver(grain_py_debug_mode=True)) + + @mock.patch.object(stats, "_REPORTING_PERIOD_SEC", 0.05) + def test_record_stats(self): + s = _make_stats_tree(stats.make_stats) + self.assertIsInstance(s, stats._ExecutionStats) + flat_stats = [] + to_visit = [s] + while to_visit: + node = to_visit.pop(0) + flat_stats.append(node) + to_visit.extend(node._parents) + + reported_self_times = collections.defaultdict(int) + + def mock_report(node): + while node._self_times_buffer: + reported_self_times[id(node)] += node._self_times_buffer.pop() + for p in node._parents: + p.report() + + for node in flat_stats: + node.report = functools.partial(mock_report, node) + for node in flat_stats: + with node.record_self_time(offset_ns=10**9): + time.sleep(0.5) + time.sleep(0.05) + self_times = list(reported_self_times.values()) + self.assertLen(self_times, len(flat_stats)) + for self_time in self_times: + self.assertGreaterEqual(self_time, 1.05 * 10**9) + + @mock.patch.object(stats, "_REPORTING_PERIOD_SEC", 0.05) + def test_record_stats_thread_safe(self): + s = stats.make_stats(stats.StatsConfig(name="test_stats"), ()) + reported_self_time = 0 + + def mock_report(node): + while node._self_times_buffer: + nonlocal reported_self_time + reported_self_time += node._self_times_buffer.pop() + for p in node._parents: + p.report() + + s.report = functools.partial(mock_report, s) + + def record_self_time(): + with s.record_self_time(): + # Sleep releases GIL, so this will actually execute concurrently. + time.sleep(1) + + n_threads = 100 + recording_threads = [] + for _ in range(n_threads): + t = threading.Thread(target=record_self_time) + t.start() + recording_threads.append(t) + for t in recording_threads: + t.join() + time.sleep(0.05) + self.assertGreaterEqual(reported_self_time, n_threads) + + def test_picklable(self): + s = stats.make_stats(stats.StatsConfig(name="test_stats"), ()) + self.assertIsInstance(s, stats._ExecutionStats) + s = cloudpickle.loads(cloudpickle.dumps(s)) + self.assertIsInstance(s, stats._ExecutionStats) + with s.record_self_time(): + time.sleep(0.5) + s = cloudpickle.loads(cloudpickle.dumps(s)) + self.assertIsInstance(s, stats._ExecutionStats) + + def test_dataset_visualization(self): + ds = ( + dataset.MapDataset.range(10) + .seed(42) + .shuffle() + .slice(slice(1, None, 3)) + .map_with_index(_add_dummy_metadata) + .map(_identity) + .repeat(2) + ) + # Visualization graph is constructed while iterating through pipeline. + _ = list(ds) + self.assertIsInstance(ds._stats, stats._ExecutionStats) + self.assertEqual(ds._stats._visualize_dataset_graph(), _MAP_DATASET_REPR) + + def test_pretty_print_execution_summary(self): + dummy_summary = execution_summary_pb2.ExecutionSummary() + dummy_summary.nodes[0].CopyFrom( + execution_summary_pb2.ExecutionSummary.Node( + id=0, + name="MapDatasetIterator(transform=_MapFnFromPreprocessingBuilder(preprocessing_builder=NextTokenAsTargetTextPreprocessingBuilder))", + inputs=[1], + wait_time_ratio=0.5, + total_processing_time_ns=0, + min_processing_time_ns=400_000, + max_processing_time_ns=0, + num_produced_elements=0, + output_spec="[]", + ) + ) + dummy_summary.nodes[1].CopyFrom( + execution_summary_pb2.ExecutionSummary.Node( + id=1, + name="PrefetchDatasetIterator", + inputs=[2], + wait_time_ratio=0.5, + total_processing_time_ns=400_000, + min_processing_time_ns=400, + max_processing_time_ns=40000, + num_produced_elements=10, + output_spec="[]", + is_output=True, + is_prefetch=True, + ) + ) + dummy_summary.nodes[2].CopyFrom( + execution_summary_pb2.ExecutionSummary.Node( + id=2, + name="MapMapDataset", + inputs=[3, 4], + wait_time_ratio=0.375, + total_processing_time_ns=400_000_000, + min_processing_time_ns=4000, + max_processing_time_ns=40_000_000, + num_produced_elements=10, + output_spec="[]", + ) + ) + dummy_summary.nodes[3].CopyFrom( + execution_summary_pb2.ExecutionSummary.Node( + id=3, + name="RangeMapDataset", + wait_time_ratio=0.125, + total_processing_time_ns=4000_000_000, + min_processing_time_ns=400_000, + max_processing_time_ns=400_000_000, + num_produced_elements=10, + inputs=[], + output_spec="[]", + ) + ) + dummy_summary.nodes[4].CopyFrom( + execution_summary_pb2.ExecutionSummary.Node( + id=4, + name="RangeMapDataset", + total_processing_time_ns=0, + wait_time_ratio=0, + min_processing_time_ns=400_000, + max_processing_time_ns=0, + num_produced_elements=0, + inputs=[], + output_spec="[]", + ) + ) + + expected_result = """ +|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | +|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 4 | RangeMapDataset | [] | 0.00% | N/A | N/A | N/A | N/A | N/A | +|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 3 | RangeMapDataset | [] | 12.50% | 4.00s | 400.00us | 400.00ms | 400.00ms | 10 | +|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 2 | MapMapDataset | [3, 4] | 37.50% | 400.00ms | 4.00us | 40.00ms | 40.00ms | 10 | +|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 1 | PrefetchDatasetIterator | [2] | N/A | 400.00us | 400ns | 40.00us | 40.00us | 10 | +|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 0 | MapDatasetIterator(transform=_ | [1] | 50.00% | N/A | N/A | N/A | N/A | N/A | +| | MapFnFromPreprocessingBuilder( | | | | | | | | +| | preprocessing_builder=NextToke | | | | | | | | +| | nAsTargetTextPreprocessingBuil | | | | | | | | +| | der)) | | | | | | | | +|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +""" + self.assertEqual( + expected_result, + "\n" + stats._pretty_format_summary(dummy_summary), + ) + + def test_compute_iterator_wait_time_ratio(self): + dummy_summary = execution_summary_pb2.ExecutionSummary() + dummy_summary.nodes[0].CopyFrom( + execution_summary_pb2.ExecutionSummary.Node( + id=0, + name="MapDatasetIterator", + inputs=[1], + total_processing_time_ns=4000_000_000, + min_processing_time_ns=400, + max_processing_time_ns=40000, + num_produced_elements=10, + output_spec="[]", + is_output=True, + ) + ) + dummy_summary.nodes[1].CopyFrom( + execution_summary_pb2.ExecutionSummary.Node( + id=1, + name="PrefetchDatasetIterator", + inputs=[2], + total_processing_time_ns=4000_000_000, + min_processing_time_ns=4000, + max_processing_time_ns=40_000_000, + num_produced_elements=10, + output_spec="[]", + is_prefetch=True, + ) + ) + dummy_summary.nodes[2].CopyFrom( + execution_summary_pb2.ExecutionSummary.Node( + id=2, + name="MapMapDataset", + inputs=[3], + total_processing_time_ns=1000_000_000, + min_processing_time_ns=400_000, + max_processing_time_ns=400_000_000, + num_produced_elements=10, + output_spec="[]", + ) + ) + dummy_summary.nodes[3].CopyFrom( + execution_summary_pb2.ExecutionSummary.Node( + id=3, + name="RangeMapDataset", + total_processing_time_ns=3000_000_000, + min_processing_time_ns=400_000, + max_processing_time_ns=4000_000, + num_produced_elements=10, + inputs=[], + output_spec="[]", + ) + ) + stats._populate_wait_time_ratio(dummy_summary) + self.assertEqual(dummy_summary.nodes[0].wait_time_ratio, 0.5) + self.assertEqual(dummy_summary.nodes[1].wait_time_ratio, 0) + self.assertEqual(dummy_summary.nodes[2].wait_time_ratio, 0.125) + self.assertEqual(dummy_summary.nodes[3].wait_time_ratio, 0.375) + + @flagsaver.flagsaver(grain_py_dataset_visualization_output_dir="TEST_DIR") + def test_dataset_visualization_with_output_dir(self): + ds = ( + dataset.MapDataset.range(10) + .shuffle(42) + .map_with_index(_add_dummy_metadata) + .map(_identity) + ) + with self.assertRaisesRegex( + NotImplementedError, + "Saving the dataset graph to a file is not supported yet.", + ): + _ = list(ds) + + +class GraphModeStatsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.enter_context( + flagsaver.flagsaver(grain_py_dataset_visualization_output_dir="") + ) + + def test_visualize_map(self): + ds = ( + dataset.MapDataset.range(10) + .seed(42) + .shuffle() + .slice(slice(1, None, 3)) + .map_with_index(_add_dummy_metadata) + .map(_identity) + .repeat(2) + ) + # Visualization graph is constructed while iterating through pipeline. + _ = list(ds) + self.assertIsInstance(ds._stats, stats._VisualizationStats) + self.assertEqual(ds._stats._visualize_dataset_graph(), _MAP_DATASET_REPR) + + def test_visualize_iter(self): + ds = ( + dataset.MapDataset.range(10) + .shuffle(42) + .to_iter_dataset() + .seed(42) + .map(lambda x: _add_dummy_metadata(2, x)) + .batch(2) + ) + # Visualization graph is constructed while iterating through pipeline. + it = ds.__iter__() + _ = list(it) + self.assertIsInstance(it._stats, stats._VisualizationStats) + self.assertEqual(it._stats._visualize_dataset_graph(), _ITER_DATASET_REPR) + + def test_visualize_with_mix(self): + ds1 = dataset.MapDataset.range(10).shuffle(42) + ds2 = dataset.MapDataset.range(10).shuffle(43) + ds = dataset.MapDataset.mix([ds1, ds2]).map(_AddOne()) + # Visualization graph is constructed while iterating through pipeline. + _ = list(ds) + self.assertIsInstance(ds._stats, stats._VisualizationStats) + self.assertEqual(ds._stats._visualize_dataset_graph(), _MIX_DATASET_REPR) + + @flagsaver.flagsaver(grain_py_dataset_visualization_output_dir="TEST_DIR") + def test_dataset_visualization_with_output_dir(self): + ds = ( + dataset.MapDataset.range(10) + .shuffle(42) + .map_with_index(_add_dummy_metadata) + .map(_identity) + ) + with self.assertRaisesRegex( + NotImplementedError, + "Saving the dataset graph to a file is not supported yet.", + ): + _ = list(ds) + + def test_picklable(self): + ds = ( + dataset.MapDataset.range(10) + .seed(42) + .shuffle() + .slice(slice(1, None, 3)) + .map_with_index(_add_dummy_metadata) + .map(_identity) + .repeat(2) + ) + ds = cloudpickle.loads(cloudpickle.dumps(ds)) + # Visualization graph is constructed while iterating through pipeline. + _ = list(ds) + self.assertIsInstance(ds._stats, stats._VisualizationStats) + self.assertEqual(ds._stats._visualize_dataset_graph(), _MAP_DATASET_REPR) + + @flagsaver.flagsaver(grain_py_dataset_visualization_output_dir=None) + def test_dataset_visualization_with_output_dir_none(self): + s = stats.make_stats(stats.StatsConfig(name="test_stats"), ()) + self.assertIsInstance(s, stats._NoopStats) + + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py index 53d9ae8e..f4473143 100644 --- a/grain/_src/python/grain_pool.py +++ b/grain/_src/python/grain_pool.py @@ -61,6 +61,7 @@ import cloudpickle from grain._src.core import parallel from grain._src.core import tree_lib +from grain._src.core.config import config import multiprocessing as mp from grain._src.python import grain_logging from grain._src.python import multiprocessing_common @@ -153,12 +154,35 @@ def deserialize(cls, serialized: bytes) -> GetElementProducerFn[T]: return obj +def parse_debug_flags(debug_flags: dict[str, Any]): + """Parses debug flags.""" + from absl import flags + flags.FLAGS["grain_py_debug_mode"].present = True + flags.FLAGS["grain_py_dataset_visualization_output_dir"].present = True + config.update("py_debug_mode", debug_flags["grain_py_debug_mode"]) + config.update( + "py_dataset_visualization_output_dir", + debug_flags["grain_py_dataset_visualization_output_dir"], + ) + def _initialize_and_get_element_producer( - args_queue: queues.Queue, *, worker_index: int, worker_count: int + args_queue: queues.Queue, + *, + debug_flags: dict[str, Any], + worker_index: int, + worker_count: int, ) -> Iterator[Any]: """Unpickles the element producer from the args queue and closes the queue.""" - serialized_init_fn, serialized_element_producer_fn = args_queue.get() + ( + serialized_flag_parse_fn, + serialized_init_fn, + serialized_element_producer_fn, + ) = args_queue.get() + flag_parse_fn: Callable[[Any], None] = cloudpickle.loads( + serialized_flag_parse_fn + ) + flag_parse_fn(debug_flags) init_fn: Callable[[], None] = cloudpickle.loads(serialized_init_fn) init_fn() element_producer_fn: GetElementProducerFn[Any] = ( @@ -182,6 +206,7 @@ def _worker_loop( worker_index: int, worker_count: int, enable_profiling: bool, + debug_flags: dict[str, Any], ): """Code to be run on each child process.""" out_of_elements = False @@ -192,7 +217,10 @@ def _worker_loop( ) logging.info("Starting work.") element_producer = _initialize_and_get_element_producer( - args_queue, worker_index=worker_index, worker_count=worker_count + args_queue, + debug_flags=debug_flags, + worker_index=worker_index, + worker_count=worker_count, ) profiling_enabled = enable_profiling and worker_index == 0 if profiling_enabled: @@ -323,14 +351,24 @@ def __init__( "worker_index": worker_index, "worker_count": options.num_workers, "enable_profiling": options.enable_profiling, + "debug_flags": { + "grain_py_debug_mode": config.py_debug_mode, + "grain_py_dataset_visualization_output_dir": ( + config.py_dataset_visualization_output_dir + ), + }, } # The process kwargs must all be pickable and will be unpickle before # absl.app.run() is called. We send arguments via a queue to ensure that # they are unpickled after absl.app.run() was called in the child # processes. worker_init_fn = lambda: None + parse_debug_flags_fn = parse_debug_flags worker_init_fn = cloudpickle.dumps(worker_init_fn) - worker_args_queue.put((worker_init_fn, get_element_producer_fn)) + parse_debug_flags_fn = cloudpickle.dumps(parse_debug_flags_fn) + worker_args_queue.put( + (parse_debug_flags_fn, worker_init_fn, get_element_producer_fn) + ) process = ctx.Process( # pytype: disable=attribute-error # re-none target=_worker_loop, kwargs=process_kwargs, daemon=True ) diff --git a/grain/proto/BUILD b/grain/proto/BUILD new file mode 100644 index 00000000..36058ad6 --- /dev/null +++ b/grain/proto/BUILD @@ -0,0 +1,18 @@ +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("@com_google_protobuf//bazel:py_proto_library.bzl", "py_proto_library") + +default_visibility = ["//grain:__subpackages__"] + +package(default_visibility = default_visibility) + +proto_library( + name = "execution_summary_proto", + srcs = ["execution_summary.proto"], + # For profiling tooling. +) + +py_proto_library( + name = "execution_summary_py_pb2", + # For profiling tooling. + deps = [":execution_summary_proto"], +) diff --git a/grain/proto/execution_summary.proto b/grain/proto/execution_summary.proto new file mode 100644 index 00000000..f9d92299 --- /dev/null +++ b/grain/proto/execution_summary.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package grain.python.execution_summary; + +message ExecutionSummary { + message Node { + // Unique ID of the node. + int32 id = 2; + // Human-readable name of the node. + string name = 3; + // Node IDs of the parent nodes. + repeated int32 inputs = 4; + // Ratio of time spent by the pipeline waiting for the given transformation + // node. + double wait_time_ratio = 5; + // Cummulative processing time spent in the node from the start in + // nanoseconds. + int64 total_processing_time_ns = 6; + // Minimum per-element processing time in nanoseconds. + int64 min_processing_time_ns = 7; + // Maximum per-element processing time in nanoseconds. + int64 max_processing_time_ns = 8; + // Number of elements produced by the node. + int64 num_produced_elements = 9; + // Human-readable specification of the produced elements. + string output_spec = 10; + // Whether the node is the root node. + bool is_output = 11; + // Whether the node is prefetch node. Child nodes of prefetch will have + // their wait time ratio derived from the ratio of the prefetch node. + // Sum of all ratios in a single pipeline is 1. + bool is_prefetch = 12; + } + // Map of node IDs to nodes in the pipeline. + map nodes = 1; +} diff --git a/grain/python/stats/__init__.py b/grain/python/stats/__init__.py deleted file mode 100644 index 9240c2aa..00000000 --- a/grain/python/stats/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License.