Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Vermouth7 committed Nov 11, 2024
1 parent 9dc4081 commit c9d2b81
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 216 deletions.
5 changes: 5 additions & 0 deletions bigcode_eval/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def parallel_generations(
intermediate_save_generations_path=intermediate_save_generations_path,
my_mode=args.my_mode,
split_file=args.split_file,
insert_layers=args.insert_layers,
normalize=args.normalize,
operator=args.operator,
coef=args.coef,
discriminator=args.discriminator,
**gen_kwargs,
)
return generations
37 changes: 28 additions & 9 deletions bigcode_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ def _make_infill_prompt(self, prefix, suffix, preprefix=""):

def _make_instruction_prompt(self, instruction, context, prefix=""):
"""Make a prompt for instruction-tuning. Delimit instruction and context with specific tokens if provided."""

messages = [
{"role": "user", "content": instruction},
]

if not self.instruction_tokens:
warnings.warn(
"Instruction-tuning tokens are not provided for an instruction-tuning task, we will leave them empty."
Expand All @@ -185,10 +190,15 @@ def _make_instruction_prompt(self, instruction, context, prefix=""):
warnings.warn(
"Instruction-tuning tokens provided but one or more are empty. Ignore warning if this was intended"
)
prompt = (
prefix + user_token + instruction + end_token + assistant_token + context
)

# prompt = (
# prefix + user_token + instruction + end_token + assistant_token + context
# )
prompt=self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)+context

return prompt


Expand Down Expand Up @@ -257,6 +267,11 @@ def complete_code(
intermediate_save_generations_path: Optional[str] = None,
my_mode=0,
split_file=None,
insert_layers=None,
normalize=True,
operator=None,
coef=1.0,
discriminator=None,
**gen_kwargs,
):
"""Generate multiple codes for each task in the dataset using multiple GPUs with accelerate.
Expand Down Expand Up @@ -321,21 +336,25 @@ def complete_code(
# In transformers (>= 4.40.2), if the length of input_ids == max_length, a ValueError is thrown.
# We want to ignore this error in order to reproduce old results with mbpp.
try:
if split_file is not None:
insert_layer=[18]
if my_mode!=0 and split_file is not None:
insert_layer=eval(insert_layers)
print("insert layers: ",insert_layer)
layers = [i - 1 for i in insert_layer]
vector=get_split_hs(model,tokenizer,inputs_ids,split_file)[0] ## only for one batch
model.reset()
model.set_controller(layer_ids=layers, activations=vector)
model.set_controller(layer_ids=layers, activations=vector,normalize=normalize,operator=operator,coef=coef)
model.set_pos(inputs)

generated_tokens = model.generate(
input_ids=inputs,
num_return_sequences=batch_size,
mode=my_mode,
use_cache=True,
output_hidden_states=False,
use_cache=False,
output_hidden_states=True,
discriminator=discriminator,
**gen_kwargs,
)

except ValueError as e:
# When the length of input_ids == max_length, the generation is the same as the input
if str(e).startswith(f"Input length of input_ids is {inputs.shape[1]}, but `max_length` is set to {gen_kwargs['max_length']}"):
Expand Down
151 changes: 111 additions & 40 deletions bigcode_eval/wrappedmodel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# wrapping classes
import random

import numpy as np
import torch

Expand All @@ -14,10 +16,10 @@ def __init__(self, block):
self.normalize = False
self.input_pos=None
self.operator=None
self.controller_chosen=None
# self.last_saved=None


def forward(self,*args,edit=False, **kwargs):
# print(edit)
def forward(self,*args,edit, **kwargs):
output = self.block(*args, **kwargs)
if isinstance(output, tuple):
self.output = output[0]
Expand All @@ -26,12 +28,16 @@ def forward(self,*args,edit=False, **kwargs):
self.output = output
modified = output
# print("output 0:")
# print(output[0].shape)
# print(self.controller.shape)
# print(modified.shape)

# if edit==False and self.controller is not None and self.last_saved is not None:
# for i in range(0,len(self.token_pos)):
# token=self.token_pos[i]
# modified[:, token]=self.last_saved

if self.controller is not None and edit == True:
norm_pre = torch.norm(modified, dim=-1, keepdim=True)

if self.mask is not None:
mask = self.mask

Expand All @@ -50,45 +56,67 @@ def forward(self,*args,edit=False, **kwargs):
mask = 1.0
# print("mask",mask.shape)

if len(self.controller.shape) == 1:
self.controller = self.controller.reshape(1, 1, -1)
# assert len(self.controller.shape) == len(modified.shape), f"Shape of controller {self.controller.shape} does not match shape of modified {modified.shape}."

self.controller = self.controller.to(modified.device)
if self.controller_chosen == None:
if len(self.controller.shape) == 1:
self.controller = self.controller.reshape(1, 1, -1)
# assert len(self.controller.shape) == len(modified.shape), f"Shape of controller {self.controller.shape} does not match shape of modified {modified.shape}."

self.controller = self.controller.to(modified.device)

if type(mask) == torch.Tensor:
mask = mask.to(modified.device)
# handle activation
# print(self.token_pos)
if isinstance(self.token_pos, int):
modified[:, self.token_pos] = self.operator(modified[:, self.token_pos], self.controller[:, self.input_pos] * mask[:, self.input_pos])
elif isinstance(self.token_pos, list) or isinstance(self.token_pos, tuple):
for i in range(0,len(self.token_pos)):
token=self.token_pos[i]
# if self.last_saved==None:
# self.last_saved=modified[:, token]
modified[:, token] = self.operator(modified[:, token], self.controller[self.input_pos[i], -1].unsqueeze(0) * mask[:, -1])
if self.normalize:
norm_post = torch.norm(modified, dim=-1, keepdim=True)
modified = modified / norm_post * norm_pre
else:
if len(self.controller_chosen.shape) == 1:
self.controller_chosen = self.controller_chosen.reshape(1, 1, -1)
# assert len(self.controller.shape) == len(modified.shape), f"Shape of controller {self.controller.shape} does not match shape of modified {modified.shape}."

self.controller_chosen = self.controller_chosen.to(modified.device)

if type(mask) == torch.Tensor:
mask = mask.to(modified.device)
# handle activation
# print(self.token_pos)
if isinstance(self.token_pos, int):
modified[:, self.token_pos] = self.operator(modified[:, self.token_pos], self.controller_chosen[:, self.input_pos] * mask[:, self.input_pos])
elif isinstance(self.token_pos, list) or isinstance(self.token_pos, tuple):
for i in range(0,len(self.token_pos)):
token=self.token_pos[i]
modified[:, token] = self.operator(modified[:, token], self.controller_chosen[self.input_pos[i], -1].unsqueeze(0) * mask[:, -1])
if self.normalize:
norm_post = torch.norm(modified, dim=-1, keepdim=True)
modified = modified / norm_post * norm_pre

if type(mask) == torch.Tensor:
mask = mask.to(modified.device)
# handle activation
# print(self.token_pos)
if isinstance(self.token_pos, int):
modified[:, self.token_pos] = self.operator(modified[:, self.token_pos], self.controller[:, self.input_pos] * mask[:, self.input_pos])
elif isinstance(self.token_pos, list) or isinstance(self.token_pos, tuple):
for i in range(0,len(self.token_pos)):
token=self.token_pos[i]
modified[:, token] = self.operator(modified[:, token], self.controller[self.input_pos[i], -1].unsqueeze(0) * mask[:, -1])
if self.normalize:
norm_post = torch.norm(modified, dim=-1, keepdim=True)
modified = modified / norm_post * norm_pre
# if isinstance(self.token_pos,int):
# self.token_pos-=1
# elif isinstance(self.token_pos,list):
# self.token_pos = [pos - 1 for pos in self.token_pos]
if isinstance(output, tuple):
output = (modified,) + output[1:]
else:
output = modified

return output

def set_controller(self, activations,token_pos=-1, masks=None, normalize=False, operator='replace'):
def set_controller(self, activations,token_pos=-1, masks=None, normalize=False, operator='replace',coef=1.0):
self.normalize = normalize
self.controller = activations
self.mask = masks

if operator == 'linear_comb':

def op(current, controller):
return current + controller
return current + coef*controller
elif operator == 'piecewise_linear':

def op(current, controller):
sign = torch.sign((current * controller).sum(-1, keepdim=True))
return current + controller * sign
Expand All @@ -101,14 +129,21 @@ def op(current,controller):
else:
raise NotImplementedError(f"Operator {operator} not implemented.")
self.operator = op
def adjust_controller(self,index):
try:
self.controller_chosen = self.controller[index]
except IndexError:
index = random.randint(0, len(self.controller) - 1)
self.controller_chosen = self.controller[index]

def reset(self):
self.output = None
self.controller = None
self.mask = None
self.token_pos = None
self.operator = None

# self.last_saved=None
self.controller_chosen=None
def set_masks(self, masks):
self.mask = masks
def set_token_pos(self,token_pos):
Expand All @@ -117,6 +152,9 @@ def set_token_pos(self,token_pos):
else:
self.input_pos=-1
self.token_pos=token_pos

# self.token_pos=[-1]


BLOCK_NAMES = [
"self_attn",
Expand Down Expand Up @@ -206,33 +244,59 @@ def _get_activations(layer_id, block_name):
return _get_activations(layer_ids, block_name)


def set_controller(self, layer_ids, activations, block_name='decoder_block', token_pos=-1, masks=None, normalize=False, operator='replace'):
def set_controller(self, layer_ids, activations, block_name='decoder_block', token_pos=-1, masks=None, normalize=False, operator='replace',coef=1.0):

def _set_controller(layer_id, activations, block_name, masks, normalize, operator,coef):
current_layer = self.model.model.layers[layer_id]
if block_name == 'decoder_block':
current_layer.set_controller(activations, token_pos, masks, normalize, operator,coef)
elif self.is_wrapped(current_layer):
current_block = current_layer.block
if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_block, block_name)):
getattr(current_block, block_name).set_controller(activations, token_pos, masks, normalize, operator,coef)
else:
return f"No wrapped block named {block_name}."

else:
if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_layer, block_name)):
getattr(current_layer, block_name).set_controller(activations, token_pos, masks, normalize, operator,coef)
else:
return f"No wrapped block named {block_name}."

if isinstance(layer_ids, list) or isinstance(layer_ids, tuple) or isinstance(layer_ids, np.ndarray):
for layer_id in layer_ids:
_set_controller(layer_id, activations[:,layer_id+1], block_name, masks, normalize, operator,coef)
elif isinstance(layer_ids,int):
_set_controller(layer_ids, activations[:,layer_ids+1], block_name, masks, normalize, operator,coef)
else:
_set_controller(layer_ids, activations, block_name, masks, normalize, operator,coef)

def set_controller_2(self, layer_ids, activations, block_name='decoder_block', token_pos=-1, masks=None, normalize=False, operator='replace',coef=1.0):

def _set_controller(layer_id, activations, block_name, masks, normalize, operator):
def _set_controller(layer_id, activations, block_name, masks, normalize, operator,coef):
current_layer = self.model.model.layers[layer_id]
if block_name == 'decoder_block':
current_layer.set_controller(activations, token_pos, masks, normalize, operator)
current_layer.set_controller(activations, token_pos, masks, normalize, operator,coef)
elif self.is_wrapped(current_layer):
current_block = current_layer.block
if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_block, block_name)):
getattr(current_block, block_name).set_controller(activations, token_pos, masks, normalize, operator)
getattr(current_block, block_name).set_controller(activations, token_pos, masks, normalize, operator,coef)
else:
return f"No wrapped block named {block_name}."

else:
if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_layer, block_name)):
getattr(current_layer, block_name).set_controller(activations, token_pos, masks, normalize, operator)
getattr(current_layer, block_name).set_controller(activations, token_pos, masks, normalize, operator,coef)
else:
return f"No wrapped block named {block_name}."

if isinstance(layer_ids, list) or isinstance(layer_ids, tuple) or isinstance(layer_ids, np.ndarray):
for layer_id in layer_ids:
_set_controller(layer_id, activations[:,layer_id+1], block_name, masks, normalize, operator)
_set_controller(layer_id, activations[:,:,layer_id+1], block_name, masks, normalize, operator,coef)
elif isinstance(layer_ids,int):
_set_controller(layer_ids, activations[:,layer_ids+1], block_name, masks, normalize, operator)
_set_controller(layer_ids, activations[:,:,layer_ids+1], block_name, masks, normalize, operator,coef)
else:
_set_controller(layer_ids, activations, block_name, masks, normalize, operator)

_set_controller(layer_ids, activations, block_name, masks, normalize, operator,coef)

def reset(self):
for layer in self.model.model.layers:
Expand Down Expand Up @@ -272,6 +336,13 @@ def unwrap(self):
setattr(self.model.model.layers[l],
block_name,
getattr(self.model.model.layers[l], block_name).block)
def set_pos(self,token_positions_list):
for layer in self.model.model.layers:
layer.set_token_pos(token_positions_list)
def adjust_controller(self,index):
for layer in self.model.model.layers:
if layer.controller is not None:
layer.adjust_controller(index)
def set_pos(self,inputs):
batch,seq_len=inputs.shape
token_positions_list=[seq_len-1]*batch
Expand Down
24 changes: 24 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,30 @@ def parse_args():
type=str,
default=None,
)
parser.add_argument(
"--insert_layers",
type=str,
default="[20]",
)
parser.add_argument(
"--normalize",
action="store_true",
)
parser.add_argument(
"--operator",
type=str,
default="replace",
)
parser.add_argument(
"--coef",
type=float,
default="1.0",
)
parser.add_argument(
"--discriminator",
type=str,
default=None,
)
return parser.parse_args()


Expand Down
Loading

0 comments on commit c9d2b81

Please sign in to comment.