Skip to content

Migration to latest versions of torch & flash-attn to solve warmstart/fsdp2/weight tying problem #384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

flxst
Copy link
Member

@flxst flxst commented Jul 8, 2025

What does this PR do?

This PR addresses #381. The problem can be traced back to a bug in torch 2.6 related to the fact that we flatten the optimizer state dict here and here.

A solution is to simply migrate to torch 2.7. This requires to also migrate flash-attn to version 2.8.

This PR includes both migrations, along with minor adjustments of the warmstart config files.

Unit tests pass with github actions.

General Changes

  • None

Breaking Changes

  • None

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

@flxst flxst requested a review from le1nux July 8, 2025 14:13
@flxst
Copy link
Member Author

flxst commented Jul 11, 2025

Everything seems to work if conda is used. However, with uv or python -m venv, the installation of flash-attn==2.8.0.post2 fails, as reported here.

@flxst flxst marked this pull request as draft July 11, 2025 11:38
@flxst
Copy link
Member Author

flxst commented Jul 18, 2025

Follow-up problem: Dao-AILab/flash-attention#1708

  • torch==2.6.0 & flash-attn==2.7.4.post1: works
  • torch==2.7.1 & flash-attn==2.8.0.post2: sometimes fails (depending on platform)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant