"""
Modified from https://github.com/idstcv/GPU-Efficient-Networks/blob/master/GENet/__init__.py
"""
import uuid
import oneflow as flow
from oneflow import nn
import oneflow.nn.functional as F
import numpy as np
from .registry import ModelCreator
from .utils import load_state_dict_from_url
model_urls = {
"genet_small": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/GENet/GENet_small.zip",
"genet_normal": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/GENet/GENet_normal.zip",
"genet_large": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/GENet/GENet_large.zip",
}
def _fuse_convkx_and_bn_(convkx, bn):
the_weight_scale = bn.weight / flow.sqrt(bn.running_var + bn.eps)
convkx.weight[:] = convkx.weight * the_weight_scale.view((-1, 1, 1, 1))
the_bias_shift = (bn.weight * bn.running_mean) / flow.sqrt(bn.running_var + bn.eps)
bn.weight[:] = 1
bn.bias[:] = bn.bias - the_bias_shift
bn.running_var[:] = 1.0 - bn.eps
bn.running_mean[:] = 0.0
convkx.bias = nn.Parameter(bn.bias)
def remove_bn_in_superblock(super_block):
new_shortcut_list = []
for the_seq_list in super_block.shortcut_list:
assert isinstance(the_seq_list, nn.Sequential)
new_seq_list = []
last_block = None
for block in the_seq_list:
if isinstance(block, nn.BatchNorm2d):
_fuse_convkx_and_bn_(last_block, block)
else:
new_seq_list.append(block)
last_block = block
new_shortcut_list.append(nn.Sequential(*new_seq_list))
super_block.shortcut_list = nn.ModuleList(new_shortcut_list)
new_conv_list = []
for the_seq_list in super_block.conv_list:
assert isinstance(the_seq_list, nn.Sequential)
new_seq_list = []
last_block = None
for block in the_seq_list:
if isinstance(block, nn.BatchNorm2d):
_fuse_convkx_and_bn_(last_block, block)
else:
new_seq_list.append(block)
last_block = block
new_conv_list.append(nn.Sequential(*new_seq_list))
super_block.conv_list = nn.ModuleList(new_conv_list)
def fuse_bn(model):
the_block_list = model.block_list
last_block = the_block_list[0]
new_block_list = [last_block]
for the_block in the_block_list[1:]:
if isinstance(the_block, BN):
_fuse_convkx_and_bn_(last_block.netblock, the_block.netblock)
else:
new_block_list.append(the_block)
last_block = the_block
pass
the_block_list = new_block_list
for the_block in the_block_list:
if hasattr(the_block, "shortcut_list"):
remove_bn_in_superblock(the_block)
else:
continue
model.block_list = new_block_list
model.module_list = nn.ModuleList(new_block_list)
return model
def _create_netblock_list_from_str_(s, no_create=False):
block_list = []
while len(s) > 0:
is_found_block_class = False
for the_block_class_name in _all_netblocks_dict_.keys():
if s.startswith(the_block_class_name):
is_found_block_class = True
the_block_class = _all_netblocks_dict_[the_block_class_name]
the_block, remaining_s = the_block_class.create_from_str(
s, no_create=no_create
)
if the_block is not None:
block_list.append(the_block)
s = remaining_s
if len(s) > 0 and s[0] == ";":
return block_list, s[1:]
break
pass
pass
assert is_found_block_class
pass
return block_list, ""
def _get_right_parentheses_index_(s):
left_paren_count = 0
for index, x in enumerate(s):
if x == "(":
left_paren_count += 1
elif x == ")":
left_paren_count -= 1
if left_paren_count == 0:
return index
else:
pass
return None
class PlainNetBasicBlockClass(nn.Module):
def __init__(
self,
in_channels=0,
out_channels=0,
stride=1,
no_create=False,
block_name=None,
**kwargs
):
super(PlainNetBasicBlockClass, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.no_create = no_create
self.block_name = block_name
def forward(self, x):
return x
@staticmethod
def create_from_str(s, no_create=False):
assert PlainNetBasicBlockClass.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("PlainNetBasicBlockClass(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
param_str_split = param_str.split(",")
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
stride = int(param_str_split[2])
return (
PlainNetBasicBlockClass(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("PlainNetBasicBlockClass(") and s[-1] == ")":
return True
else:
return False
class AdaptiveAvgPool(PlainNetBasicBlockClass):
def __init__(
self, out_channels, output_size, no_create=False, block_name=None, **kwargs
):
super(AdaptiveAvgPool, self).__init__(**kwargs)
self.in_channels = out_channels
self.out_channels = out_channels * output_size ** 2
self.output_size = output_size
self.block_name = block_name
if not no_create:
self.netblock = nn.AdaptiveAvgPool2d(
output_size=(self.output_size, self.output_size)
)
def forward(self, x):
return self.netblock(x)
@staticmethod
def create_from_str(s, no_create=False):
assert AdaptiveAvgPool.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("AdaptiveAvgPool(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
param_str_split = param_str.split(",")
out_channels = int(param_str_split[0])
output_size = int(param_str_split[1])
return (
AdaptiveAvgPool(
out_channels=out_channels,
output_size=output_size,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("AdaptiveAvgPool(") and s[-1] == ")":
return True
else:
return False
class BN(PlainNetBasicBlockClass):
def __init__(
self,
out_channels=None,
copy_from=None,
no_create=False,
block_name=None,
**kwargs
):
super(BN, self).__init__(**kwargs)
self.block_name = block_name
if copy_from is not None:
assert isinstance(copy_from, nn.BatchNorm2d)
self.in_channels = copy_from.weight.shape[0]
self.out_channels = copy_from.weight.shape[0]
assert out_channels is None or out_channels == self.out_channels
self.netblock = copy_from
else:
self.in_channels = out_channels
self.out_channels = out_channels
if no_create:
return
else:
self.netblock = nn.BatchNorm2d(num_features=self.out_channels)
def forward(self, x):
return self.netblock(x)
@staticmethod
def create_from_str(s, no_create=False):
assert BN.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("BN(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
out_channels = int(param_str)
return (
BN(
out_channels=out_channels,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("BN(") and s[-1] == ")":
return True
else:
return False
class ConvDW(PlainNetBasicBlockClass):
def __init__(
self,
out_channels=None,
kernel_size=None,
stride=None,
copy_from=None,
no_create=False,
block_name=None,
**kwargs
):
super(ConvDW, self).__init__(**kwargs)
self.block_name = block_name
self.use_weight_mean_zero_constrain = False
if copy_from is not None:
assert isinstance(copy_from, nn.Conv2d)
self.in_channels = copy_from.in_channels
self.out_channels = copy_from.out_channels
self.kernel_size = copy_from.kernel_size[0]
self.stride = copy_from.stride[0]
assert self.in_channels == self.out_channels
assert out_channels is None or out_channels == self.out_channels
assert kernel_size is None or kernel_size == self.kernel_size
assert stride is None or stride == self.stride
self.netblock = copy_from
else:
self.in_channels = out_channels
self.out_channels = out_channels
self.stride = stride
self.kernel_size = kernel_size
self.padding = (self.kernel_size - 1) // 2
if (
no_create
or self.in_channels == 0
or self.out_channels == 0
or self.kernel_size == 0
or self.stride == 0
):
return
else:
self.netblock = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=False,
groups=self.in_channels,
)
def forward(self, x):
output = self.netblock(x)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert ConvDW.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("ConvDW(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
split_str = param_str.split(",")
out_channels = int(split_str[0])
kernel_size = int(split_str[1])
stride = int(split_str[2])
return (
ConvDW(
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
no_create=no_create,
block_name=tmp_block_name,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("ConvDW(") and s[-1] == ")":
return True
else:
return False
class ConvKX(PlainNetBasicBlockClass):
def __init__(
self,
in_channels=None,
out_channels=None,
kernel_size=None,
stride=None,
copy_from=None,
no_create=False,
block_name=None,
**kwargs
):
super(ConvKX, self).__init__(**kwargs)
self.block_name = block_name
self.use_weight_mean_zero_constrain = False
if copy_from is not None:
assert isinstance(copy_from, nn.Conv2d)
self.in_channels = copy_from.in_channels
self.out_channels = copy_from.out_channels
self.kernel_size = copy_from.kernel_size[0]
self.stride = copy_from.stride[0]
assert in_channels is None or in_channels == self.in_channels
assert out_channels is None or out_channels == self.out_channels
assert kernel_size is None or kernel_size == self.kernel_size
assert stride is None or stride == self.stride
self.netblock = copy_from
else:
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.kernel_size = kernel_size
self.padding = (self.kernel_size - 1) // 2
if (
no_create
or self.in_channels == 0
or self.out_channels == 0
or self.kernel_size == 0
or self.stride == 0
):
return
else:
self.netblock = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=False,
)
def forward(self, x):
output = self.netblock(x)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert ConvKX.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("ConvKX(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
split_str = param_str.split(",")
in_channels = int(split_str[0])
out_channels = int(split_str[1])
kernel_size = int(split_str[2])
stride = int(split_str[3])
return (
ConvKX(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
no_create=no_create,
block_name=tmp_block_name,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("ConvKX(") and s[-1] == ")":
return True
else:
return False
class Flatten(PlainNetBasicBlockClass):
def __init__(self, out_channels, no_create=False, block_name=None, **kwargs):
super(Flatten, self).__init__(**kwargs)
self.block_name = block_name
self.in_channels = out_channels
self.out_channels = out_channels
def forward(self, x):
return flow.flatten(x, 1)
@staticmethod
def create_from_str(s, no_create=False):
assert Flatten.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("Flatten(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
out_channels = int(param_str)
return (
Flatten(
out_channels=out_channels,
no_create=no_create,
block_name=tmp_block_name,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("Flatten(") and s[-1] == ")":
return True
else:
return False
class Linear(PlainNetBasicBlockClass):
def __init__(
self,
in_channels=None,
out_channels=None,
bias=None,
copy_from=None,
no_create=False,
block_name=None,
**kwargs
):
super(Linear, self).__init__(**kwargs)
self.block_name = block_name
if copy_from is not None:
assert isinstance(copy_from, nn.Linear)
self.in_channels = copy_from.in_channels
self.out_channels = copy_from.out_channels
self.bias = copy_from.bias
assert in_channels is None or in_channels == self.in_channels
assert out_channels is None or out_channels == self.out_channels
assert bias is None or bias == self.bias
self.netblock = copy_from
else:
self.in_channels = in_channels
self.out_channels = out_channels
self.bias = bias
if not no_create:
self.netblock = nn.Linear(
self.in_channels, self.out_channels, bias=self.bias
)
def forward(self, x):
return self.netblock(x)
@staticmethod
def create_from_str(s, no_create=False):
assert Linear.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("Linear(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
split_str = param_str.split(",")
in_channels = int(split_str[0])
out_channels = int(split_str[1])
bias = int(split_str[2])
return (
Linear(
in_channels=in_channels,
out_channels=out_channels,
bias=bias == 1,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("Linear(") and s[-1] == ")":
return True
else:
return False
class MaxPool(PlainNetBasicBlockClass):
def __init__(
self,
out_channels,
kernel_size,
stride,
no_create=False,
block_name=None,
**kwargs
):
super(MaxPool, self).__init__(**kwargs)
self.block_name = block_name
self.in_channels = out_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = (kernel_size - 1) // 2
if not no_create:
self.netblock = nn.MaxPool2d(
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
)
def forward(self, x):
return self.netblock(x)
@staticmethod
def create_from_str(s, no_create=False):
assert MaxPool.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("MaxPool(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
param_str_split = param_str.split(",")
out_channels = int(param_str_split[0])
kernel_size = int(param_str_split[1])
stride = int(param_str_split[2])
return (
MaxPool(
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
no_create=no_create,
block_name=tmp_block_name,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("MaxPool(") and s[-1] == ")":
return True
else:
return False
class MultiSumBlock(PlainNetBasicBlockClass):
def __init__(self, inner_block_list, no_create=False, block_name=None, **kwargs):
super(MultiSumBlock, self).__init__(**kwargs)
self.block_name = block_name
self.inner_block_list = inner_block_list
if not no_create:
self.inner_module_list = nn.ModuleList(inner_block_list)
self.in_channels = np.max([x.in_channels for x in inner_block_list])
self.out_channels = np.max([x.out_channels for x in inner_block_list])
res = 1024
res = self.inner_block_list[0].get_output_resolution(res)
self.stride = 1024 // res
def forward(self, x):
output = self.inner_block_list[0](x)
for inner_block in self.inner_block_list[1:]:
output2 = inner_block(x)
output = output + output2
return output
@staticmethod
def create_from_str(s, no_create=False):
assert MultiSumBlock.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("MultiSumBlock(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
the_s = param_str
the_inner_block_list = []
while len(the_s) > 0:
tmp_block_list, remaining_s = _create_netblock_list_from_str_(
the_s, no_create=no_create
)
the_s = remaining_s
if tmp_block_list is None:
pass
elif len(tmp_block_list) == 1:
the_inner_block_list.append(tmp_block_list[0])
else:
the_inner_block_list.append(
Sequential(inner_block_list=tmp_block_list, no_create=no_create)
)
pass # end while
if len(the_inner_block_list) == 0:
return None, s[idx + 1 :]
return (
MultiSumBlock(
inner_block_list=the_inner_block_list,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("MultiSumBlock(") and s[-1] == ")":
return True
else:
return False
class RELU(PlainNetBasicBlockClass):
def __init__(self, out_channels, no_create=False, block_name=None, **kwargs):
super(RELU, self).__init__(**kwargs)
self.block_name = block_name
self.in_channels = out_channels
self.out_channels = out_channels
def forward(self, x):
return F.relu(x)
@staticmethod
def create_from_str(s, no_create=False):
assert RELU.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("RELU(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
out_channels = int(param_str)
return (
RELU(
out_channels=out_channels,
no_create=no_create,
block_name=tmp_block_name,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("RELU(") and s[-1] == ")":
return True
else:
return False
class ResBlock(PlainNetBasicBlockClass):
"""
ResBlock(in_channles, inner_blocks_str). If in_channels is missing, use inner_block_list[0].in_channels as in_channels
"""
def __init__(
self,
inner_block_list,
in_channels=None,
stride=None,
no_create=False,
block_name=None,
**kwargs
):
super(ResBlock, self).__init__(**kwargs)
self.block_name = block_name
self.inner_block_list = inner_block_list
self.stride = stride
if not no_create:
self.inner_module_list = nn.ModuleList(inner_block_list)
if in_channels is None:
self.in_channels = inner_block_list[0].in_channels
else:
self.in_channels = in_channels
self.out_channels = max(self.in_channels, inner_block_list[-1].out_channels)
if self.stride is None:
tmp_input_res = 1024
tmp_output_res = self.get_output_resolution(tmp_input_res)
self.stride = tmp_input_res // tmp_output_res
def forward(self, x):
if self.stride > 1:
downsampled_x = F.avg_pool2d(
x,
kernel_size=self.stride + 1,
stride=self.stride,
padding=self.stride // 2,
)
else:
downsampled_x = x
if len(self.inner_block_list) == 0:
return downsampled_x
output = x
for inner_block in self.inner_block_list:
output = inner_block(output)
output = output + downsampled_x
return output
@staticmethod
def create_from_str(s, no_create=False):
assert ResBlock.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
the_stride = None
param_str = s[len("ResBlock(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
first_comma_index = param_str.find(",")
if (
first_comma_index < 0 or not param_str[0:first_comma_index].isdigit()
): # cannot parse in_channels, missing, use default
in_channels = None
the_inner_block_list, remaining_s = _create_netblock_list_from_str_(
param_str, no_create=no_create
)
else:
in_channels = int(param_str[0:first_comma_index])
param_str = param_str[first_comma_index + 1 :]
second_comma_index = param_str.find(",")
if second_comma_index < 0 or not param_str[0:second_comma_index].isdigit():
the_inner_block_list, remaining_s = _create_netblock_list_from_str_(
param_str, no_create=no_create
)
else:
the_stride = int(param_str[0:second_comma_index])
param_str = param_str[second_comma_index + 1 :]
the_inner_block_list, remaining_s = _create_netblock_list_from_str_(
param_str, no_create=no_create
)
pass
pass
assert len(remaining_s) == 0
if the_inner_block_list is None or len(the_inner_block_list) == 0:
return None, s[idx + 1 :]
return (
ResBlock(
inner_block_list=the_inner_block_list,
in_channels=in_channels,
stride=the_stride,
no_create=no_create,
block_name=tmp_block_name,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("ResBlock(") and s[-1] == ")":
return True
else:
return False
class Sequential(PlainNetBasicBlockClass):
def __init__(self, inner_block_list, no_create=False, block_name=None, **kwargs):
super(Sequential, self).__init__(**kwargs)
self.block_name = block_name
self.inner_block_list = inner_block_list
if not no_create:
self.inner_module_list = nn.ModuleList(inner_block_list)
self.in_channels = inner_block_list[0].in_channels
self.out_channels = inner_block_list[-1].out_channels
res = 1024
for block in self.inner_block_list:
res = block.get_output_resolution(res)
self.stride = 1024 // res
def forward(self, x):
output = x
for inner_block in self.inner_block_list:
output = inner_block(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert Sequential.is_instance_from_str(s)
the_right_paraen_idx = _get_right_parentheses_index_(s)
param_str = s[len("Sequential(") + 1 : the_right_paraen_idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
the_inner_block_list, remaining_s = _create_netblock_list_from_str_(
param_str, no_create=no_create
)
assert len(remaining_s) == 0
if the_inner_block_list is None or len(the_inner_block_list) == 0:
return None, ""
return (
Sequential(
inner_block_list=the_inner_block_list,
no_create=no_create,
block_name=tmp_block_name,
),
"",
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("Sequential("):
return True
else:
return False
"""
Super Blocks
"""
class SuperResKXKX(PlainNetBasicBlockClass):
def __init__(
self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs
):
super(SuperResKXKX, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion)
)
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_expansion_channel,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False,
),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_out_channels,
kernel_size=current_kernel_size,
stride=1,
padding=(current_kernel_size - 1) // 2,
bias=False,
),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_out_channels),
)
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResKXKX.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("SuperResKXKX(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
param_str_split = param_str.split(",")
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return (
SuperResKXKX(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("SuperResKXKX(") and s[-1] == ")":
return True
else:
return False
class SuperResK1KX(PlainNetBasicBlockClass):
def __init__(
self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs
):
super(SuperResK1KX, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion)
)
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_expansion_channel,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_out_channels,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False,
),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_out_channels),
)
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResK1KX.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("SuperResK1KX(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
param_str_split = param_str.split(",")
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return (
SuperResK1KX(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("SuperResK1KX(") and s[-1] == ")":
return True
else:
return False
class SuperResK1KXK1(PlainNetBasicBlockClass):
def __init__(
self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs
):
super(SuperResK1KXK1, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion)
)
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_expansion_channel,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_expansion_channel,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False,
),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_out_channels),
)
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResK1KXK1.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("SuperResK1KXK1(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
param_str_split = param_str.split(",")
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return (
SuperResK1KXK1(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("SuperResK1KXK1(") and s[-1] == ")":
return True
else:
return False
class SuperResK1DWK1(PlainNetBasicBlockClass):
def __init__(
self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs
):
super(SuperResK1DWK1, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion)
)
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_expansion_channel,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_expansion_channel,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False,
groups=current_expansion_channel,
),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_out_channels),
)
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResK1DWK1.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("SuperResK1DWK1(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
param_str_split = param_str.split(",")
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return (
SuperResK1DWK1(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("SuperResK1DWK1(") and s[-1] == ")":
return True
else:
return False
class SuperResK1DW(PlainNetBasicBlockClass):
def __init__(
self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs
):
super(SuperResK1DW, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
assert abs(expansion - 1) < 1e-6
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion)
)
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_out_channels,
current_out_channels,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False,
groups=current_out_channels,
),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False,
),
nn.BatchNorm2d(current_out_channels),
)
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResK1DW.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len("SuperResK1DW(") : idx]
# find block_name
tmp_idx = param_str.find("|")
if tmp_idx < 0:
tmp_block_name = "uuid{}".format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1 :]
param_str_split = param_str.split(",")
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return (
SuperResK1DW(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create,
),
s[idx + 1 :],
)
@staticmethod
def is_instance_from_str(s):
if s.startswith("SuperResK1DW(") and s[-1] == ")":
return True
else:
return False
_all_netblocks_dict_ = {
"AdaptiveAvgPool": AdaptiveAvgPool,
"BN": BN,
"ConvDW": ConvDW,
"ConvKX": ConvKX,
"Flatten": Flatten,
"Linear": Linear,
"MaxPool": MaxPool,
"MultiSumBlock": MultiSumBlock,
"PlainNetBasicBlockClass": PlainNetBasicBlockClass,
"RELU": RELU,
"ResBlock": ResBlock,
"Sequential": Sequential,
"SuperResKXKX": SuperResKXKX,
"SuperResK1KXK1": SuperResK1KXK1,
"SuperResK1DWK1": SuperResK1DWK1,
"SuperResK1KX": SuperResK1KX,
"SuperResK1DW": SuperResK1DW,
}
class PlainNet(nn.Module):
def __init__(
self, num_classes=1000, plainnet_struct=None, no_create=False, **kwargs
):
super(PlainNet, self).__init__(**kwargs)
self.num_classes = num_classes
self.plainnet_struct = plainnet_struct
the_s = self.plainnet_struct # type: str
block_list, remaining_s = _create_netblock_list_from_str_(
the_s, no_create=no_create
)
assert len(remaining_s) == 0
if isinstance(block_list[-1], AdaptiveAvgPool):
self.adptive_avg_pool = block_list[-1]
block_list.pop(-1)
else:
self.adptive_avg_pool = AdaptiveAvgPool(
out_channels=block_list[-1].out_channels, output_size=1
)
self.block_list = block_list
if not no_create:
self.module_list = nn.ModuleList(block_list) # register
self.last_channels = self.adptive_avg_pool.out_channels
if no_create:
self.fc_linear = None
else:
self.fc_linear = nn.Linear(self.last_channels, self.num_classes, bias=True)
self.plainnet_struct = str(self) + str(self.adptive_avg_pool)
def forward(self, x):
output = x
for the_block in self.block_list:
output = the_block(output)
output = self.adptive_avg_pool(output)
output = flow.flatten(output, 1)
output = self.fc_linear(output)
return output
CONFIG = {
"small": "ConvKX(uuid46ff2328b77f40ff88aed69a5318d771|3,13,3,2)BN(uuid43b72f65311c42d9a1af485c594a6ab4|13)"
"RELU(uuid282901aaa7f84b028e3c5bd7d37ae056|13)"
"SuperResKXKX(uuiddb56d6f9a60b4455966e13b06a8ff723|13,48,3,2,1.0,1)"
"SuperResKXKX(uuidd964406e6fdf4e9abac225afaeb1fe0b|48,48,3,2,1.0,3)"
"SuperResK1KXK1(uuid39819ad4f4da405583de614af437b568|48,384,3,2,0.25,7)"
"SuperResK1DWK1(uuid420593fe7b1e46f690b76bac3786d4b7|384,560,3,2,3.0,2)"
"SuperResK1DWK1(uuid96236b3c50774f1ab2d3049d6aca6d85|560,256,3,1,3.0,1)"
"ConvKX(uuid89ed263767a14f21b7426cccb120ad1d|256,1920,1,1)BN(uuidd6ad568b290544be9f4b47dc3fa271c9|1920)"
"RELU(uuid823ced7441394fb9b3a96a5f7c40da2b|1920)AdaptiveAvgPool(1920,1)",
"normal": "ConvKX(uuid70de938099844017bd745349f7a1d35a|3,32,3,2)BN(uuid10f8a99f83294067bfdf5fc5a5c9bffd|32)"
"RELU(uuideffe03bd73254e7c8027364ba71d25cd|32)"
"SuperResKXKX(uuidb023bea8c7b34c22a1650e07dfc8e2c1|32,128,3,2,1.0,1)"
"SuperResKXKX(uuidf829740023044b879eefaf7fc7d1ad8e|128,192,3,2,1.0,2)"
"SuperResK1KXK1(uuid33bfe77cb8864357a840ca3341ea629a|192,640,3,2,0.25,6)"
"SuperResK1DWK1(uuide2c948d819fb4869980e30d67a773244|640,640,3,2,3.0,4)"
"SuperResK1DWK1(uuid53c308e481c24154b7a81fcbaf99edbf|640,640,3,1,3.0,1)"
"ConvKX(uuidbc6953bfd8de45fc8534787a66b96430|640,2560,1,1)BN(uuida8acaaae74ed47a4a7514b41c643eb23|2560)"
"RELU(uuida5d71c4fd5d24a7b848472f0383df467|2560)AdaptiveAvgPool(2560,1)",
"large": "ConvKX(uuid9d1dca0f098143aaa1a947acf1100787|3,32,3,2)BN(uuid7d10ba10dc524ffb8863ae97c4a21797|32)"
"RELU(uuidccd810d3d10a48158ccfa48ca975915c|32)"
"SuperResKXKX(uuid5ba1db21fce64b16a34ad577c258fd6c|32,128,3,2,1.0,1)"
"SuperResKXKX(uuida09fc4e4946444bf9b912f8c666c4b12|128,192,3,2,1.0,2)"
"SuperResK1KXK1(uuidfa45c5f5cc96435dbd54801f31c83ca8|192,640,3,2,0.25,6)"
"SuperResK1DWK1(uuid99bf6442b33643579dc680045da7549d|640,640,3,2,3.0,5)"
"SuperResK1DWK1(uuid615cbfd4ed284cbc8589d84cbe9b0e92|640,640,3,1,3.0,4)"
"ConvKX(uuid002fa25f74f14cdeb89a5aacd6ce64ff|640,2560,1,1)BN(uuidc5d6c88c326343efa2a8700907f87732|2560)"
"RELU(uuidd2b39caab4cb4ac2b6905b18858c0037|2560)AdaptiveAvgPool(2560,1)",
}
[docs]@ModelCreator.register_model
def genet_large(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs GENet-large 256x256 model pretrained on ImageNet-1k.
.. note::
GENet-large 256x256 model from `"Neural Architecture Design for GPU-Efficient Networks" <https://arxiv.org/pdf/2006.14090>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> genet_large = flowvision.models.genet_large(pretrained=False, progress=True)
"""
plainnet_struct = CONFIG["large"]
model = PlainNet(plainnet_struct=plainnet_struct, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["genet_large"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def genet_normal(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs GENet-normal 192x192 model pretrained on ImageNet-1k.
.. note::
GENet-normal 192x192 model from `"Neural Architecture Design for GPU-Efficient Networks" <https://arxiv.org/pdf/2006.14090>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> genet_normal = flowvision.models.genet_normal(pretrained=False, progress=True)
"""
plainnet_struct = CONFIG["normal"]
model = PlainNet(plainnet_struct=plainnet_struct, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["genet_normal"], progress=progress
)
model.load_state_dict(state_dict)
return model
[docs]@ModelCreator.register_model
def genet_small(pretrained: bool = False, progress: bool = True, **kwargs):
"""
Constructs GENet-small 192x192 model pretrained on ImageNet-1k.
.. note::
GENet-small 192x192 model from `"Neural Architecture Design for GPU-Efficient Networks" <https://arxiv.org/pdf/2006.14090>`_.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> genet_small = flowvision.models.genet_small(pretrained=False, progress=True)
"""
plainnet_struct = CONFIG["small"]
model = PlainNet(plainnet_struct=plainnet_struct, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls["genet_small"], progress=progress
)
model.load_state_dict(state_dict)
return model