Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed unnecessary transpose in Switch Transformer Routing #33582

Merged
merged 1 commit into from
Oct 4, 2024

Conversation

karan-uppal3
Copy link
Contributor

What does this PR do?

There is an additional transpose operation at link in the Switch Transformer Implementation. The shape of the tensor is [batch_size, seq_len, num_experts] which undergoes .transpose(1,2) to get the shape [batch_size, num_experts, seq_len], which is later .reshape(batch_size * seq_len, num_experts). Due to the transpose the routing gives incorrect results. Simply removing the .transpose(1,2) gives the correct output. A simple reproducible example can be found at #33463.

Fixes #33463

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@LysandreJik @ArthurZucker @younesbelkada

@LysandreJik
Copy link
Member

Thanks for the PR @karan-uppal3 !

cc @ArthurZucker when you have a second

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! We have run heavy integration tests here:

class SwitchTransformerRouterTest(unittest.TestCase):

Could you make sure they are either:

  • wrong (meaning the output we have today are just wrong)
  • right (meaning this does not influence the output)

Otherwise, makes sense and quite supprised that we did not catch this before 😢

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@karan-uppal3
Copy link
Contributor Author

Hey @ArthurZucker! I went through the tests and found unit tests for the SwitchTransformersTop1Router module but not for the SwitchTransformersSparseMLP module, which is where the unnecessary transpose exists (link). This is why upon running pytest I am getting the correct output.

@ArthurZucker
Copy link
Collaborator

Did you run the tests with RUN_SLOW=1?

class SwitchTransformerModelIntegrationTests(unittest.TestCase):
are the relevant tests!

@ArthurZucker ArthurZucker mentioned this pull request Sep 27, 2024
5 tasks
@karan-uppal3
Copy link
Contributor Author

karan-uppal3 commented Oct 1, 2024

Hey @ArthurZucker, I just ran the tests with RUN_SLOW=1, this is the output I am getting

==================================================================================================== test session starts ====================================================================================================
platform linux -- Python 3.8.12, pytest-7.4.4, pluggy-1.0.0
rootdir: /home/aiscuser/transformers
configfile: pyproject.toml
plugins: rich-0.1.1, timeout-2.3.1, dash-2.18.1, hypothesis-6.112.2, xdist-3.6.1, pythonpath-0.7.4, cov-3.0.0
collected 299 items                                                                                                                                                                                                         

tests/models/switch_transformers/test_modeling_switch_transformers.py ....s.........s.........s.......sssss.sss.ssssssssssss.ss.ss..ss........ssss......sss...........s.s...ssss.ssssssssssssssssssssssssss..ss..ssss [ 47%]
ssss..ssssssssssss...s.s...........sss.sss.sss.sss.........s...s...ssssss.sssssssssss....ssss......ss.........s.s....ssss........ssssss.sss.sssssss......sss                                                          [100%]

===================================================================================================== warnings summary ======================================================================================================
../../../opt/conda/lib/python3.8/site-packages/sklearn/utils/multiclass.py:14
  /opt/conda/lib/python3.8/site-packages/sklearn/utils/multiclass.py:14: DeprecationWarning: Please use `spmatrix` from the `scipy.sparse` namespace, the `scipy.sparse.base` namespace is deprecated.
    from scipy.sparse.base import spmatrix

src/transformers/deepspeed.py:24
  /home/aiscuser/transformers/src/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
    warnings.warn(

tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_attention_outputs
  /home/aiscuser/transformers/src/transformers/generation/configuration_utils.py:775: UserWarning: `return_dict_in_generate` is NOT set to `True`, but `output_attentions` is. When `return_dict_in_generate` is not `True`, `output_attentions` is ignored.
    warnings.warn(

tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_batching_equivalence
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_headmasking
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_hidden_states_output
  /home/aiscuser/transformers/src/transformers/generation/configuration_utils.py:775: UserWarning: `return_dict_in_generate` is NOT set to `True`, but `output_hidden_states` is. When `return_dict_in_generate` is not `True`, `output_hidden_states` is ignored.
    warnings.warn(

tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_disk_offload_bin
  /home/aiscuser/.local/lib/python3.8/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
    return self.fget.__get__(instance, owner)()

tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_feature_extraction
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_feature_extraction_fp16
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_summarization
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_summarization_fp16
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_text2text_generation
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_text2text_generation_fp16
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_translation
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_translation_fp16
  /home/aiscuser/transformers/src/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
    warnings.warn(

tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_summarization
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_summarization_fp16
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_text2text_generation
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_text2text_generation_fp16
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_translation
tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_pipeline_translation_fp16
  /home/aiscuser/transformers/src/transformers/generation/utils.py:1230: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================= 145 passed, 154 skipped, 21 warnings in 138.15s (0:02:18) =================================================================================

@ArthurZucker ArthurZucker merged commit 614660f into huggingface:main Oct 4, 2024
16 checks passed
@ArthurZucker
Copy link
Collaborator

Cool! Tested locally as well the skipped tests (generation ones) not affected

@ArthurZucker
Copy link
Collaborator

thanks!

NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unnecessary transpose in switch transformer implementation
4 participants