diff --git a/detection/backbonev2.py b/detection/backbonev2.py index 647c73e..7a35cb4 100644 --- a/detection/backbonev2.py +++ b/detection/backbonev2.py @@ -668,7 +668,7 @@ def forward(self, x): if not self.training: cls_out = (cls_out[0] + cls_out[1]) / 2 else: - cls_out = self.head(x.mean(-2)) + cls_out = self.head(x.flatten(2).mean(-1)) # for image classification return cls_out