Skip to content

Commit

Permalink
developing threading in process_stack()
Browse files Browse the repository at this point in the history
  • Loading branch information
Phlair committed Apr 20, 2021
1 parent bbbab61 commit b36600a
Showing 1 changed file with 390 additions and 0 deletions.
390 changes: 390 additions & 0 deletions pytraction/core copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,390 @@
import cv2
import os
from skimage import io
import pandas as pd
import numpy as np
import torch
import segmentation_models_pytorch as smp
from collections import defaultdict
import concurrent
from read_roi import read_roi_file
from shapely import geometry
import pickle
import zipfile
from scipy.spatial import distance

from pytraction.piv import PIV
import pytraction.net.segment as pynet
from pytraction.utils import normalize, allign_slice, bead_density, plot
from pytraction.traction_force import PyTraction
from pytraction.net.dataloader import get_preprocessing
from pytraction.utils import HiddenPrints


from google_drive_downloader import GoogleDriveDownloader as gdd
import tempfile

class TractionForce(object):

def __init__(self, scaling_factor, E, s=0.5, meshsize=10, bead_density=None, device='cpu', segment=False, window_size=None):

self.device = device
self.segment = segment
self.window_size = window_size
self.E = E
self.s = s

self.TFM_obj = PyTraction(
meshsize = meshsize, # grid spacing in pix
pix_per_mu = scaling_factor,
E = E, # Young's modulus in Pa
s = s, # Poisson's ratio
)

self.model, self.pre_fn = self.get_model()


def get_window_size(self, img):
if not self.window_size:
density = bead_density(img)

file_id = '1xQuGSUdW3nIO5lAm7DQb567sMEQgHmQD'
tmpdir = tempfile.gettempdir()
destination = f'{tmpdir}/knn.zip'


gdd.download_file_from_google_drive(file_id=file_id,
dest_path=destination,
unzip=True,
showsize=False,
overwrite=False)

with open(f'{tmpdir}/knn.pickle', 'rb') as f:
knn = pickle.load(f)

window_size = knn.predict([[density]])

window_size = int(window_size)

print(f'Automatically selected window size of {window_size}')

return window_size
else:
return self.window_size


def get_model(self):
# data_20210320.zip
file_id = '1zShYcG8IMsMjB8hA6FcBTIZPfi_wDL4n'
tmpdir = tempfile.gettempdir()
destination = f'{tmpdir}/model.zip'


gdd.download_file_from_google_drive(file_id=file_id,
dest_path=destination,
unzip=True,
showsize=True,
overwrite=False)



# currently using model from 20210316
best_model = torch.load(f'{tmpdir}/best_model_1.pth', map_location='cpu')
if self.device == 'cuda' and torch.cuda.is_available():
best_model = best_model.to('cuda')
preproc_fn = smp.encoders.get_preprocessing_fn('efficientnet-b1', 'imagenet')
preprocessing_fn = get_preprocessing(preproc_fn)

return best_model, preprocessing_fn


def get_roi(self, img, ref, frame, roi, img_stack, crop):
cell_img = np.array(img_stack[frame, 1, :, :])
cell_img = normalize(cell_img)

if not roi and self.segment:
mask = pynet.get_mask(cell_img, self.model, self.pre_fn, device=self.device)

mask = np.array(mask.astype('bool'), dtype='uint8')

contours, _ = cv2.findContours(mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
# areas = [cv2.contourArea(c) for c in contours]
# sorted_areas = np.sort(areas)

image_center = np.asarray(mask.shape) / 2
image_center = tuple(image_center.astype('int32'))

segmented = []
for contour in contours:
# find center of each contour
M = cv2.moments(contour)
center_X = int(M["m10"] / M["m00"])
center_Y = int(M["m01"] / M["m00"])
contour_center = (center_X, center_Y)

# calculate distance to image_center
distances_to_center = (distance.euclidean(image_center, contour_center))

# save to a list of dictionaries
segmented.append({
'contour': contour,
'center': contour_center,
'distance_to_center': distances_to_center
}
)


sorted_cells = sorted(segmented, key=lambda i: i['distance_to_center'])

#bounding box (red)
# pts=contours[areas.index(sorted_areas[-1])] #the biggest contour
pts=sorted_cells[0]['contour'] #the biggest contour

cv2.drawContours(cell_img, [pts], -1, (255), 1, cv2.LINE_AA)

polyx, polyy = np.squeeze(pts, axis=1).T
roi = True

if roi:
shift=0.2
if not self.segment:
polyx = roi[0]
polyy = roi[1]

minx, maxx = np.min(polyx), np.max(polyx)
miny, maxy = np.min(polyy), np.max(polyy)

midx = minx + (maxx-minx) // 2
midy = miny + (maxy-miny) // 2

pixel_shift = int(max(midx, midy) * shift) // 2

# need to raise as an issues
rescaled = []
for (xi,yi) in zip(polyx, polyy):
# apply shift
rescaled.append([xi, yi])

self.polygon = geometry.Polygon(rescaled)

x,y,w,h = cv2.boundingRect(np.array(rescaled))
pad = 50


if not self.segment:
pts = np.array(list(zip(polyx, polyy)), np.int32)
pts = pts.reshape((-1,1,2))
cv2.polylines(cell_img,[pts],True,(255), thickness=3)

if crop:
img_crop = img[y-pad:y+h+pad, x-pad:x+w+pad]
ref_crop = ref[y-pad:y+h+pad, x-pad:x+w+pad]

mask = cv2.fillPoly(np.zeros(cell_img.shape), [pts], (255))
mask_crop = mask[y-pad:y+h+pad, x-pad:x+w+pad]

cell_img_full = cell_img
cell_img_crop = cell_img[y-pad:y+h+pad, x-pad:x+w+pad]

else:
img_crop = img
ref_crop = ref
mask_crop = cv2.fillPoly(np.zeros(cell_img.shape), [pts], (255))
cell_img_crop = cell_img

return img_crop, ref_crop, cell_img_crop, mask_crop
else:
return img, ref, cell_img, None



def get_noise(self, x,y,u,v, roi=False):
if not roi:
noise = 10
xn, yn, un, vn = x[:noise],y[:noise],u[:noise],v[:noise]
noise_vec = np.array([un.flatten(), vn.flatten()])

varnoise = np.var(noise_vec)
beta = 1/varnoise

elif roi:
noise = []
for (x0,y0, u0, v0) in zip(x.flatten(),y.flatten(), u.flatten(), v.flatten()):
p1 = geometry.Point([x0,y0])
if not p1.within(self.polygon):
noise.append(np.array([u0, v0]))

noise_vec = np.array(noise)
varnoise = np.var(noise_vec)
beta = 1/varnoise
return beta


def _recursive_lookup(self, k, d):
if k in d: return d[k]
for v in d.values():
if isinstance(v, dict):
a = self._recursive_lookup(k, v)
if a is not None: return a
return None

def load_data(self, img_path, ref_path, roi_path=''):
"""
:param img_path: Image path for to nd image with shape (f,c,w,h)
:param ref_path: Reference path for to nd image with shape (c,w,h)
:param roi_path:
"""
img = io.imread(img_path)
ref = io.imread(ref_path)

if not isinstance(img,np.ndarray) or not isinstance(ref, np.ndarray):
msg = f'Image data not loaded for {img_path} or {ref_path}'
raise TypeError(msg)

if len(img.shape) != 4:
msg = f'Please ensure that the input image has shape (t,c,w,h) the current shape is {img.shape}'
raise RuntimeWarning(msg)

if len(ref.shape) != 3:
msg = f'Please ensure that the input image has shape (c,w,h) the current shape is {ref.shape}'
raise RuntimeWarning(msg)


# messy fix to include file name in log file
self.ref_path = ref_path
self.img_path = img_path

if '.csv' in roi_path:
x, y = pd.read_csv(roi_path).T.values
roi = (x,y)

elif '.roi' in roi_path:
d = read_roi_file(roi_path)
x = self._recursive_lookup('x', d)
y = self._recursive_lookup('y', d)

roi = (x,y)

elif '.zip' in roi_path:
rois = []
with zipfile.ZipFile(roi_path) as ziproi:
for file in ziproi.namelist():
roi_path_file = ziproi.extract(file)
d = read_roi_file(roi_path_file)
x = self._recursive_lookup('x', d)
y = self._recursive_lookup('y', d)
rois.append((x,y))
os.remove(roi_path_file)

roi = rois


else:
roi = None


return img, ref, roi

def process_stack(self, img_stack, ref_stack, bead_channel=0, roi=False, frame=[], crop=False, verbose=0, num_workers:int=4):
'''
TODO: docstring
'''
if verbose == 0:
print('Processing stacks')
with HiddenPrints():
output = self._process_stack(img_stack, ref_stack, bead_channel, roi, frame, crop, num_workers)
elif verbose == 1:
output = self._process_stack(img_stack, ref_stack, bead_channel, roi, frame, crop, num_workers)
return output

def _process_stack(self, img_stack, ref_stack, bead_channel=0, roi=False, frame=[], crop=False, num_workers:int=4):
'''
TODO: docstring
'''
# init log defaultdict
log = defaultdict(list)

# use concurrent future threads to parallelise execution
with concurrent.futures.ThreadPoolExecutor(max_workers = num_workers) as executor:

# submit a _process_frame() run on an executor for each frame
nframes = img_stack.shape[0]
futures = {
executor.submit(
self._process_frame, frm, img_stack, ref_stack, bead_channel, roi, frame, crop
): frm for frm in list(range(nframes))
}

# we're yielding each complete frame (futures) back in the main process here so we're threadsafe
for future in concurrent.futures.as_completed(futures):
# this is ugly but should be performant
frame_dict = future.result()
log['frame'].append(frame_dict['frame'])
log['traction_map'].append(frame_dict['traction_map'])
log['force_field'].append(frame_dict['force_field'])
log['stack_bead_roi'].append(frame_dict['stack_bead_roi'])
log['cell_roi'].append(frame_dict['cell_roi'])
log['mask_roi'].append(frame_dict['mask_roi'])
log['beta'].append(frame_dict['beta'])
log['L'].append(frame_dict['L'])
log['pos'].append(frame_dict['pos'])
log['vec'].append(frame_dict['vec'])
log['img_path'].append(frame_dict['img_path'])
log['ref_path'].append(frame_dict['ref_path'])
log['E'].append(frame_dict['E'])
log['s'].append(frame_dict['s'])

return pd.DataFrame(log)

def _process_frame(self, frame, img_stack, ref_stack, bead_channel, roi, frame2, crop):
'''
TODO: docstring
'''
# load plane
img = normalize(np.array(img_stack[frame, bead_channel, :, :]))
ref = normalize(np.array(ref_stack[bead_channel,:,:]))

window_size = self.get_window_size(img)

if isinstance(roi, list):
assert len(roi) == nframes, f'Warning ROI list has len {len(roi)} which is not equal to \
the number of frames ({nframes}). This would suggest that you do not \
have the correct number of ROIs in the zip file.'
roi_i = roi[frame]
else:
roi_i = roi


img_crop, ref_crop, cell_img_crop, mask_crop = self.get_roi(img, ref, frame, roi_i, img_stack, crop)

# do piv
piv_obj = PIV(window_size=window_size)
x, y, u, v, stack = piv_obj.iterative_piv(img_crop, ref_crop)

beta = self.get_noise(x,y,u,v, roi=False)

# make pos and vecs for TFM
pos = np.array([x.flatten(), y.flatten()])
vec = np.array([u.flatten(), v.flatten()])

# compute traction map
traction_map, f_n_m, L_optimal = self.TFM_obj.calculate_traction_map(pos, vec, beta)

# create an output dict for this frame that we can then apply to the main log object outside the threads
frame_dict = {}
frame_dict['frame'] = frame
frame_dict['traction_map'] = traction_map
frame_dict['force_field'] = f_n_m
frame_dict['stack_bead_roi'] = stack
frame_dict['cell_roi'] = cell_img_crop
frame_dict['mask_roi'] = mask_crop
frame_dict['beta'] = beta
frame_dict['L'] = L_optimal
frame_dict['pos'] = pos
frame_dict['vec'] = vec
frame_dict['img_path'] = self.img_path
frame_dict['ref_path'] = self.ref_path
frame_dict['E'] = self.E
frame_dict['s'] = self.s

return frame_dict

0 comments on commit b36600a

Please sign in to comment.