Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

treescope support and new visualizations #283

Merged
merged 13 commits into from
Aug 9, 2024
Merged

treescope support and new visualizations #283

merged 13 commits into from
Aug 9, 2024

Conversation

gsarti
Copy link
Member

@gsarti gsarti commented Aug 1, 2024

Description

The new treescope stand-alone package provides powerful utilities for visualizing objects and tensors in an interactive environment. It would be a perfect addition to improve the flexibility of visualizations for input attribution scores offered by Inseq.

Proposed changes

1. Interactive Objects Visualization

  • Enable treescope by default in notebook environments to improve the explorability of AttributionModel weights and FeatureAttributionOutput objects:
image image

Notes

1.1. Define a __treescope_repr__ for FeatureAttributionOutput and FeatureAttributionSequenceOutput classes to make the visualization more informative:
- Add length info for source and target fields.
- Override source_attributions and target_attributions default array visualizations with the one used in the new show_granular_attributions method (see below for details)
- Add highlighted tokens in the format of show_tokens_importance with hoverable scores to every step_score.

1.2 Since we use red as positive and blue as negative following feature attribution common practices, the default treescope colormap will need to be reversed:

import treescope

treescope.basic_interactive_setup()
cmap = list(reversed([
    (96, 14, 34),
    (134, 14, 41),
    (167, 36, 36),
    (186, 72, 46),
    (198, 107, 77),
    (208, 139, 115),
    (218, 171, 155),
    (228, 203, 196),
    (241, 236, 235),
    (202, 212, 216),
    (161, 190, 200),
    (117, 170, 190),
    (75, 148, 186),
    (38, 123, 186),
    (12, 94, 190),
    (41, 66, 162),
    (37, 47, 111),
]))
treescope.default_diverging_colormap.set_globally(cmap)
treescope.default_sequential_colormap.set_globally(cmap)

2. New show_granular method for granular attribution visualization

The new show_granular method of the FeatureAttributionSequenceOuput class would exploit NDArray visualizations from treescope to complement the current show method, enabling the visualization of N-dimensional attribution tensors without aggregation. By contrast, show currently applies the default aggregator with out.aggregate(), since it is limited to outputs that can be visualized as 2D matrices.

Notes

2.1. Will require adding named dimensions to all FeatureAttributionSequenceOutput objects. The first two will be Input Tokens and Generated Tokens by default, but all the following ones will depend on the object, e.g. Embedding Dimension for saliency attribution (3D), Attention Head and Model Layer for attention attribution (4D).

2.2 To make the visualization more effective, every extra dimension beyond the 3rd or with more than a maximum_size parameter elements (e.g. 20) will be moved to a slider.

Example: Visualizing dimension 41 of a saliency attribution tensor with 95 input tokens and 18 generated tokens

image

Example: Visualizing attention weights across all model layers for attention head 0 (input/generated token pair and numeric score shown on hover)

image

2.3 Enable slicing of dimensions with a slices dictionary (e.g. slices = {2: range(20)} to get the first 20 dimensions of the saliency tensor from 2.2. If the sliced dimension is shorter than maximum_size, it is visualized.

Example: Saliency attribution tensor from 2.2 with 0:20 slicing over Embedding Dimension applied

image

3. New show_tokens method for text highlights

The new show_tokens method would make large attributions more legible by presenting scores as text highlights, exploiting the text_on_color method in treescope. Hoverable input tokens with highlights will be shown for every generation step, with color highlight matching the respective token attribution score for the generated token:

image

Notes

3.1 Parameters:
- replace_char: a dict to map special characters to arbitrary characters (default: no replacement). E.g. replace_char = {"Ġ": " ", "▁": " ", "Ċ": ""} to clean up special characters for GPT-like and SentecePiece tokenization.
- wrap_after: If an int, this is the number of tokens after which a newline is inserted (e.g. 10 in the figure). If it's a string or an iterable, a newline is inserted if the current token is contained in it. Default: no wrapping (one line per attributed sequence).

@gsarti
Copy link
Member Author

gsarti commented Aug 8, 2024

The current version of show_tokens was further improved to make it more concise, with collapsible for every generated token revealing the source/target attributions for it:

Example: clicking on villaggio reveals attributions showing high scores for village in the source for an MT model

image

Another argument named step_score_highlight can be used to highlight generated tokens according to a step score collected via the model.attribute method.

Example using the probability of generated tokens for highlighting:

import inseq

model = inseq.load_model("Helsinki-NLP/opus-mt-en-it", "saliency")

input_prompt = """ In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened."""

out = model.attribute(
    input_texts=input_prompt,
    step_scores=["probability", "entropy"],
    attribute_target=True,
)

out.show_tokens(step_score_highlight="probability")
image

@gsarti gsarti marked this pull request as ready for review August 8, 2024 14:20
@gsarti
Copy link
Member Author

gsarti commented Aug 8, 2024

Issue: treescope conflicts with rich.jupyter.JupyterRenderable for displaying attribute_context outputs in notebook environments. Awaiting updates in google-deepmind/treescope#18.

Ideally attribute_context visualizations would be entirely replaced by treescope with a visualization in the style of show_tokens. Should be handled on a separate PR once this is merged.

EDIT: Solved by temporarily unhooking treescope visualizer when showing attribute_context outputs.

@gsarti gsarti merged commit 904c893 into main Aug 9, 2024
3 checks passed
@gsarti gsarti deleted the treescope branch August 9, 2024 14:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant