Skip to content

Commit

Permalink
Allow queues to be named for easier debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Nov 28, 2018
1 parent d403eac commit 4cd1967
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
9 changes: 7 additions & 2 deletions montblanc/rime/queue_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TensorQueue(object):
A Queue of tensors.
"""

def __init__(self, dtypes, shapes=None, shared_name=None):
def __init__(self, dtypes, shapes=None, name=None, shared_name=None):
"""
Constructs a simple queue accepting ``put`` operations
of tensors with the specified ``dtypes`` and ``shapes``.
Expand Down Expand Up @@ -47,11 +47,16 @@ def __init__(self, dtypes, shapes=None, shared_name=None):
A nested collection of dicts or tuples
containing shapes associated with ``dtypes``.
Must have the same structure as ``dtypes``
name : str, optional
Queue name
shared_name : str, optional
Shared resource name if this Queue is to be
shared amongst multiple tensorflow Sesssions.
"""
with ops.name_scope("tensor_queue") as scope:
if name is None:
name = "tensor_queue"

with ops.name_scope(name) as scope:
flat_dtypes = nest.flatten(dtypes)

if shapes is None:
Expand Down
9 changes: 6 additions & 3 deletions montblanc/rime/rime_ops/simple_queue_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ class QueueResource : public ResourceBase

DataTypeVector dtypes_;
std::vector<PartialTensorShape> shapes_;
std::string name_;

public:
public:
explicit QueueResource(const DataTypeVector & dtypes,
const std::vector<PartialTensorShape> & shapes)
: dtypes_(dtypes), shapes_(shapes), closed_(false)
const std::vector<PartialTensorShape> & shapes,
const std::string & name)
: dtypes_(dtypes), shapes_(shapes), name_(name), closed_(false)
{
// printf("Creating QueueResource %p\n", (void *) this);
}
Expand Down Expand Up @@ -86,6 +88,7 @@ class QueueResource : public ResourceBase
if(closed_)
{ return errors::OutOfRange("Queue is closed"); }

// No registered queues, push it on the stash
if(queues.size() == 0)
{ stash.push_back(data); }
else
Expand Down Expand Up @@ -238,7 +241,7 @@ class DatasetQueueHandleOp : public OpKernel
cinfo.container(), cinfo.name(), &queue_resource,
[this, ctx](QueueResource ** result) EXCLUSIVE_LOCKS_REQUIRED(mu_)
{
*result = new QueueResource(dtypes_, shapes_);
*result = new QueueResource(dtypes_, shapes_, cinfo.name());
return Status::OK();
}
));
Expand Down
3 changes: 2 additions & 1 deletion montblanc/rime/tensorflow_mock_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def tensor_queue(ds_name, ds_ph, dtypes, shapes):
"""
Creates TensorQueue dataset
"""
tensor_queue = TensorQueue(dtypes, shapes)

tensor_queue = TensorQueue(dtypes, shapes, name=ds_name)
tensor_dataset = QueueDataset(tensor_queue, name=ds_name)
put = tensor_queue.put(ds_ph)
close = tensor_queue.close()
Expand Down

0 comments on commit 4cd1967

Please sign in to comment.