Source code for flowvision.models.rexnet_lite

"""
Modified from https://github.com/clovaai/rexnet/blob/master/rexnetv1_lite.py
"""
import oneflow as flow
import oneflow.nn as nn

from .utils import load_state_dict_from_url
from .registry import ModelCreator
from .helpers import make_divisible


model_urls = {
    "rexnet_lite_1_0": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/RexNet/rexnet_lite_1_0.zip",
    "rexnet_lite_1_3": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/RexNet/rexnet_lite_1_3.zip",
    "rexnet_lite_1_5": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/RexNet/rexnet_lite_1_5.zip",
    "rexnet_lite_2_0": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/RexNet/rexnet_lite_2_0.zip",
}


def _add_conv(
    out,
    in_channels,
    channels,
    kernel=1,
    stride=1,
    pad=0,
    num_group=1,
    active=True,
    relu6=True,
    bn_momentum=0.1,
    bn_eps=1e-5,
):
    out.append(
        nn.Conv2d(
            in_channels, channels, kernel, stride, pad, groups=num_group, bias=False
        )
    )
    out.append(nn.BatchNorm2d(channels, momentum=bn_momentum, eps=bn_eps))
    if active:
        out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True))


class LinearBottleneck(nn.Module):
    def __init__(
        self,
        in_channels,
        channels,
        t,
        kernel_size=3,
        stride=1,
        bn_momentum=0.1,
        bn_eps=1e-5,
        **kwargs
    ):
        super(LinearBottleneck, self).__init__(**kwargs)
        self.conv_shortcut = None
        self.use_shortcut = stride == 1 and in_channels <= channels
        self.in_channels = in_channels
        self.out_channels = channels
        out = []
        if t != 1:
            dw_channels = in_channels * t
            _add_conv(
                out,
                in_channels=in_channels,
                channels=dw_channels,
                bn_momentum=bn_momentum,
                bn_eps=bn_eps,
            )
        else:
            dw_channels = in_channels

        _add_conv(
            out,
            in_channels=dw_channels,
            channels=dw_channels * 1,
            kernel=kernel_size,
            stride=stride,
            pad=(kernel_size // 2),
            num_group=dw_channels,
            bn_momentum=bn_momentum,
            bn_eps=bn_eps,
        )

        _add_conv(
            out,
            in_channels=dw_channels,
            channels=channels,
            active=False,
            bn_momentum=bn_momentum,
            bn_eps=bn_eps,
        )

        self.out = nn.Sequential(*out)

    def forward(self, x):
        out = self.out(x)

        if self.use_shortcut:
            out[:, 0 : self.in_channels] += x
        return out


class ReXNetV1_lite(nn.Module):
    def __init__(
        self,
        fix_head_stem=False,
        divisible_value=8,
        input_ch=16,
        final_ch=164,
        multiplier=1.0,
        classes=1000,
        dropout_ratio=0.2,
        bn_momentum=0.1,
        bn_eps=1e-5,
        kernel_conf="333333",
    ):
        super(ReXNetV1_lite, self).__init__()

        layers = [1, 2, 2, 3, 3, 5]
        strides = [1, 2, 2, 2, 1, 2]
        kernel_sizes = [int(element) for element in kernel_conf]

        strides = sum(
            [
                [element] + [1] * (layers[idx] - 1)
                for idx, element in enumerate(strides)
            ],
            [],
        )
        ts = [1] * layers[0] + [6] * sum(layers[1:])
        kernel_sizes = sum(
            [[element] * layers[idx] for idx, element in enumerate(kernel_sizes)], []
        )
        self.num_convblocks = sum(layers[:])

        features = []
        inplanes = input_ch / multiplier if multiplier < 1.0 else input_ch
        first_channel = 32 / multiplier if multiplier < 1.0 or fix_head_stem else 32
        first_channel = make_divisible(
            int(round(first_channel * multiplier)), divisible_value
        )

        in_channels_group = []
        channels_group = []

        _add_conv(
            features,
            3,
            first_channel,
            kernel=3,
            stride=2,
            pad=1,
            bn_momentum=bn_momentum,
            bn_eps=bn_eps,
        )

        for i in range(self.num_convblocks):
            inplanes_divisible = make_divisible(
                int(round(inplanes * multiplier)), divisible_value
            )
            if i == 0:
                in_channels_group.append(first_channel)
                channels_group.append(inplanes_divisible)
            else:
                in_channels_group.append(inplanes_divisible)
                inplanes += final_ch / (self.num_convblocks - 1 * 1.0)
                inplanes_divisible = make_divisible(
                    int(round(inplanes * multiplier)), divisible_value
                )
                channels_group.append(inplanes_divisible)

        for block_idx, (in_c, c, t, k, s) in enumerate(
            zip(in_channels_group, channels_group, ts, kernel_sizes, strides)
        ):
            features.append(
                LinearBottleneck(
                    in_channels=in_c,
                    channels=c,
                    t=t,
                    kernel_size=k,
                    stride=s,
                    bn_momentum=bn_momentum,
                    bn_eps=bn_eps,
                )
            )

        pen_channels = (
            int(1280 * multiplier) if multiplier > 1 and not fix_head_stem else 1280
        )
        _add_conv(features, c, pen_channels, bn_momentum=bn_momentum, bn_eps=bn_eps)

        self.features = nn.Sequential(*features)
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        self.output = nn.Sequential(
            nn.Conv2d(pen_channels, 1024, 1, bias=True),
            nn.BatchNorm2d(1024, momentum=bn_momentum, eps=bn_eps),
            nn.ReLU6(inplace=True),
            nn.Dropout(dropout_ratio),
            nn.Conv2d(1024, classes, 1, bias=True),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.output(x).flatten(1)
        return x


def _create_rexnet_lite(arch, pretrained=False, progress=True, **model_kwargs):
    model = ReXNetV1_lite(**model_kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        model.load_state_dict(state_dict)
    return model


[docs]@ModelCreator.register_model def rexnet_lite_1_0(pretrained=False, progress=True, **kwargs): """ Constructs the ReXNet-lite model with width multiplier of 1.0. .. note:: ReXNet-lite model with width multiplier of 1.0 from the `Rethinking Channel Dimensions for Efficient Model Design <https://arxiv.org/pdf/2007.00992.pdf>`_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> rexnet_lite_1_0 = flowvision.models.rexnet_lite_1_0(pretrained=False, progress=True) """ model_kwargs = dict(multiplier=1.0, **kwargs) return _create_rexnet_lite( "rexnet_lite_1_0", pretrained=pretrained, progress=progress, **model_kwargs )
[docs]@ModelCreator.register_model def rexnet_lite_1_3(pretrained=False, progress=True, **kwargs): """ Constructs the ReXNet-lite model with width multiplier of 1.3. .. note:: ReXNet-lite model with width multiplier of 1.3 from the `Rethinking Channel Dimensions for Efficient Model Design <https://arxiv.org/pdf/2007.00992.pdf>`_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> rexnet_lite_1_3 = flowvision.models.rexnet_lite_1_3(pretrained=False, progress=True) """ model_kwargs = dict(multiplier=1.3, **kwargs) return _create_rexnet_lite( "rexnet_lite_1_3", pretrained=pretrained, progress=progress, **model_kwargs )
[docs]@ModelCreator.register_model def rexnet_lite_1_5(pretrained=False, progress=True, **kwargs): """ Constructs the ReXNet-lite model with width multiplier of 1.5. .. note:: ReXNet-lite model with width multiplier of 1.5 from the `Rethinking Channel Dimensions for Efficient Model Design <https://arxiv.org/pdf/2007.00992.pdf>`_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> rexnet_lite_1_5 = flowvision.models.rexnet_lite_1_5(pretrained=False, progress=True) """ model_kwargs = dict(multiplier=1.5, **kwargs) return _create_rexnet_lite( "rexnet_lite_1_5", pretrained=pretrained, progress=progress, **model_kwargs )
[docs]@ModelCreator.register_model def rexnet_lite_2_0(pretrained=False, progress=True, **kwargs): """ Constructs the ReXNet-lite model with width multiplier of 2.0. .. note:: ReXNet-lite model with width multiplier of 2.0 from the `Rethinking Channel Dimensions for Efficient Model Design <https://arxiv.org/pdf/2007.00992.pdf>`_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> rexnet_lite_2_0 = flowvision.models.rexnet_lite_2_0(pretrained=False, progress=True) """ model_kwargs = dict(multiplier=2.0, **kwargs) return _create_rexnet_lite( "rexnet_lite_2_0", pretrained=pretrained, progress=progress, **model_kwargs )