diff --git a/detection_operator_runner.py b/detection_operator_runner.py index 80bca99ad..bece02b26 100644 --- a/detection_operator_runner.py +++ b/detection_operator_runner.py @@ -44,7 +44,7 @@ 'camera_logger', 'multiple_object_logger', 'collision_sensor', 'object_tracker', 'linear_predictor', 'obstacle_finder', 'fusion', 'gnss_sensor', 'imu_sensor', 'lane_invasion_sensor', 'lidar', - 'traffic_light_invasion', 'prediction_eval' + 'traffic_light_invasion', 'prediction_eval', 'r2p2' ], help='Operator of choice to test') @@ -518,6 +518,31 @@ def main(args): obstacle_pos_stream = pylot.operator_creator.add_fusion( pose_stream, obstacles_stream, depth_camera_ingest_stream, None) + if FLAGS.test_operator == 'r2p2': + time_to_decision_loop_stream = erdos.streams.LoopStream() + + obstacles_stream = pylot.operator_creator.add_obstacle_detection( + rgb_camera_ingest_stream, time_to_decision_loop_stream)[0] + + time_to_decision_stream = pylot.operator_creator.add_time_to_decision( + pose_stream, obstacles_stream) + time_to_decision_loop_stream.connect_loop(time_to_decision_stream) + + obstacles_wo_history_tracking_stream = pylot.operator_creator.add_obstacle_tracking( + obstacles_stream, rgb_camera_ingest_stream, + time_to_decision_stream) + + obstacles_tracking_stream = pylot.operator_creator.add_obstacle_location_history( + obstacles_wo_history_tracking_stream, + depth_camera_ingest_stream, pose_stream, rgb_camera_setup) + + (point_cloud_stream, notify_lidar_stream, + lidar_setup) = pylot.operator_creator.add_lidar( + transform, vehicle_id_stream, release_sensor_stream) + + prediction_stream = pylot.operator_creator.add_r2p2_prediction( + point_cloud_stream, obstacles_tracking_stream, + time_to_decision_loop_stream, lidar_setup) erdos.run_async() diff --git a/pylot/operator_creator.py b/pylot/operator_creator.py index 57543c02b..2bfa83937 100644 --- a/pylot/operator_creator.py +++ b/pylot/operator_creator.py @@ -392,17 +392,23 @@ def add_linear_prediction(tracking_stream: Stream, return prediction_stream -def add_r2p2_prediction(point_cloud_stream, obstacles_tracking_stream, - time_to_decision_stream, lidar_setup): +def add_r2p2_prediction(point_cloud_stream: Stream, + obstacles_tracking_stream: Stream, + time_to_decision_stream: Stream, + lidar_setup: LidarSetup) -> Stream: from pylot.prediction.r2p2_predictor_operator import \ R2P2PredictorOperator - op_config = erdos.OperatorConfig(name='r2p2_prediction_operator', - log_file_name=FLAGS.log_file_name, - csv_log_file_name=FLAGS.csv_log_file_name, - profile_file_name=FLAGS.profile_file_name) - [prediction_stream] = erdos.connect(R2P2PredictorOperator, op_config, [ - point_cloud_stream, obstacles_tracking_stream, time_to_decision_stream - ], FLAGS, lidar_setup) + op_config = erdos.operator.OperatorConfig( + name='r2p2_prediction_operator', + log_file_name=FLAGS.log_file_name, + csv_log_file_name=FLAGS.csv_log_file_name, + profile_file_name=FLAGS.profile_file_name) + concatenated_streams = point_cloud_stream.concat(obstacles_tracking_stream, + time_to_decision_stream) + prediction_stream = erdos.connect_one_in_one_out(R2P2PredictorOperator, + op_config, + concatenated_streams, + FLAGS, lidar_setup) return prediction_stream diff --git a/pylot/perception/tracking/obstacle_location_history_operator.py b/pylot/perception/tracking/obstacle_location_history_operator.py index bfb4f4a0d..2f8fbae7f 100644 --- a/pylot/perception/tracking/obstacle_location_history_operator.py +++ b/pylot/perception/tracking/obstacle_location_history_operator.py @@ -98,7 +98,6 @@ def on_watermark(self, context: OneInOneOutContext): # The trajectory is relative to the current location. obstacle_trajectories.append( ObstacleTrajectory(obstacle, cur_obstacle_trajectory)) - context.write_stream.send( erdos.Message( context.timestamp, diff --git a/pylot/prediction/r2p2_predictor_operator.py b/pylot/prediction/r2p2_predictor_operator.py index 34618f5c8..49ca6dd93 100644 --- a/pylot/prediction/r2p2_predictor_operator.py +++ b/pylot/prediction/r2p2_predictor_operator.py @@ -1,13 +1,17 @@ +from typing import Union import time from collections import deque import erdos +from erdos.operator import OneInOneOut +from erdos.context import OneInOneOutContext import numpy as np +from pylot.perception.messages import ObstacleTrajectoriesMessageTuple import pylot.prediction.utils -from pylot.prediction.messages import PredictionMessage from pylot.prediction.obstacle_prediction import ObstaclePrediction +from pylot.perception.point_cloud import PointCloud from pylot.utils import Location, Transform, time_epoch_ms import torch @@ -18,7 +22,7 @@ raise Exception('Error importing R2P2.') -class R2P2PredictorOperator(erdos.Operator): +class R2P2PredictorOperator(OneInOneOut): """Wrapper operator for R2P2 ego-vehicle prediction module. Args: @@ -35,10 +39,7 @@ class R2P2PredictorOperator(erdos.Operator): of the lidar. This setup is used to get the maximum range of the lidar. """ - def __init__(self, point_cloud_stream: erdos.ReadStream, - tracking_stream: erdos.ReadStream, - time_to_decision_stream: erdos.ReadStream, - prediction_stream: erdos.WriteStream, flags, lidar_setup): + def __init__(self, flags, lidar_setup): print("WARNING: R2P2 predicts only vehicle trajectories") self._logger = erdos.utils.setup_logging(self.config.name, self.config.log_file_name) @@ -51,32 +52,44 @@ def __init__(self, point_cloud_stream: erdos.ReadStream, state_dict = torch.load(flags.r2p2_model_path) self._r2p2_model.load_state_dict(state_dict) - point_cloud_stream.add_callback(self.on_point_cloud_update) - tracking_stream.add_callback(self.on_trajectory_update) - time_to_decision_stream.add_callback(self.on_time_to_decision_update) - erdos.add_watermark_callback([point_cloud_stream, tracking_stream], - [prediction_stream], self.on_watermark) - self._lidar_setup = lidar_setup self._point_cloud_msgs = deque() self._tracking_msgs = deque() - @staticmethod - def connect(point_cloud_stream: erdos.ReadStream, - tracking_stream: erdos.ReadStream, - time_to_decision_stream: erdos.ReadStream): - prediction_stream = erdos.WriteStream() - return [prediction_stream] + def on_data(self, context: OneInOneOutContext, + data: Union[PointCloud, ObstacleTrajectoriesMessageTuple, + float]): + if isinstance(data, PointCloud): + self.on_point_cloud_update(context, data) + elif isinstance(data, ObstacleTrajectoriesMessageTuple): + self.on_trajectory_update(context, data) + elif isinstance(data, float): + self.on_time_to_decision_update(context, data) + else: + raise ValueError('Unexpected data type') + + def on_point_cloud_update(self, context: OneInOneOutContext, + data: PointCloud): + self._logger.debug('@{}: received point cloud message'.format( + context.timestamp)) + self._point_cloud_msgs.append(data) - def destroy(self): - self._logger.warn('destroying {}'.format(self.config.name)) + def on_trajectory_update(self, context: OneInOneOutContext, + data: ObstacleTrajectoriesMessageTuple): + self._logger.debug('@{}: received trajectories message'.format( + context.timestamp)) + self._tracking_msgs.append(data) + + def on_time_to_decision_update(self, context: OneInOneOutContext, + data: float): + self._logger.debug('@{}: {} received ttd update {}'.format( + context.timestamp, self.config.name, data)) @erdos.profile_method() - def on_watermark(self, timestamp: erdos.Timestamp, - prediction_stream: erdos.WriteStream): - self._logger.debug('@{}: received watermark'.format(timestamp)) - if timestamp.is_top: + def on_watermark(self, context: OneInOneOutContext): + self._logger.debug('@{}: received watermark'.format(context.timestamp)) + if context.timestamp.is_top: return point_cloud_msg = self._point_cloud_msgs.popleft() tracking_msg = self._tracking_msgs.popleft() @@ -89,10 +102,10 @@ def on_watermark(self, timestamp: erdos.Timestamp, num_predictions = len(nearby_trajectories) self._logger.info( '@{}: Getting R2P2 predictions for {} vehicles'.format( - timestamp, num_predictions)) + context.timestamp, num_predictions)) if num_predictions == 0: - prediction_stream.send(PredictionMessage(timestamp, [])) + context.write_stream.send(erdos.Message(context.timestamp, [])) return # Run the forward pass. @@ -105,7 +118,7 @@ def on_watermark(self, timestamp: erdos.Timestamp, z, nearby_trajectories_tensor, binned_lidars_tensor) model_runtime = (time.time() - model_start_time) * 1000 self._csv_logger.debug("{},{},{},{:.4f}".format( - time_epoch_ms(), timestamp.coordinates[0], + time_epoch_ms(), context.timestamp.coordinates[0], 'r2p2-modelonly-runtime', model_runtime)) prediction_array = prediction_array.cpu().detach().numpy() @@ -114,10 +127,13 @@ def on_watermark(self, timestamp: erdos.Timestamp, nearby_vehicle_ego_transforms) runtime = (time.time() - start_time) * 1000 self._csv_logger.debug("{},{},{},{:.4f}".format( - time_epoch_ms(), timestamp.coordinates[0], 'r2p2-runtime', + time_epoch_ms(), context.timestamp.coordinates[0], 'r2p2-runtime', runtime)) - prediction_stream.send( - PredictionMessage(timestamp, obstacle_predictions_list)) + + print('***************************') + print(obstacle_predictions_list) + context.write_stream.send( + erdos.Message(context.timestamp, obstacle_predictions_list)) def _preprocess_input(self, tracking_msg, point_cloud_msg): @@ -203,16 +219,5 @@ def _postprocess_predictions(self, prediction_array, vehicle_trajectories, obstacle_transform, 1.0, predictions)) return obstacle_predictions_list - def on_point_cloud_update(self, msg: erdos.Message): - self._logger.debug('@{}: received point cloud message'.format( - msg.timestamp)) - self._point_cloud_msgs.append(msg) - - def on_trajectory_update(self, msg: erdos.Message): - self._logger.debug('@{}: received trajectories message'.format( - msg.timestamp)) - self._tracking_msgs.append(msg) - - def on_time_to_decision_update(self, msg: erdos.Message): - self._logger.debug('@{}: {} received ttd update {}'.format( - msg.timestamp, self.config.name, msg)) + def destroy(self): + self._logger.warn('destroying {}'.format(self.config.name))