Skip to content

Commit

Permalink
Add a converter from dcn_slack_analysis.proto to GViz DataTable format.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559508558
  • Loading branch information
SurbhiJainUSC authored and copybara-github committed Aug 23, 2023
1 parent f3e3719 commit 4bfdd0a
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 0 deletions.
24 changes: 24 additions & 0 deletions plugin/tensorboard_plugin_profile/convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package(
py_library(
name = "all_libs",
deps = [
":dcn_collective_stats_proto_to_gviz",
":diagnostics",
":input_pipeline_proto_to_gviz",
":kernel_stats_proto_to_gviz",
Expand Down Expand Up @@ -193,12 +194,35 @@ py_test(
],
)

py_library(
name = "dcn_collective_stats_proto_to_gviz",
srcs = ["dcn_collective_stats_proto_to_gviz.py"],
srcs_version = "PY2AND3",
deps = [requirement("gviz_api")],
)

py_test(
name = "dcn_collective_stats_proto_to_gviz_test",
size = "small",
srcs = ["dcn_collective_stats_proto_to_gviz_test.py"],
main = "dcn_collective_stats_proto_to_gviz_test.py",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":dcn_collective_stats_proto_to_gviz",
requirement("gviz_api"),
"//:expect_tensorflow_installed",
"@org_xprof//plugin/tensorboard_plugin_profile/protobuf:protos_all_py_pb2",
],
)

py_library(
name = "raw_to_tool_data",
srcs = ["raw_to_tool_data.py"],
srcs_version = "PY2AND3",
visibility = visibility,
deps = [
":dcn_collective_stats_proto_to_gviz",
":input_pipeline_proto_to_gviz",
":kernel_stats_proto_to_gviz",
":overview_page_proto_to_gviz",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""For conversion of Dcn Collective Stats page protos to GViz DataTables.
Usage:
gviz_data_tables = generate_all_chart_tables(dcn_slack_analysis)
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gviz_api

from tensorboard_plugin_profile.protobuf import dcn_slack_analysis_pb2


def get_dcn_collective_stats_table_args(dcn_slack_analysis):
"""Creates a gviz DataTable object from DcnSlackAnalysis proto.
Args:
dcn_slack_analysis: dcn_slack_analysis_pb2.DcnSlackAnalysis.
Returns:
Returns a gviz_api.DataTable
"""

table_description = [
("dcnCollectiveName", "string", "Dcn Collective Name"),
("recvOpName", "string", "Recv Op Name"),
("sendOpName", "string", "Send Op Name"),
("slackTime", "number", "Slack Time (ms)"),
("observedDuration", "number", "Observed Duration (ms)"),
("stallDuration", "number", "Stall Duration (ms)"),
("occurrences", "number", "Occurrences"),
]

data = []
for slack in dcn_slack_analysis.dcn_slack_summary:
row = [
slack.rendezvous,
slack.recv_op_name,
slack.send_op_name,
slack.slack_us / 1000,
slack.observed_duration_us / 1000,
slack.stall_duration_us / 1000,
slack.occurrences,
]
data.append(row)

return (table_description, data, [])


def generate_dcn_collective_stats_table(dcn_slack_analysis):
(table_description, data, custom_properties) = (
get_dcn_collective_stats_table_args(dcn_slack_analysis)
)
return gviz_api.DataTable(table_description, data, custom_properties)


def generate_all_chart_tables(dcn_slack_analysis):
"""Converts a DcnSlackAnalysis proto to gviz DataTables."""
return [
generate_dcn_collective_stats_table(dcn_slack_analysis),
]


def to_json(raw_data):
"""Converts a serialized DcnCollectiveAnalysis string to json."""
dcn_slack_analysis = dcn_slack_analysis_pb2.DcnSlackAnalysis()
dcn_slack_analysis.ParseFromString(raw_data)
all_chart_tables = generate_all_chart_tables(dcn_slack_analysis)
json_join = ",".join(x.ToJSon() if x else "{}" for x in all_chart_tables)
return "[" + json_join + "]"
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for dcn_collective_stats_proto_to_gviz."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import csv
import enum
import io

import gviz_api
import tensorflow as tf

from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz
from tensorboard_plugin_profile.protobuf import dcn_slack_analysis_pb2


class StrEnum(str, enum.Enum):
pass


class MockValues(StrEnum):
DCN_COLLECTIVE_NAME = "collective-1"
RECV_OP_NAME = "recv-done"
SEND_OP_NAME = "send"
SLACK_US = 2
OBSERVED_DURATION_US = 12
STALL_DURATION_MS = 5
OCCURRENCES = 6


class ProtoToGvizTest(tf.test.TestCase):

def create_empty_dcn_slack_analysis(self):
return dcn_slack_analysis_pb2.DcnSlackAnalysis()

def create_mock_dcn_slack_summary(self):
dcn_slack_summary = dcn_slack_analysis_pb2.DcnSlackSummary(
rendezvous=MockValues.DCN_COLLECTIVE_NAME,
recv_op_name=MockValues.RECV_OP_NAME,
send_op_name=MockValues.SEND_OP_NAME,
slack_us=int(MockValues.SLACK_US) * 1000,
observed_duration_us=int(MockValues.OBSERVED_DURATION_US) * 1000,
stall_duration_us=int(MockValues.STALL_DURATION_MS) * 1000,
occurrences=int(MockValues.OCCURRENCES),
)
return dcn_slack_summary

def create_mock_dcn_slack_analysis(self):
dcn_slack_analysis = dcn_slack_analysis_pb2.DcnSlackAnalysis()
for _ in range(0, 3):
dcn_slack_analysis.dcn_slack_summary.append(
self.create_mock_dcn_slack_summary()
)
return dcn_slack_analysis

def test_dcn_collective_stats_empty(self):
dcn_slack_analysis = self.create_empty_dcn_slack_analysis()
data_table = (
dcn_collective_stats_proto_to_gviz.generate_dcn_collective_stats_table(
dcn_slack_analysis
)
)

self.assertEqual(0, data_table.NumberOfRows())
self.assertLen(data_table.columns, 7)

def test_dcn_collective_stats_table(self):
dcn_slack_analysis = self.create_mock_dcn_slack_analysis()
(table_description, data, custom_properties) = (
dcn_collective_stats_proto_to_gviz.get_dcn_collective_stats_table_args(
dcn_slack_analysis
)
)
data_table = gviz_api.DataTable(table_description, data, custom_properties)

self.assertLen(data, 3)
self.assertEqual(3, data_table.NumberOfRows())
self.assertLen(table_description, 7)
self.assertLen(data_table.columns, 7)

csv_file = io.StringIO(data_table.ToCsv())
reader = csv.reader(csv_file)

expected = [
MockValues.DCN_COLLECTIVE_NAME,
MockValues.RECV_OP_NAME,
MockValues.SEND_OP_NAME,
MockValues.SLACK_US,
MockValues.OBSERVED_DURATION_US,
MockValues.STALL_DURATION_MS,
MockValues.OCCURRENCES,
]

for rr, row_values in enumerate(reader):
if rr == 0:
# DataTable columns match schema defined in table_description.
for cc, column_header in enumerate(row_values):
self.assertEqual(table_description[cc][2], column_header)
else:
for cc, cell_str in enumerate(row_values):
raw_value = data[rr - 1][cc]
value_type = table_description[cc][1]

# Only number and strings are used in the DataTable schema.
self.assertIn(value_type, ["number", "string"])

# Encode in similar fashion as DataTable.ToCsv().
expected_value = gviz_api.DataTable.CoerceValue(raw_value, value_type)
self.assertNotIsInstance(expected_value, tuple)
self.assertEqual(expected_value, raw_value)

# Check against expected values we have set in our mock table.
if value_type == "string":
self.assertEqual(expected[cc], cell_str)
else:
if expected[cc] == MockValues.OCCURRENCES:
self.assertEqual(str(int(expected[cc])), cell_str)
else:
self.assertEqual(str(float(expected[cc])), cell_str)


if __name__ == "__main__":
tf.test.main()
6 changes: 6 additions & 0 deletions plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import logging

from tensorflow.python.profiler.internal import _pywrap_profiler # pylint: disable=g-direct-tensorflow-import
from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz
from tensorboard_plugin_profile.convert import input_pipeline_proto_to_gviz
from tensorboard_plugin_profile.convert import kernel_stats_proto_to_gviz
from tensorboard_plugin_profile.convert import overview_page_proto_to_gviz
Expand Down Expand Up @@ -167,6 +168,11 @@ def xspace_to_tool_data(
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'dcn_collective_stats':
options = {'host_name': params.get('host')}
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = dcn_collective_stats_proto_to_gviz.to_json(raw_data)
else:
logger.warning('%s is not a known xplane tool', tool)
return data, content_type
Expand Down
1 change: 1 addition & 0 deletions plugin/tensorboard_plugin_profile/profile_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
'overview_page^',
'pod_viewer^',
'tf_data_bottleneck_analysis^',
'dcn_collective_stats^',
])

# XPlane generated tools that only support all host mode.
Expand Down
2 changes: 2 additions & 0 deletions plugin/tensorboard_plugin_profile/protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package(
proto_library(
name = "protos_all",
srcs = [
"dcn_slack_analysis.proto",
"diagnostics.proto",
"input_pipeline.proto",
"kernel_stats.proto",
Expand All @@ -24,6 +25,7 @@ proto_library(
py_proto_library(
name = "protos_all_py_pb2",
srcs = [
"dcn_slack_analysis.proto",
"diagnostics.proto",
"input_pipeline.proto",
"kernel_stats.proto",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
syntax = "proto3";

package tensorboard_plugin_profile;

message DcnSlack {
string rendezvous = 1;
// Xprof observed send start time.
uint64 send_start_time_us = 2;
// Xprof observed recv_done end time.
uint64 recv_done_end_time_us = 3;

// Slack is defined as the time the collective has to send and recv data
// without stalling the tpu. The effect of the network and other overlapping
// collectives are removed from the collective of interest.
//
//
// HOST 1 :
// |--------|SEND1|-------|SEND1.DONE|-------|RECV1|------|RECV1.DONE|-------
// HOST 2:
// |------|SEND2|-------|SEND2.DONE|-------|RECV2|------|RECV2.DONE |-----
//
// Slack is computed as
// RECV2.DONE.StartTime - SEND2.StartTime - (Overlapping Communication)
// In this case, Overlapping communication is the duration of SEND2,
// SEND2.DONE and RECV2. In cases where other collectives are interspaced
// between this collective, Overlapping duration would include their durations
// as well. Host 1 is ignored while computing the slack, as we assume that the
// similar ops are executing each core. This also prevents clock drifts to
// effect the analysis.
uint64 slack_us = 4;

uint64 bytes_transmitted_over_network = 5;

// Duration the collective stalled the TPU.
uint64 stall_duration_us = 6;

// Recv op name
string recv_op_name = 7;

// Send op name
string send_op_name = 8;
}

message DcnSlackSummary {
// Rendezvous name for the collective.
string rendezvous = 1;
// Slack Time in Microseconds,
uint64 slack_us = 2;
// Number of occurrences in the sampled duration.
uint64 occurrences = 3;
// Bytes transmitted over the network.
uint64 bytes_transmitted_over_network = 4;
// Duration the collective stalled the TPU.
uint64 stall_duration_us = 5;
// Observed duration.
uint64 observed_duration_us = 6;
// Recv op name.
string recv_op_name = 7;

// Send op name.
string send_op_name = 8;
}

message DcnSlackAnalysis {
repeated DcnSlack dcn_slack = 1;
repeated DcnSlackSummary dcn_slack_summary = 2;
}

0 comments on commit 4bfdd0a

Please sign in to comment.