import math
from functools import partial
import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.init as init
from flowvision.layers import lecun_normal_, DropPath, PatchEmbed
from .helpers import named_apply
from .utils import load_state_dict_from_url
from .registry import ModelCreator
model_urls = {
"mlp_mixer_s16_224": None,
"mlp_mixer_s32_224": None,
"mlp_mixer_b16_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Mlp-Mixer/mlp_mixer_b16_224.zip",
"mlp_mixer_b32_224": None,
"mlp_mixer_b16_224_in21k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Mlp-Mixer/mlp_mixer_b16_224_in21k.zip",
"mlp_mixer_l16_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Mlp-Mixer/mlp_mixer_l16_224.zip",
"mlp_mixer_l32_224": None,
"mlp_mixer_l16_224_in21k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Mlp-Mixer/mlp_mixer_l16_224_in21k.zip",
"mlp_mixer_b16_224_miil": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Mlp-Mixer/mlp_mixer_b16_224_miil.zip",
"mlp_mixer_b16_224_miil_in21k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Mlp-Mixer/mlp_mixer_b16_224_miil_in21k.zip",
"gmlp_ti16_224": None,
"gmlp_s16_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/Mlp-Mixer/gmlp_s16_224.zip",
"gmlp_b16_224": None,
}
# helpers
def pair(x):
if not isinstance(x, tuple):
return (x, x)
else:
return x
class Mlp(nn.Module):
"""
You can also import Mlp Block in flowvision.layers.blocks like this:
from flowvision.layers.blocks import Mlp
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class GatedMlp(nn.Module):
""" MLP as used in gMLP
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
gate_layer=None,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
if gate_layer is not None:
assert hidden_features % 2 == 0
self.gate = gate_layer(hidden_features)
hidden_features = hidden_features // 2
else:
self.gate = nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.gate(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class SpatialGatingUnit(nn.Module):
""" Spatial Gating Unit
"""
def __init__(self, dim, num_patches, norm_layer=nn.LayerNorm):
super().__init__()
gate_dim = dim // 2
self.norm = norm_layer(gate_dim)
self.proj = nn.Linear(num_patches, num_patches)
def init_weights(self):
# special init for the projection gate, called as override by base model init
nn.init.normal_(self.proj.weight, std=1e-6)
nn.init.ones_(self.proj.bias)
def forward(self, x):
# TODO: use fixed chunk op
# u, v = x.chunk(2, dim=-1)
B, N, C = x.size()
split_dim = C // 2
u, v = flow.split(x, split_dim, dim=-1)[0], flow.split(x, split_dim, dim=-1)[1]
v = self.norm(v)
v = self.proj(v.transpose(-1, -2))
return u * v.transpose(-1, -2)
class SpatialGatingBlock(nn.Module):
""" Residual Block w/ Spatial Gating
Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
def __init__(
self,
dim,
num_patches,
mlp_ratio=4,
mlp_layer=GatedMlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop=0.0,
drop_path=0.0,
):
super().__init__()
channel_dim = int(dim * mlp_ratio)
self.norm = norm_layer(dim)
sgu = partial(SpatialGatingUnit, num_patches=num_patches)
self.mlp_channels = mlp_layer(
dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.mlp_channels(self.norm(x)))
return x
class MixerBlock(nn.Module):
""" Residual Block w/ token mixing and channel MLPs
Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
def __init__(
self,
dim,
num_patches,
mlp_ratio=(0.5, 4.0),
mlp_layer=Mlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop=0.0,
drop_path=0.0,
):
super().__init__()
tokens_dim, channels_dim = [int(x * dim) for x in pair(mlp_ratio)]
self.norm1 = norm_layer(dim)
self.mlp_tokens = mlp_layer(
num_patches, tokens_dim, act_layer=act_layer, drop=drop
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(
self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)
)
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
return x
class MlpMixer(nn.Module):
def __init__(
self,
num_classes=1000,
img_size=224,
in_chans=3,
patch_size=16,
num_blocks=8,
embed_dim=512,
mlp_ratio=(0.5, 4.0),
block_layer=MixerBlock,
mlp_layer=Mlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop_rate=0.0,
drop_path_rate=0.0,
nlhb=False,
stem_norm=False,
):
super().__init__()
self.num_classes = num_classes
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
self.stem = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if stem_norm else None,
)
# TODO consistent the drop-path-rate rule with the original repo
self.blocks = nn.Sequential(
*[
block_layer(
embed_dim,
self.stem.num_patches,
mlp_ratio,
mlp_layer=mlp_layer,
norm_layer=norm_layer,
act_layer=act_layer,
drop=drop_rate,
drop_path=drop_path_rate,
)
for _ in range(num_blocks)
]
)
self.norm = norm_layer(embed_dim)
self.head = (
nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
)
self.init_weights(nlhb=nlhb)
def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0.0
named_apply(
partial(_init_weights, head_bias=head_bias), module=self
) # depth-first
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = (
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.norm(x)
x = x.mean(dim=1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _init_weights(module: nn.Module, name: str, head_bias: float = 0.0, flax=False):
""" Mixer weight initialization (trying to match Flax defaults)
"""
if isinstance(module, nn.Linear):
if name.startswith("head"):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
if flax:
# Flax defaults
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
# like MLP init in vit
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if "mlp" in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif hasattr(module, "init_weights"):
# NOTE if a parent module contains init_weights method, it can override the init of the
# child modules as this will be called in depth-first order.
module.init_weights()
def _create_mlp_mixer(arch, pretrained=False, progress=True, **model_kwargs):
model = MlpMixer(**model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def mlp_mixer_s16_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-S/16 224x224 model.
.. note::
Mixer-S/16 224x224 model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_s16_224 = flowvision.models.mlp_mixer_s16_224(pretrained=False, progress=True)
"""
model_kwargs = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
return _create_mlp_mixer(
"mlp_mixer_s16_224", pretrained=pretrained, progress=progress, **model_kwargs
)
[docs]@ModelCreator.register_model
def mlp_mixer_s32_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-S/32 224x224 model.
.. note::
Mixer-S/32 224x224 model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_s32_224 = flowvision.models.mlp_mixer_s32_224(pretrained=False, progress=True)
"""
model_kwargs = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
return _create_mlp_mixer(
"mlp_mixer_s32_224", pretrained=pretrained, progress=progress, **model_kwargs
)
[docs]@ModelCreator.register_model
def mlp_mixer_b16_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-B/16 224x224 model.
.. note::
Mixer-B/16 224x224 model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_b16_224 = flowvision.models.mlp_mixer_b16_224(pretrained=False, progress=True)
"""
model_kwargs = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
return _create_mlp_mixer(
"mlp_mixer_b16_224", pretrained=pretrained, progress=progress, **model_kwargs
)
[docs]@ModelCreator.register_model
def mlp_mixer_b32_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-B/32 224x224 model.
.. note::
Mixer-B/32 224x224 model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_b32_224 = flowvision.models.mlp_mixer_b32_224(pretrained=False, progress=True)
"""
model_kwargs = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
return _create_mlp_mixer(
"mlp_mixer_b32_224", pretrained=pretrained, progress=progress, **model_kwargs
)
[docs]@ModelCreator.register_model
def mlp_mixer_b16_224_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-B/16 224x224 ImageNet21k pretrained model.
.. note::
Mixer-B/16 224x224 ImageNet21k pretrained model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Note that this model is the pretrained model for fine-tune on different datasets.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_b16_224_in21k = flowvision.models.mlp_mixer_b16_224_in21k(pretrained=False, progress=True)
"""
model_kwargs = dict(
num_classes=21843, patch_size=16, num_blocks=12, embed_dim=768, **kwargs
)
return _create_mlp_mixer(
"mlp_mixer_b16_224_in21k",
pretrained=pretrained,
progress=progress,
**model_kwargs
)
[docs]@ModelCreator.register_model
def mlp_mixer_l16_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-L/16 224x224 model.
.. note::
Mixer-L/16 224x224 model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_l16_224 = flowvision.models.mlp_mixer_l16_224(pretrained=False, progress=True)
"""
model_kwargs = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
return _create_mlp_mixer(
"mlp_mixer_l16_224", pretrained=pretrained, progress=progress, **model_kwargs
)
[docs]@ModelCreator.register_model
def mlp_mixer_l32_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-L/32 224x224 model.
.. note::
Mixer-L/32 224x224 model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_l32_224 = flowvision.models.mlp_mixer_l32_224(pretrained=False, progress=True)
"""
model_kwargs = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
return _create_mlp_mixer(
"mlp_mixer_l32_224", pretrained=pretrained, progress=progress, **model_kwargs
)
[docs]@ModelCreator.register_model
def mlp_mixer_l16_224_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-L/16 224x224 ImageNet21k pretrained model.
.. note::
Mixer-L/16 224x224 ImageNet21k pretrained model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Note that this model is the pretrained model for fine-tune on different datasets.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_l16_224_in21k = flowvision.models.mlp_mixer_l16_224_in21k(pretrained=False, progress=True)
"""
model_kwargs = dict(
num_classes=21843, patch_size=16, num_blocks=24, embed_dim=1024, **kwargs
)
return _create_mlp_mixer(
"mlp_mixer_l16_224_in21k",
pretrained=pretrained,
progress=progress,
**model_kwargs
)
[docs]@ModelCreator.register_model
def mlp_mixer_b16_224_miil(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-B/16 224x224 model with different weights.
.. note::
Mixer-B/16 224x224 model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_b16_224_miil = flowvision.models.mlp_mixer_b16_224_miil(pretrained=False, progress=True)
"""
model_kwargs = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
return _create_mlp_mixer(
"mlp_mixer_b16_224_miil",
pretrained=pretrained,
progress=progress,
**model_kwargs
)
[docs]@ModelCreator.register_model
def mlp_mixer_b16_224_miil_in21k(pretrained=False, progress=True, **kwargs):
"""
Constructs the Mixer-B/16 224x224 ImageNet21k pretrained model.
.. note::
Mixer-B/16 224x224 ImageNet21k pretrained model from `"MLP-Mixer: An all-MLP Architecture for Vision" <https://arxiv.org/pdf/2105.01601.pdf>`_.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> mlp_mixer_b16_224_miil_in21k = flowvision.models.mlp_mixer_b16_224_miil_in21k(pretrained=False, progress=True)
"""
model_kwargs = dict(
num_classes=11221, patch_size=16, num_blocks=12, embed_dim=768, **kwargs
)
return _create_mlp_mixer(
"mlp_mixer_b16_224_miil_in21k",
pretrained=pretrained,
progress=progress,
**model_kwargs
)
[docs]@ModelCreator.register_model
def gmlp_ti16_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the gMLP-tiny-16 224x224 model.
.. note::
gMLP-tiny-16 224x224 model from `"Pay Attention to MLPs" <https://arxiv.org/pdf/2105.08050.pdf>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> gmlp_ti16_224 = flowvision.models.gmlp_ti16_224(pretrained=False, progress=True)
"""
model_kwargs = dict(
patch_size=16,
num_blocks=30,
embed_dim=128,
mlp_ratio=6,
block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp,
**kwargs
)
return _create_mlp_mixer(
"gmlp_ti16_224", pretrained=pretrained, progress=progress, **model_kwargs
)
[docs]@ModelCreator.register_model
def gmlp_s16_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the gMLP-small-16 224x224 model.
.. note::
gMLP-small-16 224x224 model from `"Pay Attention to MLPs" <https://arxiv.org/pdf/2105.08050.pdf>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> gmlp_s16_224 = flowvision.models.gmlp_s16_224(pretrained=False, progress=True)
"""
model_kwargs = dict(
patch_size=16,
num_blocks=30,
embed_dim=256,
mlp_ratio=6,
block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp,
**kwargs
)
return _create_mlp_mixer(
"gmlp_s16_224", pretrained=pretrained, progress=progress, **model_kwargs
)
[docs]@ModelCreator.register_model
def gmlp_b16_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the gMLP-base-16 224x224 model.
.. note::
gMLP-base-16 224x224 model from `"Pay Attention to MLPs" <https://arxiv.org/pdf/2105.08050.pdf>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> gmlp_b16_224 = flowvision.models.gmlp_b16_224(pretrained=False, progress=True)
"""
model_kwargs = dict(
patch_size=16,
num_blocks=30,
embed_dim=512,
mlp_ratio=6,
block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp,
**kwargs
)
return _create_mlp_mixer(
"gmlp_b16_224", pretrained=pretrained, progress=progress, **model_kwargs
)