Skip to content

Commit

Permalink
Initial attempt KF
Browse files Browse the repository at this point in the history
  • Loading branch information
Miguelmelon committed May 25, 2024
1 parent 3cc06ab commit 79ab10b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
26 changes: 26 additions & 0 deletions people_tracking_v2/scripts/yolo_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sensor_msgs.msg import Image
from people_tracking_v2.msg import Detection, DetectionArray, SegmentedImages
from cv_bridge import CvBridge, CvBridgeError
from kalman_filter import KalmanFilterCV # Import the Kalman Filter class

class YoloSegNode:
def __init__(self):
Expand All @@ -16,6 +17,10 @@ def __init__(self):
self.image_sub = rospy.Subscriber("/Webcam/image_raw", Image, self.image_callback)
self.segmented_images_pub = rospy.Publisher("/segmented_images", SegmentedImages, queue_size=10)
self.bounding_box_image_pub = rospy.Publisher("/bounding_box_image", Image, queue_size=10)
self.detection_pub = rospy.Publisher("/hero/predicted_detections", DetectionArray, queue_size=10)

# Initialize the Kalman Filter
self.kalman_filters = {}

def image_callback(self, data):
try:
Expand Down Expand Up @@ -60,6 +65,24 @@ def image_callback(self, data):

# Draw bounding boxes and labels on the bounding_box_image
x1, y1, x2, y2 = map(int, box)
x_center = (x1 + x2) / 2
y_center = (y1 + y2) / 2

# Initialize or update the Kalman Filter for this detection
if i not in self.kalman_filters:
self.kalman_filters[i] = KalmanFilterCV()
kalman_filter = self.kalman_filters[i]

# Update the Kalman Filter with the new measurement
kalman_filter.update(np.array([[x_center], [y_center]]))

# Predict the next position
kalman_filter.predict()
x_pred, y_pred = kalman_filter.get_state()[:2]

# Draw predicted position
cv2.circle(bounding_box_image, (int(x_pred), int(y_pred)), 5, (255, 0, 0), -1)

color = (0, 255, 0) # Set color for bounding boxes
thickness = 3
label_text = f'#{i+1} {int(label)}: {score:.2f}'
Expand Down Expand Up @@ -90,6 +113,9 @@ def image_callback(self, data):
except CvBridgeError as e:
rospy.logerr(e)

# Publish predicted detections
self.detection_pub.publish(detection_array)

def main():
rospy.init_node('yolo_seg_node', anonymous=True)
yolo_node = YoloSegNode()
Expand Down
34 changes: 34 additions & 0 deletions people_tracking_v2/src/people_tracking/kalman_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np

class KalmanFilterCV:
def __init__(self, dt=0.1):
# Kalman Filter parameters
self.dt = dt # Time step
self.A = np.array([[1, 0, self.dt, 0],
[0, 1, 0, self.dt],
[0, 0, 1, 0],
[0, 0, 0, 1]]) # State transition matrix
self.H = np.array([[1, 0, 0, 0],
[0, 1, 0, 0]]) # Observation matrix
self.P = np.eye(4) # Initial covariance matrix
self.Q = np.eye(4) * 0.01 # Process noise covariance
self.R = np.eye(2) * 0.1 # Measurement noise covariance
self.x = np.zeros((4, 1)) # Initial state [x, y, vx, vy]
self.P = np.eye(4) # Initial covariance matrix

def predict(self):
# Predict the next state
self.x = np.dot(self.A, self.x)
self.P = np.dot(np.dot(self.A, self.P), self.A.T) + self.Q

def update(self, z):
# Update the state with the measurement
y = z - np.dot(self.H, self.x) # Measurement residual
S = np.dot(self.H, np.dot(self.P, self.H.T)) + self.R # Residual covariance
K = np.dot(np.dot(self.P, self.H.T), np.linalg.inv(S)) # Kalman gain
self.x = self.x + np.dot(K, y)
self.P = self.P - np.dot(np.dot(K, self.H), self.P)

def get_state(self):
# Return the current state
return self.x.flatten().tolist()

0 comments on commit 79ab10b

Please sign in to comment.