Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP, Redesign] Update R2P2 Predictor #271

Open
wants to merge 1 commit into
base: redesign
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion detection_operator_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'camera_logger', 'multiple_object_logger', 'collision_sensor',
'object_tracker', 'linear_predictor', 'obstacle_finder', 'fusion',
'gnss_sensor', 'imu_sensor', 'lane_invasion_sensor', 'lidar',
'traffic_light_invasion', 'prediction_eval'
'traffic_light_invasion', 'prediction_eval', 'r2p2'
],
help='Operator of choice to test')

Expand Down Expand Up @@ -518,6 +518,31 @@ def main(args):
obstacle_pos_stream = pylot.operator_creator.add_fusion(
pose_stream, obstacles_stream, depth_camera_ingest_stream,
None)
if FLAGS.test_operator == 'r2p2':
time_to_decision_loop_stream = erdos.streams.LoopStream()

obstacles_stream = pylot.operator_creator.add_obstacle_detection(
rgb_camera_ingest_stream, time_to_decision_loop_stream)[0]

time_to_decision_stream = pylot.operator_creator.add_time_to_decision(
pose_stream, obstacles_stream)
time_to_decision_loop_stream.connect_loop(time_to_decision_stream)

obstacles_wo_history_tracking_stream = pylot.operator_creator.add_obstacle_tracking(
obstacles_stream, rgb_camera_ingest_stream,
time_to_decision_stream)

obstacles_tracking_stream = pylot.operator_creator.add_obstacle_location_history(
obstacles_wo_history_tracking_stream,
depth_camera_ingest_stream, pose_stream, rgb_camera_setup)

(point_cloud_stream, notify_lidar_stream,
lidar_setup) = pylot.operator_creator.add_lidar(
transform, vehicle_id_stream, release_sensor_stream)

prediction_stream = pylot.operator_creator.add_r2p2_prediction(
point_cloud_stream, obstacles_tracking_stream,
time_to_decision_loop_stream, lidar_setup)

erdos.run_async()

Expand Down
24 changes: 15 additions & 9 deletions pylot/operator_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,17 +392,23 @@ def add_linear_prediction(tracking_stream: Stream,
return prediction_stream


def add_r2p2_prediction(point_cloud_stream, obstacles_tracking_stream,
time_to_decision_stream, lidar_setup):
def add_r2p2_prediction(point_cloud_stream: Stream,
obstacles_tracking_stream: Stream,
time_to_decision_stream: Stream,
lidar_setup: LidarSetup) -> Stream:
from pylot.prediction.r2p2_predictor_operator import \
R2P2PredictorOperator
op_config = erdos.OperatorConfig(name='r2p2_prediction_operator',
log_file_name=FLAGS.log_file_name,
csv_log_file_name=FLAGS.csv_log_file_name,
profile_file_name=FLAGS.profile_file_name)
[prediction_stream] = erdos.connect(R2P2PredictorOperator, op_config, [
point_cloud_stream, obstacles_tracking_stream, time_to_decision_stream
], FLAGS, lidar_setup)
op_config = erdos.operator.OperatorConfig(
name='r2p2_prediction_operator',
log_file_name=FLAGS.log_file_name,
csv_log_file_name=FLAGS.csv_log_file_name,
profile_file_name=FLAGS.profile_file_name)
concatenated_streams = point_cloud_stream.concat(obstacles_tracking_stream,
time_to_decision_stream)
prediction_stream = erdos.connect_one_in_one_out(R2P2PredictorOperator,
op_config,
concatenated_streams,
FLAGS, lidar_setup)
return prediction_stream


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def on_watermark(self, context: OneInOneOutContext):
# The trajectory is relative to the current location.
obstacle_trajectories.append(
ObstacleTrajectory(obstacle, cur_obstacle_trajectory))

context.write_stream.send(
erdos.Message(
context.timestamp,
Expand Down
91 changes: 48 additions & 43 deletions pylot/prediction/r2p2_predictor_operator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Union
import time
from collections import deque

import erdos
from erdos.operator import OneInOneOut
from erdos.context import OneInOneOutContext

import numpy as np
from pylot.perception.messages import ObstacleTrajectoriesMessageTuple

import pylot.prediction.utils
from pylot.prediction.messages import PredictionMessage
from pylot.prediction.obstacle_prediction import ObstaclePrediction
from pylot.perception.point_cloud import PointCloud
from pylot.utils import Location, Transform, time_epoch_ms

import torch
Expand All @@ -18,7 +22,7 @@
raise Exception('Error importing R2P2.')


class R2P2PredictorOperator(erdos.Operator):
class R2P2PredictorOperator(OneInOneOut):
"""Wrapper operator for R2P2 ego-vehicle prediction module.

Args:
Expand All @@ -35,10 +39,7 @@ class R2P2PredictorOperator(erdos.Operator):
of the lidar. This setup is used to get the maximum range of the
lidar.
"""
def __init__(self, point_cloud_stream: erdos.ReadStream,
tracking_stream: erdos.ReadStream,
time_to_decision_stream: erdos.ReadStream,
prediction_stream: erdos.WriteStream, flags, lidar_setup):
def __init__(self, flags, lidar_setup):
print("WARNING: R2P2 predicts only vehicle trajectories")
self._logger = erdos.utils.setup_logging(self.config.name,
self.config.log_file_name)
Expand All @@ -51,32 +52,44 @@ def __init__(self, point_cloud_stream: erdos.ReadStream,
state_dict = torch.load(flags.r2p2_model_path)
self._r2p2_model.load_state_dict(state_dict)

point_cloud_stream.add_callback(self.on_point_cloud_update)
tracking_stream.add_callback(self.on_trajectory_update)
time_to_decision_stream.add_callback(self.on_time_to_decision_update)
erdos.add_watermark_callback([point_cloud_stream, tracking_stream],
[prediction_stream], self.on_watermark)

self._lidar_setup = lidar_setup

self._point_cloud_msgs = deque()
self._tracking_msgs = deque()

@staticmethod
def connect(point_cloud_stream: erdos.ReadStream,
tracking_stream: erdos.ReadStream,
time_to_decision_stream: erdos.ReadStream):
prediction_stream = erdos.WriteStream()
return [prediction_stream]
def on_data(self, context: OneInOneOutContext,
data: Union[PointCloud, ObstacleTrajectoriesMessageTuple,
float]):
if isinstance(data, PointCloud):
self.on_point_cloud_update(context, data)
elif isinstance(data, ObstacleTrajectoriesMessageTuple):
self.on_trajectory_update(context, data)
elif isinstance(data, float):
self.on_time_to_decision_update(context, data)
else:
raise ValueError('Unexpected data type')

def on_point_cloud_update(self, context: OneInOneOutContext,
data: PointCloud):
self._logger.debug('@{}: received point cloud message'.format(
context.timestamp))
self._point_cloud_msgs.append(data)

def destroy(self):
self._logger.warn('destroying {}'.format(self.config.name))
def on_trajectory_update(self, context: OneInOneOutContext,
data: ObstacleTrajectoriesMessageTuple):
self._logger.debug('@{}: received trajectories message'.format(
context.timestamp))
self._tracking_msgs.append(data)

def on_time_to_decision_update(self, context: OneInOneOutContext,
data: float):
self._logger.debug('@{}: {} received ttd update {}'.format(
context.timestamp, self.config.name, data))

@erdos.profile_method()
def on_watermark(self, timestamp: erdos.Timestamp,
prediction_stream: erdos.WriteStream):
self._logger.debug('@{}: received watermark'.format(timestamp))
if timestamp.is_top:
def on_watermark(self, context: OneInOneOutContext):
self._logger.debug('@{}: received watermark'.format(context.timestamp))
if context.timestamp.is_top:
return
point_cloud_msg = self._point_cloud_msgs.popleft()
tracking_msg = self._tracking_msgs.popleft()
Expand All @@ -89,10 +102,10 @@ def on_watermark(self, timestamp: erdos.Timestamp,
num_predictions = len(nearby_trajectories)
self._logger.info(
'@{}: Getting R2P2 predictions for {} vehicles'.format(
timestamp, num_predictions))
context.timestamp, num_predictions))

if num_predictions == 0:
prediction_stream.send(PredictionMessage(timestamp, []))
context.write_stream.send(erdos.Message(context.timestamp, []))
return

# Run the forward pass.
Expand All @@ -105,7 +118,7 @@ def on_watermark(self, timestamp: erdos.Timestamp,
z, nearby_trajectories_tensor, binned_lidars_tensor)
model_runtime = (time.time() - model_start_time) * 1000
self._csv_logger.debug("{},{},{},{:.4f}".format(
time_epoch_ms(), timestamp.coordinates[0],
time_epoch_ms(), context.timestamp.coordinates[0],
'r2p2-modelonly-runtime', model_runtime))
prediction_array = prediction_array.cpu().detach().numpy()

Expand All @@ -114,10 +127,13 @@ def on_watermark(self, timestamp: erdos.Timestamp,
nearby_vehicle_ego_transforms)
runtime = (time.time() - start_time) * 1000
self._csv_logger.debug("{},{},{},{:.4f}".format(
time_epoch_ms(), timestamp.coordinates[0], 'r2p2-runtime',
time_epoch_ms(), context.timestamp.coordinates[0], 'r2p2-runtime',
runtime))
prediction_stream.send(
PredictionMessage(timestamp, obstacle_predictions_list))

print('***************************')
print(obstacle_predictions_list)
context.write_stream.send(
erdos.Message(context.timestamp, obstacle_predictions_list))

def _preprocess_input(self, tracking_msg, point_cloud_msg):

Expand Down Expand Up @@ -203,16 +219,5 @@ def _postprocess_predictions(self, prediction_array, vehicle_trajectories,
obstacle_transform, 1.0, predictions))
return obstacle_predictions_list

def on_point_cloud_update(self, msg: erdos.Message):
self._logger.debug('@{}: received point cloud message'.format(
msg.timestamp))
self._point_cloud_msgs.append(msg)

def on_trajectory_update(self, msg: erdos.Message):
self._logger.debug('@{}: received trajectories message'.format(
msg.timestamp))
self._tracking_msgs.append(msg)

def on_time_to_decision_update(self, msg: erdos.Message):
self._logger.debug('@{}: {} received ttd update {}'.format(
msg.timestamp, self.config.name, msg))
def destroy(self):
self._logger.warn('destroying {}'.format(self.config.name))