|
| 1 | +import torch.nn as nn |
| 2 | +import copy |
| 3 | +from detrex.layers import PositionEmbeddingSine |
| 4 | +from detrex.modeling.backbone import ResNet, BasicStem |
| 5 | +from detrex.modeling.neck import ChannelMapper |
| 6 | +from detectron2.layers import ShapeSpec |
| 7 | +from detectron2.config import LazyCall as L |
| 8 | +from detrex.modeling.matcher import HungarianMatcher |
| 9 | + |
| 10 | +from projects.unified_layout_analysis_v2.modeling import ( |
| 11 | + UniDETRMultiScales, |
| 12 | + DabDeformableDetrTransformer, |
| 13 | + DabDeformableDetrTransformerEncoder, |
| 14 | + DabDeformableDetrTransformerDecoder, |
| 15 | + TwoStageCriterion, |
| 16 | + DeepStem, |
| 17 | +) |
| 18 | + |
| 19 | +from projects.unified_layout_analysis_v2.modeling.uni_relation_prediction_head import ( |
| 20 | + UniRelationPredictionHead, |
| 21 | + HRIPNHead |
| 22 | +) |
| 23 | + |
| 24 | +from projects.unified_layout_analysis_v2.modeling.doc_transformer import ( |
| 25 | + DocTransformerEncoder, |
| 26 | + DocTransformer |
| 27 | +) |
| 28 | + |
| 29 | +from projects.unified_layout_analysis_v2.modeling.backbone.bert import ( |
| 30 | + Bert, |
| 31 | + TextTokenizer |
| 32 | +) |
| 33 | + |
| 34 | +# Define the main model |
| 35 | +model = L(UniDETRMultiScales)( |
| 36 | + backbone=L(ResNet)( |
| 37 | + stem=L(DeepStem)(in_channels=3, out_channels=64, norm="FrozenBN"), |
| 38 | + stages=L(ResNet.make_default_stages)( |
| 39 | + depth=18, |
| 40 | + norm="FrozenBN", |
| 41 | + ), |
| 42 | + out_features=["res2", "res3", "res4", "res5"], |
| 43 | + freeze_at=1, |
| 44 | + ), |
| 45 | + position_embedding=L(PositionEmbeddingSine)( |
| 46 | + num_pos_feats=128, |
| 47 | + temperature=10000, |
| 48 | + normalize=True, |
| 49 | + offset=-0.5, |
| 50 | + ), |
| 51 | + neck=L(ChannelMapper)( |
| 52 | + input_shapes={ |
| 53 | + "res3": ShapeSpec(channels=128), |
| 54 | + "res4": ShapeSpec(channels=256), |
| 55 | + "res5": ShapeSpec(channels=512), |
| 56 | + }, |
| 57 | + in_features=["res3", "res4", "res5"], |
| 58 | + out_channels=256, |
| 59 | + num_outs=4, |
| 60 | + kernel_size=1, |
| 61 | + norm_layer=L(nn.GroupNorm)(num_groups=32, num_channels=256), |
| 62 | + ), |
| 63 | + transformer=L(DabDeformableDetrTransformer)( |
| 64 | + encoder=L(DabDeformableDetrTransformerEncoder)( |
| 65 | + embed_dim=256, |
| 66 | + num_heads=8, |
| 67 | + feedforward_dim=2048, |
| 68 | + attn_dropout=0.0, |
| 69 | + ffn_dropout=0.0, |
| 70 | + num_layers=3, |
| 71 | + post_norm=False, |
| 72 | + num_feature_levels=4, |
| 73 | + ), |
| 74 | + decoder=L(DabDeformableDetrTransformerDecoder)( |
| 75 | + embed_dim=256, |
| 76 | + num_heads=8, |
| 77 | + feedforward_dim=2048, |
| 78 | + attn_dropout=0.0, |
| 79 | + ffn_dropout=0.0, |
| 80 | + num_layers=3, |
| 81 | + return_intermediate=True, |
| 82 | + num_feature_levels=4, |
| 83 | + ), |
| 84 | + as_two_stage=True, |
| 85 | + num_feature_levels=4, |
| 86 | + decoder_in_feature_level=[0, 1, 2, 3], |
| 87 | + ), |
| 88 | + embed_dim=256, |
| 89 | + num_classes=14, |
| 90 | + num_graphical_classes=2, |
| 91 | + num_types=3, |
| 92 | + relation_prediction_head=L(UniRelationPredictionHead)( |
| 93 | + relation_num_classes=2, |
| 94 | + embed_dim=256, |
| 95 | + hidden_dim=1024, |
| 96 | + ), # 0: a->a, 1: intra, 2: inter |
| 97 | + aux_loss=True, |
| 98 | + criterion=L(TwoStageCriterion)( |
| 99 | + num_classes=2, |
| 100 | + matcher=L(HungarianMatcher)( |
| 101 | + cost_class=2.0, |
| 102 | + cost_bbox=5.0, |
| 103 | + cost_giou=2.0, |
| 104 | + cost_class_type="focal_loss_cost", |
| 105 | + alpha=0.25, |
| 106 | + gamma=2.0, |
| 107 | + ), |
| 108 | + weight_dict={ |
| 109 | + "loss_class": 1, |
| 110 | + "loss_bbox": 5.0, |
| 111 | + "loss_giou": 2.0, |
| 112 | + }, |
| 113 | + loss_class_type="focal_loss", |
| 114 | + alpha=0.25, |
| 115 | + gamma=2.0, |
| 116 | + two_stage_binary_cls=False, |
| 117 | + ), |
| 118 | + as_two_stage=True, |
| 119 | + pixel_mean=[123.675, 116.280, 103.530], |
| 120 | + pixel_std=[58.395, 57.120, 57.375], |
| 121 | + device="cuda", |
| 122 | + windows_size=[6,8], |
| 123 | + freeze_language_model=False, |
| 124 | +) |
| 125 | + |
| 126 | +model.logical_role_relation_prediction_head=L(UniRelationPredictionHead)( |
| 127 | + relation_num_classes=1, |
| 128 | + embed_dim=256, |
| 129 | + hidden_dim=1024, |
| 130 | +) |
| 131 | + |
| 132 | +# Update auxiliary loss weight dictionary |
| 133 | +base_weight_dict = copy.deepcopy(model.criterion.weight_dict) |
| 134 | +if model.aux_loss: |
| 135 | + weight_dict = model.criterion.weight_dict |
| 136 | + aux_weight_dict = {f"{k}_{i}": v for i in range(model.transformer.decoder.num_layers - 1) for k, v in base_weight_dict.items()} |
| 137 | + weight_dict.update(aux_weight_dict) |
| 138 | + model.criterion.weight_dict = weight_dict |
| 139 | + |
| 140 | +# Additional loss weight updates |
| 141 | +model.criterion.weight_dict.update({ |
| 142 | + "loss_class_enc": 1.0, |
| 143 | + "loss_bbox_enc": 5.0, |
| 144 | + "loss_giou_enc": 2.0, |
| 145 | +}) |
| 146 | + |
| 147 | +# Add document transformer module |
| 148 | +model.doc_transformer = L(DocTransformer)( |
| 149 | + encoder=L(DocTransformerEncoder)( |
| 150 | + embed_dim=256, |
| 151 | + num_heads=8, |
| 152 | + feedforward_dim=2048, |
| 153 | + attn_dropout=0.0, |
| 154 | + ffn_dropout=0.0, |
| 155 | + num_layers=3, |
| 156 | + post_norm=False, |
| 157 | + batch_first=True, |
| 158 | + ), |
| 159 | + decoder=None, |
| 160 | +) |
| 161 | + |
| 162 | +# Add relation prediction head |
| 163 | +model.doc_relation_prediction_head = L(HRIPNHead)( |
| 164 | + relation_num_classes=2, |
| 165 | + embed_dim=256, |
| 166 | + hidden_dim=1024, |
| 167 | +) |
| 168 | + |
| 169 | +# Add language model |
| 170 | +model.language_model = L(Bert)( |
| 171 | + bert_model_type="bert-base-uncased", |
| 172 | + text_max_len=512, |
| 173 | + input_overlap_stride=0, |
| 174 | + output_embedding_dim=1024, |
| 175 | + max_batch_size=1, |
| 176 | + used_layers=12, |
| 177 | + used_hidden_idxs=[12], |
| 178 | + hidden_embedding_dim=768, |
| 179 | +) |
| 180 | + |
| 181 | +# Add tokenizer |
| 182 | +model.tokenizer = L(TextTokenizer)( |
| 183 | + model_type="bert-base-uncased", |
| 184 | + text_max_len=512, |
| 185 | + input_overlap_stride=0, |
| 186 | +) |
0 commit comments