Skip to content

Commit 0611906

Browse files
committed
Fix color mode and NDscatter in panel
1 parent b905bcd commit 0611906

File tree

9 files changed

+117
-79
lines changed

9 files changed

+117
-79
lines changed

spikeinterface_gui/backend_panel.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def notify_active_view_updated(self):
4242
# views
4343
self.param.trigger("active_view_updated")
4444

45+
def notify_unit_color_changed(self):
46+
self.param.trigger("unit_color_changed")
47+
4548

4649
class SignalHandler(param.Parameterized):
4750
def __init__(self, controller, parent=None):
@@ -135,16 +138,16 @@ class SettingsProxy:
135138
# for instance self.settings['my_params'] instead of self.settings.my_params
136139
# self.settings['my_params'] = value instead of self.settings.my_params = value
137140
def __init__(self, myparametrized):
138-
self._parametrized = myparametrized
141+
self._parameterized = myparametrized
139142

140143
def __getitem__(self, key):
141-
return getattr(self._parametrized, key)
144+
return getattr(self._parameterized, key)
142145

143146
def __setitem__(self, key, value):
144-
self._parametrized.param.update(**{key:value})
147+
self._parameterized.param.update(**{key:value})
145148

146149
def keys(self):
147-
return list(p for p in self._parametrized.param if p != "name")
150+
return list(p for p in self._parameterized.param if p != "name")
148151

149152

150153
def create_dynamic_parameterized(settings):
@@ -177,7 +180,7 @@ def create_settings(view):
177180

178181
def listen_setting_changes(view):
179182
for setting_data in view._settings:
180-
view.settings._parametrized.param.watch(view.on_settings_changed, setting_data["name"])
183+
view.settings._parameterized.param.watch(view.on_settings_changed, setting_data["name"])
181184

182185

183186

@@ -226,7 +229,7 @@ def make_views(self):
226229

227230
tabs = [("📊", view.layout)]
228231
if view_class._settings is not None:
229-
settings = pn.Param(view.settings._parametrized, sizing_mode="stretch_height",
232+
settings = pn.Param(view.settings._parameterized, sizing_mode="stretch_height",
230233
name=f"{view_name.capitalize()} settings")
231234
if view_class._need_compute:
232235
compute_button = pn.widgets.Button(name="Compute", button_type="primary")

spikeinterface_gui/controller.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
_default_main_settings = dict(
2525
max_visible_units=10,
26-
color_mode='all_colorized',
26+
color_mode='color_by_unit',
2727
)
2828

2929
# TODO handle return_scaled
@@ -399,16 +399,16 @@ def refresh_colors(self):
399399
elif self.backend == "panel":
400400
pass
401401

402-
if self.main_settings['color_mode'] == 'all_colorized':
402+
if self.main_settings['color_mode'] == 'color_by_unit':
403403
self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar',
404404
shuffle=True, seed=42)
405-
elif self.main_settings['color_mode'] == 'colorize_only_visible':
405+
elif self.main_settings['color_mode'] == 'color_only_visible':
406406
unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar',
407407
shuffle=True, seed=42)
408408
self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids}
409409
for unit_id in self.get_visible_unit_ids():
410410
self.colors[unit_id] = unit_colors[unit_id]
411-
elif self.main_settings['color_mode'] == 'colorize_by_visibility':
411+
elif self.main_settings['color_mode'] == 'color_by_visibility':
412412
self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids}
413413
import matplotlib.pyplot as plt
414414
cmap = plt.colormaps['tab10']

spikeinterface_gui/mainsettingsview.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
# this control controller.main_settings
55
main_settings = [
66
{'name': 'max_visible_units', 'type': 'int', 'value' : 10 },
7-
{'name': 'color_mode', 'type': 'list', 'value' : 'all_colorized',
8-
'limits': ['all_colorized', 'colorize_only_visible', 'colorize_by_visibility']},
9-
7+
{'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit',
8+
'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']},
109
]
1110

1211

1312
class MainSettingsView(ViewBase):
14-
_supported_backend = ['qt', ]
13+
_supported_backend = ['qt', 'panel']
1514
_settings = None
1615
_depend_on = []
1716
_need_compute = False
@@ -62,17 +61,29 @@ def _qt_make_layout(self):
6261
self.main_settings.param('max_visible_units').sigValueChanged.connect(self.on_max_visible_units_changed)
6362
self.main_settings.param('color_mode').sigValueChanged.connect(self.on_change_color_mode)
6463

65-
66-
6764

6865
def _qt_refresh(self):
6966
pass
7067

7168

7269
## panel zone
7370
def _panel_make_layout(self):
74-
pass
75-
71+
import panel as pn
72+
from .backend_panel import create_dynamic_parameterized, SettingsProxy
73+
74+
# Create method and arguments layout
75+
self.main_settings = SettingsProxy(create_dynamic_parameterized(main_settings))
76+
self.main_settings_layout = pn.Param(self.main_settings._parameterized, sizing_mode="stretch_both",
77+
name=f"Main settings")
78+
self.main_settings._parameterized.param.watch(self._panel_on_max_visible_units_changed, 'max_visible_units')
79+
self.main_settings._parameterized.param.watch(self._panel_on_change_color_mode, 'color_mode')
80+
self.layout = pn.Column(self.main_settings_layout, sizing_mode="stretch_both")
81+
82+
def _panel_on_max_visible_units_changed(self, event):
83+
self.on_max_visible_units_changed()
84+
85+
def _panel_on_change_color_mode(self, event):
86+
self.on_change_color_mode()
7687

7788
def _panel_refresh(self):
7889
pass

spikeinterface_gui/mergeview.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,16 @@ def _panel_make_layout(self):
285285

286286
# Create method and arguments layout
287287
method_settings = SettingsProxy(create_dynamic_parameterized(self._methods))
288-
self.method_selector = pn.Param(method_settings._parametrized, sizing_mode="stretch_width", name="Method")
288+
self.method_selector = pn.Param(method_settings._parameterized, sizing_mode="stretch_width", name="Method")
289289
for setting_data in self._methods:
290-
method_settings._parametrized.param.watch(self._panel_on_method_change, setting_data["name"])
290+
method_settings._parameterized.param.watch(self._panel_on_method_change, setting_data["name"])
291291

292292
self.method_params = {}
293293
self.method_params_selectors = {}
294294
for method, params in self._method_params.items():
295295
method_params = SettingsProxy(create_dynamic_parameterized(params))
296296
self.method_params[method] = method_params
297-
self.method_params_selectors[method] = pn.Param(method_params._parametrized, sizing_mode="stretch_width",
297+
self.method_params_selectors[method] = pn.Param(method_params._parameterized, sizing_mode="stretch_width",
298298
name=f"{method.capitalize()} parameters")
299299
self.method = list(self.method_params.keys())[0]
300300

spikeinterface_gui/ndscatterview.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ def new_tour_step(self):
7676
self.tour_step+=1
7777
if self.tour_step>=num_step:
7878
self.tour_step = 0
79-
80-
self.refresh()
79+
80+
# avoid printing refresh time
81+
self._refresh(update_colors=False, update_components=False)
8182

8283
def next_face(self):
8384
self.n_face += 1
@@ -118,39 +119,25 @@ def on_channel_visibility_changed(self):
118119
def apply_dot(self, data):
119120
projected = np.dot(data[:, self.selected_comp], self.projection[self.selected_comp, :])
120121
return projected
121-
122-
def get_plotting_data(self, concatenated=True):
123-
124-
if concatenated:
125-
# panel prefer concatenated
126-
visible_unit_indices = self.controller.get_visible_unit_indices()
127-
spike_indices = np.flatnonzero(np.isin(self.pc_unit_index, visible_unit_indices))
128-
projected = self.apply_dot(self.data[spike_indices, :])
129-
scatter_x = projected[:, 0]
130-
scatter_y = projected[:, 1]
131-
132-
# set new limit
133-
if len(projected) > 0 and self.auto_update_limit:
122+
123+
def get_plotting_data(self, return_spike_indices=False):
124+
scatter_x = {}
125+
scatter_y = {}
126+
all_limits = []
127+
spike_indices = {}
128+
for unit_ind, unit_id in self.controller.iter_visible_units():
129+
mask = np.flatnonzero(self.pc_unit_index == unit_ind)
130+
projected = self.apply_dot(self.data[mask, :])
131+
scatter_x[unit_id] = projected[:, 0]
132+
scatter_y[unit_id] = projected[:, 1]
133+
if self.auto_update_limit and len(projected) > 0:
134134
projected_2d = projected[:, :2]
135-
self.limit = float(np.percentile(np.abs(projected_2d), 95) * 2.)
136-
else:
137-
# qt prefer by unit (because no need for color vectors)
138-
scatter_x = {}
139-
scatter_y = {}
140-
all_limits = []
141-
spike_indices = None
142-
for unit_ind, unit_id in self.controller.iter_visible_units():
143-
mask = np.flatnonzero(self.pc_unit_index == unit_ind)
144-
projected = self.apply_dot(self.data[mask, :])
145-
scatter_x[unit_id] = projected[:, 0]
146-
scatter_y[unit_id] = projected[:, 1]
147-
if self.auto_update_limit and len(projected) > 0:
148-
projected_2d = projected[:, :2]
149-
all_limits.append(float(np.percentile(np.abs(projected_2d), 95) * 2.))
150-
if len(all_limits) > 0 and self.auto_update_limit:
151-
self.limit = max(all_limits)
135+
all_limits.append(float(np.percentile(np.abs(projected_2d), 95) * 2.))
136+
if return_spike_indices:
137+
spike_indices[unit_id] = mask
138+
if len(all_limits) > 0 and self.auto_update_limit:
139+
self.limit = max(all_limits)
152140

153-
154141
self.limit = max(self.limit, 0.1) # ensure limit is at least 0.1
155142

156143
mask = np.isin(self.random_spikes_indices, self.controller.get_indices_spike_selected())
@@ -163,7 +150,10 @@ def get_plotting_data(self, concatenated=True):
163150
selected_scatter_x = projected_select[:, 0]
164151
selected_scatter_y = projected_select[:, 1]
165152

166-
return scatter_x, scatter_y, spike_indices, selected_scatter_x, selected_scatter_y
153+
if return_spike_indices:
154+
return scatter_x, scatter_y, selected_scatter_x, selected_scatter_y, spike_indices
155+
else:
156+
return scatter_x, scatter_y, selected_scatter_x, selected_scatter_y
167157

168158

169159
def update_selected_components(self):
@@ -286,7 +276,7 @@ def _qt_refresh(self, update_components=True):
286276
# self.scatter.setData(x=scatter_x, y=scatter_y, brush=scatter_colors, pen=pg.mkPen(None))
287277
# self.scatter_select.setData(selected_scatter_x, selected_scatter_y)
288278

289-
scatter_x, scatter_y, spike_indices, selected_scatter_x, selected_scatter_y = self.get_plotting_data(concatenated=False)
279+
scatter_x, scatter_y, selected_scatter_x, selected_scatter_y = self.get_plotting_data()
290280
for unit_index, unit_id in self.controller.iter_visible_units():
291281
color = self.get_unit_color(unit_id)
292282
self.scatter.addPoints(x=scatter_x[unit_id], y=scatter_y[unit_id], pen=pg.mkPen(None), brush=color)
@@ -440,20 +430,28 @@ def _panel_make_layout(self):
440430

441431
self.tour_timer = None
442432

443-
def _panel_refresh(self, update_components=True):
433+
def _panel_refresh(self, update_components=True, update_colors=True):
444434
if update_components:
445435
self.update_selected_components()
446-
scatter_x, scatter_y, spike_indices, selected_scatter_x, selected_scatter_y = self.get_plotting_data(concatenated=True)
436+
scatter_x, scatter_y, selected_scatter_x, selected_scatter_y = self.get_plotting_data()
447437

448-
# format rgba
449-
spike_colors = self.controller.get_spike_colors(self.pc_unit_index[spike_indices])
438+
xs, ys, colors = [], [], []
439+
for unit_id in scatter_x.keys():
440+
color = self.get_unit_color(unit_id)
441+
xs.extend(scatter_x[unit_id])
442+
ys.extend(scatter_y[unit_id])
443+
if update_colors:
444+
colors.extend([color] * len(scatter_x[unit_id]))
450445

446+
if not update_colors:
447+
colors = self.scatter_source.data.get("color")
451448

452449
self.scatter_source.data = {
453-
"x": scatter_x,
454-
"y": scatter_y,
455-
"color": spike_colors,
450+
"x": xs,
451+
"y": ys,
452+
"color": colors,
456453
}
454+
457455
self.scatter_select_source.data = {
458456
"x": selected_scatter_x,
459457
"y": selected_scatter_y,

spikeinterface_gui/probeview.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,8 +580,8 @@ def _panel_on_pan_end(self, event):
580580
# Update unit visibility
581581
new_x, new_y = self.unit_circle.center
582582
self.update_unit_visibility(new_x, new_y, self.settings['radius_units'])
583-
self._panel_update_unit_glyphs() # Update glyphs to reflect new visibility
584583
self.notify_unit_visibility_changed()
584+
self._panel_update_unit_glyphs() # Update glyphs to reflect new visibility
585585

586586
elif self.should_resize_channel_circle is not None:
587587
x_center, y_center = self.channel_circle.center
@@ -609,8 +609,8 @@ def _panel_on_pan_end(self, event):
609609
self.settings["radius_units"] = new_radius
610610
# Update unit visibility
611611
self.update_unit_visibility(x_center, y_center, self.settings['radius_units'])
612-
self._panel_update_unit_glyphs()
613612
self.notify_unit_visibility_changed()
613+
self._panel_update_unit_glyphs()
614614

615615
self.should_move_channel_circle = None
616616
self.should_move_unit_circle = None
@@ -647,7 +647,6 @@ def _panel_on_tap(self, event):
647647
if len(self.controller.get_visible_unit_ids()) == 1:
648648
select_only = True
649649

650-
self._panel_update_unit_glyphs()
651650

652651
if select_only:
653652
# Update selection circles
@@ -659,6 +658,7 @@ def _panel_on_tap(self, event):
659658
self.controller.set_channel_visibility(visible_channel_inds)
660659
self.notify_channel_visibility_changed
661660
self.notify_unit_visibility_changed()
661+
self._panel_update_unit_glyphs()
662662

663663

664664
def circle_from_roi(roi):

spikeinterface_gui/unitlist.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import warnings
22
import numpy as np
3-
import time
43

54
from .view_base import ViewBase
65

@@ -22,13 +21,13 @@ def __init__(self, controller=None, parent=None, backend="qt"):
2221
## common ##
2322
def show_all(self):
2423
self.controller.set_visible_unit_ids(self.controller.unit_ids)
25-
self.refresh()
2624
self.notify_unit_visibility_changed()
25+
self.refresh()
2726

2827
def hide_all(self):
2928
self.controller.set_all_unit_visibility_off()
30-
self.refresh()
3129
self.notify_unit_visibility_changed()
30+
self.refresh()
3231

3332
def get_selected_unit_ids(self):
3433
if self.backend == 'qt':
@@ -605,6 +604,10 @@ def _panel_refresh(self):
605604
visible.append(dict_unit_visible[unit_id])
606605
df.loc[:, "visible"] = visible
607606

607+
if self.controller.main_settings['color_mode'] in ('color_by_visibility', 'color_only_visible'):
608+
# in the mode color change dynamically but without notify to avoid double refresh
609+
self._panel_refresh_colors()
610+
608611
table_columns = self.table.value.columns
609612

610613
for table_col in table_columns:
@@ -650,8 +653,8 @@ def _panel_on_visible_checkbox_toggled(self, row):
650653

651654
# update the visible column
652655
self.table.value.loc[self.controller.unit_ids, "visible"] = self.controller.get_units_visibility_mask()
653-
self.refresh()
654656
self.notify_unit_visibility_changed()
657+
self.refresh()
655658

656659
def _panel_on_unit_visibility_changed(self):
657660
# update selection to match visible units
@@ -661,6 +664,25 @@ def _panel_on_unit_visibility_changed(self):
661664
self.table.selection = rows_to_select
662665
self.refresh()
663666

667+
def _panel_refresh_colors(self):
668+
import matplotlib.colors as mcolors
669+
670+
unit_ids_data = []
671+
for unit_id in self.table.value.index.values:
672+
unit_ids_data.append(
673+
{
674+
"id": str(unit_id),
675+
"color": mcolors.to_hex(self.controller.get_unit_color(unit_id))
676+
}
677+
)
678+
self.table.value.loc[:, "unit_id"] = unit_ids_data
679+
680+
def _panel_on_unit_color_changed(self):
681+
# here we update the unit colors, since they are then fixed in the table
682+
# during refresh
683+
self._panel_refresh_colors()
684+
self.refresh()
685+
664686
def _panel_on_edit(self, event):
665687
column = event.column
666688
if self.label_definitions is not None and column in self.label_definitions:

0 commit comments

Comments
 (0)