"""
Modified from https://github.com/facebookresearch/deit/blob/main/models_v2.py
"""
import collections.abc
from itertools import repeat
import oneflow
import oneflow.nn as nn
import oneflow.nn.functional as F
from functools import partial
from .vision_transformer import Mlp, PatchEmbed
from ..layers import DropPath, trunc_normal_
from .utils import load_state_dict_from_url
from .registry import ModelCreator
model_urls = {
"deit_3_base_224_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_base_224_1k.zip",
"deit_3_base_224_21k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_base_224_21k.zip",
"deit_3_base_384_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_base_384_1k.zip",
"deit_3_base_384_21k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_base_384_21k.zip",
"deit_3_huge_224_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_huge_224_1k.zip",
"deit_3_huge_224_21k_v1": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_huge_224_21k_v1.zip",
"deit_3_large_224_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_large_224_1k.zip",
"deit_3_large_224_21k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_large_224_21k.zip",
"deit_3_large_384_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_large_384_1k.zip",
"deit_3_large_384_21k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_large_384_21k.zip",
"deit_3_small_224_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_small_224_1k.zip",
"deit_3_small_224_21k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_small_224_21k.zip",
"deit_3_small_384_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_small_384_1k.zip",
"deit_3_small_384_21k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/DeiT_III/deit_3_small_384_21k.zip",
}
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, 2))
class Attention(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Layer_scale_init_Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.gamma_1 = nn.Parameter(
init_values * oneflow.ones((dim)), requires_grad=True
)
self.gamma_2 = nn.Parameter(
init_values * oneflow.ones((dim)), requires_grad=True
)
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class Layer_scale_init_Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.mlp1 = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.gamma_1 = nn.Parameter(
init_values * oneflow.ones((dim)), requires_grad=True
)
self.gamma_1_1 = nn.Parameter(
init_values * oneflow.ones((dim)), requires_grad=True
)
self.gamma_2 = nn.Parameter(
init_values * oneflow.ones((dim)), requires_grad=True
)
self.gamma_2_1 = nn.Parameter(
init_values * oneflow.ones((dim)), requires_grad=True
)
def forward(self, x):
x = (
x
+ self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
+ self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
)
x = (
x
+ self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
)
return x
class Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.mlp1 = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = (
x
+ self.drop_path(self.attn(self.norm1(x)))
+ self.drop_path(self.attn1(self.norm11(x)))
)
x = (
x
+ self.drop_path(self.mlp(self.norm2(x)))
+ self.drop_path(self.mlp1(self.norm21(x)))
)
return x
class hMLP_stem(nn.Module):
""" hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=nn.BatchNorm2d,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = oneflow.nn.Sequential(
*[
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
norm_layer(embed_dim),
]
)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class vit_models(nn.Module):
""" Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
global_pool=None,
block_layers=Block,
Patch_layer=PatchEmbed,
act_layer=nn.GELU,
Attention_block=Attention,
Mlp_block=Mlp,
dpr_constant=True,
init_scale=1e-4,
mlp_ratio_clstk=4.0,
):
super().__init__()
self.dropout_rate = drop_rate
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed = Patch_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(oneflow.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(oneflow.zeros(1, num_patches, embed_dim))
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList(
[
block_layers(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=0.0,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
Attention_block=Attention_block,
Mlp_block=Mlp_block,
init_values=init_scale,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")]
self.head = (
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def get_classifier(self):
return self.head
def get_num_layers(self):
return len(self.blocks)
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = (
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = x + self.pos_embed
x = oneflow.cat((cls_tokens, x), dim=1)
for i, blk in enumerate(self.blocks):
x = blk(x)
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
if self.dropout_rate:
x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
x = self.head(x)
return x
[docs]@ModelCreator.register_model
def deit_small_patch16_LS_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Small-patch16-LS-224 model.
.. note::
DeiT-Small-patch16-LS-224 model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_small_patch16_LS_224 = flowvision.models.deit_small_patch16_LS_224(pretrained=False, progress=True)
"""
model = vit_models(
img_size=224,
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_small_224_1k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_small_patch16_LS_384(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Small-patch16-LS-384 model.
.. note::
DeiT-Small-patch16-LS-384 model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 384x384.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_small_patch16_LS_384 = flowvision.models.deit_small_patch16_LS_384(pretrained=False, progress=True)
"""
model = vit_models(
img_size=384,
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_small_384_1k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_small_patch16_LS_224_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Small-patch16-LS-224 ImageNet21k pretrained model.
.. note::
DeiT-Small-patch16-LS-224 ImageNet21k pretrained model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_small_patch16_LS_224_in21k = flowvision.models.deit_small_patch16_LS_224_in21k(pretrained=False, progress=True)
"""
model = vit_models(
img_size=224,
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_small_224_21k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_small_patch16_LS_384_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Small-patch16-LS-384 ImageNet21k pretrained model.
.. note::
DeiT-Small-patch16-LS-384 ImageNet21k pretrained model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 384x384.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_small_patch16_LS_384_in21k = flowvision.models.deit_small_patch16_LS_384_in21k(pretrained=False, progress=True)
"""
model = vit_models(
img_size=384,
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_small_384_21k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_base_patch16_LS_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Base-patch16-LS-224 model.
.. note::
DeiT-Base-patch16-LS-224 model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_base_patch16_LS_224 = flowvision.models.deit_base_patch16_LS_224(pretrained=False, progress=True)
"""
model = vit_models(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_base_224_1k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_base_patch16_LS_384(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Base-patch16-LS-384 model.
.. note::
DeiT-Base-patch16-LS-384 model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 384x384.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_base_patch16_LS_384 = flowvision.models.deit_base_patch16_LS_384(pretrained=False, progress=True)
"""
model = vit_models(
img_size=384,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_base_384_1k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_base_patch16_LS_224_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Base-patch16-LS-224 ImageNet21k pretrained model.
.. note::
DeiT-Base-patch16-LS-224 ImageNet21k pretrained model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_base_patch16_LS_224_in21k = flowvision.models.deit_base_patch16_LS_224_in21k(pretrained=False, progress=True)
"""
model = vit_models(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_base_224_21k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_base_patch16_LS_384_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Base-patch16-LS-384 ImageNet21k pretrained model.
.. note::
DeiT-Base-patch16-LS-384 ImageNet21k pretrained model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 384x384.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_base_patch16_LS_384_in21k = flowvision.models.deit_base_patch16_LS_384_in21k(pretrained=False, progress=True)
"""
model = vit_models(
img_size=384,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_base_384_21k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_large_patch16_LS_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Large-patch16-LS-224 model.
.. note::
DeiT-Large-patch16-LS-224 model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_large_patch16_LS_224 = flowvision.models.deit_large_patch16_LS_224(pretrained=False, progress=True)
"""
model = vit_models(
img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_large_224_1k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_large_patch16_LS_384(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Large-patch16-LS-384 model.
.. note::
DeiT-Large-patch16-LS-384 model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 384x384.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_large_patch16_LS_384 = flowvision.models.deit_large_patch16_LS_384(pretrained=False, progress=True)
"""
model = vit_models(
img_size=384,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_large_384_1k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_large_patch16_LS_224_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Large-patch16-LS-224 ImageNet21k pretrained model.
.. note::
DeiT-Large-patch16-LS-224 ImageNet21k pretrained model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_large_patch16_LS_224_in21k = flowvision.models.deit_large_patch16_LS_224_in21k(pretrained=False, progress=True)
"""
model = vit_models(
img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_large_224_21k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_large_patch16_LS_384_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Large-patch16-LS-384 ImageNet21k pretrained model.
.. note::
DeiT-Large-patch16-LS-384 ImageNet21k pretrained model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 384x384.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_large_patch16_LS_384_in21k = flowvision.models.deit_large_patch16_LS_384_in21k(pretrained=False, progress=True)
"""
model = vit_models(
img_size=384,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_large_384_21k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_huge_patch14_LS_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Huge-patch14-LS-224 model.
.. note::
DeiT-Huge-patch14-LS-224 model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_huge_patch14_LS_224 = flowvision.models.deit_huge_patch14_LS_224(pretrained=False, progress=True)
"""
model = vit_models(
img_size=224,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_huge_224_1k"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def deit_huge_patch14_LS_224_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the DeiT-Huge-patch14-LS-224 ImageNet21k pretrained model.
.. note::
DeiT-Huge-patch14-LS-224 ImageNet21k pretrained model from `"DeiT III: Revenge of the ViT" <https://arxiv.org/pdf/2204.07118.pdf>`_.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> deit_huge_patch14_LS_224_in21k = flowvision.models.deit_huge_patch14_LS_224_in21k(pretrained=False, progress=True)
"""
model = vit_models(
img_size=224,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["deit_3_huge_224_21k_v1"], progress=progress
)
model.load_state_dict(state_dict)
return model