@@ -276,12 +276,17 @@ def plot_image(im: np.ndarray,
276
276
in points (e.g. `12.5`) or a relative size (e.g. `'xx-large'`).
277
277
channels (list[int]): The image channel(s) to be plotted. For example,
278
278
to plot the first and third channel of a 4-channel image with
279
- shape (4,1,500,600), you can use `channels=[0, 2]`.
279
+ shape (4,1,500,600), you can use `channels=[0, 2]`. If several
280
+ channels are given, they will be mapped to colors using
281
+ `channel_colors`. In the case of a single channel, `channel_colors`
282
+ can also be the name of a supported colormap.
280
283
channel_colors (list[str]): A list with python color strings
281
284
(e.g. 'red') defining the color for each channel in `channels`.
282
285
For example, to map the selected `channels=[0, 2]` to
283
286
cyan and magenta, respectively, you can use
284
- `channel_colors=['cyan', 'magenta']`.
287
+ `channel_colors=['cyan', 'magenta']`. If a single channel is
288
+ given (e.g. `channels=[0]`), this can also be one of the following
289
+ colormaps: 'viridis', 'plasma', 'inferno', 'magma', 'cividis'.
285
290
channel_ranges (list[list[float]]): A list of 2-element lists
286
291
(e.g. [0.01, 0.95]) giving the the value ranges that should
287
292
be mapped to colors for each channel. If the given numerical
@@ -361,6 +366,10 @@ def plot_image(im: np.ndarray,
361
366
assert all ([ch <= nch for ch in channels ]), (
362
367
f"Invalid `channels` parameter, must be less or equal to { nch } "
363
368
)
369
+ assert len (channels ) == len (channel_colors ), (
370
+ f"`channels` and `channel_colors` must have the same length, "
371
+ f"but are { len (channels )} and { len (channel_colors )} "
372
+ )
364
373
assert axis_style != 'micrometer' or spacing_yx != None , (
365
374
f"For `axis_style='micrometer', the parameter `spacing_yx` needs to be provided."
366
375
)
@@ -410,14 +419,22 @@ def plot_image(im: np.ndarray,
410
419
if not image_transform is None :
411
420
im = image_transform (im )
412
421
413
- # convert (ch,y,x) to rgb (y,x,3) and plot
414
- im_rgb = convert_to_rgb (im = im [channels ],
415
- channel_colors = channel_colors ,
416
- channel_ranges = channel_ranges )
422
+ # convert (ch,y,x) to rgb (y,x,3)
423
+ if len (channels ) == 1 and channel_colors [0 ] in ['viridis' , 'plasma' , 'inferno' ,
424
+ 'magma' , 'cividis' ]:
425
+ # map to colormap
426
+ im_rgb = np .squeeze (im [channels ], axis = 0 )
427
+ cmap = channel_colors [0 ]
428
+ else :
429
+ # map to rgb
430
+ im_rgb = convert_to_rgb (im = im [channels ],
431
+ channel_colors = channel_colors ,
432
+ channel_ranges = channel_ranges )
433
+ cmap = None
417
434
418
435
# define nested function with main plotting code
419
436
def _do_plot ():
420
- plt .imshow (im_rgb )
437
+ plt .imshow (im_rgb , cmap = cmap )
421
438
if not msk is None and np .max (msk ) > 0 :
422
439
plt .imshow (msk ,
423
440
interpolation = 'none' ,
0 commit comments