-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
developing threading in process_stack()
- Loading branch information
Showing
1 changed file
with
390 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,390 @@ | ||
import cv2 | ||
import os | ||
from skimage import io | ||
import pandas as pd | ||
import numpy as np | ||
import torch | ||
import segmentation_models_pytorch as smp | ||
from collections import defaultdict | ||
import concurrent | ||
from read_roi import read_roi_file | ||
from shapely import geometry | ||
import pickle | ||
import zipfile | ||
from scipy.spatial import distance | ||
|
||
from pytraction.piv import PIV | ||
import pytraction.net.segment as pynet | ||
from pytraction.utils import normalize, allign_slice, bead_density, plot | ||
from pytraction.traction_force import PyTraction | ||
from pytraction.net.dataloader import get_preprocessing | ||
from pytraction.utils import HiddenPrints | ||
|
||
|
||
from google_drive_downloader import GoogleDriveDownloader as gdd | ||
import tempfile | ||
|
||
class TractionForce(object): | ||
|
||
def __init__(self, scaling_factor, E, s=0.5, meshsize=10, bead_density=None, device='cpu', segment=False, window_size=None): | ||
|
||
self.device = device | ||
self.segment = segment | ||
self.window_size = window_size | ||
self.E = E | ||
self.s = s | ||
|
||
self.TFM_obj = PyTraction( | ||
meshsize = meshsize, # grid spacing in pix | ||
pix_per_mu = scaling_factor, | ||
E = E, # Young's modulus in Pa | ||
s = s, # Poisson's ratio | ||
) | ||
|
||
self.model, self.pre_fn = self.get_model() | ||
|
||
|
||
def get_window_size(self, img): | ||
if not self.window_size: | ||
density = bead_density(img) | ||
|
||
file_id = '1xQuGSUdW3nIO5lAm7DQb567sMEQgHmQD' | ||
tmpdir = tempfile.gettempdir() | ||
destination = f'{tmpdir}/knn.zip' | ||
|
||
|
||
gdd.download_file_from_google_drive(file_id=file_id, | ||
dest_path=destination, | ||
unzip=True, | ||
showsize=False, | ||
overwrite=False) | ||
|
||
with open(f'{tmpdir}/knn.pickle', 'rb') as f: | ||
knn = pickle.load(f) | ||
|
||
window_size = knn.predict([[density]]) | ||
|
||
window_size = int(window_size) | ||
|
||
print(f'Automatically selected window size of {window_size}') | ||
|
||
return window_size | ||
else: | ||
return self.window_size | ||
|
||
|
||
def get_model(self): | ||
# data_20210320.zip | ||
file_id = '1zShYcG8IMsMjB8hA6FcBTIZPfi_wDL4n' | ||
tmpdir = tempfile.gettempdir() | ||
destination = f'{tmpdir}/model.zip' | ||
|
||
|
||
gdd.download_file_from_google_drive(file_id=file_id, | ||
dest_path=destination, | ||
unzip=True, | ||
showsize=True, | ||
overwrite=False) | ||
|
||
|
||
|
||
# currently using model from 20210316 | ||
best_model = torch.load(f'{tmpdir}/best_model_1.pth', map_location='cpu') | ||
if self.device == 'cuda' and torch.cuda.is_available(): | ||
best_model = best_model.to('cuda') | ||
preproc_fn = smp.encoders.get_preprocessing_fn('efficientnet-b1', 'imagenet') | ||
preprocessing_fn = get_preprocessing(preproc_fn) | ||
|
||
return best_model, preprocessing_fn | ||
|
||
|
||
def get_roi(self, img, ref, frame, roi, img_stack, crop): | ||
cell_img = np.array(img_stack[frame, 1, :, :]) | ||
cell_img = normalize(cell_img) | ||
|
||
if not roi and self.segment: | ||
mask = pynet.get_mask(cell_img, self.model, self.pre_fn, device=self.device) | ||
|
||
mask = np.array(mask.astype('bool'), dtype='uint8') | ||
|
||
contours, _ = cv2.findContours(mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) | ||
# areas = [cv2.contourArea(c) for c in contours] | ||
# sorted_areas = np.sort(areas) | ||
|
||
image_center = np.asarray(mask.shape) / 2 | ||
image_center = tuple(image_center.astype('int32')) | ||
|
||
segmented = [] | ||
for contour in contours: | ||
# find center of each contour | ||
M = cv2.moments(contour) | ||
center_X = int(M["m10"] / M["m00"]) | ||
center_Y = int(M["m01"] / M["m00"]) | ||
contour_center = (center_X, center_Y) | ||
|
||
# calculate distance to image_center | ||
distances_to_center = (distance.euclidean(image_center, contour_center)) | ||
|
||
# save to a list of dictionaries | ||
segmented.append({ | ||
'contour': contour, | ||
'center': contour_center, | ||
'distance_to_center': distances_to_center | ||
} | ||
) | ||
|
||
|
||
sorted_cells = sorted(segmented, key=lambda i: i['distance_to_center']) | ||
|
||
#bounding box (red) | ||
# pts=contours[areas.index(sorted_areas[-1])] #the biggest contour | ||
pts=sorted_cells[0]['contour'] #the biggest contour | ||
|
||
cv2.drawContours(cell_img, [pts], -1, (255), 1, cv2.LINE_AA) | ||
|
||
polyx, polyy = np.squeeze(pts, axis=1).T | ||
roi = True | ||
|
||
if roi: | ||
shift=0.2 | ||
if not self.segment: | ||
polyx = roi[0] | ||
polyy = roi[1] | ||
|
||
minx, maxx = np.min(polyx), np.max(polyx) | ||
miny, maxy = np.min(polyy), np.max(polyy) | ||
|
||
midx = minx + (maxx-minx) // 2 | ||
midy = miny + (maxy-miny) // 2 | ||
|
||
pixel_shift = int(max(midx, midy) * shift) // 2 | ||
|
||
# need to raise as an issues | ||
rescaled = [] | ||
for (xi,yi) in zip(polyx, polyy): | ||
# apply shift | ||
rescaled.append([xi, yi]) | ||
|
||
self.polygon = geometry.Polygon(rescaled) | ||
|
||
x,y,w,h = cv2.boundingRect(np.array(rescaled)) | ||
pad = 50 | ||
|
||
|
||
if not self.segment: | ||
pts = np.array(list(zip(polyx, polyy)), np.int32) | ||
pts = pts.reshape((-1,1,2)) | ||
cv2.polylines(cell_img,[pts],True,(255), thickness=3) | ||
|
||
if crop: | ||
img_crop = img[y-pad:y+h+pad, x-pad:x+w+pad] | ||
ref_crop = ref[y-pad:y+h+pad, x-pad:x+w+pad] | ||
|
||
mask = cv2.fillPoly(np.zeros(cell_img.shape), [pts], (255)) | ||
mask_crop = mask[y-pad:y+h+pad, x-pad:x+w+pad] | ||
|
||
cell_img_full = cell_img | ||
cell_img_crop = cell_img[y-pad:y+h+pad, x-pad:x+w+pad] | ||
|
||
else: | ||
img_crop = img | ||
ref_crop = ref | ||
mask_crop = cv2.fillPoly(np.zeros(cell_img.shape), [pts], (255)) | ||
cell_img_crop = cell_img | ||
|
||
return img_crop, ref_crop, cell_img_crop, mask_crop | ||
else: | ||
return img, ref, cell_img, None | ||
|
||
|
||
|
||
def get_noise(self, x,y,u,v, roi=False): | ||
if not roi: | ||
noise = 10 | ||
xn, yn, un, vn = x[:noise],y[:noise],u[:noise],v[:noise] | ||
noise_vec = np.array([un.flatten(), vn.flatten()]) | ||
|
||
varnoise = np.var(noise_vec) | ||
beta = 1/varnoise | ||
|
||
elif roi: | ||
noise = [] | ||
for (x0,y0, u0, v0) in zip(x.flatten(),y.flatten(), u.flatten(), v.flatten()): | ||
p1 = geometry.Point([x0,y0]) | ||
if not p1.within(self.polygon): | ||
noise.append(np.array([u0, v0])) | ||
|
||
noise_vec = np.array(noise) | ||
varnoise = np.var(noise_vec) | ||
beta = 1/varnoise | ||
return beta | ||
|
||
|
||
def _recursive_lookup(self, k, d): | ||
if k in d: return d[k] | ||
for v in d.values(): | ||
if isinstance(v, dict): | ||
a = self._recursive_lookup(k, v) | ||
if a is not None: return a | ||
return None | ||
|
||
def load_data(self, img_path, ref_path, roi_path=''): | ||
""" | ||
:param img_path: Image path for to nd image with shape (f,c,w,h) | ||
:param ref_path: Reference path for to nd image with shape (c,w,h) | ||
:param roi_path: | ||
""" | ||
img = io.imread(img_path) | ||
ref = io.imread(ref_path) | ||
|
||
if not isinstance(img,np.ndarray) or not isinstance(ref, np.ndarray): | ||
msg = f'Image data not loaded for {img_path} or {ref_path}' | ||
raise TypeError(msg) | ||
|
||
if len(img.shape) != 4: | ||
msg = f'Please ensure that the input image has shape (t,c,w,h) the current shape is {img.shape}' | ||
raise RuntimeWarning(msg) | ||
|
||
if len(ref.shape) != 3: | ||
msg = f'Please ensure that the input image has shape (c,w,h) the current shape is {ref.shape}' | ||
raise RuntimeWarning(msg) | ||
|
||
|
||
# messy fix to include file name in log file | ||
self.ref_path = ref_path | ||
self.img_path = img_path | ||
|
||
if '.csv' in roi_path: | ||
x, y = pd.read_csv(roi_path).T.values | ||
roi = (x,y) | ||
|
||
elif '.roi' in roi_path: | ||
d = read_roi_file(roi_path) | ||
x = self._recursive_lookup('x', d) | ||
y = self._recursive_lookup('y', d) | ||
|
||
roi = (x,y) | ||
|
||
elif '.zip' in roi_path: | ||
rois = [] | ||
with zipfile.ZipFile(roi_path) as ziproi: | ||
for file in ziproi.namelist(): | ||
roi_path_file = ziproi.extract(file) | ||
d = read_roi_file(roi_path_file) | ||
x = self._recursive_lookup('x', d) | ||
y = self._recursive_lookup('y', d) | ||
rois.append((x,y)) | ||
os.remove(roi_path_file) | ||
|
||
roi = rois | ||
|
||
|
||
else: | ||
roi = None | ||
|
||
|
||
return img, ref, roi | ||
|
||
def process_stack(self, img_stack, ref_stack, bead_channel=0, roi=False, frame=[], crop=False, verbose=0, num_workers:int=4): | ||
''' | ||
TODO: docstring | ||
''' | ||
if verbose == 0: | ||
print('Processing stacks') | ||
with HiddenPrints(): | ||
output = self._process_stack(img_stack, ref_stack, bead_channel, roi, frame, crop, num_workers) | ||
elif verbose == 1: | ||
output = self._process_stack(img_stack, ref_stack, bead_channel, roi, frame, crop, num_workers) | ||
return output | ||
|
||
def _process_stack(self, img_stack, ref_stack, bead_channel=0, roi=False, frame=[], crop=False, num_workers:int=4): | ||
''' | ||
TODO: docstring | ||
''' | ||
# init log defaultdict | ||
log = defaultdict(list) | ||
|
||
# use concurrent future threads to parallelise execution | ||
with concurrent.futures.ThreadPoolExecutor(max_workers = num_workers) as executor: | ||
|
||
# submit a _process_frame() run on an executor for each frame | ||
nframes = img_stack.shape[0] | ||
futures = { | ||
executor.submit( | ||
self._process_frame, frm, img_stack, ref_stack, bead_channel, roi, frame, crop | ||
): frm for frm in list(range(nframes)) | ||
} | ||
|
||
# we're yielding each complete frame (futures) back in the main process here so we're threadsafe | ||
for future in concurrent.futures.as_completed(futures): | ||
# this is ugly but should be performant | ||
frame_dict = future.result() | ||
log['frame'].append(frame_dict['frame']) | ||
log['traction_map'].append(frame_dict['traction_map']) | ||
log['force_field'].append(frame_dict['force_field']) | ||
log['stack_bead_roi'].append(frame_dict['stack_bead_roi']) | ||
log['cell_roi'].append(frame_dict['cell_roi']) | ||
log['mask_roi'].append(frame_dict['mask_roi']) | ||
log['beta'].append(frame_dict['beta']) | ||
log['L'].append(frame_dict['L']) | ||
log['pos'].append(frame_dict['pos']) | ||
log['vec'].append(frame_dict['vec']) | ||
log['img_path'].append(frame_dict['img_path']) | ||
log['ref_path'].append(frame_dict['ref_path']) | ||
log['E'].append(frame_dict['E']) | ||
log['s'].append(frame_dict['s']) | ||
|
||
return pd.DataFrame(log) | ||
|
||
def _process_frame(self, frame, img_stack, ref_stack, bead_channel, roi, frame2, crop): | ||
''' | ||
TODO: docstring | ||
''' | ||
# load plane | ||
img = normalize(np.array(img_stack[frame, bead_channel, :, :])) | ||
ref = normalize(np.array(ref_stack[bead_channel,:,:])) | ||
|
||
window_size = self.get_window_size(img) | ||
|
||
if isinstance(roi, list): | ||
assert len(roi) == nframes, f'Warning ROI list has len {len(roi)} which is not equal to \ | ||
the number of frames ({nframes}). This would suggest that you do not \ | ||
have the correct number of ROIs in the zip file.' | ||
roi_i = roi[frame] | ||
else: | ||
roi_i = roi | ||
|
||
|
||
img_crop, ref_crop, cell_img_crop, mask_crop = self.get_roi(img, ref, frame, roi_i, img_stack, crop) | ||
|
||
# do piv | ||
piv_obj = PIV(window_size=window_size) | ||
x, y, u, v, stack = piv_obj.iterative_piv(img_crop, ref_crop) | ||
|
||
beta = self.get_noise(x,y,u,v, roi=False) | ||
|
||
# make pos and vecs for TFM | ||
pos = np.array([x.flatten(), y.flatten()]) | ||
vec = np.array([u.flatten(), v.flatten()]) | ||
|
||
# compute traction map | ||
traction_map, f_n_m, L_optimal = self.TFM_obj.calculate_traction_map(pos, vec, beta) | ||
|
||
# create an output dict for this frame that we can then apply to the main log object outside the threads | ||
frame_dict = {} | ||
frame_dict['frame'] = frame | ||
frame_dict['traction_map'] = traction_map | ||
frame_dict['force_field'] = f_n_m | ||
frame_dict['stack_bead_roi'] = stack | ||
frame_dict['cell_roi'] = cell_img_crop | ||
frame_dict['mask_roi'] = mask_crop | ||
frame_dict['beta'] = beta | ||
frame_dict['L'] = L_optimal | ||
frame_dict['pos'] = pos | ||
frame_dict['vec'] = vec | ||
frame_dict['img_path'] = self.img_path | ||
frame_dict['ref_path'] = self.ref_path | ||
frame_dict['E'] = self.E | ||
frame_dict['s'] = self.s | ||
|
||
return frame_dict |