Skip to content

Commit

Permalink
Merge pull request #27 from mboisson/shard_support
Browse files Browse the repository at this point in the history
Adds support for GPU shards detection.
  • Loading branch information
cmd-ntrf authored Jan 20, 2025
2 parents b44ac72 + e574142 commit 611a2a5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
53 changes: 41 additions & 12 deletions slurmformspawner/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,20 +267,49 @@ def config_gpus(self):
lock = self.resolve(self.gpus.get('lock'))

gpu_choice_map = {}
for gres in choices:
if gres == 'gpu:0':
# if the node has shards, we need the number of gpus and number of shards
max_shard_per_gpu = 0
gpu_types = set()
for choice in choices:
if choice == 'gpu:0':
gpu_choice_map['gpu:0'] = 'None'
continue
match = re.match(r"(gpu:[\w:.]+)", gres)
if match:
gres = match.group(1).split(':')
number = int(gres[-1])
if len(gres) == 2:
strings = ('gpu:{}', '{} x GPU')
elif len(gres) > 2:
strings = ('gpu:{}:{{}}'.format(gres[1]), '{{}} x {}'.format(gres[1].upper()))
for i in range(1, number + 1):
gpu_choice_map[strings[0].format(i)] = strings[1].format(i)

# we now have one choice per type of gres configuration to support
# heterogenous cluster configuration, each node could have multiple types of gres
gres_list = choice.split(',')

total_gpu = 0
num_shard = 0
gpu_type = ''
for gres_def in gres_list:
match = re.match(r"(gpu:[\w:.]+)", gres_def)
if match:
gres = match.group(1).split(':')
number = int(gres[-1])
total_gpu += number
if len(gres) == 2:
strings = ('gpu:{}', '{} x GPU')
gpu_type = 'GPU'
elif len(gres) > 2:
strings = ('gpu:{}:{{}}'.format(gres[1]), '{{}} x {}'.format(gres[1].upper()))
gpu_type = gres[1].upper()
for i in range(1, number + 1):
gpu_choice_map[strings[0].format(i)] = strings[1].format(i)
else:
match = re.match(r"(shard:[\w:.]+)", gres_def)
if match:
gres = match.group(1).split(':')
num_shard = int(gres[-1])
if num_shard > 0:
gpu_types.add(gpu_type)
max_shard_per_gpu = max(max_shard_per_gpu, int(num_shard / total_gpu))

if max_shard_per_gpu > 0:
strings = ('shard:{}', '{}/{} x ({})')
for i in range(1, max_shard_per_gpu):
gpu_choice_map[strings[0].format(i)] = strings[1].format(i, max_shard_per_gpu, '|'.join(gpu_types))

self.form['gpus'].choices = list(gpu_choice_map.items())
if lock:
self.form['gpus'].render_kw = {'disabled': 'disabled'}
Expand Down
2 changes: 1 addition & 1 deletion slurmformspawner/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_node_info(self):
for node in nodes:
output['cpu'].append(int(node['CPUTot']))
output['mem'].append(int(node['RealMemory']) - int(node.get('MemSpecLimit', '0')))
output['gres'].extend(node['Gres'].split(","))
output['gres'].extend([node['Gres']])
output['partitions'].extend(node['Partitions'].split(","))
return output

Expand Down

0 comments on commit 611a2a5

Please sign in to comment.