"""
Modified from https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/fcn.py
"""
import oneflow.nn as nn
from .. import resnet
from .. import mobilenet_v3
from .seg_utils import _SimpleSegmentationModel, IntermediateLayerGetter
from ..utils import load_state_dict_from_url
from flowvision.models.registry import ModelCreator
model_urls = {
"fcn_resnet50_coco": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/segmentation/FCN/fcn_resnet50_coco.zip",
"fcn_resnet101_coco": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/segmentation/FCN/fcn_resnet101_coco.zip",
}
class FCN(_SimpleSegmentationModel):
"""
Implements a Fully-Convolutional Network for semantic segmentation.
Args:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass
class FCNHead(nn.Sequential):
def __init__(self, in_channels, channels):
inter_channels = in_channels // 4
layers = [
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(inter_channels, channels, 1),
]
super(FCNHead, self).__init__(*layers)
def _fcn_segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True):
if "resnet" in backbone_name:
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True],
)
out_layer = "layer4"
out_inplanes = 2048
aux_layer = "layer3"
aux_inplanes = 1024
elif "mobilenet_v3" in backbone_name:
backbone = mobilenet_v3.__dict__[backbone_name](
pretrained=pretrained_backbone, dilated=True
).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = (
[0]
+ [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)]
+ [len(backbone) - 1]
)
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels
else:
raise NotImplementedError(
"backbone {} is not supported as of now".format(backbone_name)
)
return_layers = {out_layer: "out"}
if aux:
return_layers[aux_layer] = "aux"
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = None
if aux:
aux_classifier = FCNHead(aux_inplanes, num_classes)
classifier = FCNHead(out_inplanes, num_classes)
base_model = FCN
fcn_model = base_model(backbone, classifier, aux_classifier)
return fcn_model
def _load_model(
arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs
):
if pretrained:
aux_loss = True
kwargs["pretrained_backbone"] = False
model = _fcn_segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
_load_weights(model, arch_type, backbone, progress)
return model
def _load_weights(model, arch_type, backbone, progress):
arch = arch_type + "_" + backbone + "_coco"
model_url = model_urls.get(arch, None)
if model_url is None:
raise NotImplementedError(
"pretrained {} is not supported as of now".format(arch)
)
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
[docs]@ModelCreator.register_model
def fcn_resnet50_coco(
pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs
):
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
For example:
.. code-block:: python
>>> import flowvision
>>> deeplabv3_mobilenet_v3_large_coco = flowvision.models.segmentation.fcn_resnet50_coco(pretrained=True, progress=True)
"""
return _load_model(
"fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs
)
[docs]@ModelCreator.register_model
def fcn_resnet101_coco(
pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs
):
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
For example:
.. code-block:: python
>>> import flowvision
>>> deeplabv3_mobilenet_v3_large_coco = flowvision.models.segmentation.fcn_resnet101_coco(pretrained=True, progress=True)
"""
return _load_model(
"fcn", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs
)