Skip to content

Commit

Permalink
treescope visualization for attribute_context (#284)
Browse files Browse the repository at this point in the history
* Basic viz working

* Treescope viz working for attribute_context

* Added contrastive alternative

* Fix tests
  • Loading branch information
gsarti authored Aug 14, 2024
1 parent 904c893 commit 9de007e
Show file tree
Hide file tree
Showing 7 changed files with 428 additions and 52 deletions.
30 changes: 10 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,30 +240,19 @@ All commands support the full range of parameters available for `attribute`, att
<details>
<summary><code>inseq attribute-context</code> example</summary>

The following example uses a GPT-2 model to generate a continuation of <code>input_current_text</code>, and uses the additional context provided by <code>input_context_text</code> to estimate its influence on the the generation. In this case, the output <code>"to the hospital. He said he was fine"</code> is produced, and the generation of token <code>hospital</code> is found to be dependent on context token <code>sick</code> according to the <code>contrast_prob_diff</code> step function.
The following example uses a small LM to generate a continuation of <code>input_current_text</code>, and uses the additional context provided by <code>input_context_text</code> to estimate its influence on the the generation. In this case, the output <code>"to the hospital. He said he was fine"</code> is produced, and the generation of token <code>hospital</code> is found to be dependent on context token <code>sick</code> according to the <code>contrast_prob_diff</code> step function.

```bash
inseq attribute-context \
--model_name_or_path gpt2 \
--model_name_or_path HuggingFaceTB/SmolLM-135M \
--input_context_text "George was sick yesterday." \
--input_current_text "His colleagues asked him to come" \
--attributed_fn "contrast_prob_diff"
```

**Result:**

```
Context with [contextual cues] (std λ=1.00) followed by output sentence with {context-sensitive target spans} (std λ=1.00)
(CTI = "kl_divergence", CCI = "saliency" w/ "contrast_prob_diff" target)
Input context: George was sick yesterday.
Input current: His colleagues asked him to come
Output current: to the hospital. He said he was fine
#1.
Generated output (CTI > 0.428): to the {hospital}(0.548). He said he was fine
Input context (CCI > 0.460): George was [sick](0.516) yesterday.
```
<img src="https://raw.githubusercontent.com/inseq-team/inseq/main/docs/source/images/attribute_context_hospital_output.png" style="width:500px">
</details>

## Planned Development
Expand All @@ -280,7 +269,7 @@ Our vision for Inseq is to create a centralized, comprehensive and robust set of

## Citing Inseq

If you use Inseq in your research we suggest to include a mention to the specific release (e.g. v0.6.0) and we kindly ask you to cite our reference paper as:
If you use Inseq in your research we suggest including a mention of the specific release (e.g. v0.6.0) and we kindly ask you to cite our reference paper as:

```bibtex
@inproceedings{sarti-etal-2023-inseq,
Expand Down Expand Up @@ -308,7 +297,7 @@ If you use Inseq in your research we suggest to include a mention to the specifi
Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below.

> [!TIP]
> Last update: June 2024. Please open a pull request to add your publication to the list.
> Last update: August 2024. Please open a pull request to add your publication to the list.
<details>
<summary><b>2023</b></summary>
Expand All @@ -318,7 +307,6 @@ Inseq has been used in various research projects. A list of known publications t
<li> <a href="https://aclanthology.org/2023.nlp4convai-1.1/">Response Generation in Longitudinal Dialogues: Which Knowledge Representation Helps?</a> (Mousavi et al., 2023) </li>
<li> <a href="https://openreview.net/forum?id=XTHfNGI3zT">Quantifying the Plausibility of Context Reliance in Neural Machine Translation</a> (Sarti et al., 2023)</li>
<li> <a href="https://aclanthology.org/2023.emnlp-main.243/">A Tale of Pronouns: Interpretability Informs Gender Bias Mitigation for Fairer Instruction-Tuned Machine Translation</a> (Attanasio et al., 2023)</li>
<li> <a href="https://arxiv.org/abs/2310.09820">Assessing the Reliability of Large Language Model Knowledge</a> (Wang et al., 2023)</li>
<li> <a href="https://aclanthology.org/2023.conll-1.18/">Attribution and Alignment: Effects of Local Context Repetition on Utterance Production and Comprehension in Dialogue</a> (Molnar et al., 2023)</li>
</ol>

Expand All @@ -327,13 +315,15 @@ Inseq has been used in various research projects. A list of known publications t
<details>
<summary><b>2024</b></summary>
<ol>
<li><a href="https://arxiv.org/abs/2401.12576">LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools</a> (Wang et al., 2024)</li>
<li> <a href="https://aclanthology.org/2024.naacl-long.46/">Assessing the Reliability of Large Language Model Knowledge</a> (Wang et al., 2024)</li>
<li><a href="https://aclanthology.org/2024.hcinlp-1.9">LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools</a> (Wang et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2402.00794">ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models</a> (Zhao et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2404.02421">Revisiting subword tokenization: A case study on affixal negation in large language models</a> (Truong et al., 2024)</li>
<li><a href="https://aclanthology.org/2024.naacl-long.284">Revisiting subword tokenization: A case study on affixal negation in large language models</a> (Truong et al., 2024)</li>
<li><a href="https://hal.science/hal-04581586">Exploring NMT Explainability for Translators Using NMT Visualising Tools</a> (Gonzalez-Saez et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2405.14899">DETAIL: Task DEmonsTration Attribution for Interpretable In-context Learning</a> (Zhou et al., 2024)</li>
<li><a href="https://openreview.net/forum?id=uILj5HPrag">DETAIL: Task DEmonsTration Attribution for Interpretable In-context Learning</a> (Zhou et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2406.06399">Should We Fine-Tune or RAG? Evaluating Different Techniques to Adapt LLMs for Dialogue</a> (Alghisi et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2406.13663">Model Internals-based Answer Attribution for Trustworthy Retrieval-Augmented Generation</a> (Qi, Sarti et al., 2024)</li>
<li><a href="https://link.springer.com/chapter/10.1007/978-3-031-63787-2_14">NoNE Found: Explaining the Output of Sequence-to-Sequence Models When No Named Entity Is Recognized</a> (dela Cruz et al., 2024)</li>
</ol>

</details>
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 6 additions & 1 deletion inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def attribute_context(args: AttributeContextArgs) -> AttributeContextOutput:
model_kwargs=deepcopy(args.model_kwargs),
tokenizer_kwargs=deepcopy(args.tokenizer_kwargs),
)
if not isinstance(args.model_name_or_path, str):
args.model_name_or_path = model.model_name
return attribute_context_with_model(args, model)


Expand Down Expand Up @@ -167,6 +169,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
)
cci_kwargs = {}
contextless_output = None
contrast_token = None
if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn):
if not model.is_encoder_decoder:
formatted_input_current_text = concat_with_sep(
Expand All @@ -191,6 +194,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
contextless_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
)
tok_pos = -2 if model.is_encoder_decoder else -1
contrast_token = output_ctxless_tokens[tok_pos]
if args.attributed_fn == "kl_divergence" or output_ctx_tokens[tok_pos] == output_ctxless_tokens[tok_pos]:
cci_kwargs["contrast_force_inputs"] = True
bos_offset = int(model.is_encoder_decoder or output_ctx_tokens[0] == model.bos_token)
Expand Down Expand Up @@ -233,6 +237,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
cci_out = CCIOutput(
cti_idx=cti_idx,
cti_token=cti_tok,
contrast_token=contrast_token,
cti_score=cti_score,
contextual_output=contextual_output,
contextless_output=contextless_output,
Expand All @@ -241,7 +246,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
)
output.cci_scores.append(cci_out)
if args.show_viz or args.viz_path:
visualize_attribute_context(output, model, cti_threshold)
visualize_attribute_context(output, model, cti_threshold, args.show_viz, args.viz_path)
if not args.add_output_info:
output.info = None
if args.save_path:
Expand Down
103 changes: 102 additions & 1 deletion inseq/commands/attribute_context/attribute_context_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class CCIOutput:
cti_token: str
cti_score: float
contextual_output: str
contextless_output: str
contrast_token: str | None = None
contextless_output: str | None = None
input_context_scores: list[float] | None = None
output_context_scores: list[float] | None = None

Expand All @@ -34,6 +35,33 @@ def __repr__(self):
def to_dict(self) -> dict[str, Any]:
return dict(self.__dict__.items())

@property
def minimum(self) -> float:
scores = [0]
if self.input_context_scores:
scores.extend(self.input_context_scores)
if self.output_context_scores:
scores.extend(self.output_context_scores)
return min(scores)

@property
def maximum(self) -> float:
scores = [0]
if self.input_context_scores:
scores.extend(self.input_context_scores)
if self.output_context_scores:
scores.extend(self.output_context_scores)
return max(scores)

@property
def all_scores(self) -> list[float]:
scores = []
if self.input_context_scores:
scores.extend(self.input_context_scores)
if self.output_context_scores:
scores.extend(self.output_context_scores)
return scores


@dataclass
class AttributeContextOutput:
Expand All @@ -52,6 +80,13 @@ class AttributeContextOutput:
def __repr__(self):
return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})"

def __treescope_repr__(self, *args, **kwargs):
from inseq.commands.attribute_context.attribute_context_viz_helpers import (
visualize_attribute_context_treescope,
)

return visualize_attribute_context_treescope(self)

def to_dict(self) -> dict[str, Any]:
out_dict = {k: v for k, v in self.__dict__.items() if k not in ["cci_scores", "info"]}
out_dict["cci_scores"] = [cci_out.to_dict() for cci_out in self.cci_scores]
Expand All @@ -71,6 +106,72 @@ def from_dict(cls, out_dict: dict[str, Any]) -> "AttributeContextOutput":
out.info = AttributeContextArgs(**{k: v for k, v in out_dict["info"].items() if k in field_names})
return out

@property
def min_cti(self) -> float:
if self.cti_scores is None:
return -1
return min(self.cti_scores)

@property
def max_cti(self) -> float:
if self.cti_scores is None:
return -1
return max(self.cti_scores)

@property
def mean_cti(self) -> float:
if self.cti_scores is None:
return 0
return sum(self.cti_scores) / len(self.cti_scores)

@property
def std_cti(self) -> float:
if self.cti_scores is None:
return 0
return tensor(self.cti_scores).std().item()

@property
def min_cci(self) -> float:
if self.cci_scores is None:
return -1
return min(cci.minimum for cci in self.cci_scores)

@property
def max_cci(self) -> float:
if self.cci_scores is None:
return -1
return max(cci.maximum for cci in self.cci_scores)

@property
def cci_all_scores(self) -> list[float]:
if self.cci_scores is None:
return []
return [score for cci in self.cci_scores for score in cci.all_scores]

@property
def mean_cci(self) -> float:
if self.cci_scores is None:
return 0
return sum(self.cci_all_scores) / len(self.cci_all_scores)

@property
def std_cci(self) -> float:
if self.cci_scores is None:
return 0
return tensor(self.cci_all_scores).std().item()

@property
def input_context_scores(self) -> list[float] | None:
if self.cci_scores is None or self.cci_scores[0].input_context_scores is None:
return None
return [cci.input_context_scores for cci in self.cci_scores]

@property
def output_context_scores(self) -> list[float] | None:
if self.cci_scores is None or self.cci_scores[0].output_context_scores is None:
return None
return [cci.output_context_scores for cci in self.cci_scores]


def concat_with_sep(s1: str, s2: str, sep: str) -> bool:
"""Adds separator between two strings if needed."""
Expand Down
Loading

0 comments on commit 9de007e

Please sign in to comment.