论文介绍了一种名为 UniRepLKNet 的新型大核卷积神经网络(ConvNet),它在图像识别、音频、视频、点云、时间序列等多种模态的任务上表现出色,展示了卷积神经网络在多模态领域的巨大潜力。
创建完成后,将下面代码直接复制粘贴进去:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import trunc_normal_, DropPath, to_2tuple
from functools import partial
import torch.utils.checkpoint as checkpoint
import numpy as np
__all__ = ['unireplknet_a', 'unireplknet_f', 'unireplknet_p', 'unireplknet_n', 'unireplknet_t', 'unireplknet_s', 'unireplknet_b', 'unireplknet_l', 'unireplknet_xl']
class GRNwithNHWC(nn.Module):
""" GRN (Global Response Normalization) layer
Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)
This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)
We assume the inputs to this layer are (N, H, W, C)
"""
def __init__(self, dim, use_bias=True):
super().__init__()
self.use_bias = use_bias
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
if self.use_bias:
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
if self.use_bias:
return (self.gamma * Nx + 1) * x + self.beta
else:
return (self.gamma * Nx + 1) * x
class NCHWtoNHWC(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 2, 3, 1)
class NHWCtoNCHW(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 3, 1, 2)
#================== This function decides which conv implementation (the native or iGEMM) to use
# Note that iGEMM large-kernel conv impl will be used if
# - you attempt to do so (attempt_to_use_large_impl=True), and
# - it has been installed (follow https://github.com/AILab-CVC/UniRepLKNet), and
# - the conv layer is depth-wise, stride = 1, non-dilated, kernel_size > 5, and padding == kernel_size // 2
def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias,
attempt_use_lk_impl=True):
kernel_size = to_2tuple(kernel_size)
if padding is None:
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
else:
padding = to_2tuple(padding)
need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)
# if attempt_use_lk_impl and need_large_impl:
# print('---------------- trying to import iGEMM implementation for large-kernel conv')
# try:
# from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
# print('---------------- found iGEMM implementation ')
# except:
# DepthWiseConv2dImplicitGEMM = None
# print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')
# if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \
# and out_channels == groups and stride == 1 and dilation == 1:
# print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')
# return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
def get_bn(dim, use_sync_bn=False):
if use_sync_bn:
return nn.SyncBatchNorm(dim)
else:
return nn.BatchNorm2d(dim)
class SEBlock(nn.Module):
"""
Squeeze-and-Excitation Block proposed in SENet (https://arxiv.org/abs/1709.01507)
We assume the inputs to this layer are (N, C, H, W)
"""
def __init__(self, input_channels, internal_neurons):
super(SEBlock, self).__init__()
self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons,
kernel_size=1, stride=1, bias=True)
self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels,
kernel_size=1, stride=1, bias=True)
self.input_channels = input_channels
self.nonlinear = nn.ReLU(inplace=True)
def forward(self, inputs):
x = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))
x = self.down(x)
x = self.nonlinear(x)
x = self.up(x)
x = F.sigmoid(x)
return inputs * x.view(-1, self.input_channels, 1, 1)
def fuse_bn(conv, bn):
conv_bias = 0 if conv.bias is None else conv.bias
std = (bn.running_var + bn.eps).sqrt()
return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std
def convert_dilated_to_nondilated(kernel, dilate_rate):
identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)
if kernel.size(1) == 1:
# This is a DW kernel
dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)
return dilated
else:
# This is a dense or group-wise (but not DW) kernel
slices = []
for i in range(kernel.size(1)):
dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)
slices.append(dilated)
return torch.cat(slices, dim=1)
def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):
large_k = large_kernel.size(2)
dilated_k = dilated_kernel.size(2)
equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1
equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)
rows_to_pad = large_k // 2 - equivalent_kernel_size // 2
merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)
return merged_kernel
class DilatedReparamBlock(nn.Module):
"""
Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)
We assume the inputs to this block are (N, C, H, W)
"""
def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):
super().__init__()
self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,
padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,
attempt_use_lk_impl=attempt_use_lk_impl)
self.attempt_use_lk_impl = attempt_use_lk_impl
# Default settings. We did not tune them carefully. Different settings may work better.
if kernel_size == 17:
self.kernel_sizes = [5, 9, 3, 3, 3]
self.dilates = [1, 2, 4, 5, 7]
elif kernel_size == 15:
self.kernel_sizes = [5, 7, 3, 3, 3]
self.dilates = [1, 2, 3, 5, 7]
elif kernel_size == 13:
self.kernel_sizes = [5, 7, 3, 3, 3]
self.dilates = [1, 2, 3, 4, 5]
elif kernel_size == 11:
self.kernel_sizes = [5, 5, 3, 3, 3]
self.dilates = [1, 2, 3, 4, 5]
elif kernel_size == 9:
self.kernel_sizes = [5, 5, 3, 3]
self.dilates = [1, 2, 3, 4]
elif kernel_size == 7:
self.kernel_sizes = [5, 3, 3]
self.dilates = [1, 2, 3]
elif kernel_size == 5:
self.kernel_sizes = [3, 3]
self.dilates = [1, 2]
else:
raise ValueError('Dilated Reparam Block requires kernel_size >= 5')
if not deploy:
self.origin_bn = get_bn(channels, use_sync_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
self.__setattr__('dil_conv_k{}_{}'.format(k, r),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,
padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,
bias=False))
self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))
def forward(self, x):
if not hasattr(self, 'origin_bn'): # deploy mode
return self.lk_origin(x)
out = self.origin_bn(self.lk_origin(x))
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
out = out + bn(conv(x))
return out
def merge_dilated_branches(self):
if hasattr(self, 'origin_bn'):
origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
branch_k, branch_b = fuse_bn(conv, bn)
origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)
origin_b += branch_b
merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,
padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,
attempt_use_lk_impl=self.attempt_use_lk_impl)
merged_conv.weight.data = origin_k
merged_conv.bias.data = origin_b
self.lk_origin = merged_conv
self.__delattr__('origin_bn')
for k, r in zip(self.kernel_sizes, self.dilates):
self.__delattr__('dil_conv_k{}_{}'.format(k, r))
self.__delattr__('dil_bn_k{}_{}'.format(k, r))
class UniRepLKNetBlock(nn.Module):
def __init__(self,
dim,
kernel_size,
drop_path=0.,
layer_scale_init_value=1e-6,
deploy=False,
attempt_use_lk_impl=True,
with_cp=False,
use_sync_bn=False,
ffn_factor=4):
super().__init__()
self.with_cp = with_cp
# if deploy:
# print('------------------------------- Note: deploy mode')
# if self.with_cp:
# print('****** note with_cp = True, reduce memory consumption but may slow down training ******')
self.need_contiguous = (not deploy) or kernel_size >= 7
if kernel_size == 0:
self.dwconv = nn.Identity()
self.norm = nn.Identity()
elif deploy:
self.dwconv = get_conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
dilation=1, groups=dim, bias=True,
attempt_use_lk_impl=attempt_use_lk_impl)
self.norm = nn.Identity()
elif kernel_size >= 7:
self.dwconv = DilatedReparamBlock(dim, kernel_size, deploy=deploy,
use_sync_bn=use_sync_bn,
attempt_use_lk_impl=attempt_use_lk_impl)
self.norm = get_bn(dim, use_sync_bn=use_sync_bn)
elif kernel_size == 1:
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
dilation=1, groups=1, bias=deploy)
self.norm = get_bn(dim, use_sync_bn=use_sync_bn)
else:
assert kernel_size in [3, 5]
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
dilation=1, groups=dim, bias=deploy)
self.norm = get_bn(dim, use_sync_bn=use_sync_bn)
self.se = SEBlock(dim, dim // 4)
ffn_dim = int(ffn_factor * dim)
self.pwconv1 = nn.Sequential(
NCHWtoNHWC(),
nn.Linear(dim, ffn_dim))
self.act = nn.Sequential(
nn.GELU(),
GRNwithNHWC(ffn_dim, use_bias=not deploy))
if deploy:
self.pwconv2 = nn.Sequential(
nn.Linear(ffn_dim, dim),
NHWCtoNCHW())
else:
self.pwconv2 = nn.Sequential(
nn.Linear(ffn_dim, dim, bias=False),
NHWCtoNCHW(),
get_bn(dim, use_sync_bn=use_sync_bn))
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim),
requires_grad=True) if (not deploy) and layer_scale_init_value is not None \
and layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, inputs):
def _f(x):
if self.need_contiguous:
x = x.contiguous()
y = self.se(self.norm(self.dwconv(x)))
y = self.pwconv2(self.act(self.pwconv1(y)))
if self.gamma is not None:
y = self.gamma.view(1, -1, 1, 1) * y
return self.drop_path(y) + x
if self.with_cp and inputs.requires_grad:
return checkpoint.checkpoint(_f, inputs)
else:
return _f(inputs)
def reparameterize(self):
if hasattr(self.dwconv, 'merge_dilated_branches'):
self.dwconv.merge_dilated_branches()
if hasattr(self.norm, 'running_var') and hasattr(self.dwconv, 'lk_origin'):
std = (self.norm.running_var + self.norm.eps).sqrt()
self.dwconv.lk_origin.weight.data *= (self.norm.weight / std).view(-1, 1, 1, 1)
self.dwconv.lk_origin.bias.data = self.norm.bias + (self.dwconv.lk_origin.bias - self.norm.running_mean) * self.norm.weight / std
self.norm = nn.Identity()
if self.gamma is not None:
final_scale = self.gamma.data
self.gamma = None
else:
final_scale = 1
if self.act[1].use_bias and len(self.pwconv2) == 3:
grn_bias = self.act[1].beta.data
self.act[1].__delattr__('beta')
self.act[1].use_bias = False
linear = self.pwconv2[0]
grn_bias_projected_bias = (linear.weight.data @ grn_bias.view(-1, 1)).squeeze()
bn = self.pwconv2[2]
std = (bn.running_var + bn.eps).sqrt()
new_linear = nn.Linear(linear.in_features, linear.out_features, bias=True)
new_linear.weight.data = linear.weight * (bn.weight / std * final_scale).view(-1, 1)
linear_bias = 0 if linear.bias is None else linear.bias.data
linear_bias += grn_bias_projected_bias
new_linear.bias.data = (bn.bias + (linear_bias - bn.running_mean) * bn.weight / std) * final_scale
self.pwconv2 = nn.Sequential(new_linear, self.pwconv2[1])
default_UniRepLKNet_A_F_P_kernel_sizes = ((3, 3),
(13, 13),
(13, 13, 13, 13, 13, 13),
(13, 13))
default_UniRepLKNet_N_kernel_sizes = ((3, 3),
(13, 13),
(13, 13, 13, 13, 13, 13, 13, 13),
(13, 13))
default_UniRepLKNet_T_kernel_sizes = ((3, 3, 3),
(13, 13, 13),
(13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3),
(13, 13, 13))
default_UniRepLKNet_S_B_L_XL_kernel_sizes = ((3, 3, 3),
(13, 13, 13),
(13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3),
(13, 13, 13))
UniRepLKNet_A_F_P_depths = (2, 2, 6, 2)
UniRepLKNet_N_depths = (2, 2, 8, 2)
UniRepLKNet_T_depths = (3, 3, 18, 3)
UniRepLKNet_S_B_L_XL_depths = (3, 3, 27, 3)
default_depths_to_kernel_sizes = {
UniRepLKNet_A_F_P_depths: default_UniRepLKNet_A_F_P_kernel_sizes,
UniRepLKNet_N_depths: default_UniRepLKNet_N_kernel_sizes,
UniRepLKNet_T_depths: default_UniRepLKNet_T_kernel_sizes,
UniRepLKNet_S_B_L_XL_depths: default_UniRepLKNet_S_B_L_XL_kernel_sizes
}
class UniRepLKNet(nn.Module):
r""" UniRepLKNet
A PyTorch impl of UniRepLKNet
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 27, 3)
dims (int): Feature dimension at each stage. Default: (96, 192, 384, 768)
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
kernel_sizes (tuple(tuple(int))): Kernel size for each block. None means using the default settings. Default: None.
deploy (bool): deploy = True means using the inference structure. Default: False
with_cp (bool): with_cp = True means using torch.utils.checkpoint to save GPU memory. Default: False
init_cfg (dict): weights to load. The easiest way to use UniRepLKNet with for OpenMMLab family. Default: None
attempt_use_lk_impl (bool): try to load the efficient iGEMM large-kernel impl. Setting it to False disabling the iGEMM impl. Default: True
use_sync_bn (bool): use_sync_bn = True means using sync BN. Use it if your batch size is small. Default: False
"""
def __init__(self,
in_chans=3,
num_classes=1000,
depths=(3, 3, 27, 3),
dims=(96, 192, 384, 768),
drop_path_rate=0.,
layer_scale_init_value=1e-6,
head_init_scale=1.,
kernel_sizes=None,
deploy=False,
with_cp=False,
init_cfg=None,
attempt_use_lk_impl=True,
use_sync_bn=False,
**kwargs
):
super().__init__()
depths = tuple(depths)
if kernel_sizes is None:
if depths in default_depths_to_kernel_sizes:
# print('=========== use default kernel size ')
kernel_sizes = default_depths_to_kernel_sizes[depths]
else:
raise ValueError('no default kernel size settings for the given depths, '
'please specify kernel sizes for each block, e.g., '
'((3, 3), (13, 13), (13, 13, 13, 13, 13, 13), (13, 13))')
# print(kernel_sizes)
for i in range(4):
assert len(kernel_sizes[i]) == depths[i], 'kernel sizes do not match the depths'
self.with_cp = with_cp
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
# print('=========== drop path rates: ', dp_rates)
self.downsample_layers = nn.ModuleList()
self.downsample_layers.append(nn.Sequential(
nn.Conv2d(in_chans, dims[0] // 2, kernel_size=3, stride=2, padding=1),
LayerNorm(dims[0] // 2, eps=1e-6, data_format="channels_first"),
nn.GELU(),
nn.Conv2d(dims[0] // 2, dims[0], kernel_size=3, stride=2, padding=1),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")))
for i in range(3):
self.downsample_layers.append(nn.Sequential(
nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=2, padding=1),
LayerNorm(dims[i + 1], eps=1e-6, data_format="channels_first")))
self.stages = nn.ModuleList()
cur = 0
for i in range(4):
main_stage = nn.Sequential(
*[UniRepLKNetBlock(dim=dims[i], kernel_size=kernel_sizes[i][j], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value, deploy=deploy,
attempt_use_lk_impl=attempt_use_lk_impl,
with_cp=with_cp, use_sync_bn=use_sync_bn) for j in
range(depths[i])])
self.stages.append(main_stage)
cur += depths[i]
self.output_mode = 'features'
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
for i_layer in range(4):
layer = norm_layer(dims[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
if self.output_mode == 'logits':
for stage_idx in range(4):
x = self.downsample_layers[stage_idx](x)
x = self.stages[stage_idx](x)
x = self.norm(x.mean([-2, -1]))
x = self.head(x)
return x
elif self.output_mode == 'features':
outs = []
for stage_idx in range(4):
x = self.downsample_layers[stage_idx](x)
x = self.stages[stage_idx](x)
outs.append(self.__getattr__(f'norm{stage_idx}')(x))
return outs
else:
raise ValueError('Defined new output mode?')
def switch_to_deploy(self):
for m in self.modules():
if hasattr(m, 'reparameterize'):
m.reparameterize()
class LayerNorm(nn.Module):
r""" LayerNorm implementation used in ConvNeXt
LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", reshape_last_to_first=False):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
self.reshape_last_to_first = reshape_last_to_first
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
def update_weight(model_dict, weight_dict):
idx, temp_dict = 0, {}
for k, v in weight_dict.items():
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
temp_dict[k] = v
idx += 1
model_dict.update(temp_dict)
print(f'loading weights... {idx}/{len(model_dict)} items')
return model_dict
def unireplknet_a(weights='', **kwargs):
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(40, 80, 160, 320), **kwargs)
if weights:
model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)))
return model
def unireplknet_f(weights='', **kwargs):
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(48, 96, 192, 384), **kwargs)
if weights:
model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)))
return model
def unireplknet_p(weights='', **kwargs):
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(64, 128, 256, 512), **kwargs)
if weights:
model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)))
return model
def unireplknet_n(weights='', **kwargs):
model = UniRepLKNet(depths=UniRepLKNet_N_depths, dims=(80, 160, 320, 640), **kwargs)
if weights:
model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)))
return model
def unireplknet_t(weights='', **kwargs):
model = UniRepLKNet(depths=UniRepLKNet_T_depths, dims=(80, 160, 320, 640), **kwargs)
if weights:
model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)))
return model
def unireplknet_s(weights='', **kwargs):
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(96, 192, 384, 768), **kwargs)
if weights:
model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)))
return model
def unireplknet_b(weights='', **kwargs):
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(128, 256, 512, 1024), **kwargs)
if weights:
model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)))
return model
def unireplknet_l(weights='', **kwargs):
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(192, 384, 768, 1536), **kwargs)
if weights:
model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)))
return model
def unireplknet_xl(weights='', **kwargs):
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(256, 512, 1024, 2048), **kwargs)
if weights:
model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)))
return model
if __name__ == '__main__':
inputs = torch.randn((1, 3, 640, 640))
model = unireplknet_a('unireplknet_a_in1k_224_acc77.03.pth')
res = model(inputs)[-1]
model.switch_to_deploy()
res_fuse = model(inputs)[-1]
print(torch.mean(res_fuse - res))
from ultralytics.nn.backbone.UniRepLKNet import *
可直接将下述代码替换对应位置
def _predict_once(self, x, profile=False, visualize=False, embed=None):
"""
Perform a forward pass through the network.
Args:
x (torch.Tensor): The input tensor to the model.
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False.
embed (list, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): The last output of the model.
"""
y, dt, embeddings = [], [], [] # outputs
for idx, m in enumerate(self.model):
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
if profile:
self._profile_one_layer(m, x, dt)
if hasattr(m, 'backbone'):
x = m(x)
for _ in range(5 - len(x)):
x.insert(0, None)
for i_idx, i in enumerate(x):
if i_idx in self.save:
y.append(i)
else:
y.append(None)
# print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x if x_ is not None])}')
x = x[-1]
else:
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
# if type(x) in {list, tuple}:
# if idx == (len(self.model) - 1):
# if type(x[1]) is dict:
# print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x[1]["one2one"]])}')
# else:
# print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x[1]])}')
# else:
# print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x if x_ is not None])}')
# elif type(x) is dict:
# print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x["one2one"]])}')
# else:
# if not hasattr(m, 'backbone'):
# print(f'layer id:{idx:>2} {m.type:>50} output shape:{x.size()}')
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)
if embed and m.i in embed:
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
if m.i == max(embed):
return torch.unbind(torch.cat(embeddings, 1), dim=0)
return x
可以直接把下面的代码粘贴到对应的位置中
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
"""
Parse a YOLO model.yaml dictionary into a PyTorch model.
Args:
d (dict): Model dictionary.
ch (int): Input channels.
verbose (bool): Whether to print model details.
Returns:
(tuple): Tuple containing the PyTorch model and sorted list of output layers.
"""
import ast
# Args
max_channels = float("inf")
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
if scales:
scale = d.get("scale")
if not scale:
scale = tuple(scales.keys())[0]
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
if len(scales[scale]) == 3:
depth, width, max_channels = scales[scale]
elif len(scales[scale]) == 4:
depth, width, max_channels, threshold = scales[scale]
if act:
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
if verbose:
LOGGER.info(f"{colorstr('activation:')} {act}") # print
if verbose:
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<60}{'arguments':<50}")
ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
is_backbone = False
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
try:
if m == 'node_mode':
m = d[m]
if len(args) > 0:
if args[0] == 'head_channel':
args[0] = int(d[args[0]])
t = m
m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
except:
pass
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
try:
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
except:
args[j] = a
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in {
Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
BottleneckCSP, C1, C2, C2f, ELAN1, AConv, SPPELAN, C2fAttn, C3, C3TR,
C3Ghost, nn.Conv2d, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3, PSA, SCDown, C2fCIB
}:
if args[0] == 'head_channel':
args[0] = d[args[0]]
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
if m is C2fAttn:
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels
args[2] = int(
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
) # num heads
args = [c1, c2, *args[1:]]
elif m in {AIFI}:
args = [ch[f], *args]
c2 = args[0]
elif m in (HGStem, HGBlock):
c1, cm, c2 = ch[f], args[0], args[1]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
cm = make_divisible(min(cm, max_channels) * width, 8)
args = [c1, cm, c2, *args[2:]]
if m in (HGBlock):
args.insert(4, n) # number of repeats
n = 1
elif m is ResNetLayer:
c2 = args[1] if args[3] else args[1] * 4
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}):
args.append([ch[x] for x in f])
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f])
elif m is CBLinear:
c2 = make_divisible(min(args[0][-1], max_channels) * width, 8)
c1 = ch[f]
args = [c1, [make_divisible(min(c2_, max_channels) * width, 8) for c2_ in args[0]], *args[1:]]
elif m is CBFuse:
c2 = ch[f[-1]]
elif isinstance(m, str):
t = m
if len(args) == 2:
m = timm.create_model(m, pretrained=args[0], pretrained_cfg_overlay={'file': args[1]},
features_only=True)
elif len(args) == 1:
m = timm.create_model(m, pretrained=args[0], features_only=True)
c2 = m.feature_info.channels()
elif m in {unireplknet_a, unireplknet_f, unireplknet_p, unireplknet_n, unireplknet_t, unireplknet_s, unireplknet_b, unireplknet_l, unireplknet_xl
}:
m = m(*args)
c2 = m.channel
else:
c2 = ch[f]
if isinstance(c2, list):
is_backbone = True
m_ = m
m_.backbone = True
else:
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace('__main__.', '') # module type
m.np = sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type = i + 4 if is_backbone else i, f, t # attach index, 'from' index, type
if verbose:
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<60}{str(args):<50}") # print
save.extend(x % (i + 4 if is_backbone else i) for x in ([f] if isinstance(f, int) else f) if
x != -1) # append to savelist
layers.append(m_)
if i == 0:
ch = []
if isinstance(c2, list):
ch.extend(c2)
for _ in range(5 - len(ch)):
ch.insert(0, 0)
else:
ch.append(c2)
return nn.Sequential(*layers), sorted(save)
具体改进差别如下图所示:
在下述文件夹中创立yolov8-unireplknet.yaml
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
# 0-P1/2
# 1-P2/4
# 2-P3/8
# 3-P4/16
# 4-P5/32
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, unireplknet_a, []] # 4
- [-1, 1, SPPF, [1024, 5]] # 5
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 6
- [[-1, 3], 1, Concat, [1]] # 7 cat backbone P4
- [-1, 3, C2f, [512]] # 8
- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 9
- [[-1, 2], 1, Concat, [1]] # 10 cat backbone P3
- [-1, 3, C2f, [256]] # 11 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]] # 12
- [[-1, 8], 1, Concat, [1]] # 13 cat head P4
- [-1, 3, C2f, [512]] # 14 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]] # 15
- [[-1, 5], 1, Concat, [1]] # 16 cat head P5
- [-1, 3, C2f, [1024]] # 17 (P5/32-large)
- [[11, 14, 17], 1, Detect, [nc]] # Detect(P3, P4, P5)
将train.py中的配置文件进行修改,并运行