forked from PolyAI-LDN/conversational-datasets
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tfrutil.py
118 lines (88 loc) · 3.22 KB
/
tfrutil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# -*- coding: utf-8 -*-
"""Command line utilities for maniuplating tfrecords files.
Usage:
To count the number of examples in a tfrecord file:
python tfrutil.py size train-00999-of-01000.tfrecords
To pretty print the contents of a tfrecord file:
python tfrutil.py pp train-00999-of-01000.tfrecords
This can accept gs:// file paths, as well as local files.
"""
import codecs
import sys
import click
import six
import tensorflow as tf
@click.group()
def _cli():
"""Command line utilities for maniuplating tfrecords files."""
pass
@_cli.command(name="size")
@click.argument("path", type=str, required=True, nargs=1)
def _size(path):
"""Compute the number of examples in the input tfrecord file."""
i = 0
for _ in tf.python_io.tf_record_iterator(path):
i += 1
print(i)
@_cli.command(name="pp")
@click.argument("path", type=str, required=True, nargs=1)
def _pretty_print(path):
"""Format and print the contents of the tfrecord file to stdout."""
for i, record in enumerate(tf.python_io.tf_record_iterator(path)):
example = tf.train.Example()
example.ParseFromString(record)
print("Example %i\n--------" % i)
_pretty_print_example(example)
print("--------\n\n")
def _pretty_print_example(example):
"""Format and print an individual tensorflow example."""
_print_field("Context", _get_string_feature(example, "context"))
_print_field("Response", _get_string_feature(example, "response"))
_print_extra_contexts(example)
_print_other_features(example)
def _print_field(name, content, indent=False):
indent_str = "\t" if indent else ""
content = content.replace("\n", "\\n ")
print("%s[%s]:" % (indent_str, name))
print("%s\t%s" % (indent_str, content))
def _get_string_feature(example, feature_name):
return example.features.feature[feature_name].bytes_list.value[0].decode(
"utf-8")
def _print_extra_contexts(example):
"""Print the extra context features."""
extra_contexts = []
i = 0
while True:
feature_name = "context/{}".format(i)
try:
value = _get_string_feature(example, feature_name)
except IndexError:
break
extra_contexts.append((feature_name, value))
i += 1
if not extra_contexts:
return
print("\nExtra Contexts:")
for feature_name, value in reversed(extra_contexts):
_print_field(feature_name, value, indent=True)
def _print_other_features(example):
"""Print the other features, which will depend on the dataset.
For now, only support string features.
"""
printed_header = False
for feature_name, value in sorted(example.features.feature.items()):
if (feature_name in {"context", "response"} or
feature_name.startswith("context/")):
continue
if not printed_header:
# Only print the header if there are other features in this
# example.
print("\nOther features:")
printed_header = True
_print_field(
feature_name, value.bytes_list.value[0].decode("utf-8"),
indent=True)
if __name__ == "__main__":
if six.PY2:
sys.stdout = codecs.getwriter("utf8")(sys.stdout)
_cli()