Source code for flowvision.models.segmentation.lraspp

"""
Modified from https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/lraspp.py
"""
from collections import OrderedDict

from oneflow import nn, Tensor
from oneflow.nn import functional as F
from typing import Dict

from .. import mobilenet_v3
from .seg_utils import _SimpleSegmentationModel, IntermediateLayerGetter
from flowvision.models.utils import load_state_dict_from_url
from flowvision.models.registry import ModelCreator


model_urls = {
    "lraspp_mobilenet_v3_large_coco": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/segmentation/lraspp/lraspp_mobilenet_v3_large_coco.zip",
}


class LRASPP(nn.Module):
    """
    Implements a Lite R-ASPP Network for semantic segmentation from
    `"Searching for MobileNetV3"
    <https://arxiv.org/abs/1905.02244>`_.
    Args:
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "high" for the high level feature map and "low" for the low level feature map.
        low_channels (int): the number of channels of the low level features.
        high_channels (int): the number of channels of the high level features.
        num_classes (int): number of output classes of the model (including the background).
        inter_channels (int, optional): the number of channels for intermediate computations.
    """

    def __init__(
        self, backbone, low_channels, high_channels, num_classes, inter_channels=128
    ):
        super().__init__()
        self.backbone = backbone
        self.classifier = LRASPPHead(
            low_channels, high_channels, num_classes, inter_channels
        )

    def forward(self, input):
        features = self.backbone(input)
        out = self.classifier(features)
        out = F.interpolate(
            out, size=input.shape[-2:], mode="bilinear", align_corners=False
        )

        result = OrderedDict()
        result["out"] = out

        return result


class LRASPPHead(nn.Module):
    def __init__(self, low_channels, high_channels, num_classes, inter_channels):
        super().__init__()
        self.cbr = nn.Sequential(
            nn.Conv2d(high_channels, inter_channels, 1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
        )
        self.scale = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(high_channels, inter_channels, 1, bias=False),
            nn.Sigmoid(),
        )
        self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
        self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)

    def forward(self, input: Dict[str, Tensor]) -> Tensor:
        low = input["low"]
        high = input["high"]

        x = self.cbr(high)
        s = self.scale(high)
        x = x * s
        x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)

        return self.low_classifier(low) + self.high_classifier(x)


def _load_weights(model, arch_type, backbone, progress):
    arch = arch_type + "_" + backbone + "_coco"
    model_url = model_urls.get(arch, None)
    if model_url is None:
        raise NotImplementedError(
            "pretrained {} is not supported as of now".format(arch)
        )
    else:
        state_dict = load_state_dict_from_url(model_url, progress=progress)
        model.load_state_dict(state_dict)


def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
    backbone = mobilenet_v3.__dict__[backbone_name](
        pretrained=pretrained_backbone, dilated=True
    ).features

    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
    stage_indices = (
        [0]
        + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)]
        + [len(backbone) - 1]
    )
    low_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
    high_pos = stage_indices[-1]  # use C5 which has output_stride = 16
    low_channels = backbone[low_pos].out_channels
    high_channels = backbone[high_pos].out_channels

    backbone = IntermediateLayerGetter(
        backbone, return_layers={str(low_pos): "low", str(high_pos): "high"}
    )

    model = LRASPP(backbone, low_channels, high_channels, num_classes)
    return model


[docs]@ModelCreator.register_model def lraspp_mobilenet_v3_large_coco( pretrained=False, progress=True, num_classes=21, **kwargs ): """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr num_classes (int): number of output classes of the model (including the background) For example: .. code-block:: python >>> import flowvision >>> lraspp_mobilenet_v3_large_coco = flowvision.models.segmentation.lraspp_mobilenet_v3_large_coco(pretrained=True, progress=True) """ if kwargs.pop("aux_loss", False): raise NotImplementedError("This model does not use auxiliary loss") backbone_name = "mobilenet_v3_large" model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs) if pretrained: _load_weights(model, "lraspp", backbone_name, progress) return model