Skip to content

Commit

Permalink
Allow copy for scalar and nested sequences when converting data to nu…
Browse files Browse the repository at this point in the history
…mpy arrays
  • Loading branch information
loichuder committed Aug 2, 2024
1 parent 53d7d0f commit 09c262a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
2 changes: 1 addition & 1 deletion h5grove/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def data_stats(
return get_array_stats(data)

def _get_finite_data(self, selection: Selection) -> np.ndarray:
data = np.array(self.data(selection), copy=False) # So it works with scalars
data = np.asarray(self.data(selection)) # So it works with scalars

if not np.issubdtype(data.dtype, np.floating):
return data
Expand Down
2 changes: 1 addition & 1 deletion h5grove/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def encode(content: Any, encoding: Optional[str] = "json") -> Response:
headers={"Content-Type": "application/json"},
)

content_array = np.array(content, copy=False)
content_array = np.asarray(content)

if encoding == "bin":
return Response(
Expand Down
23 changes: 22 additions & 1 deletion test/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ def test_data_on_array_with_format(self, server, format_arg):

assert np.array_equal(retrieved_data, data)

# TODO: What should we do for csv, tiff
@pytest.mark.parametrize("format_arg", ("json", "bin", "npy"))
def test_data_on_scalar_with_format(self, server, format_arg):
"""Test /data/ endpoint on scalar dataset"""
# Test condition
tested_h5entity_path = "/entry/scalar"
data = 5

filename = "test.h5"
with h5py.File(server.served_directory / filename, mode="w") as h5file:
dset = h5file.create_dataset(tested_h5entity_path, data=data)
dtype = dset.dtype
shape = dset.shape

response = server.get(
f"/data/?{urlencode({'file': filename, 'path': tested_h5entity_path, 'format': format_arg})}"
)
retrieved_data = decode_array_response(response, format_arg, dtype.str, shape)

assert np.array_equal(retrieved_data, data)

@pytest.mark.parametrize("format_arg", ("npy", "bin"))
def test_data_on_array_with_dtype_safe(
self,
Expand Down Expand Up @@ -114,7 +135,7 @@ def test_data_on_slice_with_format_and_flatten(self, server, format_arg):
response = server.get(
f"/data/?{urlencode({'file': filename, 'path': tested_h5entity_path, 'selection': '100,0', 'format': format_arg, 'flatten': True})}"
)
retrieved_data = np.array(decode_response(response, format_arg))
retrieved_data = np.asarray(decode_response(response, format_arg))

assert retrieved_data - data[100, 0] < 1e-8

Expand Down
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def decode_array_response(
assert content_type == "application/octet-stream"
return np.frombuffer(response.content, dtype=dtype).reshape(shape)

return np.array(decode_response(response, format), copy=False)
return np.asarray(decode_response(response, format))


def assert_error_response(response: Response, error_code: int):
Expand Down

0 comments on commit 09c262a

Please sign in to comment.