Skip to content

Commit 95988a9

Browse files
committed
Refactor GPU identification
1 parent 92e4248 commit 95988a9

File tree

5 files changed

+386
-125
lines changed

5 files changed

+386
-125
lines changed

tests/test_gpu_db.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""
2+
Tests for the gpu_db module to ensure GPU identification patterns work correctly.
3+
"""
4+
5+
import pytest
6+
from torchruntime.gpu_db import (
7+
get_nvidia_arch,
8+
get_amd_gfx_info,
9+
get_gpu_type,
10+
is_gpu_vendor,
11+
NVIDIA,
12+
AMD,
13+
INTEL,
14+
)
15+
16+
17+
class TestNvidiaArchDetection:
18+
"""Test NVIDIA GPU architecture detection."""
19+
20+
def test_kepler_detection(self):
21+
device_names = ["NVIDIA GeForce GK104", "GeForce GTX 780 Ti"]
22+
assert get_nvidia_arch(device_names) == 3.7
23+
24+
def test_maxwell_detection(self):
25+
device_names = ["NVIDIA GeForce GM107", "GeForce GTX 950"]
26+
assert get_nvidia_arch(device_names) == 5.0
27+
28+
def test_pascal_detection(self):
29+
device_names = ["NVIDIA GeForce GP104", "GeForce GTX 1080"]
30+
assert get_nvidia_arch(device_names) == 6.0
31+
32+
def test_volta_detection(self):
33+
device_names = ["NVIDIA Tesla GV100"]
34+
assert get_nvidia_arch(device_names) == 7.0
35+
36+
def test_turing_detection(self):
37+
device_names = ["NVIDIA GeForce TU116", "GeForce GTX 1660"]
38+
assert get_nvidia_arch(device_names) == 7.5
39+
40+
def test_ampere_detection(self):
41+
device_names = ["NVIDIA GeForce GA102", "GeForce RTX 3080"]
42+
assert get_nvidia_arch(device_names) == 8.6
43+
44+
def test_hopper_detection(self):
45+
device_names = ["NVIDIA GH100", "H100 PCIe"]
46+
assert get_nvidia_arch(device_names) == 9.0
47+
48+
def test_ada_lovelace_detection(self):
49+
device_names = ["NVIDIA GeForce AD102", "GeForce RTX 4090"]
50+
assert get_nvidia_arch(device_names) == 8.9
51+
52+
def test_blackwell_detection(self):
53+
device_names = ["NVIDIA GeForce RTX 5090"]
54+
assert get_nvidia_arch(device_names) == 12.0
55+
56+
def test_blackwell_gb_code_detection(self):
57+
device_names = ["NVIDIA GB100"]
58+
assert get_nvidia_arch(device_names) == 12.0
59+
60+
def test_unknown_architecture(self):
61+
device_names = ["NVIDIA GeForce Unknown"]
62+
assert get_nvidia_arch(device_names) == 0
63+
64+
65+
class TestAMDGfxInfo:
66+
"""Test AMD GPU GFX information retrieval."""
67+
68+
def test_renoir_gfx(self):
69+
name, gfx, hsa = get_amd_gfx_info("1636")
70+
assert name == "Renoir"
71+
assert gfx == "gfx90c"
72+
assert hsa == "9.0.12"
73+
74+
def test_phoenix_gfx(self):
75+
name, gfx, hsa = get_amd_gfx_info("164f")
76+
assert name == "Phoenix"
77+
assert gfx == "gfx1103"
78+
assert hsa == "11.0.1"
79+
80+
def test_strix_gfx(self):
81+
name, gfx, hsa = get_amd_gfx_info("150e")
82+
assert name == "Strix"
83+
assert gfx == "gfx1150"
84+
assert hsa == "11.5.0"
85+
86+
def test_unknown_device_id(self):
87+
name, gfx, hsa = get_amd_gfx_info("9999")
88+
assert name == ""
89+
assert gfx == ""
90+
assert hsa == ""
91+
92+
93+
class TestGPUTypeDetection:
94+
"""Test GPU type detection (discrete, integrated, none)."""
95+
96+
def test_nvidia_discrete_by_brand(self):
97+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA GeForce RTX 4080") == "DISCRETE"
98+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA Quadro P1000") == "DISCRETE"
99+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA Tesla V100") == "DISCRETE"
100+
101+
def test_nvidia_discrete_by_arch_code(self):
102+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA GP104") == "DISCRETE"
103+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA TU116") == "DISCRETE"
104+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA GA102") == "DISCRETE"
105+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA AD103") == "DISCRETE"
106+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA GH100") == "DISCRETE"
107+
108+
def test_nvidia_excluded_devices(self):
109+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA Audio Controller") == "NONE"
110+
assert get_gpu_type(NVIDIA, "1234", "NVIDIA USB Controller") == "NONE"
111+
112+
def test_amd_discrete_by_name(self):
113+
assert get_gpu_type(AMD, "1234", "AMD Radeon RX 7900 XT") == "DISCRETE"
114+
assert get_gpu_type(AMD, "1234", "AMD Navi 31") == "DISCRETE"
115+
assert get_gpu_type(AMD, "1234", "AMD Instinct MI300") == "DISCRETE"
116+
117+
def test_amd_integrated_by_device_id(self):
118+
assert get_gpu_type(AMD, "1636", "AMD Renoir") == "INTEGRATED"
119+
assert get_gpu_type(AMD, "164f", "AMD Phoenix") == "INTEGRATED"
120+
assert get_gpu_type(AMD, "150e", "AMD Strix") == "INTEGRATED"
121+
122+
def test_amd_excluded_devices(self):
123+
assert get_gpu_type(AMD, "1234", "AMD Audio Device") == "NONE"
124+
assert get_gpu_type(AMD, "1234", "AMD USB Controller") == "NONE"
125+
126+
def test_intel_discrete(self):
127+
assert get_gpu_type(INTEL, "1234", "Intel Arc A770") == "DISCRETE"
128+
129+
def test_intel_integrated(self):
130+
assert get_gpu_type(INTEL, "1234", "Intel Iris Xe Graphics") == "INTEGRATED"
131+
assert get_gpu_type(INTEL, "1234", "Intel HD Graphics 620") == "INTEGRATED"
132+
assert get_gpu_type(INTEL, "1234", "Intel UHD Graphics 630") == "INTEGRATED"
133+
134+
def test_intel_excluded_devices(self):
135+
assert get_gpu_type(INTEL, "1234", "Intel Audio Device") == "NONE"
136+
137+
138+
class TestGPUVendorCheck:
139+
"""Test GPU vendor identification."""
140+
141+
def test_known_vendors(self):
142+
assert is_gpu_vendor(NVIDIA) is True
143+
assert is_gpu_vendor(AMD) is True
144+
assert is_gpu_vendor(INTEL) is True
145+
146+
def test_unknown_vendor(self):
147+
assert is_gpu_vendor("9999") is False
148+
149+
150+
class TestPatternConstruction:
151+
"""Test that patterns are correctly built without duplication."""
152+
153+
def test_nvidia_discrete_pattern_includes_all_architectures(self):
154+
"""Verify NVIDIA_DISCRETE_PATTERN includes all architecture codes."""
155+
from torchruntime.gpu_db import NVIDIA_DISCRETE_PATTERN
156+
157+
# Test that all architecture patterns are included
158+
arch_codes = ["gk104", "gm107", "gp104", "gv100", "tu116", "ga102", "gh100", "ad102", "gb100"]
159+
for code in arch_codes:
160+
assert NVIDIA_DISCRETE_PATTERN.search(code), f"Pattern should match {code}"
161+
162+
# Test brand names
163+
brands = ["GeForce", "Quadro", "Tesla", "RTX", "GTX", "Titan"]
164+
for brand in brands:
165+
assert NVIDIA_DISCRETE_PATTERN.search(brand), f"Pattern should match {brand}"
166+
167+
# Test Blackwell model numbers
168+
models = ["5060", "5070", "5080", "5090"]
169+
for model in models:
170+
assert NVIDIA_DISCRETE_PATTERN.search(model), f"Pattern should match {model}"
171+
172+
def test_amd_discrete_pattern_separate(self):
173+
"""Verify AMD has a separate discrete pattern."""
174+
from torchruntime.gpu_db import AMD_DISCRETE_PATTERN
175+
176+
# Should match AMD product lines
177+
assert AMD_DISCRETE_PATTERN.search("Radeon RX 7900")
178+
assert AMD_DISCRETE_PATTERN.search("Instinct MI300")
179+
assert AMD_DISCRETE_PATTERN.search("Navi 31")
180+
181+
# Should not match integrated codenames
182+
assert not AMD_DISCRETE_PATTERN.search("Phoenix")
183+
assert not AMD_DISCRETE_PATTERN.search("Renoir")

torchruntime/configuration.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22

33
from .consts import AMD
4-
from .device_db import get_gpus, GPU_DEVICES
4+
from .device_db import get_gpus
5+
from .gpu_db import get_amd_gfx_info
56
from .platform_detection import get_torch_platform, os_name
67

78

@@ -110,9 +111,7 @@ def _set_rocm_vars_for_integrated(gpu_infos):
110111
gpu = gpu_infos[0]
111112
env = {}
112113

113-
integrated_amd_gpus = GPU_DEVICES[AMD]["integrated"]
114-
115-
family_name, gfx_id, hsa_version = integrated_amd_gpus.get(gpu.device_id, ("", "", ""))
114+
family_name, gfx_id, hsa_version = get_amd_gfx_info(gpu.device_id)
116115
env["HSA_OVERRIDE_GFX_VERSION"] = hsa_version
117116

118117
if gfx_id.startswith("gfx8"):

torchruntime/device_db.py

Lines changed: 1 addition & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass
88

99
from .consts import NVIDIA, AMD, INTEL
10+
from .gpu_db import is_gpu_vendor, get_gpu_type
1011

1112
DEVICE_DB_FILE = "gpu_pci_ids.db" # this file will only include AMD, NVIDIA and Discrete Intel GPUs
1213

@@ -20,97 +21,9 @@ class GPU:
2021
is_discrete: bool
2122

2223

23-
GPU_DEVICES = { # if the value is a regex, it'll be applied to the device_name. if the value is a dict, the pci_id will be looked up
24-
AMD: {
25-
"discrete": re.compile(r"\b(?:radeon|instinct|fire|rage|polaris|aldebaran|navi)\b", re.IGNORECASE),
26-
"integrated": { # pci_id -> (device_name, gfx_name, hsa_override)
27-
"15dd": ("Raven Ridge", "gfx902", "9.1.0"),
28-
"15d8": ("Picasso", "gfx903", "9.1.0"),
29-
# gfx90c - https://github.com/ROCm/clr/commit/fbbae8055fbad748254bec1675a6970ce1dce594
30-
"1636": ("Renoir", "gfx90c", "9.0.12"),
31-
"164c": ("Lucienne", "gfx90c", "9.0.12"),
32-
"1638": ("Cezanne", "gfx90c", "9.0.12"),
33-
"15e7": ("Barcelo", "gfx90c", "9.0.12"),
34-
"163f": ("VanGogh", "gfx1033", "10.3.3"),
35-
"164d": ("Rembrandt", "gfx1035", "10.3.5"),
36-
"1681": ("Rembrandt", "gfx1035", "10.3.5"),
37-
"164e": ("Raphael", "gfx1036", "10.3.6"),
38-
"1506": ("Mendocino", "gfx1037", "10.3.7"),
39-
"13c0": ("Granite Ridge", "gfx1030", "10.3.0"),
40-
"164f": ("Phoenix", "gfx1103", "11.0.1"),
41-
"15bf": ("Phoenix1", "gfx1103", "11.0.3"),
42-
"15c8": ("Phoenix2", "gfx1103", "11.0.3"),
43-
"1900": ("Phoenix3", "gfx1103", "11.0.4"),
44-
"1901": ("Phoenix4", "gfx1103", "11.0.4"),
45-
# gfx1150 - https://github.com/ROCm/clr/commit/4bc515aa62368c6189d19e14b3ec18cb6dd4415e
46-
"150e": ("Strix", "gfx1150", "11.5.0"),
47-
"1586": ("Strix Halo", "gfx1151", "11.5.1"),
48-
"1114": ("Krackan", "gfx1151", "11.5.1"),
49-
},
50-
"exclude": re.compile(r"\b(?:audio|bridge|arden|oberon|stoney|wani|usb|switch)\b", re.IGNORECASE),
51-
},
52-
INTEL: {
53-
"discrete": re.compile(r"\b(?:arc)\b", re.IGNORECASE),
54-
"integrated": re.compile(r"\b(?:iris|hd graphics|uhd graphics)\b", re.IGNORECASE),
55-
"exclude": re.compile(r"\b(?:audio|bridge)\b", re.IGNORECASE),
56-
},
57-
NVIDIA: {
58-
"discrete": re.compile(
59-
r"\b(?:geforce|riva|quadro|tesla|ion|grid|rtx|tu\d{2,}.+t\d{2,}|gk1\d{2}\w*|gm10\d\w*|gp10\d\w*|gv100\w*|tu1\d{2}\w*|ga10\d\w*|ad10\d\w*)\b",
60-
re.IGNORECASE,
61-
),
62-
"exclude": re.compile(
63-
r"\b(?:audio|switch|pci|memory|smbus|ide|co-processor|bridge|usb|sata|controller)\b", re.IGNORECASE
64-
),
65-
},
66-
}
67-
68-
6924
os_name = platform.system()
7025

7126

72-
def is_gpu_vendor(vendor_id):
73-
return vendor_id in GPU_DEVICES
74-
75-
76-
def get_gpu_type(vendor_id, device_id, device_name):
77-
"""
78-
Returns the GPU type of the given PCI device.
79-
80-
Args:
81-
vendor_id (str): PCI Vendor ID.
82-
device_id (str): PCI Device ID.
83-
device_name (str): PCI Device Name.
84-
85-
Returns:
86-
gpu_type (str): "DISCRETE", "INTEGRATED", or "NONE"
87-
"""
88-
89-
def matches(pattern):
90-
if isinstance(pattern, re.Pattern):
91-
return pattern.search(device_name)
92-
if isinstance(pattern, dict):
93-
return device_id in pattern
94-
return False
95-
96-
vendor_devices = GPU_DEVICES[vendor_id]
97-
98-
discrete_devices = vendor_devices.get("discrete")
99-
integrated_devices = vendor_devices.get("integrated")
100-
exclude_devices = vendor_devices.get("exclude")
101-
102-
if matches(exclude_devices):
103-
return "NONE"
104-
105-
if matches(integrated_devices): # check integrated first, to avoid matching "Radeon" and classifying it as discrete
106-
return "INTEGRATED"
107-
108-
if matches(discrete_devices):
109-
return "DISCRETE"
110-
111-
return "NONE"
112-
113-
11427
def get_windows_output():
11528
try:
11629
command = [

0 commit comments

Comments
 (0)