-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aca3ba0
commit 644af4b
Showing
7 changed files
with
182 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#!/usr/bin/env python | ||
import rospy | ||
import csv | ||
import numpy as np | ||
from people_tracking_v2.msg import DetectionArray | ||
from std_msgs.msg import String | ||
from collections import namedtuple | ||
|
||
# Named tuple for storing target information | ||
Target = namedtuple("Target", ["nr_batch", "time", "colour_vector"]) | ||
|
||
class TargetMemoryNode: | ||
def __init__(self): | ||
# Initialize the ROS node | ||
rospy.init_node('target_memory', anonymous=True) | ||
|
||
# Subscriber to HoC detections | ||
self.subscriber_hoc = rospy.Subscriber('/hero/HoC', DetectionArray, self.hoc_callback) | ||
|
||
# Publisher for debug information or status updates | ||
self.publisher_debug = rospy.Publisher('/hero/target_memory/debug', String, queue_size=10) | ||
|
||
# Dictionary to store targets with their batch number as key | ||
self.targets = {} | ||
|
||
# Load saved HoC data | ||
self.load_hoc_data() | ||
|
||
rospy.spin() | ||
|
||
def load_hoc_data(self): | ||
"""Load the saved HoC data from the CSV file.""" | ||
try: | ||
with open('hoc_data.csv', 'r') as csvfile: | ||
csv_reader = csv.reader(csvfile) | ||
for row in csv_reader: | ||
nr_batch = int(row[0]) | ||
time = rospy.Time(float(row[1])) | ||
colour_vector = list(map(float, row[2:])) | ||
self.targets[nr_batch] = Target(nr_batch=nr_batch, time=time, colour_vector=colour_vector) | ||
rospy.loginfo("Loaded HoC data from hoc_data.csv") | ||
except Exception as e: | ||
rospy.logerr(f"Failed to load HoC data: {e}") | ||
|
||
def hoc_callback(self, msg): | ||
"""Callback function to handle new HoC detections.""" | ||
for detection in msg.detections: | ||
similarity_score = self.compute_similarity(detection) | ||
self.update_target(detection, similarity_score) | ||
|
||
def update_target(self, detection, similarity_score): | ||
"""Update or add a target based on the detection.""" | ||
batch_number = detection.nr_batch | ||
colour_vector = detection.colour_vector | ||
|
||
if batch_number in self.targets: | ||
# Update existing target | ||
existing_target = self.targets[batch_number] | ||
updated_target = Target(nr_batch=batch_number, time=existing_target.time, colour_vector=colour_vector) | ||
self.targets[batch_number] = updated_target | ||
rospy.loginfo(f"Updated target batch {batch_number} with similarity score {similarity_score:.2f}") | ||
else: | ||
# Add new target | ||
new_target = Target(nr_batch=batch_number, time=rospy.Time.now(), colour_vector=colour_vector) | ||
self.targets[batch_number] = new_target | ||
rospy.loginfo(f"Added new target batch {batch_number} with similarity score {similarity_score:.2f}") | ||
|
||
# Publish debug information | ||
debug_msg = String() | ||
debug_msg.data = f"Total targets: {len(self.targets)} | Last similarity score: {similarity_score:.2f}" | ||
self.publisher_debug.publish(debug_msg) | ||
|
||
def compute_similarity(self, detection): | ||
"""Compute the similarity score between the current detection and saved targets.""" | ||
max_similarity = 0 | ||
current_vector = np.array(detection.colour_vector) | ||
|
||
for target in self.targets.values(): | ||
saved_vector = np.array(target.colour_vector) | ||
similarity = np.dot(current_vector, saved_vector) / (np.linalg.norm(current_vector) * np.linalg.norm(saved_vector)) | ||
max_similarity = max(max_similarity, similarity) | ||
|
||
return max_similarity | ||
|
||
if __name__ == '__main__': | ||
try: | ||
TargetMemoryNode() | ||
except rospy.ROSInterruptException: | ||
pass |
File renamed without changes.
34 changes: 34 additions & 0 deletions
34
people_tracking_v2/src/people_tracking/launchers/save-HoC-data.launch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
<!-- <arg name="laptop" default="True" /> --> | ||
<!-- <arg name="depth_camera" default="False" /> --> | ||
<!-- <arg name="save_data" default="False"/> --> | ||
|
||
<launch> | ||
<node | ||
pkg="cv_camera" | ||
type="cv_camera_node" | ||
name="Webcam" | ||
output="screen" | ||
/> | ||
|
||
<node | ||
pkg="people_tracking_v2" | ||
type="yolo_seg.py" | ||
name="Image_seg" | ||
output="screen" | ||
/> | ||
|
||
<node | ||
pkg="people_tracking_v2" | ||
type="HoC.py" | ||
name="HoC" | ||
output="screen" | ||
/> | ||
|
||
<node | ||
pkg="people_tracking_v2" | ||
type="save-HoC-data.py" | ||
name="Operator_HoC" | ||
output="screen" | ||
/> | ||
|
||
</launch> |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/env python | ||
|
||
import rospy | ||
import numpy as np | ||
from people_tracking_v2.msg import SegmentedImages # Custom message for batch segmented images | ||
from cv_bridge import CvBridge, CvBridgeError | ||
import os | ||
import sys | ||
|
||
# Add the scripts directory to the Python path | ||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'scripts')) | ||
|
||
from HoC import HoCNode # Import the HoCNode class | ||
|
||
class SaveHoCDataNode: | ||
def __init__(self): | ||
rospy.init_node('save_hoc_data_node', anonymous=True) | ||
|
||
self.bridge = CvBridge() | ||
self.segmented_images_sub = rospy.Subscriber('/segmented_images', SegmentedImages, self.segmented_images_callback) | ||
|
||
# File to save HoC data | ||
self.hoc_data_file = os.path.expanduser('~/hoc_data/hoc_data.npz') | ||
|
||
# Instantiate the HoCNode for HoC computation without initializing the node again | ||
self.hoc_node = HoCNode(initialize_node=False) | ||
|
||
rospy.spin() | ||
|
||
def segmented_images_callback(self, msg): | ||
rospy.loginfo(f"Received batch of {len(msg.images)} segmented images") | ||
all_hue_histograms = [] | ||
all_sat_histograms = [] | ||
for i, segmented_image_msg in enumerate(msg.images): | ||
try: | ||
segmented_image = self.bridge.imgmsg_to_cv2(segmented_image_msg, "bgr8") | ||
hoc_hue, hoc_sat = self.hoc_node.compute_hoc(segmented_image) | ||
all_hue_histograms.append(hoc_hue) | ||
all_sat_histograms.append(hoc_sat) | ||
except CvBridgeError as e: | ||
rospy.logerr(f"Failed to convert segmented image: {e}") | ||
|
||
# Save all histograms in a single .npz file | ||
np.savez(self.hoc_data_file, hue=all_hue_histograms, sat=all_sat_histograms) | ||
rospy.loginfo(f'Saved HoC data to {self.hoc_data_file}') | ||
|
||
if __name__ == '__main__': | ||
try: | ||
SaveHoCDataNode() | ||
except rospy.ROSInterruptException: | ||
pass |