Source code for flowvision.models.fan

"""
Modified from https://github.com/NVlabs/FAN/blob/master/models/fan.py
"""
import math
from functools import partial
from collections import OrderedDict, abc
from typing import Callable
from itertools import repeat

import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.functional as F

from ..layers import DropPath
from .registry import ModelCreator
from .utils import load_state_dict_from_url

model_urls = {
    "fan_vit_tiny": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_vit_tiny.zip",
    "fan_vit_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_vit_small.zip",
    "fan_vit_base": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_vit_base.zip",
    "fan_vit_large": None,
    "fan_hybrid_tiny": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_tiny.zip",
    "fan_hybrid_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_small.zip",
    "fan_hybrid_base": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_base.zip",
    "fan_hybrid_large": None,
    "fan_hybrid_base_in22k_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_base_in22k_1k.zip",
    "fan_hybrid_base_in22k_1k_384": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_base_in22k_1k_384.zip",
    "fan_hybrid_large_in22k_1k": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_large_in22k_1k.zip",
    "fan_hybrid_large_in22k_1k_384": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/FAN/fan_hybrid_large_in22k_1k_384.zip",
}


def to_2tuple(x):
    if isinstance(x, abc.Iterable):
        return x
    return tuple(repeat(x, 2))


def _is_contiguous(tensor: flow.Tensor) -> bool:
    return tensor.is_contiguous()


def named_apply(
    fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
) -> nn.Module:
    if not depth_first and include_root:
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        child_name = ".".join((name, child_name)) if name else child_name
        named_apply(
            fn=fn,
            module=child_module,
            name=child_name,
            depth_first=depth_first,
            include_root=True,
        )
    if depth_first and include_root:
        fn(module=module, name=name)
    return module


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        print(
            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
            "The distribution of values may be incorrect."
        )

    with flow.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.0))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def _create_fc(num_features, num_classes, use_conv=False):
    if num_classes <= 0:
        fc = nn.Identity()  # pass-through (no classifier)
    elif use_conv:
        fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
    else:
        fc = nn.Linear(num_features, num_classes, bias=True)
    return fc


class MlpOri(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        bias=True,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class ClassifierHead(nn.Module):
    """Classifier head w/ configurable global pooling and dropout."""

    def __init__(
        self, in_chs, num_classes, pool_type="avg", drop_rate=0.0, use_conv=False
    ):
        super(ClassifierHead, self).__init__()
        self.drop_rate = drop_rate
        self.global_pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1))
        self.fc = _create_fc(in_chs, num_classes, use_conv=use_conv)
        self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()

    def forward(self, x, pre_logits: bool = False):
        x = self.global_pool(x)
        if self.drop_rate:
            x = F.dropout(x, p=float(self.drop_rate), training=self.training)
        if pre_logits:
            return x.flatten(1)
        else:
            x = self.fc(x)
            return self.flatten(x)


class ClassAttn(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications to do CA
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        q = (
            self.q(x[:, 0])
            .unsqueeze(1)
            .reshape(B, 1, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        k = (
            self.k(x)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )

        q = q * self.scale
        v = (
            self.v(x)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
        x_cls = self.proj(x_cls)
        x_cls = self.proj_drop(x_cls)

        return x_cls


class ConvMlp(nn.Module):
    """ MLP using 1x1 convs that keeps spatial dims
    """

    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.ReLU,
        norm_layer=None,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
        self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        return x


class LayerNorm2d(nn.LayerNorm):
    r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
    """

    def __init__(self, normalized_shape, eps=1e-6):
        super().__init__(normalized_shape, eps=eps)

    def forward(self, x) -> flow.Tensor:
        if _is_contiguous(x):
            x = x.permute(0, 2, 3, 1)
            x = super(LayerNorm2d, self).forward(x)
            return x.permute(0, 3, 1, 2)
        else:
            s, u = flow.var_mean(x, dim=1, keepdim=True)
            x = (x - u) * flow.rsqrt(s + self.eps)
            x = x * self.weight[:, None, None] + self.bias[:, None, None]
            return x


class ConvNeXtBlock(nn.Module):
    """ ConvNeXt Block
    There are two equivalent implementations:
      (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
      (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
    choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
    is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """

    def __init__(
        self,
        dim,
        drop_path=0.0,
        ls_init_value=1e-6,
        conv_mlp=True,
        mlp_ratio=4,
        norm_layer=None,
    ):
        super().__init__()
        if not norm_layer:
            norm_layer = (
                partial(LayerNorm2d, eps=1e-6)
                if conv_mlp
                else partial(nn.LayerNorm, eps=1e-6)
            )
        mlp_layer = ConvMlp if conv_mlp else MlpOri
        self.use_conv_mlp = conv_mlp
        self.conv_dw = nn.Conv2d(
            dim, dim, kernel_size=7, padding=3, groups=dim
        )  # depthwise conv
        self.norm = norm_layer(dim)
        self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU)
        self.gamma = (
            nn.Parameter(ls_init_value * flow.ones(dim)) if ls_init_value > 0 else None
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        shortcut = x
        x = self.conv_dw(x)
        if self.use_conv_mlp:
            x = self.norm(x)
            x = self.mlp(x)
        else:
            x = x.permute(0, 2, 3, 1)
            x = self.norm(x)
            x = self.mlp(x)
            x = x.permute(0, 3, 1, 2)
        if self.gamma is not None:
            x = x.mul(self.gamma.reshape(1, -1, 1, 1))
        x = self.drop_path(x) + shortcut
        return x


class ConvNeXtStage(nn.Module):
    def __init__(
        self,
        in_chs,
        out_chs,
        stride=2,
        depth=2,
        dp_rates=None,
        ls_init_value=1.0,
        conv_mlp=True,
        norm_layer=None,
        cl_norm_layer=None,
        cross_stage=False,
        no_downsample=False,
    ):
        super().__init__()

        if in_chs != out_chs or stride > 1:
            self.downsample = nn.Sequential(
                norm_layer(in_chs),
                nn.Conv2d(
                    in_chs,
                    out_chs,
                    kernel_size=stride,
                    stride=stride if not no_downsample else 1,
                ),
            )
        else:
            self.downsample = nn.Identity()

        dp_rates = dp_rates or [0.0] * depth
        self.blocks = nn.Sequential(
            *[
                ConvNeXtBlock(
                    dim=out_chs,
                    drop_path=dp_rates[j],
                    ls_init_value=ls_init_value,
                    conv_mlp=conv_mlp,
                    norm_layer=norm_layer if conv_mlp else cl_norm_layer,
                )
                for j in range(depth)
            ]
        )

    def forward(self, x):
        x = self.downsample(x)
        x = self.blocks(x)
        return x


class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  - https://arxiv.org/pdf/2201.03545.pdf
    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_rate (float): Head dropout rate
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """

    def __init__(
        self,
        in_chans=3,
        img_size=224,
        num_classes=1000,
        global_pool="avg",
        output_stride=32,
        patch_size=4,
        depths=(3, 3, 9, 3),
        dims=(96, 192, 384, 768),
        ls_init_value=1e-6,
        conv_mlp=True,
        use_head=True,
        head_init_scale=1.0,
        head_norm_first=False,
        norm_layer=None,
        drop_rate=0.0,
        drop_path_rate=0.0,
        remove_last_downsample=False,
    ):
        super().__init__()
        assert output_stride == 32
        if norm_layer is None:
            norm_layer = partial(LayerNorm2d, eps=1e-6)
            cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
        else:
            assert (
                conv_mlp
            ), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input"
            cl_norm_layer = norm_layer

        self.num_classes = num_classes
        self.drop_rate = drop_rate
        self.feature_info = []

        # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
        self.stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
            norm_layer(dims[0]),
        )

        self.stages = nn.Sequential()
        dp_rates = [
            x.tolist()
            for x in flow.linspace(0, drop_path_rate, sum(depths)).split(depths)
        ]
        curr_stride = patch_size
        prev_chs = dims[0]
        stages = []
        # 4 feature resolution stages, each consisting of multiple residual blocks
        for i in range(len(depths)):
            stride = 2 if i > 0 else 1
            curr_stride *= stride
            out_chs = dims[i]
            no_downsample = remove_last_downsample and (i == len(depths) - 1)
            stages.append(
                ConvNeXtStage(
                    prev_chs,
                    out_chs,
                    stride=stride,
                    depth=depths[i],
                    dp_rates=dp_rates[i],
                    ls_init_value=ls_init_value,
                    conv_mlp=conv_mlp,
                    norm_layer=norm_layer,
                    cl_norm_layer=cl_norm_layer,
                    no_downsample=no_downsample,
                )
            )
            prev_chs = out_chs
            # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
            self.feature_info += [
                dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}")
            ]
        self.stages = nn.Sequential(*stages)

        self.num_features = prev_chs

        if head_norm_first:
            # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)
            self.norm_pre = norm_layer(
                self.num_features
            )  # final norm layer, before pooling
            if use_head:
                self.head = ClassifierHead(
                    self.num_features,
                    num_classes,
                    pool_type=global_pool,
                    drop_rate=drop_rate,
                )
        else:
            # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
            self.norm_pre = nn.Identity()
            if use_head:
                self.head = nn.Sequential(
                    OrderedDict(
                        [
                            (
                                "global_pool",
                                nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1)),
                            ),
                            ("norm", norm_layer(self.num_features)),
                            (
                                "flatten",
                                nn.Flatten(1) if global_pool else nn.Identity(),
                            ),
                            ("drop", nn.Dropout(self.drop_rate)),
                            (
                                "fc",
                                nn.Linear(self.num_features, num_classes)
                                if num_classes > 0
                                else nn.Identity(),
                            ),
                        ]
                    )
                )

        named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)

    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes=0, global_pool="avg"):
        if isinstance(self.head, ClassifierHead):
            # norm -> global pool -> fc
            self.head = ClassifierHead(
                self.num_features,
                num_classes,
                pool_type=global_pool,
                drop_rate=self.drop_rate,
            )
        else:
            # pool -> norm -> fc
            self.head = nn.Sequential(
                OrderedDict(
                    [
                        (
                            "global_pool",
                            nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1)),
                        ),
                        ("norm", self.head.norm),
                        ("flatten", nn.Flatten(1) if global_pool else nn.Identity()),
                        ("drop", nn.Dropout(self.drop_rate)),
                        (
                            "fc",
                            nn.Linear(self.num_features, num_classes)
                            if num_classes > 0
                            else nn.Identity(),
                        ),
                    ]
                )
            )

    def forward_features(self, x):
        x = self.stem(x)
        x = self.stages(x)
        x = self.norm_pre(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


def _init_weights(module, name=None, head_init_scale=1.0):
    if isinstance(module, nn.Conv2d):
        trunc_normal_(module.weight, std=0.02)
        nn.init.constant_(module.bias, 0)
    elif isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=0.02)
        nn.init.constant_(module.bias, 0)
        if name and "head." in name:
            module.weight.data.mul_(head_init_scale)
            module.bias.data.mul_(head_init_scale)


def checkpoint_filter_fn(state_dict, model):
    """ Remap FB checkpoints -> timm """
    if "model" in state_dict:
        state_dict = state_dict["model"]
    out_dict = {}
    import re

    for k, v in state_dict.items():
        k = k.replace("downsample_layers.0.", "stem.")
        k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k)
        k = re.sub(
            r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k
        )
        k = k.replace("dwconv", "conv_dw")
        k = k.replace("pwconv", "mlp.fc")
        k = k.replace("head.", "head.fc.")
        if k in model.state_dict().keys():
            if k.startswith("norm."):
                k = k.replace("norm", "head.norm")
            if v.ndim == 2 and "head" not in k:
                model_shape = model.state_dict()[k].shape
                v = v.reshape(model_shape)
            out_dict[k] = v
    return out_dict


def _create_hybrid_backbone(**kwargs):
    model = ConvNeXt(**kwargs)
    return model


class PositionalEncodingFourier(nn.Module):
    """
    Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
    """

    def __init__(self, hidden_dim=32, dim=768, temperature=10000):
        super().__init__()
        self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
        self.scale = 2 * math.pi
        self.temperature = temperature
        self.hidden_dim = hidden_dim
        self.dim = dim
        self.eps = 1e-6

    def forward(self, B: int, H: int, W: int):
        device = self.token_projection.weight.device
        y_embed = (
            flow.arange(1, H + 1, dtype=flow.float32, device=device)
            .unsqueeze(1)
            .repeat(1, 1, W)
        )
        x_embed = flow.arange(1, W + 1, dtype=flow.float32, device=device).repeat(
            1, H, 1
        )
        y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
        x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
        dim_t = flow.arange(self.hidden_dim, dtype=flow.float32, device=device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim)
        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = flow.stack(
            [pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4
        ).flatten(3)
        pos_y = flow.stack(
            [pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4
        ).flatten(3)
        pos = flow.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        pos = self.token_projection(pos)
        return pos.repeat(B, 1, 1, 1)  # (B, C, H, W)


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution + batch norm"""
    return flow.nn.Sequential(
        nn.Conv2d(
            in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
        ),
        nn.BatchNorm2d(out_planes),
    )


def sigmoid(x, inplace=False):
    return x.sigmoid_() if inplace else x.sigmoid()


def make_divisible(v, divisor=8, min_value=None):
    min_value = min_value or divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class SqueezeExcite(nn.Module):
    def __init__(
        self,
        in_chs,
        se_ratio=0.25,
        reduced_base_chs=None,
        act_layer=nn.ReLU,
        gate_fn=sigmoid,
        divisor=1,
        **_,
    ):
        super(SqueezeExcite, self).__init__()
        self.gate_fn = gate_fn
        reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
        self.act1 = act_layer(inplace=True)
        self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)

    def forward(self, x):
        x_se = self.avg_pool(x)
        x_se = self.conv_reduce(x_se)
        x_se = self.act1(x_se)
        x_se = self.conv_expand(x_se)
        x = x * self.gate_fn(x_se)
        return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x


class SEMlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
        linear=False,
        use_se=True,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        # self.dwconv = DWConv(hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.gamma = nn.Parameter(flow.ones(hidden_features), requires_grad=True)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.linear = linear
        if self.linear:
            self.relu = nn.ReLU(inplace=True)
        self.se = (
            SqueezeExcite(out_features, se_ratio=0.25) if use_se else nn.Identity()
        )
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = self.fc1(x)
        if self.linear:
            x = self.relu(x)
        # import pdb; pdb.set_trace()
        x = self.drop(self.gamma * self.dwconv(x, H, W)) + x
        x = self.fc2(x)
        x = self.drop(x)
        x = (
            self.se(x.permute(0, 2, 1).reshape(B, C, H, W))
            .reshape(B, C, N)
            .permute(0, 2, 1)
        )
        return x, H, W


class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
        linear=False,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.gamma = nn.Parameter(flow.ones(hidden_features), requires_grad=True)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.linear = linear
        if self.linear:
            self.relu = nn.ReLU(inplace=True)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        if self.linear:
            x = self.relu(x)
        x = self.drop(self.gamma * self.dwconv(x, H, W)) + x
        x = self.fc2(x)
        x = self.drop(x)
        return x


class ConvPatchEmbed(nn.Module):
    """Image to Patch Embedding using multiple convolutional layers"""

    def __init__(
        self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU
    ):
        super().__init__()
        img_size = to_2tuple(img_size)
        num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        #         import pdb; pdb.set_trace()

        if patch_size == 16:
            self.proj = flow.nn.Sequential(
                conv3x3(in_chans, embed_dim // 8, 2),
                act_layer(),
                conv3x3(embed_dim // 8, embed_dim // 4, 2),
                act_layer(),
                conv3x3(embed_dim // 4, embed_dim // 2, 2),
                act_layer(),
                conv3x3(embed_dim // 2, embed_dim, 2),
            )
        elif patch_size == 8:
            self.proj = flow.nn.Sequential(
                conv3x3(in_chans, embed_dim // 4, 2),
                act_layer(),
                conv3x3(embed_dim // 4, embed_dim // 2, 2),
                act_layer(),
                conv3x3(embed_dim // 2, embed_dim, 2),
            )
        elif patch_size == 4:
            self.proj = flow.nn.Sequential(
                conv3x3(in_chans, embed_dim // 4, 2),
                act_layer(),
                conv3x3(embed_dim // 4, embed_dim // 1, 2),
                # act_layer(),
                # conv3x3(embed_dim // 2, embed_dim, 2),
            )
        else:
            raise ("For convolutional projection, patch size has to be in [8, 16]")

    def forward(self, x):
        x = self.proj(x)
        Hp, Wp = x.shape[2], x.shape[3]
        x = x.flatten(2).transpose(1, 2)  # (B, N, C)
        return x, (Hp, Wp)


class DWConv(nn.Module):
    def __init__(
        self, in_features, out_features=None, act_layer=nn.GELU, kernel_size=3
    ):
        super().__init__()
        out_features = out_features or in_features

        padding = kernel_size // 2

        self.conv1 = flow.nn.Conv2d(
            in_features,
            in_features,
            kernel_size=kernel_size,
            padding=padding,
            groups=in_features,
        )
        self.act = act_layer()
        self.bn = nn.BatchNorm2d(in_features)
        self.conv2 = flow.nn.Conv2d(
            in_features,
            out_features,
            kernel_size=kernel_size,
            padding=padding,
            groups=out_features,
        )

    def forward(self, x, H: int, W: int):
        B, N, C = x.shape
        x = x.permute(0, 2, 1).reshape(B, C, H, W)
        x = self.conv1(x)
        x = self.act(x)
        x = self.bn(x)

        x = self.conv2(x)
        x = x.reshape(B, C, N).permute(0, 2, 1)
        return x


class ClassAttentionBlock(nn.Module):
    """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        eta=1.0,
        tokens_norm=False,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)

        self.attn = ClassAttn(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = MlpOri(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=drop,
        )

        if eta is not None:  # LayerScale Initialization (no layerscale when None)
            self.gamma1 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
            self.gamma2 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
        else:
            self.gamma1, self.gamma2 = 1.0, 1.0
        self.tokens_norm = tokens_norm

    def forward(self, x, return_attention=False):
        x_norm1 = self.norm1(x)
        if return_attention:
            x1, attn = self.attn(x_norm1, use_attn=return_attention)
        else:
            x1 = self.attn(x_norm1)
        x_attn = flow.cat([x1, x_norm1[:, 1:]], dim=1)
        x = x + self.drop_path(self.gamma1 * x_attn)
        if self.tokens_norm:
            x = self.norm2(x)
        else:
            x = flow.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
        x_res = x
        cls_token = x[:, 0:1]
        cls_token = self.gamma2 * self.mlp(cls_token)
        x = flow.cat([cls_token, x[:, 1:]], dim=1)
        x = x_res + self.drop_path(x)
        if return_attention:
            return attn
        return x


class TokenMixing(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        sr_ratio=1,
        linear=False,
        share_atten=False,
        drop_path=0.0,
        emlp=False,
        sharpen_attn=False,
        mlp_hidden_dim=None,
        act_layer=nn.GELU,
        drop=None,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        assert (
            dim % num_heads == 0
        ), f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.share_atten = share_atten
        self.emlp = emlp

        cha_sr = 1
        self.q = nn.Linear(dim, dim // cha_sr, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2 // cha_sr, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.linear = linear
        self.sr_ratio = sr_ratio
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W, atten=None, return_attention=False):
        B, N, C = x.shape
        q = (
            self.q(x)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )

        # import pdb;pdb.set_trace()
        kv = (
            self.kv(x)
            .reshape(B, -1, 2, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )

        k, v = kv[0], kv[1]
        attn = q * self.scale @ k.transpose(-2, -1)  # * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn @ v


class HybridEmbed(nn.Module):
    """ CNN Feature Map Embedding
    Extract feature map from CNN, flatten, project to embedding dim.
    """

    def __init__(
        self,
        backbone,
        img_size=224,
        patch_size=2,
        feature_size=None,
        in_chans=3,
        embed_dim=384,
    ):
        super().__init__()
        assert isinstance(backbone, nn.Module)
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.backbone = backbone
        if feature_size is None:
            with flow.no_grad():
                # NOTE Most reliable way of determining output dims is to run forward pass
                training = backbone.training
                if training:
                    backbone.eval()
                o = self.backbone.forward_features(
                    flow.zeros(1, in_chans, img_size[0], img_size[1])
                )
                if isinstance(o, (list, tuple)):
                    o = o[-1]  # last feature if backbone outputs list/tuple of features
                feature_size = o.shape[-2:]
                feature_dim = o.shape[1]
                backbone.train(training)
        else:
            feature_size = to_2tuple(feature_size)
            if hasattr(self.backbone, "feature_info"):
                feature_dim = self.backbone.feature_info.channels()[-1]
            else:
                feature_dim = self.backbone.num_features
        assert (
            feature_size[0] % patch_size[0] == 0
            and feature_size[1] % patch_size[1] == 0
        )
        self.grid_size = (
            feature_size[0] // patch_size[0],
            feature_size[1] // patch_size[1],
        )
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.proj = nn.Conv2d(
            feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.backbone.forward_features(x)
        B, C, H, W = x.shape
        if isinstance(x, (list, tuple)):
            x = x[-1]  # last feature if backbone outputs list/tuple of features
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x, (H // self.patch_size[0], W // self.patch_size[1])


class ChannelProcessing(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        attn_drop=0.0,
        linear=False,
        drop_path=0.0,
        mlp_hidden_dim=None,
        act_layer=nn.GELU,
        drop=None,
        norm_layer=nn.LayerNorm,
        cha_sr_ratio=1,
        c_head_num=None,
    ):
        super().__init__()
        assert (
            dim % num_heads == 0
        ), f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        num_heads = c_head_num or num_heads
        self.num_heads = num_heads
        self.temperature = nn.Parameter(flow.ones(num_heads, 1, 1))

        self.cha_sr_ratio = cha_sr_ratio if num_heads > 1 else 1

        # config of mlp for v processing
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.mlp_v = Mlp(
            in_features=dim // self.cha_sr_ratio,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
            linear=linear,
        )
        self.norm_v = norm_layer(dim // self.cha_sr_ratio)

        self.q = nn.Linear(dim, dim, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def _gen_attn(self, q, k):
        q = q.softmax(-2).transpose(-1, -2)
        _, _, N, _ = k.shape
        k = flow.nn.functional.adaptive_avg_pool2d(k.softmax(-2), (N, 1))

        attn = flow.nn.functional.sigmoid(q @ k)
        return attn * self.temperature

    def forward(self, x, H, W, atten=None):
        B, N, C = x.shape
        v = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        q = (
            self.q(x)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        k = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        attn = self._gen_attn(q, k)
        attn = self.attn_drop(attn)

        Bv, Hd, Nv, Cv = v.shape
        v = (
            self.norm_v(self.mlp_v(v.transpose(1, 2).reshape(Bv, Nv, Hd * Cv), H, W))
            .reshape(Bv, Nv, Hd, Cv)
            .transpose(1, 2)
        )

        repeat_time = N // attn.shape[-1]
        attn = (
            attn.repeat_interleave(repeat_time, dim=-1) if attn.shape[-1] > 1 else attn
        )
        x = (attn * v.transpose(-1, -2)).permute(0, 3, 1, 2).reshape(B, N, C)
        return x, (attn * v.transpose(-1, -2)).transpose(-1, -2)  # attn

    def no_weight_decay(self):
        return {"temperature"}


class FANBlock_SE(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        sharpen_attn=False,
        use_se=False,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        eta=1.0,
        sr_ratio=1.0,
        qk_scale=None,
        linear=False,
        downsample=None,
        c_head_num=None,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = TokenMixing(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            mlp_hidden_dim=int(dim * mlp_ratio),
            sharpen_attn=sharpen_attn,
            attn_drop=attn_drop,
            proj_drop=drop,
            drop=drop,
            drop_path=drop_path,
            sr_ratio=sr_ratio,
            linear=linear,
            emlp=False,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = SEMlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=drop,
        )

        self.gamma1 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
        self.gamma2 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)

    def forward(self, x, H: int, W: int, attn=None):
        x_new, _ = self.attn(self.norm1(x), H, W)
        x = x + self.drop_path(self.gamma1 * x_new)
        x_new, H, W = self.mlp(self.norm2(x), H, W)
        x = x + self.drop_path(self.gamma2 * x_new)
        return x, H, W


class FANBlock(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        sharpen_attn=False,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        eta=1.0,
        sr_ratio=1.0,
        downsample=None,
        c_head_num=None,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = TokenMixing(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            mlp_hidden_dim=int(dim * mlp_ratio),
            sharpen_attn=sharpen_attn,
            attn_drop=attn_drop,
            proj_drop=drop,
            drop=drop,
            drop_path=drop_path,
            sr_ratio=sr_ratio,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = ChannelProcessing(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            drop_path=drop_path,
            drop=drop,
            mlp_hidden_dim=int(dim * mlp_ratio),
            c_head_num=c_head_num,
        )

        self.gamma1 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)
        self.gamma2 = nn.Parameter(eta * flow.ones(dim), requires_grad=True)

        self.downsample = downsample
        self.H = None
        self.W = None

    def forward(self, x, attn=None, return_attention=False):
        H, W = self.H, self.W

        x_new, attn_s = self.attn(self.norm1(x), H, W)
        x = x + self.drop_path(self.gamma1 * x_new)

        x_new, attn_c = self.mlp(self.norm2(x), H, W, atten=attn)
        x = x + self.drop_path(self.gamma2 * x_new)
        if return_attention:
            return x, attn_s

        if self.downsample is not None:
            x, H, W = self.downsample(x, H, W)
        self.H, self.W = H, W
        return x


class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=(patch_size[0] // 2, patch_size[1] // 2),
        )
        self.norm = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(-1, -2).reshape(B, C, H, W)
        x = self.proj(x)
        _, _, H, W = x.shape

        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W


class FAN(nn.Module):
    """
    Based on timm code bases
    https://github.com/rwightman/pytorch-image-models/tree/master/timm
    """

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        sharpen_attn=False,
        channel_dims=None,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        sr_ratio=None,
        backbone=None,
        use_checkpoint=False,
        act_layer=None,
        norm_layer=None,
        se_mlp=False,
        cls_attn_layers=2,
        use_pos_embed=True,
        eta=1.0,
        tokens_norm=False,
        c_head_num=None,
        hybrid_patch_size=2,
        head_init_scale=1.0,
    ):

        super().__init__()
        img_size = to_2tuple(img_size)
        self.use_checkpoint = use_checkpoint
        assert (img_size[0] % patch_size == 0) and (
            img_size[0] % patch_size == 0
        ), "`patch_size` should divide image dimensions evenly"

        self.num_classes = num_classes
        num_heads = (
            [num_heads] * depth if not isinstance(num_heads, list) else num_heads
        )

        channel_dims = [embed_dim] * depth if channel_dims is None else channel_dims
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU
        if backbone == None:
            self.patch_embed = ConvPatchEmbed(
                img_size=img_size,
                patch_size=patch_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
                act_layer=act_layer,
            )
        else:
            self.patch_embed = HybridEmbed(
                backbone=backbone, patch_size=hybrid_patch_size, embed_dim=embed_dim
            )

        self.use_pos_embed = use_pos_embed
        if use_pos_embed:
            self.pos_embed = PositionalEncodingFourier(dim=embed_dim)
        self.pos_drop = nn.Dropout(p=drop_rate)

        if se_mlp:
            build_block = FANBlock_SE
        else:
            build_block = FANBlock
        self.blocks = nn.ModuleList([])
        for i in range(depth):
            if i < depth - 1 and channel_dims[i] != channel_dims[i + 1]:
                downsample = OverlapPatchEmbed(
                    img_size=img_size,
                    patch_size=3,
                    stride=2,
                    in_chans=channel_dims[i],
                    embed_dim=channel_dims[i + 1],
                )
            else:
                downsample = None
            self.blocks.append(
                build_block(
                    dim=channel_dims[i],
                    num_heads=num_heads[i],
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    sr_ratio=sr_ratio[i],
                    attn_drop=attn_drop_rate,
                    drop_path=drop_path_rate,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    eta=eta,
                    downsample=downsample,
                    c_head_num=c_head_num[i] if c_head_num is not None else None,
                )
            )
        self.num_features = self.embed_dim = channel_dims[i]
        self.cls_token = nn.Parameter(flow.zeros(1, 1, channel_dims[i]))
        self.cls_attn_blocks = nn.ModuleList(
            [
                ClassAttentionBlock(
                    dim=channel_dims[-1],
                    num_heads=num_heads[-1],
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    eta=eta,
                    tokens_norm=tokens_norm,
                )
                for _ in range(cls_attn_layers)
            ]
        )

        # Classifier head
        self.norm = norm_layer(channel_dims[i])
        self.head = (
            nn.Linear(self.num_features, num_classes)
            if num_classes > 0
            else nn.Identity()
        )

        # Init weights
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def no_weight_decay(self):
        return {"pos_embed", "cls_token"}  # , 'patch_embed'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=""):
        self.num_classes = num_classes
        self.head = (
            nn.Linear(self.num_features, num_classes)
            if num_classes > 0
            else nn.Identity()
        )

    def forward_features(self, x):
        B = x.shape[0]
        x, (Hp, Wp) = self.patch_embed(x)

        if self.use_pos_embed:
            pos_encoding = (
                self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
            )
            x = x + pos_encoding

        x = self.pos_drop(x)
        H, W = Hp, Wp
        for blk in self.blocks:
            blk.H, blk.W = H, W
            x = blk(x)
            H, W = blk.H, blk.W

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = flow.cat((cls_tokens, x), dim=1)

        for blk in self.cls_attn_blocks:
            x = blk(x)

        x = self.norm(x)[:, 0]
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def get_last_selfattention(self, x, use_cls_attn=False, layer_idx=11):
        B = x.shape[0]
        x, (Hp, Wp) = self.patch_embed(x)

        if self.use_pos_embed:
            pos_encoding = (
                self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
            )
            x = x + pos_encoding

        x = self.pos_drop(x)

        return_idx = layer_idx or len(self.blocks) - 1

        for i, blk in enumerate(self.blocks):
            if i == return_idx:
                x, attn = blk(x, Hp, Wp, return_attention=True)
            else:
                x, Hp, Wp = blk(x, Hp, Wp)

        if use_cls_attn:
            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = flow.cat((cls_tokens, x), dim=1)

            for i, blk in enumerate(self.cls_attn_blocks):
                if i < len(self.cls_attn_blocks) - 1:
                    x = blk(x)
                else:
                    attn = blk(x, return_attention=True)
                    return attn
        else:
            return attn


# FAN-ViT Models
[docs]@ModelCreator.register_model def fan_tiny_12_p16_224(pretrained: bool = False, progress: bool = True, **kwargs): """ Constructs FAN-ViT-tiny 224x224 model pretrained on ImageNet-1k. .. note:: FAN-ViT-tiny 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_tiny_12_p16_224 = flowvision.models.fan_tiny_12_p16_224(pretrained=False, progress=True) """ depth = 12 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_kwargs = dict( patch_size=16, embed_dim=192, depth=depth, num_heads=4, eta=1.0, tokens_norm=True, sharpen_attn=False, **kwargs, ) model = FAN(sr_ratio=sr_ratio, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_vit_tiny"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_small_12_p16_224(pretrained: bool = False, progress: bool = True, **kwargs): """ Constructs FAN-ViT-small 224x224 model pretrained on ImageNet-1k. .. note:: FAN-ViT-small 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_small_12_p16_224 = flowvision.models.fan_small_12_p16_224(pretrained=False, progress=True) """ depth = 12 sr_ratio = [1] * depth model_kwargs = dict( patch_size=16, embed_dim=384, depth=depth, num_heads=8, eta=1.0, tokens_norm=True, **kwargs, ) model = FAN(sr_ratio=sr_ratio, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_vit_small"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_base_18_p16_224(pretrained: bool = False, progress: bool = True, **kwargs): """ Constructs FAN-ViT-base 224x224 model pretrained on ImageNet-1k. .. note:: FAN-ViT-base 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_base_18_p16_224 = flowvision.models.fan_base_18_p16_224(pretrained=False, progress=True) """ depth = 18 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_kwargs = dict( patch_size=16, embed_dim=448, depth=depth, num_heads=8, eta=1.0, tokens_norm=True, sharpen_attn=False, **kwargs, ) model = FAN(sr_ratio=sr_ratio, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_vit_base"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_large_24_p16_224(pretrained: bool = False, progress: bool = True, **kwargs): """ Constructs FAN-ViT-large 224x224 model pretrained on ImageNet-1k. .. note:: FAN-ViT-large 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_large_24_p16_224 = flowvision.models.fan_large_24_p16_224(pretrained=False, progress=True) """ depth = 24 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_kwargs = dict( patch_size=16, embed_dim=480, depth=depth, num_heads=10, eta=1.0, tokens_norm=True, sharpen_attn=False, **kwargs, ) model = FAN(sr_ratio=sr_ratio, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_vit_large"], progress=progress ) model.load_state_dict(state_dict) return model
# FAN-Hybrid Models # CNN backbones are based on ConvNeXt architecture with only first two stages for downsampling purpose # This has been verified to be beneficial for downstream tasks
[docs]@ModelCreator.register_model def fan_tiny_8_p4_hybrid(pretrained: bool = False, progress: bool = True, **kwargs): """ Constructs FAN-Hybrid-tiny 224x224 model pretrained on ImageNet-1k. .. note:: FAN-Hybrid-tiny 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_tiny_8_p4_hybrid = flowvision.models.fan_tiny_8_p4_hybrid(pretrained=False, progress=True) """ depth = 8 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2 + 1) model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False) backbone = _create_hybrid_backbone(**model_args) model_kwargs = dict( patch_size=16, embed_dim=192, depth=depth, num_heads=8, eta=1.0, tokens_norm=True, sharpen_attn=False, **kwargs, ) model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_hybrid_tiny"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_small_12_p4_hybrid(pretrained: bool = False, progress: bool = True, **kwargs): """ Constructs FAN-Hybrid-small 224x224 model pretrained on ImageNet-1k. .. note:: FAN-Hybrid-small 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_small_12_p4_hybrid = flowvision.models.fan_small_12_p4_hybrid(pretrained=False, progress=True) """ depth = 10 channel_dims = [384] * 10 + [384] * (depth - 10) sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False) backbone = _create_hybrid_backbone(**model_args) model_kwargs = dict( patch_size=16, embed_dim=384, depth=depth, num_heads=8, eta=1.0, tokens_norm=True, sharpen_attn=False, **kwargs, ) model = FAN( sr_ratio=sr_ratio, backbone=backbone, channel_dims=channel_dims, **model_kwargs ) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_hybrid_small"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_base_16_p4_hybrid(pretrained: bool = False, progress: bool = True, **kwargs): """ Constructs FAN-Hybrid-base 224x224 model pretrained on ImageNet-1k. .. note:: FAN-Hybrid-base 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_base_16_p4_hybrid = flowvision.models.fan_base_16_p4_hybrid(pretrained=False, progress=True) """ depth = 16 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False) backbone = _create_hybrid_backbone(**model_args) model_kwargs = dict( patch_size=16, embed_dim=448, depth=depth, num_heads=8, eta=1.0, tokens_norm=True, sharpen_attn=False, **kwargs, ) model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_hybrid_base"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_large_16_p4_hybrid(pretrained: bool = False, progress: bool = True, **kwargs): """ Constructs FAN-Hybrid-large 224x224 model pretrained on ImageNet-1k. .. note:: FAN-Hybrid-large 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_large_16_p4_hybrid = flowvision.models.fan_large_16_p4_hybrid(pretrained=False, progress=True) """ depth = 22 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_args = dict(depths=[3, 5], dims=[128, 256, 512, 1024], use_head=False) backbone = _create_hybrid_backbone(**model_args) model_kwargs = dict( patch_size=16, embed_dim=480, depth=depth, num_heads=10, eta=1.0, tokens_norm=True, sharpen_attn=False, head_init_scale=0.001, **kwargs, ) model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_hybrid_large"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_base_16_p4_hybrid_in22k_1k( pretrained: bool = False, progress: bool = True, **kwargs ): """ Constructs FAN-Hybrid-base 224x224 model pretrained on ImageNet-21k. .. note:: FAN-Hybrid-base 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_base_16_p4_hybrid_in22k_1k = flowvision.models.fan_base_16_p4_hybrid_in22k_1k(pretrained=False, progress=True) """ depth = 16 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False) backbone = _create_hybrid_backbone(**model_args) model_kwargs = dict( patch_size=16, embed_dim=448, depth=depth, num_heads=8, eta=1.0, tokens_norm=True, sharpen_attn=False, **kwargs, ) model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_hybrid_base_in22k_1k"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_base_16_p4_hybrid_in22k_1k_384( pretrained: bool = False, progress: bool = True, **kwargs ): """ Constructs FAN-Hybrid-base 384x384 model pretrained on ImageNet-21k. .. note:: FAN-Hybrid-base 384x384 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_base_16_p4_hybrid_in22k_1k_384 = flowvision.models.fan_base_16_p4_hybrid_in22k_1k_384(pretrained=False, progress=True) """ depth = 16 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_args = dict(depths=[3, 3], dims=[128, 256, 512, 1024], use_head=False) backbone = _create_hybrid_backbone(**model_args) model_kwargs = dict( patch_size=16, embed_dim=448, depth=depth, num_heads=8, eta=1.0, tokens_norm=True, sharpen_attn=False, **kwargs, ) model = FAN(img_size=384, sr_ratio=sr_ratio, backbone=backbone, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_hybrid_base_in22k_1k_384"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_large_16_p4_hybrid_in22k_1k( pretrained: bool = False, progress: bool = True, **kwargs ): """ Constructs FAN-Hybrid-large 224x224 model pretrained on ImageNet-21k. .. note:: FAN-Hybrid-large 224x224 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_large_16_p4_hybrid_in22k_1k = flowvision.models.fan_large_16_p4_hybrid_in22k_1k(pretrained=False, progress=True) """ depth = 22 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_args = dict(depths=[3, 5], dims=[128, 256, 512, 1024], use_head=False) backbone = _create_hybrid_backbone(**model_args) model_kwargs = dict( patch_size=16, embed_dim=480, depth=depth, num_heads=10, eta=1.0, tokens_norm=True, sharpen_attn=False, head_init_scale=0.001, **kwargs, ) model = FAN(sr_ratio=sr_ratio, backbone=backbone, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_hybrid_large_in22k_1k"], progress=progress ) model.load_state_dict(state_dict) return model
[docs]@ModelCreator.register_model def fan_large_16_p4_hybrid_in22k_1k_384( pretrained: bool = False, progress: bool = True, **kwargs ): """ Constructs FAN-Hybrid-large 384x384 model pretrained on ImageNet-21k. .. note:: FAN-Hybrid-large 384x384 model from `"Understanding The Robustness in Vision Transformers" <https://arxiv.org/pdf/2204.12451>`_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> fan_large_16_p4_hybrid_in22k_1k_384 = flowvision.models.fan_large_16_p4_hybrid_in22k_1k_384(pretrained=False, progress=True) """ depth = 22 sr_ratio = [1] * (depth // 2) + [1] * (depth // 2) model_args = dict(depths=[3, 5], dims=[128, 256, 512, 1024], use_head=False) backbone = _create_hybrid_backbone(**model_args) model_kwargs = dict( patch_size=16, embed_dim=480, depth=depth, num_heads=10, eta=1.0, tokens_norm=True, sharpen_attn=False, head_init_scale=0.001, **kwargs, ) model = FAN(img_size=384, sr_ratio=sr_ratio, backbone=backbone, **model_kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls["fan_hybrid_large_in22k_1k_384"], progress=progress ) model.load_state_dict(state_dict) return model