@@ -76,8 +76,9 @@ def new_tour_step(self):
7676 self .tour_step += 1
7777 if self .tour_step >= num_step :
7878 self .tour_step = 0
79-
80- self .refresh ()
79+
80+ # avoid printing refresh time
81+ self ._refresh (update_colors = False , update_components = False )
8182
8283 def next_face (self ):
8384 self .n_face += 1
@@ -118,39 +119,25 @@ def on_channel_visibility_changed(self):
118119 def apply_dot (self , data ):
119120 projected = np .dot (data [:, self .selected_comp ], self .projection [self .selected_comp , :])
120121 return projected
121-
122- def get_plotting_data (self , concatenated = True ):
123-
124- if concatenated :
125- # panel prefer concatenated
126- visible_unit_indices = self .controller .get_visible_unit_indices ()
127- spike_indices = np .flatnonzero (np .isin (self .pc_unit_index , visible_unit_indices ))
128- projected = self .apply_dot (self .data [spike_indices , :])
129- scatter_x = projected [:, 0 ]
130- scatter_y = projected [:, 1 ]
131-
132- # set new limit
133- if len (projected ) > 0 and self .auto_update_limit :
122+
123+ def get_plotting_data (self , return_spike_indices = False ):
124+ scatter_x = {}
125+ scatter_y = {}
126+ all_limits = []
127+ spike_indices = {}
128+ for unit_ind , unit_id in self .controller .iter_visible_units ():
129+ mask = np .flatnonzero (self .pc_unit_index == unit_ind )
130+ projected = self .apply_dot (self .data [mask , :])
131+ scatter_x [unit_id ] = projected [:, 0 ]
132+ scatter_y [unit_id ] = projected [:, 1 ]
133+ if self .auto_update_limit and len (projected ) > 0 :
134134 projected_2d = projected [:, :2 ]
135- self .limit = float (np .percentile (np .abs (projected_2d ), 95 ) * 2. )
136- else :
137- # qt prefer by unit (because no need for color vectors)
138- scatter_x = {}
139- scatter_y = {}
140- all_limits = []
141- spike_indices = None
142- for unit_ind , unit_id in self .controller .iter_visible_units ():
143- mask = np .flatnonzero (self .pc_unit_index == unit_ind )
144- projected = self .apply_dot (self .data [mask , :])
145- scatter_x [unit_id ] = projected [:, 0 ]
146- scatter_y [unit_id ] = projected [:, 1 ]
147- if self .auto_update_limit and len (projected ) > 0 :
148- projected_2d = projected [:, :2 ]
149- all_limits .append (float (np .percentile (np .abs (projected_2d ), 95 ) * 2. ))
150- if len (all_limits ) > 0 and self .auto_update_limit :
151- self .limit = max (all_limits )
135+ all_limits .append (float (np .percentile (np .abs (projected_2d ), 95 ) * 2. ))
136+ if return_spike_indices :
137+ spike_indices [unit_id ] = mask
138+ if len (all_limits ) > 0 and self .auto_update_limit :
139+ self .limit = max (all_limits )
152140
153-
154141 self .limit = max (self .limit , 0.1 ) # ensure limit is at least 0.1
155142
156143 mask = np .isin (self .random_spikes_indices , self .controller .get_indices_spike_selected ())
@@ -163,7 +150,10 @@ def get_plotting_data(self, concatenated=True):
163150 selected_scatter_x = projected_select [:, 0 ]
164151 selected_scatter_y = projected_select [:, 1 ]
165152
166- return scatter_x , scatter_y , spike_indices , selected_scatter_x , selected_scatter_y
153+ if return_spike_indices :
154+ return scatter_x , scatter_y , selected_scatter_x , selected_scatter_y , spike_indices
155+ else :
156+ return scatter_x , scatter_y , selected_scatter_x , selected_scatter_y
167157
168158
169159 def update_selected_components (self ):
@@ -286,7 +276,7 @@ def _qt_refresh(self, update_components=True):
286276 # self.scatter.setData(x=scatter_x, y=scatter_y, brush=scatter_colors, pen=pg.mkPen(None))
287277 # self.scatter_select.setData(selected_scatter_x, selected_scatter_y)
288278
289- scatter_x , scatter_y , spike_indices , selected_scatter_x , selected_scatter_y = self .get_plotting_data (concatenated = False )
279+ scatter_x , scatter_y , selected_scatter_x , selected_scatter_y = self .get_plotting_data ()
290280 for unit_index , unit_id in self .controller .iter_visible_units ():
291281 color = self .get_unit_color (unit_id )
292282 self .scatter .addPoints (x = scatter_x [unit_id ], y = scatter_y [unit_id ], pen = pg .mkPen (None ), brush = color )
@@ -440,20 +430,28 @@ def _panel_make_layout(self):
440430
441431 self .tour_timer = None
442432
443- def _panel_refresh (self , update_components = True ):
433+ def _panel_refresh (self , update_components = True , update_colors = True ):
444434 if update_components :
445435 self .update_selected_components ()
446- scatter_x , scatter_y , spike_indices , selected_scatter_x , selected_scatter_y = self .get_plotting_data (concatenated = True )
436+ scatter_x , scatter_y , selected_scatter_x , selected_scatter_y = self .get_plotting_data ()
447437
448- # format rgba
449- spike_colors = self .controller .get_spike_colors (self .pc_unit_index [spike_indices ])
438+ xs , ys , colors = [], [], []
439+ for unit_id in scatter_x .keys ():
440+ color = self .get_unit_color (unit_id )
441+ xs .extend (scatter_x [unit_id ])
442+ ys .extend (scatter_y [unit_id ])
443+ if update_colors :
444+ colors .extend ([color ] * len (scatter_x [unit_id ]))
450445
446+ if not update_colors :
447+ colors = self .scatter_source .data .get ("color" )
451448
452449 self .scatter_source .data = {
453- "x" : scatter_x ,
454- "y" : scatter_y ,
455- "color" : spike_colors ,
450+ "x" : xs ,
451+ "y" : ys ,
452+ "color" : colors ,
456453 }
454+
457455 self .scatter_select_source .data = {
458456 "x" : selected_scatter_x ,
459457 "y" : selected_scatter_y ,
0 commit comments