Skip to content

Commit

Permalink
Add a converter from dcn_slack_analysis.proto to GViz DataTable format.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558216192
  • Loading branch information
SurbhiJainUSC authored and copybara-github committed Aug 18, 2023
1 parent 6052c62 commit 705d953
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 0 deletions.
24 changes: 24 additions & 0 deletions plugin/tensorboard_plugin_profile/convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package(
py_library(
name = "all_libs",
deps = [
":dcn_collective_stats_proto_to_gviz",
":diagnostics",
":input_pipeline_proto_to_gviz",
":kernel_stats_proto_to_gviz",
Expand Down Expand Up @@ -193,12 +194,35 @@ py_test(
],
)

py_library(
name = "dcn_collective_stats_proto_to_gviz",
srcs = ["dcn_collective_stats_proto_to_gviz.py"],
srcs_version = "PY2AND3",
deps = [requirement("gviz_api")],
)

py_test(
name = "dcn_collective_stats_proto_to_gviz_test",
size = "small",
srcs = ["dcn_collective_stats_proto_to_gviz_test.py"],
main = "dcn_collective_stats_proto_to_gviz_test.py",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":dcn_collective_stats_proto_to_gviz",
requirement("gviz_api"),
"//:expect_tensorflow_installed",
"@org_xprof//plugin/tensorboard_plugin_profile/protobuf:protos_all_py_pb2",
],
)

py_library(
name = "raw_to_tool_data",
srcs = ["raw_to_tool_data.py"],
srcs_version = "PY2AND3",
visibility = visibility,
deps = [
":dcn_collective_stats_proto_to_gviz",
":input_pipeline_proto_to_gviz",
":kernel_stats_proto_to_gviz",
":overview_page_proto_to_gviz",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""For conversion of Dcn Collective Stats page protos to GViz DataTables.
Usage:
gviz_data_tables = generate_all_chart_tables(dcn_slack_analysis)
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

import gviz_api

from tensorboard_plugin_profile.protobuf import dcn_slack_analysis_pb2


def get_dcn_collective_stats_table_args(dcn_slack_analysis):
"""Creates a gviz DataTable object from DcnSlackAnalysis proto.
Args:
dcn_slack_analysis: dcn_slack_analysis_pb2.DcnSlackAnalysis.
Returns:
Returns a gviz_api.DataTable
"""

table_description = [
("dcnCollectiveName", "string", "Dcn Collective Name"),
("recvOpName", "string", "Recv Op Name"),
("sendOpName", "string", "Send Op Name"),
("slackTime", "number", "Slack Time (ms)"),
("observedDuration", "number", "Observed Duration (ms)"),
("stallDuration", "number", "Stall Duration (ms)"),
("occurrences", "number", "Occurrences"),
("dataTransmittedSize", "number", "Data Transmitted Size"),
("bandwidth", "number", "Bandwidth (Gbps)"),
]

data = []
for slack in dcn_slack_analysis.dcn_slack_summary:
row = [
slack.rendezvous,
slack.recv_op_name,
slack.send_op_name,
slack.slack_us / 1000,
slack.observed_duration_us / 1000,
slack.stall_duration_us / 1000,
slack.occurrences,
slack.bytes_transmitted_over_network,
]
if slack.slack_us == 0:
row.append(sys.maxsize)
else:
row.append(
(slack.bytes_transmitted_over_network / (slack.slack_us * 8)) / 1000
)
data.append(row)

return (table_description, data, [])


def generate_dcn_collective_stats_table(dcn_slack_analysis):
(table_description, data, custom_properties) = (
get_dcn_collective_stats_table_args(dcn_slack_analysis)
)
return gviz_api.DataTable(table_description, data, custom_properties)


def generate_all_chart_tables(dcn_slack_analysis):
"""Converts a DcnSlackAnalysis proto to gviz DataTables."""
return [
generate_dcn_collective_stats_table(dcn_slack_analysis),
]


def to_json(raw_data):
"""Converts a serialized DcnCollectiveAnalysis string to json."""
dcn_slack_analysis = dcn_slack_analysis_pb2.DcnSlackAnalysis()
dcn_slack_analysis.ParseFromString(raw_data)
all_chart_tables = generate_all_chart_tables(dcn_slack_analysis)
json_join = ",".join(x.ToJSon() if x else "{}" for x in all_chart_tables)
return "[" + json_join + "]"
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for dcn_collective_stats_proto_to_gviz."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import csv
import enum
import io

import gviz_api
import tensorflow as tf

from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz
from tensorboard_plugin_profile.protobuf import dcn_slack_analysis_pb2


class StrEnum(str, enum.Enum):
pass


class MockValues(StrEnum):
DCN_COLLECTIVE_NAME = "collective-1"
RECV_OP_NAME = "recv-done"
SEND_OP_NAME = "send"
SLACK_US = 2
OBSERVED_DURATION_US = 12
STALL_DURATION_MS = 5
OCCURRENCES = 6
BYTES_TRANSMITTED_OVER_NETWORK = 8192000
BANDWIDTH = 0.512


class ProtoToGvizTest(tf.test.TestCase):

def create_empty_dcn_slack_analysis(self):
return dcn_slack_analysis_pb2.DcnSlackAnalysis()

def create_mock_dcn_slack_summary(self):
dcn_slack_summary = dcn_slack_analysis_pb2.DcnSlackSummary(
rendezvous=MockValues.DCN_COLLECTIVE_NAME,
recv_op_name=MockValues.RECV_OP_NAME,
send_op_name=MockValues.SEND_OP_NAME,
slack_us=int(MockValues.SLACK_US) * 1000,
observed_duration_us=int(MockValues.OBSERVED_DURATION_US) * 1000,
stall_duration_us=int(MockValues.STALL_DURATION_MS) * 1000,
occurrences=int(MockValues.OCCURRENCES),
bytes_transmitted_over_network=int(
MockValues.BYTES_TRANSMITTED_OVER_NETWORK
)
)
return dcn_slack_summary

def create_mock_dcn_slack_analysis(self):
dcn_slack_analysis = dcn_slack_analysis_pb2.DcnSlackAnalysis()
for _ in range(0, 3):
dcn_slack_analysis.dcn_slack_summary.append(
self.create_mock_dcn_slack_summary()
)
return dcn_slack_analysis

def test_dcn_collective_stats_empty(self):
dcn_slack_analysis = self.create_empty_dcn_slack_analysis()
data_table = (
dcn_collective_stats_proto_to_gviz.generate_dcn_collective_stats_table(
dcn_slack_analysis
)
)

self.assertEqual(0, data_table.NumberOfRows())
self.assertLen(data_table.columns, 9)

def test_dcn_collective_stats_table(self):
dcn_slack_analysis = self.create_mock_dcn_slack_analysis()
(table_description, data, custom_properties) = (
dcn_collective_stats_proto_to_gviz.get_dcn_collective_stats_table_args(
dcn_slack_analysis
)
)
data_table = gviz_api.DataTable(table_description, data, custom_properties)

self.assertLen(data, 3)
self.assertEqual(3, data_table.NumberOfRows())
self.assertLen(table_description, 9)
self.assertLen(data_table.columns, 9)

csv_file = io.StringIO(data_table.ToCsv())
reader = csv.reader(csv_file)

expected = [
MockValues.DCN_COLLECTIVE_NAME,
MockValues.RECV_OP_NAME,
MockValues.SEND_OP_NAME,
MockValues.SLACK_US,
MockValues.OBSERVED_DURATION_US,
MockValues.STALL_DURATION_MS,
MockValues.OCCURRENCES,
MockValues.BYTES_TRANSMITTED_OVER_NETWORK,
MockValues.BANDWIDTH,
]

for rr, row_values in enumerate(reader):
if rr == 0:
# DataTable columns match schema defined in table_description.
for cc, column_header in enumerate(row_values):
self.assertEqual(table_description[cc][2], column_header)
else:
for cc, cell_str in enumerate(row_values):
raw_value = data[rr - 1][cc]
value_type = table_description[cc][1]

# Only number and strings are used in the DataTable schema.
self.assertIn(value_type, ["number", "string"])

# Encode in similar fashion as DataTable.ToCsv().
expected_value = gviz_api.DataTable.CoerceValue(raw_value, value_type)
self.assertNotIsInstance(expected_value, tuple)
self.assertEqual(expected_value, raw_value)

# Check against expected values we have set in our mock table.
if value_type == "string":
self.assertEqual(expected[cc], cell_str)
else:
if (
expected[cc] == MockValues.OCCURRENCES
or expected[cc] == MockValues.BYTES_TRANSMITTED_OVER_NETWORK
):
self.assertEqual(str(int(expected[cc])), cell_str)
else:
self.assertEqual(str(float(expected[cc])), cell_str)


if __name__ == "__main__":
tf.test.main()
6 changes: 6 additions & 0 deletions plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import logging

from tensorflow.python.profiler.internal import _pywrap_profiler # pylint: disable=g-direct-tensorflow-import
from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz
from tensorboard_plugin_profile.convert import input_pipeline_proto_to_gviz
from tensorboard_plugin_profile.convert import kernel_stats_proto_to_gviz
from tensorboard_plugin_profile.convert import overview_page_proto_to_gviz
Expand Down Expand Up @@ -167,6 +168,11 @@ def xspace_to_tool_data(
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'dcn_collective_stats':
options = {'host_name': params.get('host')}
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = dcn_collective_stats_proto_to_gviz.to_json(raw_data)
else:
logger.warning('%s is not a known xplane tool', tool)
return data, content_type
Expand Down
1 change: 1 addition & 0 deletions plugin/tensorboard_plugin_profile/profile_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
'overview_page^',
'pod_viewer^',
'tf_data_bottleneck_analysis^',
'dcn_collective_stats^',
])

# XPlane generated tools that only support all host mode.
Expand Down
2 changes: 2 additions & 0 deletions plugin/tensorboard_plugin_profile/protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package(
proto_library(
name = "protos_all",
srcs = [
"dcn_slack_analysis.proto",
"diagnostics.proto",
"input_pipeline.proto",
"kernel_stats.proto",
Expand All @@ -24,6 +25,7 @@ proto_library(
py_proto_library(
name = "protos_all_py_pb2",
srcs = [
"dcn_slack_analysis.proto",
"diagnostics.proto",
"input_pipeline.proto",
"kernel_stats.proto",
Expand Down
Loading

0 comments on commit 705d953

Please sign in to comment.