-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve compiled RT-DETR inference speed #33412
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow thanks for the contribution. One simple question for my curiosity, why does the major boost up occurs in the fp16?
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") | ||
if embed_dim % 4 != 0: | ||
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") | ||
pos_dim = embed_dim // 4 | ||
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim | ||
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yonigozlan Is this the main reason for being the compiled version of RT-DETR has boosted up majorly in FP16? (Since the original omega was giving default float32 and needed to be transitioned to fp16 in inference?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m quite new to torch compile, so take this with a grain of salt, but from my understanding, the main reason for the speedup (both in fp32 and fp16) is that the model now has no graph breaks. This means no CPU/GPU transfers inside the model, and lets torch compile use CUDA graphs, reducing kernel launch overhead.
I think the main boost in fp16 comes from Tensor Cores, which are optimized for half-precision and make GPU operations faster. In fact, GPU operations were already faster in the compiled models in the current version, but the gains were overshadowed by CPU/GPU transfer overhead. So I don't think this change of omega to fp16 made all the difference, but it's more a global effect of no graph breaks + Tensor Cores optimized for fp16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the kind explanation, I am also newbie of this compile logic so I just wanted to know for my curiosity.
model now has no graph breaks
I think this could be also possible, thanks 👍🏼
self.encoder_hidden_dim, | ||
self.positional_encoding_temperature, | ||
device=src_flatten.device, | ||
dtype=src_flatten.dtype, | ||
).to(src_flatten.device, src_flatten.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't yhink we need .to()
if we give device and dtype into the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! The speed-up is impressive 🔥
I left a few comments/questions below:
@@ -737,7 +737,9 @@ def multi_scale_deformable_attention( | |||
) -> Tensor: | |||
batch_size, _, num_heads, hidden_dim = value.shape | |||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape | |||
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) | |||
# Ignore copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "Copy from" looks almost useless for this function with "Ignore copy" a few lines after 🙂
does it make sense to spread changes to other architectures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, but I'm also working on making the same changes in deformable detr so the Ignore copy shouldn't be necessary soon. I'm getting some weird results with compiled deformable detr in fp16 though, which I really don't understand.
|
||
# Ignore copy | ||
total_elements = sum([shape[0] * shape[1] for shape in spatial_shapes_list]) | ||
if total_elements != sequence_length: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure we need "Ignore copy" here, we already modified method signature above
def conditional_lru_cache(*lru_args, **lru_kwargs): | ||
def decorator(func): | ||
@wraps(func) | ||
def wrapper(self, *args, **kwargs): | ||
if not torch.compiler.is_compiling(): | ||
# Cache the function only if the model is not being compiled | ||
cached_func = lru_cache(*lru_args, **lru_kwargs)(func.__get__(self, type(self))) | ||
return cached_func(*args, **kwargs) | ||
else: | ||
# Otherwise, just call the original function | ||
return func(self, *args, **kwargs) | ||
|
||
return wrapper | ||
|
||
return decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe name it as compile_compatible_lru_cache
?
# We always generate anchors in float32 to preserve equivalence between | ||
# dynamic and static anchor inference | ||
dtype = torch.float32 | ||
|
||
@conditional_lru_cache(maxsize=32) | ||
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The question here: does lru_cache
correctly handle torch.dtype
argument for caching?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like it does, however there was still a problem with conditional_lru_cache where the cached function would be recreated everytime, so there was no cache hit. It works now though, I do see the cache hits increasing when I print them with print("generate_anchors cache hits: ", self.__getattribute__(f"_cached_{func.__name__}").cache_info().hits)
inside the conditional_lru_cache decorator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥 very nice!
* modify rt detr to improve inference times when compiled * Remove redundant "to" * Fix conditional lru_cache and missing shapes_list * nit unnecessary list creation * Fix compile error when ninja not available and custon kernel activated
* modify rt detr to improve inference times when compiled * Remove redundant "to" * Fix conditional lru_cache and missing shapes_list * nit unnecessary list creation * Fix compile error when ninja not available and custon kernel activated
What does this PR do?
This PR is part of an ongoing effort to optimize the inference speed of Transformers' vision models.
For more info on the specific issues these optimizations target and how they help improve the inference speed of compiled models, you can check out this Notion page.
The following metrics are for model inference only! Pre/post-processing time are not measured here. Currently, Transformers image processors seem to be a big bottleneck for inference speed, so end-to-end inference will sadly be much slower than this.
It's also worth noting that while using the custom CUDA kernel for deformable attention (by setting
disable_custom_kernels=False
in the config) currently provides a slight boost to inference speed for compiled models, with these new optimizations, enabling or disabling it now leads to very similar performance.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@qubvel @amyeroberts @NielsRogge