Skip to content

Commit

Permalink
submit_utils work with amlt
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Apr 1, 2024
1 parent 84b2690 commit 2404d16
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 37 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ scratch
.embgrams
**workspace
.history
output.png
52 changes: 26 additions & 26 deletions imodelsx/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,7 @@ def __call__(
truncation=False,
).to(self._model.device)

# torch.manual_seed(0)
if return_next_token_prob_scores:
if return_next_token_prob_scores or target_token_strs or return_top_target_token_str:
outputs = self._model.generate(
**inputs,
max_new_tokens=1,
Expand All @@ -380,29 +379,31 @@ def __call__(
next_token_probs = next_token_logits.softmax(
axis=-1).detach().cpu().numpy()

if target_token_strs is not None:
target_token_ids = self._check_target_token_strs(
target_token_strs)
if return_top_target_token_str:
selected_tokens = next_token_probs[:, np.array(
target_token_ids)].squeeze().argmax(axis=-1)
out_strs = [
target_token_strs[selected_tokens[i]]
for i in range(len(selected_tokens))
]
if use_cache:
pkl.dump(out_strs, open(cache_file, "wb"))
return out_strs
else:
out_dict_list = [
{target_token_strs[i]: next_token_probs[prompt_num, target_token_ids[i]]
for i in range(len(target_token_strs))
}
for prompt_num in range(len(prompt))
]
return out_dict_list
else:
if target_token_strs is None:
return next_token_probs

target_token_ids = self._check_target_token_strs(
target_token_strs)
if return_top_target_token_str:
selected_tokens = next_token_probs[:, np.array(
target_token_ids)].squeeze().argmax(axis=-1)
out_strs = [
target_token_strs[selected_tokens[i]]
for i in range(len(selected_tokens))
]
if len(out_strs) == 1:
out_strs = out_strs[0]
if use_cache:
pkl.dump(out_strs, open(cache_file, "wb"))
return out_strs
else:
out_dict_list = [
{target_token_strs[i]: next_token_probs[prompt_num, target_token_ids[i]]
for i in range(len(target_token_strs))
}
for prompt_num in range(len(prompt))
]
return out_dict_list
else:
outputs = self._model.generate(
**inputs,
Expand Down Expand Up @@ -433,11 +434,10 @@ def __call__(
return out_strs

def _check_target_token_strs(self, target_token_strs, override_token_with_first_token_id=False):
# deal with target_token_strs.... ######################
if isinstance(target_token_strs, str):
target_token_strs = [target_token_strs]

target_token_ids = [self._tokenizer(target_token_str)["input_ids"]
target_token_ids = [self._tokenizer(target_token_str, add_special_tokens=False)["input_ids"]
for target_token_str in target_token_strs]

# Check that the target token is in the vocab
Expand Down
60 changes: 49 additions & 11 deletions imodelsx/submit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from itertools import repeat
import time
import traceback
import os
from os.path import dirname, join

from dict_hash import sha256
submit_utils_dir = dirname(__file__)
"""Handles utilities for job sweeps,
focused on embarassingly parallel sweeps on a single machine.
Expand All @@ -28,6 +31,7 @@ def run_args_list(
repeat_failed_jobs: bool = False,
slurm: bool = False,
slurm_kwargs: Optional[Dict] = None,
amlt_kwargs: Optional[Dict] = None,
):
"""
Params
Expand Down Expand Up @@ -56,9 +60,14 @@ def run_args_list(
Whether to run on SLURM (defaults to False)
slurm_kwargs: Optional[Dict]
kwargs for slurm
amlt_kwargs: Optional[Dict]
kwargs for amlt (will override everything else)
"""
n_gpus = len(gpu_ids)
_validate_run_arguments(n_cpus, gpu_ids)
if amlt_kwargs is not None:
print('Running on AMLT with', amlt_kwargs)
else:
n_gpus = len(gpu_ids)
_validate_run_arguments(n_cpus, gpu_ids)

# adjust order
if shuffle:
Expand All @@ -75,15 +84,8 @@ def run_args_list(
n_gpus = 0

# construct commands
param_str_list = []
for args in args_list:
param_str = cmd_python + ' ' + script_name + ' '
for k, v in args.items():
if isinstance(v, list):
param_str += '--' + k + ' ' + ' '.join(v) + ' '
else:
param_str += '--' + k + ' ' + str(v) + ' '
param_str_list.append(param_str)
param_str_list = [_param_str_from_args(
args, cmd_python, script_name) for args in args_list]

# just print and exit
if not actually_run:
Expand All @@ -100,6 +102,32 @@ def run_args_list(
print(
f'\n\n-------------------{i + 1}/{len(param_str_list)}--------------------\n' + param_str)
run_slurm(param_str, slurm_kwargs=slurm_kwargs)
return
elif amlt_kwargs is not None:
assert 'amlt_file' in amlt_kwargs
sku = amlt_kwargs.get('sku', 'G1')
amlt_dir = dirname(amlt_kwargs['amlt_file'])

amlt_text = open(amlt_kwargs['amlt_file'], 'r').read()
script_name = script_name.replace(amlt_dir, '').strip('/')
param_str_list = [_param_str_from_args(
args, 'python', script_name) for args in args_list]

logs_dir = join(amlt_dir, 'logs')
os.makedirs(logs_dir, exist_ok=True)
for i, param_str in enumerate(param_str_list):
out_file = join(logs_dir, sha256(
{'s': param_str_list[i]}) + '.yaml')
s = amlt_text.format(
param_str=param_str_list[i],
sku=sku
).replace('$CONFIG_DIR', '$CONFIG_DIR/..')
with open(out_file, 'w') as f:
f.write(s)
subprocess.run(
f'amlt run {out_file}', shell=True, check=True,
)
return

# run serial
elif n_cpus == 1 and n_gpus == 0:
Expand Down Expand Up @@ -203,6 +231,16 @@ def run_slurm(param_str, slurm_kwargs):
)


def _param_str_from_args(args, cmd_python, script_name):
param_str = cmd_python + ' ' + script_name + ' '
for k, v in args.items():
if isinstance(v, list):
param_str += '--' + k + ' ' + ' '.join(v) + ' '
else:
param_str += '--' + k + ' ' + str(v) + ' '
return param_str


def run_on_gpu(param_str, i, n):
gpu_id = job_queue_multiprocessing.get()
failed_job = None
Expand Down

0 comments on commit 2404d16

Please sign in to comment.