Source code for flowvision.models.levit

"""
Modified from https://github.com/facebookresearch/LeViT
"""
import itertools

import oneflow as flow

from flowvision.layers import trunc_normal_
from .utils import load_state_dict_from_url
from .registry import ModelCreator


model_urls = {
    "levit_128s": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/LeViT/levit_128s.zip",
    "levit_128": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/LeViT/levit_128.zip",
    "levit_192": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/LeViT/levit_192.zip",
    "levit_256": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/LeViT/levit_256.zip",
    "levit_384": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/LeViT/levit_384.zip",
}


class Conv2d_BN(flow.nn.Sequential):
    def __init__(
        self,
        a,
        b,
        ks=1,
        stride=1,
        pad=0,
        dilation=1,
        groups=1,
        bn_weight_init=1,
        resolution=-10000,
    ):
        super().__init__()
        self.add_module(
            "c", flow.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)
        )
        bn = flow.nn.BatchNorm2d(b)
        flow.nn.init.constant_(bn.weight, bn_weight_init)
        flow.nn.init.constant_(bn.bias, 0)
        self.add_module("bn", bn)


class Linear_BN(flow.nn.Sequential):
    def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
        super().__init__()
        self.add_module("c", flow.nn.Linear(a, b, bias=False))
        bn = flow.nn.BatchNorm1d(b)
        flow.nn.init.constant_(bn.weight, bn_weight_init)
        flow.nn.init.constant_(bn.bias, 0)
        self.add_module("bn", bn)

    def forward(self, x):
        l, bn = self._modules.values()
        x = l(x)
        return bn(x.flatten(0, 1)).reshape(*x.shape)


class BN_Linear(flow.nn.Sequential):
    def __init__(self, a, b, bias=True, std=0.02):
        super().__init__()
        self.add_module("bn", flow.nn.BatchNorm1d(a))
        l = flow.nn.Linear(a, b, bias=bias)
        trunc_normal_(l.weight, std=std)
        if bias:
            flow.nn.init.constant_(l.bias, 0)
        self.add_module("l", l)


def b16(n, activation, resolution=224):
    return flow.nn.Sequential(
        Conv2d_BN(3, n // 8, 3, 2, 1, resolution=resolution),
        activation(),
        Conv2d_BN(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
        activation(),
        Conv2d_BN(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
        activation(),
        Conv2d_BN(n // 2, n, 3, 2, 1, resolution=resolution // 8),
    )


class Residual(flow.nn.Module):
    def __init__(self, m, drop):
        super().__init__()
        self.m = m
        self.drop = drop

    def forward(self, x):
        if self.training and self.drop > 0:
            return (
                x
                + self.m(x)
                * (flow.rand(x.size(0), 1, 1, device=x.device) > self.drop)
                .div(1 - self.drop)
                .detach()
            )
        else:
            return x + self.m(x)


class Attention(flow.nn.Module):
    def __init__(
        self, dim, key_dim, num_heads=8, attn_ratio=4, activation=None, resolution=14
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = key_dim ** -0.5
        self.key_dim = key_dim
        self.nh_kd = nh_kd = key_dim * num_heads
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * num_heads
        self.attn_ratio = attn_ratio
        h = self.dh + nh_kd * 2
        self.qkv = Linear_BN(dim, h, resolution=resolution)
        self.proj = flow.nn.Sequential(
            activation(),
            Linear_BN(self.dh, dim, bn_weight_init=0, resolution=resolution),
        )

        points = list(itertools.product(range(resolution), range(resolution)))
        N = len(points)
        attention_offsets = {}
        idxs = []
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = flow.nn.Parameter(
            flow.zeros(num_heads, len(attention_offsets))
        )
        self.register_buffer("attention_bias_idxs", flow.LongTensor(idxs).view(N, N))

    @flow.no_grad()
    def train(self, mode=True):
        super().train(mode)
        if mode and hasattr(self, "ab"):
            del self.ab
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]

    def forward(self, x):  # x (B,N,C)
        B, N, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.view(B, N, self.num_heads, -1).split(
            [self.key_dim, self.key_dim, self.d], dim=3
        )
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale + (
            self.attention_biases[:, self.attention_bias_idxs]
            if self.training
            else self.ab
        )
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
        x = self.proj(x)
        return x


class Subsample(flow.nn.Module):
    def __init__(self, stride, resolution):
        super().__init__()
        self.stride = stride
        self.resolution = resolution

    def forward(self, x):
        B, N, C = x.shape
        x = x.view(B, self.resolution, self.resolution, C)[
            :, :: self.stride, :: self.stride
        ].reshape(B, -1, C)
        return x


class AttentionSubsample(flow.nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        key_dim,
        num_heads=8,
        attn_ratio=2,
        activation=None,
        stride=2,
        resolution=14,
        resolution_=7,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = key_dim ** -0.5
        self.key_dim = key_dim
        self.nh_kd = nh_kd = key_dim * num_heads
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * self.num_heads
        self.attn_ratio = attn_ratio
        self.resolution_ = resolution_
        self.resolution_2 = resolution_ ** 2
        h = self.dh + nh_kd
        self.kv = Linear_BN(in_dim, h, resolution=resolution)

        self.q = flow.nn.Sequential(
            Subsample(stride, resolution),
            Linear_BN(in_dim, nh_kd, resolution=resolution_),
        )
        self.proj = flow.nn.Sequential(
            activation(), Linear_BN(self.dh, out_dim, resolution=resolution_)
        )

        self.stride = stride
        self.resolution = resolution
        points = list(itertools.product(range(resolution), range(resolution)))
        points_ = list(itertools.product(range(resolution_), range(resolution_)))
        N = len(points)
        N_ = len(points_)
        attention_offsets = {}
        idxs = []
        for p1 in points_:
            for p2 in points:
                size = 1
                offset = (
                    abs(p1[0] * stride - p2[0] + (size - 1) / 2),
                    abs(p1[1] * stride - p2[1] + (size - 1) / 2),
                )
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = flow.nn.Parameter(
            flow.zeros(num_heads, len(attention_offsets))
        )
        self.register_buffer("attention_bias_idxs", flow.LongTensor(idxs).view(N_, N))

    @flow.no_grad()
    def train(self, mode=True):
        super().train(mode)
        if mode and hasattr(self, "ab"):
            del self.ab
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]

    def forward(self, x):
        B, N, C = x.shape
        k, v = (
            self.kv(x)
            .view(B, N, self.num_heads, -1)
            .split([self.key_dim, self.d], dim=3)
        )
        k = k.permute(0, 2, 1, 3)  # BHNC
        v = v.permute(0, 2, 1, 3)  # BHNC
        q = (
            self.q(x)
            .view(B, self.resolution_2, self.num_heads, self.key_dim)
            .permute(0, 2, 1, 3)
        )

        attn = (q @ k.transpose(-2, -1)) * self.scale + (
            self.attention_biases[:, self.attention_bias_idxs]
            if self.training
            else self.ab
        )
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
        x = self.proj(x)
        return x


class LeViT(flow.nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=[192],
        key_dim=[64],
        depth=[12],
        num_heads=[3],
        attn_ratio=[2],
        mlp_ratio=[2],
        hybrid_backbone=None,
        down_ops=[],
        attention_activation=flow.nn.Hardswish,
        mlp_activation=flow.nn.Hardswish,
        distillation=True,
        drop_path=0,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.num_features = embed_dim[-1]
        self.embed_dim = embed_dim
        self.distillation = distillation

        self.patch_embed = hybrid_backbone

        self.blocks = []
        down_ops.append([""])
        resolution = img_size // patch_size
        for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
            zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)
        ):
            for _ in range(dpth):
                self.blocks.append(
                    Residual(
                        Attention(
                            ed,
                            kd,
                            nh,
                            attn_ratio=ar,
                            activation=attention_activation,
                            resolution=resolution,
                        ),
                        drop_path,
                    )
                )
                if mr > 0:
                    h = int(ed * mr)
                    self.blocks.append(
                        Residual(
                            flow.nn.Sequential(
                                Linear_BN(ed, h, resolution=resolution),
                                mlp_activation(),
                                Linear_BN(
                                    h, ed, bn_weight_init=0, resolution=resolution
                                ),
                            ),
                            drop_path,
                        )
                    )
            if do[0] == "Subsample":
                # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
                resolution_ = (resolution - 1) // do[5] + 1
                self.blocks.append(
                    AttentionSubsample(
                        *embed_dim[i : i + 2],
                        key_dim=do[1],
                        num_heads=do[2],
                        attn_ratio=do[3],
                        activation=attention_activation,
                        stride=do[5],
                        resolution=resolution,
                        resolution_=resolution_
                    )
                )
                resolution = resolution_
                if do[4] > 0:  # mlp_ratio
                    h = int(embed_dim[i + 1] * do[4])
                    self.blocks.append(
                        Residual(
                            flow.nn.Sequential(
                                Linear_BN(embed_dim[i + 1], h, resolution=resolution),
                                mlp_activation(),
                                Linear_BN(
                                    h,
                                    embed_dim[i + 1],
                                    bn_weight_init=0,
                                    resolution=resolution,
                                ),
                            ),
                            drop_path,
                        )
                    )
        self.blocks = flow.nn.Sequential(*self.blocks)

        # Classifier head
        self.head = (
            BN_Linear(embed_dim[-1], num_classes)
            if num_classes > 0
            else flow.nn.Identity()
        )
        if distillation:
            self.head_dist = (
                BN_Linear(embed_dim[-1], num_classes)
                if num_classes > 0
                else flow.nn.Identity()
            )

    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.blocks(x)
        x = x.mean(1)
        if self.distillation:
            x = self.head(x), self.head_dist(x)
            if not self.training:
                x = (x[0] + x[1]) / 2
        else:
            x = self.head(x)
        return x


def model_factory(
    C, D, X, N, drop_path, num_classes, distillation, pretrained, name, progress=True
):
    embed_dim = [int(x) for x in C.split("_")]
    num_heads = [int(x) for x in N.split("_")]
    depth = [int(x) for x in X.split("_")]
    act = flow.nn.Hardswish
    model = LeViT(
        patch_size=16,
        embed_dim=embed_dim,
        num_heads=num_heads,
        key_dim=[D] * 3,
        depth=depth,
        attn_ratio=[2, 2, 2],
        mlp_ratio=[2, 2, 2],
        down_ops=[
            # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
            ["Subsample", D, embed_dim[0] // D, 4, 2, 2],
            ["Subsample", D, embed_dim[1] // D, 4, 2, 2],
        ],
        attention_activation=act,
        mlp_activation=act,
        hybrid_backbone=b16(embed_dim[0], activation=act),
        num_classes=num_classes,
        drop_path=drop_path,
        distillation=distillation,
    )
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[name], progress=progress)
        model.load_state_dict(state_dict)

    return model


specification = {
    "levit_128s": {
        "C": "128_256_384",
        "D": 16,
        "N": "4_6_8",
        "X": "2_3_4",
        "drop_path": 0,
    },
    "levit_128": {
        "C": "128_256_384",
        "D": 16,
        "N": "4_8_12",
        "X": "4_4_4",
        "drop_path": 0,
    },
    "levit_192": {
        "C": "192_288_384",
        "D": 32,
        "N": "3_5_6",
        "X": "4_4_4",
        "drop_path": 0,
    },
    "levit_256": {
        "C": "256_384_512",
        "D": 32,
        "N": "4_6_8",
        "X": "4_4_4",
        "drop_path": 0,
    },
    "levit_384": {
        "C": "384_512_768",
        "D": 32,
        "N": "6_9_12",
        "X": "4_4_4",
        "drop_path": 0.1,
    },
}


[docs]@ModelCreator.register_model def levit_128s(num_classes=1000, distillation=True, pretrained=False): """ Constructs the LeViT-128S model. .. note:: LeViT-128S model architecture from the `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference <https://arxiv.org/abs/2104.01136>`_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> levit_128s = flowvision.models.levit_128s(pretrained=False, progress=True) """ return model_factory( **specification["levit_128s"], num_classes=num_classes, distillation=distillation, pretrained=pretrained, name="levit_128s" )
[docs]@ModelCreator.register_model def levit_128(num_classes=1000, distillation=True, pretrained=False): """ Constructs the LeViT-128 model. .. note:: LeViT-128 model architecture from the `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference <https://arxiv.org/abs/2104.01136>`_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> levit_128 = flowvision.models.levit_128(pretrained=False, progress=True) """ return model_factory( **specification["levit_128"], num_classes=num_classes, distillation=distillation, pretrained=pretrained, name="levit_128" )
[docs]@ModelCreator.register_model def levit_192(num_classes=1000, distillation=True, pretrained=False): """ Constructs the LeViT-192 model. .. note:: LeViT-192 model architecture from the `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference <https://arxiv.org/abs/2104.01136>`_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> levit_192 = flowvision.models.levit_192(pretrained=False, progress=True) """ return model_factory( **specification["levit_192"], num_classes=num_classes, distillation=distillation, pretrained=pretrained, name="levit_192" )
[docs]@ModelCreator.register_model def levit_256(num_classes=1000, distillation=True, pretrained=False): """ Constructs the LeViT-256 model. .. note:: LeViT-256 model architecture from the `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference <https://arxiv.org/abs/2104.01136>`_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> levit_256 = flowvision.models.levit_256(pretrained=False, progress=True) """ return model_factory( **specification["levit_256"], num_classes=num_classes, distillation=distillation, pretrained=pretrained, name="levit_256" )
[docs]@ModelCreator.register_model def levit_384(num_classes=1000, distillation=True, pretrained=False): """ Constructs the LeViT-384 model. .. note:: LeViT-384 model architecture from the `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference <https://arxiv.org/abs/2104.01136>`_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> levit_384 = flowvision.models.levit_384(pretrained=False, progress=True) """ return model_factory( **specification["levit_384"], num_classes=num_classes, distillation=distillation, pretrained=pretrained, name="levit_384" )