Skip to content

Commit

Permalink
added working match and localize script
Browse files Browse the repository at this point in the history
  • Loading branch information
digiamm committed Jan 28, 2025
1 parent d11becb commit e2f22b7
Showing 1 changed file with 349 additions and 0 deletions.
349 changes: 349 additions & 0 deletions match_and_localize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,349 @@
import os
import argparse
import numpy as np
from scipy.spatial.transform import Rotation as R
import pycolmap
from pathlib import Path
from hloc import (
extract_features,
match_features,
pairs_from_exhaustive,
visualization
)
from hloc.utils import viz_3d
from hloc.localize_sfm import QueryLocalizer, pose_from_cluster

import matplotlib.pyplot as plt


feature_conf = extract_features.confs["superpoint_max"]
matcher_conf = match_features.confs["superpoint+lightglue"]


def numpy2rigid3d(cam_in_world):
return pycolmap.Rigid3d(cam_in_world[:3, :])

class InvalidPoseLineError(Exception):
pass

def parse_poses_file(input_file):
"""
Parses a text file containing image names and SE3 poses (rotation + translation) in COLMAP style
to create transformation matrices (Tcw).
Args:
input_file (str or Path): Path to the input text file.
Each line should be formatted as:
<image_name> <r11> <r12> <r13> <r21> <r22> <r23> <r31> <r32> <r33> <tx> <ty> <tz>
Returns:
poses_dict (dict): Dictionary to store image names as keys and corresponding 4x4 transformation matrices (Tcw).
"""
input_file = Path(input_file)
poses_dict = dict()

with input_file.open("r") as f:
for line in f:
parts = line.strip().split()
if len(parts) != 8:
raise InvalidPoseLineError(f"invalid number of elements in line: {line.strip()}")

img_name, qw, qx, qy, qz, tx, ty, tz = parts

# Convert quaternion and translation into numpy arrays
q = np.array([float(qx), float(qy), float(qz), float(qw)])
t = np.array([float(tx), float(ty), float(tz)])

# Create rotation matrix from quaternion
rotation = R.from_quat(q).as_matrix() # 3x3 rotation matrix

# Create the 4x4 transformation matrix (Tcw)
Tcw = np.eye(4)
Tcw[:3, :3] = rotation # Rotation part
Tcw[:3, 3] = t # Translation part

# Store the transformation matrix in the dictionary
poses_dict[img_name] = Tcw

print(f"loaded {len(poses_dict)} poses")
return poses_dict

class InvalidCameraLineError(Exception):
"""Custom exception to handle invalid camera lines."""
pass

def parse_camera_file(input_file):
"""
Parses a text file containing image names and intrinsics to create COLMAP camera objects.
Args:
input_file (str or Path): Path to the input text file.
Each line should be formatted as:
<image_name> <model> <width> <height> <fx> <fy> <cx> <cy>
cameras_dict (dict): Dictionary to store image names as keys and pycolmap.Camera objects as values.
Returns:
None
"""
input_file = Path(input_file)
cameras = dict()

with input_file.open("r") as f:
for line in f:
parts = line.strip().split()
if len(parts) != 8:
raise InvalidCameraLineError(f"invalid number of elements in line: {line.strip()}")

img_name, model, width, height, fx, fy, cx, cy = parts

width, height = int(width), int(height)
fx, fy, cx, cy = map(float, (fx, fy, cx, cy))

camera = pycolmap.Camera(
model=model,
width=width,
height=height,
params=[fx, fy, cx, cy],
)
cameras[img_name] = camera

print(f"loaded {len(cameras)} cameras")
return cameras

class MismatchedBufferError(Exception):
"""Custom exception for mismatched buffer sizes or strings."""
pass

def check_buffers(list_buffer, dict_buffer):
"""
Check if the length of the list and dictionary are equal and that their strings match.
Args:
list_buffer (list of str): List of strings.
dict_buffer (dict): Dictionary with strings as keys.
Raises:
MismatchedBufferError: If the lengths do not match or the strings do not match.
"""
# check if the length of the list and dictionary are the same
if len(list_buffer) != len(dict_buffer.keys()):
raise MismatchedBufferError("List and dictionary have mismatched lengths.")

# check if the strings in the list match the keys in the dictionary
for item in list_buffer:
if item not in dict_buffer:
raise MismatchedBufferError(f"String '{item}' in the list does not match any key in the dictionary.")

def get_image_filenames(image_dir, image_extensions=('.png', '.jpg', '.jpeg', '.tiff')):
"""
Get sorted list of image filenames from a directory, filtering by specified extensions.
Args:
image_dir: Path to the directory containing images.
image_extensions: Tuple of image extensions to consider (default: ('.png', '.jpg', '.jpeg', '.tiff')).
Returns:
A tuple of two lists:
- `queries_fn`: Relative paths of image files (filenames).
- `queries_fullp_fn`: Full paths of image files.
"""
image_files = sorted([p for p in Path(image_dir).iterdir() if p.suffix.lower() in image_extensions])

# get relative paths and full paths for the image files
queries_fn = [p.relative_to(image_dir).as_posix() for p in image_files]
queries_fullp_fn = [p.as_posix() for p in image_files]

return queries_fn, queries_fullp_fn

def main(ref_images_path,
query_images_path,
sfm_model_path,
ref_features_fn,
cameras_fn,
poses_fn,
output_path,
debug=True):

images = Path(ref_images_path)
query_images = Path(query_images_path)
sfm_dir = Path(sfm_model_path)
features = Path(ref_features_fn)
outputs = Path(output_path)

# create output directory if it does not exist
outputs.mkdir(parents=True, exist_ok=True)
result_outputs = Path(output_path / Path("results"))
if(debug):
result_outputs.mkdir(parents=True, exist_ok=True)

# setup paths
matches_fn = outputs / "matches.h5"
loc_pairs_fn = outputs / "pairs-loc.txt"
error_poses_fn = outputs/ "error-poses.txt"

# prepare ref and query image paths
references_fn, references_fullp_fn = get_image_filenames(images)
queries_fn, queries_fullp_fn = get_image_filenames(query_images)

# parse cameras and poses
cameras = parse_camera_file(cameras_fn)
poses = parse_poses_file(poses_fn)
check_buffers(queries_fn, cameras)
check_buffers(queries_fn, poses)

# TODO just checks the first two elements
# queries_fn = queries_fn[:50]
# keys = list(cameras.keys())[:50]
# cameras = {k: cameras[k] for k in keys}
# keys = list(poses.keys())[:50]
# poses = {k: poses[k] for k in keys}


# extract features for the query image
extract_features.main(
feature_conf, query_images, image_list=queries_fn, feature_path=features, overwrite=True
)

# generate pairings and match features
pairs_from_exhaustive.main(loc_pairs_fn, image_list=queries_fn, ref_list=references_fullp_fn)
match_features.main(
matcher_conf, loc_pairs_fn, features=features, matches=matches_fn, overwrite=True
)

# read 3D model
model = pycolmap.Reconstruction(sfm_dir)

# localize query image
ref_ids = [model.find_image_with_name(images / r).image_id for r in references_fn]
localizer = QueryLocalizer(model, {"estimation": {"ransac": {"max_error": 12}}})

fig = None
if(debug):
fig3d = viz_3d.init_figure()
viz_3d.plot_reconstruction(
fig3d, model, color="rgba(255,0,0,0.5)", name="mapping", points_rgb=True
)

# write errors to text file
error_pose_file = open(error_poses_fn, "w")

for query in queries_fn:
cam_params = cameras[query]
# localize and get pose
ret, log = pose_from_cluster(localizer, query, cam_params, ref_ids, features, matches_fn)
# print inliers count
print(f'{query} - found {ret["num_inliers"]}/{len(ret["inlier_mask"])} inlier correspondences.')

# get localizated pose
cam_from_world=ret["cam_from_world"]

# make a fk 4x4, making the inversion here for later matrix product
cam_in_world_estimate = np.vstack([cam_from_world.inverse().matrix(), np.array([0, 0, 0, 1])])

# get GT world in camera
Tcw = poses[query]

def get_angle(R):
cos_theta = (np.trace(R)-1)/2
return np.arccos(cos_theta)

T_diff = Tcw @ cam_in_world_estimate
R_diff = T_diff[0:3, 0:3]


angle_error_rad = get_angle(R_diff)
t_diff = np.linalg.norm(T_diff[0:3, 3])

# query img name, t_diff [m], r_diff [deg]
error_pose_file.write(str(query) + " " + str(t_diff) + " " + str(np.rad2deg(angle_error_rad)) + "\n")

# TODO if num inliers is lower than a threshold loc fail
if(debug):
visualization.visualize_loc_from_log(query_images, query, log, model)
fig = plt.gcf() # ax = fig.axes
output_file = result_outputs / f"{query}_results.png"
fig.savefig(output_file, bbox_inches="tight")

pose = pycolmap.Image(cam_from_world=numpy2rigid3d(Tcw))
viz_3d.plot_camera_colmap(
fig3d, pose, cam_params, color="rgba(0,0,255,0.5)", name=query, fill=True
)
# if error is ok to be plotted
if(t_diff < 5.0):
pose = pycolmap.Image(cam_from_world=cam_from_world)
viz_3d.plot_camera_colmap(
fig3d, pose, cam_params, color="rgba(0,255,0,0.5)", name=query, fill=True
)

if(debug):
fig3d.show()

error_pose_file.close()

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Localize query images using an SfM model and his features.")
parser.add_argument(
"--ref_images_path", type=str,
help="path to the folder containing reference images",
required=True
)
parser.add_argument(
"--query_images_path", type=str,
help="path to the folder containing query images",
required=True
)
parser.add_argument(
"--sfm_model_path", type=str,
help="path to the folder containing the SFM model (COLMAP style)",
required=True
)
parser.add_argument(
"--ref_features_fn", type=str,
help="path to the file containing reference features",
required=True
)
parser.add_argument(
"--cameras_fn", type=str,
help="path to the file containing COLMAP camera intrinsics",
required=True
)
parser.add_argument(
"--poses_fn", type=str,
help="path to the file containing GT camera poses COLMAP style",
required=True
)
parser.add_argument(
"--output_path", type=str,
help="path to the folder for saving outputs",
required=True
)
args = parser.parse_args()

import shutil
input_path = Path(args.ref_features_fn)
if not input_path.is_file():
raise FileNotFoundError(f"the file {args.ref_features_fn} does not exist.")

# determine the new file's path
feature_fn = input_path.parent / "localization_features.h5"

# copy the file
shutil.copy(input_path, feature_fn)
print(f"file copied from {input_path} to {feature_fn}")

main(
args.ref_images_path,
args.query_images_path,
args.sfm_model_path,
feature_fn,
args.cameras_fn,
args.poses_fn,
args.output_path,
)

feature_fn.unlink()
print(f"file {feature_fn} removed")



# python ~/source/external/Hierarchical-Localization/match_and_localize.py --ref_images_path scene_reconstruction/images/ --query_images_path labels_directions_100/0/rendering/query_images/ --sfm_model_path scene_reconstruction/sfm/ --ref_features_fn scene_reconstruction/sfm_features.h5 --camera_fn labels_directions_100/0/rendering/img_nm_to_colmap_cam.txt --output_path ~/Desktop/tmp_out

0 comments on commit e2f22b7

Please sign in to comment.