Skip to content

Commit 58e4811

Browse files
author
Edgar Solomonik
committed
add safeguarding against potential integer overflows in a few places, other minor corrections
1 parent 4bd532c commit 58e4811

File tree

5 files changed

+72
-48
lines changed

5 files changed

+72
-48
lines changed

examples/recursive_matmul.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ void recursive_matmul(int n,
2222
MPI_Comm_rank(pcomm, &rank);
2323
MPI_Comm_size(pcomm, &num_pes);
2424

25-
if (num_pes == 1){
25+
if (num_pes == 1 || m == 1 || n == 1 || k==1){
2626
C["ij"] += 1.0*A["ik"]*B["kj"];
2727
} else {
2828
for (div=2; num_pes%div!=0; div++){}

scalapack_tests/svd.cxx

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,33 +30,39 @@ bool svd(Matrix<dtype> A,
3030

3131
bool pass_orthogonality = true;
3232

33-
double nrm;
34-
E.norm2(nrm);
35-
if (nrm > m*n*1.E-6){
33+
double nrm1, nrm2, nrm3;
34+
E.norm2(nrm1);
35+
if (nrm1 > m*n*1.E-6){
3636
pass_orthogonality = false;
3737
}
3838

3939
E["ii"] = 1.;
4040

4141
E["ij"] -= VT["ik"]*conj<dtype>(VT)["jk"];
4242

43-
E.norm2(nrm);
44-
if (nrm > m*n*1.E-6){
43+
E.norm2(nrm2);
44+
if (nrm2 > m*n*1.E-6){
4545
pass_orthogonality = false;
4646
}
4747

4848
A["ij"] -= U["ik"]*S["k"]*VT["kj"];
4949

5050
bool pass_residual = true;
51-
A.norm2(nrm);
52-
if (nrm > m*n*n*1.E-6){
51+
A.norm2(nrm3);
52+
if (nrm3 > m*n*n*1.E-6){
5353
pass_residual = false;
5454
}
5555

5656
#ifndef TEST_SUITE
5757
if (dw.rank == 0){
5858
printf("SVD orthogonality check returned %d, residual check %d\n", pass_orthogonality, pass_residual);
5959
}
60+
#else
61+
if (!pass_residual || ! pass_orthogonality){
62+
if (dw.rank == 0){
63+
printf("SVD orthogonality check returned %d (%lf, %lf), residual check %d (%lf)\n", pass_orthogonality, nrm1, nrm2, pass_residual, nrm3);
64+
}
65+
}
6066
#endif
6167
return pass_residual & pass_orthogonality;
6268
}

src/contraction/contraction.cxx

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2475,14 +2475,12 @@ namespace CTF_int {
24752475
#else
24762476
for (int t=global_comm.rank+1; t<(int)wrld->topovec.size()+8; t+=global_comm.np){
24772477
#endif
2478-
TAU_FSTART(evaluate_mappings_clear_and_init);
24792478
A->clear_mapping();
24802479
B->clear_mapping();
24812480
C->clear_mapping();
24822481
A->set_padding();
24832482
B->set_padding();
24842483
C->set_padding();
2485-
TAU_FSTOP(evaluate_mappings_clear_and_init);
24862484

24872485
topology * topo_i = NULL;
24882486
if (t < 8){
@@ -2511,9 +2509,7 @@ namespace CTF_int {
25112509
} else topo_i = wrld->topovec[t-8];
25122510
ASSERT(topo_i != NULL);
25132511

2514-
TAU_FSTART(map_ctr_to_topo);
25152512
ret = map_to_topology(topo_i, j);
2516-
TAU_FSTOP(map_ctr_to_topo);
25172513

25182514
if (ret == NEGATIVE){
25192515
//printf("map_to_topology returned negative\n");
@@ -2527,18 +2523,13 @@ namespace CTF_int {
25272523
B->topo = topo_i;
25282524
C->topo = topo_i;
25292525

2530-
TAU_FSTART(check_ctr_mapping);
25312526
if (check_mapping() == 0){
2532-
TAU_FSTOP(check_ctr_mapping);
25332527
continue;
25342528
}
2535-
TAU_FSTOP(check_ctr_mapping);
25362529
est_time = 0.0;
2537-
TAU_FSTART(evaluate_mappings_set_padding2);
25382530
A->set_padding();
25392531
B->set_padding();
25402532
C->set_padding();
2541-
TAU_FSTOP(evaluate_mappings_set_padding2);
25422533
#if DEBUG >= 3
25432534
if (global_comm.rank == 0){
25442535
printf("\nTest mappings:\n");
@@ -2565,8 +2556,14 @@ namespace CTF_int {
25652556
}
25662557
nnz_frac_C = std::min(1.,std::max(nnz_frac_C,nnz_frac_A*nnz_frac_B*len_ctr));
25672558
}
2559+
// check this early on to avoid 64-bit integer overflow
2560+
double size_memuse = A->size*nnz_frac_A*A->sr->el_size + B->size*nnz_frac_B*B->sr->el_size + C->size*nnz_frac_C*C->sr->el_size;
2561+
if (size_memuse >= (double)max_memuse){
2562+
if (global_comm.rank == 0)
2563+
DPRINTF(1,"Not enough memory available for topo %d with order %d to store tensors %ld/%ld\n", t,j,(int64_t)size_memuse,max_memuse);
2564+
continue;
2565+
}
25682566

2569-
TAU_FSTART(evaluate_mappings_folding);
25702567
#if FOLD_TSR
25712568
if (can_fold()){
25722569
est_time = est_time_fold();
@@ -2589,7 +2586,6 @@ namespace CTF_int {
25892586
est_time = sctr->est_time_rec(sctr->num_lyr);
25902587
}
25912588
}
2592-
TAU_FSTOP(evaluate_mappings_folding);
25932589
#if DEBUG >= 3
25942590
if (global_comm.rank == 0){
25952591
printf("mapping passed contr est_time = %E sec\n", est_time);
@@ -2600,7 +2596,6 @@ namespace CTF_int {
26002596
need_remap_A = 0;
26012597
need_remap_B = 0;
26022598
need_remap_C = 0;
2603-
TAU_FSTART(evaluate_mappings_comp_maps);
26042599
if (topo_i == old_topo_A){
26052600
for (d=0; d<A->order; d++){
26062601
if (!comp_dim_map(&A->edge_map[d],&old_map_A[d]))
@@ -2631,8 +2626,6 @@ namespace CTF_int {
26312626
}
26322627
} else
26332628
need_remap_C = 1;
2634-
TAU_FSTOP(evaluate_mappings_comp_maps);
2635-
TAU_FSTART(est_ctr_map_time);
26362629
if (need_remap_C) {
26372630
est_time += 2.*C->est_redist_time(*dC, nnz_frac_C);
26382631
memuse = std::max(1.0*memuse,2.*C->get_redist_mem(*dC, nnz_frac_C));
@@ -2646,20 +2639,16 @@ namespace CTF_int {
26462639
printf("total (with redistribution and transp) est_time = %E\n", est_time);
26472640
}
26482641
#endif
2649-
TAU_FSTOP(est_ctr_map_time);
26502642
ASSERT(est_time >= 0.0);
26512643

2652-
TAU_FSTART(get_avail_res);
26532644
if ((int64_t)memuse >= max_memuse){
26542645
if (global_comm.rank == 0)
26552646
DPRINTF(1,"Not enough memory available for topo %d with order %d memory %ld/%ld\n", t,j,memuse,max_memuse);
2656-
TAU_FSTOP(get_avail_res);
26572647
delete sctr;
26582648
continue;
26592649
}
26602650
if ((!A->is_sparse && A->size > INT_MAX) ||(!B->is_sparse && B->size > INT_MAX) || (!C->is_sparse && C->size > INT_MAX)){
26612651
DPRINTF(1,"MPI does not handle enough bits for topo %d with order\n", j);
2662-
TAU_FSTOP(get_avail_res);
26632652
delete sctr;
26642653
continue;
26652654
}
@@ -2670,7 +2659,6 @@ namespace CTF_int {
26702659
btopo = 6*t+j;
26712660
}
26722661
delete sctr;
2673-
TAU_FSTOP(get_avail_res);
26742662
}
26752663
}
26762664
TAU_FSTOP(evaluate_mappings)
@@ -3509,7 +3497,27 @@ namespace CTF_int {
35093497
}
35103498
#endif
35113499

3512-
3500+
if (blk_sz_A < vrt_sz_A){
3501+
printf("blk_sz_A = %ld, vrt_sz_A = %ld\n", blk_sz_A, vrt_sz_A);
3502+
printf("sizes are %ld %ld %ld\n",A->size,B->size,C->size);
3503+
A->print_map(stdout, 0);
3504+
B->print_map(stdout, 0);
3505+
C->print_map(stdout, 0);
3506+
}
3507+
if (blk_sz_B < vrt_sz_B){
3508+
printf("blk_sz_B = %ld, vrt_sz_B = %ld\n", blk_sz_B, vrt_sz_B);
3509+
printf("sizes are %ld %ld %ld\n",A->size,B->size,C->size);
3510+
A->print_map(stdout, 0);
3511+
B->print_map(stdout, 0);
3512+
C->print_map(stdout, 0);
3513+
}
3514+
if (blk_sz_C < vrt_sz_C){
3515+
printf("blk_sz_C = %ld, vrt_sz_C = %ld\n", blk_sz_C, vrt_sz_C);
3516+
printf("sizes are %ld %ld %ld\n",A->size,B->size,C->size);
3517+
A->print_map(stdout, 0);
3518+
B->print_map(stdout, 0);
3519+
C->print_map(stdout, 0);
3520+
}
35133521
ASSERT(blk_sz_A >= vrt_sz_A);
35143522
ASSERT(blk_sz_B >= vrt_sz_B);
35153523
ASSERT(blk_sz_C >= vrt_sz_C);

src/contraction/ctr_2d_general.cxx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ namespace CTF_int {
4747
int & load_phase_C){
4848
mapping * map;
4949
int j;
50-
int nstep = 1;
50+
int64_t nstep = 1;
5151
if (comp_dim_map(&C->edge_map[i_C], &B->edge_map[i_B])){
5252
map = &B->edge_map[i_B];
5353
while (map->has_child) map = map->child;
@@ -111,16 +111,24 @@ namespace CTF_int {
111111
/virt_blk_len_C[j];
112112
}
113113
if (B->edge_map[i_B].type != PHYSICAL_MAP){
114+
if (blk_sz_B / nstep == 0)
115+
printf("blk_len_B[%d] = %d, nstep = %ld blk_sz_B = %ld\n",i_B,blk_len_B[i_B],nstep,blk_sz_B);
114116
blk_sz_B = blk_sz_B / nstep;
115117
blk_len_B[i_B] = blk_len_B[i_B] / nstep;
116118
} else {
119+
if (blk_sz_B * B->edge_map[i_B].np/ nstep == 0)
120+
printf("blk_len_B[%d] = %d B->edge_map[%d].np = %d, nstep = %ld blk_sz_B = %ld\n",i_B,blk_len_B[i_B],i_B,B->edge_map[i_B].np,nstep,blk_sz_B);
117121
blk_sz_B = blk_sz_B * B->edge_map[i_B].np / nstep;
118122
blk_len_B[i_B] = blk_len_B[i_B] * B->edge_map[i_B].np / nstep;
119123
}
120124
if (C->edge_map[i_C].type != PHYSICAL_MAP){
125+
if (blk_sz_C / nstep == 0)
126+
printf("blk_len_C[%d] = %d, nstep = %ld blk_sz_C = %ld\n",i_C,blk_len_C[i_C],nstep,blk_sz_C);
121127
blk_sz_C = blk_sz_C / nstep;
122128
blk_len_C[i_C] = blk_len_C[i_C] / nstep;
123129
} else {
130+
if (blk_sz_C * C->edge_map[i_C].np/ nstep == 0)
131+
printf("blk_len_C[%d] = %d C->edge_map[%d].np = %d, nstep = %ld blk_sz_C = %ld\n",i_C,blk_len_C[i_C],i_C,C->edge_map[i_C].np,nstep,blk_sz_C);
124132
blk_sz_C = blk_sz_C * C->edge_map[i_C].np / nstep;
125133
blk_len_C[i_C] = blk_len_C[i_C] * C->edge_map[i_C].np / nstep;
126134
}

src/interface/matrix.cxx

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ namespace CTF {
213213
if (lda == nrow/pc){
214214
memcpy(this->data, (char*)data_, sizeof(dtype)*this->size);
215215
} else {
216-
for (int i=0; i<ncol/pc; i++){
217-
memcpy(this->data+i*lda*sizeof(dtype),(char*)(data_+i*lda), nrow*sizeof(dtype)/pr);
216+
for (int64_t i=0; i<ncol/pc; i++){
217+
memcpy(this->data+i*lda*sizeof(dtype),(char*)(data_+i*lda), ((int64_t)nrow)*sizeof(dtype)/pr);
218218
}
219219
}
220220
} else {
@@ -258,7 +258,7 @@ namespace CTF {
258258
if (lda == nrow/pc){
259259
memcpy((char*)data_, this->data, sizeof(dtype)*this->size);
260260
} else {
261-
for (int i=0; i<ncol/pc; i++){
261+
for (int64_t i=0; i<ncol/pc; i++){
262262
memcpy((char*)(data_+i*lda), this->data+i*lda*sizeof(dtype), nrow*sizeof(dtype)/pr);
263263
}
264264
}
@@ -437,11 +437,11 @@ namespace CTF {
437437

438438
this->read_mat(desca, A);
439439

440-
dtype * tau = (dtype*)malloc(n*sizeof(dtype));
440+
dtype * tau = (dtype*)malloc(((int64_t)n)*sizeof(dtype));
441441
dtype dlwork;
442442
CTF_SCALAPACK::pgeqrf<dtype>(m,n,A,1,1,desca,tau,(dtype*)&dlwork,-1,&info);
443443
int lwork = get_int_fromreal<dtype>(dlwork);
444-
dtype * work = (dtype*)malloc(lwork*sizeof(dtype));
444+
dtype * work = (dtype*)malloc(((int64_t)lwork)*sizeof(dtype));
445445
CTF_SCALAPACK::pgeqrf<dtype>(m,n,A,1,1,desca,tau,work,lwork,&info);
446446

447447

@@ -453,14 +453,14 @@ namespace CTF {
453453
R = Matrix<dtype>(Q);
454454
else {
455455
R = Matrix<dtype>(desca,dQ,*this->wrld,*this->sr);
456-
R = R.slice(0,m*(n-1)+n-1);
456+
R = R.slice(0,((int64_t)m)*(n-1)+n-1);
457457
}
458458

459459

460460
free(work);
461461
CTF_SCALAPACK::porgqr<dtype>(m,n,n,dQ,1,1,desca,tau,(dtype*)&dlwork,-1,&info);
462462
lwork = get_int_fromreal<dtype>(dlwork);
463-
work = (dtype*)malloc(lwork*sizeof(dtype));
463+
work = (dtype*)malloc(((int64_t)lwork)*sizeof(dtype));
464464
CTF_SCALAPACK::porgqr<dtype>(m,n,n,dQ,1,1,desca,tau,work,lwork,&info);
465465
Q = Matrix<dtype>(desca, dQ, (*(this->wrld)));
466466
free(work);
@@ -502,22 +502,24 @@ namespace CTF {
502502
int64_t kpr = k/pr + (k % pr != 0);
503503
int64_t kpc = k/pc + (k % pc != 0);
504504
int64_t npc = n/pc + (n % pc != 0);
505+
505506
CTF_SCALAPACK::cdescinit(descu, m, k, 1, 1, 0, 0, ictxt, mpr, &info);
506507
CTF_SCALAPACK::cdescinit(descvt, k, n, 1, 1, 0, 0, ictxt, kpr, &info);
507-
dtype * A = (dtype*)malloc(this->size*sizeof(dtype));
508+
509+
dtype * A = (dtype*)CTF_int::alloc(this->size*sizeof(dtype));
508510

509511

510-
dtype * u = (dtype*)new dtype[mpr*kpc];
511-
dtype * s = (dtype*)new dtype[k];
512-
dtype * vt = (dtype*)new dtype[kpr*npc];
512+
dtype * u = (dtype*)CTF_int::alloc(sizeof(dtype)*mpr*kpc);
513+
dtype * s = (dtype*)CTF_int::alloc(sizeof(dtype)*k);
514+
dtype * vt = (dtype*)CTF_int::alloc(sizeof(dtype)*kpr*npc);
513515
this->read_mat(desca, A);
514516

515517
int lwork;
516518
dtype dlwork;
517519
CTF_SCALAPACK::pgesvd<dtype>('V', 'V', m, n, NULL, 1, 1, desca, NULL, NULL, 1, 1, descu, vt, 1, 1, descvt, &dlwork, -1, &info);
518520

519521
lwork = get_int_fromreal<dtype>(dlwork);
520-
dtype * work = (dtype*)malloc(sizeof(dtype)*lwork);
522+
dtype * work = (dtype*)CTF_int::alloc(sizeof(dtype)*((int64_t)lwork));
521523

522524
CTF_SCALAPACK::pgesvd<dtype>('V', 'V', m, n, A, 1, 1, desca, s, u, 1, 1, descu, vt, 1, 1, descvt, work, lwork, &info);
523525

@@ -537,18 +539,18 @@ namespace CTF {
537539
}
538540
if (rank > 0 && rank < k) {
539541
S = S.slice(0, rank-1);
540-
U = U.slice(0, rank*(m)-1);
541-
VT = VT.slice(0, k*n-(k-rank+1));
542+
U = U.slice(0, rank*((int64_t)m)-1);
543+
VT = VT.slice(0, k*((int64_t)n)-(k-rank+1));
542544
}
543545

544-
free(A);
545-
delete [] u;
546-
delete [] s;
547-
delete [] vt;
546+
CTF_int::cdealloc(A);
547+
CTF_int::cdealloc(u);
548+
CTF_int::cdealloc(s);
549+
CTF_int::cdealloc(vt);
548550
free(desca);
549551
free(descu);
550552
free(descvt);
551-
free(work);
553+
CTF_int::cdealloc(work);
552554

553555
}
554556

0 commit comments

Comments
 (0)