diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c index 4a8738a7..e44e06ec 100644 --- a/lib/TH/generic/THTensor.c +++ b/lib/TH/generic/THTensor.c @@ -315,6 +315,42 @@ void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes) { THFree(expandedStrides); } + +void THTensor_(expandNd)(THTensor **rets, THTensor **ops, int count) { + for (int i = 0; i < count; ++i) { + THArgCheck(THTensor_(nDimension)(ops[i]) > 0, i, "can't expand empty tensor %d", i); + } + + long *op_sizes[count]; + long op_dims[count]; + + for (int i = 0; i < count; ++i) { + op_sizes[i] = ops[i]->size; + op_dims[i] = ops[i]->nDimension; + } + + THLongStorage *sizes = THLongStorage_new(); + char error_buffer[1024]; + int ret = THLongStorage_inferSizeN(sizes, + count, + op_sizes, + op_dims, + error_buffer, + 1024); + + if(ret != 0) { + THLongStorage_free(sizes); + THError(error_buffer); + return; + } + + for (int i = 0; i < count; ++i) { + THTensor_(expand)(rets[i], ops[i], sizes); + } + + THLongStorage_free(sizes); +} + void THTensor_(set)(THTensor *self, THTensor *src) { if(self != src) diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h index 9a2417fb..9fb246c8 100644 --- a/lib/TH/generic/THTensor.h +++ b/lib/TH/generic/THTensor.h @@ -72,6 +72,7 @@ TH_API THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size); TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size); TH_API void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size); +TH_API void THTensor_(expandNd)(THTensor **rets, THTensor **ops, int count); TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride); TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src);