diff --git a/montblanc/rime/queue_dataset.py b/montblanc/rime/queue_dataset.py index 04e094e72..44f81feff 100644 --- a/montblanc/rime/queue_dataset.py +++ b/montblanc/rime/queue_dataset.py @@ -16,7 +16,7 @@ class TensorQueue(object): A Queue of tensors. """ - def __init__(self, dtypes, shapes=None, shared_name=None): + def __init__(self, dtypes, shapes=None, name=None, shared_name=None): """ Constructs a simple queue accepting ``put`` operations of tensors with the specified ``dtypes`` and ``shapes``. @@ -47,11 +47,16 @@ def __init__(self, dtypes, shapes=None, shared_name=None): A nested collection of dicts or tuples containing shapes associated with ``dtypes``. Must have the same structure as ``dtypes`` + name : str, optional + Queue name shared_name : str, optional Shared resource name if this Queue is to be shared amongst multiple tensorflow Sesssions. """ - with ops.name_scope("tensor_queue") as scope: + if name is None: + name = "tensor_queue" + + with ops.name_scope(name) as scope: flat_dtypes = nest.flatten(dtypes) if shapes is None: diff --git a/montblanc/rime/rime_ops/simple_queue_dataset.cpp b/montblanc/rime/rime_ops/simple_queue_dataset.cpp index 9faff76da..ae7d21889 100644 --- a/montblanc/rime/rime_ops/simple_queue_dataset.cpp +++ b/montblanc/rime/rime_ops/simple_queue_dataset.cpp @@ -33,12 +33,14 @@ class QueueResource : public ResourceBase DataTypeVector dtypes_; std::vector shapes_; + std::string name_; public: public: explicit QueueResource(const DataTypeVector & dtypes, - const std::vector & shapes) - : dtypes_(dtypes), shapes_(shapes), closed_(false) + const std::vector & shapes, + const std::string & name) + : dtypes_(dtypes), shapes_(shapes), name_(name), closed_(false) { // printf("Creating QueueResource %p\n", (void *) this); } @@ -86,6 +88,7 @@ class QueueResource : public ResourceBase if(closed_) { return errors::OutOfRange("Queue is closed"); } + // No registered queues, push it on the stash if(queues.size() == 0) { stash.push_back(data); } else @@ -238,7 +241,7 @@ class DatasetQueueHandleOp : public OpKernel cinfo.container(), cinfo.name(), &queue_resource, [this, ctx](QueueResource ** result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *result = new QueueResource(dtypes_, shapes_); + *result = new QueueResource(dtypes_, shapes_, cinfo.name()); return Status::OK(); } ));