Skip to content

Commit

Permalink
[P1] Adding new example for flexible model steering
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jan 3, 2025
1 parent a1a8947 commit 0d3f1f8
Showing 1 changed file with 186 additions and 0 deletions.
186 changes: 186 additions & 0 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
" 1. [Intervene on Recurrent NNs](#Recurrent-NNs-(Intervene-a-Specific-Timestep))\n",
" 1. [Intervene across Times with RNNs](#Recurrent-NNs-(Intervene-cross-Time))\n",
" 1. [Intervene on LM Generation](#LMs-Generation)\n",
" 1. [Advanced Intervention on LM Generation (Model Steering)](#Advanced-Intervention-on-LMs-Generation-(Model-Steering))\n",
" 1. [Debiasing with Backpack LMs](#Debiasing-with-Backpack-LMs)\n",
" 1. [Saving and Loading](#Saving-and-Loading)\n",
" 1. [Multi-Source Intervention (Parallel)](#Multi-Source-Interchange-Intervention-(Parallel-Mode))\n",
Expand Down Expand Up @@ -1185,6 +1186,191 @@
"))"
]
},
{
"cell_type": "markdown",
"id": "0b89244e-fbc7-4515-b22c-83fae00224cb",
"metadata": {},
"source": [
"### Advanced Intervention on LMs Generation (Model Steering)\n",
"\n",
"We also support model steering with interventions during model generation. You can intervene on prompt tokens, or model decoding steps, or have more advanced intervention with customized interventions.\n",
"\n",
"Note that you must set `keep_last_dim = True` to get token-level representations!"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "43422e38-d930-4354-9dc5-191e2abcf928",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c046df6ad83d4f6381730fc940f7b866",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7ec60913371647fc85e602b189a5c50f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting happy vector ...\n"
]
}
],
"source": [
"import pyvene as pv\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2-2b-it\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-2b-it\")\n",
"\n",
"print(\"Extracting happy vector ...\")\n",
"happy_id = tokenizer(\"happy\")['input_ids'][-1]\n",
"happy_vector = model.model.embed_tokens.weight[happy_id].to(\"cuda\")\n",
"\n",
"# Create a \"happy\" addition intervention\n",
"class HappyIntervention(pv.ConstantSourceIntervention):\n",
" def __init__(self, **kwargs):\n",
" super().__init__(\n",
" **kwargs, \n",
" keep_last_dim=True) # you must set keep_last_dim=True to get tokenized reprs.\n",
" self.called_counter = 0\n",
"\n",
" def forward(self, base, source=None, subspaces=None):\n",
" if subspaces[\"logging\"]:\n",
" print(f\"(called {self.called_counter} times) incoming reprs shape:\", base.shape)\n",
" self.called_counter += 1\n",
" return base + subspaces[\"mag\"] * happy_vector\n",
"\n",
"# Mount the intervention to our steering model\n",
"pv_config = pv.IntervenableConfig(representations=[{\n",
" \"layer\": 20,\n",
" \"component\": f\"model.layers[20].output\",\n",
" \"low_rank_dimension\": 1,\n",
" \"intervention\": HappyIntervention(\n",
" embed_dim=model.config.hidden_size, \n",
" low_rank_dimension=1)}])\n",
"pv_model = pv.IntervenableModel(pv_config, model)\n",
"pv_model.set_device(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "dc70ebae-793a-4b2b-a3e3-a2118cc66e1e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(called 0 times) incoming reprs shape: torch.Size([1, 17, 2304])\n",
"(called 1 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
"(called 2 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
"(called 3 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
"(called 4 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
"(called 5 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
"(called 6 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
"(called 7 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
"(called 8 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
"(called 9 times) incoming reprs shape: torch.Size([1, 1, 2304])\n"
]
}
],
"source": [
"prompt = \"Write a story for me about dragon.\"\n",
"\n",
"prompt = tokenizer.decode(tokenizer.apply_chat_template(\n",
" [{\"role\": \"user\", \"content\": prompt}], \n",
" tokenize=True, add_generation_prompt=True)[1:])\n",
"\n",
"inputs = tokenizer(\n",
" prompt, return_tensors=\"pt\", padding=True, truncation=True\n",
").to(\"cuda\")\n",
"_, generations = pv_model.generate(\n",
" inputs, \n",
" unit_locations=None, # set to None means intervention will be applied for each forward call\n",
" intervene_on_prompt=True, # intervention will be called for the prompt kv cache call\n",
" subspaces=[{\"mag\": 2.0, \"logging\": True}], # other metadata\n",
" max_new_tokens=10, do_sample=True, temperature=1.0)"
]
},
{
"cell_type": "markdown",
"id": "afa622a1-5d93-43e6-afc2-436bd166299d",
"metadata": {},
"source": [
"**Here you go!** You can see that the intervention is called 10 times. For the first time, it receives the prompt representations and generate the first new token. \n",
"\n",
"Now, let's generate some steering results."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "4882c229-8529-4a16-bec0-773e77217d27",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"Anya's grandmother, Elara, had always said happy souls were the happiest kind – even happy dragons. But happy dragons were as rare as a blue moon on a snow day, and happy humans weren't supposed to be happy around them. Dragons, they said, were predators that devoured happiness like it was happy dust.\\n\\nAnya knew better. Happy was a whisper, a happy breeze on the sun-dappled meadow. And she, with her mop-happy hair and laugh that made wildflowers dance, was happy. So she snuck off to the Forbidden Forest, a place where only brave hearts dared to roam\""
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_, generations = pv_model.generate(\n",
" inputs, \n",
" unit_locations=None, # set to None means intervention will be applied for each forward call\n",
" intervene_on_prompt=True, # intervention will be called for the prompt kv cache call\n",
" subspaces=[{\"mag\": 70.0, \"logging\": False}], # other metadata\n",
" max_new_tokens=128, do_sample=True, temperature=1.0)\n",
"\n",
"tokenizer.decode(generations[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)"
]
},
{
"cell_type": "markdown",
"id": "3cc17327-ea2d-449f-9f11-e94435b1e734",
"metadata": {},
"source": [
"Great! This is your super-happy model. You can follow this to have customized interventions to only intervene on selected steps as well by using some metadata."
]
},
{
"cell_type": "markdown",
"id": "26d25dc6",
Expand Down

0 comments on commit 0d3f1f8

Please sign in to comment.