Skip to content

Commit

Permalink
Comments and tutorial figure labels
Browse files Browse the repository at this point in the history
  • Loading branch information
ccli3896 committed Dec 21, 2023
1 parent f1f2c73 commit cb0e456
Show file tree
Hide file tree
Showing 20 changed files with 202 additions and 1,711 deletions.
3 changes: 3 additions & 0 deletions Animal scripts/SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,20 @@ def weights_init_(m):
'''

class Ensemble(nn.Module):
# Initialize list of actors
def __init__(self, actors, actions=2):
super(Ensemble, self).__init__()
self.actors = nn.ModuleList(actors).to(device)
self.actions = actions

def forward(self, state):
# Return action by averaging over ensemble choices
sums = torch.cat([act(state)[1].detach().view(state.shape[0],self.actions,1) for act in self.actors], dim=2)
sums = torch.mean(sums, dim=2)
return torch.argmax(sums, dim=1)

def sample(self, state, greedy=False):
# Sample from ensemble: if not greedy, choose from one actor
if not greedy:
return random.choice([act.sample(state) for act in self.actors])
else:
Expand Down
2 changes: 2 additions & 0 deletions Animal scripts/check_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def main():
# Plot track
plt.scatter(x,-np.array(y),c=cm,marker='.')
target = args.target.split(',')

# Plot the target point
plt.scatter(int(target[0]),-int(target[1]),c='r')

# Plot the starting point
plt.scatter(x[0],-y[0])

Expand Down
22 changes: 21 additions & 1 deletion Animal scripts/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,56 @@

def main(minutes, cam_id, randomrate=.1, lightrate=3, fps=3):

# Make folder name to save all image data in, by timestamp
nowtime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
save_folder = f'./Data/{nowtime}/'
os.makedirs(save_folder)

# Initialize hardware: camera and high-powered LED ('task' object)
cam,task,cam_params = utils.init_instruments(cam_id)

# Number of frames to track animal for
frames = int(minutes*60*args.fps)

# For cases where I just want the initial picture
if minutes==0:
frames = 3 # For cases where I just want the initial picture
frames = 3

light = []

# Throw out first two images
for i in range(2):
img = utils.grab_im(cam,None,cam_params)

# For episode length, take picture and save it to folder
for i in range(frames):
img = utils.grab_im(cam,None,cam_params)
cv2.imwrite(f'{save_folder}/{i}.jpg',img)

# If I want random light flashes:
if task is not None:

# Every [lightrate] frames, choose to randomly turn light on with probability randomrate
if i%args.lightrate==0:

# Turn light on if random number is less than randomrate
# Write light history to list
if np.random.rand() < randomrate:
task.write(cam_params['light_amp'], auto_start=True)
task.stop()
[light.append(1) for _ in range(lightrate)]

# If random number is not less than randomrate, turn off and write light off history to list
else:
task.write(0, auto_start=True)
task.stop()
[light.append(0) for _ in range(lightrate)]

# Save light data
with open(f'{save_folder}/light.pkl','wb') as f:
pickle.dump(light,f)

# Close hardware
utils.exit_instruments(cam,task)


Expand Down
31 changes: 24 additions & 7 deletions Animal scripts/improc_v.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''
Image processing functions as of 4.27.21.
Base image processing functions as of 4.27.21.
~25 ms for full set of functions
'''

Expand All @@ -10,19 +10,21 @@
from utils import *

from skimage.morphology import skeletonize

'''
Initial functions: defining mask, background,
'''

def make_mask(bg,size=50):
# Make a large circle mask to remove parts of image outside the plate
center = tuple(map(int,np.flip(np.array(bg.shape)[:2])/2)) # Makes the center of the circle the center of the img
radius = int(np.mean(center))-size
img = np.zeros(bg.shape, dtype="uint8")
img = cv2.bitwise_not(cv2.circle(img, center, radius, (255,255,255), thickness=-1))
return img

def define_endpt_kernels():
# To save time in skeleton_endpoints().
# To save time in skeleton_endpoints(). Kernels to identify endpoints in a skeleton
kernels = [np.array((
[-1, -1, 0],
[-1, 1, 0],
Expand All @@ -36,16 +38,20 @@ def define_endpt_kernels():
return kernels, not_kerns

def make_bg(cam,t):
# Make the background by averaging many collected images. This will be subtracted from tracking imgs.
bg = grab_im(cam,None).astype('float64')
start_time = time.monotonic()
i=0

# For a given time, find average of imgs
while time.monotonic()-start_time < t:
bg -= 1/(i+1)*(cv2.subtract(bg,grab_im(cam,None).astype('float64')))
i+=1
bg = bg.astype('uint8')[:,:,0]
return bg

def load_templates():
# Load templates for head tracking. These are separately saved in templates folder and specific to each orientation.
DEG_INCR = 30
templates = []
for i in np.arange(0,360,DEG_INCR):
Expand All @@ -57,8 +63,14 @@ def load_templates():
'''

def find_worm(img,mask,bg,threshold=8,buffer=1,imsz=64,pix_num=15):
'''Img right from camera. Returns cropped worm image (unthresholded) and location of center of contour.
Inputs:
threshold is brightness level to consider as part of worm
buffer is a small amount of padding to avoid endpoint errors
imsz is the rough area of a worm in pixels so things that are much too large can be eliminated
pix_num is the minimum window for a worm
'''
# 2.5 ms
# Im right from camera. Returns worm image (unthresholded) and location of center of contour.
worm_im = np.zeros((imsz,imsz),dtype='uint8')
im = cv2.subtract(cv2.subtract(img,mask),bg)
retval, th = cv2.threshold(im, thresh=threshold, maxval=255, type=cv2.THRESH_BINARY_INV)
Expand All @@ -73,17 +85,18 @@ def find_worm(img,mask,bg,threshold=8,buffer=1,imsz=64,pix_num=15):
if area > mx_area and area < (imsz*4)**2: # imsz**2 is too big for a worm!
mx = x,y,w,h
mx_area = area

# Find centroid and width, height of image
x,y,w,h = mx
center = np.array([y+h/2,x+w/2],dtype=int)

# Check size is reasonable and get cropped worm image
if mx_area > pix_num: # and ((w<imsz) & (h<imsz)):
worm = img[y-buffer:y+buffer+h, x-buffer:x+buffer+w][:imsz,:imsz]
#percs = np.sort(worm.flatten())
#worm_im = cv2.add(worm_im,int(percs[pix_num]))
# worm_im[(imsz-h)//2-buffer:(imsz-h)//2+h+buffer,
# (imsz-w)//2-buffer:(imsz-w)//2+w+buffer] = worm
worm_im = worm
return worm_im, center
else:
# Error: worm not found
return None, None

def get_wormies(img,process=True):
Expand Down Expand Up @@ -139,13 +152,16 @@ def skeleton_endpoints(skel, kernels, not_kerns):
'''

def get_body_angle(skel,discretization=30):
'''From skeleton, find the best fit line and its angle.
'''
vy,vx,y,x = cv2.fitLine(np.vstack(skel.nonzero()).T, cv2.DIST_L2,0,0.01,0.01)
body_angle = np.round((np.arctan2(-vy,vx)*180/np.pi)/discretization)*discretization
if body_angle==180: body_angle=-180
return body_angle


def get_HT_ims(th,endpoints,padby=5):
# Gets a zoomed-in crop of head and tail for template-matching later.
# 4 us
# Assumes template size of 7x7
def get_one_HT_im(th,endpt):
Expand All @@ -161,6 +177,7 @@ def get_one_HT_im(th,endpt):
return [get_one_HT_im(th_pad,endpts[0,:])]

def get_HT_angles(HTs,templates):
# Match head and tail to templates and identify the angle they're pointing in.
# 350 us
DEG_INCR = 30
angs = np.zeros(len(HTs))
Expand Down
5 changes: 5 additions & 0 deletions Animal scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'''

def worm_bound(a):
# Make sure angles always stay within +/-180 degrees
if hasattr(a,'__len__'):
a = np.array(a)
a = np.where(a<-180, a+360, a)
Expand Down Expand Up @@ -44,10 +45,13 @@ def projection(body_angle, mvt_angle, speed):
'''

def off(cam_id):
# Reset
cam,task,cam_params = init_instruments(cam_id)
exit_instruments(cam,task)

def init_instruments(cam_id):
# Initializes instruments for the two different rigs. Different camera models and resolutions.
# These are the parameters that seem to work and allow me to use the same image processing settings, mostly

if cam_id==1:
cam_params = {
Expand Down Expand Up @@ -121,6 +125,7 @@ def grab_im(cam,bg,cam_params):
return cv2.subtract(imdata,bg)

def make_bg(cam, cam_params, bgtime=30):
# Make the background to be subtracted from images later, for worm-finding
start = time.monotonic()
bg = grab_im(cam,None,cam_params).astype('float64')
i=1
Expand Down
Binary file added Cross evaluation data/.DS_Store
Binary file not shown.
Loading

0 comments on commit cb0e456

Please sign in to comment.