-
Notifications
You must be signed in to change notification settings - Fork 194
Expand file tree
/
Copy pathgraph.py
More file actions
104 lines (83 loc) · 3.1 KB
/
Copy pathgraph.py
File metadata and controls
104 lines (83 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import cudnn
from contextlib import contextmanager
from typing import Optional, List, Union, Callable
from functools import wraps
import warnings
def graph_cache(key_fn, maxsize=256):
"""Custom caching decorator that uses a provided key function
Args:
key_fn: Function that generates cache key from the input arguments
maxsize: Maximum size of the cache
"""
def decorator(func):
cache = {}
@wraps(func)
def wrapper(*args, **kwargs):
key = key_fn(*args, **kwargs)
if key in cache:
return cache[key]
result = func(*args, **kwargs)
if len(cache) >= maxsize:
# Remove oldest item if cache is full
cache.pop(next(iter(cache)))
cache[key] = result
return result
return wrapper
return decorator
def jit(
heur_modes: Union[List[cudnn.heur_mode], cudnn.heur_mode] = cudnn.heur_mode.A,
**kwargs,
) -> Callable:
"""
Decorator that automatically builds a graph with specified heuristic modes.
Args:
heur_modes: Single heuristic mode or list of modes for graph building.
**kwargs: Additional configuration options for graph building.
Returns:
Callable: Decorated context manager function that returns (graph, tensor_uids).
Example:
>>> handle = cudnn.create_handle()
>>> @cudnn.jit(heur_modes=[cudnn.heur_mode.A, cudnn.heur_mode.B])
... def my_graph():
... with graph(handle) as g:
... X = g.tensor(name="X", dim=[8, 64, 56, 56],
... stride=[56*56*64, 1, 56*64, 64])
... return g, [X] # Return graph and list of tensors to get UIDs for
"""
if not isinstance(heur_modes, list):
heur_modes = [heur_modes]
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
g, tensors = func(*args, **kwargs) # Get the result
if g.get_execution_plan_count() <= 0:
g.build(heur_modes) # Build the graph
return g, [t.get_uid() for t in tensors] # Convert tensors to UIDs
return wrapper
return decorator
@contextmanager
def graph(
handle: object,
name: str = "cudnn_graph",
io_data_type: cudnn.data_type = cudnn.data_type.HALF,
intermediate_data_type: cudnn.data_type = cudnn.data_type.FLOAT,
compute_data_type: cudnn.data_type = cudnn.data_type.FLOAT,
) -> cudnn.pygraph:
"""
Context manager for creating and managing a CUDNN graph object.
Args:
handle: CUDNN handle created with cudnn.create_handle().
name: Name of the graph for debugging purposes.
io_data_type: Data type for input/output tensors.
compute_data_type: Data type for computation.
Yields:
Tuple[cudnn.pygraph, List]: (graph object, list of tensors to get UIDs for)
"""
g = cudnn.pygraph(
handle=handle,
name=name,
io_data_type=io_data_type,
intermediate_data_type=intermediate_data_type,
compute_data_type=compute_data_type,
)
yield g, []