diff --git a/examples/pytorch/README.md b/examples/pytorch/README.md index 4e318b3edb920c..2f2766a04e493d 100644 --- a/examples/pytorch/README.md +++ b/examples/pytorch/README.md @@ -48,7 +48,7 @@ Coming soon! | [**`semantic-segmentation`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/semantic-segmentation) | [SCENE_PARSE_150](https://huggingface.co/datasets/scene_parse_150) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/semantic_segmentation.ipynb) | [**`object-detection`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection) | [CPPE-5](https://huggingface.co/datasets/cppe-5) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/object_detection.ipynb) | [**`instance-segmentation`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/instance-segmentation) | [ADE20K sample](https://huggingface.co/datasets/qubvel-hf/ade20k-mini) | ✅ | ✅ |✅ | - +| [**`zero-shot`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/zero-shot) | [CPPE-5](https://huggingface.co/datasets/cppe-5) | ✅ | ✅ |✅ | / ## Running quick tests diff --git a/examples/pytorch/object-detection/README.md b/examples/pytorch/object-detection/README.md index ab474f76075305..50d8bcc20fe1a2 100644 --- a/examples/pytorch/object-detection/README.md +++ b/examples/pytorch/object-detection/README.md @@ -69,7 +69,7 @@ python run_object_detection.py \ `--eval_do_concat_batches false` is required for correct evaluation of detection models; `--ignore_mismatched_sizes true` is required to load detection model for finetuning with different number of classes. -The resulting model can be seen here: https://huggingface.co/qubvel-hf/qubvel-hf/detr-resnet-50-finetuned-10k-cppe5. The corresponding Weights and Biases report [here](https://api.wandb.ai/links/qubvel-hf-co/bnm0r5ex). Note that it's always advised to check the original paper to know the details regarding training hyperparameters. Hyperparameters for current example were not tuned. To improve model quality you could try: +The resulting model can be seen here: https://huggingface.co/qubvel-hf/detr-resnet-50-finetuned-10k-cppe5. The corresponding Weights and Biases report [here](https://api.wandb.ai/links/qubvel-hf-co/bnm0r5ex). Note that it's always advised to check the original paper to know the details regarding training hyperparameters. Hyperparameters for current example were not tuned. To improve model quality you could try: - changing image size parameters (`--shortest_edge`/`--longest_edge`) - changing training parameters, such as learning rate, batch size, warmup, optimizer and many more (see [TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments)) - adding more image augmentations (we created a helpful [HF Space](https://huggingface.co/spaces/qubvel-hf/albumentations-demo) to choose some) diff --git a/examples/pytorch/test_pytorch_examples.py b/examples/pytorch/test_pytorch_examples.py index c609ee860c728f..5bd17a0eb982cc 100644 --- a/examples/pytorch/test_pytorch_examples.py +++ b/examples/pytorch/test_pytorch_examples.py @@ -50,6 +50,7 @@ "semantic-segmentation", "object-detection", "instance-segmentation", + "zero-shot", ] ] sys.path.extend(SRC_DIRS) @@ -76,6 +77,7 @@ import run_swag import run_translation import run_wav2vec2_pretraining_no_trainer + import run_zero_shot_object_detection logging.basicConfig(level=logging.DEBUG) @@ -678,3 +680,31 @@ def test_run_instance_segmentation(self): run_instance_segmentation.main() result = get_results(tmp_dir) self.assertGreaterEqual(result["test_map"], 0.1) + + @patch.dict(os.environ, {"WANDB_DISABLED": "true"}) + def test_zero_shotrun_object_detection(self): + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_zero_shot_object_detection.py + --model_name_or_path IDEA-Research/grounding-dino-tiny + --output_dir {tmp_dir} + --dataset_name qubvel-hf/cppe-5-sample + --do_train + --do_eval + --remove_unused_columns False + --overwrite_output_dir True + --eval_do_concat_batches False + --max_steps 10 + --learning_rate=5e-5 + --per_device_train_batch_size=1 + --per_device_eval_batch_size=1 + --seed 32 + """.split() + + if is_torch_fp16_available_on_device(torch_device): + testargs.append("--fp16") + + with patch.object(sys, "argv", testargs): + run_zero_shot_object_detection.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["test_map"], 0.01) diff --git a/examples/pytorch/zero-shot/README.md b/examples/pytorch/zero-shot/README.md new file mode 100644 index 00000000000000..1a16d3557dd010 --- /dev/null +++ b/examples/pytorch/zero-shot/README.md @@ -0,0 +1,254 @@ + + +# Object detection examples + +This directory contains 2 scripts that showcase how to fine-tune any model supported by the [`GroundingDinoForObjectDetection` API](https://huggingface.co/docs/transformers/main/en/model_doc/grounding-dino#transformers.GroundingDinoForObjectDetection) using PyTorch. + +Content: +* [PyTorch version, Trainer](#pytorch-version-trainer) +* [PyTorch version, no Trainer](#pytorch-version-no-trainer) +* [Reload and perform inference](#reload-and-perform-inference) +* [Note on custom data](#note-on-custom-data) + + +## PyTorch version, Trainer + +Based on the script [`run_zero_shot_object_detection.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/zero-shot/run_zero_shot_object_detection.py). + +The script leverages the [🤗 Trainer API](https://huggingface.co/docs/transformers/main_classes/trainer) to automatically take care of the training for you, running on distributed environments right away. + +Here we show how to fine-tune a [GroundingDino](https://huggingface.co/IDEA-Research/grounding-dino-tiny) model on the [CPPE-5](https://huggingface.co/datasets/cppe-5) dataset: + +```bash +python run_zero_shot_object_detection.py \ + --model_name_or_path IDEA-Research/grounding-dino-tiny \ + --dataset_name cppe-5 \ + --do_train true \ + --do_eval true \ + --output_dir grounding-dino-tiny-finetuned-cppe-5-10k-steps \ + --num_train_epochs 10 \ + --image_square_size 600 \ + --fp16 true \ + --learning_rate 5e-5 \ + --weight_decay 1e-4 \ + --dataloader_num_workers 4 \ + --dataloader_prefetch_factor 2 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --remove_unused_columns false \ + --eval_do_concat_batches false \ + --ignore_mismatched_sizes true \ + --include_inputs_for_metrics true \ + --metric_for_best_model eval_map \ + --greater_is_better true \ + --load_best_model_at_end true \ + --logging_strategy epoch \ + --evaluation_strategy epoch \ + --save_strategy epoch \ + --save_total_limit 2 \ + --push_to_hub true \ + --push_to_hub_model_id grounding-dino-tiny-finetuned-cppe-5-10k-steps \ + --hub_strategy end \ + --seed 1337 +``` + +> Note: +`--eval_do_concat_batches false` is required for correct evaluation of detection models; +`--ignore_mismatched_sizes true` is required to load detection model for finetuning with different number of classes. + +The resulting model can be seen here: https://huggingface.co/danelcsb/grounding-dino-tiny-finetuned-10k-cppe-5-10k-steps. Note that it's always advised to check the original paper to know the details regarding training hyperparameters. Hyperparameters for current example were not tuned. To improve model quality you could try: + - changing freeze policy of image backbone and text backbone + - changing image size parameters (`--shortest_edge`/`--longest_edge`) + - changing training parameters, such as learning rate, batch size, warmup, optimizer and many more (see [TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments)) + - adding more image augmentations (we created a helpful [HF Space](https://huggingface.co/spaces/qubvel-hf/albumentations-demo) to choose some) + +Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with model or dataset from the [hub](https://huggingface.co/). +For dataset, make sure it provides labels in the same format as [CPPE-5](https://huggingface.co/datasets/cppe-5) dataset and boxes are provided in [COCO format](https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco). + +Note that zero-shot inference output is not the same output format as object-detection output. In order to compute the evaluation metric performance such as mean average precision, we have to modify the output little bit. + +| Train method | Batch size | freeze_text_backbone | freeze_backbone | precision | MSDA kernels | GPU Memory Usage (GB) | Time (s/epoch) | +|--------------|------------|----------------------|-----------------|-----------|--------------|-----------------------|----------------| +| trainer | 2 | Y | Y | fp16 | Y | 22.785 | 353 | +| trainer | 1 | Y | Y | fp32 | Y | 8.813 | 429 | +| no_trainer | 2 | N | N | fp32 | Y | OOM | - | +| no_trainer | 1 | N | N | fp32 | N | 20.441 | 724 | +| no_trainer | 1 | N | N | fp32 | Y | 11.243 | 473 | +| no_trainer | 1 | Y | Y | fp32 | Y | 11.539 | 386 | + +Above table is tested on following device. +- Platform: Linux-5.4.0-167-generic-x86_64-with-glibc2.35 +- GPU type: NVIDIA TITAN RTX +- PyTorch version (GPU): 2.2.2 + +However, currently multi-GPU evaluation is not working due to [flatten issue](https://github.com/huggingface/transformers/pull/33561) try to use `accelerate` as default. + +## PyTorch version, no Trainer + +Based on the script [`run_zero_shot_object_detection_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/object-detection/run_zero_shot_object_detection.py). + +The script leverages [🤗 `Accelerate`](https://github.com/huggingface/accelerate), which allows to write your own training loop in PyTorch, but have it run instantly on any (distributed) environment, including CPU, multi-CPU, GPU, multi-GPU and TPU. It also supports mixed precision. + +First, run: + +```bash +accelerate config +``` + +and reply to the questions asked regarding the environment on which you'd like to train. Then + +```bash +accelerate test +``` + +that will check everything is ready for training. Finally, you can launch training with + +```bash +accelerate launch run_zero_shot_object_detection_no_trainer.py \ + --model_name_or_path "IDEA-Research/grounding-dino-tiny" \ + --dataset_name cppe-5 \ + --output_dir "grounding-dino-tiny-finetuned-cppe-5-10k-steps-no-trainer" \ + --num_train_epochs 10 \ + --image_square_size 600 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --checkpointing_steps epoch \ + --learning_rate 5e-5 \ + --ignore_mismatched_sizes \ + --with_tracking \ + --push_to_hub \ + --freeze_backbone \ + --freeze_text_backbone +``` + +and boom, you're training, possibly on multiple GPUs, logging everything to all trackers found in your environment (like Weights and Biases, Tensorboard) and regularly pushing your model to the hub (with the repo name being equal to `args.output_dir` at your HF username) 🤗 + +With the default settings, the script fine-tunes a [GroundingDino](https://huggingface.co/IDEA-Research/grounding-dino-tiny) model on the [CPPE-5](https://huggingface.co/datasets/cppe-5) dataset. The resulting model can be seen here: https://huggingface.co/danelcsb/grounding-dino-tiny-finetuned-10k-cppe-5-no-trainer. + + +## Reload and perform inference + +This means that after training, you can easily load your trained model and perform inference as follows:: + +```python +import requests +import torch + +from PIL import Image +from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection + +# Name of repo on the hub or path to a local folder +model_name = "danelcsb/grounding-dino-tiny-finetuned-10k-cppe5" + +image_processor = AutoProcessor.from_pretrained(model_name) +model = AutoModelForZeroShotObjectDetection.from_pretrained(model_name) + +# Load image for inference +url = "https://images.pexels.com/photos/8413299/pexels-photo-8413299.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2" +image = Image.open(requests.get(url, stream=True).raw) +text = "Coverall. Face_Shield. Gloves. Goggles. Mask" + +# Prepare image for the model +inputs = image_processor(images=image, text=text, return_tensors="pt") + +with torch.no_grad(): + outputs = model(**inputs) + +# Post process model predictions +# this include conversion to Pascal VOC format and filtering non confident boxes +width, height = image.size +target_sizes = torch.tensor([height, width]).unsqueeze(0) # add batch dim +results = processor.post_process_grounded_object_detection(outputs, inputs.input_ids, box_threshold=0.15, text_threshold=0.1, target_sizes=target_sizes)[0] + +for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + box = [round(i, 2) for i in box.tolist()] + print( + f"Detected {model.config.id2label[label.item()]} with confidence " + f"{round(score.item(), 3)} at location {box}" + ) +``` + +And visualize with the following code: +```python +from PIL import ImageDraw +draw = ImageDraw.Draw(image) + +for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + box = [round(i, 2) for i in box.tolist()] + x, y, x2, y2 = tuple(box) + draw.rectangle((x, y, x2, y2), outline="red", width=1) + draw.text((x, y), model.config.id2label[label.item()], fill="white") + +image +``` + + +## Note on custom data + +In case you'd like to use the script with custom data, you could prepare your data with the following way: + +```bash +custom_dataset/ +└── train + ├── 0001.jpg + ├── 0002.jpg + ├── ... + └── metadata.jsonl +└── validation + └── ... +└── test + └── ... +``` + +Where `metadata.jsonl` is a file with the following structure: +```json +{"file_name": "0001.jpg", "objects": {"bbox": [[302.0, 109.0, 73.0, 52.0]], "categories": [0], "id": [1], "area": [50.0]}} +{"file_name": "0002.jpg", "objects": {"bbox": [[810.0, 100.0, 57.0, 28.0]], "categories": [1], "id": [2], "area": [40.0]}} +... +``` +Trining script support bounding boxes in COCO format (x_min, y_min, width, height). + +Then, you cat load the dataset with just a few lines of code: + +```python +from datasets import load_dataset + +# Load dataset +dataset = load_dataset("imagefolder", data_dir="custom_dataset/") + +# >>> DatasetDict({ +# ... train: Dataset({ +# ... features: ['image', 'objects'], +# ... num_rows: 2 +# ... }) +# ... }) + +# Push to hub (assumes you have ran the huggingface-cli login command in a terminal/notebook) +dataset.push_to_hub("name of repo on the hub") + +# optionally, you can push to a private repo on the hub +# dataset.push_to_hub("name of repo on the hub", private=True) +``` + +And the final step, for training you should provide id2label mapping in the following way: +```python +id2label = {0: "Car", 1: "Bird", ...} +``` +Just find it in code and replace for simplicity, or save `json` locally and with the dataset on the hub! + +See also: [Dataset Creation Guide](https://huggingface.co/docs/datasets/image_dataset#create-an-image-dataset) diff --git a/examples/pytorch/zero-shot/requirements.txt b/examples/pytorch/zero-shot/requirements.txt new file mode 100644 index 00000000000000..2aa0d9bcf01672 --- /dev/null +++ b/examples/pytorch/zero-shot/requirements.txt @@ -0,0 +1,5 @@ +albumentations >= 1.4.5 +timm +datasets +torchmetrics +pycocotools diff --git a/examples/pytorch/zero-shot/run_zero_shot_object_detection.py b/examples/pytorch/zero-shot/run_zero_shot_object_detection.py new file mode 100644 index 00000000000000..32162f237c061b --- /dev/null +++ b/examples/pytorch/zero-shot/run_zero_shot_object_detection.py @@ -0,0 +1,637 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +"""Finetuning any 🤗 Transformers model supported by AutoModelForZeroShotObjectDetection for object detection leveraging the Trainer API.""" + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union + +import albumentations as A +import numpy as np +import torch +from datasets import load_dataset +from torchmetrics.detection.mean_ap import MeanAveragePrecision + +import transformers +from transformers import ( + AutoConfig, + AutoModelForZeroShotObjectDetection, + AutoProcessor, + HfArgumentParser, + Trainer, + TrainingArguments, +) +from transformers.image_processing_utils import BatchFeature +from transformers.image_transforms import center_to_corners_format +from transformers.trainer import EvalPrediction +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version + + +logger = logging.getLogger(__name__) + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.44.0.dev0") + +require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/zero_shot/requirements.txt") + + +@dataclass +class ModelOutput: + logits: torch.Tensor + pred_boxes: torch.Tensor + + +class ZeroShotTrainer(Trainer): + def _select_inputs_for_validation(self, inputs): + return inputs["input_ids"] + + +def format_image_annotations_as_coco( + image_id: str, categories: List[int], areas: List[float], bboxes: List[Tuple[float]] +) -> dict: + """Format one set of image annotations to the COCO format + + Args: + image_id (str): image id. e.g. "0001" + categories (List[int]): list of categories/class labels corresponding to provided bounding boxes + areas (List[float]): list of corresponding areas to provided bounding boxes + bboxes (List[Tuple[float]]): list of bounding boxes provided in COCO format + ([center_x, center_y, width, height] in absolute coordinates) + + Returns: + dict: { + "image_id": image id, + "annotations": list of formatted annotations + } + """ + annotations = [] + for category, area, bbox in zip(categories, areas, bboxes): + formatted_annotation = { + "image_id": image_id, + "category_id": category, + "iscrowd": 0, + "area": area, + "bbox": list(bbox), + } + annotations.append(formatted_annotation) + + return { + "image_id": image_id, + "annotations": annotations, + } + + +def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: + """ + Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1] + to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates. + + Args: + boxes (torch.Tensor): Bounding boxes in YOLO format + image_size (Tuple[int, int]): Image size in format (height, width) + + Returns: + torch.Tensor: Bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max) + """ + # convert center to corners format + boxes = center_to_corners_format(boxes) + + # convert to absolute coordinates + height, width = image_size + boxes = boxes * torch.tensor([[width, height, width, height]]) + + return boxes + + +def convert_zero_shot_to_coco_format(predictions, label2id): + """ + Convert zershot format output to typical object detection format in order to calculate mAP. + + Args: + predictions (Dict): Output of zero-shot object detection + e.g. {'scores': tensor([0.4786, 0.4379, 0.4760], device='cuda:0'), 'labels': ['a cat', 'a cat', 'a remote control'], 'boxes': tensor([[344.6973, 23.1085, 637.1817, 374.2748],[ 12.2690, 51.9104, 316.8564, 472.4341],[ 38.5870, 70.0092, 176.7755, 118.1748]], device='cuda:0')} + label2id (Dict): Dictionary of label to id mapping + + Returns: + Dict: Output of zero-shot object detection + e.g. {'scores': tensor([0.4786, 0.4379, 0.4760], device='cuda:0'), 'labels': [1, 1, 2], 'boxes': tensor([[344.6973, 23.1085, 637.1817, 374.2748],[ 12.2690, 51.9104, 316.8564, 472.4341],[ 38.5870, 70.0092, 176.7755, 118.1748]], device='cuda:0')} + + """ + # convert center to corners format + torch_label = [] + for prediction in predictions: + scores = prediction["scores"] + device = scores.device + labels = prediction["labels"] + for label in labels: + if label in label2id: + torch_label.append(label2id[label]) + else: + # Give background class + torch_label.append(0) + prediction["labels"] = torch.Tensor(torch_label).to(dtype=torch.int32).to(device) + + return predictions + + +def to_label_list(id2label): + return list(id2label.values()) + + +def concat_func(id2label): + return ". ".join(to_label_list(id2label)) + "." + + +def augment_and_transform_batch( + examples: Mapping[str, Any], + transform: A.Compose, + processor: AutoProcessor, + id2label: Dict[int, str], + label2id: Dict[str, int], + random_text_prompt: bool = False, + return_pixel_mask: bool = False, +) -> BatchFeature: + """ + Apply augmentations and format annotations in COCO format for object detection task. + Generates the text prompt used. If `random_text_prompt` is False + then the prompt will follow the same ordering in `id2label` if set to + True a new ordering will be created and the prompt will be build accordingly + and labels will be updated as well. + + Example: + `id2label` -> {'0': 'fish', '1': 'jellyfish', '2': 'penguins', '3': + 'sharks', '4': 'puffins', '5': 'stingrays', '6': 'starfish'} + + If `random_text_prompt` -> False + `text` -> "fish. jellyfish. penguins. sharks. puffins. stingrays. starfish." + + If `random_text_prompt` -> True + `id2label` gets shuffled e.g. {0: 'fish', 1: 'penguins', 2: 'stingrays', 3: + 'jellyfish', 4: 'sharks', 5: 'starfish', 6: 'puffins'} + `text` -> "fish. penguins. stingrays. jellyfish. sharks. starfish. puffins." + """ + + images = [] + annotations = [] + text = [] + + for image_id, image, objects in zip(examples["image_id"], examples["image"], examples["objects"]): + image = np.array(image.convert("RGB")) + + if random_text_prompt: + # Original ordering label list + label_list = to_label_list(id2label) + # Shuffle label list + random.shuffle(label_list) + # Create shuffled id2label + shuffled_id2label = dict(enumerate(label_list)) + + # Mapping of original to shuffled id to update annotations + old2new = {label2id[label]: new_id for new_id, label in shuffled_id2label.items()} + prompt = concat_func(shuffled_id2label) + category = [old2new[category] for category in objects["category"]] + else: + prompt = concat_func(id2label) + category = objects["category"] + + # apply augmentations + output = transform(image=image, bboxes=objects["bbox"], category=category) + images.append(output["image"]) + + # format annotations in COCO format + formatted_annotations = format_image_annotations_as_coco( + image_id, output["category"], objects["area"], output["bboxes"] + ) + annotations.append(formatted_annotations) + text.append(prompt) + + # Apply the image processor transformations: resizing, rescaling, normalization + result = processor(images=images, text=text, annotations=annotations, return_tensors="pt") + + if not return_pixel_mask: + result.pop("pixel_mask", None) + + return result + + +def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, List[Any]]]: + data = {} + data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch]) + data["input_ids"] = torch.stack([x["input_ids"] for x in batch]) + data["token_type_ids"] = torch.stack([x["token_type_ids"] for x in batch]) + data["labels"] = [x["labels"] for x in batch] + if "pixel_mask" in batch[0]: + data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch]) + if "attention_mask" in batch[0]: + data["attention_mask"] = torch.stack([x["attention_mask"] for x in batch]) + return data + + +@torch.no_grad() +def compute_metrics( + evaluation_results: EvalPrediction, + processor: AutoProcessor, + box_threshold: float = 0.15, + text_threshold: float = 0.1, + id2label: Optional[Mapping[int, str]] = None, + label2id: Optional[Mapping[str, int]] = None, +) -> Mapping[str, float]: + """ + Compute mean average mAP, mAR and their variants for the object detection task. + + Args: + evaluation_results (EvalPrediction): Predictions and targets from evaluation. + box_threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.15. + text_threshold (float, optional): Threshold to filter predicted text by confidence. Defaults to 0.1. + id2label (Optional[dict], optional): Mapping from class id to class name. Defaults to None. + + Returns: + Mapping[str, float]: Metrics in a form of dictionary {: } + """ + + predictions, targets, inputs = ( + evaluation_results.predictions, + evaluation_results.label_ids, + evaluation_results.inputs, + ) + # For metric computation we need to provide: + # - targets in a form of list of dictionaries with keys "boxes", "labels" + # - predictions in a form of list of dictionaries with keys "boxes", "scores", "labels" + + image_sizes = [] + post_processed_targets = [] + post_processed_predictions = [] + + # Collect targets in the required format for metric computation + for batch in targets: + # collect image sizes, we will need them for predictions post processing + batch_image_sizes = torch.tensor([x["orig_size"] for x in batch]) + image_sizes.append(batch_image_sizes) + # collect targets in the required format for metric computation + # boxes were converted to YOLO format needed for model training + # here we will convert them to Pascal VOC format (x_min, y_min, x_max, y_max) + for image_target in batch: + boxes = torch.tensor(image_target["boxes"]) + boxes = convert_bbox_yolo_to_pascal(boxes, image_target["orig_size"]) + labels = torch.tensor(image_target["class_labels"]) + post_processed_targets.append({"boxes": boxes, "labels": labels}) + + # Collect predictions in the required format for metric computation, + # model produce boxes in YOLO format, then processor convert them to Pascal VOC format + for batch, target_sizes, input_ids in zip(predictions, image_sizes, inputs): + batch_logits, batch_boxes = batch[1], batch[2] + output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes)) + post_processed_output = processor.post_process_grounded_object_detection( + output, input_ids, box_threshold=box_threshold, text_threshold=text_threshold, target_sizes=target_sizes + ) + post_processed_output = convert_zero_shot_to_coco_format(post_processed_output, label2id) + post_processed_predictions.extend(post_processed_output) + + # Compute metrics + metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True) + metric.update(post_processed_predictions, post_processed_targets) + metrics = metric.compute() + + # Replace list of per class metrics with separate metric for each class + classes = metrics.pop("classes") + map_per_class = metrics.pop("map_per_class") + mar_100_per_class = metrics.pop("mar_100_per_class") + for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class): + class_name = id2label[class_id.item()] if id2label is not None else class_id.item() + metrics[f"map_{class_name}"] = class_map + metrics[f"mar_100_{class_name}"] = class_mar + + metrics = {k: round(v.item(), 4) for k, v in metrics.items()} + + return metrics + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify + them on the command line. + """ + + dataset_name: str = field( + default="cppe-5", + metadata={ + "help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)." + }, + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_val_split: Optional[float] = field( + default=0.15, metadata={"help": "Percent to split off of train for validation."} + ) + image_square_size: Optional[int] = field( + default=600, + metadata={"help": "Image longest size will be resized to this value, then image will be padded to square."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + default="IDEA-Research/grounding-dino-tiny", + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."}) + ignore_mismatched_sizes: bool = field( + default=False, + metadata={ + "help": "Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels)." + }, + ) + token: str = field( + default=None, + metadata={ + "help": ( + "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " + "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." + ) + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether to trust the execution of code from datasets/models defined on the Hub." + " This option should only be set to `True` for repositories you trust and in which you have read the" + " code, as it will execute code present on the Hub on your local machine." + ) + }, + ) + freeze_backbone: bool = field( + default=True, + metadata={"help": ("Whether freeze the image backbone.")}, + ) + freeze_text_backbone: bool = field( + default=True, + metadata={"help": ("Whether freeze the text encoder.")}, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_object_detection", model_args, data_args) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " + + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + checkpoint = get_last_checkpoint(training_args.output_dir) + if checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # ------------------------------------------------------------------------------------------------ + # Load dataset, prepare splits + # ------------------------------------------------------------------------------------------------ + + dataset = load_dataset( + data_args.dataset_name, cache_dir=model_args.cache_dir, trust_remote_code=model_args.trust_remote_code + ) + + # If we don't have a validation split, split off a percentage of train as validation + data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split + if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: + split = dataset["train"].train_test_split(data_args.train_val_split, seed=training_args.seed) + dataset["train"] = split["train"] + dataset["validation"] = split["test"] + + # Get dataset categories and prepare mappings for label_name <-> label_id + categories = dataset["train"].features["objects"].feature["category"].names + id2label = dict(enumerate(categories)) + label2id = {v: k for k, v in id2label.items()} + + # ------------------------------------------------------------------------------------------------ + # Load pretrained config, model and image processor + # ------------------------------------------------------------------------------------------------ + + common_pretrained_args = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.token, + "trust_remote_code": model_args.trust_remote_code, + } + config = AutoConfig.from_pretrained( + model_args.config_name or model_args.model_name_or_path, + label2id=label2id, + id2label=id2label, + **common_pretrained_args, + ) + model = AutoModelForZeroShotObjectDetection.from_pretrained( + model_args.model_name_or_path, + config=config, + ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, + **common_pretrained_args, + ) + processor = AutoProcessor.from_pretrained( + model_args.image_processor_name or model_args.model_name_or_path, + ) + + # Freeze both text_backbone + if model_args.freeze_backbone: + model.model.freeze_backbone() + if model_args.freeze_text_backbone: + for name, param in model.model.text_backbone.named_parameters(): + param.requires_grad_(False) + + # ------------------------------------------------------------------------------------------------ + # Define image augmentations and dataset transforms + # ------------------------------------------------------------------------------------------------ + max_size = data_args.image_square_size + train_augment_and_transform = A.Compose( + [ + A.Compose( + [ + A.SmallestMaxSize(max_size=max_size, p=1.0), + A.RandomSizedBBoxSafeCrop(height=max_size, width=max_size, p=1.0), + ], + p=0.2, + ), + A.OneOf( + [ + A.Blur(blur_limit=7, p=0.5), + A.MotionBlur(blur_limit=7, p=0.5), + A.Defocus(radius=(1, 5), alias_blur=(0.1, 0.25), p=0.1), + ], + p=0.1, + ), + A.Perspective(p=0.1), + A.HorizontalFlip(p=0.5), + A.RandomBrightnessContrast(p=0.5), + A.HueSaturationValue(p=0.1), + ], + bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True, min_area=25), + ) + validation_transform = A.Compose( + [A.NoOp()], + bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True), + ) + + # Make transform functions for batch and apply for dataset splits + train_transform_batch = partial( + augment_and_transform_batch, + transform=train_augment_and_transform, + processor=processor, + id2label=id2label, + label2id=label2id, + random_text_prompt=False, + ) + validation_transform_batch = partial( + augment_and_transform_batch, + transform=validation_transform, + processor=processor, + id2label=id2label, + label2id=label2id, + random_text_prompt=True, + ) + + dataset["train"] = dataset["train"].with_transform(train_transform_batch) + dataset["validation"] = dataset["validation"].with_transform(validation_transform_batch) + dataset["test"] = dataset["test"].with_transform(validation_transform_batch) + + # ------------------------------------------------------------------------------------------------ + # Model training and evaluation with Trainer API + # ------------------------------------------------------------------------------------------------ + + eval_compute_metrics_fn = partial(compute_metrics, processor=processor, id2label=id2label, label2id=label2id) + + trainer = ZeroShotTrainer( + model=model, + args=training_args, + train_dataset=dataset["train"] if training_args.do_train else None, + eval_dataset=dataset["validation"] if training_args.do_eval else None, + tokenizer=processor, + data_collator=collate_fn, + compute_metrics=eval_compute_metrics_fn, + ) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + # Final evaluation + if training_args.do_eval: + metrics = trainer.evaluate(eval_dataset=dataset["test"], metric_key_prefix="test") + trainer.log_metrics("test", metrics) + trainer.save_metrics("test", metrics) + + # Write model card and (optionally) push to hub + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "dataset": data_args.dataset_name, + "tags": ["object-detection", "vision"], + } + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/zero-shot/run_zero_shot_object_detection_no_trainer.py b/examples/pytorch/zero-shot/run_zero_shot_object_detection_no_trainer.py new file mode 100644 index 00000000000000..26b44010a468d6 --- /dev/null +++ b/examples/pytorch/zero-shot/run_zero_shot_object_detection_no_trainer.py @@ -0,0 +1,895 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Finetuning 🤗 Transformers model for object detection with Accelerate.""" + +import argparse +import json +import logging +import math +import os +import random +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Mapping, Tuple, Union + +import albumentations as A +import datasets +import numpy as np +import torch +from accelerate import Accelerator, DistributedDataParallelKwargs +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from huggingface_hub import HfApi +from torch.utils.data import DataLoader +from torchmetrics.detection.mean_ap import MeanAveragePrecision +from tqdm.auto import tqdm + +import transformers +from transformers import ( + AutoConfig, + AutoModelForZeroShotObjectDetection, + AutoProcessor, + SchedulerType, + get_scheduler, +) +from transformers.image_processing_utils import BatchFeature +from transformers.image_transforms import center_to_corners_format +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.44.0.dev0") + +logging.basicConfig(level=logging.INFO) +logger = get_logger(__name__) + +require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") + + +# Copied from examples/pytorch/object-detection/run_object_detection.format_image_annotations_as_coco +def format_image_annotations_as_coco( + image_id: str, categories: List[int], areas: List[float], bboxes: List[Tuple[float]] +) -> dict: + """Format one set of image annotations to the COCO format + + Args: + image_id (str): image id. e.g. "0001" + categories (List[int]): list of categories/class labels corresponding to provided bounding boxes + areas (List[float]): list of corresponding areas to provided bounding boxes + bboxes (List[Tuple[float]]): list of bounding boxes provided in COCO format + ([center_x, center_y, width, height] in absolute coordinates) + + Returns: + dict: { + "image_id": image id, + "annotations": list of formatted annotations + } + """ + annotations = [] + for category, area, bbox in zip(categories, areas, bboxes): + formatted_annotation = { + "image_id": image_id, + "category_id": category, + "iscrowd": 0, + "area": area, + "bbox": list(bbox), + } + annotations.append(formatted_annotation) + + return { + "image_id": image_id, + "annotations": annotations, + } + + +# Copied from examples/pytorch/object-detection/run_object_detection.convert_bbox_yolo_to_pascal +def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: + """ + Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1] + to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates. + + Args: + boxes (torch.Tensor): Bounding boxes in YOLO format + image_size (Tuple[int, int]): Image size in format (height, width) + + Returns: + torch.Tensor: Bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max) + """ + # convert center to corners format + boxes = center_to_corners_format(boxes) + + # convert to absolute coordinates + height, width = image_size + boxes = boxes * torch.tensor([[width, height, width, height]]) + + return boxes + + +def convert_zero_shot_to_coco_format(predictions, label2id): + """ + Convert zershot format output to typical object detection format in order to calculate mAP. + + Args: + predictions (Dict): Output of zero-shot object detection + e.g. {'scores': tensor([0.4786, 0.4379, 0.4760], device='cuda:0'), 'labels': ['a cat', 'a cat', 'a remote control'], 'boxes': tensor([[344.6973, 23.1085, 637.1817, 374.2748],[ 12.2690, 51.9104, 316.8564, 472.4341],[ 38.5870, 70.0092, 176.7755, 118.1748]], device='cuda:0')} + label2id (Dict): Dictionary of label to id mapping + + Returns: + Dict: Output of zero-shot object detection + e.g. {'scores': tensor([0.4786, 0.4379, 0.4760], device='cuda:0'), 'labels': [1, 1, 2], 'boxes': tensor([[344.6973, 23.1085, 637.1817, 374.2748],[ 12.2690, 51.9104, 316.8564, 472.4341],[ 38.5870, 70.0092, 176.7755, 118.1748]], device='cuda:0')} + + """ + # convert center to corners format + torch_label = [] + for prediction in predictions: + scores = prediction["scores"] + device = scores.device + labels = prediction["labels"] + for label in labels: + if label in label2id: + torch_label.append(label2id[label]) + else: + # Give background class + torch_label.append(0) + prediction["labels"] = torch.Tensor(torch_label).to(dtype=torch.int32).to(device) + + return predictions + + +def to_label_list(id2label): + return list(id2label.values()) + + +def concat_func(id2label): + return ". ".join(to_label_list(id2label)) + "." + + +def augment_and_transform_batch( + examples: Mapping[str, Any], + transform: A.Compose, + processor: AutoProcessor, + id2label: Dict[int, str], + label2id: Dict[str, int], + random_text_prompt: bool = False, + return_pixel_mask: bool = False, +) -> BatchFeature: + """ + Apply augmentations and format annotations in COCO format for object detection task. + Generates the text prompt used. If `random_text_prompt` is False + then the prompt will follow the same ordering in `id2label` if set to + True a new ordering will be created and the prompt will be build accordingly + and labels will be updated as well. + + Example: + `id2label` -> {'0': 'fish', '1': 'jellyfish', '2': 'penguins', '3': + 'sharks', '4': 'puffins', '5': 'stingrays', '6': 'starfish'} + + If `random_text_prompt` -> False + `text` -> "fish. jellyfish. penguins. sharks. puffins. stingrays. starfish." + + If `random_text_prompt` -> True + `id2label` gets shuffled e.g. {0: 'fish', 1: 'penguins', 2: 'stingrays', 3: + 'jellyfish', 4: 'sharks', 5: 'starfish', 6: 'puffins'} + `text` -> "fish. penguins. stingrays. jellyfish. sharks. starfish. puffins." + """ + + images = [] + annotations = [] + text = [] + + for image_id, image, objects in zip(examples["image_id"], examples["image"], examples["objects"]): + image = np.array(image.convert("RGB")) + + if random_text_prompt: + # Original ordering label list + label_list = to_label_list(id2label) + # Shuffle label list + random.shuffle(label_list) + # Create shuffled id2label + shuffled_id2label = dict(enumerate(label_list)) + + # Mapping of original to shuffled id to update annotations + old2new = {label2id[label]: new_id for new_id, label in shuffled_id2label.items()} + prompt = concat_func(shuffled_id2label) + category = [old2new[category] for category in objects["category"]] + else: + prompt = concat_func(id2label) + category = objects["category"] + + # apply augmentations + output = transform(image=image, bboxes=objects["bbox"], category=category) + images.append(output["image"]) + + # format annotations in COCO format + formatted_annotations = format_image_annotations_as_coco( + image_id, output["category"], objects["area"], output["bboxes"] + ) + annotations.append(formatted_annotations) + text.append(prompt) + + # Apply the image processor transformations: resizing, rescaling, normalization + result = processor(images=images, text=text, annotations=annotations, return_tensors="pt") + + if not return_pixel_mask: + result.pop("pixel_mask", None) + + return result + + +def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, List[Any]]]: + data = {} + data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch]) + data["input_ids"] = torch.stack([x["input_ids"] for x in batch]) + data["token_type_ids"] = torch.stack([x["token_type_ids"] for x in batch]) + data["labels"] = [x["labels"] for x in batch] + if "pixel_mask" in batch[0]: + data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch]) + if "attention_mask" in batch[0]: + data["attention_mask"] = torch.stack([x["attention_mask"] for x in batch]) + return data + + +def nested_to_cpu(objects): + """Move nested tesnors in objects to CPU if they are on GPU""" + if isinstance(objects, torch.Tensor): + return objects.cpu() + elif isinstance(objects, Mapping): + return type(objects)({k: nested_to_cpu(v) for k, v in objects.items()}) + elif isinstance(objects, (list, tuple)): + return type(objects)([nested_to_cpu(v) for v in objects]) + elif isinstance(objects, (np.ndarray, str, int, float, bool)): + return objects + raise ValueError(f"Unsupported type {type(objects)}") + + +def evaluation_loop( + model: torch.nn.Module, + processor: AutoProcessor, + accelerator: Accelerator, + dataloader: DataLoader, + id2label: Mapping[int, str], + label2id: Mapping[str, int], +) -> dict: + model.eval() + metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True) + + for step, batch in enumerate(tqdm(dataloader, disable=not accelerator.is_local_main_process)): + with torch.no_grad(): + outputs = model(**batch) + + # For metric computation we need to collect ground truth and predicted boxes in the same format + + # 1. Collect predicted boxes, classes, scores + # processor convert boxes from YOLO format to Pascal VOC format + # ([x_min, y_min, x_max, y_max] in absolute coordinates) + image_size = torch.stack([example["orig_size"] for example in batch["labels"]], dim=0) + input_ids = batch["input_ids"] + predictions = processor.post_process_grounded_object_detection( + outputs, input_ids, box_threshold=0.15, text_threshold=0.1, target_sizes=image_size + ) + predictions = nested_to_cpu(predictions) + predictions = convert_zero_shot_to_coco_format(predictions, label2id) + + # 2. Collect ground truth boxes in the same format for metric computation + # Do the same, convert YOLO boxes to Pascal VOC format + target = [] + for label in batch["labels"]: + label = nested_to_cpu(label) + boxes = convert_bbox_yolo_to_pascal(label["boxes"], label["orig_size"]) + labels = label["class_labels"] + target.append({"boxes": boxes, "labels": labels}) + + metric.update(predictions, target) + + metric.to(accelerator.device) + metrics = metric.compute() + + # Replace list of per class metrics with separate metric for each class + classes = metrics.pop("classes") + map_per_class = metrics.pop("map_per_class") + mar_100_per_class = metrics.pop("mar_100_per_class") + for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class): + class_name = id2label[class_id.item()] + metrics[f"map_{class_name}"] = class_map + metrics[f"mar_100_{class_name}"] = class_mar + + # Convert metrics to float + metrics = {k: round(v.item(), 4) for k, v in metrics.items()} + + return metrics + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model for object detection task") + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to a pretrained model or model identifier from huggingface.co/models.", + default="IDEA-Research/grounding-dino-tiny", + ) + parser.add_argument( + "--dataset_name", + type=str, + help="Name of the dataset on the hub.", + default="cppe-5", + ) + parser.add_argument( + "--train_val_split", + type=float, + default=0.15, + help="Fraction of the dataset to be used for validation.", + ) + parser.add_argument( + "--ignore_mismatched_sizes", + action="store_true", + help="Ignore mismatched sizes between the model and the dataset.", + ) + parser.add_argument( + "--image_square_size", + type=int, + default=1333, + help="Image longest size will be resized to this value, then image will be padded to square.", + ) + parser.add_argument( + "--cache_dir", + type=str, + help="Path to a folder in which the model and dataset will be cached.", + ) + parser.add_argument( + "--use_auth_token", + action="store_true", + help="Whether to use an authentication token to access the model repository.", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help="Number of workers to use for the dataloaders.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="Beta1 for AdamW optimizer", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="Beta2 for AdamW optimizer", + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-8, + help="Epsilon for AdamW optimizer", + ) + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--trust_remote_code", + action="store_true", + help=( + "Whether to trust the execution of code from datasets/models defined on the Hub." + " This option should only be set to `True` for repositories you trust and in which you have read the" + " code, as it will execute code present on the Hub on your local machine." + ), + ) + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + required=False, + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. ' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--freeze_backbone", + required=False, + action="store_true", + help="Whether to freeze the image encoder while training.", + ) + parser.add_argument( + "--freeze_text_backbone", + required=False, + action="store_true", + help="Whether to freeze the text encoder while training.", + ) + args = parser.parse_args() + + # Sanity checks + if args.push_to_hub or args.with_tracking: + if args.output_dir is None: + raise ValueError( + "Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified." + ) + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + return args + + +def main(): + args = parse_args() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_object_detection_no_trainer", args) + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + accelerator_log_kwargs = {} + + if args.with_tracking: + accelerator_log_kwargs["log_with"] = args.report_to + accelerator_log_kwargs["project_dir"] = args.output_dir + accelerator_log_kwargs["kwargs_handlers"] = [DistributedDataParallelKwargs(find_unused_parameters=True)] + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) + + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + # We set device_specific to True as we want different data augmentation per device. + if args.seed is not None: + set_seed(args.seed, device_specific=True) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + # Retrieve of infer repo_name + repo_name = args.hub_model_id + if repo_name is None: + repo_name = Path(args.output_dir).absolute().name + # Create repo and retrieve repo_id + api = HfApi() + repo_id = api.create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Load dataset + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + dataset = load_dataset(args.dataset_name, cache_dir=args.cache_dir, trust_remote_code=args.trust_remote_code) + + # If we don't have a validation split, split off a percentage of train as validation. + args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split + if isinstance(args.train_val_split, float) and args.train_val_split > 0.0: + split = dataset["train"].train_test_split(args.train_val_split, seed=args.seed) + dataset["train"] = split["train"] + dataset["validation"] = split["test"] + + # Get dataset categories and prepare mappings for label_name <-> label_id + categories = dataset["train"].features["objects"].feature["category"].names + id2label = dict(enumerate(categories)) + label2id = {v: k for k, v in id2label.items()} + + # ------------------------------------------------------------------------------------------------ + # Load pretrained config, model and image processor + # ------------------------------------------------------------------------------------------------ + + common_pretrained_args = { + "cache_dir": args.cache_dir, + "token": args.hub_token, + "trust_remote_code": args.trust_remote_code, + } + config = AutoConfig.from_pretrained( + args.model_name_or_path, auxiliary_loss=True, label2id=label2id, id2label=id2label, **common_pretrained_args + ) + model = AutoModelForZeroShotObjectDetection.from_pretrained( + args.model_name_or_path, + config=config, + ignore_mismatched_sizes=args.ignore_mismatched_sizes, + **common_pretrained_args, + ) + processor = AutoProcessor.from_pretrained( + args.model_name_or_path, + ) + + # Freeze both text_backbone + if args.freeze_backbone: + model.model.freeze_backbone() + if args.freeze_text_backbone: + for name, param in model.model.text_backbone.named_parameters(): + param.requires_grad_(False) + + # ------------------------------------------------------------------------------------------------ + # Define image augmentations and dataset transforms + # ------------------------------------------------------------------------------------------------ + max_size = args.image_square_size + train_augment_and_transform = A.Compose( + [ + A.Compose( + [ + A.SmallestMaxSize(max_size=max_size, p=1.0), + A.RandomSizedBBoxSafeCrop(height=max_size, width=max_size, p=1.0), + ], + p=0.2, + ), + A.OneOf( + [ + A.Blur(blur_limit=7, p=0.5), + A.MotionBlur(blur_limit=7, p=0.5), + A.Defocus(radius=(1, 5), alias_blur=(0.1, 0.25), p=0.1), + ], + p=0.1, + ), + A.Perspective(p=0.1), + A.HorizontalFlip(p=0.5), + A.RandomBrightnessContrast(p=0.5), + A.HueSaturationValue(p=0.1), + ], + bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True, min_area=25), + ) + validation_transform = A.Compose( + [A.NoOp()], + bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True), + ) + + # Make transform functions for batch and apply for dataset splits + train_transform_batch = partial( + augment_and_transform_batch, + transform=train_augment_and_transform, + processor=processor, + id2label=id2label, + label2id=label2id, + random_text_prompt=False, + ) + validation_transform_batch = partial( + augment_and_transform_batch, + transform=validation_transform, + processor=processor, + id2label=id2label, + label2id=label2id, + random_text_prompt=False, + ) + + with accelerator.main_process_first(): + train_dataset = dataset["train"].with_transform(train_transform_batch) + valid_dataset = dataset["validation"].with_transform(validation_transform_batch) + test_dataset = dataset["test"].with_transform(validation_transform_batch) + + dataloader_common_args = { + "num_workers": args.dataloader_num_workers, + "collate_fn": collate_fn, + } + train_dataloader = DataLoader( + train_dataset, shuffle=True, batch_size=args.per_device_train_batch_size, **dataloader_common_args + ) + valid_dataloader = DataLoader( + valid_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args + ) + test_dataloader = DataLoader( + test_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args + ) + + # ------------------------------------------------------------------------------------------------ + # Define optimizer, scheduler and prepare everything with the accelerator + # ------------------------------------------------------------------------------------------------ + + # Optimizer + optimizer = torch.optim.AdamW( + list(model.parameters()), + lr=args.learning_rate, + betas=[args.adam_beta1, args.adam_beta2], + eps=args.adam_epsilon, + ) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = int(checkpointing_steps) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps + if overrode_max_train_steps + else args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, valid_dataloader, test_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, valid_dataloader, test_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + accelerator.init_trackers("object_detection_no_trainer", experiment_config) + + # ------------------------------------------------------------------------------------------------ + # Run training with evaluation on each epoch + # ------------------------------------------------------------------------------------------------ + + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": + checkpoint_path = args.resume_from_checkpoint + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] + dirs.sort(key=os.path.getctime) + path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last + checkpoint_path = path + path = os.path.basename(checkpoint_path) + + accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") + accelerator.load_state(checkpoint_path) + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + starting_epoch = resume_step // len(train_dataloader) + completed_steps = resume_step // args.gradient_accumulation_steps + resume_step -= starting_epoch * len(train_dataloader) + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + if args.with_tracking: + total_loss = 0 + if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) + else: + active_dataloader = train_dataloader + + for step, batch in enumerate(active_dataloader): + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + # We keep track of the loss at each epoch + if args.with_tracking: + total_loss += loss.detach().float() + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + ) + if accelerator.is_main_process: + processor.save_pretrained(args.output_dir) + api.upload_folder( + commit_message=f"Training in progress epoch {epoch}", + folder_path=args.output_dir, + repo_id=repo_id, + repo_type="model", + token=args.hub_token, + ) + + if completed_steps >= args.max_train_steps: + break + + logger.info("***** Running evaluation *****") + metrics = evaluation_loop(model, processor, accelerator, valid_dataloader, id2label, label2id) + + logger.info(f"epoch {epoch}: {metrics}") + + if args.with_tracking: + accelerator.log( + { + "train_loss": total_loss / len(train_dataloader), + **metrics, + "epoch": epoch, + "step": completed_steps, + }, + step=completed_steps, + ) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + processor.save_pretrained(args.output_dir) + api.upload_folder( + commit_message=f"Training in progress epoch {epoch}", + folder_path=args.output_dir, + repo_id=repo_id, + repo_type="model", + token=args.hub_token, + ) + + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + # ------------------------------------------------------------------------------------------------ + # Run evaluation on test dataset and save the model + # ------------------------------------------------------------------------------------------------ + + logger.info("***** Running evaluation on test dataset *****") + metrics = evaluation_loop(model, processor, accelerator, test_dataloader, id2label, label2id) + metrics = {f"test_{k}": v for k, v in metrics.items()} + + logger.info(f"Test metrics: {metrics}") + + if args.with_tracking: + accelerator.end_training() + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump(metrics, f, indent=2) + + processor.save_pretrained(args.output_dir) + + if args.push_to_hub: + api.upload_folder( + commit_message="End of training", + folder_path=args.output_dir, + repo_id=repo_id, + repo_type="model", + token=args.hub_token, + ignore_patterns=["epoch_*"], + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/grounding_dino/configuration_grounding_dino.py b/src/transformers/models/grounding_dino/configuration_grounding_dino.py index 362e50a1c1cc68..6d9cf0156c4bcc 100644 --- a/src/transformers/models/grounding_dino/configuration_grounding_dino.py +++ b/src/transformers/models/grounding_dino/configuration_grounding_dino.py @@ -97,10 +97,14 @@ class GroundingDinoConfig(PretrainedConfig): Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. giou_cost (`float`, *optional*, defaults to 2.0): Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + class_loss_coefficient (`float`, *optional*, defaults to 2.0): + Relative weight of the cross-entropy loss in the object detection loss. bbox_loss_coefficient (`float`, *optional*, defaults to 5.0): Relative weight of the L1 bounding box loss in the object detection loss. giou_loss_coefficient (`float`, *optional*, defaults to 2.0): Relative weight of the generalized IoU loss in the object detection loss. + class_loss_reduction (`str`, *optional*, defaults to `"sum"`): + The reduction method for the classification loss. One of `"mean"` or `"sum"`. focal_alpha (`float`, *optional*, defaults to 0.25): Alpha parameter in the focal loss. disable_custom_kernels (`bool`, *optional*, defaults to `False`): @@ -181,8 +185,10 @@ def __init__( class_cost=1.0, bbox_cost=5.0, giou_cost=2.0, + class_loss_coefficient=2.0, bbox_loss_coefficient=5.0, giou_loss_coefficient=2.0, + class_loss_reduction="sum", focal_alpha=0.25, disable_custom_kernels=False, # other parameters @@ -199,6 +205,11 @@ def __init__( layer_norm_eps=1e-5, **kwargs, ): + if class_loss_reduction not in ["sum", "mean"]: + raise ValueError( + f"Invalid class_loss_reduction: {class_loss_reduction}. It must be either 'sum' or 'mean'." + ) + if backbone_config is None and backbone is None: logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.") backbone_config = CONFIG_MAPPING["swin"]( @@ -255,8 +266,10 @@ def __init__( self.bbox_cost = bbox_cost self.giou_cost = giou_cost # Loss coefficients + self.class_loss_coefficient = class_loss_coefficient self.bbox_loss_coefficient = bbox_loss_coefficient self.giou_loss_coefficient = giou_loss_coefficient + self.class_loss_reduction = class_loss_reduction self.focal_alpha = focal_alpha self.disable_custom_kernels = disable_custom_kernels # Text backbone diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 3b298704de32fb..fe6c63a4bc662c 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch Grounding DINO model.""" -import copy import math import os import warnings @@ -259,12 +258,18 @@ class GroundingDinoModelOutput(ModelOutput): weighted average in the text-vision attention, vision-text attention, text-enhancer (self-attention) and multi-scale deformable attention heads. attention softmax, used to compute the weighted average in the bi-attention heads. + encoder_topk_proposals (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*, returned when `config.two_stage=True`): + Top `config.num_queries` scoring bounding boxes indices picked as region proposals in the first stage. enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`): Predicted bounding boxes scores where the top `config.num_queries` scoring bounding boxes are picked as region proposals in the first stage. Output of bounding box binary classification (i.e. foreground and background). enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): Logits of predicted bounding boxes coordinates in the first stage. + encoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`): + Logits of top `config.num_queries` scoring bounding boxes in the first stage. + encoder_pred_boxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): + Coordinates of top `config.num_queries` scoring bounding boxes in the first stage. """ last_hidden_state: torch.FloatTensor = None @@ -278,8 +283,11 @@ class GroundingDinoModelOutput(ModelOutput): encoder_vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_text_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + encoder_topk_proposals: Optional[torch.FloatTensor] = None enc_outputs_class: Optional[torch.FloatTensor] = None enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + encoder_logits: Optional[torch.FloatTensor] = None + encoder_pred_boxes: Optional[torch.FloatTensor] = None @dataclass @@ -338,12 +346,18 @@ class GroundingDinoObjectDetectionOutput(ModelOutput): Stacked intermediate reference points (reference points of each layer of the decoder). init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): Initial reference points sent through the Transformer decoder. + encoder_topk_proposals (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*, returned when `config.two_stage=True`): + Top `config.num_queries` scoring bounding boxes indices picked as region proposals in the first stage. enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`): Predicted bounding boxes scores where the top `config.num_queries` scoring bounding boxes are picked as region proposals in the first stage. Output of bounding box binary classification (i.e. foreground and background). enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): Logits of predicted bounding boxes coordinates in the first stage. + encoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`): + Logits of top `config.num_queries` scoring bounding boxes in the first stage. + encoder_pred_boxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): + Coordinates of top `config.num_queries` scoring bounding boxes in the first stage. """ loss: Optional[torch.FloatTensor] = None @@ -362,8 +376,11 @@ class GroundingDinoObjectDetectionOutput(ModelOutput): encoder_vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_text_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + encoder_topk_proposals: Optional[torch.FloatTensor] = None enc_outputs_class: Optional[torch.FloatTensor] = None enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + encoder_logits: Optional[torch.FloatTensor] = None + encoder_pred_boxes: Optional[torch.FloatTensor] = None # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->GroundingDino @@ -2384,8 +2401,11 @@ def forward( ) # Fifth, prepare decoder inputs + topk_proposals = None enc_outputs_class = None enc_outputs_coord_logits = None + encoder_logits = None + encoder_pred_boxes = None if self.config.two_stage: object_query_embedding, output_proposals = self.generate_encoder_output_proposals( encoder_outputs[0], ~mask_flatten, spatial_shapes @@ -2418,6 +2438,10 @@ def forward( target = torch.gather( object_query_embedding, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) ).detach() + + # Set intermediate topk proposals (coords and class) for loss computation + encoder_pred_boxes = reference_points + encoder_logits = self.encoder_output_class_embed(target, text_features, text_token_mask) else: target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1) reference_points = self.reference_points.weight.unsqueeze(0).repeat(batch_size, 1, 1).sigmoid() @@ -2440,7 +2464,17 @@ def forward( ) if not return_dict: - enc_outputs = tuple(value for value in [enc_outputs_class, enc_outputs_coord_logits] if value is not None) + enc_outputs = tuple( + value + for value in [ + topk_proposals, + enc_outputs_class, + enc_outputs_coord_logits, + encoder_logits, + encoder_pred_boxes, + ] + if value is not None + ) tuple_outputs = ( (decoder_outputs[0], init_reference_points) + decoder_outputs[1:] + encoder_outputs + enc_outputs ) @@ -2459,8 +2493,11 @@ def forward( encoder_vision_hidden_states=encoder_outputs.vision_hidden_states, encoder_text_hidden_states=encoder_outputs.text_hidden_states, encoder_attentions=encoder_outputs.attentions, + encoder_topk_proposals=topk_proposals, enc_outputs_class=enc_outputs_class, enc_outputs_coord_logits=enc_outputs_coord_logits, + encoder_logits=encoder_logits, + encoder_pred_boxes=encoder_pred_boxes, ) @@ -2554,38 +2591,17 @@ def generalized_box_iou(boxes1, boxes2): return iou - (area - union) / area -# Copied from transformers.models.detr.modeling_detr._max_by_axis -def _max_by_axis(the_list): - # type: (List[List[int]]) -> List[int] - maxes = the_list[0] - for sublist in the_list[1:]: - for index, item in enumerate(sublist): - maxes[index] = max(maxes[index], item) - return maxes - - -# Copied from transformers.models.detr.modeling_detr.dice_loss -def dice_loss(inputs, targets, num_boxes): - """ - Compute the DICE loss, similar to generalized IOU for masks - - Args: - inputs: A float tensor of arbitrary shape. - The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs (0 for the negative class and 1 for the positive - class). - """ - inputs = inputs.sigmoid() - inputs = inputs.flatten(1) - numerator = 2 * (inputs * targets).sum(1) - denominator = inputs.sum(-1) + targets.sum(-1) - loss = 1 - (numerator + 1) / (denominator + 1) - return loss.sum() / num_boxes - - -# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss -def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): +# Similar to the one used in `DeformableDetr` but we pass `num_queries`, as `logits` are flattened +# due to masked selection, and support different `reduction` modes. +def sigmoid_focal_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + num_boxes: int, + num_queries: int, + reduction: str = "mean", + alpha: float = 0.25, + gamma: float = 2, +): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. @@ -2595,9 +2611,15 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f targets (`torch.FloatTensor` with the same shape as `inputs`) A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class and 1 for the positive class). - alpha (`float`, *optional*, defaults to `0.25`): + num_boxes (`int`): + The total number of boxes in the batch. + num_queries (`int`): + The number of query boxes per image. + reduction (`str`, *optional*, defaults to `'mean'`): + Specifies the redction to apply to the loss. Can be either `'mean'`, or `'sum'`. + alpha (`float`, *optional*, defaults to 0.25): Optional weighting factor in the range (0,1) to balance positive vs. negative examples. - gamma (`int`, *optional*, defaults to `2`): + gamma (`int`, *optional*, defaults to 2): Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: @@ -2613,47 +2635,12 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss - return loss.mean(1).sum() / num_boxes - - -# Copied from transformers.models.detr.modeling_detr.NestedTensor -class NestedTensor: - def __init__(self, tensors, mask: Optional[Tensor]): - self.tensors = tensors - self.mask = mask - - def to(self, device): - cast_tensor = self.tensors.to(device) - mask = self.mask - if mask is not None: - cast_mask = mask.to(device) - else: - cast_mask = None - return NestedTensor(cast_tensor, cast_mask) - - def decompose(self): - return self.tensors, self.mask - - def __repr__(self): - return str(self.tensors) - - -# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list -def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): - if tensor_list[0].ndim == 3: - max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - batch_shape = [len(tensor_list)] + max_size - batch_size, num_channels, height, width = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], : img.shape[2]] = False + if reduction == "mean": + return loss.sum() / num_queries / num_boxes + elif reduction == "sum": + return loss.sum() / num_boxes else: - raise ValueError("Only 3-dimensional tensors are supported") - return NestedTensor(tensor, mask) + raise ValueError(f"{reduction=} is not a valid reduction method") # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->GroundingDino @@ -2685,6 +2672,7 @@ def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float raise ValueError("All costs of the Matcher can't be 0") @torch.no_grad() + # Ignore copy def forward(self, outputs, targets): """ Args: @@ -2692,6 +2680,7 @@ def forward(self, outputs, targets): A dictionary that contains at least these entries: * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates. + * "label_maps": Tuple of tensors of dim [num_classes, hidden_dim]. targets (`List[dict]`): A list of targets (len(targets) = batch_size), where each target is a dict containing: * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of @@ -2708,11 +2697,16 @@ def forward(self, outputs, targets): batch_size, num_queries = outputs["logits"].shape[:2] # We flatten to compute the cost matrices in a batch - out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] + out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, hidden_dim] out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + label_maps = outputs["label_maps"] + + # First take the label map for each class in each batch and then concatenate them + label_maps = torch.cat([label_map[target["class_labels"]] for label_map, target in zip(label_maps, targets)]) + # Normalize label maps based on number of tokens per class + label_maps = label_maps / label_maps.sum(dim=-1, keepdim=True) # Also concat the target labels and boxes - target_ids = torch.cat([v["class_labels"] for v in targets]) target_bbox = torch.cat([v["boxes"] for v in targets]) # Compute the classification cost. @@ -2720,7 +2714,8 @@ def forward(self, outputs, targets): gamma = 2.0 neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) - class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids] + # Compute the classification cost by taking pos and neg cost in the appropriate index + class_cost = (pos_cost_class - neg_cost_class) @ label_maps.t() # Compute the L1 cost between boxes bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) @@ -2737,7 +2732,6 @@ def forward(self, outputs, targets): return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss with DeformableDetr->GroundingDino class GroundingDinoLoss(nn.Module): """ This class computes the losses for `GroundingDinoForObjectDetection`. The process happens in two steps: 1) we @@ -2747,22 +2741,41 @@ class GroundingDinoLoss(nn.Module): Args: matcher (`GroundingDinoHungarianMatcher`): Module able to compute a matching between targets and proposals. - num_classes (`int`): - Number of object categories, omitting the special no-object category. focal_alpha (`float`): Alpha parameter in focal loss. losses (`List[str]`): List of all the losses to be applied. See `get_loss` for a list of all available losses. + class_reduction (`str`): + Specifies the reduction to apply to the label loss. Can be either `'mean'` or `'sum'` """ - def __init__(self, matcher, num_classes, focal_alpha, losses): + def __init__(self, matcher, focal_alpha, losses, class_reduction): super().__init__() self.matcher = matcher - self.num_classes = num_classes self.focal_alpha = focal_alpha self.losses = losses + self.class_reduction = class_reduction + + def _get_target_classes_one_hot(self, outputs, targets, indices): + """ + Create one_hot based on the matching indices + """ + logits = outputs["logits"] + # Add offsets to class_labels to select the correct label map + class_labels = torch.cat( + [ + target["class_labels"][J] + len(outputs["label_maps"][i]) if i > 0 else target["class_labels"][J] + for i, (target, (_, J)) in enumerate(zip(targets, indices)) + ] + ) + label_maps = torch.cat(outputs["label_maps"], dim=0) + + idx = self._get_source_permutation_idx(indices) + target_classes_onehot = torch.zeros_like(logits, device=logits.device, dtype=torch.long) + target_classes_onehot[idx] = label_maps[class_labels].to(torch.long) + + return target_classes_onehot - # removed logging parameter, which was part of the original implementation def loss_labels(self, outputs, targets, indices, num_boxes): """ Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor @@ -2770,50 +2783,35 @@ def loss_labels(self, outputs, targets, indices, num_boxes): """ if "logits" not in outputs: raise KeyError("No logits were found in the outputs") - source_logits = outputs["logits"] - - idx = self._get_source_permutation_idx(indices) - target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) - target_classes = torch.full( - source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device - ) - target_classes[idx] = target_classes_o + if "text_mask" not in outputs: + raise KeyError("No text_mask were found in the outputs") - target_classes_onehot = torch.zeros( - [source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1], - dtype=source_logits.dtype, - layout=source_logits.layout, - device=source_logits.device, + target_classes_onehot = self._get_target_classes_one_hot(outputs, targets, indices) + source_logits = outputs["logits"] + text_mask = outputs["text_mask"] + + # Select only valid logits + source_logits = torch.masked_select(source_logits, text_mask) + target_classes_onehot = torch.masked_select(target_classes_onehot, text_mask) + + num_queries = source_logits.shape[0] + + target_classes_onehot = target_classes_onehot.float() + loss_ce = sigmoid_focal_loss( + inputs=source_logits, + targets=target_classes_onehot, + num_boxes=num_boxes, + num_queries=num_queries, + reduction=self.class_reduction, + alpha=self.focal_alpha, + gamma=2, ) - target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) - target_classes_onehot = target_classes_onehot[:, :, :-1] - loss_ce = ( - sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) - * source_logits.shape[1] - ) losses = {"loss_ce": loss_ce} return losses - @torch.no_grad() - # Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_cardinality - def loss_cardinality(self, outputs, targets, indices, num_boxes): - """ - Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. - - This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. - """ - logits = outputs["logits"] - device = logits.device - target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) - # Count the number of predictions that are NOT "no-object" (which is the last class) - card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) - card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) - losses = {"cardinality_error": card_err} - return losses - - # Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_boxes + # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_boxes def loss_boxes(self, outputs, targets, indices, num_boxes): """ Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. @@ -2838,14 +2836,31 @@ def loss_boxes(self, outputs, targets, indices, num_boxes): losses["loss_giou"] = loss_giou.sum() / num_boxes return losses - # Copied from transformers.models.detr.modeling_detr.DetrLoss._get_source_permutation_idx + @torch.no_grad() + # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_cardinality + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ + Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. + + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. + """ + logits = outputs["logits"] + device = logits.device + target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) + card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) + losses = {"cardinality_error": card_err} + return losses + + # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_source_permutation_idx def _get_source_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) source_idx = torch.cat([source for (source, _) in indices]) return batch_idx, source_idx - # Copied from transformers.models.detr.modeling_detr.DetrLoss._get_target_permutation_idx + # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_target_permutation_idx def _get_target_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) @@ -2855,8 +2870,8 @@ def _get_target_permutation_idx(self, indices): def get_loss(self, loss, outputs, targets, indices, num_boxes): loss_map = { "labels": self.loss_labels, - "cardinality": self.loss_cardinality, "boxes": self.loss_boxes, + "cardinality": self.loss_cardinality, } if loss not in loss_map: raise ValueError(f"Loss {loss} not supported") @@ -2904,18 +2919,82 @@ def forward(self, outputs, targets): if "enc_outputs" in outputs: enc_outputs = outputs["enc_outputs"] - bin_targets = copy.deepcopy(targets) - for bt in bin_targets: - bt["class_labels"] = torch.zeros_like(bt["class_labels"]) - indices = self.matcher(enc_outputs, bin_targets) + indices = self.matcher(enc_outputs, targets) for loss in self.losses: - l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes) + l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes) l_dict = {k + "_enc": v for k, v in l_dict.items()} losses.update(l_dict) return losses +def build_label_maps(logits: torch.FloatTensor, input_ids: torch.LongTensor) -> Tuple[torch.FloatTensor]: + """ + Computes a mapping between tokens and their corresponding labels, where `num_labels` is determined by the number of classes in the input prompt. + The function identifies segments of tokens between specific delimiter tokens and generates label maps for those segments. + Args: + logits (`torch.Tensor` of shape `(batch_size, seq_length, hidden_size)`): + The output logits from the model, where `hidden_size` corresponds to the dimension of the model's output features. + + input_ids (`torch.Tensor` of shape `(batch_size, seq_length)`): + The input token IDs corresponding to the input prompt. For example, given the prompt "fish. shark.", + `input_ids` might look like `[101, 3869, 1012, 11420, 1012, 102]` where each number corresponds to a token including special tokens. + Returns: + tuple: A tuple containing label maps for each instance in the batch. + - label_maps (tuple of `torch.Tensor`): + A tuple of tensors, where each tensor in the tuple corresponds to an instance in the batch. Each tensor + has shape `(num_labels, hidden_size)` and contains binary values (0 or 1), where `1` indicates the tokens + that are associated with a specific label (class) between delimiter tokens, and `0` elsewhere. + Example: + Given an input prompt "fish. shark." and corresponding `input_ids` as `[101, 3869, 1012, 11420, 1012, 102]`: + - The function identifies the tokens for "fish" (IDs `[3869]`) and "shark" (IDs `[11420]`). + - The function then constructs label maps for these tokens, where each label map indicates which tokens + correspond to which label between the delimiter tokens (e.g., between the period `.`). + - The output is a tuple of label maps, one for each instance in the batch. + Note: + - `SPECIAL_TOKENS` should be a predefined list of tokens that are considered special (e.g., `[CLS]`, `[SEP]`, etc.). + """ + max_seq_len = logits.shape[-1] + # Add [PAD] token to the list of special tokens + delimiter_tokens = torch.tensor(SPECIAL_TOKENS + [0], device=input_ids.device) + + delimiter_token_masks = torch.isin(input_ids, delimiter_tokens) + label_groups = torch.cumsum(delimiter_token_masks, dim=1) * (~delimiter_token_masks).to(torch.int32) + + label_maps = () + + # Iterate over batch dimension as we can have different number of labels + for label_group in label_groups: + # `label_group` is a tensor of shape `(seq_len,)` with zeros for non-label tokens and integers for label tokens + # label tokens with same integer value are part of the same label group + + # Get unique labels and exclude 0 (i.e. non-label tokens) + unique_labels = torch.unique(label_group)[1:, None] + num_labels = unique_labels.shape[0] + + # Create one-hot encoding for each label group + label_map = label_group.unsqueeze(0).repeat(num_labels, 1) + label_map = torch.where(label_map == unique_labels, 1, 0) + + # Pad label_map to match `max_seq_len` + label_map = F.pad(label_map, (0, max_seq_len - label_map.shape[1]), value=0) + + label_maps += (label_map,) + + return label_maps + + +def build_text_mask(logits, attention_mask): + """ + Create text_mask based on the matching indices + """ + seq_len = attention_mask.shape[1] + text_mask = torch.zeros_like(logits, device=logits.device, dtype=attention_mask.dtype) + text_mask[:, :, :seq_len] = attention_mask[:, None, :] + + return text_mask.bool() + + @add_start_docstrings( """ Grounding DINO Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, @@ -2956,11 +3035,14 @@ def __init__(self, config: GroundingDinoConfig): # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py @torch.jit.unused - def _set_aux_loss(self, outputs_class, outputs_coord): + def _set_aux_loss(self, outputs_class, outputs_coord, label_maps, text_mask): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. - return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + return [ + {"logits": a, "pred_boxes": b, "label_maps": label_maps, "text_mask": text_mask} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + ] @add_start_docstrings_to_model_forward(GROUNDING_DINO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=GroundingDinoObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) @@ -3085,31 +3167,50 @@ def forward( losses = ["labels", "boxes", "cardinality"] criterion = GroundingDinoLoss( matcher=matcher, - num_classes=self.config.num_labels, + class_reduction=self.config.class_loss_reduction, focal_alpha=self.config.focal_alpha, losses=losses, ) criterion.to(self.device) + label_maps = build_label_maps(logits, input_ids) + text_mask = build_text_mask(logits, attention_mask) # Third: compute the losses, based on outputs and labels outputs_loss = {} outputs_loss["logits"] = logits outputs_loss["pred_boxes"] = pred_boxes + outputs_loss["label_maps"] = label_maps + outputs_loss["text_mask"] = text_mask + if self.config.auxiliary_loss: - auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord) + auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord, label_maps, text_mask) outputs_loss["auxiliary_outputs"] = auxiliary_outputs + if self.config.two_stage: - enc_outputs_coord = outputs[-1].sigmoid() - outputs_loss["enc_outputs"] = {"logits": outputs[-2], "pred_boxes": enc_outputs_coord} + outputs_loss["enc_outputs"] = { + "logits": outputs[-2], + "pred_boxes": outputs[-1], + "label_maps": label_maps, + "text_mask": text_mask, + } loss_dict = criterion(outputs_loss, labels) # Fourth: compute total loss, as a weighted sum of the various losses - weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient} - weight_dict["loss_giou"] = self.config.giou_loss_coefficient + weight_dict = { + "loss_ce": self.config.class_loss_coefficient, + "loss_bbox": self.config.bbox_loss_coefficient, + "loss_giou": self.config.giou_loss_coefficient, + } + + if self.config.two_stage: + enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()} + weight_dict.update(enc_weight_dict) + if self.config.auxiliary_loss: aux_weight_dict = {} for i in range(self.config.decoder_layers - 1): aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) if not return_dict: @@ -3138,8 +3239,11 @@ def forward( intermediate_hidden_states=outputs.intermediate_hidden_states, intermediate_reference_points=outputs.intermediate_reference_points, init_reference_points=outputs.init_reference_points, + encoder_topk_proposals=outputs.encoder_topk_proposals, enc_outputs_class=outputs.enc_outputs_class, enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + encoder_logits=outputs.encoder_logits, + encoder_pred_boxes=outputs.encoder_pred_boxes, ) return dict_outputs diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 216d5cd4296008..b79f20205f8978 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2346,25 +2346,16 @@ def _inner_training_loop( total_batched_samples += 1 if self.args.include_num_input_tokens_seen: - main_input_name = getattr(self.model, "main_input_name", "input_ids") - if main_input_name not in inputs: - logger.warning( - "Tried to track the number of tokens seen, however the current model is " - "not configured properly to know what item is the input. To fix this, add " - "a `main_input_name` attribute to the model class you are using." - ) - else: - self.state.num_input_tokens_seen += ( - torch.sum( - self.accelerator.gather( - torch.tensor( - inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64 - ) - ) + selected_inputs = self._select_inputs_for_validation(inputs) + self.state.num_input_tokens_seen += ( + torch.sum( + self.accelerator.gather( + torch.tensor(selected_inputs, device=self.args.device, dtype=torch.int64) ) - .cpu() - .item() ) + .cpu() + .item() + ) if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False @@ -3402,6 +3393,11 @@ def log(self, logs: Dict[str, float]) -> None: self.state.log_history.append(output) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) + def _select_inputs_for_validation(self, inputs): + """Simple getattr function for getting input by name""" + main_input_name = getattr(self.model, "main_input_name", "input_ids") + return inputs[main_input_name] + def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: """ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. @@ -4059,8 +4055,8 @@ def evaluation_loop( # Prediction step losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - main_input_name = getattr(self.model, "main_input_name", "input_ids") - inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None + selected_inputs = self._select_inputs_for_validation(inputs) + inputs_decode = self._prepare_input(selected_inputs) if args.include_inputs_for_metrics else None if is_torch_xla_available(): xm.mark_step() @@ -4655,8 +4651,8 @@ def prediction_loop( for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - main_input_name = getattr(self.model, "main_input_name", "input_ids") - inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None + selected_inputs = self._select_inputs_for_validation(inputs) + inputs_decode = self._prepare_input(selected_inputs) if args.include_inputs_for_metrics else None if loss is not None: losses = loss.repeat(batch_size) diff --git a/tests/models/grounding_dino/test_modeling_grounding_dino.py b/tests/models/grounding_dino/test_modeling_grounding_dino.py index c6e9671dd59ae0..f70b08ba5c6039 100644 --- a/tests/models/grounding_dino/test_modeling_grounding_dino.py +++ b/tests/models/grounding_dino/test_modeling_grounding_dino.py @@ -20,6 +20,8 @@ import re import unittest +from datasets import load_dataset + from transformers import ( GroundingDinoConfig, SwinConfig, @@ -37,14 +39,14 @@ ) from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import GroundingDinoForObjectDetection, GroundingDinoModel + from transformers import GroundingDinoConfig, GroundingDinoForObjectDetection, GroundingDinoModel from transformers.pytorch_utils import id_tensor_storage @@ -54,6 +56,39 @@ from transformers import AutoProcessor +def generate_fake_bounding_boxes(n_boxes): + """Generate bounding boxes in the format (cx, cy, w, h)""" + # Validate the input + if not isinstance(n_boxes, int): + raise ValueError("n_boxes must be an integer") + if n_boxes <= 0: + raise ValueError("n_boxes must be a positive integer") + + # Generate random bounding boxes in the format (cx, cy, w, h) + bounding_boxes = torch.rand((n_boxes, 4)) + + # Extract the components + cx = bounding_boxes[:, 0] + cy = bounding_boxes[:, 1] + w = bounding_boxes[:, 2] + h = bounding_boxes[:, 3] + + # Ensure width and height do not exceed bounds + w = torch.min(w, torch.tensor(1.0)) + h = torch.min(h, torch.tensor(1.0)) + + # Ensure the bounding box stays within the normalized space + cx = torch.where(cx - w / 2 < 0, w / 2, cx) + cx = torch.where(cx + w / 2 > 1, 1 - w / 2, cx) + cy = torch.where(cy - h / 2 < 0, h / 2, cy) + cy = torch.where(cy + h / 2 > 1, 1 - h / 2, cy) + + # Combine back into bounding boxes + bounding_boxes = torch.stack([cx, cy, w, h], dim=1) + + return bounding_boxes + + class GroundingDinoModelTester: def __init__( self, @@ -72,7 +107,7 @@ def __init__( num_channels=3, image_size=98, n_targets=8, - num_labels=3, + num_labels=2, num_feature_levels=4, encoder_n_points=2, decoder_n_points=6, @@ -115,7 +150,11 @@ def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) - input_ids = ids_tensor([self.batch_size, self.max_text_len], self.num_labels) + # When using `GroundingDino` the text input template is '{label1}. {label2}. {label3. ... {labelN}.' + # Therefore to avoid errors when running tests with `labels` `input_ids` have to follow this structure. + # Otherwise when running `build_label_maps` it will throw an error when trying to split the input_ids into segments. + input_ids = torch.tensor([101, 3869, 1012, 11420, 3869, 1012, 102]) + input_ids = input_ids.unsqueeze(0).expand(self.batch_size, -1) labels = None if self.use_labels: @@ -126,7 +165,7 @@ def prepare_config_and_inputs(self): target["class_labels"] = torch.randint( high=self.num_labels, size=(self.n_targets,), device=torch_device ) - target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + target["boxes"] = generate_fake_bounding_boxes(self.n_targets).to(torch_device) target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device) labels.append(target) @@ -317,7 +356,7 @@ def test_attention_outputs(self): ) out_len = len(outputs) - correct_outlen = 10 + correct_outlen = 13 # loss is at first position if "labels" in inputs_dict: @@ -741,3 +780,53 @@ def test_cross_attention_mask(self): self.assertTrue(torch.allclose(outputs1.logits, outputs_batched.logits[:1], atol=1e-3)) # For some reason 12 elements are > 1e-3, but the rest are fine self.assertTrue(torch.allclose(outputs2.logits, outputs_batched.logits[1:], atol=1.8e-3)) + + def test_grounding_dino_loss(self): + ds = load_dataset("EduardoPacheco/aquarium-sample", split="train") + image_processor = self.default_processor.image_processor + tokenizer = self.default_processor.tokenizer + id2label = {0: "fish", 1: "jellyfish", 2: "penguins", 3: "sharks", 4: "puffins", 5: "stingrays", 6: "starfish"} + prompt = ". ".join(id2label.values()) + "." + + text_inputs = tokenizer([prompt, prompt], return_tensors="pt") + image_inputs = image_processor(images=ds["image"], annotations=ds["annotations"], return_tensors="pt") + + # Passing class_reduction="sum" and auxiliary_loss=True to compare with the expected loss + model = GroundingDinoForObjectDetection.from_pretrained( + "IDEA-Research/grounding-dino-tiny", auxiliary_loss=True, class_loss_reduction="sum" + ) + # Interested in the loss only + model.eval() + with torch.no_grad(): + outputs = model(**text_inputs, **image_inputs) + + expected_loss_dict = { + "loss_ce": torch.tensor(1.1151), + "loss_bbox": torch.tensor(0.2031), + "loss_giou": torch.tensor(0.5819), + "loss_ce_0": torch.tensor(1.1942), + "loss_bbox_0": torch.tensor(0.1978), + "loss_giou_0": torch.tensor(0.5524), + "loss_ce_1": torch.tensor(1.1623), + "loss_bbox_1": torch.tensor(0.1909), + "loss_giou_1": torch.tensor(0.5892), + "loss_ce_2": torch.tensor(1.1643), + "loss_bbox_2": torch.tensor(0.1891), + "loss_giou_2": torch.tensor(0.5626), + "loss_ce_3": torch.tensor(1.1945), + "loss_bbox_3": torch.tensor(0.1943), + "loss_giou_3": torch.tensor(0.5592), + "loss_ce_4": torch.tensor(1.0946), + "loss_bbox_4": torch.tensor(0.2037), + "loss_giou_4": torch.tensor(0.5813), + "loss_ce_enc": torch.tensor(16226.3145), + "loss_bbox_enc": torch.tensor(0.3063), + "loss_giou_enc": torch.tensor(0.7380), + } + + expected_loss = torch.tensor(32482.2344) + + for key in expected_loss_dict: + self.assertTrue(torch.allclose(outputs.loss_dict[key], expected_loss_dict[key], atol=1e-3)) + + self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-3))