在yolov10改进的时候,经常可以看到需要修改parse_model方法,但是相信很多东西都不知道这个方法是干嘛的,以及流程方式,所以今天给大家详细介绍一下这些变量的含义和作用,方便大家理解原理。
源代码
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
"""Parse a YOLO model.yaml dictionary into a PyTorch model."""
import ast
# Args
max_channels = float("inf")
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
if scales:
scale = d.get("scale")
if not scale:
scale = tuple(scales.keys())[0]
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
depth, width, max_channels = scales[scale]
if act:
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
if verbose:
(f"{colorstr('activation:')} {act}") # print
if verbose:
(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in {
Classify,
Conv,
ConvTranspose,
GhostConv,
Bottleneck,
GhostBottleneck,
SPP,
SPPF,
DWConv,
Focus,
BottleneckCSP,
C1,
C2,
C2f,
RepNCSPELAN4,
ADown,
SPPELAN,
C2fAttn,
C3,
C3TR,
C3Ghost,
nn.ConvTranspose2d,
DWConvTranspose2d,
C3x,
RepC3,
PSA,
SCDown,
C2fCIB
}:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
if m is C2fAttn:
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels
args[2] = int(
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
) # num heads
args = [c1, c2, *args[1:]]
if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB):
args.insert(2, n) # number of repeats
n = 1
elif m in {CARAFE}:
c2 = ch[f]
args = [c2,*args]
elif m is AIFI:
args = [ch[f], *args]
elif m in {HGStem, HGBlock}:
c1, cm, c2 = ch[f], args[0], args[1]
args = [c1, cm, c2, *args[2:]]
if m is HGBlock:
args.insert(4, n) # number of repeats
n = 1
elif m is ResNetLayer:
c2 = args[1] if args[3] else args[1] * 4
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f])
elif m is CBLinear:
c2 = args[0]
c1 = ch[f]
args = [c1, c2, *args[1:]]
elif m is CBFuse:
c2 = ch[f[-1]]
else:
c2 = ch[f]
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace("__main__.", "") # module type
= sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
if verbose:
(f"{i:>3}{str(f):>20}{n_:>3}{:10.0f} {t:<45}{str(args):<30}") # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:
ch = []
ch.append(c2)
return nn.Sequential(*layers), sorted(save)
代码解析
parse_model
函数的作用是将 YOLO 模型的配置(通常在 YAML 文件中定义)解析并构建成 PyTorch 模型。这个函数接收一个模型配置字典 d
,输入通道数 ch
,以及一个可选的布尔值 verbose
来控制是否打印详细的构建信息。下面是对这个函数的逐行解释:
- 导入模块:
import ast
导入 ast
模块,用于将字符串形式的 Python 表达式转换为 Python 对象。
- 提取配置参数:
max_channels = float("inf")
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
从模型配置字典 d
中提取必要的参数,包括类别数 nc
、激活函数 act
、尺度 scales
、深度倍数 depth
、宽度倍数 width
和关键点形状 kpt_shape
。如果 scales
存在,则进一步提取模型的尺度参数。
- 设置默认激活函数:
if act:
Conv.default_act = eval(act)
if verbose:
(f"{colorstr('activation:')} {act}")
如果配置中指定了激活函数,则设置为 Conv
类的默认激活函数,并在 verbose
模式下打印。
- 初始化日志信息:
if verbose:
(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
如果 verbose
为 True,初始化日志信息的格式。
- 初始化构建参数:
ch = [ch]
layers, save, c2 = [], [], ch[-1]
初始化通道列表 ch
,层列表 layers
,保存列表 save
和当前输出通道数 c2
。
- 遍历模型配置并构建层:
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):
遍历模型的 backbone
和 head
配置,f
表示输入来源,n
表示重复次数,m
表示模块类型,args
表示模块参数。
- 获取模块:
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m]
根据模块类型 m
获取对应的 PyTorch 模块。
- 处理字符串参数:
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
将 args
中的字符串参数转换为实际的 Python 对象。
- 调整深度和宽度:
n = n_ = max(round(n * depth), 1) if n > 1 else n
根据深度倍数调整重复次数。
- 构建特定模块:
- 对于不同的模块类型,如
Classify
、Conv
、ConvTranspose
等,根据输入通道ch[f]
、输出通道args[0]
和其他参数构建模块。 - 对于
C2fAttn
模块,特别处理嵌入通道数和头数。
- 处理特殊模块:
- 对于
CARAFE
、AIFI
、HGStem
、HGBlock
、ResNetLayer
、nn.BatchNorm2d
、Concat
、Detect
、RTDETRDecoder
、CBLinear
和CBFuse
等特殊模块类型,进行特定的参数处理。
- 构建模块并附加信息:
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)
t = str(m)[8:-2].replace("__main__.", "")
= sum(x.numel() for x in m_.parameters())
m_.i, m_.f, m_.type = i, f, t
if verbose:
(f"{i:>3}{str(f):>20}{n_:>3}{:10.0f} {t:<45}{str(args):<30}")
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)
layers.append(m_)
构建模块 m_
,计算参数数量 ,并附加索引、来源和类型信息。如果
verbose
为 True,打印模块信息,并更新保存列表。
- 更新通道列表:
if i == 0:
ch = []
ch.append(c2)
更新通道列表 ch
,为下一层的构建做准备。
- 返回构建的模型和保存列表:
return nn.Sequential(*layers), sorted(save)
返回构建的 PyTorch 模型和需要保存的层的索引列表。
总结:parse_model
函数的作用是根据 YOLO 模型的配置字典,构建并返回一个 PyTorch 模型和需要保存的层的索引列表。这个函数处理了多种模块类型和参数,能够灵活地构建复杂的 YOLO 模型架构。