Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720640505
  • Loading branch information
tensorflower-gardener committed Jan 28, 2025
1 parent 27d603e commit a2713fa
Showing 1 changed file with 121 additions and 77 deletions.
198 changes: 121 additions & 77 deletions official/projects/detr/tasks/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""DETR detection task definition."""

from typing import Optional

from absl import logging
Expand Down Expand Up @@ -47,21 +48,25 @@ class DetectionTask(base_task.Task):
def build_model(self):
"""Build DETR model."""

input_specs = tf_keras.layers.InputSpec(shape=[None] +
self._task_config.model.input_size)
input_specs = tf_keras.layers.InputSpec(
shape=[None] + self._task_config.model.input_size
)

backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=self._task_config.model.backbone,
norm_activation_config=self._task_config.model.norm_activation)

model = detr.DETR(backbone,
self._task_config.model.backbone_endpoint_name,
self._task_config.model.num_queries,
self._task_config.model.hidden_size,
self._task_config.model.num_classes,
self._task_config.model.num_encoder_layers,
self._task_config.model.num_decoder_layers)
norm_activation_config=self._task_config.model.norm_activation,
)

model = detr.DETR(
backbone,
self._task_config.model.backbone_endpoint_name,
self._task_config.model.num_queries,
self._task_config.model.hidden_size,
self._task_config.model.num_classes,
self._task_config.model.num_encoder_layers,
self._task_config.model.num_decoder_layers,
)
return model

def initialize(self, model: tf_keras.Model):
Expand All @@ -84,12 +89,13 @@ def initialize(self, model: tf_keras.Model):
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()

logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
logging.info(
'Finished loading pretrained checkpoint from %s', ckpt_dir_or_file
)

def build_inputs(self,
params,
input_context: Optional[tf.distribute.InputContext] = None):
def build_inputs(
self, params, input_context: Optional[tf.distribute.InputContext] = None
):
"""Build input dataset."""
if isinstance(params, coco.COCODataConfig):
dataset = coco.COCODataLoader(params).load(input_context)
Expand All @@ -100,14 +106,17 @@ def build_inputs(self,
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id)
regenerate_source_id=decoder_cfg.regenerate_source_id
)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
regenerate_source_id=decoder_cfg.regenerate_source_id,
)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
raise ValueError(
'Unknown decoder type: {}!'.format(params.decoder.type)
)

parser = detr_input.Parser(
class_offset=self._task_config.losses.class_offset,
Expand All @@ -118,7 +127,8 @@ def build_inputs(self,
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
parser_fn=parser.parse_fn(params.is_training),
)
dataset = reader.read(input_context=input_context)

return dataset
Expand All @@ -128,35 +138,44 @@ def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# background: 0
cls_cost = self._task_config.losses.lambda_cls * tf.gather(
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1)
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1
)

# Compute the L1 cost between boxes,
paired_differences = self._task_config.losses.lambda_box * tf.abs(
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1))
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1)
)
box_cost = tf.reduce_sum(paired_differences, axis=-1)

# Compute the giou cost betwen boxes
giou_cost = self._task_config.losses.lambda_giou * -box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_outputs),
box_ops.cycxhw_to_yxyx(box_targets))
giou_cost = (
self._task_config.losses.lambda_giou
* -box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_outputs),
box_ops.cycxhw_to_yxyx(box_targets),
)
)

total_cost = cls_cost + box_cost + giou_cost

max_cost = (
self._task_config.losses.lambda_cls * 0.0 +
self._task_config.losses.lambda_box * 4. +
self._task_config.losses.lambda_giou * 0.0)
self._task_config.losses.lambda_cls * 0.0
+ self._task_config.losses.lambda_box * 4.0
+ self._task_config.losses.lambda_giou * 0.0
)

# Set pads to large constant
valid = tf.expand_dims(
tf.cast(tf.not_equal(cls_targets, 0), dtype=total_cost.dtype), axis=1)
tf.cast(tf.not_equal(cls_targets, 0), dtype=total_cost.dtype), axis=1
)
total_cost = (1 - valid) * max_cost + valid * total_cost

# Set inf of nan to large constant
total_cost = tf.where(
tf.logical_or(tf.math.is_nan(total_cost), tf.math.is_inf(total_cost)),
max_cost * tf.ones_like(total_cost, dtype=total_cost.dtype),
total_cost)
total_cost,
)

return total_cost

Expand All @@ -168,7 +187,8 @@ def build_losses(self, outputs, labels, aux_losses=None):
box_targets = labels['boxes']

cost = self._compute_cost(
cls_outputs, box_outputs, cls_targets, box_targets)
cls_outputs, box_outputs, cls_targets, box_targets
)

_, indices = matchers.hungarian_matching(cost)
indices = tf.stop_gradient(indices)
Expand All @@ -179,45 +199,53 @@ def build_losses(self, outputs, labels, aux_losses=None):

background = tf.equal(cls_targets, 0)
num_boxes = tf.reduce_sum(
tf.cast(tf.logical_not(background), tf.float32), axis=-1)
tf.cast(tf.logical_not(background), tf.float32), axis=-1
)

# Down-weight background to account for class imbalance.
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=cls_targets, logits=cls_assigned)
labels=cls_targets, logits=cls_assigned
)
cls_loss = self._task_config.losses.lambda_cls * tf.where(
background, self._task_config.losses.background_cls_weight * xentropy,
xentropy)
background,
self._task_config.losses.background_cls_weight * xentropy,
xentropy,
)
cls_weights = tf.where(
background,
self._task_config.losses.background_cls_weight * tf.ones_like(cls_loss),
tf.ones_like(cls_loss))
tf.ones_like(cls_loss),
)

# Box loss is only calculated on non-background class.
l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1)
box_loss = self._task_config.losses.lambda_box * tf.where(
background, tf.zeros_like(l_1), l_1)
background, tf.zeros_like(l_1), l_1
)

# Giou loss is only calculated on non-background class.
giou = tf.linalg.diag_part(1.0 - box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_assigned),
box_ops.cycxhw_to_yxyx(box_targets)
))
giou = tf.linalg.diag_part(
1.0
- box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_assigned),
box_ops.cycxhw_to_yxyx(box_targets),
)
)
giou_loss = self._task_config.losses.lambda_giou * tf.where(
background, tf.zeros_like(giou), giou)
background, tf.zeros_like(giou), giou
)

# Consider doing all reduce once in train_step to speed up.
num_boxes_per_replica = tf.reduce_sum(num_boxes)
cls_weights_per_replica = tf.reduce_sum(cls_weights)
replica_context = tf.distribute.get_replica_context()
num_boxes_sum, cls_weights_sum = replica_context.all_reduce(
tf.distribute.ReduceOp.SUM,
[num_boxes_per_replica, cls_weights_per_replica])
cls_loss = tf.math.divide_no_nan(
tf.reduce_sum(cls_loss), cls_weights_sum)
box_loss = tf.math.divide_no_nan(
tf.reduce_sum(box_loss), num_boxes_sum)
giou_loss = tf.math.divide_no_nan(
tf.reduce_sum(giou_loss), num_boxes_sum)
[num_boxes_per_replica, cls_weights_per_replica],
)
cls_loss = tf.math.divide_no_nan(tf.reduce_sum(cls_loss), cls_weights_sum)
box_loss = tf.math.divide_no_nan(tf.reduce_sum(box_loss), num_boxes_sum)
giou_loss = tf.math.divide_no_nan(tf.reduce_sum(giou_loss), num_boxes_sum)

aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0

Expand All @@ -236,7 +264,8 @@ def build_metrics(self, training=True):
annotation_file=self._task_config.annotation_file,
include_mask=False,
need_rescale_bboxes=True,
per_category_metrics=self._task_config.per_category_metrics)
per_category_metrics=self._task_config.per_category_metrics,
)
return metrics

def train_step(self, inputs, model, optimizer, metrics=None):
Expand All @@ -262,8 +291,11 @@ def train_step(self, inputs, model, optimizer, metrics=None):

for output in outputs:
# Computes per-replica loss.
layer_loss, layer_cls_loss, layer_box_loss, layer_giou_loss = self.build_losses(
outputs=output, labels=labels, aux_losses=model.losses)
layer_loss, layer_cls_loss, layer_box_loss, layer_giou_loss = (
self.build_losses(
outputs=output, labels=labels, aux_losses=model.losses
)
)
loss += layer_loss
cls_loss += layer_cls_loss
box_loss += layer_box_loss
Expand Down Expand Up @@ -323,7 +355,8 @@ def validation_step(self, inputs, model, metrics=None):

outputs = model(features, training=False)[-1]
loss, cls_loss, box_loss, giou_loss = self.build_losses(
outputs=outputs, labels=labels, aux_losses=model.losses)
outputs=outputs, labels=labels, aux_losses=model.losses
)

# Multiply for logging.
# Since we expect the gradient replica sum to happen in the optimizer,
Expand All @@ -341,35 +374,46 @@ def validation_step(self, inputs, model, metrics=None):
# This is for backward compatibility.
if 'detection_boxes' not in outputs:
detection_boxes = box_ops.cycxhw_to_yxyx(
outputs['box_outputs']) * tf.expand_dims(
tf.concat([
labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2,
1],
labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2,
1]
outputs['box_outputs']
) * tf.expand_dims(
tf.concat(
[
labels['image_info'][:, 1:2, 0],
labels['image_info'][:, 1:2, 1],
labels['image_info'][:, 1:2, 0],
labels['image_info'][:, 1:2, 1],
],
axis=1),
axis=1)
axis=1,
),
axis=1,
)
else:
detection_boxes = outputs['detection_boxes']

detection_scores = tf.math.reduce_max(
tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1
) if 'detection_scores' not in outputs else outputs['detection_scores']
if 'detection_scores' not in outputs:
detection_scores = tf.math.reduce_max(
tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1
)
else:
detection_scores = outputs['detection_scores']

if 'detection_classes' not in outputs:
detection_classes = tf.math.argmax(
outputs['cls_outputs'][:, :, 1:], axis=-1) + 1
detection_classes = (
tf.math.argmax(outputs['cls_outputs'][:, :, 1:], axis=-1) + 1
)
else:
detection_classes = outputs['detection_classes']

if 'num_detections' not in outputs:
num_detections = tf.reduce_sum(
tf.cast(
tf.math.greater(
tf.math.reduce_max(outputs['cls_outputs'], axis=-1), 0),
tf.int32),
axis=-1)
tf.math.reduce_max(outputs['cls_outputs'], axis=-1), 0
),
tf.int32,
),
axis=-1,
)
else:
num_detections = outputs['num_detections']

Expand All @@ -379,21 +423,21 @@ def validation_step(self, inputs, model, metrics=None):
'detection_classes': detection_classes,
'num_detections': num_detections,
'source_id': labels['id'],
'image_info': labels['image_info']
'image_info': labels['image_info'],
}

ground_truths = {
'source_id': labels['id'],
'height': labels['image_info'][:, 0:1, 0],
'width': labels['image_info'][:, 0:1, 1],
'num_detections': tf.reduce_sum(
tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1),
tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1
),
'boxes': labels['gt_boxes'],
'classes': labels['classes'],
'is_crowds': labels['is_crowd']
'is_crowds': labels['is_crowd'],
}
logs.update({'predictions': predictions,
'ground_truths': ground_truths})
logs.update({'predictions': predictions, 'ground_truths': ground_truths})

all_losses = {
'cls_loss': cls_loss,
Expand All @@ -413,8 +457,8 @@ def aggregate_logs(self, state=None, step_outputs=None):
state = self.coco_metric

state.update_state(
step_outputs['ground_truths'],
step_outputs['predictions'])
step_outputs['ground_truths'], step_outputs['predictions']
)
return state

def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
Expand Down

0 comments on commit a2713fa

Please sign in to comment.