Skip to content

Commit

Permalink
Improve databricks compatibility with UMAP. [#250]
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesdolezal committed Oct 25, 2023
1 parent 646be9a commit 6f79a83
Showing 1 changed file with 50 additions and 18 deletions.
68 changes: 50 additions & 18 deletions slideflow/stats/slidemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def load(cls, path: str):
"the path is a valid directory with either 'parametric_umap' "
"subdirectory or a valid 'umap.pkl'.")
# Load range/clip
if exists(join(path, 'range_clip.npz')):
obj.load_range_clip(join(path, 'range_clip.npz'))
else:
try:
obj.load_range_clip(path)
except FileNotFoundError:
log.warn("Could not find range_clip.npz; results from "
"umap_transform() will not be normalized.")
if exists(join(path, 'tfrecords.json')):
Expand Down Expand Up @@ -1041,21 +1041,38 @@ def save_umap(self, path: str) -> None:
pickle.dump(self.umap, f)
log.info(f"Wrote UMAP coordinates to [green]{path}")
self.save_coordinates(join(path, 'slidemap.parquet'))
np.savez(
join(path, 'range_clip.npz'),
range=self._umap_normalized_range,
clip=self._umap_normalized_clip)
self.save_range_clip(path)

def save_encoder(self, path: str) -> None:
"""Save Parametric UMAP encoder only."""
if not self.parametric_umap:
raise ValueError("SlideMap not built with Parametric UMAP.")
self.umap.encoder.save(join(path, 'encoder'))
self.save_coordinates(join(path, 'slidemap.parquet'))
np.savez(
join(path, 'range_clip.npz'),
range=self._umap_normalized_range,
clip=self._umap_normalized_clip)
self.save_range_clip(path)

def save_range_clip(self, dest: str) -> None:
"""Save range/clip information.
If ZIP saving is enabled, will save to range_clip.npz, with the
attributes ``"range"`` and ``"clip"``.
If ZIP saving is disabled (SF_ALLOW_ZIP=0, for databricks compatibility),
will save these attributes to range.npy and clip.npy, separately.
Args:
dest (str): Destination directory.
"""
if sf.util.zip_allowed():
np.savez(
dest + 'range_clip.npz',
range=self._umap_normalized_range,
clip=self._umap_normalized_clip
)
else:
np.save(dest + 'range.npy', self._umap_normalized_range)
np.save(dest + 'clip.npy', self._umap_normalized_clip)

def load_range_clip(self, path: str) -> None:
"""Load a saved range_clip.npz file for normalizing raw UMAP output.
Expand All @@ -1065,13 +1082,28 @@ def load_range_clip(self, path: str) -> None:
as generated from ``SlideMap.save()``.
"""
log.debug(f"Loading range_clip at {path}")
loaded = np.load(path)
if not ('range' in loaded and 'clip' in loaded):
raise ValueError(f"Unable to load {path}; did not find values "
"'range' and 'clip'.")
self._umap_normalized_clip = loaded['clip']
self._umap_normalized_range = loaded['range']
rc_path, r_path, c_path = None, None, None
if exists(path) and path.endswith('.npz'):
rc_path = path
elif exists(join(path, 'range_clip.npz')):
rc_path = join(path, 'range_clip.npz')
elif exists(join(path, 'range.npy')) and exists(join(path, 'clip.npy')):
r_path = join(path, 'range.npy')
c_path = join(path, 'clip.npy')
else:
raise FileNotFoundError(
f"Unable to find range/clip information at {path}."
)
if rc_path:
loaded = np.load(path)
if not ('range' in loaded and 'clip' in loaded):
raise ValueError(f"Unable to load {path}; did not find values "
"'range' and 'clip'.")
self._umap_normalized_clip = loaded['clip']
self._umap_normalized_range = loaded['range']
else:
self._umap_normalized_clip = np.load(c_path)
self._umap_normalized_range = np.load(r_path)
log.info("Loaded range={}, clip={}".format(
self._umap_normalized_range,
self._umap_normalized_clip
Expand Down

0 comments on commit 6f79a83

Please sign in to comment.