5
5
from typing import Optional , Sequence
6
6
from numbers import Integral
7
7
from ipywidgets import widgets , interact
8
- from numpy import arange , r_ , nan , fill_diagonal
8
+ from numpy import arange , nan , fill_diagonal
9
9
from pandas import DataFrame
10
10
from altair import (
11
11
Axis , Chart , Color , condition , data_transformers , selection_multi ,
@@ -401,7 +401,7 @@ def plot_stairs(
401
401
x = X (** x_kws ),
402
402
y = Y (** y_kws ),
403
403
tooltip = [xlabel , ylabel , tooltip_label ]
404
- )
404
+ )
405
405
return chart \
406
406
.configure_axis (labelFontSize = font_size , titleFontSize = font_size )\
407
407
.configure_legend (labelFontSize = font_size , titleFontSize = font_size )
@@ -472,17 +472,15 @@ def plot_stairbars(
472
472
x = X (** x_kws ),
473
473
y = Y (** y_kws ),
474
474
tooltip = [xlabel , ylabel , tooltip_label ]
475
- )
475
+ )
476
476
return chart \
477
477
.configure_axis (labelFontSize = font_size , titleFontSize = font_size )\
478
478
.configure_legend (labelFontSize = font_size , titleFontSize = font_size )
479
479
480
480
481
-
482
481
def plot_heatmap (
483
482
data : DataFrame ,
484
483
columns : Optional [Sequence [str ]] = None ,
485
- tooltip_cols : Optional [Sequence [str ]] = None ,
486
484
names : list = None ,
487
485
sort : bool = True ,
488
486
droppable : bool = True ,
@@ -506,8 +504,6 @@ def plot_heatmap(
506
504
Input data.
507
505
columns : Optional[Sequence[str]], optional
508
506
Columns that are to be displayed on a plot.
509
- tooltip_cols : Optional[Sequence[str]], optional
510
- Columns to be used in tooltips.
511
507
names : list, optional
512
508
Values labels passed as a list.
513
509
The first element corresponds to non-missing values,
@@ -545,7 +541,8 @@ def plot_heatmap(
545
541
if not x_kws :
546
542
x_kws = {'sort' : None , 'shorthand' : xlabel , 'type' : 'nominal' }
547
543
if not y_kws :
548
- y_kws = {'sort' : None , 'shorthand' : ylabel , 'type' : 'ordinal' , 'axis' : Axis (labelOverlap = 'greedy' )}
544
+ y_kws = {'sort' : None , 'shorthand' : ylabel ,
545
+ 'type' : 'ordinal' , 'axis' : Axis (labelOverlap = 'greedy' )}
549
546
if not names :
550
547
names = ['Filled' , 'NA' , 'Droppable' ]
551
548
if not color_kws :
@@ -557,37 +554,35 @@ def plot_heatmap(
557
554
range = ["green" , "red" , "orange" ])
558
555
}
559
556
if not rect_kws :
560
- rect_kws = {}
557
+ rect_kws = {"clip" : True }
561
558
562
559
cols = _select_cols (data , columns )
563
- tt_cols = _select_cols (data , tooltip_cols , [])
564
560
565
- data_copy = data .loc [:, r_ [cols , tt_cols ]].copy ()
566
- data_copy .loc [:, cols ] = data_copy .loc [:, cols ].isna ()
561
+ data_copy = data .loc [:, cols ].copy ().isna ()
567
562
if sort :
568
- cols_sorted = data_copy . loc [:, cols ] \
563
+ cols_sorted = data_copy \
569
564
.sum ()\
570
565
.sort_values (ascending = False )\
571
566
.index .tolist ()
572
567
data_copy .sort_values (by = cols_sorted , inplace = True )
573
568
x_kws .update ({'sort' : cols_sorted })
574
569
575
570
if droppable :
576
- non_na_mask = ~ data_copy .loc [:, cols ].values
577
- na_rows_mask = data_copy .loc [:, cols ].any (axis = 1 ).values [:, None ]
571
+ non_na_mask = ~ data_copy .values
572
+ na_rows_mask = data_copy .any (
573
+ axis = 1 ).values [:, None ]
578
574
droppable_mask = non_na_mask & na_rows_mask
579
- data_copy .loc [:, cols ] = data_copy .loc [:, cols ].astype (int )
580
- data_copy .loc [:, cols ] = data_copy .loc [:, cols ]\
575
+ data_copy = data_copy .astype (int )\
581
576
.mask (droppable_mask , other = 2 )
582
577
else :
583
- data_copy . loc [:, cols ] = data_copy . loc [:, cols ] .astype (int )
578
+ data_copy = data_copy .astype (int )
584
579
585
- data_copy . loc [:, cols ] = data_copy . loc [:, cols ] .replace (
580
+ data_copy = data_copy .replace (
586
581
dict (zip ([0 , 1 , 2 ], names )))
587
582
588
583
data_copy [ylabel ] = arange (data .shape [0 ])
589
584
data_copy = data_copy .melt (
590
- id_vars = r_ [[ ylabel ], tt_cols ],
585
+ id_vars = [ ylabel ],
591
586
value_vars = cols ,
592
587
var_name = xlabel ,
593
588
value_name = zlabel )
@@ -598,8 +593,7 @@ def plot_heatmap(
598
593
x = X (** x_kws ),
599
594
y = Y (** y_kws ),
600
595
color = Color (** color_kws ),
601
- tooltip = tt_cols .tolist ()
602
- )
596
+ )
603
597
604
598
return chart \
605
599
.configure_axis (labelFontSize = font_size , titleFontSize = font_size )\
0 commit comments