Skip to content

Commit

Permalink
Changed uae_gen file
Browse files Browse the repository at this point in the history
  • Loading branch information
mskishan26 committed Oct 9, 2024
1 parent a22e05b commit 3fd433f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
9 changes: 1 addition & 8 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,14 @@ ARG SRCDIR=/usr/local/src/pl-uae_gen
WORKDIR ${SRCDIR}

COPY requirements.txt .
RUN apt-get update && \
apt-get install -y \
git \
python3-pip \
python3-dev \
python3-opencv \
libglib2.0-0
RUN python3 -m pip install --upgrade pip
# RUN pip3 install torch -f https://download.pytorch.org/whl/cu111/torch_stable.html
RUN pip install -r requirements.txt

COPY . .
ARG extras_require=none
RUN pip install ".[${extras_require}]" \
&& cd / && rm -rf ${SRCDIR}

WORKDIR /

CMD ["uae_gen"]
46 changes: 28 additions & 18 deletions uae_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,12 @@
"""


parser = ArgumentParser(description='!!!CHANGE ME!!! An example ChRIS plugin which '
'counts the number of occurrences of a given '
'word in text files.',
parser = ArgumentParser(description='Find activation energy from multiple layers of VGGNet',
formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('-p', '--pattern', default='*0.txt', type=str,
parser.add_argument('-p', '--pattern', default='*0.mat', type=str,
help='input file filter glob')
parser.add_argument('-V', '--version', action='version',
version=f'%(prog)s {__version__}')
options = parser.parse_args()

def is_basename(x:str): return os.path.dirname(x) == ""

def uae(f_map):
return torch.sum(f_map)/f_map.numel()

def hook_fn(module, input, output):
result = uae(output[0])
activation_outputs.append(result)

activation_outputs = []

# The main function of this *ChRIS* plugin is denoted by this ``@chris_plugin`` "decorator."
# Some metadata about the plugin is specified here. There is more metadata specified in setup.py.
Expand All @@ -56,7 +42,7 @@ def hook_fn(module, input, output):
category='', # ref. https://chrisstore.co/plugins
min_memory_limit='100Mi', # supported units: Mi, Gi
min_cpu_limit='1000m', # millicores, e.g. "1000m" = 1 CPU core
min_gpu_limit=1 # set min_gpu_limit=1 to enable GPU
min_gpu_limit=0 # set min_gpu_limit=1 to enable GPU
)
def main(options: Namespace, inputdir: Path, outputdir: Path):
"""
Expand All @@ -80,8 +66,20 @@ def main(options: Namespace, inputdir: Path, outputdir: Path):
# adding a progress bar and parallelism.
mapper = PathMapper.file_mapper(inputdir, outputdir, glob=options.pattern)
for input_file, output_file in mapper:
def is_basename(x:str): return os.path.dirname(x) == ""

def uae(f_map):
return torch.sum(f_map)/f_map.numel()

def hook_fn(module, input, output):
result = uae(output[0])
activation_outputs.append(result)

activation_outputs = []

patient_id = os.path.splitext(os.path.basename(input_file))[0]
patient_id = patient_id.split_id('_')[0]
print(patient_id)
patient_id = patient_id.split('_')[0]
model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)
# print(model)
hooks = []
Expand All @@ -91,6 +89,17 @@ def main(options: Namespace, inputdir: Path, outputdir: Path):
hook = model.features[layer].register_forward_hook(hook_fn)
hooks.append(hook)

device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using {device} device")

model = model.to(device)

mat = mat73.loadmat(input_file)
mat = mat['tf_images_mat']

Expand All @@ -104,6 +113,7 @@ def main(options: Namespace, inputdir: Path, outputdir: Path):
img = mat[epoch,channel,:,:,:]
img = img.reshape(-1,img.shape[2],img.shape[0],img.shape[1])
img = torch.tensor(img, dtype=torch.float32)
img = img.to(device)
res = model(img)
np_activ = np.array([tensor.item() for tensor in activation_outputs])
results[epoch,channel,:] = np_activ
Expand Down

0 comments on commit 3fd433f

Please sign in to comment.