Skip to content

Commit e11d004

Browse files
Add input shapes validation to prevent heap overflow in Concat layer
1 parent b039b87 commit e11d004

File tree

6 files changed

+30
-12
lines changed

6 files changed

+30
-12
lines changed

src/layer/arm/concat_arm.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,13 @@ int Concat_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
197197
// total channels
198198
size_t elemsize = bottom_blobs[0].elemsize;
199199
int elempack = bottom_blobs[0].elempack;
200-
int top_channels = 0;
201-
for (size_t b = 0; b < bottom_blobs.size(); b++)
200+
int top_channels = bottom_blobs[0].c * bottom_blobs[0].elempack;
201+
for (size_t b = 1; b < bottom_blobs.size(); b++)
202202
{
203203
const Mat& bottom_blob = bottom_blobs[b];
204+
if (bottom_blob.w != w || bottom_blob.h != h || bottom_blob.d != d)
205+
return -100;
206+
204207
elemsize = std::min(elemsize, bottom_blob.elemsize);
205208
elempack = std::min(elempack, bottom_blob.elempack);
206209
top_channels += bottom_blob.c * bottom_blob.elempack;

src/layer/concat.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,13 @@ int Concat::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_
128128
int d = bottom_blobs[0].d;
129129

130130
// total channels
131-
int top_channels = 0;
132-
for (size_t b = 0; b < bottom_blobs.size(); b++)
131+
int top_channels = bottom_blobs[0].c;
132+
for (size_t b = 1; b < bottom_blobs.size(); b++)
133133
{
134134
const Mat& bottom_blob = bottom_blobs[b];
135+
if (bottom_blob.w != w || bottom_blob.h != h || bottom_blob.d != d)
136+
return -100;
137+
135138
top_channels += bottom_blob.c;
136139
}
137140

src/layer/loongarch/concat_loongarch.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,13 @@ int Concat_loongarch::forward(const std::vector<Mat>& bottom_blobs, std::vector<
176176
// total channels
177177
size_t elemsize = bottom_blobs[0].elemsize;
178178
int elempack = bottom_blobs[0].elempack;
179-
int top_channels = 0;
180-
for (size_t b = 0; b < bottom_blobs.size(); b++)
179+
int top_channels = bottom_blobs[0].c * bottom_blobs[0].elempack;
180+
for (size_t b = 1; b < bottom_blobs.size(); b++)
181181
{
182182
const Mat& bottom_blob = bottom_blobs[b];
183+
if (bottom_blob.w != w || bottom_blob.h != h || bottom_blob.d != d)
184+
return -100;
185+
183186
elemsize = std::min(elemsize, bottom_blob.elemsize);
184187
elempack = std::min(elempack, bottom_blob.elempack);
185188
top_channels += bottom_blob.c * bottom_blob.elempack;

src/layer/mips/concat_mips.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,13 @@ int Concat_mips::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
176176
// total channels
177177
size_t elemsize = bottom_blobs[0].elemsize;
178178
int elempack = bottom_blobs[0].elempack;
179-
int top_channels = 0;
180-
for (size_t b = 0; b < bottom_blobs.size(); b++)
179+
int top_channels = bottom_blobs[0].c * bottom_blobs[0].elempack;
180+
for (size_t b = 1; b < bottom_blobs.size(); b++)
181181
{
182182
const Mat& bottom_blob = bottom_blobs[b];
183+
if (bottom_blob.w != w || bottom_blob.h != h || bottom_blob.d != d)
184+
return -100;
185+
183186
elemsize = std::min(elemsize, bottom_blob.elemsize);
184187
elempack = std::min(elempack, bottom_blob.elempack);
185188
top_channels += bottom_blob.c * bottom_blob.elempack;

src/layer/riscv/concat_riscv.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,13 @@ int Concat_riscv::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>
218218
// total channels
219219
size_t elemsize = bottom_blobs[0].elemsize;
220220
int elempack = bottom_blobs[0].elempack;
221-
int top_channels = 0;
222-
for (size_t b = 0; b < bottom_blobs.size(); b++)
221+
int top_channels = bottom_blobs[0].c * bottom_blobs[0].elempack;
222+
for (size_t b = 1; b < bottom_blobs.size(); b++)
223223
{
224224
const Mat& bottom_blob = bottom_blobs[b];
225+
if (bottom_blob.w != w || bottom_blob.h != h || bottom_blob.d != d)
226+
return -100;
227+
225228
elemsize = std::min(elemsize, bottom_blob.elemsize);
226229
elempack = std::min(elempack, bottom_blob.elempack);
227230
top_channels += bottom_blob.c * bottom_blob.elempack;

src/layer/x86/concat_x86.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,13 @@ int Concat_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
388388
// total channels
389389
size_t elemsize = bottom_blobs[0].elemsize;
390390
int elempack = bottom_blobs[0].elempack;
391-
int top_channels = 0;
392-
for (size_t b = 0; b < bottom_blobs.size(); b++)
391+
int top_channels = bottom_blobs[0].c * bottom_blobs[0].elempack;
392+
for (size_t b = 1; b < bottom_blobs.size(); b++)
393393
{
394394
const Mat& bottom_blob = bottom_blobs[b];
395+
if (bottom_blob.w != w || bottom_blob.h != h || bottom_blob.d != d)
396+
return -100;
397+
395398
elemsize = std::min(elemsize, bottom_blob.elemsize);
396399
elempack = std::min(elempack, bottom_blob.elempack);
397400
top_channels += bottom_blob.c * bottom_blob.elempack;

0 commit comments

Comments
 (0)