Skip to content

Commit

Permalink
Working comparison node
Browse files Browse the repository at this point in the history
  • Loading branch information
Miguelmelon committed May 28, 2024
1 parent 125981c commit 9b32631
Showing 1 changed file with 67 additions and 19 deletions.
86 changes: 67 additions & 19 deletions people_tracking_v2/scripts/comparison_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,36 @@

import rospy
import numpy as np
from people_tracking_v2.msg import HoCVector # Import the custom message type
from people_tracking_v2.msg import HoCVector, BodySize # Import the custom message types
from std_msgs.msg import String
import os

class CompareHoCDataNode:
class ComparisonNode:
def __init__(self):
# Initialize the ROS node
rospy.init_node('compare_hoc_data', anonymous=True)
rospy.init_node('comparison_node', anonymous=True)

# Subscriber to HoC vectors
# Subscribers to HoC vectors and BodySize
self.subscriber_hoc = rospy.Subscriber('/hoc_vectors', HoCVector, self.hoc_callback)
self.subscriber_pose = rospy.Subscriber('/pose_distances', BodySize, self.pose_callback)

# Publisher for debug information or status updates
self.publisher_debug = rospy.Publisher('/comparison/debug', String, queue_size=10)

# Load saved HoC data
# Load saved HoC and Pose data
self.hoc_data_file = os.path.expanduser('~/hoc_data/hoc_data.npz')
self.pose_data_file = os.path.expanduser('~/pose_data/pose_data.npz')
self.load_hoc_data()
self.load_pose_data()

# Initialize storage for the latest incoming data
self.latest_hoc_vectors = []
self.latest_pose_data = None

rospy.spin()

def load_hoc_data(self):
"""Load the saved HoC data from the .npz file."""
"""Load the saved HoC data from the .npz file (HoC)."""
if os.path.exists(self.hoc_data_file):
data = np.load(self.hoc_data_file)
self.saved_hue = data['hue'][0]
Expand All @@ -35,20 +42,61 @@ def load_hoc_data(self):
self.saved_hue = None
self.saved_sat = None

def load_pose_data(self):
"""Load the saved Pose data from the .npz file (Pose)."""
if os.path.exists(self.pose_data_file):
data = np.load(self.pose_data_file)
self.saved_pose_data = {
'left_shoulder_hip_distance': data['left_shoulder_hip_distance'],
'right_shoulder_hip_distance': data['right_shoulder_hip_distance']
}
rospy.loginfo(f"Loaded Pose data from {self.pose_data_file}")
else:
rospy.logerr(f"Pose data file {self.pose_data_file} not found")
self.saved_pose_data = None

def hoc_callback(self, msg):
"""Callback function to handle new HoC detections."""
"""Callback function to handle new HoC detections (HoC)."""
if self.saved_hue is None or self.saved_sat is None:
rospy.logerr("No saved HoC data available for comparison")
return

self.latest_hoc_vectors.append(msg)
self.compare_data()

def pose_callback(self, msg):
"""Callback function to handle new BodySize data (Pose)."""
self.latest_pose_data = msg
self.compare_data()

def compare_data(self):
"""Compare HoC and pose data if both are available (General)."""
if not self.latest_hoc_vectors or self.latest_pose_data is None or self.saved_pose_data is None:
return

for hoc_vector in self.latest_hoc_vectors:
# Compare HoC data
hue_vector = hoc_vector.hue_vector
sat_vector = hoc_vector.sat_vector
hoc_distance_score = self.compute_hoc_distance_score(hue_vector, sat_vector)
rospy.loginfo(f"HoC Distance score for detection: {hoc_distance_score}")

# Compare pose data
left_shoulder_hip_distance = self.latest_pose_data.left_shoulder_hip_distance
right_shoulder_hip_distance = self.latest_pose_data.right_shoulder_hip_distance
left_shoulder_hip_saved = np.mean(self.saved_pose_data['left_shoulder_hip_distance'])
right_shoulder_hip_saved = np.mean(self.saved_pose_data['right_shoulder_hip_distance'])

left_distance = self.compute_distance(left_shoulder_hip_distance, left_shoulder_hip_saved)
right_distance = self.compute_distance(right_shoulder_hip_distance, right_shoulder_hip_saved)
pose_distance_score = (left_distance + right_distance) / 2
rospy.loginfo(f"Pose Distance score for detection: {pose_distance_score}")

hue_vector = msg.hue_vector
sat_vector = msg.sat_vector
distance_score = self.compute_distance_score(hue_vector, sat_vector)
rospy.loginfo(f"Distance score for detection: {distance_score:.2f}")
self.publish_debug_info(distance_score)
# Publish debug information
self.publish_debug_info(hoc_distance_score, pose_distance_score)

def compute_distance_score(self, hue_vector, sat_vector):
"""Compute the distance score between the current detection and saved data."""
def compute_hoc_distance_score(self, hue_vector, sat_vector):
"""Compute the distance score between the current detection and saved data (HoC)."""
hue_vector = np.array(hue_vector)
sat_vector = np.array(sat_vector)

Expand All @@ -58,17 +106,17 @@ def compute_distance_score(self, hue_vector, sat_vector):
return (hue_distance + sat_distance) / 2

def compute_distance(self, vector1, vector2):
"""Compute the Euclidean distance between two vectors."""
"""Compute the Euclidean distance between two vectors (General)."""
return np.linalg.norm(vector1 - vector2)

def publish_debug_info(self, distance_score):
"""Publish debug information about the current comparison."""
def publish_debug_info(self, hoc_distance_score, pose_distance_score):
"""Publish debug information about the current comparison (General)."""
debug_msg = String()
debug_msg.data = f"Distance score: {distance_score:.2f}"
debug_msg.data = f"HoC Distance score: {hoc_distance_score}, Pose Distance score: {pose_distance_score}"
self.publisher_debug.publish(debug_msg)

if __name__ == '__main__':
try:
CompareHoCDataNode()
ComparisonNode()
except rospy.ROSInterruptException:
pass

0 comments on commit 9b32631

Please sign in to comment.