Source code for flowvision.models.cswin

"""
Modified from https://github.com/microsoft/CSWin-Transformer/blob/main/models/cswin.py
"""

import numpy as np

import oneflow as flow
import oneflow.nn as nn

from flowvision.layers import DropPath, trunc_normal_
from .utils import load_state_dict_from_url
from .registry import ModelCreator


model_urls = {
    "cswin_tiny_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CSWin_Transformer/cswin_tiny_224.zip",
    "cswin_small_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CSWin_Transformer/cswin_small_224.zip",
    "cswin_base_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CSWin_Transformer/cswin_base_224.zip",
    "cswin_large_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CSWin_Transformer/cswin_large_224.zip",
    "cswin_base_384": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CSWin_Transformer/cswin_base_384.zip",
    "cswin_large_384": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CSWin_Transformer/cswin_large_384.zip",
}


class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def img2windows(img, H_sp, W_sp):
    """
    img: B C H W
    """
    B, C, H, W = img.shape
    img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
    img_perm = (
        img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp * W_sp, C)
    )
    return img_perm


def windows2img(img_splits_hw, H_sp, W_sp, H, W):
    """
    img_splits_hw: B' H W C
    """
    B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))

    img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1)
    img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return img


class LePEAttention(nn.Module):
    def __init__(
        self,
        dim,
        resolution,
        idx,
        split_size=7,
        dim_out=None,
        num_heads=8,
        attn_drop=0.0,
        proj_drop=0.0,
        qk_scale=None,
    ):
        super().__init__()
        self.dim = dim
        self.dim_out = dim_out or dim
        self.resolution = resolution
        self.split_size = split_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        # using idx to control vertical or horizontal attention
        if idx == -1:
            H_sp, W_sp = resolution, resolution
        elif idx == 0:
            H_sp, W_sp = resolution, split_size
        elif idx == 1:
            W_sp, H_sp = resolution, split_size
        else:
            print("ERROR MODE", idx)
            exit(0)

        self.H_sp = H_sp
        self.W_sp = W_sp
        stride = 1
        self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)

        self.attn_drop = nn.Dropout(attn_drop)

    def im2cswin(self, x):
        B, N, C = x.shape
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
        x = img2windows(x, self.H_sp, self.W_sp)
        x = (
            x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )
        return x

    def get_lepe(self, x, func):
        B, N, C = x.shape
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)

        H_sp, W_sp = self.H_sp, self.W_sp
        x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp)

        # local position embedding
        lepe = func(x)
        lepe = (
            lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp)
            .permute(0, 1, 3, 2)
            .contiguous()
        )

        x = (
            x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp)
            .permute(0, 1, 3, 2)
            .contiguous()
        )
        return x, lepe

    def forward(self, qkv):
        """
        qkv: B, L, C
        """
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Image to Window
        H = W = self.resolution
        B, L, C = q.shape
        assert L == H * W, "flatten img_tokens has wrong size"

        q = self.im2cswin(q)
        k = self.im2cswin(k)
        # get position embedding based on v
        v, lepe = self.get_lepe(v, self.get_v)
        q = q * self.scale
        attn = flow.matmul(
            q, k.transpose(-2, -1)
        )  # (B head N C) x (B, head C N) -> (B head N N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = flow.matmul(attn, v) + lepe
        x = x.transpose(1, 2).reshape(
            -1, self.H_sp * self.W_sp, C
        )  # (B, head N N) x (B head N C)

        x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C)

        return x


class CSWinBlock(nn.Module):
    def __init__(
        self,
        dim,
        reso,
        num_heads,
        split_size=7,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        last_stage=False,
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.patches_resolution = reso
        self.split_size = split_size
        self.mlp_ratio = mlp_ratio
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.norm1 = norm_layer(dim)

        if self.patches_resolution == split_size:
            last_stage = True
        if last_stage:
            self.branch_num = 1
        else:
            self.branch_num = 2
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(drop)

        if last_stage:
            self.attns = nn.ModuleList(
                [
                    LePEAttention(
                        dim,
                        resolution=self.patches_resolution,
                        idx=-1,
                        split_size=split_size,
                        num_heads=num_heads,
                        dim_out=dim,
                        qk_scale=qk_scale,
                        attn_drop=attn_drop,
                        proj_drop=drop,
                    )
                    for i in range(self.branch_num)
                ]
            )
        else:
            # split num_heads into two part, half for horizontal attention, half for vertical attention
            self.attns = nn.ModuleList(
                [
                    LePEAttention(
                        dim // 2,
                        resolution=self.patches_resolution,
                        idx=i,
                        split_size=split_size,
                        num_heads=num_heads // 2,
                        dim_out=dim // 2,
                        qk_scale=qk_scale,
                        attn_drop=attn_drop,
                        proj_drop=drop,
                    )
                    for i in range(self.branch_num)
                ]
            )

        mlp_hidden_dim = int(dim * mlp_ratio)

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            out_features=dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.norm2 = norm_layer(dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """

        H = W = self.patches_resolution
        B, L, C = x.shape
        assert L == H * W, "flatten img_tokens has wrong size"
        img = self.norm1(x)
        qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3)

        if self.branch_num == 2:
            x1 = self.attns[0](qkv[:, :, :, : C // 2])
            x2 = self.attns[1](qkv[:, :, :, C // 2 :])
            attened_x = flow.cat([x1, x2], dim=2)
        else:
            attened_x = self.attns[0](qkv)
        attened_x = self.proj(attened_x)
        x = x + self.drop_path(attened_x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class Merge_Block(nn.Module):
    def __init__(self, dim, dim_out, norm_layer=nn.LayerNorm):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim_out, 3, 2, 1)
        self.norm = norm_layer(dim_out)

    def forward(self, x):
        B, new_HW, C = x.shape
        H = W = int(np.sqrt(new_HW))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
        x = self.conv(x)
        B, C = x.shape[:2]
        x = x.view(B, C, -1).transpose(-2, -1).contiguous()
        x = self.norm(x)

        return x


# TODO: Add OneFlow backend into einops
class Rearrange(nn.Module):
    def __init__(self, img_size=224):
        super().__init__()
        self.img_size = img_size

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).view(B, H * W, C)
        return x


class CSWinTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=96,
        depth=[2, 2, 6, 2],
        split_size=[3, 5, 7],
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        hybrid_backbone=None,
        norm_layer=nn.LayerNorm,
        use_chk=False,
    ):
        super().__init__()
        self.use_chk = use_chk
        self.num_classes = num_classes
        self.num_features = (
            self.embed_dim
        ) = embed_dim  # num_features for consistency with other models
        heads = num_heads

        self.stage1_conv_embed = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, 7, 4, 2),
            # Rearrange('b c h w -> b (h w) c', h = img_size//4, w = img_size//4),
            Rearrange(img_size=img_size),
            nn.LayerNorm(embed_dim),
        )

        curr_dim = embed_dim
        # stochastic depth
        dpr = [
            x.item() for x in flow.linspace(0, drop_path_rate, sum(depth))
        ]  # stochastic depth decay rule
        self.stage1 = nn.ModuleList(
            [
                CSWinBlock(
                    dim=curr_dim,
                    num_heads=heads[0],
                    reso=img_size // 4,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    split_size=split_size[0],
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth[0])
            ]
        )

        self.merge1 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage2 = nn.ModuleList(
            [
                CSWinBlock(
                    dim=curr_dim,
                    num_heads=heads[1],
                    reso=img_size // 8,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    split_size=split_size[1],
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[np.sum(depth[:1]) + i],
                    norm_layer=norm_layer,
                )
                for i in range(depth[1])
            ]
        )

        self.merge2 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        temp_stage3 = []
        temp_stage3.extend(
            [
                CSWinBlock(
                    dim=curr_dim,
                    num_heads=heads[2],
                    reso=img_size // 16,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    split_size=split_size[2],
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[np.sum(depth[:2]) + i],
                    norm_layer=norm_layer,
                )
                for i in range(depth[2])
            ]
        )

        self.stage3 = nn.ModuleList(temp_stage3)

        self.merge3 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage4 = nn.ModuleList(
            [
                CSWinBlock(
                    dim=curr_dim,
                    num_heads=heads[3],
                    reso=img_size // 32,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    split_size=split_size[-1],
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[np.sum(depth[:-1]) + i],
                    norm_layer=norm_layer,
                    last_stage=True,
                )
                for i in range(depth[-1])
            ]
        )

        self.norm = norm_layer(curr_dim)
        # Classifier head
        self.head = (
            nn.Linear(curr_dim, num_classes) if num_classes > 0 else nn.Identity()
        )

        trunc_normal_(self.head.weight, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=""):
        if self.num_classes != num_classes:
            print("reset head to", num_classes)
            self.num_classes = num_classes
            self.head = (
                nn.Linear(self.out_dim, num_classes)
                if num_classes > 0
                else nn.Identity()
            )
            self.head = self.head.cuda()
            trunc_normal_(self.head.weight, std=0.02)
            if self.head.bias is not None:
                nn.init.constant_(self.head.bias, 0)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.stage1_conv_embed(x)
        for blk in self.stage1:
            x = blk(x)
        for pre, blocks in zip(
            [self.merge1, self.merge2, self.merge3],
            [self.stage2, self.stage3, self.stage4],
        ):
            x = pre(x)
            for blk in blocks:
                x = blk(x)
        x = self.norm(x)
        return flow.mean(x, dim=1)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


def _create_cswin_transformer(arch, pretrained=False, progress=True, **model_kwargs):
    model = CSWinTransformer(**model_kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        model.load_state_dict(state_dict)
    return model


# 224 model
[docs]@ModelCreator.register_model def cswin_tiny_224(pretrained=False, progress=True, **kwargs): """ Constructs CSwin-T 224x224 model. .. note:: CSwin-T 224x224 model from `"CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows" <https://arxiv.org/pdf/2107.00652.pdf>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> cswin_tiny_224 = flowvision.models.cswin_tiny_224(pretrained=False, progress=True) """ model_kwargs = dict( patch_size=4, embed_dim=64, depth=(1, 2, 21, 1), split_size=(1, 2, 7, 7), num_heads=(2, 4, 8, 16), mlp_ratio=4.0, **kwargs ) return _create_cswin_transformer( "cswin_tiny_224", pretrained=pretrained, progress=progress, **model_kwargs )
[docs]@ModelCreator.register_model def cswin_small_224(pretrained=False, progress=True, **kwargs): """ Constructs CSwin-S 224x224 model. .. note:: CSwin-S 224x224 model from `"CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows" <https://arxiv.org/pdf/2107.00652.pdf>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> cswin_small_224 = flowvision.models.cswin_small_224(pretrained=False, progress=True) """ model_kwargs = dict( patch_size=4, embed_dim=64, depth=(2, 4, 32, 2), split_size=(1, 2, 7, 7), num_heads=(2, 4, 8, 16), mlp_ratio=4.0, **kwargs ) return _create_cswin_transformer( "cswin_small_224", pretrained=pretrained, progress=progress, **model_kwargs )
[docs]@ModelCreator.register_model def cswin_base_224(pretrained=False, progress=True, **kwargs): """ Constructs CSwin-B 224x224 model. .. note:: CSwin-B 224x224 model from `"CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows" <https://arxiv.org/pdf/2107.00652.pdf>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> cswin_base_224 = flowvision.models.cswin_base_224(pretrained=False, progress=True) """ model_kwargs = dict( patch_size=4, embed_dim=96, depth=(2, 4, 32, 2), split_size=(1, 2, 7, 7), num_heads=(4, 8, 16, 32), mlp_ratio=4.0, **kwargs ) return _create_cswin_transformer( "cswin_base_224", pretrained=pretrained, progress=progress, **model_kwargs )
[docs]@ModelCreator.register_model def cswin_large_224(pretrained=False, progress=True, **kwargs): """ Constructs CSwin-L 224x224 model. .. note:: CSwin-L 224x224 model from `"CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows" <https://arxiv.org/pdf/2107.00652.pdf>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> cswin_large_224 = flowvision.models.cswin_large_224(pretrained=False, progress=True) """ model_kwargs = dict( patch_size=4, embed_dim=144, depth=(2, 4, 32, 2), split_size=(1, 2, 7, 7), num_heads=(6, 12, 24, 24), mlp_ratio=4.0, **kwargs ) return _create_cswin_transformer( "cswin_large_224", pretrained=pretrained, progress=progress, **model_kwargs )
[docs]@ModelCreator.register_model def cswin_base_384(pretrained=False, progress=True, **kwargs): """ Constructs CSwin-B 384x384 model. .. note:: CSwin-B 384x384 model from `"CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows" <https://arxiv.org/pdf/2107.00652.pdf>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> cswin_base_384 = flowvision.models.cswin_base_384(pretrained=False, progress=True) """ model_kwargs = dict( img_size=384, patch_size=4, embed_dim=96, depth=(2, 4, 32, 2), split_size=(1, 2, 12, 12), num_heads=(4, 8, 16, 32), mlp_ratio=4.0, **kwargs ) return _create_cswin_transformer( "cswin_base_384", pretrained=pretrained, progress=progress, **model_kwargs )
[docs]@ModelCreator.register_model def cswin_large_384(pretrained=False, progress=True, **kwargs): """ Constructs CSwin-L 384x384 model. .. note:: CSwin-L 384x384 model from `"CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows" <https://arxiv.org/pdf/2107.00652.pdf>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> cswin_large_384 = flowvision.models.cswin_large_384(pretrained=False, progress=True) """ model_kwargs = dict( img_size=384, patch_size=4, embed_dim=144, depth=(2, 4, 32, 2), split_size=(1, 2, 12, 12), num_heads=(6, 12, 24, 24), mlp_ratio=4.0, **kwargs ) return _create_cswin_transformer( "cswin_large_384", pretrained=pretrained, progress=progress, **model_kwargs )