diff --git a/notebooks/topostats_file_helper_example.ipynb b/notebooks/topostats_file_helper_example.ipynb new file mode 100644 index 00000000000..9458162a237 --- /dev/null +++ b/notebooks/topostats_file_helper_example.ipynb @@ -0,0 +1,141 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import needed libraries\n", + "import numpy as np\n", + "from topostats.io import TopoFileHelper\n", + "from IPython.display import clear_output\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load topostats file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load topostats file\n", + "file = \"../tests/resources/file.topostats\"\n", + "helper = TopoFileHelper(file)\n", + "# Clear logging output\n", + "clear_output(wait=False)\n", + "print(\"File loaded\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Print the structure of the file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Print the structure of the file\n", + "helper.pretty_print_structure()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Find data within the file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Find the name of the data we want, we know it contains \"ordered_trace_heights\" and we want grain 2, but don't know\n", + "# what keys precede it\n", + "helper.find_data([\"ordered_trace_heights\", \"2\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Retrieve data from the file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get some data from the file\n", + "cropped_image = helper.get_data(\"grain_trace_data/above/cropped_images/2\")\n", + "ordered_trace_heights = helper.get_data(\"grain_trace_data/above/ordered_trace_heights/2\")\n", + "cumulative_distances = helper.get_data(\"grain_trace_data/above/ordered_trace_cumulative_distances/2\")\n", + "ordered_traces = helper.get_data(\"grain_trace_data/above/ordered_traces/2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use the retrieved data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the image\n", + "plt.imshow(cropped_image)\n", + "# Create a basic colour scale for the moleucle trace\n", + "c = np.arange(0, len(ordered_traces))\n", + "# Plot the molecule trace\n", + "plt.scatter(ordered_traces[:, 1], ordered_traces[:, 0], c=c, s=10)\n", + "plt.show()\n", + "# Plot the height of the molecule trace against the cumulative distance in nanometres\n", + "plt.plot(cumulative_distances, ordered_trace_heights)\n", + "plt.xlabel(\"Cumulative distance (nm)\")\n", + "plt.ylabel(\"Height (nm)\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "topo-unet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_io.py b/tests/test_io.py index 0384deeee0d..0e9bad443e7 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,5 +1,7 @@ """Tests of IO.""" +from __future__ import annotations + import argparse import json import logging @@ -15,6 +17,7 @@ from topostats.io import ( LoadScans, + TopoFileHelper, convert_basename_to_relative_paths, dict_to_hdf5, dict_to_json, @@ -1379,3 +1382,41 @@ def test_dict_to_json(dictionary: dict, target: dict, tmp_path: Path) -> None: with outfile.open("r", encoding="utf-8") as f: assert target == json.load(f) + + +class TestTopoFileHelper: + """Test the TopoFileHelper class.""" + + @pytest.mark.parametrize( + ("file_path_or_string"), + [ + pytest.param( + "tests/resources/file.topostats", + id="String file path", + ), + pytest.param( + Path("tests/resources/file.topostats"), + id="Path object path", + ), + ], + ) + def test_init(self, file_path_or_string: Path | str) -> None: + """Test the __init__ method of the TopoFileHelper class.""" + topo_file_helper = TopoFileHelper(file_path_or_string) + assert isinstance(topo_file_helper, TopoFileHelper) + assert isinstance(topo_file_helper.data, dict) + + def test_get_data(self) -> None: + """Test the get_data method of the TopoFileHelper class.""" + topo_file_helper = TopoFileHelper("tests/resources/file.topostats") + cropped_image = topo_file_helper.get_data("grain_trace_data/above/cropped_images/2") + assert isinstance(cropped_image, np.ndarray) + + +# This test only works when not part of the TestTopoFileHelper class +def test_pretty_print_structure(caplog) -> None: + """Test the pretty_print_structure method of the TopoFileHelper class.""" + topo_file_helper = TopoFileHelper("tests/resources/file.topostats") + topo_file_helper.pretty_print_structure() + assert "filename" in caplog.text + assert "keys with numpy arrays as values" in caplog.text diff --git a/topostats/io.py b/topostats/io.py index 06b3a669e86..c71732f6764 100644 --- a/topostats/io.py +++ b/topostats/io.py @@ -1284,3 +1284,369 @@ def dict_to_json(data: dict, output_dir: str | Path, filename: str | Path, inden output_file = output_dir / filename with output_file.open("w") as f: json.dump(data, f, indent=indent, cls=NumpyEncoder) + + +class TopoFileHelper: + """ + Helper class for searching through the data in a .topostats (hdf5) file. + + Parameters + ---------- + topofile : Path + Path to the .topostats file. + + Examples + -------- + Creating a helper object. + ```python + from topostats.io import TopoFileHelper + + topofile = "path/to/topostats_file.topostats" + helper = TopoFileHelper(topofile) + ``` + + Print the structure of the data in the file. + ```python + from topostats.io import TopoFileHelper + + topofile = "path/to/topostats_file.topostats" + helper = TopoFileHelper(topofile) + helper.pretty_print_structure() + ``` + >>> [./tests/resources/file.topostats] + >>> ├ filename + >>> │ └ minicircle + >>> ├ grain_masks + >>> │ └ above + >>> │ └ Numpy array, shape: (1024, 1024), dtype: int64 + >>> ├ grain_trace_data + >>> │ └ above + >>> │ ├ cropped_images + >>> │ │ └ 21 keys with numpy arrays as values + >>> │ ├ ordered_trace_cumulative_distances + >>> │ │ └ 21 keys with numpy arrays as values + >>> │ ├ ordered_trace_heights + >>> │ │ └ 21 keys with numpy arrays as values + >>> │ ├ ordered_traces + >>> │ │ └ 21 keys with numpy arrays as values + >>> │ └ splined_traces + >>> │ └ 21 keys with numpy arrays as values + >>> ├ image + >>> │ └ Numpy array, shape: (1024, 1024), dtype: float64 + >>> ├ image_original + >>> │ └ Numpy array, shape: (1024, 1024), dtype: float64 + >>> ├ img_path + >>> │ └ /Users/sylvi/Documents/TopoStats/tests/resources/minicircle + >>> ├ pixel_to_nm_scaling + >>> │ └ 0.4940029296875 + >>> └ topostats_file_version + >>> └ 0.2 + + Finding data in a file. + ```python + from topostats.io import TopoFileHelper + + topofile = "path/to/topostats_file.topostats" + helper = TopoFileHelper(topofile) + helper.find_data(["ordered_trace_heights", "0"]) + ``` + + >>> [ Searching for ['ordered_trace_heights', '0'] in ./path/to/topostats_file.topostats ] + >>> | [search] No direct match found. + >>> | [search] Searching for partial matches. + >>> | [search] !! [ 1 Partial matches found] !! + >>> | [search] └ grain_trace_data/above/ordered_trace_heights/0 + >>> └ [End of search] + + Get data from a file. + ```python + from topostats.io import TopoFileHelper + + topofile = "path/to/topostats_file.topostats" + helper = TopoFileHelper(topofile) + + data = helper.get_data("ordered_trace_heights/0") + ``` + >>> [ Get data ] Data found at grain_trace_data/above/ordered_trace_heights/0, type: + + Get data information + ```python + from topostats.io import TopoFileHelper + + topofile = "path/to/topostats_file.topostats" + helper = TopoFileHelper(topofile) + + helper.data_info("grain_trace_data/above/ordered_trace_heights/0") + ``` + >>> [ Info ] Data at grain_trace_data/above/ordered_trace_heights/0 is a numpy array with shape: (95,), + >>> dtype: float64 + """ + + def __init__(self, topofile: Path | str) -> None: + """ + Initialise the TopoFileHelper object. + + Parameters + ---------- + topofile : Path | str + Path to the .topostats file. + """ + self.topofile: Path = Path(topofile) + with h5py.File(self.topofile, "r") as f: + self.data: dict = hdf5_to_dict(open_hdf5_file=f, group_path="/") + + def search_partial_matches(self, data: dict, keys: list, current_path: list | None = None) -> list: + """ + Find partial matches to the keys in the dictionary. + + Recursively search through nested dictionaries and keep only the paths that match the keys in the correct order, + allowing gaps between the keys. + + Parameters + ---------- + data : dict + The dictionary to search through. + keys : list + The list of keys to search for. + current_path : list, optional + The current path in the dictionary, by default []. + + Returns + ------- + list + A list of paths that match the keys in the correct order. + """ + if current_path is None: + # Need to initialise the empty list here and not as a default argument since it is mutable + current_path = [] + + partial_matches = [] + + def recursive_partial_search(data, keys, current_path) -> None: + """ + Recursively find partial matches to the keys in the dictionary. + + Recursive function to search through the dictionary and keep only the paths + that match the keys in the correct order, + allowing gaps between the keys. + + Parameters + ---------- + data : dict + The dictionary to search through. + keys : list + The list of keys to search for. + current_path : list + The current path in the dictionary. + """ + # If have reached the end of the current dictionary, return + if not keys: + partial_matches.append(current_path) + return + + current_key = keys[0] + + if isinstance(data, dict): + for k, v in data.items(): + new_path = current_path + [k] + try: + # Check if the current key can be converted to an integer + current_key_int = int(current_key) + k_int = int(k) + # If the current key and the key in the dictionary can be converted to integers, + # check if they are equal + if current_key_int == k_int: + # If the current key is in the key list of the dictionary, continue searching + # but remove the current key from the list + remaining_keys = keys[1:] + recursive_partial_search(v, remaining_keys, new_path) + except ValueError: + # If the current key cannot be converted to an integer, allow for partial matches + if current_key in k: + # If the current key is in the key list of the dictionary, continue searching + # but remove the current key from the list + remaining_keys = keys[1:] + recursive_partial_search(v, remaining_keys, new_path) + else: + # If the current key is not in the key list of the dictionary, continue searching + # but don't remove the current key from the list as it might be deeper in the dictionary + recursive_partial_search(v, keys, new_path) + + recursive_partial_search(data, keys, current_path) + return partial_matches + + def find_data(self, search_keys: list) -> None: + """ + Find the data in the dictionary that matches the list of keys. + + Parameters + ---------- + search_keys : list + The list of keys to search for. + """ + # Find the best match for the list of keys + # First check if there is a direct match + LOGGER.info(f"[ Searching for {search_keys} in {self.topofile} ]") + + try: + current_data = self.data + for key in search_keys: + current_data = current_data[key] + + LOGGER.info("| [search] Direct match found") + except KeyError: + LOGGER.info("| [search] No direct match found.") + + # If no direct match is found, try to find a partial match + LOGGER.info("| [search] Searching for partial matches.") + partial_matches = self.search_partial_matches(data=self.data, keys=search_keys) + if partial_matches: + LOGGER.info(f"| [search] !! [ {len(partial_matches)} Partial matches found] !!") + for index, match in enumerate(partial_matches): + match_str = "/".join(match) + if index == len(partial_matches) - 1: + prefix = "| [search] └" + else: + prefix = "| [search] ├" + LOGGER.info(f"{prefix} {match_str}") + else: + LOGGER.info("| [search] No partial matches found.") + LOGGER.info("└ [End of search]") + + def pretty_print_structure(self) -> None: + """ + Print the structure of the data in the data dictionary. + + The structure is printed with the keys indented to show the hierarchy of the data. + """ + + def print_structure(data: dict, level=0, prefix=""): + """ + Recursive function to print the structure. + + Parameters + ---------- + data : dict + The dictionary to print the structure of. + level : int, optional + The current level of the dictionary, by default 0. + prefix : str, optional + The prefix to use when printing the dictionary, by default "". + """ + for i, (key, value) in enumerate(data.items()): + is_last_item = i == len(data) - 1 + current_prefix = prefix + ("└ " if is_last_item else "├ ") + LOGGER.info(current_prefix + key) + + if isinstance(value, dict): + # Check if all keys are able to be integers, they are strings but need to check if they can be + # converted to integers without error + all_keys_are_integers = True + for k in value.keys(): + try: + int(k) + except ValueError: + all_keys_are_integers = False + break + all_values_are_numpy_arrays = all(isinstance(v, np.ndarray) for v in value.values()) + # if dictionary has keys that are integers and values that are numpy arrays, print the number + # of keys and the shape of the numpy arrays + if all_keys_are_integers and all_values_are_numpy_arrays: + LOGGER.info( + prefix + + (" " if is_last_item else "│ ") + + "└ " + + f"{len(value)} keys with numpy arrays as values" + ) + else: + new_prefix = prefix + (" " if is_last_item else "│ ") + print_structure(value, level + 1, new_prefix) + + elif isinstance(value, np.ndarray): + # Don't print the array, just the shape + LOGGER.info( + prefix + + (" " if is_last_item else "│ ") + + "└ " + + f"Numpy array, shape: {str(value.shape)}, dtype: {value.dtype}" + ) + else: + LOGGER.info(f"{prefix + (' ' if is_last_item else '│ ') + '└ ' + str(value)}") + + LOGGER.info(f"[{self.topofile}]") + print_structure(self.data) + + def get_data(self, location: str) -> int | float | str | np.ndarray | dict | None: + """ + Retrieve data from the dictionary using a '/' separated string. + + Parameters + ---------- + location : str + The location of the data in the dictionary, separated by '/'. + + Returns + ------- + int | float | str | np.ndarray | dict + The data at the location. + """ + # If there's a trailing '/', remove it + if location[-1] == "/": + location = location[:-1] + keys = location.split("/") + + try: + current_data = self.data + for key in keys: + current_data = current_data[key] + LOGGER.info(f"[ Get data ] Data found at {location}, type: {type(current_data)}") + return current_data + except KeyError as e: + LOGGER.error(f"[ Get data ] Key not found: {e}, please check the location string.") + return None + + def data_info(self, location: str, verbose: bool = False) -> None: + """ + Get information about the data at a location. + + Parameters + ---------- + location : str + The location of the data in the dictionary, separated by '/'. + + verbose : bool, optional + Print more detailed information about the data, by default False. + """ + # If there's a trailing '/', remove it + if location[-1] == "/": + location = location[:-1] + keys = location.split("/") + + try: + current_data = self.data + for key in keys: + current_data = current_data[key] + except KeyError as e: + LOGGER.error(f"[ Info ] Key not found: {e}, please check the location string.") + return + + if isinstance(current_data, dict): + key_types = {type(k) for k in current_data.keys()} + value_types = {type(v) for v in current_data.values()} + LOGGER.info( + f"[ Info ] Data at {location} is a dictionary with {len(current_data)} " + f"keys of types {key_types} and values " + f"of types {value_types}" + ) + if verbose: + for k, v in current_data.items(): + LOGGER.info(f" {k}: {type(v)}") + elif isinstance(current_data, np.ndarray): + LOGGER.info( + f"[ Info ] Data at {location} is a numpy array with shape: {current_data.shape}, " + f"dtype: {current_data.dtype}" + ) + else: + LOGGER.info(f"[ Info ] Data at {location} is {type(current_data)}") + + return