Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanThrush committed Feb 7, 2025
1 parent cf879de commit 419c112
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
49 changes: 29 additions & 20 deletions examples/get_error_and_bpb/get_error_and_bpb.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

if args.mode == "suffix":
percentage_positions_list = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64]
if args.hf_llm_name is not None and "AI-Sweden-Models" in args.hf_llm_name: # REEEEEE
percentage_positions_list = []
else:
percentage_positions_list = []

Expand Down Expand Up @@ -140,11 +142,21 @@
if args.chunked_pretraining_data_sample is not None:
ds = load_from_disk(args.chunked_pretraining_data_sample)

tokenizer = AutoTokenizer.from_pretrained(
args.hf_llm_name,
revision=args.hf_llm_revision,
trust_remote_code=True,
)
try:
tokenizer = AutoTokenizer.from_pretrained(
args.hf_llm_name,
revision=args.hf_llm_revision,
trust_remote_code=True,
use_fast=True,
)
print("Loaded fast tokenizer successfully!")
except ValueError:
print("Fast tokenizer not available, falling back to slow tokenizer.")
tokenizer = AutoTokenizer.from_pretrained(
args.hf_llm_name,
revision=args.hf_llm_revision,
trust_remote_code=True,
)

if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
if not hasattr(tokenizer, "eos_token") or tokenizer.eos_token is None:
Expand Down Expand Up @@ -188,7 +200,7 @@ def get_loss_hf(examples):
if args.mode == "suffix":
inputs, suffix_indices, char_indices_list = batch_tokenize_with_percentage_based_indices(tokenizer, texts, percentage_positions_list)
elif args.mode == "token":
inputs, token_strings = batch_tokenize_with_token_str_info(tokenizer, text)
inputs, token_strings = batch_tokenize_with_token_str_info(tokenizer, texts)
else: # "sub_chunk"
inputs, step_indices, char_indices_list = batch_tokenize_with_char_step(tokenizer, texts, args.sub_chunk_char_step)

Expand Down Expand Up @@ -230,15 +242,12 @@ def get_loss_hf(examples):
losses = loss.sum(dim=1) / inputs.attention_mask[..., 1:].sum(dim=1)
suffix_losses = compute_average_loss_from_indices(loss, suffix_indices, inputs.attention_mask)
elif args.mode == "token":
step_losses_list = compute_average_loss_from_index_tuples(loss, step_indices, inputs.attention_mask)
else: # "sub_chunk"
shifted_token_strings = []
for strings in token_strings:
shifted_strings = strings[1:]
shifted_token_strings.append(shifted_strings)
token_loss_dicts = compute_token_loss_dicts(loss, inputs.attention_mask, token_strings)
else: # "sub_chunk"
step_losses_list = compute_average_loss_from_index_tuples(loss, step_indices, inputs.attention_mask)

except Exception as e:
print(e)
if args.mode == "suffix":
losses = torch.full((args.hf_llm_batch_size,), float('nan'))
suffix_losses = torch.full((args.hf_llm_batch_size,len(percentage_positions_list)), float('nan'))
Expand Down Expand Up @@ -284,9 +293,9 @@ def get_loss_hf(examples):
if "domain" in examples.keys():
output_examples["domain"] = []
for index, token_loss_dict in enumerate(token_loss_dicts):
output_examples["id"].append([examples["id"][index]]*len(token_loss_dict))
output_examples["id"] += [examples["id"][index]]*len(token_loss_dict)
for token_str, (loss, token_count) in token_loss_dict.items():
output_examples["chunk"].append(examples["chunk"][index] + "_" + token_str)
output_examples["chunk"].append(str(examples["chunk"][index]) + "_" + token_str)
output_examples["loss"].append(loss)
output_examples["token_count"].append(token_count)
output_examples["byte_count"].append(len(token_str.encode("utf-8")))
Expand All @@ -303,12 +312,12 @@ def get_loss_hf(examples):
if "domain" in examples.keys():
output_examples["domain"] = []
for index, char_indices in enumerate(char_indices_list):
output_examples["id"].append([examples["id"][index]]*len(char_indices))
for loss_index, index_pair in enumerate(char_indices):
output_examples["chunk"].append(examples["chunk"][index] + "_" + str(index_pair[0]) + ":" + str(index_pair[1]))
output_examples["id"] += [examples["id"][index]]*len(char_indices)
for loss_index, start_index in enumerate(char_indices):
output_examples["chunk"].append(str(examples["chunk"][index]) + "_" + str(start_index) + ":" + str(start_index + args.sub_chunk_char_step))
output_examples["loss"].append(step_losses_list[index][loss_index])
output_examples["byte_count"].append(len(texts[index][index_pair[0]:index_pair[1]].encode("utf-8")))
output_examples["token_count"].append(inputs.attention_mask[index][index_pair[0]:index_pair[1]].sum(dim=1))
output_examples["byte_count"].append(len(texts[index][start_index:start_index + args.sub_chunk_char_step].encode("utf-8")))
output_examples["token_count"].append(inputs.attention_mask[index][step_indices[index][loss_index][0]:step_indices[index][loss_index][1]].sum(dim=0))
if "domain" in examples.keys():
output_examples["domain"].append(examples["domain"][index])

Expand Down Expand Up @@ -394,7 +403,7 @@ def get_bpb(df, percent_prefix_designation=""):
df["id"] = df["id"].astype(str)
df = df[keep_columns].copy()

df[new_column_name + percent_prefix_designation] = (
df[new_column_name] = (
(df["token_count" + percent_prefix_designation] / df["byte_count" + percent_prefix_designation]) * df["loss" + percent_prefix_designation] / np.log(2)
)
df.drop(columns=["token_count" + percent_prefix_designation, "byte_count" + percent_prefix_designation, "loss" + percent_prefix_designation], inplace=True)
Expand Down
5 changes: 4 additions & 1 deletion examples/get_error_and_bpb/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ def batch_tokenize_with_percentage_based_indices(
if distance < min_distance:
min_distance = distance
nearest_token_index = j


if nearest_token_index is None:
# failed
return encoded_batch, None, char_indices_list
nearest_token_indices.append(nearest_token_index)

suffix_indices.append(nearest_token_indices)
Expand Down

0 comments on commit 419c112

Please sign in to comment.