Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve compiled RT-DETR inference speed #33412

Merged
merged 5 commits into from
Sep 18, 2024

Conversation

yonigozlan
Copy link
Member

@yonigozlan yonigozlan commented Sep 10, 2024

What does this PR do?

This PR is part of an ongoing effort to optimize the inference speed of Transformers' vision models.
For more info on the specific issues these optimizations target and how they help improve the inference speed of compiled models, you can check out this Notion page.

The following metrics are for model inference only! Pre/post-processing time are not measured here. Currently, Transformers image processors seem to be a big bottleneck for inference speed, so end-to-end inference will sadly be much slower than this.

benchmark_results_rt_detr_fp16

benchmark_results_rt_detr_r101

It's also worth noting that while using the custom CUDA kernel for deformable attention (by setting disable_custom_kernels=False in the config) currently provides a slight boost to inference speed for compiled models, with these new optimizations, enabling or disabling it now leads to very similar performance.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@qubvel @amyeroberts @NielsRogge

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@SangbumChoi SangbumChoi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow thanks for the contribution. One simple question for my curiosity, why does the major boost up occurs in the fp16?

grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
if embed_dim % 4 != 0:
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
Copy link
Contributor

@SangbumChoi SangbumChoi Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yonigozlan Is this the main reason for being the compiled version of RT-DETR has boosted up majorly in FP16? (Since the original omega was giving default float32 and needed to be transitioned to fp16 in inference?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m quite new to torch compile, so take this with a grain of salt, but from my understanding, the main reason for the speedup (both in fp32 and fp16) is that the model now has no graph breaks. This means no CPU/GPU transfers inside the model, and lets torch compile use CUDA graphs, reducing kernel launch overhead.
I think the main boost in fp16 comes from Tensor Cores, which are optimized for half-precision and make GPU operations faster. In fact, GPU operations were already faster in the compiled models in the current version, but the gains were overshadowed by CPU/GPU transfer overhead. So I don't think this change of omega to fp16 made all the difference, but it's more a global effect of no graph breaks + Tensor Cores optimized for fp16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the kind explanation, I am also newbie of this compile logic so I just wanted to know for my curiosity.

model now has no graph breaks

I think this could be also possible, thanks 👍🏼

self.encoder_hidden_dim,
self.positional_encoding_temperature,
device=src_flatten.device,
dtype=src_flatten.dtype,
).to(src_flatten.device, src_flatten.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't yhink we need .to() if we give device and dtype into the function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed thanks!

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! The speed-up is impressive 🔥
I left a few comments/questions below:

@@ -737,7 +737,9 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
# Ignore copy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "Copy from" looks almost useless for this function with "Ignore copy" a few lines after 🙂
does it make sense to spread changes to other architectures?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but I'm also working on making the same changes in deformable detr so the Ignore copy shouldn't be necessary soon. I'm getting some weird results with compiled deformable detr in fp16 though, which I really don't understand.

Comment on lines 864 to 867

# Ignore copy
total_elements = sum([shape[0] * shape[1] for shape in spatial_shapes_list])
if total_elements != sequence_length:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure we need "Ignore copy" here, we already modified method signature above

Comment on lines 1597 to 1611
def conditional_lru_cache(*lru_args, **lru_kwargs):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not torch.compiler.is_compiling():
# Cache the function only if the model is not being compiled
cached_func = lru_cache(*lru_args, **lru_kwargs)(func.__get__(self, type(self)))
return cached_func(*args, **kwargs)
else:
# Otherwise, just call the original function
return func(self, *args, **kwargs)

return wrapper

return decorator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! 🔥

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe name it as compile_compatible_lru_cache?

Comment on lines 1674 to 1709
# We always generate anchors in float32 to preserve equivalence between
# dynamic and static anchor inference
dtype = torch.float32

@conditional_lru_cache(maxsize=32)
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
Copy link
Member

@qubvel qubvel Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question here: does lru_cache correctly handle torch.dtype argument for caching?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it does, however there was still a problem with conditional_lru_cache where the cached function would be recreated everytime, so there was no cache hit. It works now though, I do see the cache hits increasing when I print them with print("generate_anchors cache hits: ", self.__getattribute__(f"_cached_{func.__name__}").cache_info().hits) inside the conditional_lru_cache decorator.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 very nice!

src/transformers/models/rt_detr/modeling_rt_detr.py Outdated Show resolved Hide resolved
@yonigozlan yonigozlan merged commit 7b1ce63 into huggingface:main Sep 18, 2024
16 checks passed
@yonigozlan yonigozlan mentioned this pull request Sep 19, 2024
5 tasks
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* modify rt detr to improve inference times when compiled

* Remove redundant "to"

* Fix conditional lru_cache and missing shapes_list

* nit unnecessary list creation

* Fix compile error when ninja not available and custon kernel activated
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 2, 2024
* modify rt detr to improve inference times when compiled

* Remove redundant "to"

* Fix conditional lru_cache and missing shapes_list

* nit unnecessary list creation

* Fix compile error when ninja not available and custon kernel activated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants