Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] EKFAC #127

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft

[ADD] EKFAC #127

wants to merge 16 commits into from

Conversation

runame
Copy link
Collaborator

@runame runame commented Sep 17, 2024

Implements EKFAC (and its inverse) support (resolves #116).

I think we should at some point refactor KFACLinearOperator and KFACInverseLinearOperator to inherit from KroneckerProductLinearOperator and EigendecomposedKroneckerProductLinearOperator (or similar) classes since torch_matmat and other methods can be shared. Also, currently KFACInverseLinearOperator doesn't support trace, det, etc. properties which can also be shared. I created #126 for this.

@runame runame added the enhancement New feature or request label Sep 17, 2024
@runame runame requested a review from f-dangel September 17, 2024 04:35
@runame runame self-assigned this Sep 17, 2024
@coveralls
Copy link

coveralls commented Sep 17, 2024

Pull Request Test Coverage Report for Build 10975103378

Details

  • 207 of 210 (98.57%) changed or added relevant lines in 2 files are covered.
  • 2 unchanged lines in 1 file lost coverage.
  • Overall coverage increased (+0.5%) to 89.5%

Changes Missing Coverage Covered Lines Changed/Added Lines %
curvlinops/inverse.py 36 37 97.3%
curvlinops/kfac.py 171 173 98.84%
Files with Coverage Reduction New Missed Lines %
curvlinops/kfac.py 2 93.71%
Totals Coverage Status
Change from base Build 10408891176: 0.5%
Covered Lines: 1449
Relevant Lines: 1619

💛 - Coveralls

@runame
Copy link
Collaborator Author

runame commented Sep 20, 2024

@f-dangel One thing that is not tested and that could be wrong is the per-example gradient computation when there is weight sharing.

Copy link
Owner

@f-dangel f-dangel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gave some refactoring comments.

Overall, while reading through the diff, I was wondering if there is a better way to separate the eigenvalue correction of EKFAC. Ideally, I was imagining we can keep KFAC as is and implement EKFAC separately, e.g. by inheriting EKFAC from KFAC.

Do you have a good idea how to do this? Otherwise I believe this PR will make the code a lot more complex, and long-term complicate extending KFAC, especially for developers that are less familiar with EKFAC.

curvlinops/inverse.py Outdated Show resolved Hide resolved
curvlinops/inverse.py Outdated Show resolved Hide resolved
curvlinops/kfac.py Outdated Show resolved Hide resolved
curvlinops/kfac.py Show resolved Hide resolved
curvlinops/kfac.py Outdated Show resolved Hide resolved
curvlinops/kfac.py Outdated Show resolved Hide resolved
Comment on lines +706 to +707
# Delete the cached activations
self._cached_activations.clear()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these cached activations concatenated over batches? Why don't they have to be cleared inside the data loop?

Copy link
Collaborator Author

@runame runame Sep 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No they will just be overwritten, this avoids redundant clearing of the cache before it is filled up again anyway. Do you think it is cleaner to clear the cache explicitly every iteration?

"d_out1 d_out2, ... d_out1 d_in1, d_in1 d_in2 -> ... d_out2 d_in2",
)
.square_()
.sum(dim=0)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sum correct, or do you want to sum out the ... of the einsum result?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the above variable, I would change .sum(dim=0) into .sum(list(range(shared_axes)))

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check the else branch below for the same suggestions.

per_example_gradient = einsum(
g,
self._cached_activations[module_name],
"shared d_out, shared d_in -> shared d_out d_in",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shared should be replaced by ...

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, add a line shared_axes = g.ndim - 2.

curvlinops/kfac.py Outdated Show resolved Hide resolved
@runame
Copy link
Collaborator Author

runame commented Oct 2, 2024

Will continue this PR in ~2 weeks.

@runame runame marked this pull request as draft November 29, 2024 03:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement EKFAC
3 participants