Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 12, 2024
1 parent 02a67ff commit 09f242f
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions examples/jax/encoder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,22 @@
These helper functions to get gpu properties via nvidia-smi
"""


@lru_cache
def get_device_compute_capability() -> int:
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=compute_cap', '--format=csv,noheader,nounits'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
result = subprocess.run(
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader,nounits"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode == 0:
# Split the output by newlines and take the first non-empty line
compute_caps = [cap.strip() for cap in result.stdout.split('\n') if cap.strip()]
compute_caps = [cap.strip() for cap in result.stdout.split("\n") if cap.strip()]
if compute_caps:
# Convert MAJOR.MINOR to MAJOR * 10 + MINOR
major, minor = map(int, compute_caps[0].split('.'))
major, minor = map(int, compute_caps[0].split("."))
return major * 10 + minor
else:
return "No GPU detected"
Expand All @@ -28,6 +33,7 @@ def get_device_compute_capability() -> int:
except FileNotFoundError:
return "nvidia-smi command not found. Ensure NVIDIA drivers are installed."


@lru_cache
def is_bf16_supported() -> bool:
"""Return if BF16 has hardware supported"""
Expand All @@ -41,6 +47,7 @@ def is_fp8_available() -> bool:
gpu_arch = get_device_compute_capability()
return gpu_arch >= 90


@lru_cache
def get_num_gpus() -> int:
"""Return number of available gpus"""
Expand Down

0 comments on commit 09f242f

Please sign in to comment.