Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

treescope visualization for attribute_context #284

Merged
merged 4 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,30 +240,19 @@ All commands support the full range of parameters available for `attribute`, att
<details>
<summary><code>inseq attribute-context</code> example</summary>

The following example uses a GPT-2 model to generate a continuation of <code>input_current_text</code>, and uses the additional context provided by <code>input_context_text</code> to estimate its influence on the the generation. In this case, the output <code>"to the hospital. He said he was fine"</code> is produced, and the generation of token <code>hospital</code> is found to be dependent on context token <code>sick</code> according to the <code>contrast_prob_diff</code> step function.
The following example uses a small LM to generate a continuation of <code>input_current_text</code>, and uses the additional context provided by <code>input_context_text</code> to estimate its influence on the the generation. In this case, the output <code>"to the hospital. He said he was fine"</code> is produced, and the generation of token <code>hospital</code> is found to be dependent on context token <code>sick</code> according to the <code>contrast_prob_diff</code> step function.

```bash
inseq attribute-context \
--model_name_or_path gpt2 \
--model_name_or_path HuggingFaceTB/SmolLM-135M \
--input_context_text "George was sick yesterday." \
--input_current_text "His colleagues asked him to come" \
--attributed_fn "contrast_prob_diff"
```

**Result:**

```
Context with [contextual cues] (std λ=1.00) followed by output sentence with {context-sensitive target spans} (std λ=1.00)
(CTI = "kl_divergence", CCI = "saliency" w/ "contrast_prob_diff" target)

Input context: George was sick yesterday.
Input current: His colleagues asked him to come
Output current: to the hospital. He said he was fine

#1.
Generated output (CTI > 0.428): to the {hospital}(0.548). He said he was fine
Input context (CCI > 0.460): George was [sick](0.516) yesterday.
```
<img src="https://raw.githubusercontent.com/inseq-team/inseq/main/docs/source/images/attribute_context_hospital_output.png" style="width:500px">
</details>

## Planned Development
Expand All @@ -280,7 +269,7 @@ Our vision for Inseq is to create a centralized, comprehensive and robust set of

## Citing Inseq

If you use Inseq in your research we suggest to include a mention to the specific release (e.g. v0.6.0) and we kindly ask you to cite our reference paper as:
If you use Inseq in your research we suggest including a mention of the specific release (e.g. v0.6.0) and we kindly ask you to cite our reference paper as:

```bibtex
@inproceedings{sarti-etal-2023-inseq,
Expand Down Expand Up @@ -308,7 +297,7 @@ If you use Inseq in your research we suggest to include a mention to the specifi
Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below.

> [!TIP]
> Last update: June 2024. Please open a pull request to add your publication to the list.
> Last update: August 2024. Please open a pull request to add your publication to the list.

<details>
<summary><b>2023</b></summary>
Expand All @@ -318,7 +307,6 @@ Inseq has been used in various research projects. A list of known publications t
<li> <a href="https://aclanthology.org/2023.nlp4convai-1.1/">Response Generation in Longitudinal Dialogues: Which Knowledge Representation Helps?</a> (Mousavi et al., 2023) </li>
<li> <a href="https://openreview.net/forum?id=XTHfNGI3zT">Quantifying the Plausibility of Context Reliance in Neural Machine Translation</a> (Sarti et al., 2023)</li>
<li> <a href="https://aclanthology.org/2023.emnlp-main.243/">A Tale of Pronouns: Interpretability Informs Gender Bias Mitigation for Fairer Instruction-Tuned Machine Translation</a> (Attanasio et al., 2023)</li>
<li> <a href="https://arxiv.org/abs/2310.09820">Assessing the Reliability of Large Language Model Knowledge</a> (Wang et al., 2023)</li>
<li> <a href="https://aclanthology.org/2023.conll-1.18/">Attribution and Alignment: Effects of Local Context Repetition on Utterance Production and Comprehension in Dialogue</a> (Molnar et al., 2023)</li>
</ol>

Expand All @@ -327,13 +315,15 @@ Inseq has been used in various research projects. A list of known publications t
<details>
<summary><b>2024</b></summary>
<ol>
<li><a href="https://arxiv.org/abs/2401.12576">LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools</a> (Wang et al., 2024)</li>
<li> <a href="https://aclanthology.org/2024.naacl-long.46/">Assessing the Reliability of Large Language Model Knowledge</a> (Wang et al., 2024)</li>
<li><a href="https://aclanthology.org/2024.hcinlp-1.9">LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools</a> (Wang et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2402.00794">ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models</a> (Zhao et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2404.02421">Revisiting subword tokenization: A case study on affixal negation in large language models</a> (Truong et al., 2024)</li>
<li><a href="https://aclanthology.org/2024.naacl-long.284">Revisiting subword tokenization: A case study on affixal negation in large language models</a> (Truong et al., 2024)</li>
<li><a href="https://hal.science/hal-04581586">Exploring NMT Explainability for Translators Using NMT Visualising Tools</a> (Gonzalez-Saez et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2405.14899">DETAIL: Task DEmonsTration Attribution for Interpretable In-context Learning</a> (Zhou et al., 2024)</li>
<li><a href="https://openreview.net/forum?id=uILj5HPrag">DETAIL: Task DEmonsTration Attribution for Interpretable In-context Learning</a> (Zhou et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2406.06399">Should We Fine-Tune or RAG? Evaluating Different Techniques to Adapt LLMs for Dialogue</a> (Alghisi et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2406.13663">Model Internals-based Answer Attribution for Trustworthy Retrieval-Augmented Generation</a> (Qi, Sarti et al., 2024)</li>
<li><a href="https://link.springer.com/chapter/10.1007/978-3-031-63787-2_14">NoNE Found: Explaining the Output of Sequence-to-Sequence Models When No Named Entity Is Recognized</a> (dela Cruz et al., 2024)</li>
</ol>

</details>
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 6 additions & 1 deletion inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def attribute_context(args: AttributeContextArgs) -> AttributeContextOutput:
model_kwargs=deepcopy(args.model_kwargs),
tokenizer_kwargs=deepcopy(args.tokenizer_kwargs),
)
if not isinstance(args.model_name_or_path, str):
args.model_name_or_path = model.model_name
return attribute_context_with_model(args, model)


Expand Down Expand Up @@ -167,6 +169,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
)
cci_kwargs = {}
contextless_output = None
contrast_token = None
if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn):
if not model.is_encoder_decoder:
formatted_input_current_text = concat_with_sep(
Expand All @@ -191,6 +194,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
contextless_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
)
tok_pos = -2 if model.is_encoder_decoder else -1
contrast_token = output_ctxless_tokens[tok_pos]
if args.attributed_fn == "kl_divergence" or output_ctx_tokens[tok_pos] == output_ctxless_tokens[tok_pos]:
cci_kwargs["contrast_force_inputs"] = True
bos_offset = int(model.is_encoder_decoder or output_ctx_tokens[0] == model.bos_token)
Expand Down Expand Up @@ -233,6 +237,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
cci_out = CCIOutput(
cti_idx=cti_idx,
cti_token=cti_tok,
contrast_token=contrast_token,
cti_score=cti_score,
contextual_output=contextual_output,
contextless_output=contextless_output,
Expand All @@ -241,7 +246,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
)
output.cci_scores.append(cci_out)
if args.show_viz or args.viz_path:
visualize_attribute_context(output, model, cti_threshold)
visualize_attribute_context(output, model, cti_threshold, args.show_viz, args.viz_path)
if not args.add_output_info:
output.info = None
if args.save_path:
Expand Down
103 changes: 102 additions & 1 deletion inseq/commands/attribute_context/attribute_context_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class CCIOutput:
cti_token: str
cti_score: float
contextual_output: str
contextless_output: str
contrast_token: str | None = None
contextless_output: str | None = None
input_context_scores: list[float] | None = None
output_context_scores: list[float] | None = None

Expand All @@ -34,6 +35,33 @@ def __repr__(self):
def to_dict(self) -> dict[str, Any]:
return dict(self.__dict__.items())

@property
def minimum(self) -> float:
scores = [0]
if self.input_context_scores:
scores.extend(self.input_context_scores)
if self.output_context_scores:
scores.extend(self.output_context_scores)
return min(scores)

@property
def maximum(self) -> float:
scores = [0]
if self.input_context_scores:
scores.extend(self.input_context_scores)
if self.output_context_scores:
scores.extend(self.output_context_scores)
return max(scores)

@property
def all_scores(self) -> list[float]:
scores = []
if self.input_context_scores:
scores.extend(self.input_context_scores)
if self.output_context_scores:
scores.extend(self.output_context_scores)
return scores


@dataclass
class AttributeContextOutput:
Expand All @@ -52,6 +80,13 @@ class AttributeContextOutput:
def __repr__(self):
return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})"

def __treescope_repr__(self, *args, **kwargs):
from inseq.commands.attribute_context.attribute_context_viz_helpers import (
visualize_attribute_context_treescope,
)

return visualize_attribute_context_treescope(self)

def to_dict(self) -> dict[str, Any]:
out_dict = {k: v for k, v in self.__dict__.items() if k not in ["cci_scores", "info"]}
out_dict["cci_scores"] = [cci_out.to_dict() for cci_out in self.cci_scores]
Expand All @@ -71,6 +106,72 @@ def from_dict(cls, out_dict: dict[str, Any]) -> "AttributeContextOutput":
out.info = AttributeContextArgs(**{k: v for k, v in out_dict["info"].items() if k in field_names})
return out

@property
def min_cti(self) -> float:
if self.cti_scores is None:
return -1
return min(self.cti_scores)

@property
def max_cti(self) -> float:
if self.cti_scores is None:
return -1
return max(self.cti_scores)

@property
def mean_cti(self) -> float:
if self.cti_scores is None:
return 0
return sum(self.cti_scores) / len(self.cti_scores)

@property
def std_cti(self) -> float:
if self.cti_scores is None:
return 0
return tensor(self.cti_scores).std().item()

@property
def min_cci(self) -> float:
if self.cci_scores is None:
return -1
return min(cci.minimum for cci in self.cci_scores)

@property
def max_cci(self) -> float:
if self.cci_scores is None:
return -1
return max(cci.maximum for cci in self.cci_scores)

@property
def cci_all_scores(self) -> list[float]:
if self.cci_scores is None:
return []
return [score for cci in self.cci_scores for score in cci.all_scores]

@property
def mean_cci(self) -> float:
if self.cci_scores is None:
return 0
return sum(self.cci_all_scores) / len(self.cci_all_scores)

@property
def std_cci(self) -> float:
if self.cci_scores is None:
return 0
return tensor(self.cci_all_scores).std().item()

@property
def input_context_scores(self) -> list[float] | None:
if self.cci_scores is None or self.cci_scores[0].input_context_scores is None:
return None
return [cci.input_context_scores for cci in self.cci_scores]

@property
def output_context_scores(self) -> list[float] | None:
if self.cci_scores is None or self.cci_scores[0].output_context_scores is None:
return None
return [cci.output_context_scores for cci in self.cci_scores]


def concat_with_sep(s1: str, s2: str, sep: str) -> bool:
"""Adds separator between two strings if needed."""
Expand Down
Loading
Loading