"""
Modified from https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
"""
import oneflow as flow
import oneflow.nn as nn
from flowvision.layers import DropPath, trunc_normal_
from .registry import ModelCreator
from .utils import load_state_dict_from_url
model_urls = {
"crossformer_tiny_patch4_group7_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CrossFormer/crossformer_tiny_patch4_group7_224.zip",
"crossformer_small_patch4_group7_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CrossFormer/crossformer_small_patch4_group7_224.zip",
"crossformer_base_patch4_group7_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CrossFormer/crossformer_base_patch4_group7_224.zip",
"crossformer_large_patch4_group7_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/CrossFormer/crossformer_large_patch4_group7_224.zip",
}
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DynamicPosBias(nn.Module):
def __init__(self, dim, num_heads, residual):
super(DynamicPosBias, self).__init__()
self.residual = residual
self.num_heads = num_heads
self.pos_dim = dim // 4
self.pos_proj = nn.Linear(2, self.pos_dim)
self.pos1 = nn.Sequential(
nn.LayerNorm(self.pos_dim),
nn.ReLU(inplace=True),
nn.Linear(self.pos_dim, self.pos_dim),
)
self.pos2 = nn.Sequential(
nn.LayerNorm(self.pos_dim),
nn.ReLU(inplace=True),
nn.Linear(self.pos_dim, self.pos_dim),
)
self.pos3 = nn.Sequential(
nn.LayerNorm(self.pos_dim),
nn.ReLU(inplace=True),
nn.Linear(self.pos_dim, self.num_heads),
)
def forward(self, biases):
if self.residual:
pos = self.pos_proj(biases) # 2Wh - 1 * 2Ww - 1, heads
pos = pos + self.pos1(pos)
pos = pos + self.pos2(pos)
pos = self.pos3(pos)
else:
pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
return pos
class Attention(nn.Module):
r""" Multi-head self attention module with dynamic position bias.
Args:
dim (int): Number of input channels
group_size (tuple[int]): The height and width of the group
num_heads (int): Number of attention heads
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: ``True``
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: ``0.0``
proj_drop (float, optional): Dropout ratio of output. Default: ``0.0``
"""
def __init__(
self,
dim,
group_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
position_bias=True,
):
super(Attention, self).__init__()
self.dim = dim
self.group_size = group_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.position_bias = position_bias
if position_bias:
self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
# generate mother-set
position_bias_h = flow.arange(
1 - self.group_size[0], self.group_size[0]
) # height index
position_bias_w = flow.arange(
1 - self.group_size[1], self.group_size[1]
) # width index
biases = flow.stack(
flow.meshgrid(position_bias_h, position_bias_w)
) # 2, wh, wh
biases = biases.flatten(1).transpose(0, 1).float()
self.register_buffer("biases", biases)
# get pair-wise relative position index for each token inside the group
coords_h = flow.arange(self.group_size[0])
coords_w = flow.arange(self.group_size[1])
coords = flow.stack(flow.meshgrid(coords_h, coords_w)) # 2, Wh, Ww
coords_flatten = flow.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.group_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.group_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: Input features with shape of (num_groups*B, N, C)
mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2] # obtain query, key and value
q = q * self.scale
# TODO: unsupported operation type: @, using flow.matmul
# attn = (q @ k.transpose(-2, -1))
attn = flow.matmul(q, k.transpose(-2, -1))
if self.position_bias:
pos = self.pos(self.biases) # 2Wh-1 * 2Ww-1, heads
# select position bias
relative_position_bias = pos[self.relative_position_index.view(-1)].view(
self.group_size[0] * self.group_size[1],
self.group_size[0] * self.group_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
1
).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
# x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = flow.matmul(attn, v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class CrossFormerBlock(nn.Module):
r""" CrossFormer Block.
Args:
dim (int): Number of input channels
input_resolution (tuple[int]): Input resulotion
num_heads (int): Number of attention heads
group_size (int): Group size
lsda_flag (int): Use SDA or LDA, 0 for SDA and 1 for LDA
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: ``True``
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
drop (float, optional): Dropout rate. Default: ``0.0``
attn_drop (float, optional): Attention dropout rate. Default: ``0.0``
drop_path (float, optional): Stochastic depth rate. Default: ``0.0``
act_layer (nn.Module, optional): Activation layer. Default: ``nn.GELU``
norm_layer (nn.Module, optional): Normalization layer. Default: ``nn.LayerNorm``
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
group_size=7,
lsda_flag=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
num_patch_size=1,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.group_size = group_size
self.lsda_flag = lsda_flag
self.mlp_ratio = mlp_ratio
self.num_patch_size = num_patch_size
if min(self.input_resolution) <= self.group_size:
# if group size is larger than input resolution, we don't partition groups
self.lsda_flag = 0
self.group_size = min(self.input_resolution)
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
group_size=pair(self.group_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
position_bias=True,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# group embeddings
G = self.group_size
if self.lsda_flag == 0: # 0 for SDA
x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)
else: # 1 for LDA
x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)
x = x.reshape(B * H * W // G ** 2, G ** 2, C)
# multi-head self-attention
x = self.attn(x, mask=self.attn_mask) # nW*B, G*G, C
# ungroup embeddings
x = x.reshape(B, H // G, W // G, G, G, C)
if self.lsda_flag == 0:
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C)
else:
x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C)
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature
dim (int): Number of input channels
norm_layer (nn.Module, optional): Normalization layer. Default: ``nn.LayerNorm``
"""
def __init__(
self,
input_resolution,
dim,
norm_layer=nn.LayerNorm,
patch_size=[2],
num_input_patch_size=1,
):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reductions = nn.ModuleList()
self.patch_size = patch_size
self.norm = norm_layer(dim)
for i, ps in enumerate(patch_size):
if i == len(patch_size) - 1:
out_dim = 2 * dim // 2 ** i
else:
out_dim = 2 * dim // 2 ** (i + 1)
stride = 2
padding = (ps - stride) // 2
self.reductions.append(
nn.Conv2d(dim, out_dim, kernel_size=ps, stride=stride, padding=padding)
)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = self.norm(x)
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
xs = []
for i in range(len(self.reductions)):
tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2)
xs.append(tmp_x)
x = flow.cat(xs, dim=2)
return x
class Stage(nn.Module):
""" CrossFormer blocks for one stage.
Args:
dim (int): Number of input channels
input_resolution (tuple[int]): Input resolution
depth (int): Number of blocks
num_heads (int): Number of attention heads
group_size (int): Variable G in the paper, one group has GxG embeddings
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: ``True``
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
drop (float, optional): Dropout rate. Default: ``0.0``
attn_drop (float, optional): Attention dropout rate. Default: ``0.0``
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: ``0.0``
norm_layer (nn.Module, optional): Normalization layer. Default: ``nn.LayerNorm``
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: ``None``
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
group_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
patch_size_end=[4],
num_patch_size=None,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
# TODO: add checkpoint to save memory
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList()
for i in range(depth):
lsda_flag = 0 if (i % 2 == 0) else 1
self.blocks.append(
CrossFormerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
group_size=group_size,
lsda_flag=lsda_flag,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list)
else drop_path,
norm_layer=norm_layer,
num_patch_size=num_patch_size,
)
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution,
dim=dim,
norm_layer=norm_layer,
patch_size=patch_size_end,
num_input_patch_size=num_patch_size,
)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: ``224``
patch_size (int): Patch token size. Default: ``[4]``
in_chans (int): Number of input image channels. Default: ``3``
embed_dim (int): Number of linear projection output channels. Default: ``96``
norm_layer (nn.Module, optional): Normalization layer. Default: ``None``
"""
def __init__(
self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None
):
super().__init__()
img_size = pair(img_size)
# patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0],
img_size[0] // patch_size[0],
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.projs = nn.ModuleList()
for i, ps in enumerate(patch_size):
if i == len(patch_size) - 1:
dim = embed_dim // 2 ** i
else:
dim = embed_dim // 2 ** (i + 1)
stride = patch_size[0]
padding = (ps - patch_size[0]) // 2
self.projs.append(
nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding)
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
xs = []
for i in range(len(self.projs)):
tx = self.projs[i](x).flatten(2).transpose(1, 2)
xs.append(tx) # B Ph*Pw C
x = flow.cat(xs, dim=2)
if self.norm is not None:
x = self.norm(x)
return x
class CrossFormer(nn.Module):
r""" CrossFormer
A OneFlow impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention` -
Args:
img_size (int | tuple(int)): Input image size. Default: ``224``
patch_size (int | tuple(int)): Patch size. Default: ``4``
in_chans (int): Number of input image channels. Default: ``3``
num_classes (int): Number of classes for classification head. Default: ``1000``
embed_dim (int): Patch embedding dimension. Default: ``96``
depths (tuple(int)): Depth of each stage
num_heads (tuple(int)): Number of attention heads in different layers
group_size (int): Group size. Default: ``7``
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: ``4``
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: ``True``
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: ``None``
drop_rate (float): Dropout rate. Default: ``0``
attn_drop_rate (float): Attention dropout rate. Default: ``0``
drop_path_rate (float): Stochastic depth rate. Default: ``0.1``
norm_layer (nn.Module): Normalization layer. Default: ``nn.LayerNorm``
ape (bool): If True, add absolute position embedding to the patch embedding. Default: ``False``
patch_norm (bool): If True, add normalization after patch embedding. Default: ``True``
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: ``False``
"""
def __init__(
self,
img_size=224,
patch_size=[4],
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
group_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
merge_size=[[2], [2], [2]],
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(
flow.zeros(1, num_patches, embed_dim)
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in flow.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]
for i_layer in range(self.num_layers):
patch_size_end = (
merge_size[i_layer] if i_layer < self.num_layers - 1 else None
)
num_patch_size = num_patch_sizes[i_layer]
layer = Stage(
dim=int(embed_dim * 2 ** i_layer),
input_resolution=(
patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
group_size=group_size[i_layer],
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
patch_size_end=patch_size_end,
num_patch_size=num_patch_size,
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = (
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity()
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = flow.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_cross_former(arch, pretrained=False, progress=True, **model_kwargs):
model = CrossFormer(**model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model