|
170 | 170 | "source": [
|
171 | 171 | "import inseq\n",
|
172 | 172 | "\n",
|
173 |
| - "# Load the model Helsinki-NLP/opus-mt-en-fr (6-layer encoder-decoder transformer) from the \n", |
| 173 | + "# Load the model Helsinki-NLP/opus-mt-en-fr (6-layer encoder-decoder transformer) from the\n", |
174 | 174 | "# Huggingface Hub and hook it with the Input X Gradient feature attribution method\n",
|
175 | 175 | "model = inseq.load_model(\"Helsinki-NLP/opus-mt-en-it\", \"input_x_gradient\")\n",
|
176 | 176 | "\n",
|
|
180 | 180 | "out = model.attribute(\n",
|
181 | 181 | " input_texts=\"Hello everyone, hope you're enjoying the tutorial!\",\n",
|
182 | 182 | " attribute_target=True,\n",
|
183 |
| - " step_scores=[\"probability\"]\n", |
| 183 | + " step_scores=[\"probability\"],\n", |
184 | 184 | ")\n",
|
185 | 185 | "# Visualize the attributions and step scores\n",
|
186 | 186 | "out.show()"
|
|
349 | 349 | ],
|
350 | 350 | "source": [
|
351 | 351 | "out = model.attribute(\n",
|
352 |
| - " input_texts=\"Hello everyone, hope you're enjoying the tutorial!\",\n", |
353 |
| - " attribute_target=True,\n", |
354 |
| - " method=\"attention\"\n", |
| 352 | + " input_texts=\"Hello everyone, hope you're enjoying the tutorial!\", attribute_target=True, method=\"attention\"\n", |
355 | 353 | ")\n",
|
356 | 354 | "# out[0] is a shortcut for out.sequence_attributions[0]\n",
|
357 | 355 | "out[0].source_attributions.shape"
|
|
535 | 533 | ],
|
536 | 534 | "source": [
|
537 | 535 | "# Gets the mean weights of the first three attention heads only, no normalization\n",
|
538 |
| - "# do_post_aggregation_checks=False is needed since the output has >2 dimensions and \n", |
| 536 | + "# do_post_aggregation_checks=False is needed since the output has >2 dimensions and\n", |
539 | 537 | "# could not be visualized\n",
|
540 | 538 | "aggregated_heads_seq_attr_out = out[0].aggregate(\n",
|
541 |
| - " \"mean\", select_idx=(0,3), normalize=False, do_post_aggregation_checks=False\n", |
| 539 | + " \"mean\", select_idx=(0, 3), normalize=False, do_post_aggregation_checks=False\n", |
542 | 540 | ")\n",
|
543 | 541 | "\n",
|
544 | 542 | "# (source_len, target_len, num_layers)\n",
|
|
726 | 724 | " \"Domanda: Quanti studenti hanno partecipato alle LCL nel 2023?\"\n",
|
727 | 725 | ")\n",
|
728 | 726 | "\n",
|
729 |
| - "qa_model = inseq.load_model(\"it5/it5-base-question-answering\", \"input_x_gradient\")\n", |
| 727 | + "qa_model = inseq.load_model(\"it5/it5-base-question-answering\", \"input_x_gradient\")\n", |
730 | 728 | "out = qa_model.attribute(question, attribute_target=True, step_scores=[\"probability\"])\n",
|
731 | 729 | "\n",
|
732 | 730 | "# Aggregate only source tokens, leave target tokens as they are\n",
|
|
1097 | 1095 | " contrast_targets=\"Ho salutato la manager\",\n",
|
1098 | 1096 | " attribute_target=True,\n",
|
1099 | 1097 | " # We also visualize the score used as target using the same function as step score\n",
|
1100 |
| - " step_scores=[\"contrast_prob_diff\"]\n", |
| 1098 | + " step_scores=[\"contrast_prob_diff\"],\n", |
1101 | 1099 | ")\n",
|
1102 | 1100 | "\n",
|
1103 | 1101 | "# Weight attribution scores by the difference in probabilities\n",
|
|
1212 | 1210 | ")\n",
|
1213 | 1211 | "\n",
|
1214 | 1212 | "source_without_context = \"Do you already know when you'll be back?\"\n",
|
1215 |
| - "source_with_context = \"Thank you for your help, my friend, you really saved my life. Do you already know when you'll be back?\"\n", |
| 1213 | + "source_with_context = (\n", |
| 1214 | + " \"Thank you for your help, my friend, you really saved my life. Do you already know when you'll be back?\"\n", |
| 1215 | + ")\n", |
1216 | 1216 | "\n",
|
1217 |
| - "print(\"Generation without context:\", model.generate(source_without_context, forced_bos_token_id=model.tokenizer.lang_code_to_id[\"it_IT\"]))\n", |
1218 |
| - "print(\"Generation with context:\", model.generate(source_with_context, forced_bos_token_id=model.tokenizer.lang_code_to_id[\"it_IT\"]))\n", |
| 1217 | + "print(\n", |
| 1218 | + " \"Generation without context:\",\n", |
| 1219 | + " model.generate(source_without_context, forced_bos_token_id=model.tokenizer.lang_code_to_id[\"it_IT\"]),\n", |
| 1220 | + ")\n", |
| 1221 | + "print(\n", |
| 1222 | + " \"Generation with context:\",\n", |
| 1223 | + " model.generate(source_with_context, forced_bos_token_id=model.tokenizer.lang_code_to_id[\"it_IT\"]),\n", |
| 1224 | + ")\n", |
1219 | 1225 | "\n",
|
1220 | 1226 | "out = model.attribute(\n",
|
1221 | 1227 | " source_without_context,\n",
|
|
1224 | 1230 | " contrast_targets=\"Grazie per il tuo aiuto, mi hai davvero salvato la vita. Sai già quando tornerai?\",\n",
|
1225 | 1231 | " attribute_target=True,\n",
|
1226 | 1232 | " # We also visualize the score used as target using the same function as step score\n",
|
1227 |
| - " step_scores=[\"pcxmi\", \"probability\"]\n", |
| 1233 | + " step_scores=[\"pcxmi\", \"probability\"],\n", |
1228 | 1234 | ")\n",
|
1229 | 1235 | "\n",
|
1230 | 1236 | "out.show()"
|
|
1336 | 1342 | ],
|
1337 | 1343 | "source": [
|
1338 | 1344 | "# Print tokens to get token indices\n",
|
1339 |
| - "print([(i, x) for i, x in enumerate(model.encode(mt_target, as_targets=True).input_tokens[0])])\n", |
1340 |
| - "print([(i, x) for i, x in enumerate(model.encode(pe_target, as_targets=True).input_tokens[0])])" |
| 1345 | + "print(list(enumerate(model.encode(mt_target, as_targets=True).input_tokens[0])))\n", |
| 1346 | + "print(list(enumerate(model.encode(pe_target, as_targets=True).input_tokens[0])))" |
1341 | 1347 | ]
|
1342 | 1348 | },
|
1343 | 1349 | {
|
|
1394 | 1400 | " attributed_fn=\"contrast_prob_diff\",\n",
|
1395 | 1401 | " step_scores=[\"contrast_prob_diff\"],\n",
|
1396 | 1402 | " contrast_targets=pe_target,\n",
|
1397 |
| - " contrast_targets_alignments=[(0,0), (1,1), (2,2), (3,4), (4,4), (5,5), (6,7), (7,9)],\n", |
| 1403 | + " contrast_targets_alignments=[(0, 0), (1, 1), (2, 2), (3, 4), (4, 4), (5, 5), (6, 7), (7, 9)],\n", |
1398 | 1404 | ")\n",
|
1399 | 1405 | "\n",
|
1400 | 1406 | "# Reasonable alignments\n",
|
|
1504 | 1510 | "metadata": {},
|
1505 | 1511 | "outputs": [],
|
1506 | 1512 | "source": [
|
1507 |
| - "import inseq\n", |
1508 | 1513 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
1509 | 1514 | "\n",
|
| 1515 | + "import inseq\n", |
| 1516 | + "\n", |
1510 | 1517 | "# The model is loaded in 8-bit on available GPUs using the bitsandbytes library integrated in HF Transformers\n",
|
1511 | 1518 | "# This will make the model much smaller for inference purposes, but attributions are not guaranteed to match those\n",
|
1512 | 1519 | "# of the full-precision model.\n",
|
|
1930 | 1937 | }
|
1931 | 1938 | ],
|
1932 | 1939 | "source": [
|
1933 |
| - "from inseq import FeatureAttributionOutput\n", |
1934 | 1940 | "import pandas as pd\n",
|
1935 | 1941 | "\n",
|
| 1942 | + "from inseq import FeatureAttributionOutput\n", |
| 1943 | + "\n", |
1936 | 1944 | "scores = {}\n",
|
1937 | 1945 | "\n",
|
1938 | 1946 | "for layer_idx in range(48):\n",
|
1939 | 1947 | " curr_out = FeatureAttributionOutput.load(f\"../data/cat_outputs/layer_{layer_idx}.json\")\n",
|
1940 | 1948 | " out_dict = curr_out.get_scores_dicts(do_aggregation=False)[0]\n",
|
1941 |
| - " scores[layer_idx] = [score for score in out_dict[\"target_attributions\"][\"ĠParis\"].values()][:-1]\n", |
| 1949 | + " scores[layer_idx] = list(out_dict[\"target_attributions\"][\"ĠParis\"].values())[:-1]\n", |
1942 | 1950 | "\n",
|
1943 | 1951 | "prefix_tokens = list(out_dict[\"target_attributions\"][\"ĠParis\"].keys())\n",
|
1944 | 1952 | "attributions_df = pd.DataFrame(scores, index=prefix_tokens[:-1])\n",
|
|
1989 | 1997 | "ax.set_xticks([0.5 + i for i in range(0, attributions_df.values.shape[1], 4)])\n",
|
1990 | 1998 | "ax.set_xticklabels(list(range(0, 48, 4)))\n",
|
1991 | 1999 | "ax.set_yticklabels(attributions_df.index)\n",
|
1992 |
| - "cb = plt.colorbar(h, ticks=[0, .15, .3, .45, .6, .75])\n", |
| 2000 | + "cb = plt.colorbar(h, ticks=[0, 0.15, 0.3, 0.45, 0.6, 0.75])\n", |
1993 | 2001 | "fig.suptitle(\"What activations are contributing to predicting 'Paris' over 'Rome'?\")\n",
|
1994 | 2002 | "plt.savefig(filename)\n",
|
1995 | 2003 | "plt.show()"
|
|
0 commit comments