Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow copy for scalar and nested sequences when converting data to numpy arrays #95

Merged
merged 4 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion h5grove/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def data_stats(
return get_array_stats(data)

def _get_finite_data(self, selection: Selection) -> np.ndarray:
data = np.array(self.data(selection), copy=False) # So it works with scalars
data = np.asarray(self.data(selection)) # So it works with scalars

if not np.issubdtype(data.dtype, np.floating):
return data
Expand Down Expand Up @@ -288,6 +288,7 @@ def get_content_from_file(
fallback=LinkResolution.ONLY_VALID,
)
except QueryArgumentError as e:
f.close()
raise create_error(422, str(e))

try:
Expand Down
23 changes: 14 additions & 9 deletions h5grove/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def encode(content: Any, encoding: Optional[str] = "json") -> Response:
headers={"Content-Type": "application/json"},
)

content_array = np.array(content, copy=False)
content_array = np.asarray(content)

if encoding == "bin":
return Response(
Expand All @@ -126,21 +126,26 @@ def encode(content: Any, encoding: Optional[str] = "json") -> Response:
f"Unsupported encoding {encoding} for non-numeric content"
)

if encoding == "csv":
if encoding == "npy":
return Response(
csv_encode(content_array),
npy_encode(content_array),
headers={
"Content-Type": "text/csv",
"Content-Disposition": 'attachment; filename="data.csv"',
"Content-Type": "application/octet-stream",
"Content-Disposition": 'attachment; filename="data.npy"',
},
)

if encoding == "npy":
if content_array.ndim == 0:
raise QueryArgumentError(
f"Unsupported encoding {encoding} for empty and scalar datasets"
)

if encoding == "csv":
return Response(
npy_encode(content_array),
csv_encode(content_array),
headers={
"Content-Type": "application/octet-stream",
"Content-Disposition": 'attachment; filename="data.npy"',
"Content-Type": "text/csv",
"Content-Disposition": 'attachment; filename="data.csv"',
},
)

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ dev =
types-contextvars
types-dataclasses
types-orjson
types-pkg-resources
types-setuptools

# E501 (line too long) ignored for now
# E203 and W503 incompatible with black formatting (https://black.readthedocs.io/en/stable/compatible_configs.html#flake8)
Expand Down
39 changes: 38 additions & 1 deletion test/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,26 @@ def test_data_on_array_with_format(self, server, format_arg):

assert np.array_equal(retrieved_data, data)

@pytest.mark.parametrize("format_arg", ("json", "bin", "npy"))
def test_data_on_scalar_with_format(self, server, format_arg):
"""Test /data/ endpoint on scalar dataset"""
# Test condition
tested_h5entity_path = "/entry/scalar"
data = 5

filename = "test.h5"
with h5py.File(server.served_directory / filename, mode="w") as h5file:
dset = h5file.create_dataset(tested_h5entity_path, data=data)
dtype = dset.dtype
shape = dset.shape

response = server.get(
f"/data/?{urlencode({'file': filename, 'path': tested_h5entity_path, 'format': format_arg})}"
)
retrieved_data = decode_array_response(response, format_arg, dtype.str, shape)

assert np.array_equal(retrieved_data, data)

@pytest.mark.parametrize("format_arg", ("npy", "bin"))
def test_data_on_array_with_dtype_safe(
self,
Expand Down Expand Up @@ -114,7 +134,7 @@ def test_data_on_slice_with_format_and_flatten(self, server, format_arg):
response = server.get(
f"/data/?{urlencode({'file': filename, 'path': tested_h5entity_path, 'selection': '100,0', 'format': format_arg, 'flatten': True})}"
)
retrieved_data = np.array(decode_response(response, format_arg))
retrieved_data = np.asarray(decode_response(response, format_arg))

assert retrieved_data - data[100, 0] < 1e-8

Expand Down Expand Up @@ -575,3 +595,20 @@ def test_422_on_invalid_query_arg(self, server):
f"/meta/?file={filename}&path={path}&resolve_links={invalid_link_resolution}",
422,
)

@pytest.mark.parametrize("format_arg", ("csv", "tiff"))
def test_422_on_format_incompatible_with_empty_or_scalar_datasets(
self, server, format_arg
):
filename = "test.h5"

with h5py.File(server.served_directory / filename, mode="w") as h5file:
h5file["scalar"] = 55
h5file["empty"] = h5py.Empty(dtype="<4f")

server.assert_error_code(
f"/data/?file={filename}&path=/scalar&format={format_arg}", 422
)
server.assert_error_code(
f"/data/?file={filename}&path=/empty&format={format_arg}", 422
)
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def decode_array_response(
assert content_type == "application/octet-stream"
return np.frombuffer(response.content, dtype=dtype).reshape(shape)

return np.array(decode_response(response, format), copy=False)
return np.asarray(decode_response(response, format))


def assert_error_response(response: Response, error_code: int):
Expand Down
Loading