From 4e1d828cacdf3be92fab23c7ba60a46d273bffc4 Mon Sep 17 00:00:00 2001 From: mskishan26 Date: Thu, 10 Oct 2024 12:22:07 -0400 Subject: [PATCH] uae_changes --- uae_gen.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/uae_gen.py b/uae_gen.py index 3aba168..0d4e8c8 100644 --- a/uae_gen.py +++ b/uae_gen.py @@ -10,8 +10,9 @@ from scipy.io import savemat, loadmat from tqdm import tqdm import os +import re -__version__ = '1.0.0' +__version__ = '1.0.2' DISPLAY_TITLE = r""" _ @@ -32,6 +33,18 @@ parser.add_argument('-V', '--version', action='version', version=f'%(prog)s {__version__}') +def is_basename(x:str): return os.path.dirname(x) == "" + +def uae(f_map): + return torch.sum(f_map)/f_map.numel() + +def hook_fn(module, input, output): + global activation_outputs + result = uae(output[0]) + activation_outputs.append(result) + +activation_outputs = [] + # The main function of this *ChRIS* plugin is denoted by this ``@chris_plugin`` "decorator." # Some metadata about the plugin is specified here. There is more metadata specified in setup.py. # @@ -42,7 +55,7 @@ category='', # ref. https://chrisstore.co/plugins min_memory_limit='100Mi', # supported units: Mi, Gi min_cpu_limit='1000m', # millicores, e.g. "1000m" = 1 CPU core - min_gpu_limit=1 # set min_gpu_limit=1 to enable GPU + min_gpu_limit=0 # set min_gpu_limit=1 to enable GPU ) def main(options: Namespace, inputdir: Path, outputdir: Path): """ @@ -66,24 +79,16 @@ def main(options: Namespace, inputdir: Path, outputdir: Path): # adding a progress bar and parallelism. mapper = PathMapper.file_mapper(inputdir, outputdir, glob=options.pattern) for input_file, output_file in mapper: - def is_basename(x:str): return os.path.dirname(x) == "" - - def uae(f_map): - return torch.sum(f_map)/f_map.numel() - - def hook_fn(module, input, output): - result = uae(output[0]) - activation_outputs.append(result) - - activation_outputs = [] + global activation_outputs patient_id = os.path.splitext(os.path.basename(input_file))[0] - print(patient_id) patient_id = patient_id.split('_')[0] + model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True) # print(model) hooks = [] activation_outputs = [] + layers = [2,4,7,9,12,14,16,19,21,23,26,28,30] for layer in layers: hook = model.features[layer].register_forward_hook(hook_fn) @@ -101,7 +106,10 @@ def hook_fn(module, input, output): model = model.to(device) mat = mat73.loadmat(input_file) - mat = mat['tf_images_mat'] + mat_keys= mat.keys() + var_pattern = re.compile(r'^tf_image') + var_matches = [item for item in mat_keys if var_pattern.match(item)] + mat = mat[var_matches[0]] epochs = mat.shape[0] channels = mat.shape[1]