Skip to content

Commit

Permalink
Fix for device issue (#126)
Browse files Browse the repository at this point in the history
* Will's fix for device issue
* more descriptive GPU choices
  • Loading branch information
matsen authored Mar 6, 2025
1 parent 25a3a56 commit 92d29c9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ def find_least_used_cuda_gpu(mem_round_val=1300):

# Check prioritization order:
if max(utilization) > 0:
print("used utilization")
print("GPU chosen via utilization")
return utilization.index(min(utilization)) # Least utilized GPU

if max(memory_used) > 0:
print("used memory")
print("GPU chosen via memory")
return memory_used.index(min(memory_used)) # Least memory used GPU

if len(set(gpu_counts.values())) > 1:
print("used processes")
print("GPU chosen via process count")
return min(
uuid_to_index[uuid]
for uuid, count in gpu_counts.items()
Expand Down
2 changes: 1 addition & 1 deletion netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def selection_factors_of_aa_str(self, aa_sequence: Tuple[str, str]) -> Tensor:
).squeeze(0)
else:
result = result.squeeze(0)
return split_heavy_light_model_outputs(result, idx_seq.squeeze(0))
return split_heavy_light_model_outputs(result.cpu(), idx_seq.squeeze(0).cpu())


class TransformerBinarySelectionModelLinAct(AbstractBinarySelectionModel):
Expand Down

0 comments on commit 92d29c9

Please sign in to comment.