Skip to content

Commit

Permalink
typehints and review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Mar 9, 2024
1 parent 7472eaa commit 89e5d5b
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions ehrapy/tools/cohort_tracking/_cohort_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,21 @@ def __call__(self, adata: AnnData, label: str = None, operations_done: str = Non
self._tracked_operations.append(operations_done)
self._tracked_steps += 1

# track new stuff
# track new tableone object
t1 = TableOne(adata.obs, columns=self.columns, categorical=self.categorical, **tableone_kwargs)
self._track_t1.append(t1)

def _get_cat_dicts(self, table_one, col):
cat_pct = {category: [] for category in table_one.cat_table.loc[col].index}
def _get_cat_dicts(self, table_one: TableOne, col: str) -> pd.DataFrame:
# mypy error if not specifying dict below
cat_pct: dict = {category: [] for category in table_one.cat_table.loc[col].index}
for cat in cat_pct.keys():
# if tableone does not have the category of this column anymore, set the percentage to 0
# for categorized columns (e.g. gender 1.0/0.0), str(cat) helps to avoid considering the category as a float
# if (col, str(cat)) in table_one.cat_table["Overall"].index:
pct = float(table_one.cat_table["Overall"].loc[(col, str(cat))].split("(")[1].split(")")[0])
# else:
# pct = 0

cat_pct[cat] = [pct]
return pd.DataFrame(cat_pct).T[0]

def _get_num_dicts(self, table_one, col):
def _get_num_dicts(self, table_one: TableOne, col: str):
return table_one.cont_table["Overall"].loc[(col, "")]

@property
Expand All @@ -117,7 +115,7 @@ def track_t1(self):

def plot_cohort_change(
self,
set_axis_labels=True,
set_axis_labels: bool = True,
subfigure_title: bool = False,
color_palette: str = "colorblind",
show: bool = True,
Expand Down Expand Up @@ -240,14 +238,10 @@ def plot_cohort_change(
if idx == 0:
legend_labels.append([Patch(color=level_color, label=col)])

# Set y-axis labels
if set_axis_labels:
single_ax.set_yticks(
range(len(self.columns))
) # Set ticks at positions corresponding to the number of columns
single_ax.set_yticklabels(self.columns) # Set y-axis labels to the column names
single_ax.set_yticks(range(len(self.columns)))
single_ax.set_yticklabels(self.columns)

# Add legend
# These list of lists is needed to reverse the order of the legend labels,
# making the plot much more readable
legend_labels.reverse()
Expand Down Expand Up @@ -278,7 +272,7 @@ def plot_flowchart(
title: str = None,
arrow_size: float = 0.7,
show: bool = True,
ax=None,
ax: Axes = None,
bbox_kwargs: dict = None,
arrowprops_kwargs: dict = None,
) -> None | list[Axes] | tuple[Figure, list[Axes]]:
Expand Down Expand Up @@ -362,7 +356,7 @@ def plot_flowchart(
arrowprops=tot_arrowprops_kwargs,
)

# Set the limits of the axes to center the plot
# required to center the plot
axes.set_xlim(-0.5, 0.5)
axes.set_ylim(0, 1.1)

Expand Down

0 comments on commit 89e5d5b

Please sign in to comment.