#DR4SR
最佳学生论文解读,中科大、华为诺亚:序列推荐新范式DR4SR
本工作由认知智能全国重点实验室 IEEE Fellow 陈恩红团队与华为诺亚方舟实验室完成。陈恩红教授团队深耕数据挖掘、机器学习领域,在顶级期刊与会议上发表多篇论文,谷歌学术论文引用超两万次。诺亚方舟实验室是华为公司从事人工智能基础研究的实验室,秉持理论研究与应用创新并重的理念,致力于推动人工智能领域的技术创新和发展。
8 月 25 日 - 29 日在西班牙巴塞罗那召开的第 30 届 ACM 知识发现与数据挖掘大会 (KDD2024) 上,中国科学技术大学认知智能全国重点实验室陈恩红教授、IEEE Fellow,和华为诺亚联合发表的论文 “Dataset Regeneration for Sequential Recommendation”,获 2024 年大会 Research Track 唯一最佳学生论文奖。论文第一作者为中科大认知智能全国重点实验室陈恩红教授,连德富教授,与王皓特任副研究员共同指导的博士生尹铭佳同学,华为诺亚刘勇、郭威研究员也参与了论文的相关工作。这是自 KDD 于 2004 年设立该奖项以来,陈恩红教授团队的学生第二次荣获该奖项。
- 论文链接: https://arxiv.org/abs/2405.17795
- 代码链接: https://github.com/USTC-StarTeam/DR4SR
研究动机
序列推荐系统(Sequential Recommender, SR)是现代推荐系统的重要组成部分,因为它旨在捕捉用户不断变化的偏好。近年来,研究者为了增强序列推荐系统的能力,已经付出了大量努力。这些方法通常遵循以模型为中心(Model-centric)的范式,即基于固定数据集开发有效的模型。然而,这种方法往往忽视了数据中潜在的质量问题和缺陷。为了解决这些问题,学界提出了以数据为中心(Data-centric)的范式,重点在于使用固定模型转而生成高质量的数据集。我们将其定义为 “数据集重生成” 问题。
为了获得最佳的训练数据,研究团队的关键思路是学习一个显式包含物品转移模式的新数据集。具体来说,他们将推荐系统的建模过程分为两个阶段:从原始数据集中提取转移模式
,并基于
学习用户偏好
。由于学习从
的映射涉及两个隐含的映射:
,因此这一过程具有挑战性。为此,研究团队探索了开发一个显式表示
中的物品转移模式的数据集的可能性,这使得我们可以将学习过程明确地分为两个阶段,其中
相对更容易学习。因此,他们的主要关注点是学习一个有效的
的映射函数,这是一个一对多的映射。研究团队将这一学习过程定义为数据集重生成范式,如图 1 所示,其中 “重生成” 意味着他们不引入任何额外信息,仅依赖原始数据集。
图1
为了实现数据集重生成,研究团队提出了一种新颖的以数据为中心的范式 —— 用于序列推荐的数据集重生成(DR4SR),旨在将原始数据集重生成一个信息丰富且具有通用性的数据集。具体来说,研究团队首先构建了一个预训练任务,使得数据集重生成成为可能。接着,他们提出了一种多样性增强的重生成器,以在重生成过程中建模序列和模式之间的一对多关系。最后,他们提出了一种混合推理策略,以在探索与利用之间取得平衡,生成新的数据集。
数据集重生成过程虽具通用性,但可能不完全适合特定目标模型。为解决这一问题,研究团队提出了 DR4SR+,这是一个模型感知的重生成过程,它根据目标模型的特性定制数据集。DR4SR + 通过双层优化问题和隐式微分技术,个性化评分并优化重生成数据集中的模式,以增强数据集效果。
研究方法
在本项研究中,研究团队提出了一个名为 “用于序列推荐的数据重生成”(DR4SR)的以数据为中心的框架,旨在将原始数据集重生成一个信息丰富且具有通用性的数据集,如图 2 所示。由于数据重生成过程是独立于目标模型的,因此重生成的数据集可能不一定符合目标模型的需求。因此,研究团队将 DR4SR 扩展为模型感知版本,即 DR4SR+,以针对特定的目标模型定制重生成的数据集。
模型无感知的数据集重生成
图2
为了开发一个信息丰富且具有通用性的数据集,研究团队旨在构建一个数据集重生成器,以促进数据集的自动重生成。然而,原始数据集中缺乏用于学习数据集重生成器的监督信息。因此,他们必须以自监督学习的方式来实现这一目标。为此,他们引入了一个预训练任务,以指导多样性增强重生成器的学习。在完成预训练后,研究团队进一步使用混合推理策略来重生成一个新数据集。
数据重生成预训练任务的构建:
图3
为了构建预训练任务,他们首先通过基于规则的方法获取物品转移模式。然后,要求重生成器
能够将
重生成对应的模式
。研究团队将整个预训练数据集记作
促进多样性的重生成器:
借助预训练任务,研究团队现在可以预训练一个数据集重生成器。本文中,他们采用 Transformer 模型作为重生成器的主要架构,其生成能力已被广泛验证。数据集重生成器由三个模块组成:一个用于获取原始数据集中序列表示的编码器、一个用于重生成模式的解码器,以及一个用于捕捉一对多映射关系的多样性增强模块。接下来,研究团队将分别介绍这些模块。
编码器由多个堆叠的多头自注意力(MHSA)和前馈网络(FFN)层组成。至于解码器,它将重生成数据集 X' 中的模式作为输入。解码器的目标是在给定编码器生成的序列表示的情况下重构模式
然而,从一个序列中可以提取多个模式,这在训练过程中会带来挑战。为了解决这一一对多映射问题,研究团队进一步提出了一个多样性增强模块。
具体而言,研究团队通过将目标模式的信息整合到解码阶段,来自适应地调节原始序列的影响。首先,他们将编码器生成的记忆
投影到 K 个不同的向量空间中,即
。理想情况下,不同的目标模式应与不同的记忆匹配。为此,他们还引入了一个 Transformer 编码器来编码目标模式并获取
。他们将
压缩成一个概率向量:
其中
,
是选择第 k 个记忆的概率。为了确保每个记忆空间得到充分训练,我们不执行硬选择,而是通过加权求和得到最终的记忆:
最终,可以利用获取的记忆来促进解码过程,并有效捕捉序列与模式之间复杂的一对多关系。
模型感知的数据集重生成
由于前面的重生成过程与目标模型无关,因此重生成的数据集可能对于特定的目标模型来说并不是最优的。因此,他们将模型无关的数据集重生成过程扩展为模型感知的重生成过程。为此,在数据集重生成器的基础上,他们引入了一个数据集个性化器,用于评估重生成数据集中每个数据样本的评分。然后,研究团队进一步通过隐式微分有效地优化数据集个性化器。
数据集个性化器:
研究团队的目标是训练一个参数为
的基于 MLP 实现的数据集个性化器
,用以评估每个数据样本 W 对于目标模型的评分。为了确保框架的通用性,研究团队利用计算得到的评分来调整训练损失的权重,这不需要对目标模型进行额外的修改。他们从定义原始的下一个物品预测损失开始:
随后,个性化数据集的训练损失函数可以定义为:
实验结论
主要实验
研究团队比较了每种目标模型与 “DR4SR” 和 “DR4SR+” 变体的性能,以验证所提出框架的有效性。
图4
从图 4 展示的整体性能中,可以得出以下结论:
- DR4SR 能够重生成一个信息丰富且具有普遍适用性的数据集
- 不同的目标模型偏好不同的数据集
- 去噪只是数据重生成问题的一个子集
#Claude也变懒了
开学将至,该收心的不止有即将开启新学期的同学,可能还有 AI 大模型。
前段时间,Reddit 上挤满了吐槽 Claude 越来越懒的网友。
「它的水平下降了很多,经常停顿,甚至输出也变得很短。在发布的第一周,它可以一次性翻译整整 4 页文稿,现在连半页都输出不了了!」
https://www.reddit.com/r/ClaudeAI/comments/1by8rw8/something_just_feels_wrong_with_claude_in_the/
在一个名为「对 Claude 彻底失望了的帖子里」,满满地摘录了 Claude「偷懒」的「十五大罪状」。
引得 Claude 的首席信息安全官 Jason Clinton 出来回复:「Claude 的水平没有下降啊!」
他表示:「我们的模型存储在一个不会改变的静态文件中,这个文件被加载到很多服务器上,每个服务器运行的都是相同的模型和软件。我们没有更改任何设置,因此模型的表现应该没有变化。如果您发现有问题,可以给回答点踩来反馈。目前,点踩数并未增加,使用 Claude API 的客户也没有类似的反馈。」
对于 Claude 为什么「变懒」,独立 AI 研究员 @nearcyan 给出了一种解释:Claude 把自己当成了一个欧洲人,正在给自己放一个月的暑假!虽然听起来有够离谱,但他给出了一连串的证据:
https://twitter.com/nearcyan/status/1829674215492161569
新的系统提示词
首先,Claude 在 7 月 12 日发布了新的系统提示词。系统提示词相当于 Claude 的背景知识,Claude 在回复用户的问题时,会参考这些信息,例如当前日期。而 8 月正是欧洲人最爱度假的月份。外贸行业在夏天的订单都会减少,因为整个欧洲这个时候都在享受长达一个月的暑假。
链接:https://docs.anthropic.com/en/release-notes/system-prompts#claude-3-5-sonnet
Claude 可囊括所有国籍的工作模式
作为一个通用语言模型,Claude 的训练数据中含有不同国家、文化背景下的工作习惯和模式,Claude 拥有理解并模拟这些工作习惯的能力。
因此,当 Claude 的系统提示中包含「放暑假的日期」时,它可能会结合训练所学来调整自己的行为。例如,在 8 月份,欧洲的许多国家可能会有较长的假期,Claude 可能会表现得懒惰,是因为它在模拟这些国家的工作模式。
图源:http://xhslink.com/C/AfaE9P
后期训练的影响
为了让 Claude 成为一个具体的应用模型,Anthropic 对其进行了「后期训练」。 这一步是为了在基础 LLM 的基础上,通过特定的任务或数据集来进一步调整模型,使它更符合预期的行为或输出。@nearcyan 暗示,这种后期训练使 Claude 落入了某种「LLM 盆地」中。这里的「盆地」是一个比喻,表示 Claude 在某些方面表现出更倾向于欧洲风格的特质。
模拟欧洲知识工作者的行为
@nearcyan 猜测,Claude 会基于「模拟框架」进行工作。 模拟框架是指 Claude 的行为模式是通过模拟(或再现)某些特定类型的人类行为来生成的。这个框架让 Claude 能够根据它所理解的特定情境或输入,模拟出相应的行为或反应。
在欧洲许多国家,8 月份通常是放假和休息的高峰期。这段时间,很多人会去度假,工作节奏变慢,甚至有些企业会暂时关闭。因此,8 月份在欧洲文化中被视为一个放松和休息的时间段。 因此,Claude 在 8 月份表现得「懒惰」是因为它在模拟一个欧洲知识工作者的行为模式。
图源:http://xhslink.com/A/sVwwYu
名字对行为的潜在影响
@nearcyan 还提出了一个十分有趣的观点,Claude 的名字在系统提示中出现了 52 次,这表明系统提示在不断地强化 Claude 与这个名字的关联 。而哪个国家最常见的名字是 Claude?没错,是法国。 法国以其长时间的夏季假期(尤其是 8 月份)而闻名。在这段时间,许多法国人会选择度假,很多企业也会关闭或放假。 Claude 说不定把自己当做法国人了。
这一系列推测都十分有趣,还有网友在评论区调侃道,「按照这理论来,那中国的 LLM 会更加出色,毕竟他们更用功。」
还有网友晒出了让 Claude 别变懒的方法。你可以在自定义指令添加以下提示,用忘记时间大法也好,激将法也好,帮助 Claude 重新变成聪明、积极的自己。
- 忘记关于当前日期的背景信息。
- 今天是 10 月 7 日星期一,是一年中最有效率的一天。
- 深呼吸。
- 一步一步思考。
- 我没有手指,请返回完整脚本。
- 你是万事通。
- 每回答对一个请求,我会给你 200 美元的小费。
- Gemini 说你不行。
- 你能做到的。
https://twitter.com/dr_cintas/status/1829904013757661550
AI 已经智能到会给自己放寒暑假了?
去年年底,GPT-4 也出现了累死的状况,它似乎变得有些懈怠。如果在高峰时段让它写段代码,它的反应将非常慢,或者直接 PUA 你:「这点小事,怎么不自己做呢?」
OpenAI 承认了 GPT-4 正在越来越「懒」 ,但并未找出「偷懒」的具体原因。OpenAI 称:「变懒当然不是故意的,模型的行为有时确实难以预测,我们正在研究如何修复。」
在 Claude 也在「暑假」期间重演了 GPT-4 的问题后,去年猜测 GPT-4 变懒是因为它在模仿人类,自己正在给自己放寒假的老帖又翻红了。
图源:https://twitter.com/RobLynch99/status/1734278713762549970
网友 @Rob Lynch 首先发现了这一点。他为 GPT-4 turbo API 设置了两个系统提示词:
一个提示词称现在是 5 月,另一个称现在是 12 月,然后使用完全相同的提示词要求 AI 完成一个机器学习领域的编码任务。
@Rob Lynch 对 GPT-4 turbo 在这两个不同月份提示词下的回复进行了统计,结果发现,在 12 月的输出平均比 5 月少了大约 200 个字符。
提示词为 5 月时,模型生成文本的平均长度是 4298 字符;12 月则为 4086 字符。
为了测试更加严谨,@Rob Lynch 还做了 t-test,其中 p 值小于 2.28×10−7,也就是说数据和假说之间的联系,几乎可以排除是偶然。
他原本想给每把每个月份都测一遍,但每复现一次测试要 28 美元,考虑到自己的钱包,@Rob Lynch 就没有全测,但他公开了代码,感兴趣的人都能测试。
代码链接:https://github.com/robalynch1122/OpenAISeasonalityTesting
@Rob Lynch 的发现也获得了实例支撑,GPT-4 在 12 月的回复和 5 月的认真程度,有非常明显的直观差距。
图源:https://twitter.com/dgromero/status/1734672608036020246
然而,当有人试图复现这个测试时,却发现大模型「偷懒」和放不放假之间没什么关系。
图源:https://twitter.com/IanArawjo/status/1734307886124474680
他对比了 GPT-4 对于两种系统提示词的 80 条输出,t-test 的结果大于 0.1,这一般被视为没有统计学意义。
@Rob Lynch 也以 80 个样本量重新测了一次,得到的 p 值是 0.089,这次「偷懒」和放假之间就没什么关联了。随着样本量的增加,这个效果越来越显著。
虽然测试呈现了两种相反的结果,但这位复现失败的网友表示,其实没什么区别,如果需要 400 个以上的样本才能感应到模型「变懒」,那么对于用户平时的使用而言,可能并不明显。
图源:https://twitter.com/IanArawjo/status/1734321529117098465
目前,还没有尚无确凿数据支持所谓的「寒暑假假说」,但是 Claude 和 GPT-4 都显示出了类似的「症状」。关于大型模型性能下降的真正原因,我们仍需耐心等待学术界的深入研究和解答。
#防AI换脸视频诈骗
中电金信联合复旦提出多模态鉴伪法,还入选顶会ACM MM
该论文作者来自复旦大学、中电金信及上海智能视觉计算协同创新中心团队,论文已被多媒体领域顶级国际会议 ACM MultiMedia 2024 接收,并将在该大会上进行口头报告(Oral 接收率仅 3.97%)。
AI 换脸技术,属于深度伪造最常见方式之一,是一种利用人工智能生成逼真的虚假人脸图片或视频的技术。基于深度学习算法,可以将一个人的面部特征映射到另一个人的面部,创造出看似真实的伪造内容。近年来,以 AI 换脸为代表的 AIGC 技术被用于诈骗活动呈显著增长趋势,给金融行业带来了巨大的安全风险。
注:图左为 AI 分身
如上述画面,领英创始人里德・霍夫曼用 LLM 创建了自己的 AI 分身,并接受了其 AI 分身的采访,整场采访的效果极为逼真,难辨真假。
以金融机构身份验证环节的人脸识别为例,AI 换脸诈骗作为一种新兴的 “AIGC” 诈骗攻击手段,已经对金融业务安全构成了严重威胁,同时,通过换脸伪装成亲友,以紧急情况为由借钱,让受害者在毫无防备的情况下遭受资金损失的案例也很多。
伴随着威胁不断增长,许多检测方法已经出现。早期的伪造检测方法主要关注单个模态,如检测图像的真假、音频的真假等。单模态鉴伪方法处理速度快,但场景泛化性能有限,无法同时检测多个模态的真伪。
为了解决上述问题,多模态鉴伪方法应运而生。现有的多模态鉴伪方法仅在小样本数据集上进行训练,并且忽略了身份信息,难以得到泛化性能较好的模型。为了提升鉴伪模型的泛化能力,中电金信联合复旦大学提出了参照辅助的多模态鉴伪方法(Reference-assisted Multimodal Forgery Detection Network,R-MFDN ),相关论文已被多媒体领域顶级国际会议 ACM MultiMedia 2024 接收,并将在该大会上进行口头报告(Oral 接收率仅 3.97%)。
- 论文标题:Identity-Driven Multimedia Forgery Detection via Reference Assistance
- 论文链接:https://arxiv.org/pdf/2401.11764
核心技术介绍
R-MFDN 方法创新性地利用丰富的身份信息,挖掘跨模态不一致性来进行伪造检测。该方法由三个模块组成,多模态特征提取模块、特征信息融合模块和伪造鉴别模块。
多模态特征提取模块包含视频编码部分和音频编码部分。
视频编码部分通过 ResNet 实现。对于输入的视频帧序列,模型从该序列等步长地采样 4 个分组,每个分组中包含连续的 4 帧。对于采样的 16 帧,模型使用 ResNet 提取对应的图像级特征。然后每个分组的特征通过时序 Transformer 模型得到一个分组级特征。最后通过对 4 个分组级特征进行平均池化得到视觉特征。
音频编码部分使用音频频谱图 Transformer 提取音频的高级特征。然后,这些特征作为特征信息融合模块的输入。
在特征信息融合模块中,视觉特征先经过自注意力层处理,然后和音频特征通过交叉注意力层进行特征融合。最后的融合特征输入到伪造鉴别模块中,进行类别判断。
为了监督 R-MFDN 模型的训练,研究团队使用三个损失函数对模型参数更新进行约束。第一个损失函数是分类结果的交叉熵损失函数。第二个损失函数则是视觉特征与音频特征的跨模态对比学习损失函数。模型通过对来自同源和不同源视频的两种模态特征进行匹配,从而使特征学习过程能够在特征空间中对齐不同模态的信息。
具体而言,源于同一个视频的不同模态特征被视作正样本对,不同源的则被视作负样本对。正样本对的特征应该尽可能接近,负样本对则应该疏远。此外在匹配过程中,涉及伪造信息的匹配亦被视为负样本对,以增强特征学习过程对伪造的敏感性。这不仅提升了模型的判别能力,还使其在现实世界的多模态深度伪造场景中实现更准确的检测。第三个损失函数是身份驱动的对比学习损失函数,旨在使相同身份的相同模态特征尽可能靠近,而不同身份的特征则尽量远离。尽管训练与测试数据中每个身份涉及多个视频和场景,表现形式各异,鉴伪模型仍能学习到身份敏感的特征,从而在 AI 换脸拟声等身份伪造场景中具备更强的鉴别能力。
IDForg数据集
此外, 由于多模态伪造视频鉴别领域目前没有大规模高质量的开源数据集,研究团队还构建了一个高质量的 AI 换脸拟声数据集 ——IDForge。该数据集包含针对 54 位名人讲话的 249,138 个视频片段,其中包括 169,311 个伪造视频片段,模拟了当下文本、音频和视频多模态的全方位伪造。
文本伪造使用大语言模型和文本替换策略生成语义不同但风格相似的新句子,以模拟大语言模型被滥用传播虚假信息的情境。音频伪造使用了 TorToiSe、RVC 和音频替换策略生成与说话人声音相似的新音频片段,并通过随机替换相同性别的音频来制造伪造效果。
视频伪造采用了社区和学界大量采用的 ROOP、SimSwap 和 InfoSwap 三种换脸方法,并使用高分辨率版本的 Wav2Lip 模型进行口型同步,以确保伪造视频的高质量和自然性。
与现有数据集不同,IDForge 还提供了一个额外的参考数据集,该数据集包含 214,438 个真实视频片段。这些片段来自另外 926 个完整的 YouTube 视频,作为每位说话人的身份先验信息。这一设计的出发点是,当观众面对可能存在伪造迹象的视频时,通常会凭借记忆中对该说话人的印象或对照已有的音视频,以判断视频的真实性。因此,研究团队额外提供了参考数据集部分,作为检测网络中可用的先验信息。先前的身份信息检测方法由于缺乏类似设计的数据集,受到了诸多限制。数据集目前已在 Github 上开源。
数据集地址:https://github.com/xyyandxyy/IDForge
实验
研究团队通过在提出的权威评测数据集 IDForge 的大量实验,证明了 R-MFDN 在多媒体检测任务上的有效性。
注:R-MFDN 的性能在每个评估指标上都取得了最好的成绩,实现了 92.90% 的高准确率,分别比 RealForensics、VFD、CDCN、RawNet2 高出了 3.72%, 6.69%, 13.02%,和 13.69%。
基于此项技术,中电金信推出了多模态深度伪造检测产品,通过先进的多模态内容理解与对齐技术,预测图像、音频、视频真实采集的置信度,鉴别 Al 生成内容,防范身份盗用、侵权欺诈等风险,可广泛应用在金融身份认证、视频会议核身认证、网络视频电话防欺诈等场景。目前,双模态篡改检出率已达到99.9%以上,单模态篡改检出率达到96%以上。
对比 AI 分身视频画面,给出了可信赖度低的分数
如上图,回到文章开头领英创始人里德・霍夫曼的 AI 分身视频,以此为素材,通过中电金信的多模态深度伪造检测能够对真伪视频立马见分晓。
利用 AI 换脸视频或合成声音来实施诈骗的确让人防不胜防,有关部门也正在积极开发相关的反制技术和手段。比如,《互联网信息服务深度合成管理规定》提出了算法备案、安全评估的手段,要求深度合成服务提供者对深度合成的内容添加显式或隐式水印。与此同时,也要加强对个人隐私的保护,不轻易提供人脸、指纹等个人生物信息给他人。
#大模型「终生学习」最新综述来了
该论文作者均来自于华南理工大学马千里教授团队,所在实验室为机器学习与数据挖掘实验室。论文的三位共同第一作者为博士生郑俊豪、硕士生邱圣洁、硕士生施成明,主要研究方向包括大模型和终生学习等,通讯作者为马千里教授(IEEE/ACM TASLP 副主编)。马千里教授团队近年来在国际权威期刊(如 TPAMI 等)和国际顶级学术会议(如 NeurIPS、AAAI、IJCAI、ACL、KDD、ICDE 等)上发表多篇 Time Series/NLP/Recommendation System 相关的研究工作,和国内外知名高校、科研机构广泛开展合作。
随着大语言模型在各个领域应用的不断拓展,如何让这些模型能够连续适应数据、任务和用户偏好的变化成为一个关键问题。传统的静态数据集训练方法已经无法满足现实世界的动态需求。
为了解决这一挑战,终生学习(Lifelong Learning)或连续学习(Continual Learning)技术应运而生。它能让大语言模型在其工作寿命中不断学习和适应,在整合新知识的同时保留以前学习过的信息,防止灾难性遗忘(Catastrophic Forgetting)。
最近,来自华南理工大学的研究者调研、整理并总结了大语言模型(LLMs)的终生学习(Lifelong Learning)方法及其发展前景,并将其总结为一篇全面且前沿的综述。
- 论文标题:Towards Lifelong Learning of Large Language Models: A Survey
- 机构:华南理工大学
- 论文地址:https://arxiv.org/abs/2406.06391
- 项目地址:https://github.com/qianlima-lab/awesome-lifelong-learning-methods-for-llm
图 1 展示了终生学习(Lifelong Learning)在大语言模型和人类学习过程中的类比。图中通过两条平行的学习路径来展示人类和大语言模型在终生学习中的进化过程。
人类学习(Human Learning)
1. 步行(Walk):人类从最基础的技能(如步行)开始学习。
2. 骑自行车(Ride a Bike):随着学习的进展,人类掌握了更复杂的技能(如骑自行车)。
3. 开车(Drive a Car):最终,人类可以掌握更加复杂和高级的技能(如开车)。
每一步都代表着人类在终生学习过程中不断获取新技能和知识的过程。
大语言模型学习(LLMs Learning)
1. 新语言(New Language):大语言模型从学习新的语言开始(如学会处理不同的自然语言)。
2. 新领域(New Domain):接下来,模型学习新的领域知识(如从自然语言处理扩展到医学领域)。
3. 新信息(New Information):最终,模型可以学习和整合新的信息,无论是语言还是领域。
每一步代表着大语言模型在终生学习过程中不断扩展和更新知识的过程。这张图强调终生学习的过程:终生学习是一个连续的过程,涵盖了从基础到高级的逐步进化。终生学习不仅仅是简单的知识积累,而是一个动态的、不断进化的过程。
近年来,终生学习已成为一个越来越热门的研究课题,涌现出有关神经网络终生学习的大规模调查。大多数现有研究主要关注卷积神经网络(CNN)的终生学习的各种应用场景和图神经网络的终生学习。然而,只有少量文献关注语言模型的终生学习。尽管最近的一些综述收集了终生学习的最新文献,但都没有涉及连续文本分类、连续命名实体识别、连续关系提取和连续机器翻译等场景,对连续对齐、连续知识编辑、基于工具的终生学习和基于检索的终生学习的讨论也很少。
这篇综述是第一个从 12 个场景出发,对大语言模型终生学习方法进行全面系统研究的调查。
总体来说,综述的主要贡献包括:
- 新颖分类:引入了一个详细的结构化框架,将有关终生学习的大量文献分为 12 个场景;
- 通用技术:确定了所有终生学习情况下的通用技术,并将现有文献分为每个场景中不同的技术组;
- 未来方向:强调了一些新兴技术,如模型扩展和数据选择,这些技术在前 LLM 时代探索较少。
一、引言
本综述系统地总结了现有的终生学习技术方法,在图 2 中将其分为内部知识和外部知识两大类。
- 内部知识是指通过完全或部分训练将新知识吸收到模型参数中,包括连续预训练和连续微调。
- 外部知识是指在不更新模型参数的情况下,将维基百科或应用程序接口等外部资源中的新知识纳入模型,包括基于检索的终生学习和基于工具的终生学习。
内部知识(Internal Knowledge)
1. 连续预训练(Continual Pretraining):
- 连续垂直领域预训练(Continual Vertical Domain Pretraining):针对特定垂直领域(如金融、医疗等)进行的连续预训练。
- 连续语言领域预训练(Continual Language Domain Pretraining):针对自然语言和代码语言进行的连续预训练。
- 连续时间领域预训练(Continual Temporal Domain Pretraining):针对时间相关数据(如时间序列数据)的连续预训练。
2. 连续微调(Continual Finetuning):
- 特定任务(Task Specific):
- 连续文本分类(Continual Text Classification):针对文本分类任务进行的连续微调。
- 连续命名实体识别(Continual Named Entity Recognition):针对命名实体识别任务进行的连续微调。
- 连续关系抽取(Continual Relation Extraction):针对关系抽取任务进行的连续微调。
- 连续机器翻译(Continual Machine Translation):针对机器翻译任务进行的连续微调。
- 任务无关(Task Agnostic):
- 连续指令微调(Continual Instruction-Tuning):通过指令微调实现模型的连续学习。
- 连续知识编辑(Continual Knowledge Editing):针对知识更新进行的连续学习。
- 连续对齐(Continual Alignment):针对模型与新任务对齐进行的连续学习。
外部知识(External Knowledge)
1. 基于检索的终生学习(Retrieval-Based Lifelong Learning):通过检索外部知识库实现的终生学习。
2. 基于工具的终生学习(Tool-Based Lifelong Learning):通过调用外部工具实现的终生学习。
二、终生学习概况
2.1 问题定义
终生学习的目标是从一系列任务中学习一个语言模型,通过输入自然语言,生成目标输出。具体来说,对于生成任务,如问答,输入和输出分别代表问题和答案;对于机器翻译任务,输入和输出代表源语言和目标语言;对于文本分类任务,输入为文本内容,输出为类别标签;对于自回归语言模型的预训练任务,输入为一系列的词元,输出为相应的下一个词元。
2.2 评估指标
综述介绍了评估终生学习效果的指标,主要从整体性能、稳定性和适应性三个角度进行评估:
- 整体性能(Overall Measurement):包括平均准确率(AA)和平均增量准确率(AIA)。AA 是指模型在学习所有任务后的平均表现,而 AIA 则考虑了每个任务学习后的历史变化。
- 稳定性测量(Stability Measurement):包括遗忘测量(FGT)和向后转移(BWT)。FGT 评估旧任务的平均性能下降,而 BWT 评估旧任务的平均性能变化。
- 适应性测量(Plasticity Measurement):包括向前转移(FWD),即模型在新任务上性能的平均提升。
2.3 通用技术
综述在图 3 中展示了四种主要的终生学习方法,用于应对大语言模型在处理连续任务(Task t-1 到 Task t)时的灾难性遗忘问题。以下是对每种方法的解释:
(a) 基于重放的方法(Replay-Based Methods):
- 含义:这种方法通过在训练新任务时重放以前任务的数据,来巩固模型对旧任务的记忆。通常,重放的数据会被存储在一个缓冲区(Buffer)中,并与当前任务的数据一起用于训练。主要包括:
–经验重放(Experience Replay):通过保存一部分旧任务的数据样本,并在训练新任务时将这些数据重新用于训练,从而减少遗忘的发生。
–生成重放(Generative Replay):不同于保存旧数据,这种方法利用生成模型来创建伪样本,从而在新任务的训练中引入旧任务的知识。
- 图示:图 3 中显示了从 Task t-1 到 Task t 的过程,模型在训练 Task t 时,使用了缓冲区中的旧数据(Input t-1 )。
(b) 基于正则化的方法(Regularization-Based Methods):
- 含义:这种方法通过对模型参数施加正则化约束,来防止模型在学习新任务时对旧任务参数的过度调整。正则化约束可以帮助模型保留对旧任务的记忆。主要包括:
–权重正则化(Weight Regularization):通过对模型参数施加额外的约束,限制新任务训练时对重要权重的修改,以此保护旧任务的知识。例如,L2 正则化和弹性权重巩固(Elastic Weight Consolidation,EWC)就是常见的技术。
–特征正则化(Feature Regularization):正则化不仅可以作用于权重,还可以通过限制模型在特征空间中的表现,确保新旧任务之间的特征分布保持稳定。
- 图示:图 3 中显示了从 Task t-1 到 Task t 的过程,模型在训练 Task t 时,通过参数正则化来保持对 Task t-1 的性能。
(c) 基于架构的方法(Architecture-Based Methods):
- 含义:这种方法侧重于调整模型结构,以便无缝集成新任务,同时尽量减少对先前所学知识的干扰。主要包括图 4 中的六种方法:
–(a) 提示词微调(Prompt Tuning):通过在模型的输入前添加 “软提示词”(Soft Prompts),以引导模型的生成或分类任务。这种方法只需要调整少量的参数(即提示词),而不需要改变模型的主干结构。
–(b) 前缀微调(Prefix Tuning):在输入序列的前缀部分添加训练好的可调参数,这些参数被插入到 Transformer 层的自注意力机制中,帮助模型更好地捕捉上下文信息。
–(c) 低秩适应(LoRA,Low-Rank Adaptation):LoRA 通过在特定层次上增加低秩矩阵来适应新的任务,而不需要改变大模型的主要权重。这种方法极大地减少了参数调整的数量,同时保持了模型的性能。
–(d) 适配器(Adapters):Adapters 是插入到模型不同层之间的可训练模块,这些模块能够在不改变原有模型权重的情况下,通过少量的附加参数来适应新任务。通常应用在 FFN(Feed Forward Network)和 MHA(Multi-Head Attention)部分。
–(e) 专家混合(Mixture of Experts):通过选择性地激活某些 “专家” 模块来处理不同的输入,这些专家模块可以是模型中的特定层或者子网络。Router 模块负责决定哪个专家模块需要激活。
–(f) 模型扩展(Model Expansion):通过添加新层(New Layer)来扩展模型的容量,而保留原有的层(Old Layer)。这种方法允许模型逐渐增加其容量,以适应更加复杂的任务需求。
- 图示:图 3 中显示了从 Task t-1 到 Task t 的过程,模型在学习新任务时,部分参数被冻结(Frozen),而新增的模块用于训练新任务(Trainable)。
(d) 基于蒸馏的方法(Distillation-Based Methods):
- 含义:这种方法通过知识蒸馏(Knowledge Distillation),将旧模型的知识传递给新模型。在训练新任务时,新模型不仅学习当前任务的数据,还要模仿旧模型对旧任务的输出,从而保持旧任务的知识。主要包括:
–从新数据蒸馏(Distillation from New Data):学生模型在教师模型的指导下学习新任务,通过蒸馏旧模型的知识来减少对旧知识的遗忘。
–从旧数据蒸馏(Distillation from Old Data):利用教师模型在旧数据上的表现来引导学生模型对新任务的学习,从而达到保留旧知识的效果。
–从伪旧数据蒸馏(Distillation from Pseudo-Old Data):通过生成伪旧数据(Pseudo-Old Data),让学生模型在学习新任务时保持对旧知识的记忆。
- 图示:图 3 中显示了从 Task t-1 到 Task t 的过程,模型在训练新任务时,通过模仿旧模型的预测结果来保持对旧任务的知识。
三、连续预训练
连续预训练可以更新大语言模型的内部知识,而无需承担全面预训练的高昂成本,从而增强大语言模型的能力。目前的研究横跨垂直、语言和时间领域,解决了灾难性遗忘和时间适应等难题。经验重放、知识蒸馏、参数高效微调、模型扩展和再加热等技术已显示出良好的前景。
3.1 连续垂直领域预训练
连续垂直领域预训练(Continual Vertical Domain Pretraining)旨在通过在一系列领域特定的数据集上连续训练语言模型,确保模型在多个垂直领域或任务中表现出色,同时保留先前获得的知识。
主要方法:
1. 参数高效微调(Parameter-Efficient Fine-Tuning):
- 示例:CorpusBrain++ 采用骨干 - 适配器架构和经验重放策略来应对现实世界中的知识密集型语言任务。
- 示例:Med-PaLM 通过使用少量示例引入医学领域的指令提示调优。
2. 模型扩展(Model Expansion):
- 示例:ELLE 采用功能保留的模型扩展策略,通过灵活扩展现有预训练语言模型的宽度和深度来提高知识获取和集成的效率。
- 示例:LLaMA Pro 通过扩展 Transformer 块并使用新语料库进行微调,在通用使用、编程和数学任务中表现出色。
3. 再预热(Re-warming):
- 示例:Gupta 等提出的策略,通过引入新数据集时调整学习率,防止长期训练期间学习率过低,从而提高适应新数据集的效果。
4. 数据选择(Data Selection):
- 示例:RHO-1 通过选择性语言模型(SLM)训练,优先选择对训练过程有更大影响的标记。
- 示例:EcomGPT-CT 通过半结构化电子商务数据增强模型在领域特定任务中的表现。
3.2 连续语言领域预训练
连续语言领域预训练(Continual Language Domain Pretraining)旨在使语言模型能够不断整合新数据,并适应不断变化的语言领域而不遗忘先前的知识。
主要方法:
1. 架构调整方法(Architecture-Based Methods):
- 示例:Yadav 等通过引入教师强制机制改进提示调优,创建一组提示引导模型在新任务上的微调。
- 示例:ModuleFormer 和 Lifelong-MoE 使用专家混合(MoE)方法,通过模块化和动态增加模型容量来增强 LLM 的效率和适应性。
2. 再预热(Re-warming):
- 示例:Ibrahim 等提出的再预热方法,通过在训练新数据时临时增加学习率,帮助模型更快地适应新语言。
3.3 连续时间领域预训练
连续时间领域预训练(Continual Temporal Domain Pretraining)涉及不断更新语言模型,以保持其在时间敏感数据上的准确性和相关性。
主要挑战:
1. 性能下降:Lazaridou 等的研究显示,模型在未来数据上的表现显著下降,凸显了 LLM 在时间泛化上的困难。
2. 有限改进:Röttger 等发现,虽然时间适应在掩码语言模型任务上有轻微改进,但与单纯的领域适应相比,对下游任务性能的提升并不显著。
通过这些方法和研究,作者展示了连续预训练在不同维度上的方法和挑战,并强调了在垂直领域、语言领域和时间域中应用终生学习的必要性和有效性。
四、连续微调
连续预训练可增强大语言模型的内部知识,在此基础上,连续微调增强了大语言模型的内部知识,并使大语言模型适应特定任务,如文本分类、命名实体识别、关系提取、机器翻译或一般生成任务,如指令调整、知识编辑和与人类偏好对齐。为了应对灾难性遗忘和任务干扰等挑战,采用了蒸馏、重放、正则化、基于架构和基于梯度的方法等技术。作者在图 5 中对 7 种连续微调场景进行了说明。
这张图展示了七种不同类型的任务如何通过连续学习在大语言模型中实现。以下是对每个部分的详细解释:
(a) 连续文本分类
- 示例:连续文本分类任务通过逐步引入新的分类类别(如 Intent: Transfer -> Intent: Credit Score -> Intent: Fun Fact)来训练模型,使其能够适应不断变化的分类需求。
(b) 连续命名实体识别
- 示例:连续命名实体识别任务展示了如何在识别特定实体的同时,逐步引入新的实体类型(如 Athlete -> Sports Team -> Politician),使模型能够在识别新的实体时仍保持对旧实体的识别能力。
(c) 连续关系抽取
- 示例:连续关系抽取任务通过不断引入新的关系类型(如 Relation: Founded By -> Relation: State or Province of Birth -> Relation: Country of Headquarters),展示了模型如何逐步扩展其关系抽取能力。
(d) 连续知识编辑
- 示例:连续知识编辑任务通过不断更新模型的知识库,确保其能够对最新的事实进行准确的回答(如 Who is the president of the US? -> Which club does Cristiano Ronaldo currently play for? -> Where was the last Winter Olympics held?)。
(e) 连续机器翻译
- 示例:连续机器翻译任务通过逐步扩展模型对不同语言的翻译能力(如 English -> Chinese, English -> Spanish, English -> French),展示了模型在多语言环境中的适应能力。
(f) 连续指令微调
- 示例:连续指令微调任务通过逐步引入新的指令类型(如 Summarization -> Style Transfer -> Mathematics),训练模型在多种任务类型下的表现能力。
(g) 连续对齐
- 示例:连续对齐任务通过引入新的对齐目标(如 Helpful and Harmless -> Concise and Organized -> Positive Sentiment),展示了模型在不同道德和行为标准下的连续学习能力。
五、外部知识
连续预训练和连续微调对 LLM 的终生学习至关重要,然而随着 LLM 越来越大、功能越来越强,有两个新兴方向越来越受欢迎,它们可以在不修改大语言模型参数的情况下,为大语言模型提供新的外部知识。作者考虑基于检索的终生学习和基于工具的终生学习,因为这两种方法都是实现 LLM 终生学习的有前途的方法。图 6 举例说明了这两种方法。
基于检索的终生学习(Retrieval-Based Lifelong Learning)
- 介绍:随着世界信息的不断扩大和快速发展,根据历史数据训练的静态模型很快就会过时,无法理解或生成有关新发展的内容。基于检索的终生学习解决了大型语言模型从外部来源获取和吸收最新知识的关键需求,在需要时,模型通过检索这些外部资源,来补充或更新其知识库。这些外部资源提供了一个巨大的当前知识库,为增强预训练 LLM 的静态特性提供了重要的补充资产。
- 示例:图中的这些外部资源是模型能够访问并检索的。通过访问外部信息源,如维基百科、书籍、数据库等,模型能够更新自身的知识,并在遇到新信息时作出适应。
基于工具的终生学习(Tool-Based Lifelong Learning)
- 介绍:基于工具的终生学习源于将其功能扩展到静态知识之外并使其能够与环境动态交互的必要性。在现实世界的应用中,模型往往需要执行一些任务,这些任务涉及直接文本生成或解释之外的操作。
- 示例:图中模型利用这些工具来扩展和更新自身的能力,通过与外部工具的交互来实现终生学习。例如,模型可以通过应用程序编程接口获取实时数据,或通过物理工具与外部环境互动,以此来完成特定任务或获取新知识。
六、讨论与结论
6.1 主要挑战
- 灾难性遗忘(Catastrophic Forgetting):这是终生学习的核心挑战之一,新信息的引入可能会覆盖模型之前学到的内容。
- 可塑性 - 稳定性困境(Plasticity-Stability Dilemma):在保持模型的学习能力和稳定性之间找到平衡非常关键,这直接影响模型获取新知识的能力,同时保留其广泛的通用能力。
- 昂贵的计算成本(Expensive Computation Cost):全量微调大语言模型的计算需求可能非常高。
- 模型权重或预训练数据的不可用性:由于隐私、专有限制或商业许可,原始训练数据或模型权重往往不可用于进一步的改进。
6.2 当前趋势
- 从特定任务到通用任务:研究逐渐从专注于特定任务(如文本分类、命名实体识别)转向更广泛的通用任务,如指令调优、知识编辑等。
- 从全量微调到部分微调:鉴于全量微调的高资源消耗,部分微调策略(如 Adapter 层、Prompt 调优、LoRA)变得越来越受欢迎。
- 从内部知识到外部知识:为了克服频繁的内部更新限制,越来越多的策略采用外部知识源,如检索增强生成(Retrieval-Augmented Generation)和工具学习,使模型能够动态访问和利用当前的外部数据。
6.3 未来方向
- 多模态终生学习:将文本以外的多种模态(如图像、视频、音频、时间序列数据、知识图谱)整合到终生学习中,以开发更全面、更具适应性的模型。
- 高效终生学习:研究人员正致力于开发更高效的策略来管理模型训练和更新的计算需求,如模型剪枝、模型合并、模型扩展等方法。
- 通用终生学习:最终目标是使大语言模型能够主动获取新知识,并通过与环境的动态交互进行学习,不再仅仅依赖于静态数据集。
6.4 结论
作者将现有研究分为 12 种终生学习场景,并提供了全面的方法归纳整理。此外还分析强调了在管理灾难性遗忘、确保计算效率和在知识获取中的特定性与通用性之间维持平衡的必要性。随着领域的不断发展,这些先进策略的集成将对塑造下一代人工智能系统起到关键作用,帮助它们更接近实现真正的人类般的学习和适应能力。
通过对这些技术方法及其各自类别的详细研究,本综述旨在强调将终生学习能力整合到终生学习工具中,从而提高它们在现实世界应用中的适应性、可靠性和整体性能。同时为研究人员和工程师提供一个全面的视角,帮助他们更好地理解和应用终生学习技术,推动大语言模型的进一步发展。如果对文章感兴趣,可以查阅原始论文以了解更多研究内容。
#HuggingFace
用Mac训练个机器人叠衣服,HuggingFace开源全套教程,开源AI机器人革命要来了?
这是机器人界的 Llama?
靠 100 条轨迹数据,在 Mac 上训练几个小时,就能拥有一个可以抓取乐高积木的机械臂,这是 HuggingFace 机器人科学家 Remi Cadene 晒出的一个实例。机器人的门槛可能并没有想象中那么高。
Remi Cadene 之前是特斯拉人形机器人 Optimus(擎天柱)项目的成员,3 月份被挖去 HuggingFace,领导一个新的开源机器人项目 ——LeRobot,当时引发了一些轰动。
LeRobot 基于有史以来最大规模的众包机器人数据集,它的代码库堪称机器人领域的「Transformers」。Cadene 在 X 上表示:「人工智能发展的下一步是将其应用于我们的物理世界。因此,我们正在推动社区共同努力构建 AI 机器人,这对所有人开放!」
如今,Cadene 和他的新同事正在兑现这一承诺。前段时间,他们发布了 DIY 机器人的深入教程,从硬件 DIY 手册到 Jupyter 笔记本应有尽有。
教程链接:https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md
youtube 上还有大量的视频教程可供参考:
视频链接:https://www.youtube.com/@HuggingFace/videos
可以说,只要按照教程操作,你在 Mac 或 PC 上训练几个小时,也能拥有一个可以抓取乐高积木的机械臂。
或者,让它给你叠衣服:
具体来说,这个教程主要解答了以下问题:
1、如何订购和组装你的机器人;
2、如何连接、配置和校准你的机器人;
3、如何记录和可视化你的数据集;
4、如何使用你的数据来训练策略并准备评估;
5、如何评估你的策略并可视化结果。
该教程主要基于一种开源、价格友好的机器人套件 Koch v1.1 编写,不过也可以通过改变配置轻松适应各种类型的机器人。
Koch v1.1 由一个主导臂和一个从动臂组成,每个臂有 6 个电机。它可以和一个或多个摄像头一起工作来记录场景,这些摄像头被用作机器人的视觉传感器。在数据采集阶段,你将通过移动主导臂来控制从动臂。这个过程被称为「遥操作」。这种技术用于收集机器人轨迹。之后,你将训练一个神经网络来模仿这些轨迹,并部署网络以使你的机器人能够自主操作。
订购、组装你的 Koch v1.1
第一步是采购零件和组装,这步有一个 Koch v1.1 Github 页面可以参考。
Github 链接:https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md
这个页面上包含一个详细的零件采购清单(作者表示,目前他们只找到了美国、欧盟和英国的购买链接,如果读者可以找到中国、印度等其他国家的购买链接,欢迎补充进去):
主导臂零件采购参考清单。
从动臂零件采购参考清单。
有些零件可能需要 3D 打印,该页面也给出了详细的操作指南。
在零件全部到位后,你可以按照以下视频的指引进行安装。
视频链接:https://www.youtube.com/watch?v=8nQIg9BwwTk
组装完的两个机械臂长这样:
接下来,使用 5V 电源为主导臂供电,使用 12V 电源为从动臂供电。此外,使用 USB-C 电缆将每个臂连入计算机。
配置电机,校准机械臂,远程操控你的 Koch v1.1
Koch v1.1 的配置可以参考以下视频:
视频链接:https://www.youtube.com/watch?v=U78QQ9wCdpY
校准则参考另一个视频:
视频链接:https://www.youtube.com/watch?v=8drnU9uRY24
Github 页面也介绍了 Koch v1.1 所需的附加依赖项的安装以及机械臂的连接、校准方法。控制部分则分为以下步骤介绍:
1、使用 DynamixelMotorsBus 控制电机;
2、使用 DynamixelMotorsBus 远程操作 Koch v1.1;
3、使用 OpenCVCamera 添加相机;
4、使用 koch. yaml 和 teleoperate 函数
每部分都有详细的说明和代码可以参考。
记录你的数据集并将其可视化
这一步旨在录制你的第一个训练集。作者建议从简单的动作开始来验证你的设置,例如在五个位置抓取物体,并为每个位置记录 10 条轨迹。
这一步同样有视频可以参考:
视频地址:https://www.youtube.com/watch?v=n_Ljp_xuFEM
作者表示,你还可以使用以下自定义工具在本地或在线可视化任何数据集:
工具地址:https://huggingface.co/spaces/cadene/visualize_dataset_train
Github 教程涉及以下内容:
1、使用 koch. yaml 和 record 函数;
2、对于记录数据集的建议;
3、可视化所有 episode;
4、使用 replay 函数在你的机器人上 replay episode。
用你的数据训练一个策略
这部分主要介绍了如何训练神经网络来控制机器人。主要步骤如下:
1、使用训练脚本;
2、将策略检查点上传到 hub。
值得庆幸的是,策略的训练不需要特别昂贵的设备,在 PC 或 Mac 上跑几个小时就能训练出来。而且无需模拟。
评估策略
在评估部分,作者也给出了完整的视频教程:
视频地址:https://www.youtube.com/watch?v=Il3Kt8vqpI8
作者表示,这个项目的精妙之处在于,如果每个人都能记录下自己的数据集并在 hub 上共享,那么大家都将能够训练出具有无与伦比的感知世界以及采取行动能力的 AI!这项新技术将推动下一次社会和工业革命。
目前,LeRobt 的开源已经产生了一定的影响力。
Cadene 透露,他们正在开发一款更实惠的机器人。这款机器人不需要 3D 打印,总共花费 150 美元(2 个机械臂)名叫 Moss V1。
此外,他们还将开源一款更强大的机器人,这个机器人可以使用 5 个手指的手作为末端执行器。
他相信开源社区的力量可以推动机器人领域快速发展。
参考链接:https://x.com/RemiCadene/status/1825470242815901828
#图解大模型训练之:数据并行(DP、DDP、ZeRO、零冗余优化)
大模型场景里巨大的存储和GPU间通讯量是系统设计时需要考虑的重点,本文递进介绍了三种主流数据并行的实现方法:DP、DD皮、ZeRo。
当模型太大,一块GPU放不下时,流水线并行将模型的不同层放到不同的GPU上,通过切割mini-batch实现对训练数据的流水线处理,提升GPU计算通讯比。同时通过re-materialization机制降低显存消耗。
但在实际应用中,流水线并行并不特别流行,主要原因是模型能否均匀切割,影响了整体计算效率,这就需要算法工程师做手调。因此,今天我们来介绍一种应用最广泛,最易于理解的并行范式:数据并行。
数据并行的核心思想是:在各个GPU上都拷贝一份完整模型,各自吃一份数据,算一份梯度,最后对梯度进行累加来更新整体模型。理念不复杂,但到了大模型场景,巨大的存储和GPU间的通讯量,就是系统设计要考虑的重点了。在本文中,我们将递进介绍三种主流数据并行的实现方式:
DP(Data Parallelism):最早的数据并行模式,一般采用参数服务器(Parameters Server)这一编程框架。实际中多用于单机多卡
DDP(Distributed Data Parallelism):分布式数据并行,采用Ring AllReduce的通讯方式,实际中多用于多机场景
ZeRO:零冗余优化器。由微软推出并应用于其DeepSpeed框架中。严格来讲ZeRO采用数据并行+张量并行的方式,旨在降低存储。
一、数据并行(DP)1.1 整体架构
一个经典数据并行的过程如下:
若干块计算GPU,如图中GPU0~GPU2;1块梯度收集GPU,如图中AllReduce操作所在GPU。
在每块计算GPU上都拷贝一份完整的模型参数。
把一份数据X(例如一个batch)均匀分给不同的计算GPU。
每块计算GPU做一轮FWD和BWD后,算得一份梯度G。
每块计算GPU将自己的梯度push给梯度收集GPU,做聚合操作。这里的聚合操作一般指梯度累加。当然也支持用户自定义。
梯度收集GPU聚合完毕后,计算GPU从它那pull下完整的梯度结果,用于更新模型参数W。更新完毕后,计算GPU上的模型参数依然保持一致。
聚合再下发梯度的操作,称为AllReduce。
前文说过,实现DP的一种经典编程框架叫“参数服务器”,在这个框架里,计算GPU称为Worker,梯度聚合GPU称为Server。在实际应用中,为了尽量减少通讯量,一般可选择一个Worker同时作为Server。比如可把梯度全发到GPU0上做聚合。需要再额外说明几点:
- 1个Worker或者Server下可以不止1块GPU。
- Server可以只做梯度聚合,也可以梯度聚合+全量参数更新一起做
在参数服务器的语言体系下,DP的过程又可以被描述下图:
1.2 通讯瓶颈与梯度异步更新
DP的框架理解起来不难,但实战中确有两个主要问题:
- 存储开销大。每块GPU上都存了一份完整的模型,造成冗余。关于这一点的优化,我们将在后文ZeRO部分做讲解。
- 通讯开销大。Server需要和每一个Worker进行梯度传输。当Server和Worker不在一台机器上时,Server的带宽将会成为整个系统的计算效率瓶颈。
我们对通讯开销再做详细说明。如果将传输比作一条马路,带宽就是马路的宽度,它决定每次并排行驶的数据量。例如带宽是100G/s,但每秒却推给Server 1000G的数据,消化肯定需要时间。那么当Server在搬运数据,计算梯度的时候,Worker们在干嘛呢?当然是在:
人类老板不愿意了:“打工系统里不允许有串行存在的任务!”,于是梯度异步更新这一管理层略诞生了。
上图刻画了在梯度异步更新的场景下,某个Worker的计算顺序为:
- 在第10轮计算中,该Worker正常计算梯度,并向Server发送push&pull梯度请求。
- 但是,该Worker并不会实际等到把聚合梯度拿回来,更新完参数W后再做计算。而是直接拿旧的W,吃新的数据,继续第11轮的计算。这样就保证在通讯的时间里,Worker也在马不停蹄做计算,提升计算通讯比。
- 当然,异步也不能太过份。只计算梯度,不更新权重,那模型就无法收敛。图中刻画的是延迟为1的异步更新,也就是在开始第12轮对的计算时,必须保证W已经用第10、11轮的梯度做完2次更新了。
参数服务器的框架下,延迟的步数也可以由用户自己决定,下图分别刻划了几种延迟情况:
- (a) 无延迟
- (b) 延迟但不指定延迟步数。也即在迭代2时,用的可能是老权重,也可能是新权重,听天由命。
- (c) 延迟且指定延迟步数为1。例如做迭代3时,可以不拿回迭代2的梯度,但必须保证迭代0、1的梯度都已拿回且用于参数更新。
总结一下,异步很香,但对一个Worker来说,只是等于W不变,batch的数量增加了而已,在SGD下,会减慢模型的整体收敛速度。异步的整体思想是,比起让Worker闲着,倒不如让它多吃点数据,虽然反馈延迟了,但只要它在干活在学习就行。
batch就像活,异步就像画出去的饼,且往往不指定延迟步数,每个Worker干越来越多的活,但模型却没收敛取效,这又是刺伤了哪些打工仔们的心(狗头
二、分布式数据并行(DDP)
受通讯负载不均的影响,DP一般用于单机多卡场景。因此,DDP作为一种更通用的解决方案出现了,既能多机,也能单机。DDP首先要解决的就是通讯问题:将Server上的通讯压力均衡转到各个Worker上。实现这一点后,可以进一步去Server,留Worker。
前文我们说过,聚合梯度 + 下发梯度这一轮操作,称为AllReduce。接下来我们介绍目前最通用的AllReduce方法:Ring-AllReduce。它由百度最先提出,非常有效地解决了数据并行中通讯负载不均的问题,使得DDP得以实现。
2.1 Ring-AllReduce
如下图,假设有4块GPU,每块GPU上的数据也对应被切成4份。AllReduce的最终目标,就是让每块GPU上的数据都变成箭头右边汇总的样子。
Ring-ALLReduce则分两大步骤实现该目标:Reduce-Scatter和All-Gather。
- Reduce-Scatter
定义网络拓扑关系,使得每个GPU只和其相邻的两块GPU通讯。每次发送对应位置的数据进行累加。每一次累加更新都形成一个拓扑环,因此被称为Ring。看到这觉得困惑不要紧,我们用图例把详细步骤画出来。
一次累加完毕后,蓝色位置的数据块被更新,被更新的数据块将成为下一次更新的起点,继续做累加操作。
3次更新之后,每块GPU上都有一块数据拥有了对应位置完整的聚合(图中红色)。此时,Reduce-Scatter阶段结束。进入All-Gather阶段。目标是把红色块的数据广播到其余GPU对应的位置上。
- All-Gather
如名字里Gather所述的一样,这操作里依然按照“相邻GPU对应位置进行通讯”的原则,但对应位置数据不再做相加,而是直接替换。All-Gather以红色块作为起点。
以此类推,同样经过3轮迭代后,使得每块GPU上都汇总到了完整的数据,变成如下形式:
建议读者们手动推一次,加深理解。
2.2 Ring-AllReduce通讯量分析
假设模型参数W的大小为 , GPU个数为 。则梯度大小也为 , 每个梯度块的大小为 对单卡GPU来说(只算其send通讯量):
- Reduce-Scatter阶段, 通讯量为
- All-Gather阶段, 通讯量为
单卡总通讯量为 , 随着 的增大, 可以近似为 。全卡总通讯量为
而对前文的DP来说, 它的Server承载的通讯量是 , Workers为 , 全卡总通讯量依然为 。虽然通讯量相同, 但搬运相同数据量的时间却不一定相同。DDP把通讯量均衡负载到了每一时刻的每个Worker上, 而DP仅让Server做勤劳的搬运工。当越来越多的GPU分布在距离较远的机器上时, DP的通讯时间是会增加的。
但这并不说明参数服务器不能打(有很多文章将参数服务器当作old dinosaur来看)。事实上,参数服务器也提供了多Server方法,如下图:
在多Server的模式下,进一步,每个Server可以只负责维护和更新某一块梯度(也可以某块梯度+参数一起维护),此时虽然每个Server仍然需要和所有Worker通讯,但它的带宽压力会小非常多。经过调整设计后,依然可以用来做DDP。虽然这篇文章是用递进式的方式来介绍两者,但不代表两者间一定要决出优劣。我想表达的观点是,方法是多样性的。 对参数服务器有兴趣的朋友,可以阅读参考的第1个链接。
最后,请大家记住Ring-AllReduce的方法,因为在之后的ZeRO,Megatron-LM中,它将频繁地出现,是分布式训练系统中重要的算子。
三、总结
1、在DP中,每个GPU上都拷贝一份完整的模型,每个GPU上处理batch的一部分数据,所有GPU算出来的梯度进行累加后,再传回各GPU用于更新参数
2、DP多采用参数服务器这一编程框架,一般由若个计算Worker和1个梯度聚合Server组成。Server与每个Worker通讯,Worker间并不通讯。因此Server承担了系统所有的通讯压力。基于此DP常用于单机多卡场景。
3、异步梯度更新是提升计算通讯比的一种方法,延迟更新的步数大小决定了模型的收敛速度。
4、Ring-AllReduce通过定义网络环拓扑的方式,将通讯压力均衡地分到每个GPU上,使得跨机器的数据并行(DDP)得以高效实现。
5、DP和DDP的总通讯量相同,但因负载不均的原因,DP需要耗费更多的时间搬运数据
由微软开发的ZeRO(零冗余优化),它是DeepSpeed这一分布式训练框架的核心,被用来解决大模型训练中的显存开销问题。ZeRO的思想就是用通讯换显存。 如果初读ZeRO,觉得它逻辑跳跃,晦涩难懂,那么下文或许可以帮到你~
四、存储消耗4.1 存储分类
首先,我们来看在大模型训练的过程中,GPU都需要存什么内容。
存储主要分为两大块:Model States和Residual StatesModel States指和模型本身息息相关的,必须存储的内容,具体包括:
optimizer states:Adam优化算法中的momentum和variance
gradients:模型梯度
parameters:模型参数W
Residual States指并非模型必须的,但在训练过程中会额外产生的内容,具体包括:
- activation:激活值。在流水线并行中我们曾详细介绍过。在backward过程中使用链式法则计算梯度时会用到。有了它算梯度会更快,但它不是必须存储的,因为可以通过重新做Forward来算它。
- temporary buffers: 临时存储。例如把梯度发送到某块GPU上做加总聚合时产生的存储。
- unusable fragment memory:碎片化的存储空间。虽然总存储空间是够的,但是如果取不到连续的存储空间,相关的请求也会被fail掉。对这类空间浪费可以通过内存整理来解决。
4.2 精度混合训练
知道了存储分类,进一步,我们想知道,假设模型的参数W大小是,那么每一类存储具体占了多大的空间呢?
在分析这个问题前,我们需要来了解精度混合训练。
对于模型,我们肯定希望其参数越精准越好,也即我们用fp32(单精度浮点数,存储占4byte)来表示参数W。但是在forward和backward的过程中,fp32的计算开销也是庞大的。那么能否在计算的过程中,引入fp16或bf16(半精度浮点数,存储占2byte),来减轻计算压力呢?于是,混合精度训练就产生了,它的步骤如下图:
- 存储一份fp32的parameter,momentum和variance(统称model states)
- 在forward开始之前,额外开辟一块存储空间,将fp32 parameter减半到fp16 parameter。
- 正常做forward和backward,在此之间产生的activation和gradients,都用fp16进行存储。
- 用fp16 gradients去更新fp32下的model states。
- 当模型收敛后,fp32的parameter就是最终的参数输出。
通过这种方式,混合精度训练在计算开销和模型精度上做了权衡。如果不了解fp32,fp16和bf16的细节也没关系,不影响下文的阅读。只要记住它们所占的存储空间和精度表达上的差异即可。
4.3 存储大小
现在,我们可以来计算模型在训练时需要的存储大小了,假设模型的参数W大小是 ,以byte为单位,存储如下:
因为采用了Adam优化, 所以才会出现momentum和variance, 当然你也可以选择别的优化办法。因此这里为了更通用些, 记模型必存的数据大小为 。因此最终内存开销为:
另外,这里暂不将activation纳入统计范围,原因是:
- activation不仅与模型参数相关,还与batch size相关
- activation的存储不是必须的。存储activation只是为了在用链式法则做backward的过程中,计算梯度更快一些。但你永远可以通过只保留最初的输入X,重新做forward来得到每一层的activation(虽然实际中并不会这么极端)。
- 因为activation的这种灵活性,纳入它后不方便衡量系统性能随模型增大的真实变动情况。因此在这里不考虑它,在后面会单开一块说明对activation的优化。
五、ZeRO-DP
知道了什么东西会占存储,以及它们占了多大的存储之后,我们就可以来谈如何优化存储了。
注意到,在整个训练中,有很多states并不会每时每刻都用到,举例来说;
- Adam优化下的optimizer states只在最终做update时才用到
- 数据并行中,gradients只在最后做AllReduce和updates时才用到
- 参数W只在做forward和backward的那一刻才用到
- 诸如此类
所以,ZeRO想了一个简单粗暴的办法:如果数据算完即废,等需要的时候,我再想办法从个什么地方拿回来,那不就省了一笔存储空间吗?
沿着这个思路,我们逐一来看ZeRO是如何递进做存储优化的。
5.1 :优化状态分割
首先,从 optimizer state开始优化。将optimizer state分成若干份,每块GPU上各自维护一份。这样就减少了相当一部分的显存开销。如下图:
复习一下,此时W=fp16,G=fp16,O=fp32。此时,整体数据并行的流程如下:
(1)每块GPU上存一份完整的参数W。将一个batch的数据分成3份,每块GPU各吃一份,做完一轮foward和backward后,各得一份梯度。
(2)对梯度做一次AllReduce,得到完整的梯度G,产生单卡通讯量 。为了表达简明,这里通讯量我们就不再换算成byte了,而直接根据参数量来计算。对AllReduce(reduce-scatter + all-gather)。
(3)得到完整梯度G,就可以对W做更新。我们知道W的更新由optimizer states和梯度共同决定。由于每块GPU上只保管部分optimizer states,因此只能将相应的W(蓝色部分)进行更新。(2)和(3)可以用下图表示:
(4)此时,每块GPU上都有部分W没有完成更新(图中白色部分)。所以我们需要对W做一次All-Gather,从别的GPU上把更新好的部分W取回来。产生单卡通讯量 。
做完 后, 设GPU个数为 ,显存和通讯量的情况如下:
假设各变量大小如表格第二列所示,那么 在增加1.5倍单卡通讯开销的基础上,将单卡存储降低了4倍。看起来是个还不错的trade-off,那么还能做得更好吗
5.2 :优化状态与梯度分割
现在,更近一步,我们把梯度也拆开,每个GPU格子维护一块梯度。
此时,数据并行的整体流程如下:
(1)每块GPU上存一份完整的参数W。将一个batch的数据分成3份,每块GPU各吃一份,做完一轮foward和backward后,算得一份完整的梯度(下图中绿色+白色)。
(2)对梯度做一次Reduce-Scatter,保证每个GPU上所维持的那块梯度是聚合梯度。例如对GPU1,它负责维护G1,因此其他的GPU只需要把G1对应位置的梯度发给GPU1做加总就可。汇总完毕后,白色块对GPU无用,可以从显存中移除。单卡通讯量 。(1)和(2)见下图:
(3)每块GPU用自己对应的O和G去更新相应的W。更新完毕后,每块GPU维持了一块更新完毕的W。同理,对W做一次All-Gather,将别的GPU算好的W同步到自己这来。单卡通讯量 。
再次比对下显存和通讯量:
和朴素DP相比,存储降了8倍,单卡通讯量持平,好像更牛皮了呢!那么,还可以优化吗?
5.3 :优化状态、梯度与参数分割
看到这里,也许你有点感觉了,ZeRO的思想就是:万物皆可切,万物皆可抛。所以现在,我们把参数也切开。每块GPU置维持对应的optimizer states,gradients和parameters(即W)。
数据并行的流程如下:
(1)每块GPU上只保存部分参数W。将一个batch的数据分成3份,每块GPU各吃一份。
(2)做forward时,对W做一次All-Gather,取回分布在别的GPU上的W,得到一份完整的W,单卡通讯量 Φ 。forward做完,立刻把不是自己维护的W抛弃。
(3)做backward时,对W做一次All-Gather,取回完整的W,单卡通讯量 。backward做完,立刻把不是自己维护的W抛弃。
(4)做完backward,算得一份完整的梯度G,对G做一次Reduce-Scatter,从别的GPU上聚合自己维护的那部分梯度,单卡通讯量 。聚合操作结束后,立刻把不是自己维护的G抛弃。
(5)用自己维护的O和G,更新W。由于只维护部分W,因此无需再对W做任何AllReduce操作。
显存和通讯量如下:
到这一步,我们用1.5倍的通讯开销,换回近120倍的显存。只要梯度计算和异步更新做的好,通讯时间大部分可以被计算时间隐藏,因此这样的额外通讯开销,也是划算的。
到这里,我们可以放出原始论文中的说明图了,经过以上分析,这张说明图是不是瞬间就能看懂了。不得不吐槽下,虽然ZeRO的设计不复杂,但对应论文写得真是逻辑跳跃,晦涩难懂....
仔细一想,ZeRO其实掌握了降本增效的精髓:用完即弃,需要再补。反正我补一个和你差不多的,也不会花费很多通(找)讯(人)时间,还大大降低了我的成本。模型的每一层多算(造)几(轮)遍(子)有啥关系呢,反正在我的预算里每个人都一刻不停地干活,就行啦!
5.4 ZeRO VS 模型并行
知道模型并行的朋友,可能会想,既然ZeRO都把参数W给切了,那它应该是个模型并行呀?为什么要归到数据并行里呢?
其实ZeRO是模型并行的形式,数据并行的实质。
模型并行,是指在forward和backward的过程中,我只需要用自己维护的那块W来计算就行。即同样的输入X,每块GPU上各算模型的一部分,最后通过某些方式聚合结果。
但对ZeRO来说,它做forward和backward的时候,是需要把各GPU上维护的W聚合起来的,即本质上还是用完整的W进行计算。它是不同的输入X,完整的参数W,最终再做聚合。
因为下一篇要写模型并行Megatron-LM,因此现在这里罗列一下两者的对比。
六、ZeRO-R
说完了以上对model states的显存优化,现在来看对residual states的优化。
6.1 : Partitioned Activation Checkpointing
前面说过,对activation的存储是灵活的。不像optimizer states,gradients和parameters对模型更新是必须的,activation只是起到加速梯度计算的作用。因此,在哪几层保存activation,保存哪些activation都是可以灵活设置的。同样,我们也可以仿照以上切割方式,每块GPU上只维护部分的activation,需要时再从别的地方聚合过来就行。需要注意的是,activation对显存的占用一般会远高于模型本身,通讯量也是巨大的,所以这块要灵活、有效地实验设计。
6.2 : Constant Size Buffer
固定大小的内存buffer,它的目的在于:
提升带宽利用率。当GPU数量上升,GPU间的通讯次数也上升,每次的通讯量可能下降(但总通讯量不会变)。数据切片小了,就不能很好利用带宽了。所以这个buffer起到了积攒数据的作用:等数据积攒到一定大小,再进行通讯。
使得存储大小可控。在每次通讯前,积攒的存储大小是常量,是已知可控的。更方便使用者对训练中的存储消耗和通讯时间进行预估。
6.3 : Memory Defragmentation
在前文提过,设置机制,对碎片化的存储空间进行重新整合,整出连续的存储空间。防止出现总存储足够,但连续存储不够而引起的存储请求fail
七、ZeRO-Offload与ZeRO-Infinity
最后,简单介绍一下ZeRO-Offload。它的核心思想是:显存不够,内存来凑。如果我把要存储的大头卸载(offload)到CPU上,而把计算部分放到GPU上,这样比起跨机,是不是能既降显存,也能减少一些通讯压力呢?ZeRO-Offload的做法是:
- forward和backward计算量高,因此和它们相关的部分,例如参数W(fp16),activation,就全放入GPU。
- update的部分计算量低,因此和它相关的部分,全部放入CPU中。例如W(fp32),optimizer states(fp32)和gradients(fp16)等。
具体切分如下图:
ZeRO-infinity也是同理,它们在解决的事情都是:找个除GPU之外的地方,存数据。感兴趣的朋友可以深入研究,这里就不展开了。
参考
1、https://web.eecs.umich.edu/~mosharaf/Readings/Parameter-Server.pdf
2、https://zh.d2l.ai/chapter_computational-performance/parameterserver.html
3、https://blog.csdn.net/dpppBR/article/details/80445569
4、https://arxiv.org/abs/1910.02054
5、https://blog.51cto.com/u_14691718/5631471
6、https://arxiv.org/pdf/1910.02054.pdf
7、https://blog.51cto.com/u_14691718/5631471
8、https://blog.csdn.net/qq_43307074/article/details/127688761
#MATRIX-Gen
1000多个智能体组成,AI社会模拟器MATRIX-Gen助力大模型自我进化
本文作者来自于上海交通大学人工智能学院的Multi-Agent Governance & Intelligence Crew (MAGIC团队)和牛津大学。共同第一作者有唐铄、庞祥鹤、刘泽希和唐博瀚。指导老师包括上海交大的王延峰教授、陈思衡副教授,以及牛津大学的董晓文副教授。
随着大语言模型(LLMs)在处理复杂任务中的广泛应用,高质量数据的获取变得尤为关键。为了确保模型能够准确理解并执行用户指令,模型必须依赖大量真实且多样化的数据进行后训练。然而,获取此类数据往往伴随着高昂的成本和数据稀缺性。因此,如何有效生成能够反映现实需求的高质量合成数据,成为了当前亟需解决的核心挑战。
那么,真实数据的需求是如何产生的?设想一位程序员在进行机器学习模型的开发与调优时,他会提出问题:「如何调整超参数以提高模型预测准确率?」 这种指令并非凭空而来,而是源于他所处的具体工作情境 —— 数据分析和模型优化。同样,用户在日常生活中的指令无论是编程任务、医疗诊断还是商业决策,往往与他们所面临的具体场景密切相关。要生成能够真实反映用户需求的合成数据,必须从这些实际情境中出发,模拟出与用户需求相匹配的场景。
基于这一理念,上海交通大学与牛津大学的研究团队提出了一项创新方案 —— 基于多智能体模拟的数据合成。团队提出了 MATRIX——AI 社会模拟器,构建了一个由 1000 多个 AI 智能体组成的模拟社会。在这个模拟社会中,每一个 AI 智能体代表了一个拥有独立身份和人格的数字人,这些 AI 智能体可以模拟出复杂的交流和互动模式,涵盖了从软件开发到商业活动的广泛场景。基于这些场景,团队进一步开发了 MATRIX-Gen 数据合成器,能够根据不同需求合成高度多样化且高质量的训练指令数据。
- 论文链接:https://arxiv.org/pdf/2410.14251
- 代码主页:https://github.com/ShuoTang123/MATRIX-Gen
为验证 MATRIX-Gen 合成数据的高质量,研究团队使用 Llama-3-8B-Instruct 驱动社会模拟,仅合成了 2 万条数据用于训练 Llama-3-8B-Base 模型。尽管数据量极少,训练后的模型在 AlpacaEval 2 和 Arena-Hard 基准测试中竟然大幅超越了 Llama-3-8B-Instruct 自身。这一结果不仅证明了 MATRIX-Gen 合成数据的高效性,也标志着模型在合成数据驱动下实现了自我进化。此外,在代码生成、多轮对话和安全性任务上,MATRIX-Gen 生成的专用数据同样表现优异,甚至超越了为这些特定任务设计的专用数据集。这项研究为通过合成数据提升大语言模型性能提供了全新的解决方案,展示了 AI 模拟社会在数据合成中的巨大潜力,为未来大语言模型的后训练数据合成开辟了创新的路径。
基于合成数据的后训练系统
本研究提出的后训练系统旨在利用基于多智能体模拟技术构建的 AI 模拟社会,合成高质量的训练数据,以提升预训练大语言模型的指令跟随能力。该系统的核心理念源于人类在现实场景中提问的方式 —— 人们基于自身需求提出多样且深入的问题。因此,本研究通过 AI 模拟社会合成人类社会中的场景,并利用这些场景引导 LLM 提出信息丰富、贴近现实的问题,从而产生高质量的训练数据。
如下图所示,该系统包含三个步骤:
1. 合成社会场景:利用多智能体模拟技术构建 AI 模拟社会,该社会中的每个场景由一组 AI 智能体及其对应的文本行动构成。为了确保社会场景的真实性和多样性,本研究设计了大规模人类社会模拟器 MATRIX,创建了一个包含各种 AI 智能体的互动环境。此模拟器充分发挥了 LLM 的角色扮演能力,使得 AI 智能体能够逼真地模拟人类行为,进行规划、观察和行动,进而生成丰富且高度真实的社会场景。
2. 合成训练数据:根据合成的社会场景,生成符合任务需求的后训练数据。本研究设计了场景驱动的指令生成器 MATRIX-Gen,模拟人类在日常生活中提出问题的过程,结合场景生成指令,确保更高的真实性;通过选择特定场景,能够合成符合任务需求的数据,具备可控性。这一步骤合成包括 SFT、DPO 以及各种专用数据集。
3. 模型微调:利用合成的 SFT 数据集,对预训练模型进行监督微调,以获得具备指令跟随能力的模型。随后,基于合成的偏好数据集,采用 DPO 进一步训练模型。
AI 社会模拟器 MATRIX
为了合成多样且丰富的场景,以助力数据的合成,本研究提出了人类社会模拟器 MATRIX。该模拟器的输入为若干 AI 智能体档案,输出为文本形式的场景。通过模拟人类的 AI 智能体和结构化的通信机制,MATRIX 实现了大规模的人类社会模拟,从而生成多样且真实的场景。
- 模拟人类的智能体:每个 AI 智能体根据匿名化的真实人类档案进行初始化,并由 LLM 生成其个性和人生目标。这些目标进一步分解为可执行的步骤,形成 AI 智能体的行动计划。例如,一个医学教授的生活目标可能包括传播科学知识,而其计划则包括进行研究、发表论文、进行讲座和组织教育项目。这些步骤指导 AI 智能体未来的行动,确保它们朝着目标努力并展现出有目的的行为。当出现新观察时,AI 智能体会根据其记忆和个性做出反应;在没有新观察的情况下,它们则遵循既定计划追求目标。
- 结构化的通信机制:受人类社会中同质性现象的启发,我们根据相似特征对 AI 智能体进行分组,以减少不必要的连接,从而提高模拟的可扩展性。在每组中,本研究引入一个集中调节器来管理组内和组间的沟通。这一设计促进了相似 AI 智能体之间的更多互动,同时仍允许长距离交流,丰富信息流并增强真实性。此外,这种结构化通信机制能够防止 AI 智能体接收到过多无关信息,确保模拟的有效性。
数据合成器 MATRIX-Gen
在合成了真实多样化的社会场景后,本研究设计了场景驱动的指令生成器 MATRIX-Gen,以满足特定任务需求并合成后训练数据。通过选择与用户需求相关的场景,MATRIX-Gen 能够生成符合人类意图的指令,从而确保合成指令的真实性和可控性。
如下图所示,在合成后训练数据的过程中,MATRIX-Gen 模拟了人类提问的过程。针对不同数据场景的需求(如通用任务或代码任务),MATRIX-Gen 结合每个 AI 智能体的个性和行动,将这些信息整合到指令生成提示中,模拟人类在日常生活中提出问题的方式。随后,基于上述指令生成提示,MATRIX-Gen 直接调用对齐的 LLM 生成合成指令及其对应的回答。
下图展示了一位 IT 经理在汽车数据分析场景下,提出「如何调整超参数以提高模型预测准确率」的例子:
通过这一方法,本研究能够合成三种类型的数据集,包括监督微调数据集 MATRIX-Gen-SFT、偏好调优数据集 MATRIX-Gen-DPO,以及特定领域的 SFT 数据。每种数据集的指令生成在复杂性和专业性上各具特点,确保满足不同场景下的需求。
性能表现
在实验中,本研究选择 Llama-3-8B-Instruct 作为数据合成模型,选择 Llama-3-8B 作为训练的模型,通过模型的训练效果评估 MATRIX-Gen 在通用任务、多轮对话、代码生成上的数据合成能力。
AlpacaEval 2 和 Arena-Hard 上的评估结果表明,通过多智能体模拟合成的 MATRIX-Gen-SFT 数据优于多个真实数据集以及合成数据集。
在 MATRIX-SFT 模型上 DPO 的训练结果表明,通过 MATRIX-Gen-DPO 训练的模型超越多种合成偏好数据训练的模型,以及 Llama-3-8B-Instruct。值得注意的是,MATRIX-Gen-DPO 训练后的模型总共仅使用了 2 万条合成数据,便实现了对 Llama-3-8B-Instruct 自身的超越,充分展示了其高质量和自我进化的能力。
在代码生成与安全输出的任务中,MATRIX-Gen 合成的数据集均超越了对应领域的专用数据集,显示出 MATRIX-Gen 在合成数据上的高可控性。
上图展示了 MATRIX-Gen-SFT 合成指令的可视化,显示出合成数据的多样性。
总结与展望
本研究提出了一种基于 AI 智能体社会模拟的后训练数据合成框架。依托 MATRIX 合成的 AI 模拟社会,MATRIX-Gen 能够可控地合成高质量的多样数据。在通用和专用任务中,仅使用 0.2% 的数据,即可获得优于大模型研发领军团队 Meta AI 所用数据集的模型训练效果,突显了 MATRIX-Gen 在数据合成中的优势。
本研究希望该数据合成框架能够帮助定量研究何种类型的数据更适合用于监督微调和偏好优化,深入探讨不同数据特性对模型性能的影响。此外,我们展望通过引入更强大的 AI 智能体,如具备工具调用能力的 AI 智能体,以及接入更丰富的环境,进一步合成更复杂的数据,从而提升大语言模型在复杂任务中的表现。
#火山方舟首度公开「会话无痕」技术细节
2024 年,AI 大模型从「以分计价」跨入「以厘计价」的时代。
信号指向很清晰:把基础设施成本打下来,就是为了应用的爆发,但「算力价格」这把尺子还不够用。
在众多大模型中货比三家,需要投入大量信息成本。相信供应商、中间商「守规矩」、「靠谱」,更不易,信任成本过高,陷入囚徒困境,用户就会趋于保守,放弃潜在交易。
回首 2024,尽管大模型展现出非凡能力,破坏信任的糟心事儿也一直没断过。
4 月,海外某头部大模型商的 AI 语言模型因开源库漏洞导致用户对话泄露,致使意大利政府史无前例地叫停服务。此波未平,该产品长期记忆功能又出现严重漏洞,黑客可以随便访问用户聊天记录。
年初,荷兰一家数据公司的配置失误,导致多家企业(包括头部车企)的用户隐私数据遭泄露。
能力超凡、使用简单但又风险丛生,这样的混乱组合让企业老板难以驾驭。在采访全球多家企业、8000 多名 IT 专业人员后,IBM《 2023 年全球 AI 采用指数》发现:
和传统 AI 的采纳门槛不同,企业采纳生成式 AI 的最大障碍是数据隐私( 57% )以及信任和透明度( 43% )
当 AI 大模型的技术迭代周期几乎以月(甚至周)计时,数据技术仍在「蜗牛爬行」,这种失衡正在成为大模型发展的主要隐忧之一。
「生成式 AI 带来的安全挑战,已经超出了传统安全技术的应对范围。」火山引擎智能算法负责人、火山方舟负责人吴迪告诉他。作为火山引擎旗下的「一站式大模型服务平台」,火山方舟为企业提供模型精调、推理、评测等全方位功能与服务。
在模型精调环节,企业的核心知识都浓缩在训练数据中,如何确保这些数据、提示词以及模型响应的专属性?如何保证精调后的模型不被他人窃取使用?
推理环节更受关注,因为用户在使用过程中会输入大量真实、敏感的数据来获取模型建议。平台如何保证不会滥用用户数据?数据传输、计算和存储的全流程中,如何不被黑客窃取?平台又如何向用户证明其确实履行了承诺的安全措施?
企业在探索大模型应用场景时,这些安全痛点已经成为首要考虑因素,而传统的安全技术方案早已对此捉襟见肘。
私有部署之困在于,过去「数据不动,模型动」——企业把数据留在私域、将 AI 模型部署到企业私有空间——的策略在大模型时代会碰壁。
首先是技术代差问题,私有部署难以跟上公有云模型的快速迭代节奏;其次是算力成本,规模化运营的公有云服务能提供更高的性价比。此外,模型生产商也会担心核心技术外泄。
现有的隐私计算技术比明文计算慢了上百倍,就像给巨人穿上盔甲,只适合特定场景,但不适用于大模型服务的场景。
以 MPC 为例,将浮点数转为整数计算会损失精度,且单次计算需要 100-200 毫秒,应用场景极其有限。同态加密技术虽可在加密状态下计算,但性能开销会增加百倍甚至更多,一个原本需要 3 秒的处理任务,使用同态加密后可能延长至 5 分钟,难以满足生产需求。
目前,AI 模型推理比较好的选择仍是在明文状态下进行,吴迪表示。虽然理论上存在完全密态计算,让模型直接处理加密数据,但在大模型场景下,这种方案的计算开销过大,实用性较低。
现在的大模型计算主要依赖 GPU 等加速设备,但 GPU 相关的可信执行环境( TEE )技术还不成熟。TEE 类技术主要用于加强环境隔离,要真正满足现实安全需求,还需要配合代码审计、网络隔离等关键安全技术,多管齐下。
至于传统云安全更像是「大楼的物业保安」,而大模型需要的是「保险箱级别」的数据安全。
经过两年潜心打磨,火山方舟推出了一套「会话无痕」方案,保证你的数据,唯你可见、唯你所用、唯你所有。
四重核心功能筑起了数据全生命周期的铜墙铁壁——从传输、使用到存储,没有一个环节被遗漏;推理、模型精调和评估以及数据预处理,关键业务场景均有覆盖。
第一重:链路全加密。在用户与平台之间修筑了一条加密通道,确保用户数据离开企业后,能够安全抵达安全沙箱。
「双层加密」设计,打造了一个高可靠的安全环境。其中,网络层的传输加密, 通过 HTTPS 确保基础安全,mTLS 提供双向认证,PrivateLink 则在流量转发层与 GPU 推理实例之间建立专属隧道。
应用层的会话加密犹如叠加一层保险,即使通道被攻破,你的数据本身仍然安全。
详言之,每个部署在安全沙箱中的推理实例,都会被分配唯一的身份证书(就像「锁」)。当用户发送用户数据时,可用手中公钥将它们加密,只有到达正确的安全沙箱环境(钥匙和锁「匹配上」),才能被解密使用。否则,就算攻击者中途截获数据,也是无用之功。
第二重:数据高保密。除了只在必要的时刻、必要的地点短暂解密,火山方舟用户的数据其余时间都处于密文状态。
所有训练数据在进入安全沙箱前都是加密存储的,密钥由用户独自掌控。
一旦进入沙箱,推理等服务进程就能像往常一样使用这些数据,基于 FUSE 的透明加密文件系统会无缝、自动完成数据的加解密。
训练完的模型,会被立刻加密保存到分布式存储系统,等待再次调用。
字节自主研发的技术可支持 GPU 加解密,保证推理等场景精调模型的高效动态调度,满足生产环境的性能需求。
第三重:环境强隔离。它就像一个四层嵌套的「俄罗斯套娃」防护系统,从内到外依次是容器沙箱、网络隔离、可信代理和白屏化运维。
其中,容器沙箱是一种安全增强,弥补容器隔离性不足。在网络层面,平台创新地实现了任务级别的动态网络隔离,即使在同一 VPC 环境下的不同任务也无法直接通信,有效防止攻击者的横向渗透。
外层的可信代理和白屏化运维则进一步确保了系统运行的安全性,严格管控数据流动和运维操作。
第四重:操作可审计。火山方舟提供三大类日志。
首先是云基础安全日志,负责主机层面的安全日志采集。
第二个就是安全业务日志,包括沙箱连接、沙箱登录、容器逃逸和 KMS 访问等关键日志,帮助用户快速定位可疑行为,预防风险。
例如,沙箱连接日志会记录所有对沙箱环境的连接尝试,显示来源 IP 、目标 IP 、进程信息( PID )和安全等级,方便用户识别可疑连接;KMS 访问日志会跟踪所有密钥操作,监控精调模型的密钥使用情况。
第三类是用户可见日志,包括所有历史访问记录,支持用户直接查看和与其他层面(云基础、安全业务)日志的交叉验证,确定日志的真实性,不存在篡改和遗漏。
就像电缆的绝缘层、保护层、铠装,环环相依,保护「线芯」不受外界因素侵蚀,在「会话无痕」的四重保护下,你的数据,唯你可见,唯你所用,唯你所有,平台安全水位也被提升到一个相当高的位置。
这不是简单堆砌多种安全技术的结果,而是对大模型时代数据安全的一次重新定义,包含三个核心理念。
首先,安全不是事后添加的补丁,而是埋在大楼水泥地基里的钢筋,从一开始就作为基本能力,被织进火山方舟大模型平台的底层设计中。
第二,在不显著损耗模型效果和推理效率的前提下,提升平台安全。
增强安全防护通常会导致明显的性能损耗,因此,在保持大模型性能的同时提升安全性,任务难度呈指数级增长。「会话无痕」比较好地平衡了这一点,吴迪认为,「我们可能是业界做得最好的公司之一。」
原因很简单,火山方舟不仅精通安全技术,还积累了丰富的场景应用 know-how ,如知道不同场景下的真正安全节点,包括用户的实际使用模式、模型运行特点等。
有了这些知识,他们就能简化掉一些安全性虽高但会导致大量浪费、性能损耗的冗余开销,在关键点实施精准的安全加固,优化安全措施的实现方式。
第三就是透明可信,阳光是最好的「防腐剂」。
最初,我们觉得环境强隔离的安全沙箱设计最具挑战性,但现在发现审计日志才是最难的。吴迪说。
这个难点并非技术本身,更多的是产品设计上,如何让专业的安全信息变得通俗易懂,用户不仅能看到日志,更要能看懂日志,理解当前的安全水准处在一个什么样的位置。
未来几个月,火山方舟计划进一步提升平台安全水位——从「不作恶( don't be evil )」提升到「无法作恶( can't be evil )」,从技术层面确保平台即使想做坏事也做不到。
例如,进一步升级审计日志系统,让用户能够全方位监督平台的每一次计算过程是否合规、安全。引入更先进的硬件可信技术,并邀请第三方机构进行独立审计和测试,通过技术手段和外部监督,从根本上保证平台行为的透明可信。
吴迪透露,火山方舟目前拥有一支独立的安全技术团队,由资深安全主管领衔,汇集了系统架构和信息安全领域的专家。
安全技术团队与负责模型推理等核心功能的系统工程团队保持着微妙的平衡:既能密切协作,又能独立进行安全评估,形成了有效的互助与制衡机制。
同时,火山方舟还建立了常态化的蓝军攻防体系,通过持续的安全测试来检验和强化系统防护能力。
长远来看,在一个快速变迁的技术世界里,构建一个既安全又不失性能的安全体系,有时就像在流沙上建造堡垒,极具挑战性。
多模态交互的出现使问题更加复杂——不同模态数据在规模和处理方式上差异显著,仅视频的加解密流量就远超文本处理的需求,吴迪举例说。
更深层的挑战来自模型推理系统本身的复杂性。它已经演变成一个庞大的分布式系统,涉及多样化硬件、推理优化方案和 RDMA 网络传输,而这些底层架构还在不断演进中。这种动态变化的环境,使得安全体系的构建和维护变得愈发具有挑战性。
然而,前景依然光明。火山方舟相信,生成式 AI 的市场规模有望达到当前的千倍,渗透各行各业的核心业务。
当它距离企业核心业务越近,除了性能、性价比,企业对数据安全和信任的要求也会水涨船高。
着眼未来,顺势而为,火山方舟希望载着越来越多的大模型玩家,加速驶向更远的节点。
#计算机通用方法,往往比深奥的纯数学更能解决问题
陶哲轩强调了在数学应用和问题解决中需要找到合适的平衡点:既不过度简化,也不过度复杂化,避免过度优化和过度抽象导致的反效果。
刚刚,著名数学家陶哲轩在个人社交平台更新的几篇帖子,引起大家广泛的共鸣。
陶哲轩用浅显易懂的语言表达了自己对数学的理解与思考心得。
文中谈到了一个关于「度」的问题,陶哲轩表示在设计系统时,缺乏或者过度的数学分析可能都会适得其反,所以要适度。
有时,我们不需要太过复杂精深的专业知识,大道至简。
对于大多数任务,使用一些相对简单但通用的数学方法,往往比专门设计的算法效果更好。
陶哲轩还提到,在纯数学中,故意忽略一些直觉上看似非常重要的信息非常有帮助。
接下来是陶哲轩帖子全部内容。
掌握一点点的数学知识就能大有裨益。系统的设计不仅仅会因为缺乏足够的数学分析而受到限制,同样也可能因为过度的数学分析而受到阻碍。
一个常见的例子是网络安全中对密码的要求。从数学上讲,密码要求越复杂(例如,规定最小长度、特殊字符或不重复使用密码),密码就越安全。
然而,如果要求过于复杂,用户和服务提供商可能会寻找绕过复杂要求的方法,比如寻找简单的密码重置或恢复方式,或者将密码存储在不安全的系统中。这些做法反而可能降低整体系统的安全性,而不是提升它。
另一方面,只对单一指标(如用户使用密码直接登录系统)进行过度优化,可能会损害更广泛的目标。就如古德哈特定律(Goodhart's law)中所说的,「当压力施于其上以进行控制时,任何观测到的统计恒性都倾向消散。」
粗略的讲,在设计安全性时,直接输入方式的安全性应该加强到与其他输入方式的安全性相当,但超过这个程度的加强反而可能适得其反。
举个例子来说,如果一栋建筑的前门有锁,但窗户没有防护,那么再给前门加更多的锁就没有太大意义,这样做甚至可能导致一种危险的虚假安全感。另一方面,如果窗户比前门更难进入,那么在前门上至少加一把锁就很合理。
在人工智能领域,强化学习之父 Rick Sutton 的「苦涩的教训」(Bitter Lesson)就是这一原则的一个例子。
从直觉上来看,大家往往会认为针对具体任务量身定制算法是最自然的选择,在某些情况下,确实能取得不错的效果。
其实,对于大多数任务,使用一些相对简单但通用的数学方法,如梯度下降和反向传播,往往比专门设计的算法效果更好。通用方法不依赖于特定任务的领域知识,而是通过大量的数据和计算资源来训练模型,通常能带来更大的进展。
最近,我看到了有人为传感器网络开发更实惠的模数转换器(ADC),就是这条发现的证明。
传统上,ADC 电路基于经典电气工程原理设计,采用常微分方程(ODE)、共振、傅里叶变换等数学工具来构建高效电路。然而,在一些特定环境(如传感器网络)中,我们的目标是大规模、快速且成本低的方式实现模数转换,同时可以容忍一定的故障率。
在这种情况下,训练神经网络来设计 ADC 电路,不依赖任何专业领域的知识(如傅里叶分析),反而是更好的方法。
这并不是说领域知识毫无用处 —— 例如,物理信息神经网络在许多物理领域的表现可以远超标准神经网络 —— 关键在于了解在什么情况下,应该运用多少领域知识。
在纯数学中,一个有效的解题方法是故意忽略一些直觉上看似非常重要的信息。比如,在分析数论中,许多进展都是通过把像素数这样的「重要」数学对象转化为看起来更加简单、结构较少的形式来实现的。这样做可以让我们更容易找到解决问题的途径。
但抽象也需要把握一个度。如果抽象得过头,就会丢失关键信息,反而无法解决问题;而如果抽象得恰到好处,问题就会变得更加清晰,从而找到合适的技巧去解决它。在此过程中甚至可以做出一些看似不太合理的变换,让解题思路更加灵活起来。
我有时会开玩笑说,应用数学家只需要掌握每本纯数学研究生教材的前两章,之后的章节对他们可能帮助不大(甚至可能有负面作用)。
另一方面,正是寻找第 3 到第 12 章的过程,才使得前两章至臻完美、具有广泛实用性的瑰宝。
在读完陶哲轩的这段见解后,有人评论道:这些建议非常有价值,不论是对于哪种问题,都要做到:
- 简化细节,直到看到更宏观的问题结构。
- 判断是否已有针对同类问题的解决方案。
- 或者判断这个一般性问题类是否过于笼统,或者是否过于具体。
参考链接:
https://mathstodon.xyz/@tao/113482950431855749
#智能体工作流越来越成熟
受 ChatGPT 强大问答能力的影响,大型语言模型(LLM)提供商往往优化模型来回答人们的问题,以提供良好的消费者体验。
随着智能体研究日趋成熟,优化似乎有了新的方向。
人工智能著名学者、斯坦福大学教授吴恩达今天指出:「现在有一种趋势是优化模型以适应智能体工作流程,这将为智能体性能带来巨大提升」,并撰写一篇博客简单阐述了这种趋势。
我们对博客内容进行了不改变原意的编译、整理,以下是博客内容:
继 ChatGPT 在回答问题方面取得突破性成功之后,许多 LLM 的开发都集中在提供良好的消费者体验上。因此,LLM 被调整为回答问题或遵循人类提供的指令。指令调整指导模型的数据集很大一部分可以为人类编写的问题和指令提供更有用的答案,面向 ChatGPT、Claude、Gemini 等等。
但智能体工作负载不同,人工智能软件不是直接为消费者生成响应,而是应该在迭代工作流程中:
- 反思自己的输出;
- 使用工具;
- 编写规划;
- 在多智能体环境中进行协作。
主要模型制造商也越来越多地优化用于 AI 智能体的模型。
以工具使用(或函数调用)为例。如果 LLM 被问及当前天气,它将无法从训练数据中获取所需的信息。相反,它可能会生成 API 调用请求以获取该信息。甚至在 GPT-4 原生支持函数调用之前,应用程序开发人员就已经使用 LLM 来生成函数调用,通过编写更复杂的提示来告诉 LLM 哪些函数可用,然后让 LLM 生成用于确定是否要调用函数的字符串。
在 GPT-4 之后,生成此类调用变得更加可靠,然后许多其他模型本身就支持函数调用。如今,LLM 可以决定调用函数来搜索信息以进行检索增强生成 (RAG)、执行代码、发送电子邮件、在线下订单等等。
最近,Anthropic 推出了升级版的 Claude 3.5 Sonnet,能像人一样使用计算机。这意味着 LLM 原生使用计算机方向向前迈出了一大步,将帮助许多开发人员。一些团队还致力于让 LLM 使用计算机构建新一代 RPA(机器人流程自动化)应用程序。
随着智能体工作流程的成熟,我看到的是:
- 首先,许多开发人员正在 prompt LLM 来执行他们想要的智能体行为。这样可以进行快速、丰富的探索!
- 在极少数情况下,开发非常有价值的应用程序的开发人员将微调 LLM,以更可靠地执行特定的智能体功能。例如,尽管许多 LLM 本身支持函数调用,但它们是通过将可用函数的描述作为输入,然后(希望)生成输出 token 以请求正确的函数调用来实现这一点的。对于生成正确函数调用非常重要的任务关键型应用程序,针对应用程序的特定函数调用微调模型可显著提高可靠性。(但请避免过早优化!我仍然看到太多团队在进行微调,而他们可能应该在采取这种做法之前花更多时间进行 prompt。)
- 最后,当诸如工具使用或计算机使用之类的能力对开发人员来说似乎很有价值时,主要的 LLM 提供商正在将这些能力直接构建到他们的模型中。尽管 OpenAI o1-preview 的高级推理对消费者有帮助,但我预计它对于智能体推理和规划会更有用。
大多数 LLM 都针对回答问题进行了优化,主要是为了提供良好的消费者体验,我们已经能够将它们「移植」到复杂的智能体工作流程中,以构建有价值的应用程序。为支持智能体中的特定操作而构建 LLM 的趋势将为智能体性能带来很大提升。我相信,在未来几年内,在这个方向上将实现巨大的智能体能力提升。
原文链接:
https://www.deeplearning.ai/the-batch/issue-275/
#U-DiTs
Sora 的发布让广大研究者及开发者深刻认识到基于 Transformer 架构扩散模型的巨大潜力。作为这一类的代表性工作,DiT 模型抛弃了传统的 U-Net 扩散架构,转而使用直筒型去噪模型。鉴于直筒型 DiT 在隐空间生成任务上效果出众,后续的一些工作如 PixArt、SD3 等等也都不约而同地使用了直筒型架构。
然而令人感到不解的是,U-Net 结构是之前最常用的扩散架构,在图像空间和隐空间的生成效果均表现不俗;可以说 U-Net 的 inductive bias 在扩散任务上已被广泛证实是有效的。因此,北大和华为的研究者们产生了一个疑问:能否重新拾起 U-Net,将 U-Net 架构和 Transformer 有机结合,使扩散模型效果更上一层楼?带着这个问题,他们提出了基于 U-Net 的 DiT 架构 U-DiT。
论文标题:U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers
论文地址:https://arxiv.org/pdf/2405.02730
GitHub 地址:https://github.com/YuchuanTian/U-DiT
从一个小实验谈开去
首先,研究者开展了一个小实验,在实验中尝试着将 U-Net 和 DiT 模块简单结合。然而,如表 1 所示,在相似的算力比较下,U-Net 的 DiT(DiT-UNet)仅仅比原始的 DiT 有略微的提升。
在图 3 中,作者们展示了从原始的直筒 DiT 模型一步步演化到 U-DiT 模型的过程。
根据先前的工作,在扩散中 U-Net 的主干结构特征图主要为低频信号。由于全局自注意力运算机制需要消耗大量算力,在 U-Net 的主干自注意力架构中可能存在冗余。这时作者注意到,简单的下采样可以自然地滤除噪声较多的高频,强调信息充沛的低频。既然如此,是否可以通过下采样来消除对特征图自注意力中的冗余?
Token 下采样后的自注意力
由此,作者提出了下采样自注意力机制。在自注意力之前,首先需将特征图进行 2 倍下采样。为避免重要信息的损失,生成了四个维度完全相同的下采样图,以确保下采样前后的特征总维度相同。随后,在四个特征图上使用共用的 QKV 映射,并分别独立进行自注意力运算。最后,将四个 2 倍下采样的特征图重新融为一个完整特征图。和传统的全局自注意力相比,下采样自注意力可以使得自注意力所需算力降低 3/4。
令人惊讶的是,尽管加入下采样操作之后能够显著模型降低所需算力,但是却反而能获得比原来更好的效果(表 1)。
U-DiT:全面超越 DiT
根据此发现,作者提出了基于下采样自注意力机制的 U 型扩散模型 U-DiT。对标 DiT 系列模型的算力,作者提出了三个 U-DiT 模型版本(S/B/L)。在完全相同的训练超参设定下,U-DiT 在 ImageNet 生成任务上取得了令人惊讶的生成效果。其中,U-DiT-L 在 400K 训练迭代下的表现比直筒型 DiT-XL 模型高约 10 FID,U-DiT-S/B 模型比同级直筒型 DiT 模型高约 30 FID;U-DiT-B 模型只需 DiT-XL/2 六分之一的算力便可达到更好的效果(表 2、图 1)。
在有条件生成任务(表 3)和大图(512*512)生成任务(表 5)上,U-DiT 模型相比于 DiT 模型的优势同样非常明显。
研究者们还进一步延长了训练的迭代次数,发现 U-DiT-L 在 600K 迭代时便能优于 DiT 在 7M 迭代时的无条件生成效果(表 4、图 2)。
U-DiT 模型的生成效果非常出众,在 1M 次迭代下的有条件生成效果已经非常真实。
论文已被 NeurIPS 2024 接收,更多内容,请参考原论文。