diff --git a/requirements.txt b/requirements.txt index c2540ed..80b306a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ # pinned dependencies to reproduce a working development environment hdmf>=3.1.1 pynwb>=2.0.0 -probeinterface>=0.2.17 +probeinterface>=0.2.18 diff --git a/src/pynwb/ndx_probeinterface/io.py b/src/pynwb/ndx_probeinterface/io.py index da22727..2a192f9 100644 --- a/src/pynwb/ndx_probeinterface/io.py +++ b/src/pynwb/ndx_probeinterface/io.py @@ -11,7 +11,8 @@ inverted_unit_map = {v: k for k, v in unit_map.items()} -def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup]) -> List[Device]: +def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup], + name: Optional[str] = None) -> List[Device]: """ Construct ndx-probeinterface Probe devices from a probeinterface.Probe @@ -32,7 +33,7 @@ def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup]) -> List[D probes = probe_or_probegroup.probes devices = [] for probe in probes: - devices.append(_single_probe_to_nwb_device(probe)) + devices.append(_single_probe_to_nwb_device(probe, name=name)) return devices @@ -116,8 +117,8 @@ def to_probeinterface(ndx_probe) -> Probe: return probeinterface_probe -def _single_probe_to_nwb_device(probe: Probe): - from pynwb import load_namespaces, get_class +def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None): + from pynwb import get_class Probe = get_class("Probe", "ndx-probeinterface") ContactTable = get_class("ContactTable", "ndx-probeinterface") @@ -126,14 +127,8 @@ def _single_probe_to_nwb_device(probe: Probe): contact_plane_axes = probe.contact_plane_axes contact_ids = probe.contact_ids contacts_arr = probe.to_numpy() - shank_ids = probe.shank_ids planar_contour = probe.probe_planar_contour - if shank_ids is not None: - unique_shanks = np.unique(shank_ids) - else: - unique_shanks = ["0"] - shape_keys = [] for shape_params in probe.contact_shape_params: keys = list(shape_params.keys()) @@ -161,21 +156,13 @@ def _single_probe_to_nwb_device(probe: Probe): kwargs["shank_id"] = probe.shank_ids[index] contact_table.add_row(kwargs) - if "serial_number" in probe.annotations: - serial_number = probe.annotations["serial_number"] - else: - serial_number = None - if "model_name" in probe.annotations: - model_name = probe.annotations["model_name"] - else: - model_name = None - if "manufacturer" in probe.annotations: - manufacturer = probe.annotations["manufacturer"] - else: - manufacturer = None + serial_number = probe.serial_number + model_name = probe.model_name + manufacturer = probe.manufacturer + name = name if name is not None else probe.name probe_device = Probe( - name=probe.annotations["name"], + name=name, model_name=model_name, serial_number=serial_number, manufacturer=manufacturer, diff --git a/src/pynwb/tests/test_probe.py b/src/pynwb/tests/test_probe.py index f4c0656..0697c00 100644 --- a/src/pynwb/tests/test_probe.py +++ b/src/pynwb/tests/test_probe.py @@ -78,6 +78,9 @@ def test_constructor_from_probe_single_shank(self): contact_table = device_w_indices.contact_table np.testing.assert_array_equal(contact_table["device_channel_index_pi"][:], device_channel_indices) + devices_w_names = Probe.from_probeinterface(probe, name="Test Probe") + assert devices_w_names[0].name == "Test Probe" + def test_constructor_from_probe_multi_shank(self): """Test that the constructor from Probe sets values as expected for multi-shank.""" @@ -110,7 +113,7 @@ def test_constructor_from_probegroup(self): """Test that the constructor from probegroup sets values as expected.""" probegroup = self.probegroup - global_device_channel_indices = np.arange(probegroup.get_channel_count()) + global_device_channel_indices = np.arange(probegroup.get_contact_count()) probegroup.set_global_device_channel_indices(global_device_channel_indices) devices = Probe.from_probeinterface(probegroup) probes = probegroup.probes