@@ -321,3 +321,146 @@ def separate_set(selections,set_divisions = [0.5,0.5],IDs=None):
321
321
if not is_none and IDs [i ] not in prime_hasher :
322
322
prime_hasher [IDs [i ]] = j + 1
323
323
return selections_ids
324
+
325
+ from copy import deepcopy as copy
326
+
327
+ def recompute_selection_ratios (selection_ratios ,selection_limits ,N ):
328
+ new_selection_ratios = copy (selection_ratios )
329
+ assert (np .any (np .isinf (selection_limits )))
330
+ variable = [True for i in range (len (selection_ratios ))]
331
+
332
+ for i in range (len (selection_ratios )):
333
+ if selection_ratios [i ] * N > selection_limits [i ]:
334
+ new_selection_ratios [i ] = selection_limits [i ] / N
335
+ variable [i ] = False
336
+ else :
337
+ new_selection_ratios [i ] = selection_ratios [i ]
338
+ vsum = 0.0
339
+ nvsum = 0.0
340
+ for i in range (len (selection_ratios )):
341
+ if variable [i ]: vsum += new_selection_ratios [i ]
342
+ else : nvsum += new_selection_ratios [i ]
343
+ assert (nvsum < 1 )
344
+ for i in range (len (selection_ratios )):
345
+ if variable [i ]:
346
+ new_selection_ratios [i ] = \
347
+ (new_selection_ratios [i ] / vsum ) * (1 - nvsum )
348
+ return new_selection_ratios
349
+
350
+ def get_balanced_filename_list (test_variable ,confounds_array ,
351
+ selection_ratios = [0.66 ,0.16 ,0.16 ],
352
+ selection_limits = [np .Inf ,np .Inf ,np .Inf ],value_ranges = [],
353
+ output_selection_savepath = None ,test_value_ranges = None ,
354
+ get_all_test_set = False ,total_size_limit = None ,
355
+ verbose = False ,non_confound_value_ranges = {},database = None ,
356
+ n_buckets = 10 ,patient_id_key = None ):
357
+ if len (value_ranges ) == 0 :
358
+ value_ranges = [None for _ in confounds_array ]
359
+ assert (len (value_ranges ) == len (confounds_array ))
360
+
361
+ covars_df = database
362
+ if verbose : print ("len(covars): %d" % len (covars_df ))
363
+ value_selection = np .ones ((len (covars_df ),),dtype = bool )
364
+ for ncv in non_confound_value_ranges :
365
+ if ncv in confounds_array :
366
+ print ("confounds_array: %s" % str (confounds_array ))
367
+ print ("non_confound_value_ranges: %s" % \
368
+ str (non_confound_value_ranges ))
369
+ print ("ncv: %s" % str (ncv ))
370
+ assert (ncv not in confounds_array )
371
+ confounds_array .append (ncv )
372
+ value_ranges .append (non_confound_value_ranges [ncv ])
373
+ confounds_array .append (test_variable )
374
+ value_ranges .append (test_value_ranges )
375
+ if verbose : print ("confounds_array: %s" % str (confounds_array ))
376
+ if verbose : print ("value_ranges: %s" % str (value_ranges ))
377
+ for i in range (len (confounds_array )):
378
+ temp_value_selection = np .zeros ((len (covars_df ),),dtype = bool )
379
+ c = covars_df [confounds_array [i ]]
380
+ value_range = value_ranges [i ]
381
+ if value_range is None :
382
+ continue
383
+ if isinstance (value_range ,tuple ):
384
+ for j in range (len (c )):
385
+ if c [j ] is None :
386
+ continue
387
+ if c [j ] >= value_range [0 ] and \
388
+ c [j ] <= value_range [1 ]:
389
+ temp_value_selection [j ] = True
390
+ elif callable (value_range ):
391
+ for j in range (len (c )):
392
+ if c [j ] is None :
393
+ continue
394
+ if value_range (c [j ]):
395
+ temp_value_selection [j ] = True
396
+ else :
397
+ for j in range (len (c )):
398
+ if c [j ] is None :
399
+ continue
400
+ if c [j ] in value_range :
401
+ temp_value_selection [j ] = True
402
+ value_selection = np .logical_and (value_selection ,
403
+ temp_value_selection )
404
+ del confounds_array [- 1 ]
405
+ del value_ranges [- 1 ]
406
+ for ncv in non_confound_value_ranges :
407
+ del confounds_array [- 1 ]
408
+ del value_ranges [- 1 ]
409
+ if verbose :
410
+ print ("value_selection.sum(): %s" % str (value_selection .sum ()))
411
+ if verbose :
412
+ print ("value_selection.shape: %s" % str (value_selection .shape ))
413
+ covars_df = covars_df [value_selection ]
414
+ covars_df = covars_df .sample (frac = 1 )
415
+ test_vars = covars_df [test_variable ].to_numpy (dtype = np .dtype (object ))
416
+ # If it's a string array, it just returns strings
417
+ test_vars = bucketize (test_vars ,n_buckets )
418
+ ccc = {}
419
+ if output_selection_savepath is not None and \
420
+ os .path .isfile (output_selection_savepath ):
421
+ selection = np .load (output_selection_savepath )
422
+ else :
423
+
424
+ if len (confounds_array ) == 0 :
425
+ if verbose : print (test_value_ranges )
426
+ selection = class_balance (test_vars ,[],
427
+ unique_classes = test_value_ranges ,plim = 0.1 )
428
+ else :
429
+ selection = class_balance (test_vars ,
430
+ covars_df [confounds_array ].to_numpy (\
431
+ dtype = np .dtype (object )).T ,
432
+ unique_classes = test_value_ranges ,plim = 0.1 )
433
+ selection_ratios = recompute_selection_ratios (selection_ratios ,
434
+ selection_limits ,np .sum (selection ))
435
+ if total_size_limit is not None :
436
+ select_sum = selection .sum ()
437
+ rr = list (range (len (selection )))
438
+ for i in rr :
439
+ if select_sum <= total_size_limit :
440
+ break
441
+ if selection [i ]:
442
+ selection [i ] = 0
443
+ select_sum -= 1
444
+ if patient_id_key is None :
445
+ selection = separate_set (selection ,selection_ratios )
446
+ else :
447
+ selection = separate_set (selection ,selection_ratios ,
448
+ covars_df [patient_id_key ].to_numpy (dtype = \
449
+ np .dtype (object )).T )
450
+ if output_selection_savepath is not None :
451
+ np .save (output_selection_savepath ,selection )
452
+ all_files = (covars_df .index .values )
453
+ if get_all_test_set :
454
+ selection [selection == 0 ] = 2
455
+ X_files = [all_files [selection == i ] \
456
+ for i in range (1 ,len (selection_ratios ) + 1 )]
457
+ Y_files = [test_vars [selection == i ] \
458
+ for i in range (1 ,len (selection_ratios ) + 1 )]
459
+ if verbose : print (np .sum ([len (x ) for x in X_files ]))
460
+ for i in range (len (X_files )):
461
+ rr = list (range (len (X_files [i ])))
462
+ random .shuffle (rr )
463
+ X_files [i ] = X_files [i ][rr ]
464
+ Y_files [i ] = Y_files [i ][rr ]
465
+ return X_files ,Y_files
466
+
0 commit comments