Skip to content

Commit

Permalink
add tests + move segmentation src + add CI
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldoucet committed Jun 12, 2024
1 parent 1030d80 commit 7e0ad3e
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 65 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Hest tests

on:
push:
branches: [ "main", "develop"]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: "3.9"
- name: Install python dependencies
run: |
python -m pip install -e .
- name: Install apt dependencies
run: |
apt install libvips libvips-dev openslide-tools
- name: Run reader tests
run: |
python reader_tests.py
- name: Run hestdata tests
run: |
python hestdata_tests.py
7 changes: 1 addition & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,15 @@ st.save(save_dir, pyramidal=True)

## Patching and otsu-based tissue segmentation

Note that by default the whole WSI image is loaded in RAM to speed up the patching process,
if the image is too large to be loaded in RAM, you can pass `load_in_memory=False` but the patching will be slower.

```python
# By default the whole WSI image is loaded in RAM
# to speed up the patching process
load_in_memory = True

st.dump_patches(
patch_save_dir,
'demo',
target_patch_size=224,
target_pixel_size=0.5,
load_in_memory=load_in_memory
target_pixel_size=0.5
)
```

Expand Down
30 changes: 25 additions & 5 deletions src/hest/HESTData.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from spatialdata import SpatialData
from tqdm import tqdm

from .SegDataset import SegDataset
from .segmentation import (apply_otsu_thresholding, keep_largest_area,
from .segmentation.SegDataset import SegDataset
from .segmentation.segmentation import (apply_otsu_thresholding, keep_largest_area,
mask_to_contours, save_pkl, scale_contour_dim,
segment_tissue_deep, visualize_tissue_seg)
from .utils import (ALIGNED_HE_FILENAME, check_arg, get_path_from_meta_row,
Expand Down Expand Up @@ -109,6 +109,9 @@ class representing a single ST profile + its associated WSI image
self._verify_format(adata)
self.pixel_size = pixel_size
self.cellvit_seg = cellvit_seg
self.tissue_mask = None
self.contours_holes = None
self.contours_tissue = None

if 'total_counts' not in self.adata.var_names:
sc.pp.calculate_qc_metrics(self.adata, inplace=True)
Expand All @@ -129,19 +132,20 @@ def __repr__(self):
return rep


def save_spatial_plot(self, save_path: str, key='total_counts', pl_kwargs={}):
def save_spatial_plot(self, save_path: str, name: str='', key='total_counts', pl_kwargs={}):
"""Save the spatial plot from that STObject
Args:
save_path (str): path to a directory where the spatial plot will be saved
name (str): save plot as {name}spatial_plots.png
key (str): feature to plot. Default: 'total_counts'
pl_kwargs(Dict): arguments for sc.pl.spatial
"""
print("Plotting spatial plots...")

sc.pl.spatial(self.adata, show=None, img_key="downscaled_fullres", color=[key], title=f"in_tissue spots", **pl_kwargs)

filename = f"spatial_plots.png"
filename = f"{name}spatial_plots.png"

# Save the figure
plt.savefig(os.path.join(save_path, filename))
Expand Down Expand Up @@ -418,7 +422,7 @@ def dump_patches(
attr_dict['img'] = {'patch_size': patch_size_pxl,
'factor': scale_factor}

initsave_hdf5(output_datafile, asset_dict, attr_dict, mode=mode_HE, verbose=1)
initsave_hdf5(output_datafile, asset_dict, attr_dict, mode=mode_HE)
mode_HE = 'a'


Expand Down Expand Up @@ -489,6 +493,22 @@ def save_tissue_seg_pkl(self, save_dir: str, name: str) -> None:

asset_dict = self.get_tissue_contours()
save_pkl(os.path.join(save_dir, f'{name}_mask.pkl'), asset_dict)


def save_vis(self, save_dir, name) -> None:

vis = visualize_tissue_seg(
self.wsi.img,
self.tissue_mask,
self.contours_tissue,
self.contours_holes,
line_color=(0, 255, 0),
hole_color=(0, 0, 255),
line_thickness=5,
target_width=1000,
seg_display=True,
)
vis.save(os.path.join(save_dir, f'{name}_vis.jpg'))


def to_spatial_data(self, lazy_img=True) -> SpatialData:
Expand Down
File renamed without changes.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.data import DataLoader
from torchvision import models, transforms

from hest.SegDataset import SegDataset
from hest.segmentation.SegDataset import SegDataset
from hest.utils import get_path_relative
from hest.wsi import WSI, WSIPatcher

Expand Down Expand Up @@ -51,7 +51,7 @@ def segment_tissue_deep(img: Union[np.ndarray, openslide.OpenSlide, 'CuImage', W

width, height = wsi.get_dimensions()

weights_path = get_path_relative(__file__, '../../models/deeplabv3_seg_v3.ckpt')
weights_path = get_path_relative(__file__, '../../../models/deeplabv3_seg_v3.ckpt')

patcher = WSIPatcher(wsi, patch_size_src)

Expand Down Expand Up @@ -204,6 +204,8 @@ def visualize_tissue_seg(

img = wsi.get_thumbnail(round(width * downsample), round(height * downsample))
#img = cv2.resize(img, (round(width * downsample), round(height * downsample)))
if tissue_mask is None and contours_tissue is None and contour_holes is None:
return Image.fromarray(img)

downscaled_mask = cv2.resize(tissue_mask, (img.shape[1], img.shape[0]))
downscaled_mask = np.expand_dims(downscaled_mask, axis=-1)
Expand Down
22 changes: 15 additions & 7 deletions tests/test.py → tests/hestdata_tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List
import unittest
from os.path import join as _j

Expand All @@ -23,47 +24,54 @@ class TestHESTData(unittest.TestCase):
def setUpClass(self):
self.cur_dir = get_path_relative(__file__, '')
cur_dir = self.cur_dir
self.output_dir = _j(cur_dir, 'output_tests')
self.output_dir = _j(cur_dir, 'output_tests/hestdata_tests')

# Create an instance of HESTData
adata = sc.read_h5ad(_j(cur_dir, './assets/SPA154.h5ad'))
pixel_size = 0.9206

self.st_objects = []
self.st_objects: List[HESTData] = []


if CuImage is not None:
img = CuImage(_j(cur_dir, './assets/SPA154.tif'))
self.st_objects.append(HESTData(adata, img, pixel_size))

img = openslide.OpenSlide(_j(cur_dir, './assets/SPA154.tif'))
self.st_objects.append(HESTData(adata, img, pixel_size))
self.st_objects.append({'name': 'numpy', 'st': HESTData(adata, img, pixel_size)})

if CuImage is not None:
img = WSI(CuImage(_j(cur_dir, './assets/SPA154.tif'))).numpy()
self.st_objects.append(HESTData(adata, img, pixel_size))
self.st_objects.append({'name': 'cuimage', 'st': HESTData(adata, img, pixel_size)})
else:
img = WSI(openslide.OpenSlide(_j(cur_dir, './assets/SPA154.tif'))).numpy()
self.st_objects.append(HESTData(adata, img, pixel_size))
self.st_objects.append({'name': 'openslide', 'st': HESTData(adata, img, pixel_size)})

def test_tissue_seg(self):
for idx, st in enumerate(self.st_objects):
st = st['st']
with self.subTest(st_object=idx):
st.compute_mask(method='deep')
st.save_tissue_seg_jpg(self.output_dir, name=f'deep_{idx}')
st.save_tissue_seg_pkl(self.output_dir, name=f'deep_{idx}')
st.save_vis(self.output_dir, name=f'deep_{idx}')

st.compute_mask(method='otsu')
st.save_tissue_seg_jpg(self.output_dir, name=f'otsu_{idx}')
st.save_tissue_seg_pkl(self.output_dir, name=f'otsu_{idx}')
st.save_vis(self.output_dir, name=f'otsu_{idx}')

def test_patching(self):
for idx, st in enumerate(self.st_objects):
for idx, conf in enumerate(self.st_objects):
st = conf['st']
with self.subTest(st_object=idx):
st.dump_patches(self.output_dir)
name = ''
name += conf['name']
st.dump_patches(self.output_dir, name=name)

def test_wsi(self):
for idx, st in enumerate(self.st_objects):
st = st['st']
with self.subTest(st_object=idx):
os.makedirs(_j(self.output_dir, f'test_save_{idx}'), exist_ok=True)
st.save(_j(self.output_dir, f'test_save_{idx}'), save_img=True)
Expand Down
68 changes: 68 additions & 0 deletions tests/reader_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
import unittest
from os.path import join as _j

import scanpy as sc

from hest.readers import VisiumReader
from hest.utils import get_path_relative


class TestHESTData(unittest.TestCase):

@classmethod
def setUpClass(self):
self.cur_dir = get_path_relative(__file__, '')
cur_dir = self.cur_dir
self.output_dir = _j(cur_dir, 'output_tests', 'reader_tests')
os.makedirs(self.output_dir, exist_ok=True)


def test_visium_reader_img_matrix_spatial(self):
cur_dir = self.cur_dir
fullres_img_path = _j(cur_dir, './assets/WSA_LngSP9258463.jpg')
bc_matrix_path = _j(cur_dir, './assets/filtered_feature_bc_matrix.h5')
spatial_coord_path = _j(cur_dir, './assets/spatial')


st = VisiumReader().read(
fullres_img_path, # path to a full res image
bc_matrix_path, # path to filtered_feature_bc_matrix.h5
spatial_coord_path=spatial_coord_path # path to a space ranger spatial/ folder containing either a tissue_positions.csv or tissue_position_list.csv
)

st.save(_j(self.output_dir, 'img+filtered_matrix+spatial'), pyramidal=True)
st.save_spatial_plot(_j(self.output_dir, 'img+filtered_matrix+spatial'), self.output_dir)


st.dump_patches(
self.output_dir,
'demo',
target_patch_size=224,
target_pixel_size=0.5
)


def test_visium_reader_img_matrix(self):
cur_dir = self.cur_dir
fullres_img_path = _j(cur_dir, './assets/WSA_LngSP9258463.jpg')
bc_matrix_path = _j(cur_dir, './assets/filtered_feature_bc_matrix.h5')

# if both the alignment file and the spatial folder are missing, attempt autoalignment
st = VisiumReader().read(
fullres_img_path, # path to a full res image
bc_matrix_path, # path to filtered_feature_bc_matrix.h5
)
st.save(_j(self.output_dir, 'img+filtered_matrix'), pyramidal=True)
st.save_spatial_plot(_j(self.output_dir, 'img+filtered_matrix'), self.output_dir)

st.dump_patches(
self.output_dir,
'demo',
target_patch_size=224,
target_pixel_size=0.5
)


if __name__ == '__main__':
unittest.main()
45 changes: 0 additions & 45 deletions tests/test_visium_reader.py

This file was deleted.

0 comments on commit 7e0ad3e

Please sign in to comment.