Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for different labels #93

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions anipose/anipose.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
'pose_2d_projected': 'pose-2d-proj',
'pose_3d': 'pose-3d',
'pose_3d_filter': 'pose-3d-filtered',
'pose_3d_ext': 'h5',
'videos_labeled_2d': 'videos-labeled',
'videos_labeled_2d_filter': 'videos-labeled-filtered',
'calibration_videos': 'calibration',
Expand Down Expand Up @@ -131,7 +132,7 @@ def tracking_errors(config, scorer=None):
from .tracking_errors import get_tracking_errors
click.echo('Comparing tracking to labeled data...')
get_tracking_errors(config, scorer)

@cli.command()
@pass_config
def analyze(config):
Expand Down Expand Up @@ -236,7 +237,7 @@ def label_2d_proj(config):
from .label_videos_proj import label_proj_all
click.echo('Making 2D videos from 3D projections...')
label_proj_all(config)

@cli.command()
@pass_config
def label_2d(config):
Expand Down Expand Up @@ -322,7 +323,7 @@ def run_data(config):
from .filter_3d import filter_pose_3d_all
click.echo('Filtering 3D points...')
filter_pose_3d_all(config)

click.echo('Computing angles...')
compute_angles_all(config)

Expand Down
52 changes: 40 additions & 12 deletions anipose/triangulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os.path
import pandas as pd
import toml
import pickle
from numpy import array as arr
from glob import glob
from scipy import optimize
Expand Down Expand Up @@ -81,12 +82,14 @@ def correct_coordinate_frame(config, all_points_3d, bodyparts):
def load_pose2d_fnames(fname_dict, offsets_dict=None, cam_names=None):
if cam_names is None:
cam_names = sorted(fname_dict.keys())

pose_names = [fname_dict[cname] for cname in cam_names]

if offsets_dict is None:
offsets_dict = dict([(cname, (0,0)) for cname in cam_names])

datas = []
joints_all = []
for ix_cam, (cam_name, pose_name) in \
enumerate(zip(cam_names, pose_names)):
dlabs = pd.read_hdf(pose_name)
Expand All @@ -104,19 +107,21 @@ def load_pose2d_fnames(fname_dict, offsets_dict=None, cam_names=None):
dlabs.loc[:, (joint, 'y')] += dy

datas.append(dlabs)
joints_all += joint_names

joint_names_unique = np.unique(joints_all)
n_cams = len(cam_names)
n_joints = len(joint_names)
n_joints = len(joint_names_unique)
n_frames = min([d.shape[0] for d in datas])

# frame, camera, bodypart, xy
points = np.full((n_cams, n_frames, n_joints, 2), np.nan, 'float')
scores = np.full((n_cams, n_frames, n_joints), np.zeros(1), 'float')#initialise as zeros, instead of NaN, makes more sense?
scores = np.full((n_cams, n_frames, n_joints), np.zeros(1), 'float')#initialise as zeros, instead of NaN, makes more sense?

for cam_ix, dlabs in enumerate(datas):
for joint_ix, joint_name in enumerate(joint_names):
for joint_ix, joint_name in enumerate(joint_names_unique):
try:
points[cam_ix, :, joint_ix] = np.array(dlabs.loc[:, (joint_name, ('x', 'y'))])[:n_frames]
points[cam_ix, :, joint_ix] = np.array(dlabs.loc[:, (joint_name, ('x', 'y'))])[:n_frames]
scores[cam_ix, :, joint_ix] = np.array(dlabs.loc[:, (joint_name, ('likelihood'))])[:n_frames].ravel()
except KeyError:
pass
Expand All @@ -125,7 +130,7 @@ def load_pose2d_fnames(fname_dict, offsets_dict=None, cam_names=None):
'cam_names': cam_names,
'points': points,
'scores': scores,
'bodyparts': joint_names
'bodyparts': joint_names_unique
}


Expand Down Expand Up @@ -290,7 +295,22 @@ def triangulate(config,

dout['fnum'] = np.arange(n_frames)

dout.to_csv(output_fname, index=False)
if output_fname.endswith('.csv'):
dout.to_csv(output_fname, index=False)
else:
with open(output_fname, 'wb') as f: pickle.dump(dout, f)

print(f'Triangulated pose is saved at: {output_fname}')


def get_camera_names_from_calib(calib_folder):
""" Get camera names from the calibration folder."""
calib_fname = os.path.join(calib_folder, 'calibration.toml')
master_dict = toml.load(calib_fname)
return [
camera_dict['name']
for camera, camera_dict in master_dict.items() if 'cam' in camera
]


def process_session(config, session_path):
Expand All @@ -299,6 +319,10 @@ def process_session(config, session_path):
pipeline_pose = config['pipeline']['pose_2d']
pipeline_pose_filter = config['pipeline']['pose_2d_filter']
pipeline_3d = config['pipeline']['pose_3d']
output_ext = config['pipeline']['pose_3d_ext']

if not output_ext in ['csv', 'h5', 'pkl']:
raise ValueError('The output extension should be csv, h5 or pkl!')

calibration_path = find_calibration_folder(config, session_path)
if calibration_path is None:
Expand All @@ -313,7 +337,12 @@ def process_session(config, session_path):
video_folder = os.path.join(session_path, pipeline_videos_raw)
output_folder = os.path.join(session_path, pipeline_3d)

pose_files = glob(os.path.join(pose_folder, '*.h5'))
camera_names = get_camera_names_from_calib(calib_folder)

pose_files = []

for cam_name in camera_names:
pose_files.append(glob(os.path.join(pose_folder, f'*{cam_name}*.h5'))[0])

cam_videos = defaultdict(list)

Expand All @@ -332,21 +361,20 @@ def process_session(config, session_path):
cam_names = [get_cam_name(config, f) for f in fnames]
fname_dict = dict(zip(cam_names, fnames))

output_fname = os.path.join(output_folder, name + '.csv')

output_fname = os.path.join(output_folder, name + f'pose3d.{output_ext}')
print(output_fname)

if os.path.exists(output_fname):
print(f'Triangulation file already exists at: {output_fname}')
continue


try:
triangulate(config,
calib_folder, video_folder, pose_folder,
fname_dict, output_fname)
except ValueError:
import traceback, sys
traceback.print_exc(file=sys.stdout)


triangulate_all = make_process_fun(process_session)