Skip to content

Commit

Permalink
updates to functionality - not clean yet
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanThrush committed Jan 28, 2025
1 parent d34f497 commit 22f906f
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 132 deletions.
4 changes: 3 additions & 1 deletion examples/get_error_and_bpb/chunk_pretraining_data_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
config.hf_name,
#name=config.subset,
#split=config.split,
)[config.split]
)
if config.split is not None:
ds = ds[config.split]
else:
ds = load_dataset(
config.hf_name,
Expand Down
14 changes: 9 additions & 5 deletions examples/get_error_and_bpb/error_and_bpb_scheduler.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
# if the jobs fail, then this flag says to not redo the cached bpb values and start
# computing bpb where the script left off.

# -q sphinx --exclude sphinx3,sphinx4,sphinx5,sphinx6,sphinx7,sphinx8,sphinx9

JOBID1=$(nlprun -q sphinx --exclude sphinx3,sphinx4,sphinx5,sphinx6,sphinx7,sphinx8,sphinx9 -r 100G -g 1 -c 16 -a perplexity-correlations -o \
$1/job_log.txt "python get_error_and_bpb.py \
$1/job_log.txt "export HF_DATASETS_TRUST_REMOTE_CODE=1; python get_error_and_bpb.py \
--raw_job_output_path $1 \
--hf_llm_family $2 \
--hf_llm_name $3 \
--hf_llm_revision $4 \
--eleuther_eval_names $5 \
--eleuther_eval_metrics $6 \
--eleuther_eval_lower_is_better $7 \
--chunked_pretraining_data_sample $8 \
--error_output_csv $9 \
--bpb_output_csv_prefix ${10} \
--eleuther_eval_num_fewshot $7 \
--eleuther_eval_lower_is_better $8 \
--chunked_pretraining_data_sample $9 \
--error_output_csv ${10} \
--bpb_output_csv_prefix ${11} \
--custom_evals ${12} \
--resume \
--half_precision \
--save_model_info \
Expand Down
203 changes: 127 additions & 76 deletions examples/get_error_and_bpb/get_error_and_bpb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
import time
import pandas as pd
import warnings
from custom_evals.jeopardy import jeopardy
from custom_evals.jeopardy_acc import jeopardy_accuracy
import math

custom_evals = {"jeopardy": jeopardy, "jeopardy_accuracy": jeopardy_accuracy}

parser = argparse.ArgumentParser()

Expand All @@ -26,6 +31,8 @@
parser.add_argument("--eleuther_eval_names", nargs="*", required=False)
parser.add_argument("--eleuther_eval_metrics", nargs="*", required=False)
parser.add_argument("--eleuther_eval_lower_is_better", nargs="*", required=False)
parser.add_argument("--eleuther_eval_num_fewshot", nargs="*", required=False)
parser.add_argument("--custom_evals", nargs="*", required=False)
parser.add_argument("--chunked_pretraining_data_sample", required=False)
parser.add_argument("--raw_job_output_path", required=False)
parser.add_argument("--error_output_csv", required=False)
Expand All @@ -37,30 +44,40 @@
parser.add_argument("--save_model_info", action="store_true")
parser.add_argument("--device", default="cuda")
parser.add_argument("--half_precision", action="store_true")
parser.add_argument("--hf_llm_batch_size", type=int, default=2)
parser.add_argument("--hf_llm_batch_size", type=int, default=1)

args = parser.parse_args()

if args.chunked_pretraining_data_sample == "None":
args.chunked_pretraining_data_sample = None

# If args.config is specified, use this script just to kick off a bunch
# of jobs, and then exit from the script
if args.config is not None:
with open(args.config, "r") as file:
config = SimpleNamespace(**yaml.safe_load(file))

eleuther_eval_names = []
eleuther_eval_metrics = []
eleuther_eval_num_fewshot = []
eleuther_eval_lower_is_better = []
for eval in config.evals:
eval = SimpleNamespace(**eval)
eleuther_eval_names.append(eval.eleuther_name)
eleuther_eval_metrics.append(eval.metric)
eleuther_eval_lower_is_better.append(eval.lower_is_better)
eleuther_eval_num_fewshot.append(eval.num_fewshot)
eleuther_eval_names = " ".join(eleuther_eval_names)
eleuther_eval_metrics = " ".join(eleuther_eval_metrics)
eleuther_eval_num_fewshot = " ".join(
[str(obj) for obj in eleuther_eval_num_fewshot]
)
eleuther_eval_lower_is_better = " ".join(
[str(obj) for obj in eleuther_eval_lower_is_better]
)

custom_evals = " ".join(config.custom_evals)

for family in config.llms:
family = SimpleNamespace(**family)
for llm in family.hf_names:
Expand All @@ -79,9 +96,9 @@
os.makedirs(output_path, exist_ok=True)
command = f"bash error_and_bpb_scheduler.sh \
'{output_path}' '{family.family}' '{llm}' '{revision}' '{eleuther_eval_names}' \
'{eleuther_eval_metrics}' '{eleuther_eval_lower_is_better}' \
'{config.chunked_pretraining_data_sample}' '{config.error_output_csv}' \
'{config.bpb_output_csv_prefix}'"
'{eleuther_eval_metrics}' '{eleuther_eval_num_fewshot}' \
'{eleuther_eval_lower_is_better}' '{config.chunked_pretraining_data_sample}' \
'{config.error_output_csv}' '{config.bpb_output_csv_prefix}' '{custom_evals}'"
subprocess.call(command, shell=True)
sys.exit()

Expand All @@ -91,8 +108,8 @@
args.hf_llm_name,
args.eleuther_eval_names,
args.eleuther_eval_metrics,
args.eleuther_eval_num_fewshot,
args.eleuther_eval_lower_is_better,
args.chunked_pretraining_data_sample,
args.error_output_csv,
args.bpb_output_csv_prefix,
):
Expand All @@ -101,16 +118,17 @@
--hf_llm_name\n\
--eleuther_eval_names\n\
--eleuther_eval_metrics\n\
--eleuther_eval_num_fewshot\n\
--eleuther_eval_lower_is_better\n\
--chunked_pretraining_data_sample\n\
--error_output_csv\n\
--bpb_output_csv_prefix\n\
are required if --config is not provided."
)

os.makedirs(args.raw_job_output_path, exist_ok=True)

ds = load_from_disk(args.chunked_pretraining_data_sample)
if args.chunked_pretraining_data_sample is not None:
ds = load_from_disk(args.chunked_pretraining_data_sample)

tokenizer = AutoTokenizer.from_pretrained(
args.hf_llm_name,
Expand Down Expand Up @@ -164,26 +182,41 @@ def get_loss_hf(examples):
# Some models require this.
inputs["attention_mask"] = inputs["attention_mask"].bool()

outputs = model(**inputs)
max_len = model.config.max_position_embeddings if (hasattr(model, "config") and hasattr(model.config, "max_position_embeddings")) else None
if max_len is None:
max_len = tokenizer.model_max_length if hasattr(tokenizer, "model_max_length") else None

# UGH OPT WHYYY REEE
if 'opt-2.7b' in args.hf_llm_name:
max_len = 1024+512

logits = outputs.logits
if max_len is None or len(inputs["input_ids"][0]) <= max_len:
try:
outputs = model(**inputs)

loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
logits = outputs.logits

shift_logits = logits[..., :-1, :].contiguous()
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")

# Need to set pad indices to -100 for cross entropy loss to ignore the padding.
pad_indices = torch.where(inputs.attention_mask == 0)
inputs.input_ids[pad_indices] = -100
shift_logits = logits[..., :-1, :].contiguous()

shift_labels = inputs.input_ids[..., 1:].contiguous()
# Need to set pad indices to -100 for cross entropy loss to ignore the padding.
pad_indices = torch.where(inputs.attention_mask == 0)
inputs.input_ids[pad_indices] = -100

loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
shift_labels = inputs.input_ids[..., 1:].contiguous()

loss = loss.view(shift_labels.size())
loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

# This averages while ignoring the padding
losses = loss.sum(dim=1) / inputs.attention_mask[..., 1:].sum(dim=1)
loss = loss.view(shift_labels.size())

# This averages while ignoring the padding
losses = loss.sum(dim=1) / inputs.attention_mask[..., 1:].sum(dim=1)
except Exception as e:
losses = torch.full((args.hf_llm_batch_size,), float('nan'))

else:
losses = torch.full((args.hf_llm_batch_size,), float('nan'))

output_examples = {
"id": examples["id"],
Expand All @@ -199,34 +232,37 @@ def get_loss_hf(examples):
return output_examples


# Create a list to hold the shards. This enables us to resume getting the loss
# from the shard where we left off if there is some issue that causes the job to
# exit early.
shards = []
if args.chunked_pretraining_data_sample is not None:
# Create a list to hold the shards. This enables us to resume getting the loss
# from the shard where we left off if there is some issue that causes the job to
# exit early.
shards = []

# Shard the dataset and add each shard to the list
for i in range(args.num_loss_shards):
if args.resume and os.path.exists(f"{args.raw_job_output_path}/loss_shards/{i}"):
shard = load_from_disk(f"{args.raw_job_output_path}/loss_shards/{i}")
shards.append(shard)
else:
shard = ds.shard(num_shards=args.num_loss_shards, index=i)
# Shard the dataset and add each shard to the list
for i in range(args.num_loss_shards):
if args.resume and os.path.exists(f"{args.raw_job_output_path}/loss_shards/{i}"):
shard = load_from_disk(f"{args.raw_job_output_path}/loss_shards/{i}")
shards.append(shard)
else:
shard = ds.shard(num_shards=args.num_loss_shards, index=i)

# For efficiency - we want to avoid as much padding as possible
shard = shard.sort(["reference_token_count"], reverse=[True])
# For efficiency - we want to avoid as much padding as possible
shard = shard.sort(["reference_token_count"], reverse=[True])

shard = shard.map(
lambda example: get_loss_hf(example),
remove_columns=ds.column_names,
batched=True,
batch_size=args.hf_llm_batch_size,
)
shard = shard.map(
lambda example: get_loss_hf(example),
remove_columns=ds.column_names,
batched=True,
batch_size=args.hf_llm_batch_size,
)

print("NaNs in shard: ", sum(1 for item in shard['loss'] if math.isnan(item)))

shard.save_to_disk(f"{args.raw_job_output_path}/loss_shards/{i}")
shard.save_to_disk(f"{args.raw_job_output_path}/loss_shards/{i}")

shards.append(shard)
shards.append(shard)

loss_df = concatenate_datasets(shards).to_pandas()
loss_df = concatenate_datasets(shards).to_pandas()

# Convert to BPB at the end, so raw losses, token counts, and byte counts are still
# stored in the loss shard datasets in case they would be useful in the future.
Expand Down Expand Up @@ -265,18 +301,19 @@ def get_bpb(df):
(df["token_count"] / df["byte_count"]) * df["loss"] / np.log(2)
)
df.drop(columns=["token_count", "byte_count", "loss"], inplace=True)
df["id"]=df["id"].astype(str)
return df

if args.chunked_pretraining_data_sample is not None:
bpb_dfs = [get_bpb(loss_df)]

bpb_dfs = [get_bpb(loss_df)]

if "domain" in loss_df.columns:
agg_groups = [["chunk", "id", "domain"], ["id", "domain"], ["domain"]]
else:
agg_groups = [["chunk", "id"], ["id"]]
if "domain" in loss_df.columns:
agg_groups = [["chunk", "id", "domain"], ["id", "domain"], ["domain"]]
else:
agg_groups = [["chunk", "id"], ["id"]]

for agg_group in agg_groups[1:]:
bpb_dfs.append(get_bpb(aggregate_by_domain_or_id(loss_df, agg_group)))
for agg_group in agg_groups[1:]:
bpb_dfs.append(get_bpb(aggregate_by_domain_or_id(loss_df, agg_group)))


# Function to safely read, modify, and write to shared CSV file.
Expand All @@ -295,6 +332,7 @@ def update_csv_async(
already_added = False
try:
shared_df = pd.read_csv(csv_file_path)
shared_df["id"]=shared_df["id"].astype(str)
if new_column_name in shared_df.columns:
shared_df = shared_df.drop(columns=[new_column_name])
warnings.warn(
Expand Down Expand Up @@ -332,17 +370,18 @@ def get_lockfile_pathname(pathname):
return lockfile_pathname


for index in range(len(agg_groups)):
bpb_df = bpb_dfs[index]
agg_group = agg_groups[index]
bpb_output_csv_name = f"{args.bpb_output_csv_prefix}_{agg_group[0]}.csv"
bpb_lock_file_pathname = get_lockfile_pathname(bpb_output_csv_name)
update_csv_async(
bpb_output_csv_name,
bpb_lock_file_pathname,
bpb_df,
agg_group,
)
if args.chunked_pretraining_data_sample is not None:
for index in range(len(agg_groups)):
bpb_df = bpb_dfs[index]
agg_group = agg_groups[index]
bpb_output_csv_name = f"{args.bpb_output_csv_prefix}_{agg_group[0]}.csv"
bpb_lock_file_pathname = get_lockfile_pathname(bpb_output_csv_name)
update_csv_async(
bpb_output_csv_name,
bpb_lock_file_pathname,
bpb_df,
agg_group,
)

# Check to see that there are actually evals specified before continuing.
if len(args.eleuther_eval_names) == 0:
Expand All @@ -358,28 +397,40 @@ def get_model_info(self):

hflm_eleuther = HFLM_Local(pretrained=model, tokenizer=tokenizer)

results = lm_eval.simple_evaluate(
model=hflm_eleuther,
tasks=args.eleuther_eval_names,
batch_size="auto",
limit=5000,
bootstrap_iters=1000,
log_samples=False,
)

# Make the name of the error column the llm model family and name, so we can
# merge with the big shared error matrix.
error_dict = {
"benchmark": args.eleuther_eval_names,
new_column_name: [],
"benchmark": args.custom_evals + [f"{name}_{shots}" for name, shots in zip(args.eleuther_eval_names, args.eleuther_eval_num_fewshot)],
new_column_name: [custom_evals[custom_eval](model, tokenizer, args.device) for custom_eval in args.custom_evals],
}

for index in range(len(args.eleuther_eval_names)):

if args.eleuther_eval_num_fewshot[index] is not None:
results = lm_eval.simple_evaluate(
model=hflm_eleuther,
tasks=[args.eleuther_eval_names[index]],
batch_size="auto",
limit=5000,
bootstrap_iters=1000,
log_samples=False,
num_fewshot=int(args.eleuther_eval_num_fewshot[index]),
)
else:
results = lm_eval.simple_evaluate(
model=hflm_eleuther,
tasks=[args.eleuther_eval_names[index]],
batch_size="auto",
limit=5000,
bootstrap_iters=1000,
log_samples=False,
)
print(results)

name = args.eleuther_eval_names[index]
metric = args.eleuther_eval_metrics[index]
lower_is_better = ast.literal_eval(args.eleuther_eval_lower_is_better[index])
score = results["results"][name][metric]
if not lower_is_better:
score = 1 - score
#if not lower_is_better:
# score = 1 - score
error_dict[new_column_name].append(score)

error_df = pd.DataFrame.from_dict(error_dict)
Expand Down
Loading

0 comments on commit 22f906f

Please sign in to comment.