Skip to content

Commit

Permalink
Initial comparison method
Browse files Browse the repository at this point in the history
  • Loading branch information
Miguelmelon committed May 26, 2024
1 parent 644af4b commit 1d01bca
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 77 deletions.
3 changes: 2 additions & 1 deletion people_tracking_v2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_message_files(
Detection.msg
DetectionArray.msg
SegmentedImages.msg
HoCVector.msg
)

generate_messages(
Expand All @@ -37,9 +38,9 @@ include_directories(
install(PROGRAMS
scripts/face_recognition_node.py
scripts/pose_estimation_node.py
scripts/yolo.py
scripts/yolo_seg.py
scripts/HoC.py
scripts/comparison_node.py
src/kalman_filter.py
tools/save-HoC-data.py
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
Expand Down
3 changes: 3 additions & 0 deletions people_tracking_v2/msg/HoCVector.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Header header
float32[] hue_vector
float32[] sat_vector
17 changes: 14 additions & 3 deletions people_tracking_v2/scripts/HoC.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python

import rospy
from people_tracking_v2.msg import SegmentedImages # Custom message for batch segmented images
from people_tracking_v2.msg import SegmentedImages, HoCVector # Custom message for batch segmented images and HoC vectors
from cv_bridge import CvBridge, CvBridgeError
import cv2
import numpy as np
Expand All @@ -14,6 +14,9 @@ def __init__(self, initialize_node=True):
self.bridge = CvBridge()
self.segmented_images_sub = rospy.Subscriber('/segmented_images', SegmentedImages, self.segmented_images_callback)

# Publisher for HoC vectors
self.hoc_vector_pub = rospy.Publisher('/hoc_vectors', HoCVector, queue_size=10)

if initialize_node:
rospy.spin()

Expand All @@ -24,7 +27,7 @@ def segmented_images_callback(self, msg):
segmented_image = self.bridge.imgmsg_to_cv2(segmented_image_msg, "bgr8")
hoc_hue, hoc_sat = self.compute_hoc(segmented_image)
rospy.loginfo(f'Computed HoC for segmented image #{i + 1}')
# You can process hoc_hue and hoc_sat here or pass them to another node
self.publish_hoc_vectors(hoc_hue, hoc_sat)
except CvBridgeError as e:
rospy.logerr(f"Failed to convert segmented image: {e}")

Expand All @@ -47,7 +50,15 @@ def compute_hoc(self, segmented_image):

# Flatten the histograms
return hist_hue.flatten(), hist_sat.flatten()


def publish_hoc_vectors(self, hue_vector, sat_vector):
"""Publish the computed HoC vectors."""
hoc_msg = HoCVector()
hoc_msg.header.stamp = rospy.Time.now()
hoc_msg.hue_vector = hue_vector.tolist()
hoc_msg.sat_vector = sat_vector.tolist()
self.hoc_vector_pub.publish(hoc_msg)

if __name__ == '__main__':
try:
HoCNode()
Expand Down
74 changes: 74 additions & 0 deletions people_tracking_v2/scripts/comparison_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/usr/bin/env python

import rospy
import numpy as np
from people_tracking_v2.msg import HoCVector # Import the custom message type
from std_msgs.msg import String
import os

class CompareHoCDataNode:
def __init__(self):
# Initialize the ROS node
rospy.init_node('compare_hoc_data', anonymous=True)

# Subscriber to HoC vectors
self.subscriber_hoc = rospy.Subscriber('/hoc_vectors', HoCVector, self.hoc_callback)

# Publisher for debug information or status updates
self.publisher_debug = rospy.Publisher('/comparison/debug', String, queue_size=10)

# Load saved HoC data
self.hoc_data_file = os.path.expanduser('~/hoc_data/hoc_data.npz')
self.load_hoc_data()

rospy.spin()

def load_hoc_data(self):
"""Load the saved HoC data from the .npz file."""
if os.path.exists(self.hoc_data_file):
data = np.load(self.hoc_data_file)
self.saved_hue = data['hue'][0]
self.saved_sat = data['sat'][0]
rospy.loginfo(f"Loaded HoC data from {self.hoc_data_file}")
else:
rospy.logerr(f"HoC data file {self.hoc_data_file} not found")
self.saved_hue = None
self.saved_sat = None

def hoc_callback(self, msg):
"""Callback function to handle new HoC detections."""
if self.saved_hue is None or self.saved_sat is None:
rospy.logerr("No saved HoC data available for comparison")
return

hue_vector = msg.hue_vector
sat_vector = msg.sat_vector
distance_score = self.compute_distance_score(hue_vector, sat_vector)
rospy.loginfo(f"Distance score for detection: {distance_score:.2f}")
self.publish_debug_info(distance_score)

def compute_distance_score(self, hue_vector, sat_vector):
"""Compute the distance score between the current detection and saved data."""
hue_vector = np.array(hue_vector)
sat_vector = np.array(sat_vector)

hue_distance = self.compute_distance(hue_vector, self.saved_hue)
sat_distance = self.compute_distance(sat_vector, self.saved_sat)

return (hue_distance + sat_distance) / 2

def compute_distance(self, vector1, vector2):
"""Compute the Euclidean distance between two vectors."""
return np.linalg.norm(vector1 - vector2)

def publish_debug_info(self, distance_score):
"""Publish debug information about the current comparison."""
debug_msg = String()
debug_msg.data = f"Distance score: {distance_score:.2f}"
self.publisher_debug.publish(debug_msg)

if __name__ == '__main__':
try:
CompareHoCDataNode()
except rospy.ROSInterruptException:
pass
73 changes: 0 additions & 73 deletions people_tracking_v2/scripts/memory_node.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@
output="screen"
/>

<node
pkg="people_tracking_v2"
type="comparison_node.py"
name="Comparison"
output="screen"
/>

<!-- Uncomment these nodes when ready -->
<!--
<node
Expand Down

0 comments on commit 1d01bca

Please sign in to comment.