Skip to content

Commit 757f27a

Browse files
authored
Merge pull request #140 from samuelgarcia/some_fix
Implement color mode in main settings.
2 parents c7a95f2 + 0e9faed commit 757f27a

File tree

14 files changed

+304
-90
lines changed

14 files changed

+304
-90
lines changed

spikeinterface_gui/backend_panel.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class SignalNotifier(param.Parameterized):
1313
manual_curation_updated = param.Event()
1414
time_info_updated = param.Event()
1515
active_view_updated = param.Event()
16+
unit_color_changed = param.Event()
1617

1718
def __init__(self, view=None):
1819
param.Parameterized.__init__(self)
@@ -41,6 +42,9 @@ def notify_active_view_updated(self):
4142
# views
4243
self.param.trigger("active_view_updated")
4344

45+
def notify_unit_color_changed(self):
46+
self.param.trigger("unit_color_changed")
47+
4448

4549
class SignalHandler(param.Parameterized):
4650
def __init__(self, controller, parent=None):
@@ -61,6 +65,7 @@ def connect_view(self, view):
6165
view.notifier.param.watch(self.on_manual_curation_updated, "manual_curation_updated")
6266
view.notifier.param.watch(self.on_time_info_updated, "time_info_updated")
6367
view.notifier.param.watch(self.on_active_view_updated, "active_view_updated")
68+
view.notifier.param.watch(self.on_unit_color_changed, "unit_color_changed")
6469

6570
def on_spike_selection_changed(self, param):
6671
if not self._active:
@@ -112,6 +117,14 @@ def on_active_view_updated(self, param):
112117
view._panel_view_is_active = True
113118
else:
114119
view._panel_view_is_active = False
120+
121+
def on_unit_color_changed(self, param):
122+
if not self._active:
123+
return
124+
for view in self.controller.views:
125+
if param.obj.view == view:
126+
continue
127+
view.on_unit_color_changed()
115128

116129
param_type_map = {
117130
"float": param.Number,
@@ -125,16 +138,16 @@ class SettingsProxy:
125138
# for instance self.settings['my_params'] instead of self.settings.my_params
126139
# self.settings['my_params'] = value instead of self.settings.my_params = value
127140
def __init__(self, myparametrized):
128-
self._parametrized = myparametrized
141+
self._parameterized = myparametrized
129142

130143
def __getitem__(self, key):
131-
return getattr(self._parametrized, key)
144+
return getattr(self._parameterized, key)
132145

133146
def __setitem__(self, key, value):
134-
self._parametrized.param.update(**{key:value})
147+
self._parameterized.param.update(**{key:value})
135148

136149
def keys(self):
137-
return list(p for p in self._parametrized.param if p != "name")
150+
return list(p for p in self._parameterized.param if p != "name")
138151

139152

140153
def create_dynamic_parameterized(settings):
@@ -167,7 +180,7 @@ def create_settings(view):
167180

168181
def listen_setting_changes(view):
169182
for setting_data in view._settings:
170-
view.settings._parametrized.param.watch(view.on_settings_changed, setting_data["name"])
183+
view.settings._parameterized.param.watch(view.on_settings_changed, setting_data["name"])
171184

172185

173186

@@ -216,7 +229,7 @@ def make_views(self):
216229

217230
tabs = [("📊", view.layout)]
218231
if view_class._settings is not None:
219-
settings = pn.Param(view.settings._parametrized, sizing_mode="stretch_height",
232+
settings = pn.Param(view.settings._parameterized, sizing_mode="stretch_height",
220233
name=f"{view_name.capitalize()} settings")
221234
if view_class._need_compute:
222235
compute_button = pn.widgets.Button(name="Compute", button_type="primary")
@@ -351,8 +364,14 @@ def start_server(mainwindow, address="localhost", port=0):
351364

352365
pn.config.sizing_mode = "stretch_width"
353366

354-
mainwindow.main_layout.servable()
367+
# mainwindow.main_layout.servable()
368+
# TODO alessio : find automatically a port when port = 0
355369

370+
if address != "localhost":
371+
websocket_origin = f"{address}:{port}"
372+
else:
373+
websocket_origin = None
374+
356375
server = pn.serve({"/": mainwindow.main_layout}, address=address, port=port,
357-
show=False, start=True, dev=True, autoreload=True,
376+
show=False, start=True, dev=True, autoreload=True,websocket_origin=websocket_origin,
358377
title="SpikeInterface GUI")

spikeinterface_gui/backend_qt.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class SignalNotifier(QT.QObject):
1818
channel_visibility_changed = QT.pyqtSignal()
1919
manual_curation_updated = QT.pyqtSignal()
2020
time_info_updated = QT.pyqtSignal()
21+
unit_color_changed = QT.pyqtSignal()
2122

2223
def __init__(self, parent=None, view=None):
2324
QT.QObject.__init__(self, parent=parent)
@@ -38,6 +39,10 @@ def notify_manual_curation_updated(self):
3839
def notify_time_info_updated(self):
3940
self.time_info_updated.emit()
4041

42+
def notify_unit_color_changed(self):
43+
self.unit_color_changed.emit()
44+
45+
4146
# Used by controler to handle/callback signals
4247
class SignalHandler(QT.QObject):
4348
def __init__(self, controller, parent=None):
@@ -57,6 +62,7 @@ def connect_view(self, view):
5762
view.notifier.channel_visibility_changed.connect(self.on_channel_visibility_changed)
5863
view.notifier.manual_curation_updated.connect(self.on_manual_curation_updated)
5964
view.notifier.time_info_updated.connect(self.on_time_info_updated)
65+
view.notifier.unit_color_changed.connect(self.on_unit_color_changed)
6066

6167
def on_spike_selection_changed(self):
6268
if not self._active:
@@ -68,6 +74,7 @@ def on_spike_selection_changed(self):
6874
view.on_spike_selection_changed()
6975

7076
def on_unit_visibility_changed(self):
77+
7178
if not self._active:
7279
return
7380
for view in self.controller.views:
@@ -102,6 +109,15 @@ def on_time_info_updated(self):
102109
# do not refresh it self
103110
continue
104111
view.on_time_info_updated()
112+
113+
def on_unit_color_changed(self):
114+
if not self._active:
115+
return
116+
for view in self.controller.views:
117+
if view.qt_widget == self.sender().parent():
118+
# do not refresh it self
119+
continue
120+
view.on_unit_color_changed()
105121

106122

107123
def create_settings(view, parent):

spikeinterface_gui/controller.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
('visible', 'bool'), ('selected', 'bool'), ('rand_selected', 'bool')]
2222

2323

24+
_default_main_settings = dict(
25+
max_visible_units=10,
26+
color_mode='color_by_unit',
27+
)
28+
2429
# TODO handle return_scaled
2530
from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties
2631

@@ -54,6 +59,11 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
5459
self.verbose = verbose
5560
t0 = time.perf_counter()
5661

62+
63+
self.main_settings = _default_main_settings.copy()
64+
65+
66+
5767
self.num_channels = self.analyzer.get_num_channels()
5868
# this now private and shoudl be acess using function
5969
self._visible_unit_ids = [self.unit_ids[0]]
@@ -218,8 +228,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
218228
self.sampling_frequency = self.analyzer.sampling_frequency
219229

220230
# spikeinterface handle colors in matplotlib style tuple values in range (0,1)
221-
self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar',
222-
shuffle=True, seed=42)
231+
self.refresh_colors()
223232

224233
# at init, we set the visible channels as the sparsity of the first unit
225234
self.visible_channel_inds = self.analyzer_sparsity.unit_id_to_channel_indices[self.unit_ids[0]].astype("int64")
@@ -328,9 +337,6 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
328337
print('Curation quality labels are the default ones')
329338
self.has_default_quality_labels = True
330339

331-
self.main_settings = dict(
332-
max_visible_units=10,
333-
)
334340

335341
def check_is_view_possible(self, view_name):
336342
from .viewlist import possible_class_views
@@ -387,6 +393,29 @@ def get_information_txt(self):
387393

388394
return txt
389395

396+
def refresh_colors(self):
397+
if self.backend == "qt":
398+
self._cached_qcolors = {}
399+
elif self.backend == "panel":
400+
pass
401+
402+
if self.main_settings['color_mode'] == 'color_by_unit':
403+
self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar',
404+
shuffle=True, seed=42)
405+
elif self.main_settings['color_mode'] == 'color_only_visible':
406+
unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar',
407+
shuffle=True, seed=42)
408+
self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids}
409+
for unit_id in self.get_visible_unit_ids():
410+
self.colors[unit_id] = unit_colors[unit_id]
411+
elif self.main_settings['color_mode'] == 'color_by_visibility':
412+
self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids}
413+
import matplotlib.pyplot as plt
414+
cmap = plt.colormaps['tab10']
415+
for i, unit_id in enumerate(self.get_visible_unit_ids()):
416+
self.colors[unit_id] = cmap(i)
417+
418+
390419
def get_unit_color(self, unit_id):
391420
# scalar unit_id -> color html or QtColor
392421
return self.colors[unit_id]

spikeinterface_gui/main.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ def run_mainwindow(
2626
skip_extensions=None,
2727
recording=None,
2828
start_app=True,
29-
make_servable=False,
3029
layout_preset=None,
30+
address="localhost",
31+
port=0,
3132
verbose=False,
3233
):
3334
"""
@@ -62,6 +63,10 @@ def run_mainwindow(
6263
If True, the QT app loop is started
6364
layout_preset : str | None
6465
The name of the layout preset. None is default.
66+
address: str, default : "localhost"
67+
For "web" mode only. By default only on local machine.
68+
port: int, default: 0
69+
For "web" mode only. If 0 then the port is automatic.
6570
verbose: bool, default: False
6671
If True, print some information in the console
6772
"""
@@ -111,14 +116,15 @@ def run_mainwindow(
111116
win.show()
112117
if start_app:
113118
app.exec()
119+
114120
elif backend == "panel":
115121
import panel
116122
from .backend_panel import PanelMainWindow, start_server
117123
win = PanelMainWindow(controller, layout_preset=layout_preset)
124+
win.main_layout.servable(title='SpikeInterface GUI')
118125
if start_app:
119-
start_server(win)
120-
elif make_servable:
121-
win.main_layout.servable(title='SpikeInterface GUI')
126+
start_server(win, address=address, port=port)
127+
122128

123129
return win
124130

spikeinterface_gui/mainsettingsview.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from .view_base import ViewBase
22

33

4-
4+
# this control controller.main_settings
55
main_settings = [
66
{'name': 'max_visible_units', 'type': 'int', 'value' : 10 },
7+
{'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit',
8+
'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']},
79
]
810

911

1012
class MainSettingsView(ViewBase):
11-
_supported_backend = ['qt', ]
13+
_supported_backend = ['qt', 'panel']
1214
_settings = None
1315
_depend_on = []
1416
_need_compute = False
@@ -26,7 +28,15 @@ def on_max_visible_units_changed(self):
2628
visible_ids = visible_ids[:max_visible]
2729
self.controller.set_visible_unit_ids(visible_ids)
2830
self.notify_unit_visibility_changed()
31+
32+
def on_change_color_mode(self):
2933

34+
self.controller.main_settings['color_mode'] = self.main_settings['color_mode']
35+
self.controller.refresh_colors()
36+
self.notify_unit_color_changed()
37+
38+
# for view in self.controller.views:
39+
# view.refresh()
3040

3141
## QT zone
3242
def _qt_make_layout(self):
@@ -49,6 +59,7 @@ def _qt_make_layout(self):
4959
self.layout.addWidget(self.tree_main_settings)
5060

5161
self.main_settings.param('max_visible_units').sigValueChanged.connect(self.on_max_visible_units_changed)
62+
self.main_settings.param('color_mode').sigValueChanged.connect(self.on_change_color_mode)
5263

5364

5465
def _qt_refresh(self):
@@ -57,8 +68,22 @@ def _qt_refresh(self):
5768

5869
## panel zone
5970
def _panel_make_layout(self):
60-
pass
61-
71+
import panel as pn
72+
from .backend_panel import create_dynamic_parameterized, SettingsProxy
73+
74+
# Create method and arguments layout
75+
self.main_settings = SettingsProxy(create_dynamic_parameterized(main_settings))
76+
self.main_settings_layout = pn.Param(self.main_settings._parameterized, sizing_mode="stretch_both",
77+
name=f"Main settings")
78+
self.main_settings._parameterized.param.watch(self._panel_on_max_visible_units_changed, 'max_visible_units')
79+
self.main_settings._parameterized.param.watch(self._panel_on_change_color_mode, 'color_mode')
80+
self.layout = pn.Column(self.main_settings_layout, sizing_mode="stretch_both")
81+
82+
def _panel_on_max_visible_units_changed(self, event):
83+
self.on_max_visible_units_changed()
84+
85+
def _panel_on_change_color_mode(self, event):
86+
self.on_change_color_mode()
6287

6388
def _panel_refresh(self):
6489
pass

spikeinterface_gui/mergeview.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,16 @@ def _panel_make_layout(self):
285285

286286
# Create method and arguments layout
287287
method_settings = SettingsProxy(create_dynamic_parameterized(self._methods))
288-
self.method_selector = pn.Param(method_settings._parametrized, sizing_mode="stretch_width", name="Method")
288+
self.method_selector = pn.Param(method_settings._parameterized, sizing_mode="stretch_width", name="Method")
289289
for setting_data in self._methods:
290-
method_settings._parametrized.param.watch(self._panel_on_method_change, setting_data["name"])
290+
method_settings._parameterized.param.watch(self._panel_on_method_change, setting_data["name"])
291291

292292
self.method_params = {}
293293
self.method_params_selectors = {}
294294
for method, params in self._method_params.items():
295295
method_params = SettingsProxy(create_dynamic_parameterized(params))
296296
self.method_params[method] = method_params
297-
self.method_params_selectors[method] = pn.Param(method_params._parametrized, sizing_mode="stretch_width",
297+
self.method_params_selectors[method] = pn.Param(method_params._parameterized, sizing_mode="stretch_width",
298298
name=f"{method.capitalize()} parameters")
299299
self.method = list(self.method_params.keys())[0]
300300

0 commit comments

Comments
 (0)