-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a converter from dcn_slack_analysis.proto to GViz DataTable format.
PiperOrigin-RevId: 558216192
- Loading branch information
1 parent
6052c62
commit 705d953
Showing
7 changed files
with
345 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
plugin/tensorboard_plugin_profile/convert/dcn_collective_stats_proto_to_gviz.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""For conversion of Dcn Collective Stats page protos to GViz DataTables. | ||
Usage: | ||
gviz_data_tables = generate_all_chart_tables(dcn_slack_analysis) | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import sys | ||
|
||
import gviz_api | ||
|
||
from tensorboard_plugin_profile.protobuf import dcn_slack_analysis_pb2 | ||
|
||
|
||
def get_dcn_collective_stats_table_args(dcn_slack_analysis): | ||
"""Creates a gviz DataTable object from DcnSlackAnalysis proto. | ||
Args: | ||
dcn_slack_analysis: dcn_slack_analysis_pb2.DcnSlackAnalysis. | ||
Returns: | ||
Returns a gviz_api.DataTable | ||
""" | ||
|
||
table_description = [ | ||
("dcnCollectiveName", "string", "Dcn Collective Name"), | ||
("recvOpName", "string", "Recv Op Name"), | ||
("sendOpName", "string", "Send Op Name"), | ||
("slackTime", "number", "Slack Time (ms)"), | ||
("observedDuration", "number", "Observed Duration (ms)"), | ||
("stallDuration", "number", "Stall Duration (ms)"), | ||
("occurrences", "number", "Occurrences"), | ||
("dataTransmittedSize", "number", "Data Transmitted Size"), | ||
("bandwidth", "number", "Bandwidth (Gbps)"), | ||
] | ||
|
||
data = [] | ||
for slack in dcn_slack_analysis.dcn_slack_summary: | ||
row = [ | ||
slack.rendezvous, | ||
slack.recv_op_name, | ||
slack.send_op_name, | ||
slack.slack_us / 1000, | ||
slack.observed_duration_us / 1000, | ||
slack.stall_duration_us / 1000, | ||
slack.occurrences, | ||
slack.bytes_transmitted_over_network, | ||
] | ||
if slack.slack_us == 0: | ||
row.append(sys.maxsize) | ||
else: | ||
row.append( | ||
(slack.bytes_transmitted_over_network / (slack.slack_us * 8)) / 1000 | ||
) | ||
data.append(row) | ||
|
||
return (table_description, data, []) | ||
|
||
|
||
def generate_dcn_collective_stats_table(dcn_slack_analysis): | ||
(table_description, data, custom_properties) = ( | ||
get_dcn_collective_stats_table_args(dcn_slack_analysis) | ||
) | ||
return gviz_api.DataTable(table_description, data, custom_properties) | ||
|
||
|
||
def generate_all_chart_tables(dcn_slack_analysis): | ||
"""Converts a DcnSlackAnalysis proto to gviz DataTables.""" | ||
return [ | ||
generate_dcn_collective_stats_table(dcn_slack_analysis), | ||
] | ||
|
||
|
||
def to_json(raw_data): | ||
"""Converts a serialized DcnCollectiveAnalysis string to json.""" | ||
dcn_slack_analysis = dcn_slack_analysis_pb2.DcnSlackAnalysis() | ||
dcn_slack_analysis.ParseFromString(raw_data) | ||
all_chart_tables = generate_all_chart_tables(dcn_slack_analysis) | ||
json_join = ",".join(x.ToJSon() if x else "{}" for x in all_chart_tables) | ||
return "[" + json_join + "]" |
149 changes: 149 additions & 0 deletions
149
plugin/tensorboard_plugin_profile/convert/dcn_collective_stats_proto_to_gviz_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Tests for dcn_collective_stats_proto_to_gviz.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import csv | ||
import enum | ||
import io | ||
|
||
import gviz_api | ||
import tensorflow as tf | ||
|
||
from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz | ||
from tensorboard_plugin_profile.protobuf import dcn_slack_analysis_pb2 | ||
|
||
|
||
class StrEnum(str, enum.Enum): | ||
pass | ||
|
||
|
||
class MockValues(StrEnum): | ||
DCN_COLLECTIVE_NAME = "collective-1" | ||
RECV_OP_NAME = "recv-done" | ||
SEND_OP_NAME = "send" | ||
SLACK_US = 2 | ||
OBSERVED_DURATION_US = 12 | ||
STALL_DURATION_MS = 5 | ||
OCCURRENCES = 6 | ||
BYTES_TRANSMITTED_OVER_NETWORK = 8192000 | ||
BANDWIDTH = 0.512 | ||
|
||
|
||
class ProtoToGvizTest(tf.test.TestCase): | ||
|
||
def create_empty_dcn_slack_analysis(self): | ||
return dcn_slack_analysis_pb2.DcnSlackAnalysis() | ||
|
||
def create_mock_dcn_slack_summary(self): | ||
dcn_slack_summary = dcn_slack_analysis_pb2.DcnSlackSummary( | ||
rendezvous=MockValues.DCN_COLLECTIVE_NAME, | ||
recv_op_name=MockValues.RECV_OP_NAME, | ||
send_op_name=MockValues.SEND_OP_NAME, | ||
slack_us=int(MockValues.SLACK_US) * 1000, | ||
observed_duration_us=int(MockValues.OBSERVED_DURATION_US) * 1000, | ||
stall_duration_us=int(MockValues.STALL_DURATION_MS) * 1000, | ||
occurrences=int(MockValues.OCCURRENCES), | ||
bytes_transmitted_over_network=int( | ||
MockValues.BYTES_TRANSMITTED_OVER_NETWORK | ||
) | ||
) | ||
return dcn_slack_summary | ||
|
||
def create_mock_dcn_slack_analysis(self): | ||
dcn_slack_analysis = dcn_slack_analysis_pb2.DcnSlackAnalysis() | ||
for _ in range(0, 3): | ||
dcn_slack_analysis.dcn_slack_summary.append( | ||
self.create_mock_dcn_slack_summary() | ||
) | ||
return dcn_slack_analysis | ||
|
||
def test_dcn_collective_stats_empty(self): | ||
dcn_slack_analysis = self.create_empty_dcn_slack_analysis() | ||
data_table = ( | ||
dcn_collective_stats_proto_to_gviz.generate_dcn_collective_stats_table( | ||
dcn_slack_analysis | ||
) | ||
) | ||
|
||
self.assertEqual(0, data_table.NumberOfRows()) | ||
self.assertLen(data_table.columns, 9) | ||
|
||
def test_dcn_collective_stats_table(self): | ||
dcn_slack_analysis = self.create_mock_dcn_slack_analysis() | ||
(table_description, data, custom_properties) = ( | ||
dcn_collective_stats_proto_to_gviz.get_dcn_collective_stats_table_args( | ||
dcn_slack_analysis | ||
) | ||
) | ||
data_table = gviz_api.DataTable(table_description, data, custom_properties) | ||
|
||
self.assertLen(data, 3) | ||
self.assertEqual(3, data_table.NumberOfRows()) | ||
self.assertLen(table_description, 9) | ||
self.assertLen(data_table.columns, 9) | ||
|
||
csv_file = io.StringIO(data_table.ToCsv()) | ||
reader = csv.reader(csv_file) | ||
|
||
expected = [ | ||
MockValues.DCN_COLLECTIVE_NAME, | ||
MockValues.RECV_OP_NAME, | ||
MockValues.SEND_OP_NAME, | ||
MockValues.SLACK_US, | ||
MockValues.OBSERVED_DURATION_US, | ||
MockValues.STALL_DURATION_MS, | ||
MockValues.OCCURRENCES, | ||
MockValues.BYTES_TRANSMITTED_OVER_NETWORK, | ||
MockValues.BANDWIDTH, | ||
] | ||
|
||
for rr, row_values in enumerate(reader): | ||
if rr == 0: | ||
# DataTable columns match schema defined in table_description. | ||
for cc, column_header in enumerate(row_values): | ||
self.assertEqual(table_description[cc][2], column_header) | ||
else: | ||
for cc, cell_str in enumerate(row_values): | ||
raw_value = data[rr - 1][cc] | ||
value_type = table_description[cc][1] | ||
|
||
# Only number and strings are used in the DataTable schema. | ||
self.assertIn(value_type, ["number", "string"]) | ||
|
||
# Encode in similar fashion as DataTable.ToCsv(). | ||
expected_value = gviz_api.DataTable.CoerceValue(raw_value, value_type) | ||
self.assertNotIsInstance(expected_value, tuple) | ||
self.assertEqual(expected_value, raw_value) | ||
|
||
# Check against expected values we have set in our mock table. | ||
if value_type == "string": | ||
self.assertEqual(expected[cc], cell_str) | ||
else: | ||
if ( | ||
expected[cc] == MockValues.OCCURRENCES | ||
or expected[cc] == MockValues.BYTES_TRANSMITTED_OVER_NETWORK | ||
): | ||
self.assertEqual(str(int(expected[cc])), cell_str) | ||
else: | ||
self.assertEqual(str(float(expected[cc])), cell_str) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.