"""
Modified from https://github.com/NVlabs/FAN/blob/master/models/fan.py
"""
import math
from functools import partial
from collections import OrderedDict, abc
from typing import Callable
from itertools import repeat
import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.functional as F
from ..layers import DropPath
from .registry import ModelCreator
from .utils import load_state_dict_from_url
model_urls = {
"fan_vit_tiny": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_vit_tiny.zip",
"fan_vit_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_vit_small.zip",
"fan_vit_base": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_vit_base.zip",
"fan_vit_large": None,
"fan_hybrid_tiny": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_tiny.zip",
"fan_hybrid_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_small.zip",
"fan_hybrid_base": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_base.zip",
"fan_hybrid_large": None,
"fan_hybrid_base_in22k_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_base_in22k_1k.zip",
"fan_hybrid_base_in22k_1k_384": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_base_in22k_1k_384.zip",
"fan_hybrid_large_in22k_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_large_in22k_1k.zip",
"fan_hybrid_large_in22k_1k_384": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_large_in22k_1k_384.zip",
}
def to_2tuple(x):
if isinstance(x, abc.Iterable):
return x
return tuple(repeat(x, 2))
def _is_contiguous(tensor: flow.Tensor) -> bool:
return tensor.is_contiguous()
def named_apply(
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(
fn=fn,
module=child_module,
name=child_name,
depth_first=depth_first,
include_root=True,
)
if depth_first and include_root:
fn(module=module, name=name)
return module
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
print(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
)
with flow.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def _create_fc(num_features, num_classes, use_conv=False):
if num_classes <= 0:
fc = nn.Identity() # pass-through (no classifier)
elif use_conv:
fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
else:
fc = nn.Linear(num_features, num_classes, bias=True)
return fc
class MlpOri(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout."""
def __init__(
self, in_chs, num_classes, pool_type="avg", drop_rate=0.0, use_conv=False
):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1))
self.fc = _create_fc(in_chs, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
if pre_logits:
return x.flatten(1)
else:
x = self.fc(x)
return self.flatten(x)
class ClassAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to do CA
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
q = (
self.q(x[:, 0])
.unsqueeze(1)
.reshape(B, 1, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k = (
self.k(x)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
q = q * self.scale
v = (
self.v(x)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
x_cls = self.proj(x_cls)
x_cls = self.proj_drop(x_cls)
return x_cls
class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
norm_layer=None,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.norm(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return x
class LayerNorm2d(nn.LayerNorm):
r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__(normalized_shape, eps=eps)
def forward(self, x) -> flow.Tensor:
if _is_contiguous(x):
x = x.permute(0, 2, 3, 1)
x = super(LayerNorm2d, self).forward(x)
return x.permute(0, 3, 1, 2)
else:
s, u = flow.var_mean(x, dim=1, keepdim=True)
x = (x - u) * flow.rsqrt(s + self.eps)
x = x * self.weight[:, None, None] + self.bias[:, None, None]
return x
class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block
There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(
self,
dim,
drop_path=0.0,
ls_init_value=1e-6,
conv_mlp=True,
mlp_ratio=4,
norm_layer=None,
):
super().__init__()
if not norm_layer:
norm_layer = (
partial(LayerNorm2d, eps=1e-6)
if conv_mlp
else partial(nn.LayerNorm, eps=1e-6)
)
mlp_layer = ConvMlp if conv_mlp else MlpOri
self.use_conv_mlp = conv_mlp
self.conv_dw = nn.Conv2d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = norm_layer(dim)
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU)
self.gamma = (
nn.Parameter(ls_init_value * flow.ones(dim)) if ls_init_value > 0 else None
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv_dw(x)
if self.use_conv_mlp:
x = self.norm(x)
x = self.mlp(x)
else:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.mlp(x)
x = x.permute(0, 3, 1, 2)
if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut
return x
class ConvNeXtStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
stride=2,
depth=2,
dp_rates=None,
ls_init_value=1.0,
conv_mlp=True,
norm_layer=None,
cl_norm_layer=None,
cross_stage=False,
no_downsample=False,
):
super().__init__()
if in_chs != out_chs or stride > 1:
self.downsample = nn.Sequential(
norm_layer(in_chs),
nn.Conv2d(
in_chs,
out_chs,
kernel_size=stride,
stride=stride if not no_downsample else 1,
),
)
else:
self.downsample = nn.Identity()
dp_rates = dp_rates or [0.0] * depth
self.blocks = nn.Sequential(
*[
ConvNeXtBlock(
dim=out_chs,
drop_path=dp_rates[j],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
norm_layer=norm_layer if conv_mlp else cl_norm_layer,
)
for j in range(depth)
]
)
def forward(self, x):
x = self.downsample(x)
x = self.blocks(x)
return x
class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_rate (float): Head dropout rate
drop_path_rate (float): Stochastic depth rate. Default: 0.
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self,
in_chans=3,
img_size=224,
num_classes=1000,
global_pool="avg",
output_stride=32,
patch_size=4,
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 768),
ls_init_value=1e-6,
conv_mlp=True,
use_head=True,
head_init_scale=1.0,
head_norm_first=False,
norm_layer=None,
drop_rate=0.0,
drop_path_rate=0.0,
remove_last_downsample=False,
):
super().__init__()
assert output_stride == 32
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
else:
assert (
conv_mlp
), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input"
cl_norm_layer = norm_layer
self.num_classes = num_classes
self.drop_rate = drop_rate
self.feature_info = []
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
norm_layer(dims[0]),
)
self.stages = nn.Sequential()
dp_rates = [
x.tolist()
for x in flow.linspace(0, drop_path_rate, sum(depths)).split(depths)
]
curr_stride = patch_size
prev_chs = dims[0]
stages = []
# 4 feature resolution stages, each consisting of multiple residual blocks
for i in range(len(depths)):
stride = 2 if i > 0 else 1
curr_stride *= stride
out_chs = dims[i]
no_downsample = remove_last_downsample and (i == len(depths) - 1)
stages.append(
ConvNeXtStage(
prev_chs,
out_chs,
stride=stride,
depth=depths[i],
dp_rates=dp_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
norm_layer=norm_layer,
cl_norm_layer=cl_norm_layer,
no_downsample=no_downsample,
)
)
prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [
dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}")
]
self.stages = nn.Sequential(*stages)
self.num_features = prev_chs
if head_norm_first:
# norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)
self.norm_pre = norm_layer(
self.num_features
) # final norm layer, before pooling
if use_head:
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
else:
# pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = nn.Identity()
if use_head:
self.head = nn.Sequential(
OrderedDict(
[
(
"global_pool",
nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1)),
),
("norm", norm_layer(self.num_features)),
(
"flatten",
nn.Flatten(1) if global_pool else nn.Identity(),
),
("drop", nn.Dropout(self.drop_rate)),
(
"fc",
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity(),
),
]
)
)
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool="avg"):
if isinstance(self.head, ClassifierHead):
# norm -> global pool -> fc
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
)
else:
# pool -> norm -> fc
self.head = nn.Sequential(
OrderedDict(
[
(
"global_pool",
nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1)),
),
("norm", self.head.norm),
("flatten", nn.Flatten(1) if global_pool else nn.Identity()),
("drop", nn.Dropout(self.drop_rate)),
(
"fc",
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity(),
),
]
)
)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm_pre(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=0.02)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
nn.init.constant_(module.bias, 0)
if name and "head." in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """
if "model" in state_dict:
state_dict = state_dict["model"]
out_dict = {}
import re
for k, v in state_dict.items():
k = k.replace("downsample_layers.0.", "stem.")
k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k)
k = re.sub(
r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k
)
k = k.replace("dwconv", "conv_dw")
k = k.replace("pwconv", "mlp.fc")
k = k.replace("head.", "head.fc.")
if k in model.state_dict().keys():
if k.startswith("norm."):
k = k.replace("norm", "head.norm")
if v.ndim == 2 and "head" not in k:
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
def _create_hybrid_backbone(**kwargs):
model = ConvNeXt(**kwargs)
return model
class PositionalEncodingFourier(nn.Module):
"""
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
"""
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__()
self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
self.eps = 1e-6
def forward(self, B: int, H: int, W: int):
device = self.token_projection.weight.device
y_embed = (
flow.arange(1, H + 1, dtype=flow.float32, device=device)
.unsqueeze(1)
.repeat(1, 1, W)
)
x_embed = flow.arange(1, W + 1, dtype=flow.float32, device=device).repeat(
1, H, 1
)
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
dim_t = flow.arange(self.hidden_dim, dtype=flow.float32, device=device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = flow.stack(
[pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4
).flatten(3)
pos_y = flow.stack(
[pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4
).flatten(3)
pos = flow.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution + batch norm"""
return flow.nn.Sequential(
nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
),
nn.BatchNorm2d(out_planes),
)
def sigmoid(x, inplace=False):
return x.sigmoid_() if inplace else x.sigmoid()
def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class SqueezeExcite(nn.Module):
def __init__(
self,
in_chs,
se_ratio=0.25,
reduced_base_chs=None,
act_layer=nn.ReLU,
gate_fn=sigmoid,
divisor=1,
**_,
):
super(SqueezeExcite, self).__init__()
self.gate_fn = gate_fn
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class SEMlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
linear=False,
use_se=True,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
# self.dwconv = DWConv(hidden_features)
self.dwconv = DWConv(hidden_features)
self.gamma = nn.Parameter(flow.ones(hidden_features), requires_grad=True)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.linear = linear
if self.linear:
self.relu = nn.ReLU(inplace=True)
self.se = (
SqueezeExcite(out_features, se_ratio=0.25) if use_se else nn.Identity()
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
x = self.fc1(x)
if self.linear:
x = self.relu(x)
# import pdb; pdb.set_trace()
x = self.drop(self.gamma * self.dwconv(x, H, W)) + x
x = self.fc2(x)
x = self.drop(x)
x = (
self.se(x.permute(0, 2, 1).reshape(B, C, H, W))
.reshape(B, C, N)
.permute(0, 2, 1)
)
return x, H, W
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
linear=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.gamma = nn.Parameter(flow.ones(hidden_features), requires_grad=True)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.linear = linear
if self.linear:
self.relu = nn.ReLU(inplace=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
if self.linear:
x = self.relu(x)
x = self.drop(self.gamma * self.dwconv(x, H, W)) + x
x = self.fc2(x)
x = self.drop(x)
return x
class ConvPatchEmbed(nn.Module):
"""Image to Patch Embedding using multiple convolutional layers"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU
):
super().__init__()
img_size = to_2tuple(img_size)
num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
# import pdb; pdb.set_trace()
if patch_size == 16:
self.proj = flow.nn.Sequential(
conv3x3(in_chans, embed_dim // 8, 2),
act_layer(),
conv3x3(embed_dim // 8, embed_dim // 4, 2),
act_layer(),
conv3x3(embed_dim // 4, embed_dim // 2, 2),
act_layer(),
conv3x3(embed_dim // 2, embed_dim, 2),
)
elif patch_size == 8:
self.proj = flow.nn.Sequential(
conv3x3(in_chans, embed_dim // 4, 2),
act_layer(),
conv3x3(embed_dim // 4, embed_dim // 2, 2),
act_layer(),
conv3x3(embed_dim // 2, embed_dim, 2),
)
elif patch_size == 4:
self.proj = flow.nn.Sequential(
conv3x3(in_chans, embed_dim // 4, 2),
act_layer(),
conv3x3(embed_dim // 4, embed_dim // 1, 2),
# act_layer(),
# conv3x3(embed_dim // 2, embed_dim, 2),
)
else:
raise ("For convolutional projection, patch size has to be in [8, 16]")
def forward(self, x):
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2) # (B, N, C)
return x, (Hp, Wp)
class DWConv(nn.Module):
def __init__(
self, in_features, out_features=None, act_layer=nn.GELU, kernel_size=3
):
super().__init__()
out_features = out_features or in_features
padding = kernel_size // 2
self.conv1 = flow.nn.Conv2d(
in_features,
in_features,
kernel_size=kernel_size,
padding=padding,
groups=in_features,
)
self.act = act_layer()
self.bn = nn.BatchNorm2d(in_features)
self.conv2 = flow.nn.Conv2d(
in_features,
out_features,
kernel_size=kernel_size,
padding=padding,
groups=out_features,
)
def forward(self, x, H: int, W: int):
B, N, C = x.shape
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.conv1(x)
x = self.act(x)
x = self.bn(x)
x = self.conv2(x)
x = x.reshape(B, C, N).permute(0, 2, 1)
return x
class ClassAttentionBlock(nn.Module):
"""Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
eta=1.0,
tokens_norm=False,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = ClassAttn(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MlpOri(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
if eta is not None: # LayerScale Initialization (no layerscale when None)
self.gamma1 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
self.gamma2 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
else:
self.gamma1, self.gamma2 = 1.0, 1.0
self.tokens_norm = tokens_norm
def forward(self, x, return_attention=False):
x_norm1 = self.norm1(x)
if return_attention:
x1, attn = self.attn(x_norm1, use_attn=return_attention)
else:
x1 = self.attn(x_norm1)
x_attn = flow.cat([x1, x_norm1[:, 1:]], dim=1)
x = x + self.drop_path(self.gamma1 * x_attn)
if self.tokens_norm:
x = self.norm2(x)
else:
x = flow.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
x_res = x
cls_token = x[:, 0:1]
cls_token = self.gamma2 * self.mlp(cls_token)
x = flow.cat([cls_token, x[:, 1:]], dim=1)
x = x_res + self.drop_path(x)
if return_attention:
return attn
return x
class TokenMixing(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
sr_ratio=1,
linear=False,
share_atten=False,
drop_path=0.0,
emlp=False,
sharpen_attn=False,
mlp_hidden_dim=None,
act_layer=nn.GELU,
drop=None,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.share_atten = share_atten
self.emlp = emlp
cha_sr = 1
self.q = nn.Linear(dim, dim // cha_sr, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2 // cha_sr, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.linear = linear
self.sr_ratio = sr_ratio
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W, atten=None, return_attention=False):
B, N, C = x.shape
q = (
self.q(x)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
# import pdb;pdb.set_trace()
kv = (
self.kv(x)
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
k, v = kv[0], kv[1]
attn = q * self.scale @ k.transpose(-2, -1) # * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn @ v
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(
self,
backbone,
img_size=224,
patch_size=2,
feature_size=None,
in_chans=3,
embed_dim=384,
):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone
if feature_size is None:
with flow.no_grad():
# NOTE Most reliable way of determining output dims is to run forward pass
training = backbone.training
if training:
backbone.eval()
o = self.backbone.forward_features(
flow.zeros(1, in_chans, img_size[0], img_size[1])
)
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, "feature_info"):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
assert (
feature_size[0] % patch_size[0] == 0
and feature_size[1] % patch_size[1] == 0
)
self.grid_size = (
feature_size[0] // patch_size[0],
feature_size[1] // patch_size[1],
)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(
feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
x = self.backbone.forward_features(x)
B, C, H, W = x.shape
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2)
return x, (H // self.patch_size[0], W // self.patch_size[1])
class ChannelProcessing(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.0,
linear=False,
drop_path=0.0,
mlp_hidden_dim=None,
act_layer=nn.GELU,
drop=None,
norm_layer=nn.LayerNorm,
cha_sr_ratio=1,
c_head_num=None,
):
super().__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
num_heads = c_head_num or num_heads
self.num_heads = num_heads
self.temperature = nn.Parameter(flow.ones(num_heads, 1, 1))
self.cha_sr_ratio = cha_sr_ratio if num_heads > 1 else 1
# config of mlp for v processing
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.mlp_v = Mlp(
in_features=dim // self.cha_sr_ratio,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
linear=linear,
)
self.norm_v = norm_layer(dim // self.cha_sr_ratio)
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def _gen_attn(self, q, k):
q = q.softmax(-2).transpose(-1, -2)
_, _, N, _ = k.shape
k = flow.nn.functional.adaptive_avg_pool2d(k.softmax(-2), (N, 1))
attn = flow.nn.functional.sigmoid(q @ k)
return attn * self.temperature
def forward(self, x, H, W, atten=None):
B, N, C = x.shape
v = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = (
self.q(x)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = self._gen_attn(q, k)
attn = self.attn_drop(attn)
Bv, Hd, Nv, Cv = v.shape
v = (
self.norm_v(self.mlp_v(v.transpose(1, 2).reshape(Bv, Nv, Hd * Cv), H, W))
.reshape(Bv, Nv, Hd, Cv)
.transpose(1, 2)
)
repeat_time = N // attn.shape[-1]
attn = (
attn.repeat_interleave(repeat_time, dim=-1) if attn.shape[-1] > 1 else attn
)
x = (attn * v.transpose(-1, -2)).permute(0, 3, 1, 2).reshape(B, N, C)
return x, (attn * v.transpose(-1, -2)).transpose(-1, -2) # attn
def no_weight_decay(self):
return {"temperature"}
class FANBlock_SE(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
sharpen_attn=False,
use_se=False,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
eta=1.0,
sr_ratio=1.0,
qk_scale=None,
linear=False,
downsample=None,
c_head_num=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = TokenMixing(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
mlp_hidden_dim=int(dim * mlp_ratio),
sharpen_attn=sharpen_attn,
attn_drop=attn_drop,
proj_drop=drop,
drop=drop,
drop_path=drop_path,
sr_ratio=sr_ratio,
linear=linear,
emlp=False,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = SEMlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.gamma1 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
self.gamma2 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
def forward(self, x, H: int, W: int, attn=None):
x_new, _ = self.attn(self.norm1(x), H, W)
x = x + self.drop_path(self.gamma1 * x_new)
x_new, H, W = self.mlp(self.norm2(x), H, W)
x = x + self.drop_path(self.gamma2 * x_new)
return x, H, W
class FANBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
sharpen_attn=False,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
eta=1.0,
sr_ratio=1.0,
downsample=None,
c_head_num=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = TokenMixing(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
mlp_hidden_dim=int(dim * mlp_ratio),
sharpen_attn=sharpen_attn,
attn_drop=attn_drop,
proj_drop=drop,
drop=drop,
drop_path=drop_path,
sr_ratio=sr_ratio,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = ChannelProcessing(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
drop_path=drop_path,
drop=drop,
mlp_hidden_dim=int(dim * mlp_ratio),
c_head_num=c_head_num,
)
self.gamma1 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
self.gamma2 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
self.downsample = downsample
self.H = None
self.W = None
def forward(self, x, attn=None, return_attention=False):
H, W = self.H, self.W
x_new, attn_s = self.attn(self.norm1(x), H, W)
x = x + self.drop_path(self.gamma1 * x_new)
x_new, attn_c = self.mlp(self.norm2(x), H, W, atten=attn)
x = x + self.drop_path(self.gamma2 * x_new)
if return_attention:
return x, attn_s
if self.downsample is not None:
x, H, W = self.downsample(x, H, W)
self.H, self.W = H, W
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2),
)
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(-1, -2).reshape(B, C, H, W)
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class FAN(nn.Module):
"""
Based on timm code bases
https://github.com/rwightman/pytorch-image-models/tree/master/timm
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
sharpen_attn=False,
channel_dims=None,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
sr_ratio=None,
backbone=None,
use_checkpoint=False,
act_layer=None,
norm_layer=None,
se_mlp=False,
cls_attn_layers=2,
use_pos_embed=True,
eta=1.0,
tokens_norm=False,
c_head_num=None,
hybrid_patch_size=2,
head_init_scale=1.0,
):
super().__init__()
img_size = to_2tuple(img_size)
self.use_checkpoint = use_checkpoint
assert (img_size[0] % patch_size == 0) and (
img_size[0] % patch_size == 0
), "`patch_size` should divide image dimensions evenly"
self.num_classes = num_classes
num_heads = (
[num_heads] * depth if not isinstance(num_heads, list) else num_heads
)
channel_dims = [embed_dim] * depth if channel_dims is None else channel_dims
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
if backbone == None:
self.patch_embed = ConvPatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
act_layer=act_layer,
)
else:
self.patch_embed = HybridEmbed(
backbone=backbone, patch_size=hybrid_patch_size, embed_dim=embed_dim
)
self.use_pos_embed = use_pos_embed
if use_pos_embed:
self.pos_embed = PositionalEncodingFourier(dim=embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
if se_mlp:
build_block = FANBlock_SE
else:
build_block = FANBlock
self.blocks = nn.ModuleList([])
for i in range(depth):
if i < depth - 1 and channel_dims[i] != channel_dims[i + 1]:
downsample = OverlapPatchEmbed(
img_size=img_size,
patch_size=3,
stride=2,
in_chans=channel_dims[i],
embed_dim=channel_dims[i + 1],
)
else:
downsample = None
self.blocks.append(
build_block(
dim=channel_dims[i],
num_heads=num_heads[i],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
sr_ratio=sr_ratio[i],
attn_drop=attn_drop_rate,
drop_path=drop_path_rate,
act_layer=act_layer,
norm_layer=norm_layer,
eta=eta,
downsample=downsample,
c_head_num=c_head_num[i] if c_head_num is not None else None,
)
)
self.num_features = self.embed_dim = channel_dims[i]
self.cls_token = nn.Parameter(flow.zeros(1, 1, channel_dims[i]))
self.cls_attn_blocks = nn.ModuleList(
[
ClassAttentionBlock(
dim=channel_dims[-1],
num_heads=num_heads[-1],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
act_layer=act_layer,
norm_layer=norm_layer,
eta=eta,
tokens_norm=tokens_norm,
)
for _ in range(cls_attn_layers)
]
)
# Classifier head
self.norm = norm_layer(channel_dims[i])
self.head = (
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity()
)
# Init weights
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def no_weight_decay(self):
return {"pos_embed", "cls_token"} # , 'patch_embed'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = (
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity()
)
def forward_features(self, x):
B = x.shape[0]
x, (Hp, Wp) = self.patch_embed(x)
if self.use_pos_embed:
pos_encoding = (
self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
)
x = x + pos_encoding
x = self.pos_drop(x)
H, W = Hp, Wp
for blk in self.blocks:
blk.H, blk.W = H, W
x = blk(x)
H, W = blk.H, blk.W
cls_tokens = self.cls_token.expand(B, -1, -1)
x = flow.cat((cls_tokens, x), dim=1)
for blk in self.cls_attn_blocks:
x = blk(x)
x = self.norm(x)[:, 0]
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def get_last_selfattention(self, x, use_cls_attn=False, layer_idx=11):
B = x.shape[0]
x, (Hp, Wp) = self.patch_embed(x)
if self.use_pos_embed:
pos_encoding = (
self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
)
x = x + pos_encoding
x = self.pos_drop(x)
return_idx = layer_idx or len(self.blocks) - 1
for i, blk in enumerate(self.blocks):
if i == return_idx:
x, attn = blk(x, Hp, Wp, return_attention=True)
else:
x, Hp, Wp = blk(x, Hp, Wp)
if use_cls_attn:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = flow.cat((cls_tokens, x), dim=1)
for i, blk in enumerate(self.cls_attn_blocks):
if i < len(self.cls_attn_blocks) - 1:
x = blk(x)
else:
attn = blk(x, return_attention=True)
return attn
else:
return attn
# FAN-ViT Models
[docs]@ModelCreator.register_model
def fan_tiny_12_p16_224(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs FAN-ViT-tiny 224x224 model pretrained on ImageNet-1k.
.. note::
FAN-ViT-tiny 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_tiny_12_p16_224 = flowvision.models.fan_tiny_12_p16_224(pretrained=False, progress=True)
"""
depth = 12
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_kwargs = dict(
patch_size=16,
embed_dim=192,
depth=depth,
num_heads=4,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
**kwargs,
)
model = FAN(sr_ratio=sr_ratio, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_vit_tiny"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_small_12_p16_224(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs FAN-ViT-small 224x224 model pretrained on ImageNet-1k.
.. note::
FAN-ViT-small 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_small_12_p16_224 = flowvision.models.fan_small_12_p16_224(pretrained=False, progress=True)
"""
depth = 12
sr_ratio = [1] * depth
model_kwargs = dict(
patch_size=16,
embed_dim=384,
depth=depth,
num_heads=8,
eta=1.0,
tokens_norm=True,
**kwargs,
)
model = FAN(sr_ratio=sr_ratio, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_vit_small"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_base_18_p16_224(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs FAN-ViT-base 224x224 model pretrained on ImageNet-1k.
.. note::
FAN-ViT-base 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_base_18_p16_224 = flowvision.models.fan_base_18_p16_224(pretrained=False, progress=True)
"""
depth = 18
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_kwargs = dict(
patch_size=16,
embed_dim=448,
depth=depth,
num_heads=8,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
**kwargs,
)
model = FAN(sr_ratio=sr_ratio, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_vit_base"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_large_24_p16_224(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs FAN-ViT-large 224x224 model pretrained on ImageNet-1k.
.. note::
FAN-ViT-large 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_large_24_p16_224 = flowvision.models.fan_large_24_p16_224(pretrained=False, progress=True)
"""
depth = 24
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_kwargs = dict(
patch_size=16,
embed_dim=480,
depth=depth,
num_heads=10,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
**kwargs,
)
model = FAN(sr_ratio=sr_ratio, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_vit_large"], progress=progress
)
model.load_state_dict(state_dict)
return model
# FAN-Hybrid Models
# CNN backbones are based on ConvNeXt architecture with only first two stages for downsampling purpose
# This has been verified to be beneficial for downstream tasks
[docs]@ModelCreator.register_model
def fan_tiny_8_p4_hybrid(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs FAN-Hybrid-tiny 224x224 model pretrained on ImageNet-1k.
.. note::
FAN-Hybrid-tiny 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_tiny_8_p4_hybrid = flowvision.models.fan_tiny_8_p4_hybrid(pretrained=False, progress=True)
"""
depth = 8
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2 + 1)
model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False)
backbone = _create_hybrid_backbone(**model_args)
model_kwargs = dict(
patch_size=16,
embed_dim=192,
depth=depth,
num_heads=8,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
**kwargs,
)
model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_hybrid_tiny"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_small_12_p4_hybrid(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs FAN-Hybrid-small 224x224 model pretrained on ImageNet-1k.
.. note::
FAN-Hybrid-small 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_small_12_p4_hybrid = flowvision.models.fan_small_12_p4_hybrid(pretrained=False, progress=True)
"""
depth = 10
channel_dims = [384] * 10 + [384] * (depth - 10)
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False)
backbone = _create_hybrid_backbone(**model_args)
model_kwargs = dict(
patch_size=16,
embed_dim=384,
depth=depth,
num_heads=8,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
**kwargs,
)
model = FAN(
sr_ratio=sr_ratio, backbone=backbone, channel_dims=channel_dims, **model_kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_hybrid_small"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_base_16_p4_hybrid(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs FAN-Hybrid-base 224x224 model pretrained on ImageNet-1k.
.. note::
FAN-Hybrid-base 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_base_16_p4_hybrid = flowvision.models.fan_base_16_p4_hybrid(pretrained=False, progress=True)
"""
depth = 16
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False)
backbone = _create_hybrid_backbone(**model_args)
model_kwargs = dict(
patch_size=16,
embed_dim=448,
depth=depth,
num_heads=8,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
**kwargs,
)
model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_hybrid_base"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_large_16_p4_hybrid(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs FAN-Hybrid-large 224x224 model pretrained on ImageNet-1k.
.. note::
FAN-Hybrid-large 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_large_16_p4_hybrid = flowvision.models.fan_large_16_p4_hybrid(pretrained=False, progress=True)
"""
depth = 22
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_args = dict(depths=[3, 5], dims=[128, 256, 512, 1024], use_head=False)
backbone = _create_hybrid_backbone(**model_args)
model_kwargs = dict(
patch_size=16,
embed_dim=480,
depth=depth,
num_heads=10,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
head_init_scale=0.001,
**kwargs,
)
model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_hybrid_large"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_base_16_p4_hybrid_in22k_1k(
pretrained: bool = False, progress: bool = True, **kwargs
):
"""
Constructs FAN-Hybrid-base 224x224 model pretrained on ImageNet-21k.
.. note::
FAN-Hybrid-base 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_base_16_p4_hybrid_in22k_1k = flowvision.models.fan_base_16_p4_hybrid_in22k_1k(pretrained=False, progress=True)
"""
depth = 16
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False)
backbone = _create_hybrid_backbone(**model_args)
model_kwargs = dict(
patch_size=16,
embed_dim=448,
depth=depth,
num_heads=8,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
**kwargs,
)
model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_hybrid_base_in22k_1k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_base_16_p4_hybrid_in22k_1k_384(
pretrained: bool = False, progress: bool = True, **kwargs
):
"""
Constructs FAN-Hybrid-base 384x384 model pretrained on ImageNet-21k.
.. note::
FAN-Hybrid-base 384x384 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_base_16_p4_hybrid_in22k_1k_384 = flowvision.models.fan_base_16_p4_hybrid_in22k_1k_384(pretrained=False, progress=True)
"""
depth = 16
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False)
backbone = _create_hybrid_backbone(**model_args)
model_kwargs = dict(
patch_size=16,
embed_dim=448,
depth=depth,
num_heads=8,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
**kwargs,
)
model = FAN(img_size=384, sr_ratio=sr_ratio, backbone=backbone, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_hybrid_base_in22k_1k_384"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_large_16_p4_hybrid_in22k_1k(
pretrained: bool = False, progress: bool = True, **kwargs
):
"""
Constructs FAN-Hybrid-large 224x224 model pretrained on ImageNet-21k.
.. note::
FAN-Hybrid-large 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_large_16_p4_hybrid_in22k_1k = flowvision.models.fan_large_16_p4_hybrid_in22k_1k(pretrained=False, progress=True)
"""
depth = 22
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_args = dict(depths=[3, 5], dims=[128, 256, 512, 1024], use_head=False)
backbone = _create_hybrid_backbone(**model_args)
model_kwargs = dict(
patch_size=16,
embed_dim=480,
depth=depth,
num_heads=10,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
head_init_scale=0.001,
**kwargs,
)
model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_hybrid_large_in22k_1k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def fan_large_16_p4_hybrid_in22k_1k_384(
pretrained: bool = False, progress: bool = True, **kwargs
):
"""
Constructs FAN-Hybrid-large 384x384 model pretrained on ImageNet-21k.
.. note::
FAN-Hybrid-large 384x384 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> fan_large_16_p4_hybrid_in22k_1k_384 = flowvision.models.fan_large_16_p4_hybrid_in22k_1k_384(pretrained=False, progress=True)
"""
depth = 22
sr_ratio = [1] * (depth // 2) + [1] * (depth // 2)
model_args = dict(depths=[3, 5], dims=[128, 256, 512, 1024], use_head=False)
backbone = _create_hybrid_backbone(**model_args)
model_kwargs = dict(
patch_size=16,
embed_dim=480,
depth=depth,
num_heads=10,
eta=1.0,
tokens_norm=True,
sharpen_attn=False,
head_init_scale=0.001,
**kwargs,
)
model = FAN(img_size=384, sr_ratio=sr_ratio, backbone=backbone, **model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["fan_hybrid_large_in22k_1k_384"], progress=progress
)
model.load_state_dict(state_dict)
return model