Skip to content

Commit

Permalink
Merge pull request #38 from smart-on-fhir/mikix/info-labels
Browse files Browse the repository at this point in the history
feat(info): add --labels flag to print label statistics
  • Loading branch information
mikix authored Jun 10, 2024
2 parents 2138b29 + 00c2803 commit f69caea
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 57 deletions.
14 changes: 12 additions & 2 deletions chart_review/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def __init__(self, proj_config: config.ProjectConfig):
self.ls_export = common.read_json(self.config.path("labelstudio-export.json"))
self.annotations = simplify.simplify_export(self.ls_export, self.config)

# Add a placeholder for any annotators that don't have mentions for some reason
for annotator in self.config.annotators.values():
self.annotations.mentions.setdefault(annotator, types.Mentions())

# Load external annotations (i.e. from NLP tags or ICD10 codes)
for name, value in self.config.external_annotations.items():
external.merge_external(self.annotations, self.ls_export, self.project_dir, name, value)
Expand All @@ -42,9 +46,15 @@ def __init__(self, proj_config: config.ProjectConfig):
# Calculate the final set of note ranges for each annotator
self.note_range, self.ignored_notes = self._collect_note_ranges(self.ls_export)

# Remove any ignored notes from the mentions table, for ease of consuming code
for mentions in self.annotations.mentions.values():
for note in self.ignored_notes:
if note in mentions:
del mentions[note]

def _collect_note_ranges(
self, exported_json: list[dict]
) -> tuple[dict[str, set[int]], set[int]]:
) -> tuple[dict[str, types.NoteSet], types.NoteSet]:
# Detect note ranges if they were not defined in the project config
# (i.e. default to the full set of annotated notes)
note_ranges = {k: set(v) for k, v in self.config.note_ranges.items()}
Expand All @@ -55,7 +65,7 @@ def _collect_note_ranges(
all_ls_notes = {int(entry["id"]) for entry in exported_json if "id" in entry}

# Parse ignored IDs (might be note IDs, might be external IDs)
ignored_notes: set[int] = set()
ignored_notes = types.NoteSet()
for ignore_id in self.config.ignore:
ls_id = external.external_id_to_label_studio_id(exported_json, str(ignore_id))
if ls_id is None:
Expand Down
89 changes: 64 additions & 25 deletions chart_review/commands/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import rich
import rich.box
import rich.table
import rich.text
import rich.tree

from chart_review import cli_utils, cohort, config, console_utils
from chart_review import cli_utils, cohort, config, console_utils, types


def print_info(reader: cohort.CohortReader) -> None:
Expand All @@ -26,10 +27,6 @@ def print_info(reader: cohort.CohortReader) -> None:
"Chart Count",
"Chart IDs",
box=rich.box.ROUNDED,
pad_edge=False,
title="Annotations:",
title_justify="left",
title_style="bold",
)
for annotator in sorted(reader.note_range):
notes = reader.note_range[annotator]
Expand All @@ -38,26 +35,9 @@ def print_info(reader: cohort.CohortReader) -> None:
str(len(notes)),
console_utils.pretty_note_range(notes),
)
console.print(chart_table)

# Ignored charts
if reader.ignored_notes:
ignored_count = len(reader.ignored_notes)
chart_word = "chart" if ignored_count == 1 else "charts"
pretty_ranges = console_utils.pretty_note_range(reader.ignored_notes)
console.print(
f" Ignoring {ignored_count} {chart_word} ({pretty_ranges})",
highlight=False,
style="italic",
)

# Labels
console.print()
console.print("Labels:", style="bold")
if reader.class_labels:
console.print(", ".join(sorted(reader.class_labels, key=str.casefold)))
else:
console.print("None", style="italic", highlight=False)
console.print(chart_table)
print_ignored_charts(reader)


def print_ids(reader: cohort.CohortReader) -> None:
Expand Down Expand Up @@ -98,11 +78,68 @@ def print_ids(reader: cohort.CohortReader) -> None:
writer.writerow([chart_id, None, None])


def print_labels(reader: cohort.CohortReader) -> None:
"""
Show label information on the console.
:param reader: the cohort configuration
"""
# Calculate all label counts for each annotator
label_names = sorted(reader.class_labels, key=str.casefold)
label_notes: dict[str, dict[str, types.NoteSet]] = {} # annotator -> label -> note IDs
any_annotator_note_sets: dict[str, types.NoteSet] = {}
for annotator, mentions in reader.annotations.mentions.items():
label_notes[annotator] = {}
for name in label_names:
note_ids = {note_id for note_id, labels in mentions.items() if name in labels}
label_notes[annotator][name] = note_ids
any_annotator_note_sets.setdefault(name, types.NoteSet()).update(note_ids)

label_table = rich.table.Table(
"Annotator",
"Chart Count",
"Label",
box=rich.box.ROUNDED,
)

# First add summary entries, for counts across the union of all annotators
for name in label_names:
count = str(len(any_annotator_note_sets.get(name, {})))
label_table.add_row(rich.text.Text("Any", style="italic"), count, name)

# Now do each annotator as their own little boxed section
for annotator in sorted(label_notes.keys(), key=str.casefold):
label_table.add_section()
for name, note_set in label_notes[annotator].items():
count = str(len(note_set))
label_table.add_row(annotator, count, name)

rich.get_console().print(label_table)
print_ignored_charts(reader)


def print_ignored_charts(reader: cohort.CohortReader):
"""Prints a line about ignored charts, suitable for underlying a table"""
if not reader.ignored_notes:
return

ignored_count = len(reader.ignored_notes)
chart_word = "chart" if ignored_count == 1 else "charts"
pretty_ranges = console_utils.pretty_note_range(reader.ignored_notes)
rich.get_console().print(
f" Ignoring {ignored_count} {chart_word} ({pretty_ranges})",
highlight=False,
style="italic",
)


def make_subparser(parser: argparse.ArgumentParser) -> None:
cli_utils.add_project_args(parser)
parser.add_argument(
mode = parser.add_mutually_exclusive_group()
mode.add_argument(
"--ids", action="store_true", help="Prints a CSV of ID mappings (chart & FHIR IDs)"
)
mode.add_argument("--labels", action="store_true", help="Prints label info and usage")
parser.set_defaults(func=run_info)


Expand All @@ -111,5 +148,7 @@ def run_info(args: argparse.Namespace) -> None:
reader = cohort.CohortReader(proj_config)
if args.ids:
print_ids(reader)
elif args.labels:
print_labels(reader)
else:
print_info(reader)
4 changes: 3 additions & 1 deletion chart_review/console_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Helper methods for printing to the console."""

from chart_review import types

def pretty_note_range(notes: set[int]) -> str:

def pretty_note_range(notes: types.NoteSet) -> str:
"""
Returns a pretty, human-readable string for a set of notes.
Expand Down
1 change: 1 addition & 0 deletions chart_review/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AnnotatorMap = dict[int, str]

LabelSet = set[str]
NoteSet = set[int]

# Map of label_studio_note_id: {all labels for that note}
# Usually used in the context of a specific annotator's label mentions.
Expand Down
2 changes: 1 addition & 1 deletion docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ The new group labels do not need to be a part of your source `labels` list.
#### Example
```yaml
grouped-labels:
ill: [insomnia, chickenpox, ebola]
animal: [dog, cat, fox]
```

### `ignore`
Expand Down
33 changes: 29 additions & 4 deletions docs/info.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,13 @@ This is helpful to examine the computed list of chart ID ranges or labels.

```shell
$ chart-review info
Annotations:
╭──────────┬─────────────┬──────────╮
│Annotator │ Chart Count │ Chart IDs│
├──────────┼─────────────┼──────────┤
│jane │ 3 │ 1, 3–4 │
│jill │ 4 │ 1–4 │
│john │ 3 │ 1–2, 4 │
╰──────────┴─────────────┴──────────╯

Labels:
Cough, Fatigue, Headache
```

## Options
Expand Down Expand Up @@ -59,6 +55,35 @@ chart_id,original_fhir_id,anonymized_fhir_id
2,DocumentReference/D899,DocumentReference/605338cd18c2617864db23fd5fd956f3e806af2021ffa6d11c34cac998eb3b6d
```

### `--labels`

Prints some statistics on the project labels and how often each annotator used each label.

#### Example

```shell
$ chart-review info --labels
╭───────────┬─────────────┬──────────╮
│ Annotator │ Chart Count │ Label │
├───────────┼─────────────┼──────────┤
│ Any │ 2 │ Cough │
│ Any │ 3 │ Fatigue │
│ Any │ 3 │ Headache │
├───────────┼─────────────┼──────────┤
│ jane │ 1 │ Cough │
│ jane │ 2 │ Fatigue │
│ jane │ 2 │ Headache │
├───────────┼─────────────┼──────────┤
│ jill │ 2 │ Cough │
│ jill │ 3 │ Fatigue │
│ jill │ 0 │ Headache │
├───────────┼─────────────┼──────────┤
│ john │ 1 │ Cough │
│ john │ 2 │ Fatigue │
│ john │ 2 │ Headache │
╰───────────┴─────────────┴──────────╯
```

### `--config=PATH`

Use this to point to a secondary (non-default) config file.
Expand Down
119 changes: 95 additions & 24 deletions tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,29 @@ def test_info(self):
stdout = self.grab_output("--project-dir", f"{DATA_DIR}/cold")

self.assertEqual(
"""Annotations:
╭──────────┬─────────────┬──────────╮
│Annotator │ Chart Count │ Chart IDs│
├──────────┼─────────────┼──────────┤
│jane │ 3 │ 1, 3–4 │
│jill │ 4 │ 1–4 │
│john │ 3 │ 1–2, 4 │
╰──────────┴─────────────┴──────────╯
Labels:
Cough, Fatigue, Headache
""", # noqa: W291
"""╭───────────┬─────────────┬───────────╮
│ Annotator │ Chart Count │ Chart IDs │
├───────────┼─────────────┼───────────┤
│ jane │ 3 │ 1, 3–4 │
│ jill │ 4 │ 1–4 │
│ john │ 3 │ 1–2, 4 │
╰───────────┴─────────────┴───────────╯
""",
stdout,
)

def test_info_ignored(self):
stdout = self.grab_output("--project-dir", f"{DATA_DIR}/ignore")

self.assertEqual(
"""Annotations:
╭──────────┬─────────────┬──────────╮
│Annotator │ Chart Count │ Chart IDs│
├──────────┼─────────────┼──────────┤
│adam │ 2 │ 1–2 │
│allison │ 2 │ 1–2 │
╰──────────┴─────────────┴──────────╯
Ignoring 3 charts (3–5)
Labels:
A, B
""", # noqa: W291
"""╭───────────┬─────────────┬───────────╮
│ Annotator │ Chart Count │ Chart IDs │
├───────────┼─────────────┼───────────┤
│ adam │ 2 │ 1–2 │
│ allison │ 2 │ 1–2 │
╰───────────┴─────────────┴───────────╯
Ignoring 3 charts (3–5)
""",
stdout,
)

Expand Down Expand Up @@ -167,3 +159,82 @@ def test_ids_sources(self):
],
stdout.splitlines(),
)

def test_labels(self):
stdout = self.grab_output("--project-dir", f"{DATA_DIR}/cold", "--labels")

self.assertEqual(
"""╭───────────┬─────────────┬──────────╮
│ Annotator │ Chart Count │ Label │
├───────────┼─────────────┼──────────┤
│ Any │ 2 │ Cough │
│ Any │ 3 │ Fatigue │
│ Any │ 3 │ Headache │
├───────────┼─────────────┼──────────┤
│ jane │ 1 │ Cough │
│ jane │ 2 │ Fatigue │
│ jane │ 2 │ Headache │
├───────────┼─────────────┼──────────┤
│ jill │ 2 │ Cough │
│ jill │ 3 │ Fatigue │
│ jill │ 0 │ Headache │
├───────────┼─────────────┼──────────┤
│ john │ 1 │ Cough │
│ john │ 2 │ Fatigue │
│ john │ 2 │ Headache │
╰───────────┴─────────────┴──────────╯
""",
stdout,
)

def test_labels_grouped(self):
"""Verify that we only show final grouped labels, not intermediate ones"""
with tempfile.TemporaryDirectory() as tmpdir:
common.write_json(
f"{tmpdir}/config.json",
{
"labels": ["fever", "rash", "recent"],
"grouped-labels": {"symptoms": ["fever", "rash"]},
},
)
common.write_json(
f"{tmpdir}/labelstudio-export.json",
[],
)
stdout = self.grab_output("--labels", "--project-dir", tmpdir)

self.assertEqual(
"""╭───────────┬─────────────┬──────────╮
│ Annotator │ Chart Count │ Label │
├───────────┼─────────────┼──────────┤
│ Any │ 0 │ recent │
│ Any │ 0 │ symptoms │
╰───────────┴─────────────┴──────────╯
""",
stdout,
)

def test_labels_ignored(self):
"""Verify that we show info on ignored notes"""
with tempfile.TemporaryDirectory() as tmpdir:
common.write_json(
f"{tmpdir}/config.json",
{
"ignore": [3, 4, 6],
},
)
common.write_json(
f"{tmpdir}/labelstudio-export.json",
[
{"id": 3},
{"id": 4},
{"id": 5},
{"id": 6},
],
)
stdout = self.grab_output("--labels", "--project-dir", tmpdir)

self.assertEqual(
"Ignoring 3 charts (3–4, 6)",
stdout.splitlines()[-1].strip(),
)

0 comments on commit f69caea

Please sign in to comment.