From 2cbc83549439753a0e9f9a33125369230194cb74 Mon Sep 17 00:00:00 2001 From: Ken Haagh Date: Tue, 7 Nov 2023 17:20:27 +0100 Subject: [PATCH] Implementation People_tracker node --- people_tracking/src/UKFclass.py | 104 +++++++++++++++++++ people_tracking/src/colour_check.py | 2 +- people_tracking/src/people_tracker.py | 57 +++++++--- people_tracking/src/person_detection.py | 3 +- people_tracking/test/test_with_webcam.launch | 12 +-- 5 files changed, 154 insertions(+), 24 deletions(-) create mode 100755 people_tracking/src/UKFclass.py diff --git a/people_tracking/src/UKFclass.py b/people_tracking/src/UKFclass.py new file mode 100755 index 0000000..2e43fd4 --- /dev/null +++ b/people_tracking/src/UKFclass.py @@ -0,0 +1,104 @@ +import numpy as np +from filterpy.kalman import MerweScaledSigmaPoints +from filterpy.kalman import UnscentedKalmanFilter +from filterpy.common.discretization import Q_discrete_white_noise +# from numpy.random import randn +# +# import random +# import matplotlib.pyplot as plt + + +class UKF: + + def __init__(self): + self.current_time = 0 + + dt = 0.1 # standard dt + + # Create sigma points + self.points = MerweScaledSigmaPoints(4, alpha=0.1, beta=2.0, kappa=-1) + + self.kf = UnscentedKalmanFilter(dim_x=4, dim_z=2, dt=dt, fx=self.fx, hx=self.hx, points=self.points) + + self.kf.x = np.array([1., 0, 1., 0]) # Initial state + self.kf.P *= 0.2 # Initial uncertainty + + z_std = 0.2 + self.kf.R = np.diag([z_std ** 2, z_std ** 2]) # Measurement noise covariance matrix + + self.kf.Q = Q_discrete_white_noise(dim=2, dt=dt, var= 3 ** 2, block_size=2) # + + # https://filterpy.readthedocs.io/en/latest/kalman/UnscentedKalmanFilter.html + def fx(self, x, dt): + """ Function that returns the state x transformed by the state transition function. (cv model) + Assumption: + * x = [x, vx, z, vz]^T + """ + F = np.array([[1, dt, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, dt], + [0, 0, 0, 1]], dtype=float) + return np.dot(F, x) + + def hx(self, x): + """Measurement function - convert the state into a measurement where measurements are [x_pos, z_pos] """ + return np.array([x[0], x[2]]) + + def predict(self, time): + """ update the kalman filter by predicting value.""" + delta_t = time - self.current_time + self.kf.predict(dt=delta_t) + self.current_time = time + + def update(self, time, z): + """ Update filter with measurements z.""" + self.predict(time) + self.kf.update(z) + + +# # Example with varying dt +# z_std = 0.1 +# zs = [[i + randn() * z_std, i + randn() * z_std] for i in range(50)] # Measurements +# +# start = 0 +# step = 0.1 +# end = 10 +# times = [start + step * i for i in range(int((end - start) / step) + 1)] +# times = [x + random.uniform(0, 0.05) for x in times] +# +# print(times) +# xs = [] +# ys = [] +# xk = [] +# yk = [] +# +# i = -1 +# time_last = 0 +# +# ukf = UKF() +# +# for z in zs: +# i = i + 1 +# # delta_t = times[i] - time_last +# # print(times[i], delta_t) +# ukf.predict(times[i]) +# ukf.update(times[i], z) +# # time_last = times[i] +# # kf.update(z) +# # print(kf.x, 'log-likelihood', kf.log_likelihood) +# xs.append(z[0]) +# ys.append(z[1]) +# xk.append(ukf.kf.x[0]) +# yk.append(ukf.kf.x[2]) +# # print(xk) +# +# fig, ax = plt.subplots() +# # measured = [list(t) for t in zs] +# ax.scatter(xs, ys, s=5) +# ax.plot(xk, yk) +# +# ax.set_xlabel('X-axis') +# ax.set_ylabel('Z-axis') +# +# # Display the plot +# plt.show() diff --git a/people_tracking/src/colour_check.py b/people_tracking/src/colour_check.py index 11addf9..670ffc5 100755 --- a/people_tracking/src/colour_check.py +++ b/people_tracking/src/colour_check.py @@ -78,7 +78,7 @@ def compare_hoc(self, detected_persons): for idx_person, vector in enumerate(person_vectors): distance = self.euclidean(vector, Hoc_detection) if distance < 0.25: - rospy.loginfo(str(idx_person) + " " + str(distance)) + # rospy.loginfo(str(idx_person) + " " + str(distance)) if len(self.HoC_detections) < 5: self.HoC_detections.append(vector) else: diff --git a/people_tracking/src/people_tracker.py b/people_tracking/src/people_tracker.py index 9c06687..bdfcf44 100755 --- a/people_tracking/src/people_tracker.py +++ b/people_tracking/src/people_tracker.py @@ -5,32 +5,59 @@ import numpy as np from cv_bridge import CvBridge +from UKFclass import * + # MSGS -from std_msgs.msg import String +from people_tracking.msg import ColourCheckedTarget +from sensor_msgs.msg import Image NODE_NAME = 'people_tracker' TOPIC_PREFIX = '/hero/' +laptop = True +name_subscriber_RGB = '/hero/head_rgbd_sensor/rgb/image_raw' if not laptop else 'video_frames' + + class PeopleTracker: def __init__(self) -> None: # ROS Initialize rospy.init_node(NODE_NAME, anonymous=True) - self.publisher = rospy.Publisher(TOPIC_PREFIX + 'Location', String, queue_size= 2) - # self.subscriber = rospy.Subscriber(name_subscriber_RGB, Image, self.callback, queue_size = 1) - - rate = rospy.Rate(2) - # Keep publishing the messages until the user interrupts - while not rospy.is_shutdown(): - message = "Hello World" - # rospy.loginfo('Published: ' + message) - # publish the message to the topic - self.publisher.publish(message) - rate.sleep() - # def callback(self, data): - # self.publisher.publish("Hello World") - # rospy.loginfo("send message") + self.subscriber = rospy.Subscriber(TOPIC_PREFIX + 'HoC', ColourCheckedTarget, self.callback, queue_size=1) + # self.publisher = rospy.Publisher(TOPIC_PREFIX + 'Location', String, queue_size= 2) + + self.subscriber_debug = rospy.Subscriber(name_subscriber_RGB, Image, self.plot_tracker, queue_size=1) + self.publisher_debug = rospy.Publisher(TOPIC_PREFIX + 'people_tracker_debug', Image, queue_size=10) + + # Variables + self.ukf = UKF() + + def plot_tracker(self, data): + latest_image = data + bridge = CvBridge() + cv_image = bridge.imgmsg_to_cv2(latest_image, desired_encoding='passthrough') + + self.ukf.predict(float(rospy.get_time())) + x_position = int(self.ukf.kf.x[0]) + # rospy.loginfo('predict: time: ' + str(float(rospy.get_time())) + 'x: ' + str(x_position)) + + x_position = 0 if x_position < 0 else x_position + x_position = 639 if x_position > 639 else x_position + cv2.circle(cv_image, (x_position, 200), 5, (0, 0, 255), -1) + tracker_image = bridge.cv2_to_imgmsg(cv_image, encoding="passthrough") + self.publisher_debug.publish(tracker_image) + + + + def callback(self, data): + x_position = data.x_position + time = data.time + position = [x_position, 0] + rospy.loginfo('time: ' + str(time) + ' x: ' +str(x_position)) + self.ukf.update(time, position) + + if __name__ == '__main__': diff --git a/people_tracking/src/person_detection.py b/people_tracking/src/person_detection.py index 4742872..9b9feb1 100755 --- a/people_tracking/src/person_detection.py +++ b/people_tracking/src/person_detection.py @@ -95,7 +95,7 @@ def process_latest_image(self): image_message = bridge.cv2_to_imgmsg(cropped_image, encoding="passthrough") detected_persons.append(image_message) - x_positions.append((x2-x1)// 2) + x_positions.append(int( x1 + ((x2-x1) / 2))) # Create person_detections msg msg = DetectedPerson() @@ -115,7 +115,6 @@ def process_latest_image(self): def main_loop(self): """ Main loop that makes sure only the latest images are processed""" while not rospy.is_shutdown(): - # self.msg_callback() self.process_latest_image() rospy.sleep(0.001) diff --git a/people_tracking/test/test_with_webcam.launch b/people_tracking/test/test_with_webcam.launch index f4d97aa..38a9ba3 100644 --- a/people_tracking/test/test_with_webcam.launch +++ b/people_tracking/test/test_with_webcam.launch @@ -14,12 +14,12 @@ output="screen" /> - - - - - - +