Skip to content

Commit 59b870d

Browse files
authored
MRG: clean up code a bit; do cleanrun on CI (#27)
* change CI to clean run * cleanup and commenting * clean up code a bit
1 parent e84c49f commit 59b870d

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

.github/workflows/build-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,4 @@ jobs:
4747

4848
- name: build examples
4949
shell: bash -l {0}
50-
run: make examples
50+
run: make cleanrun

src/sourmash_plugin_betterplot.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from sourmash.plugins import CommandLinePlugin
2626

2727

28-
###
29-
28+
### utility functions
3029

3130
def load_labelinfo_csv(filename):
31+
"Load file output by 'sourmash compare --labels-to'"
3232
with sourmash_args.FileInputCSV(filename) as r:
3333
labelinfo = list(r)
3434

@@ -37,12 +37,15 @@ def load_labelinfo_csv(filename):
3737

3838

3939
def load_categories_csv(filename, labelinfo):
40+
"Load categories file, integrate with labelinfo => colors"
4041
with sourmash_args.FileInputCSV(filename) as r:
4142
categories = list(r)
4243

4344
category_map = {}
4445
colors = None
4546
if categories:
47+
# first, figure out which column is matching between labelinfo
48+
# and categories file.
4649
assert labelinfo
4750
keys = set(categories[0].keys())
4851
keys -= {"category"}
@@ -54,19 +57,27 @@ def load_categories_csv(filename, labelinfo):
5457
key = k
5558
break
5659

60+
# found one? awesome. load in all the categories & assign colors.
61+
5762
if key:
58-
category_values = list(set([row["category"] for row in categories]))
59-
category_values.sort()
63+
# get distinct categories
64+
category_values = set([row["category"] for row in categories])
65+
category_values = list(sorted(category_values))
6066

67+
# map to colormap colors
6168
cat_colors = list(map(plt.cm.tab10, range(len(category_values))))
69+
70+
# build map of category => color
6271
category_map = {}
6372
for v, color in zip(category_values, cat_colors):
6473
category_map[v] = color
6574

75+
# build map of key => color
6676
category_map2 = {}
6777
for row in categories:
6878
category_map2[row[key]] = category_map[row["category"]]
6979

80+
# build list of colors
7081
colors = []
7182
for row in labelinfo:
7283
value = row[key]
@@ -82,7 +93,7 @@ def load_categories_csv(filename, labelinfo):
8293

8394

8495
def load_categories_csv_for_labels(filename, queries):
85-
"Load a categories CSV that must use label name."
96+
"Load a categories CSV that uses the 'label' column."
8697
with sourmash_args.FileInputCSV(filename) as r:
8798
categories = list(r)
8899

@@ -91,20 +102,24 @@ def load_categories_csv_for_labels(filename, queries):
91102
if categories:
92103
key = "label"
93104

105+
# load distinct categories
94106
category_values = list(set([row["category"] for row in categories]))
95107
category_values.sort()
96108

109+
# map categories to color
97110
cat_colors = list(map(plt.cm.tab10, range(len(category_values))))
98111
category_map = {}
99112
for v, color in zip(category_values, cat_colors):
100113
category_map[v] = color
101114

115+
# map label to color
102116
category_map2 = {}
103117
for row in categories:
104118
label = row[key]
105119
cat = row["category"]
106120
category_map2[label] = category_map[cat]
107121

122+
# build list of colors
108123
colors = []
109124
for label, idx in queries:
110125
color = category_map2[label]
@@ -116,10 +131,9 @@ def load_categories_csv_for_labels(filename, queries):
116131

117132

118133
#
119-
# CLI plugin - supports 'sourmash scripts plot2'
134+
# CLI plugin code
120135
#
121136

122-
123137
class Command_Plot2(CommandLinePlugin):
124138
command = "plot2" # 'scripts <command>'
125139
description = (
@@ -247,11 +261,12 @@ def plot_composite_matrix(
247261
no_labels=not show_labels,
248262
get_leaves=True,
249263
)
250-
# ax1.set_xticks([])
251264

265+
# draw cut point
252266
if cut_point is not None:
253267
ax1.axvline(x=cut_point, c="red", linestyle="dashed")
254268

269+
# draw matrix
255270
xstart = 0.45
256271
width = 0.45
257272
if not show_labels:
@@ -538,7 +553,7 @@ def main(self, args):
538553
plt.savefig(args.output_figure)
539554

540555

541-
# @CTB unused again...
556+
# @CTB unused code for sparse matrix foo. Revisit!
542557
def create_sparse_dissimilarity_matrix(tuples, num_objects):
543558
# Initialize matrix in LIL format for efficient setup
544559
similarity_matrix = lil_matrix((num_objects, num_objects))

0 commit comments

Comments
 (0)