Skip to content

Commit

Permalink
remove experiment name
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Oct 16, 2024
1 parent 0b0b13f commit 51e7f11
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions scripts/submit_dpo_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def load_yaml(file_path):

def main():
parser = argparse.ArgumentParser(description="Run experiment with Beaker config")
parser.add_argument("--experiment_name", required=False, help="Name of the experiment")
parser.add_argument("--default_beaker_config", default="configs/beaker_configs/default_dpo.yaml",
help="Path to the default Beaker config file")
parser.add_argument("--config", default=None,
Expand Down Expand Up @@ -156,14 +155,10 @@ def parse_args(args):
new_arguments = re.sub(r'--num_machines \d+', f'--num_machines {args.num_nodes}', new_arguments)

# if given, use the provided name. Otherwise try to guess from the config
if args.experiment_name:
model_name = args.experiment_name
else:
model_name = get_model_name(new_arguments)
if model_name.lower() == "/model":
raise ValueError("Warning - model name found was just /model. Please provide a more descriptive name via the `--experiment_name` flag.")
# if model name has /, replace with _
model_name = model_name.replace("/", "_")
exp_name = get_exp_name(new_arguments)
if exp_name is None:
exp_name = ""
exp_name = exp_name.replace("/", "_")
# try given config only has one
dataset_name, dataset_mixer, train_file = check_dataset_selection(new_arguments)
print(f"Dataset selection is valid.")
Expand All @@ -175,7 +170,7 @@ def parse_args(args):
d['tasks'][0]['arguments'][0] = new_arguments

# name and description
exp_name = f"open_instruct_dpo_tune_{model_name}_{now}"
exp_name = f"dpo_tune_{exp_name}_{now}"
d['description'] = exp_name
d['tasks'][0]['name'] = exp_name

Expand Down Expand Up @@ -258,10 +253,10 @@ def parse_dataset_mixer(mixer_dict):
elems.append(str(v))
return ' '.join(elems)

def get_model_name(command_string):
def get_exp_name(command_string):
parts = shlex.split(command_string)
for i, part in enumerate(parts):
if part == '--model_name_or_path':
if part == '--exp_name':
if i + 1 < len(parts):
return parts[i + 1]
return None # Return None if model name is not found
Expand Down

0 comments on commit 51e7f11

Please sign in to comment.