-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add Distillation API to Keras #21572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Distillation API to Keras #21572
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @divyashreepathihalli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a new Knowledge Distillation API to Keras, designed to facilitate the efficient transfer of learned knowledge from larger, pre-trained "teacher" models to smaller "student" models. The API seamlessly integrates with Keras's existing training, evaluation, and prediction workflows, providing a flexible and extensible framework for various distillation techniques.
Highlights
- New Distiller Model: A core Distiller class is added, which is a keras.Model subclass, enabling the combination and training of teacher and student models within the standard Keras workflow.
- Pluggable Distillation Strategies: Introduces a BaseDistillationStrategy and three concrete implementations: LogitsDistillation (for softening logits), FeatureDistillation (for intermediate feature matching), and MultiOutputDistillation (for handling models with multiple outputs).
- Configurable Loss Balancing: The Distiller allows specifying an alpha parameter to balance the contribution of the student's original loss and the distillation loss.
- Automatic Teacher Freezing: The teacher model is automatically set to non-trainable (trainable=False) during the distillation process to prevent its weights from being updated.
- Comprehensive Testing: New test files (distiller_test.py and strategies_test.py) are added to ensure the robustness and correctness of the new API, covering initialization, loss computation, and end-to-end workflows.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive Knowledge Distillation API to Keras, which is a great addition. The implementation is well-structured with a Distiller
model and pluggable strategies
. My review has identified a few issues: a critical issue with the FeatureDistillation
strategy which is not fully implemented, a high-severity issue with an unused temperature
parameter in the Distiller
class that could mislead users, and a medium-severity issue regarding a simplistic fallback for loss calculation in multi-output scenarios. Addressing these points will improve the robustness and clarity of this new API.
keras/src/distillation/distiller.py
Outdated
if isinstance(y_pred, (list, tuple)): | ||
# For multi-output models, use the first output for student | ||
# loss | ||
# This is a simplified approach for compatibility | ||
if isinstance(y, (list, tuple)): | ||
student_loss = self.student_loss_fn(y[0], y_pred[0]) | ||
else: | ||
student_loss = self.student_loss_fn(y, y_pred[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fallback logic for calculating the student loss in _compute_loss
for multi-output models is overly simplistic as it always defaults to using the first output (y_pred[0]
). This might not align with user expectations for all multi-output scenarios and could lead to incorrect training behavior if model.compile()
is not called with a loss that properly handles multiple outputs.
While the primary path using self.compiled_loss
is correct, this fallback could be made more robust. Consider raising a more specific error if a multi-output model is used without a compiled loss, or clarifying this behavior more explicitly in the documentation.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a well-structured Knowledge Distillation API to Keras, a valuable addition for model compression and knowledge transfer. The implementation of the Distiller
model and pluggable strategies is clean and the code is generally of high quality, with good documentation and examples.
I have identified a few significant issues that should be addressed to improve the robustness and completeness of the API:
- A critical flaw in
FeatureDistillation
limits its use to sequential models, which will prevent it from working with more complex architectures. - The
Distiller
's loss computation for multi-output models has a bug in its fallback logic. - The
Distiller
model cannot be correctly serialized and deserialized due to an incompleteget_config
implementation. - The tests for
FeatureDistillation
are not comprehensive enough to cover non-sequential model architectures.
Addressing these points will significantly enhance the reliability and usability of this new API. Overall, this is a great contribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive Knowledge Distillation API to Keras, which is a great addition. The implementation is well-structured with a Distiller
model and pluggable DistillationStrategy
classes. The code is generally clean and includes a good set of tests.
I've identified a few key areas for improvement:
- Serialization: The
Distiller
andMultiOutputDistillation
classes are not correctly configured for serialization, which will prevent models using them from being saved and loaded. This is a critical feature in Keras that needs to be addressed. - Feature Extraction Robustness: The
FeatureDistillation
strategy uses a method for extracting intermediate layer features that is not robust and will fail for models with non-sequential architectures (e.g., ResNets). This limitation needs to be documented and ideally improved. - Code Simplification: There's a small piece of unreachable code in the
Distiller
's loss computation that can be simplified.
Addressing these points will significantly improve the robustness and usability of this new API.
keras/src/distillation/distiller.py
Outdated
if isinstance(y_pred, list) and len(y_pred) > 0: | ||
# For multi-output, use first output for student loss | ||
student_loss = self.student_loss_fn(y[0], y_pred[0]) | ||
else: | ||
student_loss = self.student_loss_fn(y, y_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The isinstance(y_pred, list)
check on line 283 is redundant because y_pred
is converted to a list on line 263. This makes the else
block on line 286 unreachable. The logic can be simplified to directly use the first output for the student loss calculation.
# Fallback: use student_loss_fn directly
# For multi-output, use first output for student loss
student_loss = self.student_loss_fn(y[0], y_pred[0])
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive and well-designed Knowledge Distillation API to Keras. The implementation is robust, featuring a flexible Distiller
class and a set of pluggable distillation strategies that cover common use cases like logits and feature distillation, as well as multi-output models. The code is accompanied by extensive and thorough tests, which is excellent. My feedback includes a couple of suggestions to improve code style in the API files and to enhance the robustness of a test case by removing a broad exception handler. Overall, this is a high-quality contribution that will be a valuable addition to Keras.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21572 +/- ##
==========================================
- Coverage 82.75% 82.50% -0.25%
==========================================
Files 567 571 +4
Lines 56471 57108 +637
Branches 8818 8945 +127
==========================================
+ Hits 46730 47116 +386
- Misses 7580 7790 +210
- Partials 2161 2202 +41
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Some quick comments on the API.
This PR adds Knowledge Distillation API to Keras, enabling efficient transfer of knowledge from large teacher models to smaller student models.
Key Features
Core Components
Usage Examples
Basic Knowledge Distillation
Multi-Output Models