diff --git a/src/net.cpp b/src/net.cpp index 9aafee03e4ee..58863db74812 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -2206,6 +2206,14 @@ int Extractor::input(int blob_index, const Mat& in) if (blob_index < 0 || blob_index >= (int)d->blob_mats.size()) return -1; + const Mat& shape = d->net->blobs()[blob_index].shape; + if (shape.total()) + { + const Mat& in_shape = in.shape(); + if (shape.dims != in_shape.dims || shape.w != in_shape.w || shape.h != in_shape.h || shape.d != in_shape.d || shape.c != in_shape.c) + return -1; + } + d->blob_mats[blob_index] = in; return 0;