Skip to content

Commit

Permalink
Fixed batched attribution
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Dec 1, 2021
1 parent 8adf86a commit 3ede9f8
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 29 deletions.
5 changes: 2 additions & 3 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def prepare(
depending on the number of inputs.
"""
if isinstance(sources, str) or isinstance(sources, list):
sources: BatchEncoding = self.attribution_model.encode_texts(sources, return_baseline=True)
sources: BatchEncoding = self.attribution_model.encode(sources, return_baseline=True)
if isinstance(sources, BatchEncoding):
if self.is_layer_attribution:
embeds = BatchEmbedding(None, None)
Expand All @@ -228,7 +228,7 @@ def prepare(
)
sources = Batch(sources, embeds)
if isinstance(targets, str) or isinstance(targets, list):
targets: BatchEncoding = self.attribution_model.encode_texts(
targets: BatchEncoding = self.attribution_model.encode(
targets,
as_targets=True,
prepend_bos_token=prepend_bos_token,
Expand Down Expand Up @@ -497,7 +497,6 @@ def format_attribute_args(
target_ids: TargetIdsTensor,
**kwargs,
) -> Dict[str, Any]:
logger.debug(f"batch: {batch},\ntarget_ids: {pretty_tensor(target_ids, lpad=4)}")
# For now only encoder attribution is supported
if self.is_layer_attribution:
inputs = batch.sources.input_ids
Expand Down
3 changes: 2 additions & 1 deletion inseq/attr/feat/gradient_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def attribute_step(
of size `(batch_size)`, if the attribution step supports deltas and they are requested.
"""
attribute_args = self.format_attribute_args(batch, target_ids, **kwargs)
logger.debug(f"batch: {batch},\ntarget_ids: {pretty_tensor(target_ids, lpad=4)}")
attr = self.method.attribute(**attribute_args)
delta = None
if (
Expand All @@ -93,7 +94,7 @@ def attribute_step(
and self.method.has_convergence_delta()
):
attr, delta = attr
logger.debug(f"attributions prenorm: {pretty_tensor(attr)}\n")
logger.debug(f"attributions prenorm: {pretty_tensor(attr)}, summed: {attr.sum(dim=-1).squeeze(0)}\n")
attr = sum_normalize(attr, dim_sum=-1)
logger.debug(f"attributions: {pretty_tensor(attr)}\n" + "-" * 30)
return (attr, delta) if delta is not None else attr
Expand Down
17 changes: 9 additions & 8 deletions inseq/data/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def show_attributions(
attributions: OneOrMoreFeatureAttributionSequenceOutputs,
min_val: Optional[int] = None,
max_val: Optional[int] = None,
display_html: bool = True,
return_html: Optional[bool] = False,
) -> Optional[str]:
if not return_html:
if display_html:
try:
from IPython.core.display import HTML, display
except ImportError:
Expand All @@ -63,12 +64,12 @@ def show_attributions(
max_val = max(attribution.maximum for attribution in attributions)
html_out = ""
for i, attribution in enumerate(attributions):
if not return_html:
display(HTML(get_instance_html(i)))
display(HTML(seq2seq_plots(attribution, min_val, max_val)))
else:
html_out += get_instance_html(i)
html_out += seq2seq_plots(attribution, min_val, max_val)
curr_html = ""
curr_html += get_instance_html(i)
curr_html += seq2seq_plots(attribution, min_val, max_val)
if display_html:
display(HTML(curr_html))
html_out += curr_html
if return_html:
return html_out

Expand Down Expand Up @@ -124,7 +125,7 @@ def get_progress_bar(
return None
elif show and not pretty:
return tqdm(
total=max([tgt_len for _, _, tgt_len in target_sentences]),
total=max(tgt_len for _, _, tgt_len in target_sentences),
desc=f"Attributing with {method_name}...",
)
elif show and pretty:
Expand Down
12 changes: 7 additions & 5 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def attribute(
return []
texts, reference_texts = self.format_input_texts(texts, reference_texts)
if not reference_texts:
texts = self.encode_texts(texts, return_baseline=True)
texts = self.encode(texts, return_baseline=True)
generation_args = kwargs.pop("generation_args", {})
reference_texts = self.generate(texts, return_generation_output=False, **generation_args)
logger.debug(f"reference_texts={reference_texts}")
Expand All @@ -155,8 +155,10 @@ def attribute(
)

def embed(self, inputs: Union[TextInput, IdsTensor], as_targets: bool = False):
if isinstance(inputs, str) or (isinstance(inputs, list) and inputs[0] == str):
batch = self.encode_texts(inputs, as_targets)
if isinstance(inputs, str) or (
isinstance(inputs, list) and len(inputs) > 0 and all([isinstance(x, str) for x in inputs])
):
batch = self.encode(inputs, as_targets)
inputs = batch.input_ids
if as_targets:
return self.decoder_embed_ids(inputs)
Expand All @@ -170,14 +172,14 @@ def score_func(self, **kwargs) -> torch.Tensor:
@abstractmethod
def generate(
self,
encodings: BatchEncoding,
encodings: Union[TextInput, BatchEncoding],
return_generation_output: Optional[bool] = False,
**kwargs,
) -> Union[List[str], Tuple[List[str], Any]]:
pass

@abstractmethod
def encode_texts(self, texts: TextInput, as_targets: Optional[bool] = False, *args) -> BatchEncoding:
def encode(self, texts: TextInput, as_targets: Optional[bool] = False, *args) -> BatchEncoding:
pass

@abstractmethod
Expand Down
18 changes: 11 additions & 7 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def score_func(
@unhooked
def generate(
self,
encodings: BatchEncoding,
encodings: Union[TextInput, BatchEncoding],
return_generation_output: Literal[False] = False,
**kwargs,
) -> List[str]:
Expand All @@ -149,7 +149,7 @@ def generate(
@unhooked
def generate(
self,
encodings: BatchEncoding,
encodings: Union[TextInput, BatchEncoding],
return_generation_output: Literal[True],
**kwargs,
) -> Tuple[List[str], GenerationOutput]:
Expand All @@ -158,13 +158,17 @@ def generate(
@unhooked
def generate(
self,
encodings: BatchEncoding,
inputs: Union[TextInput, BatchEncoding],
return_generation_output: Optional[bool] = False,
**kwargs,
) -> Union[List[str], Tuple[List[str], GenerationOutput]]:
if isinstance(inputs, str) or (
isinstance(inputs, list) and len(inputs) > 0 and all([isinstance(x, str) for x in inputs])
):
inputs = self.encode(inputs)
generation_out = self.model.generate(
input_ids=encodings.input_ids,
attention_mask=encodings.attention_mask,
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
return_dict_in_generate=True,
**kwargs,
)
Expand All @@ -177,7 +181,7 @@ def generate(
return texts, generation_out
return texts

def encode_texts(
def encode(
self,
texts: TextInput,
as_targets: Optional[bool] = False,
Expand All @@ -198,7 +202,7 @@ def encode_texts(
# Some tokenizer have weird values for max_len_single_sentence
# Cap length with max_model_input_sizes instead
if max_length > 1e6:
max_length = max([v for _, v in self.tokenizer.max_model_input_sizes.items()])
max_length = max(v for _, v in self.tokenizer.max_model_input_sizes.items())
batch = self.tokenizer(
texts,
add_special_tokens=True,
Expand Down
3 changes: 2 additions & 1 deletion inseq/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def sum_normalize(
) -> AttributionOutputTensor:
"""
Sum and normalize tensor across dim_sum.
The outcome is a matrix of unit row vectors.
"""
attributions = attributions.sum(dim=dim_sum).squeeze(0)
attributions = attributions / torch.norm(attributions)
attributions = attributions.T.div(torch.norm(attributions, dim=dim_sum)).T
if len(attributions.shape) == 1:
return attributions.unsqueeze(0)
return attributions
Expand Down
8 changes: 4 additions & 4 deletions inseq/utils/viz_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def red_transparent_blue_colormap():
return LinearSegmentedColormap.from_list("red_transparent_blue", colors)


def get_color(score, cmax, cmap):
scaled_value = 0.5 + 0.5 * score / cmax
def get_color(score, min_value, max_value, cmap):
# Normalize between 0-1 for the color scale
scaled_value = (score - min_value) / (max_value - min_value)
color = cmap(scaled_value)
color = "rgba" + str((color[0] * 255, color[1] * 255, color[2] * 255, color[3]))
return color
Expand All @@ -53,11 +54,10 @@ def get_colors(
cmap,
):
input_colors = []
cmax = max(abs(min_value), abs(max_value))
for row_index in range(scores.shape[0]):
input_colors_row = []
for col_index in range(scores.shape[1]):
color = get_color(scores[row_index, col_index], cmax, cmap)
color = get_color(scores[row_index, col_index], min_value, max_value, cmap)
input_colors_row.append(color)
input_colors.append(input_colors_row)
return input_colors
Expand Down

0 comments on commit 3ede9f8

Please sign in to comment.