From 09f242f0cc792a66d1dd659837584cd05abbd24b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:13:02 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/jax/encoder/common.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index cc6cd26ab4..15833992eb 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -9,17 +9,22 @@ These helper functions to get gpu properties via nvidia-smi """ + @lru_cache def get_device_compute_capability() -> int: try: - result = subprocess.run(['nvidia-smi', '--query-gpu=compute_cap', '--format=csv,noheader,nounits'], - stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + result = subprocess.run( + ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader,nounits"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) if result.returncode == 0: # Split the output by newlines and take the first non-empty line - compute_caps = [cap.strip() for cap in result.stdout.split('\n') if cap.strip()] + compute_caps = [cap.strip() for cap in result.stdout.split("\n") if cap.strip()] if compute_caps: # Convert MAJOR.MINOR to MAJOR * 10 + MINOR - major, minor = map(int, compute_caps[0].split('.')) + major, minor = map(int, compute_caps[0].split(".")) return major * 10 + minor else: return "No GPU detected" @@ -28,6 +33,7 @@ def get_device_compute_capability() -> int: except FileNotFoundError: return "nvidia-smi command not found. Ensure NVIDIA drivers are installed." + @lru_cache def is_bf16_supported() -> bool: """Return if BF16 has hardware supported""" @@ -41,6 +47,7 @@ def is_fp8_available() -> bool: gpu_arch = get_device_compute_capability() return gpu_arch >= 90 + @lru_cache def get_num_gpus() -> int: """Return number of available gpus"""