在yolov10改进的时候,经常可以看到需要修改parse_model方法,但是相信很多东西都不知道这个方法是干嘛的,以及流程方式,所以今天给大家详细介绍一下这些变量的含义和作用,方便大家理解原理。

yolov10中,tasks.py的parse_model方法详解_parse_model

源代码

def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
    """Parse a YOLO model.yaml dictionary into a PyTorch model."""
    import ast

    # Args
    max_channels = float("inf")
    nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
    depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
    if scales:
        scale = d.get("scale")
        if not scale:
            scale = tuple(scales.keys())[0]
            LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
        depth, width, max_channels = scales[scale]

    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        if verbose:
            (f"{colorstr('activation:')} {act}")  # print

    if verbose:
        (f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
    ch = [ch]
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, args
        m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m]  # get module
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    args[j] = locals()[a] if a in locals() else ast.literal_eval(a)

        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain
        if m in {
            Classify,
            Conv,
            ConvTranspose,
            GhostConv,
            Bottleneck,
            GhostBottleneck,
            SPP,
            SPPF,
            DWConv,
            Focus,
            BottleneckCSP,
            C1,
            C2,
            C2f,
            RepNCSPELAN4,
            ADown,
            SPPELAN,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            nn.ConvTranspose2d,
            DWConvTranspose2d,
            C3x,
            RepC3,
            PSA,
            SCDown,
            C2fCIB
        }:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            if m is C2fAttn:
                args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)  # embed channels
                args[2] = int(
                    max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
                )  # num heads

            args = [c1, c2, *args[1:]]
            if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB):
                args.insert(2, n)  # number of repeats
                n = 1
        elif m in {CARAFE}:
            c2 = ch[f]
            args = [c2,*args]
        elif m is AIFI:
            args = [ch[f], *args]
        elif m in {HGStem, HGBlock}:
            c1, cm, c2 = ch[f], args[0], args[1]
            args = [c1, cm, c2, *args[2:]]
            if m is HGBlock:
                args.insert(4, n)  # number of repeats
                n = 1
        elif m is ResNetLayer:
            c2 = args[1] if args[3] else args[1] * 4
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
            args.append([ch[x] for x in f])
            if m is Segment:
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
        elif m is RTDETRDecoder:  # special case, channels arg must be passed in index 1
            args.insert(1, [ch[x] for x in f])
        elif m is CBLinear:
            c2 = args[0]
            c1 = ch[f]
            args = [c1, c2, *args[1:]]
        elif m is CBFuse:
            c2 = ch[f[-1]]
        else:
            c2 = ch[f]

        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace("__main__.", "")  # module type
         = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type = i, f, t  # attach index, 'from' index, type
        if verbose:
            (f"{i:>3}{str(f):>20}{n_:>3}{:10.0f}  {t:<45}{str(args):<30}")  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)


代码解析

parse_model 函数的作用是将 YOLO 模型的配置(通常在 YAML 文件中定义)解析并构建成 PyTorch 模型。这个函数接收一个模型配置字典 d,输入通道数 ch,以及一个可选的布尔值 verbose 来控制是否打印详细的构建信息。下面是对这个函数的逐行解释:

  1. 导入模块
import ast

导入 ast 模块,用于将字符串形式的 Python 表达式转换为 Python 对象。

  1. 提取配置参数
max_channels = float("inf")
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))

从模型配置字典 d 中提取必要的参数,包括类别数 nc、激活函数 act、尺度 scales、深度倍数 depth、宽度倍数 width 和关键点形状 kpt_shape。如果 scales 存在,则进一步提取模型的尺度参数。

  1. 设置默认激活函数
if act:
    Conv.default_act = eval(act)
    if verbose:
        (f"{colorstr('activation:')} {act}")

如果配置中指定了激活函数,则设置为 Conv 类的默认激活函数,并在 verbose 模式下打印。

  1. 初始化日志信息
if verbose:
    (f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")

如果 verbose 为 True,初始化日志信息的格式。

  1. 初始化构建参数
ch = [ch]
layers, save, c2 = [], [], ch[-1]

初始化通道列表 ch,层列表 layers,保存列表 save 和当前输出通道数 c2

  1. 遍历模型配置并构建层
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):

遍历模型的 backbonehead 配置,f 表示输入来源,n 表示重复次数,m 表示模块类型,args 表示模块参数。

  1. 获取模块
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m]

根据模块类型 m 获取对应的 PyTorch 模块。

  1. 处理字符串参数
for j, a in enumerate(args):
    if isinstance(a, str):
        with contextlib.suppress(ValueError):
            args[j] = locals()[a] if a in locals() else ast.literal_eval(a)

args 中的字符串参数转换为实际的 Python 对象。

  1. 调整深度和宽度
n = n_ = max(round(n * depth), 1) if n > 1 else n

根据深度倍数调整重复次数。

  1. 构建特定模块
  • 对于不同的模块类型,如 ClassifyConvConvTranspose 等,根据输入通道 ch[f]、输出通道 args[0] 和其他参数构建模块。
  • 对于 C2fAttn 模块,特别处理嵌入通道数和头数。
  1. 处理特殊模块
  • 对于 CARAFEAIFIHGStemHGBlockResNetLayernn.BatchNorm2dConcatDetectRTDETRDecoderCBLinear 和 CBFuse 等特殊模块类型,进行特定的参数处理。
  1. 构建模块并附加信息
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)
t = str(m)[8:-2].replace("__main__.", "")
 = sum(x.numel() for x in m_.parameters())
m_.i, m_.f, m_.type = i, f, t
if verbose:
    (f"{i:>3}{str(f):>20}{n_:>3}{:10.0f}  {t:<45}{str(args):<30}")
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)
layers.append(m_)

构建模块 m_,计算参数数量 ,并附加索引、来源和类型信息。如果 verbose 为 True,打印模块信息,并更新保存列表。

  1. 更新通道列表
if i == 0:
    ch = []
ch.append(c2)

更新通道列表 ch,为下一层的构建做准备。

  1. 返回构建的模型和保存列表
return nn.Sequential(*layers), sorted(save)

返回构建的 PyTorch 模型和需要保存的层的索引列表。

总结:parse_model 函数的作用是根据 YOLO 模型的配置字典,构建并返回一个 PyTorch 模型和需要保存的层的索引列表。这个函数处理了多种模块类型和参数,能够灵活地构建复杂的 YOLO 模型架构。