Skip to content

Commit 02ab737

Browse files
committed
enhancements to random number generation and accommodates more or larger bootstrap resamples
1 parent 8c5d9b1 commit 02ab737

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

src/boot.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void mexFunction (int nlhs, mxArray* plhs[],
103103
mexErrMsgTxt ("The first input argument (N or X) must be of type double.");
104104
}
105105
// Second input argument (nboot)
106-
const int nboot = *(mxGetPr (prhs[1])); // 32-bit int
106+
const int nboot = static_cast<const int> ( *(mxGetPr (prhs[1])) ); // 32-bit int
107107
if ( mxGetNumberOfElements (prhs[1]) > 1 ) {
108108
mexErrMsgTxt ("The second input argument (NBOOT) must be scalar.");
109109
}
@@ -125,7 +125,7 @@ void mexFunction (int nlhs, mxArray* plhs[],
125125
if (mxGetNumberOfElements (prhs[2]) > 1 || !mxIsClass (prhs[2], "logical")) {
126126
mexErrMsgTxt ("The third input argument (LOO) must be a logical scalar value.");
127127
}
128-
loo = *(mxGetLogicals (prhs[2]));
128+
loo = static_cast<bool> ( *(mxGetLogicals (prhs[2])) );
129129
} else {
130130
loo = false;
131131
}
@@ -138,11 +138,13 @@ void mexFunction (int nlhs, mxArray* plhs[],
138138
if ( !mxIsClass (prhs[3], "double") ) {
139139
mexErrMsgTxt ("The fourth input argument (SEED) must be of type double.");
140140
}
141-
seed = *(mxGetPr(prhs[3]));
141+
seed = static_cast<unsigned long int> ( *(mxGetPr(prhs[3])) );
142142
if ( !mxIsFinite (seed) ) {
143143
mexErrMsgTxt ("The fourth input argument (SEED) cannot be NaN or Inf.");
144144
}
145-
srand (seed);
145+
} else {
146+
random_device rd;
147+
seed = static_cast<unsigned int> ( rd () );
146148
}
147149
// Fifth input argument (w, weights)
148150
// Error checking is handled later (see below in 'Declare variables' section)
@@ -156,12 +158,11 @@ void mexFunction (int nlhs, mxArray* plhs[],
156158
mwSize dims[2] = {static_cast<mwSize>(n), static_cast<mwSize>(nboot)};
157159
plhs[0] = mxCreateNumericArray (2, dims,
158160
mxDOUBLE_CLASS,
159-
mxREAL); // Prepare array for bootstrap sample indices
160-
long long int N = n * nboot; // Total counts of all sample indices
161-
long long int k; // Variable to store random number
162-
long long int d; // Counter for cumulative sum calculations
163-
vector<long long int> c; // Counter for each of the sample indices
164-
c.reserve (n);
161+
mxREAL); // Prepare array for sample indices
162+
long long unsigned int N = n * nboot; // Total counts of all sample indices
163+
long long unsigned int k; // Variable to store random number
164+
long long unsigned int d; // Counter for cumulative sum calculation
165+
vector<long long int> c(n, nboot); // Counter for each of the sample indices
165166
if ( nrhs > 4 && !mxIsEmpty (prhs[4]) ) {
166167
// Assign user defined weights (counts)
167168
if ( !mxIsClass (prhs[4], "double") ) {
@@ -182,17 +183,12 @@ void mexFunction (int nlhs, mxArray* plhs[],
182183
if ( w[i] < 0 ) {
183184
mexErrMsgTxt ("The fifth input argument (WEIGHTS) must contain only positive integers.");
184185
}
185-
c.push_back (w[i]); // Set each element in c to the specified weight
186+
c[i] = w[i]; // Set each element in c to the specified weight
186187
s += c[i];
187188
}
188189
if ( s != N ) {
189190
mexErrMsgTxt ("The elements of WEIGHTS must sum to N * NBOOT.");
190191
}
191-
} else {
192-
// Assign weights (counts) for uniform sampling
193-
for ( int i = 0; i < n ; i++ ) {
194-
c.push_back (nboot); // Set each element in c to nboot
195-
}
196192
}
197193
long long int m = 0; // Counter for LOO sample index r
198194
int r = -1; // Sample index for LOO
@@ -201,8 +197,8 @@ void mexFunction (int nlhs, mxArray* plhs[],
201197
double *ptr = (double *) mxGetData(plhs[0]);
202198

203199
// Initialize pseudo-random number generator (Mersenne Twister 19937)
204-
mt19937 rng (rand ());
205-
uniform_int_distribution<int> distr (0, n - 1);
200+
mt19937_64 rng (seed);
201+
uniform_int_distribution<long long unsigned int> distr (0, n - 1);
206202

207203
// Perform balanced sampling
208204
for ( int b = 0; b < nboot ; b++ ) {
@@ -226,7 +222,7 @@ void mexFunction (int nlhs, mxArray* plhs[],
226222
loo = false;
227223
}
228224
}
229-
uniform_int_distribution<int> distk (0, N - m - 1);
225+
uniform_int_distribution<long long unsigned int> distk (0, N - m - 1);
230226
k = distk (rng);
231227
d = c[0];
232228
for ( int j = 0; j < n ; j++ ) {

0 commit comments

Comments
 (0)