diff --git a/official/projects/detr/tasks/detection.py b/official/projects/detr/tasks/detection.py index 870c10820f..abe78bf3fc 100644 --- a/official/projects/detr/tasks/detection.py +++ b/official/projects/detr/tasks/detection.py @@ -13,6 +13,7 @@ # limitations under the License. """DETR detection task definition.""" + from typing import Optional from absl import logging @@ -47,21 +48,25 @@ class DetectionTask(base_task.Task): def build_model(self): """Build DETR model.""" - input_specs = tf_keras.layers.InputSpec(shape=[None] + - self._task_config.model.input_size) + input_specs = tf_keras.layers.InputSpec( + shape=[None] + self._task_config.model.input_size + ) backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=self._task_config.model.backbone, - norm_activation_config=self._task_config.model.norm_activation) - - model = detr.DETR(backbone, - self._task_config.model.backbone_endpoint_name, - self._task_config.model.num_queries, - self._task_config.model.hidden_size, - self._task_config.model.num_classes, - self._task_config.model.num_encoder_layers, - self._task_config.model.num_decoder_layers) + norm_activation_config=self._task_config.model.norm_activation, + ) + + model = detr.DETR( + backbone, + self._task_config.model.backbone_endpoint_name, + self._task_config.model.num_queries, + self._task_config.model.hidden_size, + self._task_config.model.num_classes, + self._task_config.model.num_encoder_layers, + self._task_config.model.num_decoder_layers, + ) return model def initialize(self, model: tf_keras.Model): @@ -84,12 +89,13 @@ def initialize(self, model: tf_keras.Model): status = ckpt.restore(ckpt_dir_or_file) status.expect_partial().assert_existing_objects_matched() - logging.info('Finished loading pretrained checkpoint from %s', - ckpt_dir_or_file) + logging.info( + 'Finished loading pretrained checkpoint from %s', ckpt_dir_or_file + ) - def build_inputs(self, - params, - input_context: Optional[tf.distribute.InputContext] = None): + def build_inputs( + self, params, input_context: Optional[tf.distribute.InputContext] = None + ): """Build input dataset.""" if isinstance(params, coco.COCODataConfig): dataset = coco.COCODataLoader(params).load(input_context) @@ -100,14 +106,17 @@ def build_inputs(self, decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( - regenerate_source_id=decoder_cfg.regenerate_source_id) + regenerate_source_id=decoder_cfg.regenerate_source_id + ) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, - regenerate_source_id=decoder_cfg.regenerate_source_id) + regenerate_source_id=decoder_cfg.regenerate_source_id, + ) else: - raise ValueError('Unknown decoder type: {}!'.format( - params.decoder.type)) + raise ValueError( + 'Unknown decoder type: {}!'.format(params.decoder.type) + ) parser = detr_input.Parser( class_offset=self._task_config.losses.class_offset, @@ -118,7 +127,8 @@ def build_inputs(self, params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, - parser_fn=parser.parse_fn(params.is_training)) + parser_fn=parser.parse_fn(params.is_training), + ) dataset = reader.read(input_context=input_context) return dataset @@ -128,35 +138,44 @@ def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets): # The 1 is a constant that doesn't change the matching, it can be ommitted. # background: 0 cls_cost = self._task_config.losses.lambda_cls * tf.gather( - -tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1) + -tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1 + ) # Compute the L1 cost between boxes, paired_differences = self._task_config.losses.lambda_box * tf.abs( - tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1)) + tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1) + ) box_cost = tf.reduce_sum(paired_differences, axis=-1) # Compute the giou cost betwen boxes - giou_cost = self._task_config.losses.lambda_giou * -box_ops.bbox_generalized_overlap( - box_ops.cycxhw_to_yxyx(box_outputs), - box_ops.cycxhw_to_yxyx(box_targets)) + giou_cost = ( + self._task_config.losses.lambda_giou + * -box_ops.bbox_generalized_overlap( + box_ops.cycxhw_to_yxyx(box_outputs), + box_ops.cycxhw_to_yxyx(box_targets), + ) + ) total_cost = cls_cost + box_cost + giou_cost max_cost = ( - self._task_config.losses.lambda_cls * 0.0 + - self._task_config.losses.lambda_box * 4. + - self._task_config.losses.lambda_giou * 0.0) + self._task_config.losses.lambda_cls * 0.0 + + self._task_config.losses.lambda_box * 4.0 + + self._task_config.losses.lambda_giou * 0.0 + ) # Set pads to large constant valid = tf.expand_dims( - tf.cast(tf.not_equal(cls_targets, 0), dtype=total_cost.dtype), axis=1) + tf.cast(tf.not_equal(cls_targets, 0), dtype=total_cost.dtype), axis=1 + ) total_cost = (1 - valid) * max_cost + valid * total_cost # Set inf of nan to large constant total_cost = tf.where( tf.logical_or(tf.math.is_nan(total_cost), tf.math.is_inf(total_cost)), max_cost * tf.ones_like(total_cost, dtype=total_cost.dtype), - total_cost) + total_cost, + ) return total_cost @@ -168,7 +187,8 @@ def build_losses(self, outputs, labels, aux_losses=None): box_targets = labels['boxes'] cost = self._compute_cost( - cls_outputs, box_outputs, cls_targets, box_targets) + cls_outputs, box_outputs, cls_targets, box_targets + ) _, indices = matchers.hungarian_matching(cost) indices = tf.stop_gradient(indices) @@ -179,31 +199,41 @@ def build_losses(self, outputs, labels, aux_losses=None): background = tf.equal(cls_targets, 0) num_boxes = tf.reduce_sum( - tf.cast(tf.logical_not(background), tf.float32), axis=-1) + tf.cast(tf.logical_not(background), tf.float32), axis=-1 + ) # Down-weight background to account for class imbalance. xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=cls_targets, logits=cls_assigned) + labels=cls_targets, logits=cls_assigned + ) cls_loss = self._task_config.losses.lambda_cls * tf.where( - background, self._task_config.losses.background_cls_weight * xentropy, - xentropy) + background, + self._task_config.losses.background_cls_weight * xentropy, + xentropy, + ) cls_weights = tf.where( background, self._task_config.losses.background_cls_weight * tf.ones_like(cls_loss), - tf.ones_like(cls_loss)) + tf.ones_like(cls_loss), + ) # Box loss is only calculated on non-background class. l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1) box_loss = self._task_config.losses.lambda_box * tf.where( - background, tf.zeros_like(l_1), l_1) + background, tf.zeros_like(l_1), l_1 + ) # Giou loss is only calculated on non-background class. - giou = tf.linalg.diag_part(1.0 - box_ops.bbox_generalized_overlap( - box_ops.cycxhw_to_yxyx(box_assigned), - box_ops.cycxhw_to_yxyx(box_targets) - )) + giou = tf.linalg.diag_part( + 1.0 + - box_ops.bbox_generalized_overlap( + box_ops.cycxhw_to_yxyx(box_assigned), + box_ops.cycxhw_to_yxyx(box_targets), + ) + ) giou_loss = self._task_config.losses.lambda_giou * tf.where( - background, tf.zeros_like(giou), giou) + background, tf.zeros_like(giou), giou + ) # Consider doing all reduce once in train_step to speed up. num_boxes_per_replica = tf.reduce_sum(num_boxes) @@ -211,13 +241,11 @@ def build_losses(self, outputs, labels, aux_losses=None): replica_context = tf.distribute.get_replica_context() num_boxes_sum, cls_weights_sum = replica_context.all_reduce( tf.distribute.ReduceOp.SUM, - [num_boxes_per_replica, cls_weights_per_replica]) - cls_loss = tf.math.divide_no_nan( - tf.reduce_sum(cls_loss), cls_weights_sum) - box_loss = tf.math.divide_no_nan( - tf.reduce_sum(box_loss), num_boxes_sum) - giou_loss = tf.math.divide_no_nan( - tf.reduce_sum(giou_loss), num_boxes_sum) + [num_boxes_per_replica, cls_weights_per_replica], + ) + cls_loss = tf.math.divide_no_nan(tf.reduce_sum(cls_loss), cls_weights_sum) + box_loss = tf.math.divide_no_nan(tf.reduce_sum(box_loss), num_boxes_sum) + giou_loss = tf.math.divide_no_nan(tf.reduce_sum(giou_loss), num_boxes_sum) aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0 @@ -236,7 +264,8 @@ def build_metrics(self, training=True): annotation_file=self._task_config.annotation_file, include_mask=False, need_rescale_bboxes=True, - per_category_metrics=self._task_config.per_category_metrics) + per_category_metrics=self._task_config.per_category_metrics, + ) return metrics def train_step(self, inputs, model, optimizer, metrics=None): @@ -262,8 +291,11 @@ def train_step(self, inputs, model, optimizer, metrics=None): for output in outputs: # Computes per-replica loss. - layer_loss, layer_cls_loss, layer_box_loss, layer_giou_loss = self.build_losses( - outputs=output, labels=labels, aux_losses=model.losses) + layer_loss, layer_cls_loss, layer_box_loss, layer_giou_loss = ( + self.build_losses( + outputs=output, labels=labels, aux_losses=model.losses + ) + ) loss += layer_loss cls_loss += layer_cls_loss box_loss += layer_box_loss @@ -323,7 +355,8 @@ def validation_step(self, inputs, model, metrics=None): outputs = model(features, training=False)[-1] loss, cls_loss, box_loss, giou_loss = self.build_losses( - outputs=outputs, labels=labels, aux_losses=model.losses) + outputs=outputs, labels=labels, aux_losses=model.losses + ) # Multiply for logging. # Since we expect the gradient replica sum to happen in the optimizer, @@ -341,25 +374,33 @@ def validation_step(self, inputs, model, metrics=None): # This is for backward compatibility. if 'detection_boxes' not in outputs: detection_boxes = box_ops.cycxhw_to_yxyx( - outputs['box_outputs']) * tf.expand_dims( - tf.concat([ - labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2, - 1], - labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2, - 1] + outputs['box_outputs'] + ) * tf.expand_dims( + tf.concat( + [ + labels['image_info'][:, 1:2, 0], + labels['image_info'][:, 1:2, 1], + labels['image_info'][:, 1:2, 0], + labels['image_info'][:, 1:2, 1], ], - axis=1), - axis=1) + axis=1, + ), + axis=1, + ) else: detection_boxes = outputs['detection_boxes'] - detection_scores = tf.math.reduce_max( - tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1 - ) if 'detection_scores' not in outputs else outputs['detection_scores'] + if 'detection_scores' not in outputs: + detection_scores = tf.math.reduce_max( + tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1 + ) + else: + detection_scores = outputs['detection_scores'] if 'detection_classes' not in outputs: - detection_classes = tf.math.argmax( - outputs['cls_outputs'][:, :, 1:], axis=-1) + 1 + detection_classes = ( + tf.math.argmax(outputs['cls_outputs'][:, :, 1:], axis=-1) + 1 + ) else: detection_classes = outputs['detection_classes'] @@ -367,9 +408,12 @@ def validation_step(self, inputs, model, metrics=None): num_detections = tf.reduce_sum( tf.cast( tf.math.greater( - tf.math.reduce_max(outputs['cls_outputs'], axis=-1), 0), - tf.int32), - axis=-1) + tf.math.reduce_max(outputs['cls_outputs'], axis=-1), 0 + ), + tf.int32, + ), + axis=-1, + ) else: num_detections = outputs['num_detections'] @@ -379,7 +423,7 @@ def validation_step(self, inputs, model, metrics=None): 'detection_classes': detection_classes, 'num_detections': num_detections, 'source_id': labels['id'], - 'image_info': labels['image_info'] + 'image_info': labels['image_info'], } ground_truths = { @@ -387,13 +431,13 @@ def validation_step(self, inputs, model, metrics=None): 'height': labels['image_info'][:, 0:1, 0], 'width': labels['image_info'][:, 0:1, 1], 'num_detections': tf.reduce_sum( - tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1), + tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1 + ), 'boxes': labels['gt_boxes'], 'classes': labels['classes'], - 'is_crowds': labels['is_crowd'] + 'is_crowds': labels['is_crowd'], } - logs.update({'predictions': predictions, - 'ground_truths': ground_truths}) + logs.update({'predictions': predictions, 'ground_truths': ground_truths}) all_losses = { 'cls_loss': cls_loss, @@ -413,8 +457,8 @@ def aggregate_logs(self, state=None, step_outputs=None): state = self.coco_metric state.update_state( - step_outputs['ground_truths'], - step_outputs['predictions']) + step_outputs['ground_truths'], step_outputs['predictions'] + ) return state def reduce_aggregated_logs(self, aggregated_logs, global_step=None):