Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions finetune/models/cogvideox_t2v/kd_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import torchvision
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
from ..utils import register
from typing_extensions import override
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models import vgg16
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
import torch.nn as nn
from collections import OrderedDict

class VGG16BackboneWrapper(nn.Module):
def __init__(self, vgg_features):
super(VGG16BackboneWrapper, self).__init__()
self.features = vgg_features
self.out_channels = 512

def forward(self, x):
x = self.features(x)
return OrderedDict([("0", x)])

class CogVideoXT2VKdTrainer(CogVideoXT2VLoraTrainer):
# Remove vae from the unload list to make it available in compute_loss
UNLOAD_LIST = ["text_encoder"]

def __init__(self, args):
super().__init__(args)
self.teacher_model = self.load_teacher_model()

def load_teacher_model(self):
teacher_model_path = self.args.teacher_model_path if hasattr(self.args, 'teacher_model_path') else None
if not teacher_model_path:
print("Warning: teacher_model_path is not provided. Knowledge distillation will be skipped.")
return None

try:
# Create a VGG16-based Faster R-CNN model
# 1. VGG16 backbone
vgg_features = vgg16(weights=None).features
# The original VGG16 model in torchvision has a maxpool layer at the end of features.
# Faster R-CNN with VGG backbone in many implementations does not use this last maxpool.
# Let's remove it to be closer to the original Caffe implementation.
backbone_features = vgg_features[:-1]
backbone = VGG16BackboneWrapper(backbone_features)

# 2. RPN
anchor_generator = AnchorGenerator(sizes=((128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),))

# 3. RoI heads
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=7, sampling_ratio=2)

# The user's model has fc6 and fc7 layers, which corresponds to a TwoMLPHead.
# VGG16's output from the backbone is 512 * 7 * 7 = 25088
box_head = TwoMLPHead(in_channels=25088, representation_size=4096)

num_classes = self.args.teacher_model_num_classes if hasattr(self.args, 'teacher_model_num_classes') else 91 # COCO default
box_predictor = FastRCNNPredictor(in_channels=4096, num_classes=num_classes)


# 4. Faster R-CNN model
model = FasterRCNN(backbone,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
box_head=box_head,
box_predictor=box_predictor,
num_classes=num_classes)

# Load the pre-trained weights from the converted file
print(f"Loading teacher model from: {teacher_model_path}")
state_dict = torch.load(teacher_model_path)
model.load_state_dict(state_dict)
model.eval()
model.to(self.accelerator.device)
print("Teacher model loaded successfully.")
return model
except Exception as e:
print(f"Error loading teacher model: {e}")
return None

@override
def compute_loss(self, batch) -> torch.Tensor:
# Get the original diffusion loss
diffusion_loss = super().compute_loss(batch)

if self.teacher_model is None:
return diffusion_loss

latents = batch["encoded_videos"]

# Decode the latents to get video frames
video_frames = self.components.vae.decode(latents / self.components.vae.config.scaling_factor).sample

# The output of the VAE is in the range [-1, 1]. We need to normalize it to [0, 1] for the teacher model.
video_frames = (video_frames + 1) / 2

video_frames = video_frames.permute(0, 2, 1, 3, 4)

# Calculate the knowledge distillation loss
kd_loss = 0
num_frames_processed = 0
for i in range(video_frames.shape[0]): # For each video in the batch
frames = [frame for frame in video_frames[i]] # list of frames for the i-th video
if not frames:
continue

num_frames_processed += len(frames)
teacher_output = self.teacher_model(frames)

for output in teacher_output:
if len(output['boxes']) == 0:
kd_loss += 1

if num_frames_processed > 0:
kd_loss /= num_frames_processed
else:
kd_loss = 0

# Combine the losses
kd_loss_weight = self.args.kd_loss_weight if hasattr(self.args, 'kd_loss_weight') else 0.1
# Make kd_loss a tensor
kd_loss_tensor = torch.tensor(kd_loss, device=self.accelerator.device, dtype=diffusion_loss.dtype)
total_loss = diffusion_loss + kd_loss_weight * kd_loss_tensor

self.accelerator.log({"kd_loss": kd_loss, "diffusion_loss": diffusion_loss.item(), "total_loss": total_loss.item()})

return total_loss

register("cogvideox-t2v", "kd", CogVideoXT2VKdTrainer)
2 changes: 1 addition & 1 deletion finetune/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}


def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer):
def register(model_name: str, training_type: Literal["lora", "sft", "kd"], trainer_cls: Trainer):
"""Register a model and its associated functions for a specific training type.

Args:
Expand Down
12 changes: 11 additions & 1 deletion finetune/schemas/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ class Args(BaseModel):
model_path: Path
model_name: str
model_type: Literal["i2v", "t2v"]
training_type: Literal["lora", "sft"] = "lora"
training_type: Literal["lora", "sft", "kd"] = "lora"

########## KD ##########
teacher_model_path: Path | None = None
teacher_model_num_classes: int | None = None
kd_loss_weight: float = 0.1

########## Output ##########
output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now()))
Expand Down Expand Up @@ -239,6 +244,11 @@ def parse_args(cls):
# Validation
parser.add_argument("--do_validation", type=lambda x: x.lower() == 'true', default=False)
parser.add_argument("--validation_steps", type=int, default=None)

# KD parameters
parser.add_argument("--teacher_model_path", type=str, default=None)
parser.add_argument("--teacher_model_num_classes", type=int, default=None)
parser.add_argument("--kd_loss_weight", type=float, default=0.1)
parser.add_argument("--validation_dir", type=str, default=None)
parser.add_argument("--validation_prompts", type=str, default=None)
parser.add_argument("--validation_images", type=str, default=None)
Expand Down