Skip to content

Commit

Permalink
remove OmdetTurboModel
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Sep 20, 2024
1 parent 66ef0b9 commit 0980201
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 231 deletions.
5 changes: 0 additions & 5 deletions docs/source/en/model_doc/omdet-turbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,6 @@ Detected statue with confidence 0.2 at location [428.1, 205.5, 767.3, 759.5] in
[[autodoc]] OmDetTurboProcessor
- post_process_grounded_object_detection

## OmDetTurboModel

[[autodoc]] OmDetTurboModel
- forward

## OmDetTurboForObjectDetection

[[autodoc]] OmDetTurboForObjectDetection
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2851,7 +2851,6 @@
_import_structure["models.omdet_turbo"].extend(
[
"OmDetTurboForObjectDetection",
"OmDetTurboModel",
"OmDetTurboPreTrainedModel",
]
)
Expand Down Expand Up @@ -7368,7 +7367,6 @@
)
from .models.omdet_turbo import (
OmDetTurboForObjectDetection,
OmDetTurboModel,
OmDetTurboPreTrainedModel,
)
from .models.oneformer import (
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
("nystromformer", "NystromformerModel"),
("olmo", "OlmoModel"),
("olmoe", "OlmoeModel"),
("omdet-turbo", "OmDetTurboModel"),
("omdet-turbo", "OmDetTurboForObjectDetection"),
("oneformer", "OneFormerModel"),
("open-llama", "OpenLlamaModel"),
("openai-gpt", "OpenAIGPTModel"),
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/omdet_turbo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
else:
_import_structure["modeling_omdet_turbo"] = [
"OmDetTurboForObjectDetection",
"OmDetTurboModel",
"OmDetTurboPreTrainedModel",
]

Expand All @@ -48,7 +47,6 @@
else:
from .modeling_omdet_turbo import (
OmDetTurboForObjectDetection,
OmDetTurboModel,
OmDetTurboPreTrainedModel,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class OmDetTurboConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`OmDetTurboModel`].
This is the configuration class to store the configuration of a [`OmDetTurboForObjectDetection`].
It is used to instantiate a OmDet-Turbo model according to the specified arguments, defining the model architecture
Instantiating a configuration with the defaults will yield a similar configuration to that of the OmDet-Turbo
[yonigozlan/omdet-turbo-tiny](https://huggingface.co/yonigozlan/omdet-turbo-tiny) architecture.
Expand Down Expand Up @@ -132,13 +132,13 @@ class OmDetTurboConfig(PretrainedConfig):
Examples:
```python
>>> from transformers import OmDetTurboConfig, OmDetTurboModel
>>> from transformers import OmDetTurboConfig, OmDetTurboForObjectDetection
>>> # Initializing a OmDet-Turbo omlab/omdet-turbo-tiny style configuration
>>> configuration = OmDetTurboConfig()
>>> # Initializing a model (with random weights) from the omlab/omdet-turbo-tiny style configuration
>>> model = OmDetTurboModel(configuration)
>>> model = OmDetTurboForObjectDetection(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Expand Down
226 changes: 29 additions & 197 deletions src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,60 +121,6 @@ class OmDetTurboDecoderOutput(ModelOutput):
intermediate_reference_points: Tuple[Tuple[torch.FloatTensor]] = None


@dataclass
class OmDetTurboModelOutput(ModelOutput):
"""
Output type of [`OmDetTurboObjectDetectionOutput`].
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder.
decoder_coords (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
The predicted coordinates logits of the objects.
decoder_classes (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes)`):
The predicted classes of the objects.
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
The initial reference points.
intermediate_reference_points (`Tuple[Tuple[torch.FloatTensor]]`):
The intermediate reference points.
encoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
The predicted coordinates of the objects from the encoder.
encoder_class_logits (`Tuple[torch.FloatTensor]`):
The predicted class of the objects from the encoder.
encoder_extracted_states (`torch.FloatTensor`):
The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder.
decoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
decoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`):
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
encoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
encoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`):
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
"""

last_hidden_state: torch.FloatTensor = None
decoder_coords: torch.FloatTensor = None
decoder_classes: torch.FloatTensor = None
init_reference_points: torch.FloatTensor = None
intermediate_reference_points: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
encoder_coord_logits: torch.FloatTensor = None
encoder_class_logits: Tuple[torch.FloatTensor] = None
encoder_extracted_states: torch.FloatTensor = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class OmDetTurboObjectDetectionOutput(ModelOutput):
"""
Expand Down Expand Up @@ -1709,7 +1655,7 @@ def forward(
""",
OMDET_TURBO_START_DOCSTRING,
)
class OmDetTurboModel(OmDetTurboPreTrainedModel):
class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel):
def __init__(self, config: OmDetTurboConfig):
super().__init__(config)
self.vision_backbone = OmDetTurboVisionBackbone(config)
Expand Down Expand Up @@ -1741,7 +1687,7 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
return model_embeds

@add_start_docstrings_to_model_forward(OMDET_TURBO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OmDetTurboModelOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
Expand All @@ -1754,7 +1700,7 @@ def forward(
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.FloatTensor], OmDetTurboModelOutput]:
) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]:
r"""
Returns:
Expand All @@ -1777,9 +1723,24 @@ def forward(
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 900, 256]
>>> # convert outputs (bounding boxes and class logits)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... classes=classes,
... target_sizes=[image.size[::-1]],
... score_threshold=0.3,
... nms_threshold=0.3,
>>> )[0]
>>> for score, class_name, box in zip(results["scores"], results["classes"], results["boxes"]):
... box = [round(i, 1) for i in box.tolist()]
... print(
... f"Detected {class_name} with confidence "
... f"{round(score.item(), 2)} at location {box}"
... )
Detected remote with confidence 0.76 at location [39.9, 71.3, 176.5, 117.9]
Detected cat with confidence 0.72 at location [345.1, 22.5, 639.7, 371.9]
Detected cat with confidence 0.65 at location [12.7, 53.8, 315.5, 475.3]
Detected remote with confidence 0.57 at location [333.4, 75.6, 370.7, 187.0]
```"""
if labels is not None:
raise NotImplementedError("Training is not implemented yet")
Expand All @@ -1790,6 +1751,7 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

loss = None
image_features = self.vision_backbone(pixel_values)
encoder_outputs = self.encoder(
image_features,
Expand Down Expand Up @@ -1819,10 +1781,9 @@ def forward(
return tuple(
output
for output in [
None,
decoder_outputs[0],
decoder_outputs[3],
decoder_outputs[4],
loss,
decoder_outputs[3][-1],
decoder_outputs[4][-1],
decoder_outputs[7],
decoder_outputs[8],
decoder_outputs[5],
Expand All @@ -1836,10 +1797,10 @@ def forward(
if output is not None
)

return OmDetTurboModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_coords=decoder_outputs.decoder_coords,
decoder_classes=decoder_outputs.decoder_classes,
return OmDetTurboObjectDetectionOutput(
loss=loss,
decoder_coord_logits=decoder_outputs.decoder_coords[-1],
decoder_class_logits=decoder_outputs.decoder_classes[-1],
init_reference_points=decoder_outputs.init_reference_points,
intermediate_reference_points=decoder_outputs.intermediate_reference_points,
encoder_coord_logits=decoder_outputs.encoder_coord_logits,
Expand All @@ -1850,132 +1811,3 @@ def forward(
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)


@add_start_docstrings(
"""
OmDetTurbo Model (consisting of a vision and a text backbone, and encoder-decoder architecture) outputting
raw hidden states.
""",
OMDET_TURBO_START_DOCSTRING,
)
class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel):
def __init__(self, config: OmDetTurboConfig):
super().__init__(config)
self.model = OmDetTurboModel(config)
self.vocab_size = config.text_config.vocab_size
self.post_init()

def get_input_embeddings(self):
return self.model.language_backbone.model.get_input_embeddings()

def set_input_embeddings(self, value):
self.model.language_backbone.model.set_input_embeddings(value)

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.model.language_backbone.model.resize_token_embeddings(
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of
)
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds

@add_start_docstrings_to_model_forward(OMDET_TURBO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
classes_input_ids: Tensor,
classes_attention_mask: Tensor,
tasks_input_ids: Tensor,
tasks_attention_mask: Tensor,
classes_structure: Tensor,
labels: Optional[Tensor] = None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]:
r"""
Returns:
Examples:
```python
>>> import requests
>>> from PIL import Image
>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection
>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-tiny")
>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-tiny")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> classes = ["cat", "remote"]
>>> task = "Detect {}.".format(", ".join(classes))
>>> inputs = processor(image, text=classes, task=task, return_tensors="pt")
>>> outputs = model(**inputs)
>>> # convert outputs (bounding boxes and class logits)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... classes=classes,
... target_sizes=[image.size[::-1]],
... score_threshold=0.3,
... nms_threshold=0.3,
>>> )[0]
>>> for score, class_name, box in zip(results["scores"], results["classes"], results["boxes"]):
... box = [round(i, 1) for i in box.tolist()]
... print(
... f"Detected {class_name} with confidence "
... f"{round(score.item(), 2)} at location {box}"
... )
Detected remote with confidence 0.76 at location [39.9, 71.3, 176.5, 117.9]
Detected cat with confidence 0.72 at location [345.1, 22.5, 639.7, 371.9]
Detected cat with confidence 0.65 at location [12.7, 53.8, 315.5, 475.3]
Detected remote with confidence 0.57 at location [333.4, 75.6, 370.7, 187.0]
```"""
if labels is not None:
raise NotImplementedError("Training is not implemented yet")

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

loss = None
outputs = self.model(
pixel_values=pixel_values,
classes_input_ids=classes_input_ids,
classes_attention_mask=classes_attention_mask,
tasks_input_ids=tasks_input_ids,
tasks_attention_mask=tasks_attention_mask,
classes_structure=classes_structure,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

if not return_dict:
object_detection_outputs = (loss, outputs[1][-1], outputs[2][-1]) + outputs[3:]
return tuple(output for output in object_detection_outputs if output is not None)

decoder_coord_logits = outputs.decoder_coords[-1]
decoder_class_logits = outputs.decoder_classes[-1]

return OmDetTurboObjectDetectionOutput(
loss=loss,
decoder_coord_logits=decoder_coord_logits,
decoder_class_logits=decoder_class_logits,
init_reference_points=outputs.init_reference_points,
intermediate_reference_points=outputs.intermediate_reference_points,
encoder_coord_logits=outputs.encoder_coord_logits,
encoder_class_logits=outputs.encoder_class_logits,
encoder_extracted_states=outputs.encoder_extracted_states,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
7 changes: 0 additions & 7 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6559,13 +6559,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class OmDetTurboModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class OmDetTurboPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
Loading

0 comments on commit 0980201

Please sign in to comment.