Skip to content

Commit 59fea2b

Browse files
committed
Simplify NVIDIA GPU detection
1 parent 7f6e70a commit 59fea2b

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

tests/test_gpu_db.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,11 @@ class TestGPUTypeDetection:
9494
"""Test GPU type detection (discrete, integrated, none)."""
9595

9696
def test_nvidia_discrete_by_brand(self):
97-
assert get_gpu_type(NVIDIA, "1234", "NVIDIA GeForce RTX 4080") == "DISCRETE"
98-
assert get_gpu_type(NVIDIA, "1234", "NVIDIA Quadro P1000") == "DISCRETE"
99-
assert get_gpu_type(NVIDIA, "1234", "NVIDIA Tesla V100") == "DISCRETE"
97+
# Note: In real PCI IDs database, GPU names include architecture codes
98+
# Testing with realistic device names that include arch codes
99+
assert get_gpu_type(NVIDIA, "2704", "AD103 [GeForce RTX 4080]") == "DISCRETE"
100+
assert get_gpu_type(NVIDIA, "1cb1", "GP107GL [Quadro P1000]") == "DISCRETE"
101+
assert get_gpu_type(NVIDIA, "1db4", "GV100GL [Tesla V100]") == "DISCRETE"
100102

101103
def test_nvidia_discrete_by_arch_code(self):
102104
assert get_gpu_type(NVIDIA, "1234", "NVIDIA GP104") == "DISCRETE"
@@ -151,24 +153,24 @@ class TestPatternConstruction:
151153
"""Test that patterns are correctly built without duplication."""
152154

153155
def test_nvidia_discrete_pattern_includes_all_architectures(self):
154-
"""Verify NVIDIA_DISCRETE_PATTERN includes all architecture codes."""
156+
"""Verify NVIDIA_DISCRETE_PATTERN includes all architecture codes (Kepler+)."""
155157
from torchruntime.gpu_db import NVIDIA_DISCRETE_PATTERN
156158

157-
# Test that all architecture patterns are included
159+
# Test that all architecture patterns are included (Kepler and newer)
158160
arch_codes = ["gk104", "gm107", "gp104", "gv100", "tu116", "ga102", "gh100", "ad102", "gb100"]
159161
for code in arch_codes:
160162
assert NVIDIA_DISCRETE_PATTERN.search(code), f"Pattern should match {code}"
161163

162-
# Test brand names
163-
brands = ["GeForce", "Quadro", "Tesla", "RTX", "GTX", "Titan"]
164-
for brand in brands:
165-
assert NVIDIA_DISCRETE_PATTERN.search(brand), f"Pattern should match {brand}"
166-
167164
# Test Blackwell model numbers
168165
models = ["5060", "5070", "5080", "5090"]
169166
for model in models:
170167
assert NVIDIA_DISCRETE_PATTERN.search(model), f"Pattern should match {model}"
171168

169+
# Brand names alone (without arch codes) are NOT supported for Kepler+ only detection
170+
# Pre-Kepler GPUs like "GeForce 8800 GTX" (G80) should NOT match
171+
assert not NVIDIA_DISCRETE_PATTERN.search("GeForce 8800 GTX")
172+
assert not NVIDIA_DISCRETE_PATTERN.search("GeForce GTX 580") # GF110, pre-Kepler
173+
172174
def test_amd_discrete_pattern_separate(self):
173175
"""Verify AMD has a separate discrete pattern."""
174176
from torchruntime.gpu_db import AMD_DISCRETE_PATTERN

torchruntime/gpu_db.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ def get_amd_gfx_info(device_id):
106106
# GPU Device Type Identification
107107
# =============================================================================
108108

109-
# Build NVIDIA discrete pattern by combining brand names with architecture patterns
110-
_NVIDIA_BRANDS = r"geforce|riva|quadro|tesla|ion|grid|rtx|gtx|titan"
109+
# Build NVIDIA discrete pattern from architecture patterns (Kepler and newer only)
110+
# This covers GPUs from 2012 onwards with compute capability 3.7+
111111
_NVIDIA_ARCH_PATTERNS = r"|".join(pattern.pattern.strip(r"\b()").strip(r"(?:)") for pattern in NVIDIA_ARCH_MAP.keys())
112112
NVIDIA_DISCRETE_PATTERN = re.compile(
113-
rf"\b(?:{_NVIDIA_BRANDS}|{_NVIDIA_ARCH_PATTERNS})\b",
113+
rf"\b(?:{_NVIDIA_ARCH_PATTERNS})\b",
114114
re.IGNORECASE,
115115
)
116116

0 commit comments

Comments
 (0)