diff --git a/stytra/collectors/accumulators.py b/stytra/collectors/accumulators.py index fef3eceb..90254326 100644 --- a/stytra/collectors/accumulators.py +++ b/stytra/collectors/accumulators.py @@ -1,3 +1,5 @@ +from typing import Optional + from PyQt5.QtCore import QObject, pyqtSignal import datetime import numpy as np @@ -7,6 +9,7 @@ from bisect import bisect_right from os.path import basename +from stytra.collectors.namedtuplequeue import NamedTupleQueue from stytra.utilities import save_df @@ -14,7 +17,7 @@ class Accumulator(QObject): def __init__(self, experiment, name="", max_history_if_not_running=1000): super().__init__() self.name = name - self.exp = experiment + #self.exp = experiment self.stored_data = [] self.times = [] self.max_history_if_not_running = max_history_if_not_running @@ -76,7 +79,7 @@ def __getitem__(self, item): def t(self): return np.array(self.times) - def values_at_abs_time(self, time): + def values_at_abs_time(self, time, t0): """Finds the values in the accumulator closest to the datetime time Parameters @@ -84,12 +87,15 @@ def values_at_abs_time(self, time): time : datetime time to search for + t0: + reference time 0 + Returns ------- namedtuple of values """ - find_time = (time - self.exp.t0).total_seconds() + find_time = (time - t0).total_seconds() i = bisect_right(self.times, find_time) return self.stored_data[i - 1] @@ -239,17 +245,21 @@ class QueueDataAccumulator(DataFrameAccumulator): Parameters ---------- - data_queue : (multiprocessing.Queue object) + data_queue : NamedTupleQueue queue from witch to retrieve data. + output_queue:Optional[NamedTupleQueue] + an optinal queue to forward the data to header_list : list of str headers for the data to stored. - Returns - ------- - """ - def __init__(self, data_queue, **kwargs): + def __init__( + self, + data_queue: NamedTupleQueue, + output_queue: Optional[NamedTupleQueue] = None, + **kwargs + ): """ """ super().__init__(**kwargs) @@ -257,6 +267,7 @@ def __init__(self, data_queue, **kwargs): # only time differences in milliseconds in the list (faster) self.starting_time = None self.data_queue = data_queue + self.output_queue = output_queue def update_list(self): """Upon calling put all available data into a list.""" @@ -264,6 +275,10 @@ def update_list(self): try: # Get data from queue: t, data = self.data_queue.get(timeout=0.001) + + if self.output_queue is not None: + self.output_queue.put(t, data) + newtype = False if len(self.stored_data) == 0 or type(data) != type( self.stored_data[-1] @@ -313,7 +328,7 @@ def __init__(self, *args, queue, **kwargs): super().__init__(*args, **kwargs) self.queue = queue - def update_list(self): + def update_list(self, fps): while True: try: # Get data from queue: diff --git a/stytra/examples/custom_tracking_exp.py b/stytra/examples/custom_tracking_exp.py index 9b914c39..1fa85b24 100644 --- a/stytra/examples/custom_tracking_exp.py +++ b/stytra/examples/custom_tracking_exp.py @@ -180,7 +180,7 @@ def retrieve_image(self): # To match tracked points and frame displayed looks for matching # timestamps of the displayed frame and of tracked queue: retrieved_data = self.experiment.acc_tracking.values_at_abs_time( - self.current_frame_time + self.current_frame_time, self.experiment.t0 ) # Check for valid data to be displayed: diff --git a/stytra/experiments/tracking_experiments.py b/stytra/experiments/tracking_experiments.py index 6510a55d..bf9649ef 100644 --- a/stytra/experiments/tracking_experiments.py +++ b/stytra/experiments/tracking_experiments.py @@ -21,6 +21,7 @@ EstimatorLog, FramerateQueueAccumulator, ) +from stytra.stimulation.estimator_process import EstimatorProcess from stytra.tracking.tracking_process import TrackingProcess from stytra.tracking.pipelines import Pipeline from stytra.collectors.namedtuplequeue import NamedTupleQueue @@ -191,9 +192,7 @@ class TrackingExperiment(CameraVisualExperiment): """ - def __init__( - self, *args, tracking, recording=None, second_output_queue=None, **kwargs - ): + def __init__(self, *args, tracking, recording=None, second_output_queue=None, **kwargs): """ :param tracking_method: class with the parameters for tracking (instance of TrackingMethod class, defined in the child); @@ -210,14 +209,10 @@ def __init__( super().__init__(*args, **kwargs) self.arguments.update(locals()) - self.recording_event = ( - Event() if (recording is not None or recording is False) else None - ) + self.recording_event = Event() if (recording is not None or recording is False) else None self.pipeline_cls = ( - pipeline_dict.get(tracking["method"], None) - if isinstance(tracking["method"], str) - else tracking["method"] + pipeline_dict.get(tracking["method"], None) if isinstance(tracking["method"], str) else tracking["method"] ) self.frame_dispatcher = TrackingProcess( @@ -237,20 +232,6 @@ def __init__( assert isinstance(self.pipeline, Pipeline) self.pipeline.setup(tree=self.dc) - self.acc_tracking = QueueDataAccumulator( - name="tracking", - experiment=self, - data_queue=self.tracking_output_queue, - monitored_headers=self.pipeline.headers_to_plot, - ) - self.acc_tracking.sig_acc_init.connect(self.refresh_plots) - - # Data accumulator is updated with GUI timer: - self.gui_timer.timeout.connect(self.acc_tracking.update_list) - - # Tracking is reset at experiment start: - self.protocol_runner.sig_protocol_started.connect(self.acc_tracking.reset) - # start frame dispatcher process: self.frame_dispatcher.start() @@ -263,15 +244,28 @@ def __init__( est = est_type if est is not None: + self.estimator_process = EstimatorProcess(est_type, self.tracking_output_queue, self.finished_sig) self.estimator_log = EstimatorLog(experiment=self) - self.estimator = est( - self.acc_tracking, - experiment=self, - **tracking.get("estimator_params", {}) - ) + self.estimator = est(self.acc_tracking, experiment=self, **tracking.get("estimator_params", {})) self.estimator_log.sig_acc_init.connect(self.refresh_plots) + tracking_output_queue = self.estimator_process.tracking_output_queue else: self.estimator = None + tracking_output_queue = self.tracking_output_queue + + self.acc_tracking = QueueDataAccumulator( + name="tracking", + experiment=self, + data_queue=tracking_output_queue, + monitored_headers=self.pipeline.headers_to_plot, + ) + self.acc_tracking.sig_acc_init.connect(self.refresh_plots) + + # Data accumulator is updated with GUI timer: + self.gui_timer.timeout.connect(self.acc_tracking.update_list) + + # Tracking is reset at experiment start: + self.protocol_runner.sig_protocol_started.connect(self.acc_tracking.reset) self.acc_tracking_framerate = FramerateQueueAccumulator( self, @@ -376,9 +370,7 @@ def end_protocol(self, save=True): def save_data(self): """Save tail position and dynamic parameters and terminate.""" - self.window_main.camera_display.save_image( - name=self.filename_base() + "img.png" - ) + self.window_main.camera_display.save_image(name=self.filename_base() + "img.png") self.dc.add_static_data(self.filename_prefix() + "img.png", "tracking/image") # Save log and estimators: diff --git a/stytra/gui/camera_display.py b/stytra/gui/camera_display.py index c193c618..2b1e843a 100644 --- a/stytra/gui/camera_display.py +++ b/stytra/gui/camera_display.py @@ -340,7 +340,7 @@ def retrieve_image(self): # To match tracked points and frame displayed looks for matching # timestamps from the two different queues: retrieved_data = self.experiment.acc_tracking.values_at_abs_time( - self.current_frame_time + self.current_frame_time, self.experiment.t0 ) # Check for data to be displayed: # Retrieve tail angles from tail @@ -442,7 +442,7 @@ def retrieve_image(self): # To match tracked points and frame displayed looks for matching # timestamps from the two different queues: retrieved_data = self.experiment.acc_tracking.values_at_abs_time( - self.current_frame_time + self.current_frame_time, self.experiment.t0 ) # Check for data to be displayed: @@ -622,7 +622,7 @@ def retrieve_image(self): return current_data = self.experiment.acc_tracking.values_at_abs_time( - self.current_frame_time + self.current_frame_time, self.experiment.t0 ) n_fish = self.tracking_params.n_fish_max diff --git a/stytra/stimulation/estimator_process.py b/stytra/stimulation/estimator_process.py new file mode 100644 index 00000000..994915d5 --- /dev/null +++ b/stytra/stimulation/estimator_process.py @@ -0,0 +1,29 @@ +from multiprocessing import Event, Process +from typing import Type + +from stytra.collectors import QueueDataAccumulator +from stytra.collectors.namedtuplequeue import NamedTupleQueue +from stytra.stimulation.estimators import Estimator + + +class EstimatorProcess(Process): + def __init__( + self, + estimator_cls: Type[Estimator], + tracking_queue: NamedTupleQueue, + finished_signal: Event, + ): + super().__init__() + self.tracking_queue = tracking_queue + self.tracking_output_queue = NamedTupleQueue() + self.estimator_queue = NamedTupleQueue() + self.tracking_accumulator = QueueDataAccumulator(self.tracking_queue, self.tracking_output_queue) + self.finished_signal = finished_signal + self.estimator_cls = estimator_cls + + def run(self): + estimator = self.estimator_cls(self.tracking_accumulator, self.estimator_queue) + + while not self.finished_signal.is_set(): + self.tracking_accumulator.update_list() + estimator.update() diff --git a/stytra/stimulation/estimators.py b/stytra/stimulation/estimators.py index 60d17d9e..ed0c6f6b 100644 --- a/stytra/stimulation/estimators.py +++ b/stytra/stimulation/estimators.py @@ -16,10 +16,10 @@ class Estimator: stream of the tracking pipelines (position in pixels, tail angles, etc.). """ - def __init__(self, acc_tracking: QueueDataAccumulator, experiment): - self.exp = experiment + def __init__(self, acc_tracking: QueueDataAccumulator, output_queue: NamedTupleQueue, cam_to_proj=None): self.acc_tracking = acc_tracking - self.output_queue = NamedTupleQueue() + self.output_queue = output_queue + self.cam_to_proj = cam_to_proj self._output_type = None def update(self): @@ -184,8 +184,8 @@ def get_position(self) -> Tuple[float, PositionEstimate]: past_coords = self.acc_tracking.stored_data[-1] t = self.acc_tracking.times[-1] - if not self.calibrator.cam_to_proj is None: - projmat = np.array(self.calibrator.cam_to_proj) + if not self.cam_to_proj is None: + projmat = np.array(self.cam_to_proj) if projmat.shape != (2, 3): projmat = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) diff --git a/stytra/tracking/tracking_process.py b/stytra/tracking/tracking_process.py index 7b7b19d5..47b1631f 100644 --- a/stytra/tracking/tracking_process.py +++ b/stytra/tracking/tracking_process.py @@ -1,5 +1,5 @@ from queue import Empty, Full -from multiprocessing import Event, Value +from multiprocessing import Event from stytra.utilities import FrameProcess from arrayqueues.shared_arrays import TimestampedArrayQueue