Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add SpyNet model #33

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions configs/_base_/models/spynet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model = dict(
type='SpyNet',
img_channels=3,
pyramid_levels=[
'level0', 'level1', 'level2', 'level3', 'level4', 'level5'
],
decoder=dict(
type='SpyNetDecoder',
in_channels=8,
pyramid_levels=[
'level0', 'level1', 'level2', 'level3', 'level4', 'level5'
],
out_channels=(32, 64, 32, 16, 2),
kernel_size=7,
stride=1,
warp_cfg=dict(type='Warp', align_corners=True),
act_cfg=dict(type='ReLU'),
))
9 changes: 6 additions & 3 deletions mmflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
build_flow_estimator)
from .decoders import (FlowNetCDecoder, FlowNetSDecoder, FlowRefine,
IRRPWCDecoder, MaskFlowNetDecoder, MaskFlowNetSDecoder,
NetE, OccRefine, OccShuffleUpsample, PWCNetDecoder)
NetE, OccRefine, OccShuffleUpsample, PWCNetDecoder,
SpyNetDecoder)
from .encoders import (CorrEncoder, FlowNetEncoder, FlowNetSDEncoder, NetC,
PWCNetEncoder, RAFTEncoder)
from .flow_estimators import (IRRPWC, FlowNet2, FlowNetC, FlowNetCSS, FlowNetS,
LiteFlowNet, MaskFlowNet, MaskFlowNetS, PWCNet)
LiteFlowNet, MaskFlowNet, MaskFlowNetS, PWCNet,
SpyNet)
from .losses import (MultiLevelBCE, MultiLevelCharbonnierLoss, MultiLevelEPE,
SequenceLoss)

Expand All @@ -21,5 +23,6 @@
'build_flow_estimator', 'COMPONENTS', 'build_components', 'MultiLevelBCE',
'MultiLevelEPE', 'MultiLevelCharbonnierLoss', 'SequenceLoss', 'IRRPWC',
'IRRPWCDecoder', 'FlowRefine', 'OccRefine', 'OccShuffleUpsample',
'FlowNet2', 'FlowNetCSS', 'MaskFlowNetDecoder', 'MaskFlowNet'
'FlowNet2', 'FlowNetCSS', 'MaskFlowNetDecoder', 'MaskFlowNet',
'SpyNetDecoder', 'SpyNet'
]
4 changes: 3 additions & 1 deletion mmflow/models/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from .maskflownet_decoder import MaskFlowNetDecoder, MaskFlowNetSDecoder
from .pwcnet_decoder import PWCNetDecoder
from .raft_decoder import RAFTDecoder
from .spynet_decoder import SpyNetDecoder

__all__ = [
'FlowNetCDecoder', 'FlowNetSDecoder', 'PWCNetDecoder',
'MaskFlowNetSDecoder', 'NetE', 'ContextNet', 'RAFTDecoder', 'FlowRefine',
'OccRefine', 'OccShuffleUpsample', 'IRRPWCDecoder', 'MaskFlowNetDecoder'
'OccRefine', 'OccShuffleUpsample', 'IRRPWCDecoder', 'MaskFlowNetDecoder',
'SpyNetDecoder'
]
155 changes: 155 additions & 0 deletions mmflow/models/decoders/spynet_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import Dict, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from mmflow.ops.builder import build_operators
from ..builder import DECODERS
from .base_decoder import BaseDecoder


class BasicLayers(BaseModule):

def __init__(self,
in_channels: int,
out_channels=(32, 64, 32, 16, 2),
kernel_size=7,
stride=1,
act_cfg=dict(type='ReLU', inplace=False),
init_cfg=None):
super().__init__(init_cfg=init_cfg)

convs = []
in_ch = in_channels
for out_ch in out_channels[:-1]:
convs.append(
ConvModule(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
act_cfg=act_cfg))
in_ch = out_ch
convs.append(
nn.Conv2d(
in_channels=in_ch,
out_channels=out_channels[-1],
kernel_size=kernel_size,
padding=kernel_size // 2))
self.layers = nn.Sequential(*convs)

def forward(self, x):
return self.layers(x)


@DECODERS.register_module()
class SpyNetDecoder(BaseDecoder):

def __init__(self,
in_channels,
pyramid_levels,
out_channels=(32, 64, 32, 16, 2),
kernel_size=7,
stride=1,
warp_cfg: dict = dict(type='Warp', align_corners=True),
act_cfg=dict(type='ReLU'),
init_cfg: Optional[Union[dict, list]] = None) -> None:
super().__init__(init_cfg=init_cfg)

self.in_channels = in_channels
self.pyramid_levels = pyramid_levels
self.pyramid_levels.sort()

self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.act_cfg = act_cfg

self.warp = build_operators(warp_cfg)

layers = []

for level in self.pyramid_levels:

layers.append([level, self.make_layers()])

self.decoders = nn.ModuleDict(layers)

def make_layers(self):
return BasicLayers(
in_channels=self.in_channels, out_channels=self.out_channels)

def forward(self, imgs1, imgs2):
flow = None

residual_flow_preds = dict()
previous_flow_preds = dict()
for level in self.pyramid_levels[::-1]:

img1 = imgs1[level]
img2 = imgs2[level]
_, _, H, W = img1.shape

if flow is None:
flow = torch.zeros(1, 2, H, W).to(img1)
else:
flow = F.interpolate(
flow, scale_factor=2, mode='bilinear',
align_corners=False) * 2.0

warped_img2 = self.warp(img2, flow)
previous_flow_preds[level] = flow

in_feat = torch.cat((img1, warped_img2, flow), dim=1)

residual_flow = self.decoders[level](in_feat)
flow += residual_flow

residual_flow_preds[level] = residual_flow

return flow, residual_flow_preds, previous_flow_preds

def losses(
self,
residual_flow_preds: Dict[str, torch.Tensor],
previous_flow_preds: Dict[str, torch.Tensor],
flow_gt: torch.Tensor,
valid: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""Compute optical flow loss.

Args:
flow_pred (Dict[str, Tensor]): multi-level predicted optical flow.
flow_gt (Tensor): The ground truth of optical flow.
valid (Tensor, optional): The valid mask. Defaults to None.

Returns:
Dict[str, Tensor]: The dict of losses.
"""
loss = dict()
loss['loss_flow'] = self.flow_loss(residual_flow_preds,
previous_flow_preds, flow_gt, valid)
return loss

def forward_train(self, imgs1, imgs2, flow_gt, valid=None):
_, residual_flow_preds, previous_flow_preds = self.forward(
imgs1=imgs1, imgs2=imgs2)

return self.losses(
residual_flow_preds=residual_flow_preds,
previous_flow_preds=previous_flow_preds,
flow_gt=flow_gt,
valid=valid)

def forward_test(self, imgs1, imgs2, img_metas=None):
flow, _, _ = self.forward(imgs1=imgs1, imgs2=imgs2)
flow_result = flow.permute(0, 2, 3, 1).cpu().data.numpy()

# unravel batch dim,
flow_result = list(flow_result)
flow_result = [dict(flow=f) for f in flow_result]

return self.get_flow(flow_result, img_metas=img_metas)
3 changes: 2 additions & 1 deletion mmflow/models/flow_estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from .maskflownet import MaskFlowNet, MaskFlowNetS
from .pwcnet import PWCNet
from .raft import RAFT
from .spynet import SpyNet

__all__ = [
'FlowNetC', 'FlowNetS', 'LiteFlowNet', 'PWCNet', 'MaskFlowNetS', 'RAFT',
'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet'
'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet', 'SpyNet'
]
58 changes: 58 additions & 0 deletions mmflow/models/flow_estimators/spynet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F

from ..builder import FLOW_ESTIMATORS, build_decoder
from .base import FlowEstimator


@FLOW_ESTIMATORS.register_module()
class SpyNet(FlowEstimator):

def __init__(self,
pyramid_levels,
decoder,
img_channels=3,
**kwargs) -> None:
super().__init__(**kwargs)
self.pyramid_levels = pyramid_levels
self.pyramid_levels.sort()
self.img_channels = img_channels
self.decoder = build_decoder(decoder)

def downsample_images(self, imgs):
imgs1 = dict()
imgs2 = dict()
img1 = imgs[:, :self.img_channels, ...]
img2 = imgs[:, self.img_channels:, ...]

imgs1[self.pyramid_levels[0]] = img1
imgs2[self.pyramid_levels[0]] = img2

for level in self.pyramid_levels[1:]:

img1 = F.avg_pool2d(
img1,
kernel_size=2,
stride=2,
)
img2 = F.avg_pool2d(
img2,
kernel_size=2,
stride=2,
)
imgs1[level] = img1
imgs2[level] = img2

return imgs1, imgs2

def forward_train(self, imgs, flow_gt, valid=None, img_meta=None):
imgs1, imgs2 = self.downsample_images(imgs)

return self.decoder.forward_train(
imgs1=imgs1, imgs2=imgs2, flow_gt=flow_gt, valid=valid)

def forward_test(self, imgs, img_metas=None):
imgs1, imgs2 = self.downsample_images(imgs)

return self.decoder.forward_test(
imgs1=imgs1, imgs2=imgs2, img_metas=img_metas)