Skip to content

Commit

Permalink
uae_changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mskishan26 committed Oct 10, 2024
1 parent 45bd360 commit 4e1d828
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions uae_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from scipy.io import savemat, loadmat
from tqdm import tqdm
import os
import re

__version__ = '1.0.0'
__version__ = '1.0.2'

DISPLAY_TITLE = r"""
_
Expand All @@ -32,6 +33,18 @@
parser.add_argument('-V', '--version', action='version',
version=f'%(prog)s {__version__}')

def is_basename(x:str): return os.path.dirname(x) == ""

def uae(f_map):
return torch.sum(f_map)/f_map.numel()

def hook_fn(module, input, output):
global activation_outputs
result = uae(output[0])
activation_outputs.append(result)

activation_outputs = []

# The main function of this *ChRIS* plugin is denoted by this ``@chris_plugin`` "decorator."
# Some metadata about the plugin is specified here. There is more metadata specified in setup.py.
#
Expand All @@ -42,7 +55,7 @@
category='', # ref. https://chrisstore.co/plugins
min_memory_limit='100Mi', # supported units: Mi, Gi
min_cpu_limit='1000m', # millicores, e.g. "1000m" = 1 CPU core
min_gpu_limit=1 # set min_gpu_limit=1 to enable GPU
min_gpu_limit=0 # set min_gpu_limit=1 to enable GPU
)
def main(options: Namespace, inputdir: Path, outputdir: Path):
"""
Expand All @@ -66,24 +79,16 @@ def main(options: Namespace, inputdir: Path, outputdir: Path):
# adding a progress bar and parallelism.
mapper = PathMapper.file_mapper(inputdir, outputdir, glob=options.pattern)
for input_file, output_file in mapper:
def is_basename(x:str): return os.path.dirname(x) == ""

def uae(f_map):
return torch.sum(f_map)/f_map.numel()

def hook_fn(module, input, output):
result = uae(output[0])
activation_outputs.append(result)

activation_outputs = []
global activation_outputs

patient_id = os.path.splitext(os.path.basename(input_file))[0]
print(patient_id)
patient_id = patient_id.split('_')[0]

model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)
# print(model)
hooks = []
activation_outputs = []

layers = [2,4,7,9,12,14,16,19,21,23,26,28,30]
for layer in layers:
hook = model.features[layer].register_forward_hook(hook_fn)
Expand All @@ -101,7 +106,10 @@ def hook_fn(module, input, output):
model = model.to(device)

mat = mat73.loadmat(input_file)
mat = mat['tf_images_mat']
mat_keys= mat.keys()
var_pattern = re.compile(r'^tf_image')
var_matches = [item for item in mat_keys if var_pattern.match(item)]
mat = mat[var_matches[0]]

epochs = mat.shape[0]
channels = mat.shape[1]
Expand Down

0 comments on commit 4e1d828

Please sign in to comment.