Skip to content

Commit

Permalink
Implementation People_tracker node
Browse files Browse the repository at this point in the history
  • Loading branch information
KenH2 committed Nov 7, 2023
1 parent 913461d commit 2cbc835
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 24 deletions.
104 changes: 104 additions & 0 deletions people_tracking/src/UKFclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import numpy as np
from filterpy.kalman import MerweScaledSigmaPoints
from filterpy.kalman import UnscentedKalmanFilter
from filterpy.common.discretization import Q_discrete_white_noise
# from numpy.random import randn
#
# import random
# import matplotlib.pyplot as plt


class UKF:

def __init__(self):
self.current_time = 0

dt = 0.1 # standard dt

# Create sigma points
self.points = MerweScaledSigmaPoints(4, alpha=0.1, beta=2.0, kappa=-1)

self.kf = UnscentedKalmanFilter(dim_x=4, dim_z=2, dt=dt, fx=self.fx, hx=self.hx, points=self.points)

self.kf.x = np.array([1., 0, 1., 0]) # Initial state
self.kf.P *= 0.2 # Initial uncertainty

z_std = 0.2
self.kf.R = np.diag([z_std ** 2, z_std ** 2]) # Measurement noise covariance matrix

self.kf.Q = Q_discrete_white_noise(dim=2, dt=dt, var= 3 ** 2, block_size=2) #

# https://filterpy.readthedocs.io/en/latest/kalman/UnscentedKalmanFilter.html
def fx(self, x, dt):
""" Function that returns the state x transformed by the state transition function. (cv model)
Assumption:
* x = [x, vx, z, vz]^T
"""
F = np.array([[1, dt, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, dt],
[0, 0, 0, 1]], dtype=float)
return np.dot(F, x)

def hx(self, x):
"""Measurement function - convert the state into a measurement where measurements are [x_pos, z_pos] """
return np.array([x[0], x[2]])

def predict(self, time):
""" update the kalman filter by predicting value."""
delta_t = time - self.current_time
self.kf.predict(dt=delta_t)
self.current_time = time

def update(self, time, z):
""" Update filter with measurements z."""
self.predict(time)
self.kf.update(z)


# # Example with varying dt
# z_std = 0.1
# zs = [[i + randn() * z_std, i + randn() * z_std] for i in range(50)] # Measurements
#
# start = 0
# step = 0.1
# end = 10
# times = [start + step * i for i in range(int((end - start) / step) + 1)]
# times = [x + random.uniform(0, 0.05) for x in times]
#
# print(times)
# xs = []
# ys = []
# xk = []
# yk = []
#
# i = -1
# time_last = 0
#
# ukf = UKF()
#
# for z in zs:
# i = i + 1
# # delta_t = times[i] - time_last
# # print(times[i], delta_t)
# ukf.predict(times[i])
# ukf.update(times[i], z)
# # time_last = times[i]
# # kf.update(z)
# # print(kf.x, 'log-likelihood', kf.log_likelihood)
# xs.append(z[0])
# ys.append(z[1])
# xk.append(ukf.kf.x[0])
# yk.append(ukf.kf.x[2])
# # print(xk)
#
# fig, ax = plt.subplots()
# # measured = [list(t) for t in zs]
# ax.scatter(xs, ys, s=5)
# ax.plot(xk, yk)
#
# ax.set_xlabel('X-axis')
# ax.set_ylabel('Z-axis')
#
# # Display the plot
# plt.show()
2 changes: 1 addition & 1 deletion people_tracking/src/colour_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def compare_hoc(self, detected_persons):
for idx_person, vector in enumerate(person_vectors):
distance = self.euclidean(vector, Hoc_detection)
if distance < 0.25:
rospy.loginfo(str(idx_person) + " " + str(distance))
# rospy.loginfo(str(idx_person) + " " + str(distance))
if len(self.HoC_detections) < 5:
self.HoC_detections.append(vector)
else:
Expand Down
57 changes: 42 additions & 15 deletions people_tracking/src/people_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,59 @@
import numpy as np
from cv_bridge import CvBridge

from UKFclass import *

# MSGS
from std_msgs.msg import String
from people_tracking.msg import ColourCheckedTarget
from sensor_msgs.msg import Image


NODE_NAME = 'people_tracker'
TOPIC_PREFIX = '/hero/'

laptop = True
name_subscriber_RGB = '/hero/head_rgbd_sensor/rgb/image_raw' if not laptop else 'video_frames'


class PeopleTracker:
def __init__(self) -> None:

# ROS Initialize
rospy.init_node(NODE_NAME, anonymous=True)
self.publisher = rospy.Publisher(TOPIC_PREFIX + 'Location', String, queue_size= 2)
# self.subscriber = rospy.Subscriber(name_subscriber_RGB, Image, self.callback, queue_size = 1)

rate = rospy.Rate(2)
# Keep publishing the messages until the user interrupts
while not rospy.is_shutdown():
message = "Hello World"
# rospy.loginfo('Published: ' + message)
# publish the message to the topic
self.publisher.publish(message)
rate.sleep()
# def callback(self, data):
# self.publisher.publish("Hello World")
# rospy.loginfo("send message")
self.subscriber = rospy.Subscriber(TOPIC_PREFIX + 'HoC', ColourCheckedTarget, self.callback, queue_size=1)
# self.publisher = rospy.Publisher(TOPIC_PREFIX + 'Location', String, queue_size= 2)

self.subscriber_debug = rospy.Subscriber(name_subscriber_RGB, Image, self.plot_tracker, queue_size=1)
self.publisher_debug = rospy.Publisher(TOPIC_PREFIX + 'people_tracker_debug', Image, queue_size=10)

# Variables
self.ukf = UKF()

def plot_tracker(self, data):
latest_image = data
bridge = CvBridge()
cv_image = bridge.imgmsg_to_cv2(latest_image, desired_encoding='passthrough')

self.ukf.predict(float(rospy.get_time()))
x_position = int(self.ukf.kf.x[0])
# rospy.loginfo('predict: time: ' + str(float(rospy.get_time())) + 'x: ' + str(x_position))

x_position = 0 if x_position < 0 else x_position
x_position = 639 if x_position > 639 else x_position
cv2.circle(cv_image, (x_position, 200), 5, (0, 0, 255), -1)
tracker_image = bridge.cv2_to_imgmsg(cv_image, encoding="passthrough")
self.publisher_debug.publish(tracker_image)



def callback(self, data):
x_position = data.x_position
time = data.time
position = [x_position, 0]
rospy.loginfo('time: ' + str(time) + ' x: ' +str(x_position))
self.ukf.update(time, position)




if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions people_tracking/src/person_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def process_latest_image(self):
image_message = bridge.cv2_to_imgmsg(cropped_image, encoding="passthrough")

detected_persons.append(image_message)
x_positions.append((x2-x1)// 2)
x_positions.append(int( x1 + ((x2-x1) / 2)))

# Create person_detections msg
msg = DetectedPerson()
Expand All @@ -115,7 +115,6 @@ def process_latest_image(self):
def main_loop(self):
""" Main loop that makes sure only the latest images are processed"""
while not rospy.is_shutdown():
# self.msg_callback()
self.process_latest_image()

rospy.sleep(0.001)
Expand Down
12 changes: 6 additions & 6 deletions people_tracking/test/test_with_webcam.launch
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
output="screen"
/>

<!-- <node -->
<!-- pkg="people_tracking" -->
<!-- type="people_tracker.py" -->
<!-- name="people_tracker" -->
<!-- output="screen" -->
<!-- /> -->
<node
pkg="people_tracking"
type="people_tracker.py"
name="people_tracker"
output="screen"
/>

<!-- only use when working with webcam: -->
<node
Expand Down

0 comments on commit 2cbc835

Please sign in to comment.