Skip to content

Add Distillation API to Keras #21572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

divyashreepathihalli
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli commented Aug 11, 2025

This PR adds Knowledge Distillation API to Keras, enabling efficient transfer of knowledge from large teacher models to smaller student models.

Key Features

Core Components

  • Distiller: Main distillation model that combines teacher and student models
  • Strategies: Pluggable distillation strategies (LogitsDistillation, FeatureDistillation, MultiOutputDistillation)

Usage Examples

Basic Knowledge Distillation

import keras
from keras.distillation import Distiller, LogitsDistillation

# Create models
teacher = keras.Sequential([...])  # Large, pre-trained model
student = keras.Sequential([...])  # Smaller model to train

# Set up distillation
distiller = Distiller(
    teacher=teacher,
    student=student,
    strategies=[LogitsDistillation(temperature=3.0)],
    alpha=0.7  # 70% student loss, 30% distillation loss
)

# Standard Keras workflow
distiller.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
distiller.fit(x_train, y_train, epochs=10)
predictions = distiller.predict(x_test)

Multi-Output Models

from keras.distillation import MultiOutputDistillation

multi_strategy = MultiOutputDistillation(
    output_strategies={
        0: LogitsDistillation(temperature=3.0, output_index=0),
        1: LogitsDistillation(temperature=2.0, output_index=1)
    },
    weights={0: 1.0, 1: 0.5}
)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @divyashreepathihalli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new Knowledge Distillation API to Keras, designed to facilitate the efficient transfer of learned knowledge from larger, pre-trained "teacher" models to smaller "student" models. The API seamlessly integrates with Keras's existing training, evaluation, and prediction workflows, providing a flexible and extensible framework for various distillation techniques.

Highlights

  • New Distiller Model: A core Distiller class is added, which is a keras.Model subclass, enabling the combination and training of teacher and student models within the standard Keras workflow.
  • Pluggable Distillation Strategies: Introduces a BaseDistillationStrategy and three concrete implementations: LogitsDistillation (for softening logits), FeatureDistillation (for intermediate feature matching), and MultiOutputDistillation (for handling models with multiple outputs).
  • Configurable Loss Balancing: The Distiller allows specifying an alpha parameter to balance the contribution of the student's original loss and the distillation loss.
  • Automatic Teacher Freezing: The teacher model is automatically set to non-trainable (trainable=False) during the distillation process to prevent its weights from being updated.
  • Comprehensive Testing: New test files (distiller_test.py and strategies_test.py) are added to ensure the robustness and correctness of the new API, covering initialization, loss computation, and end-to-end workflows.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive Knowledge Distillation API to Keras, which is a great addition. The implementation is well-structured with a Distiller model and pluggable strategies. My review has identified a few issues: a critical issue with the FeatureDistillation strategy which is not fully implemented, a high-severity issue with an unused temperature parameter in the Distiller class that could mislead users, and a medium-severity issue regarding a simplistic fallback for loss calculation in multi-output scenarios. Addressing these points will improve the robustness and clarity of this new API.

Comment on lines 229 to 236
if isinstance(y_pred, (list, tuple)):
# For multi-output models, use the first output for student
# loss
# This is a simplified approach for compatibility
if isinstance(y, (list, tuple)):
student_loss = self.student_loss_fn(y[0], y_pred[0])
else:
student_loss = self.student_loss_fn(y, y_pred[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic for calculating the student loss in _compute_loss for multi-output models is overly simplistic as it always defaults to using the first output (y_pred[0]). This might not align with user expectations for all multi-output scenarios and could lead to incorrect training behavior if model.compile() is not called with a loss that properly handles multiple outputs.

While the primary path using self.compiled_loss is correct, this fallback could be made more robust. Consider raising a more specific error if a multi-output model is used without a compiled loss, or clarifying this behavior more explicitly in the documentation.

@divyashreepathihalli divyashreepathihalli marked this pull request as draft August 11, 2025 23:44
@divyashreepathihalli
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a well-structured Knowledge Distillation API to Keras, a valuable addition for model compression and knowledge transfer. The implementation of the Distiller model and pluggable strategies is clean and the code is generally of high quality, with good documentation and examples.

I have identified a few significant issues that should be addressed to improve the robustness and completeness of the API:

  1. A critical flaw in FeatureDistillation limits its use to sequential models, which will prevent it from working with more complex architectures.
  2. The Distiller's loss computation for multi-output models has a bug in its fallback logic.
  3. The Distiller model cannot be correctly serialized and deserialized due to an incomplete get_config implementation.
  4. The tests for FeatureDistillation are not comprehensive enough to cover non-sequential model architectures.

Addressing these points will significantly enhance the reliability and usability of this new API. Overall, this is a great contribution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive Knowledge Distillation API to Keras, which is a great addition. The implementation is well-structured with a Distiller model and pluggable DistillationStrategy classes. The code is generally clean and includes a good set of tests.

I've identified a few key areas for improvement:

  1. Serialization: The Distiller and MultiOutputDistillation classes are not correctly configured for serialization, which will prevent models using them from being saved and loaded. This is a critical feature in Keras that needs to be addressed.
  2. Feature Extraction Robustness: The FeatureDistillation strategy uses a method for extracting intermediate layer features that is not robust and will fail for models with non-sequential architectures (e.g., ResNets). This limitation needs to be documented and ideally improved.
  3. Code Simplification: There's a small piece of unreachable code in the Distiller's loss computation that can be simplified.

Addressing these points will significantly improve the robustness and usability of this new API.

Comment on lines 283 to 287
if isinstance(y_pred, list) and len(y_pred) > 0:
# For multi-output, use first output for student loss
student_loss = self.student_loss_fn(y[0], y_pred[0])
else:
student_loss = self.student_loss_fn(y, y_pred)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The isinstance(y_pred, list) check on line 283 is redundant because y_pred is converted to a list on line 263. This makes the else block on line 286 unreachable. The logic can be simplified to directly use the first output for the student loss calculation.

                # Fallback: use student_loss_fn directly
                # For multi-output, use first output for student loss
                student_loss = self.student_loss_fn(y[0], y_pred[0])

@divyashreepathihalli divyashreepathihalli removed the request for review from hertschuh August 12, 2025 00:11
@divyashreepathihalli
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive and well-designed Knowledge Distillation API to Keras. The implementation is robust, featuring a flexible Distiller class and a set of pluggable distillation strategies that cover common use cases like logits and feature distillation, as well as multi-output models. The code is accompanied by extensive and thorough tests, which is excellent. My feedback includes a couple of suggestions to improve code style in the API files and to enhance the robustness of a test case by removing a broad exception handler. Overall, this is a high-quality contribution that will be a valuable addition to Keras.

@codecov-commenter
Copy link

codecov-commenter commented Aug 12, 2025

Codecov Report

❌ Patch coverage is 47.88030% with 209 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.50%. Comparing base (8c55abe) to head (a109178).
⚠️ Report is 10 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/distillation/strategies.py 47.00% 98 Missing and 8 partials ⚠️
keras/src/distillation/distiller.py 48.14% 77 Missing and 21 partials ⚠️
keras/api/_tf_keras/keras/distillation/__init__.py 0.00% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21572      +/-   ##
==========================================
- Coverage   82.75%   82.50%   -0.25%     
==========================================
  Files         567      571       +4     
  Lines       56471    57108     +637     
  Branches     8818     8945     +127     
==========================================
+ Hits        46730    47116     +386     
- Misses       7580     7790     +210     
- Partials     2161     2202      +41     
Flag Coverage Δ
keras 82.31% <47.88%> (-0.25%) ⬇️
keras-jax 63.54% <47.88%> (-0.24%) ⬇️
keras-numpy 57.93% <16.20%> (-0.36%) ⬇️
keras-openvino 34.48% <16.20%> (-0.20%) ⬇️
keras-tensorflow 64.09% <47.88%> (-0.13%) ⬇️
keras-torch 63.68% <47.88%> (-0.15%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Some quick comments on the API.

@divyashreepathihalli divyashreepathihalli marked this pull request as ready for review August 19, 2025 20:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants