Skip to content

Commit 53eedc8

Browse files
authored
Migrate to ruff format (#225)
1 parent 8d50c70 commit 53eedc8

19 files changed

+1185
-1153
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,10 @@ repos:
1818
- id: end-of-file-fixer
1919
exclude: LICENSE
2020

21-
- repo: local
22-
hooks:
23-
- id: black
24-
name: black
25-
entry: poetry run black --config pyproject.toml
26-
types: [python]
27-
language: system
28-
29-
- repo: https://github.com/charliermarsh/ruff-pre-commit
30-
rev: 'v0.0.267'
21+
- repo: https://github.com/astral-sh/ruff-pre-commit
22+
rev: v0.1.2
3123
hooks:
24+
- id: ruff-format
3225
- id: ruff
3326

3427
- repo: local

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ make lint
2424

2525
### Checks
2626

27-
Many checks are configured for this project. Command `make check-style` will check black and isort.
27+
Many checks are configured for this project. Command `make check-style` will check style with `ruff`.
2828
The `make check-safety` command will look at the security of your code.
2929

3030
Comand `make lint` applies all checks.

Makefile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,15 @@ update-deps:
7878
#* Linting
7979
.PHONY: check-style
8080
check-style:
81-
poetry run black --diff --check --config pyproject.toml ./
82-
poetry run ruff --no-fix --config pyproject.toml ./
81+
poetry run ruff format --check --config pyproject.toml ./
82+
poetry run ruff check --no-fix --config pyproject.toml ./
8383
# poetry run darglint --verbosity 2 inseq tests
8484
# poetry run mypy --config-file pyproject.toml ./
8585

8686
.PHONY: fix-style
8787
fix-style:
88-
poetry run black --config pyproject.toml ./
89-
poetry run ruff --config pyproject.toml ./
88+
poetry run ruff format --config pyproject.toml ./
89+
poetry run ruff check --config pyproject.toml ./
9090

9191
.PHONY: check-safety
9292
check-safety:

examples/inseq_tutorial.ipynb

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@
170170
"source": [
171171
"import inseq\n",
172172
"\n",
173-
"# Load the model Helsinki-NLP/opus-mt-en-fr (6-layer encoder-decoder transformer) from the \n",
173+
"# Load the model Helsinki-NLP/opus-mt-en-fr (6-layer encoder-decoder transformer) from the\n",
174174
"# Huggingface Hub and hook it with the Input X Gradient feature attribution method\n",
175175
"model = inseq.load_model(\"Helsinki-NLP/opus-mt-en-it\", \"input_x_gradient\")\n",
176176
"\n",
@@ -180,7 +180,7 @@
180180
"out = model.attribute(\n",
181181
" input_texts=\"Hello everyone, hope you're enjoying the tutorial!\",\n",
182182
" attribute_target=True,\n",
183-
" step_scores=[\"probability\"]\n",
183+
" step_scores=[\"probability\"],\n",
184184
")\n",
185185
"# Visualize the attributions and step scores\n",
186186
"out.show()"
@@ -349,9 +349,7 @@
349349
],
350350
"source": [
351351
"out = model.attribute(\n",
352-
" input_texts=\"Hello everyone, hope you're enjoying the tutorial!\",\n",
353-
" attribute_target=True,\n",
354-
" method=\"attention\"\n",
352+
" input_texts=\"Hello everyone, hope you're enjoying the tutorial!\", attribute_target=True, method=\"attention\"\n",
355353
")\n",
356354
"# out[0] is a shortcut for out.sequence_attributions[0]\n",
357355
"out[0].source_attributions.shape"
@@ -535,10 +533,10 @@
535533
],
536534
"source": [
537535
"# Gets the mean weights of the first three attention heads only, no normalization\n",
538-
"# do_post_aggregation_checks=False is needed since the output has >2 dimensions and \n",
536+
"# do_post_aggregation_checks=False is needed since the output has >2 dimensions and\n",
539537
"# could not be visualized\n",
540538
"aggregated_heads_seq_attr_out = out[0].aggregate(\n",
541-
" \"mean\", select_idx=(0,3), normalize=False, do_post_aggregation_checks=False\n",
539+
" \"mean\", select_idx=(0, 3), normalize=False, do_post_aggregation_checks=False\n",
542540
")\n",
543541
"\n",
544542
"# (source_len, target_len, num_layers)\n",
@@ -726,7 +724,7 @@
726724
" \"Domanda: Quanti studenti hanno partecipato alle LCL nel 2023?\"\n",
727725
")\n",
728726
"\n",
729-
"qa_model = inseq.load_model(\"it5/it5-base-question-answering\", \"input_x_gradient\")\n",
727+
"qa_model = inseq.load_model(\"it5/it5-base-question-answering\", \"input_x_gradient\")\n",
730728
"out = qa_model.attribute(question, attribute_target=True, step_scores=[\"probability\"])\n",
731729
"\n",
732730
"# Aggregate only source tokens, leave target tokens as they are\n",
@@ -1097,7 +1095,7 @@
10971095
" contrast_targets=\"Ho salutato la manager\",\n",
10981096
" attribute_target=True,\n",
10991097
" # We also visualize the score used as target using the same function as step score\n",
1100-
" step_scores=[\"contrast_prob_diff\"]\n",
1098+
" step_scores=[\"contrast_prob_diff\"],\n",
11011099
")\n",
11021100
"\n",
11031101
"# Weight attribution scores by the difference in probabilities\n",
@@ -1212,10 +1210,18 @@
12121210
")\n",
12131211
"\n",
12141212
"source_without_context = \"Do you already know when you'll be back?\"\n",
1215-
"source_with_context = \"Thank you for your help, my friend, you really saved my life. Do you already know when you'll be back?\"\n",
1213+
"source_with_context = (\n",
1214+
" \"Thank you for your help, my friend, you really saved my life. Do you already know when you'll be back?\"\n",
1215+
")\n",
12161216
"\n",
1217-
"print(\"Generation without context:\", model.generate(source_without_context, forced_bos_token_id=model.tokenizer.lang_code_to_id[\"it_IT\"]))\n",
1218-
"print(\"Generation with context:\", model.generate(source_with_context, forced_bos_token_id=model.tokenizer.lang_code_to_id[\"it_IT\"]))\n",
1217+
"print(\n",
1218+
" \"Generation without context:\",\n",
1219+
" model.generate(source_without_context, forced_bos_token_id=model.tokenizer.lang_code_to_id[\"it_IT\"]),\n",
1220+
")\n",
1221+
"print(\n",
1222+
" \"Generation with context:\",\n",
1223+
" model.generate(source_with_context, forced_bos_token_id=model.tokenizer.lang_code_to_id[\"it_IT\"]),\n",
1224+
")\n",
12191225
"\n",
12201226
"out = model.attribute(\n",
12211227
" source_without_context,\n",
@@ -1224,7 +1230,7 @@
12241230
" contrast_targets=\"Grazie per il tuo aiuto, mi hai davvero salvato la vita. Sai già quando tornerai?\",\n",
12251231
" attribute_target=True,\n",
12261232
" # We also visualize the score used as target using the same function as step score\n",
1227-
" step_scores=[\"pcxmi\", \"probability\"]\n",
1233+
" step_scores=[\"pcxmi\", \"probability\"],\n",
12281234
")\n",
12291235
"\n",
12301236
"out.show()"
@@ -1336,8 +1342,8 @@
13361342
],
13371343
"source": [
13381344
"# Print tokens to get token indices\n",
1339-
"print([(i, x) for i, x in enumerate(model.encode(mt_target, as_targets=True).input_tokens[0])])\n",
1340-
"print([(i, x) for i, x in enumerate(model.encode(pe_target, as_targets=True).input_tokens[0])])"
1345+
"print(list(enumerate(model.encode(mt_target, as_targets=True).input_tokens[0])))\n",
1346+
"print(list(enumerate(model.encode(pe_target, as_targets=True).input_tokens[0])))"
13411347
]
13421348
},
13431349
{
@@ -1394,7 +1400,7 @@
13941400
" attributed_fn=\"contrast_prob_diff\",\n",
13951401
" step_scores=[\"contrast_prob_diff\"],\n",
13961402
" contrast_targets=pe_target,\n",
1397-
" contrast_targets_alignments=[(0,0), (1,1), (2,2), (3,4), (4,4), (5,5), (6,7), (7,9)],\n",
1403+
" contrast_targets_alignments=[(0, 0), (1, 1), (2, 2), (3, 4), (4, 4), (5, 5), (6, 7), (7, 9)],\n",
13981404
")\n",
13991405
"\n",
14001406
"# Reasonable alignments\n",
@@ -1504,9 +1510,10 @@
15041510
"metadata": {},
15051511
"outputs": [],
15061512
"source": [
1507-
"import inseq\n",
15081513
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
15091514
"\n",
1515+
"import inseq\n",
1516+
"\n",
15101517
"# The model is loaded in 8-bit on available GPUs using the bitsandbytes library integrated in HF Transformers\n",
15111518
"# This will make the model much smaller for inference purposes, but attributions are not guaranteed to match those\n",
15121519
"# of the full-precision model.\n",
@@ -1930,15 +1937,16 @@
19301937
}
19311938
],
19321939
"source": [
1933-
"from inseq import FeatureAttributionOutput\n",
19341940
"import pandas as pd\n",
19351941
"\n",
1942+
"from inseq import FeatureAttributionOutput\n",
1943+
"\n",
19361944
"scores = {}\n",
19371945
"\n",
19381946
"for layer_idx in range(48):\n",
19391947
" curr_out = FeatureAttributionOutput.load(f\"../data/cat_outputs/layer_{layer_idx}.json\")\n",
19401948
" out_dict = curr_out.get_scores_dicts(do_aggregation=False)[0]\n",
1941-
" scores[layer_idx] = [score for score in out_dict[\"target_attributions\"][\"ĠParis\"].values()][:-1]\n",
1949+
" scores[layer_idx] = list(out_dict[\"target_attributions\"][\"ĠParis\"].values())[:-1]\n",
19421950
"\n",
19431951
"prefix_tokens = list(out_dict[\"target_attributions\"][\"ĠParis\"].keys())\n",
19441952
"attributions_df = pd.DataFrame(scores, index=prefix_tokens[:-1])\n",
@@ -1989,7 +1997,7 @@
19891997
"ax.set_xticks([0.5 + i for i in range(0, attributions_df.values.shape[1], 4)])\n",
19901998
"ax.set_xticklabels(list(range(0, 48, 4)))\n",
19911999
"ax.set_yticklabels(attributions_df.index)\n",
1992-
"cb = plt.colorbar(h, ticks=[0, .15, .3, .45, .6, .75])\n",
2000+
"cb = plt.colorbar(h, ticks=[0, 0.15, 0.3, 0.45, 0.6, 0.75])\n",
19932001
"fig.suptitle(\"What activations are contributing to predicting 'Paris' over 'Rome'?\")\n",
19942002
"plt.savefig(filename)\n",
19952003
"plt.show()"

inseq/attr/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from .step_functions import (
33
STEP_SCORES_MAP,
44
StepFunctionArgs,
5-
StepFunctionEncoderDecoderArgs,
65
list_step_functions,
76
register_step_function,
87
)
@@ -15,5 +14,4 @@
1514
"STEP_SCORES_MAP",
1615
"extract_args",
1716
"StepFunctionArgs",
18-
"StepFunctionEncoderDecoderArgs",
1917
]

inseq/attr/feat/attribution_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from ...models import AttributionModel
2020
from .feature_attribution import FeatureAttribution
2121

22-
2322
logger = logging.getLogger(__name__)
2423

2524

@@ -43,7 +42,7 @@ def rescale_attributions_to_tokens(
4342
attributions: OneOrMoreAttributionSequences, tokens: OneOrMoreTokenSequences
4443
) -> OneOrMoreAttributionSequences:
4544
return [
46-
attr[: len(tokens)] if not all([math.isnan(x) for x in attr]) else []
45+
attr[: len(tokens)] if not all(math.isnan(x) for x in attr) else []
4746
for attr, tokens in zip(attributions, tokens)
4847
]
4948

@@ -154,8 +153,7 @@ def get_source_target_attributions(
154153
return attr[0], None
155154
else:
156155
return attr, None
156+
elif isinstance(attr, tuple):
157+
return None, attr[0]
157158
else:
158-
if isinstance(attr, tuple):
159-
return None, attr[0]
160-
else:
161-
return None, attr
159+
return None, attr

inseq/attr/feat/internals_attribution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
class InternalsAttributionRegistry(FeatureAttribution, Registry):
3131
r"""Model Internals-based attribution method registry."""
32+
3233
pass
3334

3435

inseq/attr/feat/ops/lime.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
import inspect
2+
import logging
23
import math
3-
import warnings
44
from functools import partial
55
from typing import Any, Callable, Optional, cast
66

77
import torch
8-
from captum._utils.common import (
9-
_expand_additional_forward_args,
10-
_expand_target,
11-
)
8+
from captum._utils.common import _expand_additional_forward_args, _expand_target
129
from captum._utils.models.linear_model import SkLearnLinearModel
1310
from captum._utils.models.model import Model
1411
from captum._utils.progress import progress
15-
from captum._utils.typing import (
16-
TargetType,
17-
TensorOrTupleOfTensorsGeneric,
18-
)
12+
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
1913
from captum.attr import LimeBase
2014
from torch import Tensor
2115
from torch.utils.data import DataLoader, TensorDataset
2216

17+
logger = logging.getLogger(__name__)
18+
2319

2420
class Lime(LimeBase):
2521
def __init__(
@@ -135,7 +131,7 @@ def attribute(
135131
try:
136132
curr_sample = next(perturb_generator)
137133
except StopIteration:
138-
warnings.warn("Generator completed prior to given n_samples iterations!")
134+
logger.warning("Generator completed prior to given n_samples iterations!")
139135
break
140136
else:
141137
curr_sample = self.perturb_func(inputs, **kwargs)

inseq/attr/feat/ops/sequential_integrated_gradients.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ def attribute( # type: ignore
151151
method: str = "gausslegendre",
152152
internal_batch_size: Union[None, int] = None,
153153
return_convergence_delta: bool = False,
154-
) -> Union[TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor],]:
154+
) -> Union[
155+
TensorOrTupleOfTensorsGeneric,
156+
Tuple[TensorOrTupleOfTensorsGeneric, Tensor],
157+
]:
155158
r"""
156159
This method attributes the output of the model with given target index
157160
(in case it is provided, otherwise it assumes that output is a

inseq/data/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _eq(self_attr: TensorClass, other_attr: TensorClass) -> bool:
114114
if isinstance(self_attr, torch.Tensor):
115115
return torch.allclose(self_attr, other_attr, equal_nan=True)
116116
elif isinstance(self_attr, dict):
117-
return all([TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys()])
117+
return all(TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys())
118118
else:
119119
return self_attr == other_attr
120120
except: # noqa: E722

0 commit comments

Comments
 (0)