#FlashAttention

FlashAttention新升级!斯坦福博士一人重写算法,第二代实现了最高9倍速提升。Transformer上下文长度史诗级提升

继超快且省内存的注意力算法FlashAttention爆火后,升级版的2代来了。

FlashAttention-2是一种从头编写的算法,可以加快注意力并减少其内存占用,且没有任何近似值。

比起第一代,FlashAttention-2速度提升了2倍。

甚至,相较于PyTorch的标准注意力,其运行速度最高可达9倍。

一年前,StanfordAILab博士Tri Dao发布了FlashAttention,让注意力快了2到4倍,如今,FlashAttention已经被许多企业和研究室采用,广泛应用于大多数LLM库。

如今,随着长文档查询、编写故事等新用例的需要,大语言模型的上下文以前比过去变长了许多——GPT-4的上下文长度是32k,MosaicML的MPT上下文长度是65k,Anthropic的Claude上下文长度是100k。

但是,扩大Transformer的上下文长度是一项极大的挑战,因为作为其核心的注意力层的运行时间和内存要求,是输入序列长度的二次方。

Tri Dao一直在研究FlashAttention-2,它比v1快2倍,比标准的注意力快5到9倍,在A100上已经达到了225 TFLOP/s的训练速度!

论文地址:https://tridao.me/publications/flash2/flash2.pdf

项目地址:https://github.com/Dao-AILab/flash-attention

FlashAttention-2:更好的算法、并行性和工作分区

端到端训练GPT模型,速度高达225 TFLOP/s

虽说FlashAttention在发布时就已经比优化的基线快了2-4倍,但还是有相当大的进步空间。

比方说,FlashAttention仍然不如优化矩阵乘法(GEMM)运算快,仅能达到理论最大FLOPs/s的25-40%(例如,在A100 GPU上的速度可达124 TFLOPs/s)。

w~深度学习~合集4_深度学习

GEMM如何用于卷积

在过去的几个月里,研究人员一直在开发FlashAttention-2,它的性能指标比第一代更强。

研究人员表示,2代相当于完全从头重写,使用英伟达的CUTLASS 3.x及其核心库CuTe。从速度上看,FlashAttention-2比之前的版本快了2倍,在A100 GPU上的速度可达230 TFLOPs/s。

当使用端到端来训练GPT之类的语言模型时,研究人员的训练速度高达225 TFLOPs/s(模型的FLOP利用率为72%)。

对注意力计算重新排序

我们知道,FlashAttention是一种对注意力计算进行重新排序的算法,利用平铺、重新计算来显著加快计算速度,并将序列长度的内存使用量从二次减少到线性。

w~深度学习~合集4_深度学习_02

研究人员将输入块从HBM(GPU内存)加载到SRAM(快速缓存),并对该模块执行注意,更新HBM中的输出。

由于没有将大型中间注意力矩阵写入HBM,内存的读/写量也跟着减少,进而带来了2-4倍的执行时间加速。

下图是FlashAttention的前向传递图:通过平铺和softmax重新缩放,研究人员人员按模块进行操作,避免从HBM读取或是写入,同时获得正确输出,无需近似。

w~深度学习~合集4_深度学习_03

然而,FlashAttention仍然存在一些低效率的问题,这是由于不同线程块之间的工作划分并不理想,以及GPU上的warp——导致低占用率或不必要的共享内存读写。

更少的non-matmul FLOP(非矩阵乘法浮点计算数)

研究人员通过调整FlashAttention的算法来减少non-matmul FLOP的次数。这非常重要,因为现代GPU有专门的计算单元(比如英伟达GPU上的张量核心),这就使得matmul的速度更快。

例如,A100 GPU FP16/BF16 matmul的最大理论吞吐量为312 TFLOPs/s,但non-matmul FP32的理论吞吐量仅为 19.5 TFLOPs/s。

另外,每个非matmul FLOP比matmul FLOP要贵16倍。

所以为了保持高吞吐量,研究人员希望在matmul FLOP上花尽可能多的时间。

研究人员还重新编写了FlashAttention中使用的在线softmax技巧,以减少重新缩放操作的数量,以及边界检查和因果掩码操作,而无需更改输出。

更好的并行性

FlashAttention v1在批大小和部数量上进行并行化处理。研究人员使用1个线程块来处理一个注意力头,共有 (batch_size * head number) 个线程块。

w~深度学习~合集4_深度学习_04

在前向处理(左图)中,研究者将Worker(线程块)并行化,每个Worker负责处理注意力矩阵的一个行块。在后向处理过程中(右图),每个Worker处理注意力矩阵的一个列块

每个线程块都在流式多处理器 (SM)运行,例如,A100 GPU上有108个这样的处理器。当这个数字很大(比如 ≥80)时,这种调度是有效的,因为在这种情况下,可以有效地使用GPU上几乎所有的计算资源。

在长序列的情况下(通常意味着更小批或更少的头),为了更好地利用GPU上的多处理器,研究人员在序列长度的维度上另外进行了并行化,使得该机制获得了显著加速。

更好的工作分区

即使在每个线程块内,研究人员也必须决定如何在不同的warp(线程束)之间划分工作(一组32个线程一起工作)。研究人员通常在每个线程块使用4或8个warp,分区方案如下图所示。

研究人员在FlashAttention-2中改进了这种分区,减少了不同warp之间的同步和通信量,从而减少共享内存读/写。

w~深度学习~合集4_深度学习_05

对于每个块,FlashAttention将K和V分割到4个warp上,同时保持Q可被所有warp访问。这称为「sliced-K」方案。

然而,这样做的效率并不高,因为所有warp都需要将其中间结果写入共享内存,进行同步,然后再将中间结果相加。

而这些共享内存读/写会减慢FlashAttention中的前向传播速度。

在FlashAttention-2中,研究人员将Q拆分为4个warp,同时保持所有warp都可以访问K和V。

在每个warp执行矩阵乘法得到Q K^T的一个切片后,它们只需与共享的V切片相乘,即可得到相应的输出切片。

这样一来,warp之间就不再需要通信。共享内存读写的减少就可以提高速度。

新功能:头的维度高达256,多查询注意力

FlashAttention仅支持最大128的头的维度,虽说适用于大多数模型,但还是有一些模型被排除在外。

FlashAttention-2现在支持256的头的维度,这意味着GPT-J、CodeGen、CodeGen2以及Stable Diffusion 1.x等模型都可以使用FlashAttention-2来获得加速和节省内存。

v2还支持多查询注意力(MQA)以及分组查询注意力(GQA)。

w~深度学习~合集4_深度学习_06

GQA为每组查询头共享单个key和value的头,在多头和多查询注意之间进行插值

这些都是注意力的变体,其中多个查询头会指向key和value的同一个头,以减少推理过程中KV缓存的大小,并可以显著提高推理的吞吐量。

注意力基准

研究人员人员在A100 80GB SXM4 GPU 上测量不同设置(有无因果掩码、头的维度是64或128)下不同注意力方法的运行时间。

w~深度学习~合集4_深度学习_07

研究人员发现FlashAttention-2比第一代快大约2倍(包括在xformers库和Triton中的其他实现)。与PyTorch中的标准注意力实现相比,FlashAttention-2的速度最高可达其9倍。

w~深度学习~合集4_深度学习_08

A100 GPU上的前向+后向速度只需在H100 GPU上运行相同的实现(不需要使用特殊指令来利用TMA和第四代Tensor Core等新硬件功能),研究人员就可以获得高达335 TFLOPs/s的速度。

w~深度学习~合集4_深度学习_09

H100 GPU上的前向+后向速度

当用于端到端训练GPT类模型时,FlashAttention-2能在A100 GPU上实现高达225TFLOPs/s的速度(模型FLOPs利用率为72%)。

与已经非常优化的FlashAttention模型相比,端到端的加速进一步提高了1.3倍。

w~深度学习~合集4_深度学习_10

未来的工作

速度上快2倍,意味着研究人员可以用与之前训练8k上下文模型相同的成本,来训练16k上下文长度的模型。这些模型可以理解长篇书籍和报告、高分辨率图像、音频和视频。

同时,FlashAttention-2还将加速现有模型的训练、微调和推理。

在不久的将来,研究人员还计划扩大合作,使FlashAttention广泛适用于不同类型的设备(例如H100 GPU、AMD GPU)以及新的数据类型(例如fp8)。

下一步,研究人员计划针对H100 GPU进一步优化FlashAttention-2,以使用新的硬件功能(TMA、第四代Tensor Core、fp8等等)。

将FlashAttention-2中的低级优化与高级算法更改(例如局部、扩张、块稀疏注意力)相结合,可以让研究人员用更长的上下文来训练AI模型。

研究人员也很高兴与编译器研究人员合作,使这些优化技术更好地应用于编程。

参考资料:

https://princeton-nlp.github.io/flash-atttention-2/

别再「浪费」GPU了,FlashAttention重磅升级,实现长文本推理速度8倍提升

处理小说、法律文件等长文本是大模型的一个重要应用方向,但也面临速度上的挑战。FlashAttention 作者 Tri Dao 等人提出的「Flash-Decoding」通过充分利用 GPU,可以将大模型的长上下文推理速度提高至 8 倍。

最近,像 ChatGPT 或 Llama 这样的大型语言模型(LLM)引起了前所未有的关注。然而,它们的运行成本仍然极高。虽然生成单个响应可能仅需 0.01 美元(在 AWS 上的 8xA100 实例上运行几秒钟),但当扩大规模以满足数十亿用户的需求时,成本会迅速累积。而且,这些用户可能每天与 LLM 进行多次互动。某些用例的成本更高,例如代码自动生成,因为它会随着每次输入新字符而运行。随着 LLM 应用的不断增加,即使在生成时间方面实现细微的效率提升,也将产生巨大的影响。

LLM 推理(或「解码」)是一个迭代的过程:token 逐个生成。生成包含 N 个 token 的完整句子需要通过模型进行 N 次前向传递。幸运的是,我们可以缓存先前计算的 token:这意味着单个生成步骤不依赖于上下文长度,除了一个单独的操作 —— 注意力。这个操作导致上下文长度不能很好地扩展。

在 LLM 的重要新兴用例中,有一些需要利用更长的上下文。只有拥有了更长的上下文窗口,LLM 才能对更长的文档进行推理,无论是总结文档还是回答其中的问题。此外,它们还可以保持更长的对话历史,甚至在编写代码之前处理整个代码库。举个例子,在 2022 年,大多数 LLM 的上下文长度最多为 2k(例如 GPT-3),但现在,有些开源 LLM 已经可以扩展到 32k(比如 Llama-2-32k),甚至有些模型已经达到了 100k(比如 CodeLlama)。在这些情境中,注意力操作在推理过程中占据了相当大的时间比例。

在扩展 batch size 维度时,即使上下文相对较短,注意力也可能成为一个瓶颈。这是因为随着 batch 维度的增加,需要读取的内存量也会增加,而对于模型的其余部分,内存需求只取决于模型的大小。

为了解决上述问题,FlashAttention 的作者 Tri Dao 等人提出了一项名为「Flash-Decoding」的技术,它显著加速了推理过程中的注意力计算,使长序列的处理生成速度提高到了原来的 8 倍。其主要思想是以最快的速度并行加载键和值,然后分别重新缩放和合并结果,以维持正确的注意力输出。

解码时的多头注意力

在解码期间,生成的每个新 token 都需要关注所有先前的 token,以计算:softmax (queries @ keys.transpose) @ values

这个操作已经在训练阶段通过 FlashAttention 进行了优化(包括最近的 v1 和 v2 版本),瓶颈是读写中间结果的内存带宽(如 Q @ K^T)。然而,这些优化并不直接适用于推理情况,因为瓶颈不同。在训练中,FlashAttention 并行处理 batch size 和查询长度两个维度。而在推理过程中,查询长度通常为 1:这意味着,如果 batch size 小于 GPU 上的流多处理器(streaming multiprocessor,SM)数量(例如 A100 有 108 个),该操作只会利用 GPU 的一小部分!特别是在处理长上下文时,情况尤为明显,因为它需要较小的 batch size 以适应 GPU 内存。当 batch size 为 1 时,FlashAttention 将使用不到 1% 的 GPU!

w~深度学习~合集4_深度学习_11

FlashAttention 只在查询块和 batch size 之间并行,并且在解码期间不会设法占用整个 GPU。

使用矩阵乘法基元也能执行注意力计算,这样就不需要使用 FlashAttention 了。在这种情况下,该操作会占用整个 GPU,但会启动许多写入和读取中间结果的内核,因此并不是最优的做法。

更快的注意力解码:Flash-Decoding

新方法 Flash-Decoding 基于 FlashAttention,同时引入了一个新的并行维度:键值序列的长度。它综合了上述两种方法的优点。与 FlashAttention 类似,它在全局内存中存储的额外数据很少。然而,只要上下文足够长,即使 batch size 较小,它也能充分利用 GPU。

w~深度学习~合集4_深度学习_12

Flash-Decoding 也在键和值之间并行化,代价是一个小的最终归约(reduction 步骤。

Flash-Decoding 主要有三个工作步骤:

  1. 首先,将键 / 值分成更小的块;
  2. 使用 FlashAttention 并行计算查询与每个这些分块的注意力,为每行和每个分块额外写入一个标量值:注意力值的 log-sum-exp
  3. 最后,通过对所有分块进行归约来计算实际输出,使用 log-sum-exp 来调整每个分块的贡献。

这一切之所以可行,都是因为注意力 /softmax 可以进行迭代计算。在 Flash-Decoding 中,它在两个级别上被使用:在分块内部(类似 FlashAttention),以及跨分块进行最终的归约计算。

实际操作中,步骤(1)不涉及任何 GPU 操作,因为键 / 值块是完整键 / 值张量的视图。然后,有两个独立的核函数,分别用于执行步骤(2)和(3)。

在 CodeLlama 34B 上进行的基准测试

为了验证上述新方法,研究者对 CodeLLaMa-34b 的解码吞吐量进行了基准测试。该模型与 Llama 2 具有相同的架构,一般来说,结果应该适用于许多大型语言模型。研究者在不同序列长度下(从 512 到 64k),以 tok/s 为单位来测量解码速度,并比较了多种计算注意力的方式:

  • Pytorch:使用纯粹的 PyTorch 基元来运行注意力计算(不使用 FlashAttention);
  • FlashAttention v2;
  • FasterTransformer:使用 FasterTransformer 的注意力内核;
  • Flash-Decoding;
  • 以及一个上限值,该值计算了从内存中读取整个模型和 KV-cache 所需的时间

对于非常大的序列,Flash-Decoding 可以将解码速度提高至 8 倍,并且比其他方法的扩展性要好得多。

w~深度学习~合集4_深度学习_13

在 prompt 比较小时,所有方法表现接近。但是当序列长度从 512 增加到 64k 时,除了 Flash-Decoding,其他方法的可扩展性都很差。在 Flash-Decoding 的这种模式下(batch size 为 1),扩展序列长度对生成速度的影响很小。

组件级微基准测试

研究者还在 A100 上对多头注意力进行了微基准测试,输入为 f16,考虑了不同的序列长度和 batch size。他们将 batch size 设置为 1,并且使用 16 个 128 维的查询头,以及 2 个键 / 值头(分组查询注意力),这与在 4 个 GPU 上运行的 CodeLLaMa-34b 使用的维度相匹配。

w~深度学习~合集4_深度学习_14

上述微基准测试展示了多头注意力的运行时间,单位为微秒。Flash-Decoding 在序列长度扩展到高达 64k 时,几乎实现了恒定的运行时间。

之前测量的高达 8 倍的端到端加速是可能的,因为注意力本身的速度比 FlashAttention 快高达 50 倍。在序列长度达到 32k 之前,注意力的时间大致是恒定的,因为 Flash-Decoding 能够完全利用 GPU。

使用 Flash-Decoding

Flash-decoding 可以在以下链接中找到:

  • FlashAttention 包,从 v2.2 开始:https://github.com/Dao-AILab/flash-attention/tree/main
  • xFormers 包(搜索 xformers.ops.memory_efficient_attention),从 0.0.22 开始:调度程序将根据问题的大小自动使用 Flash-Decoding 或 FlashAttention 方法。当这些方法不受支持时,它可以调度到一个高效的 triton 内核,该内核实现了 Flash-Decoding 算法。

一个完整的使用 LLaMa v2 / CodeLLaMa 的解码示例可以在 FlashAttention  repo 和 xFormers  repo 中找到。此外,作者还提供了一个简单的 LLaMa v1/v2 模型的高效解码代码示例,旨在快速、易读、有教育意义和易于修改。

参考链接:https://princeton-nlp.github.io/flash-decoding/








#FlashAttention~2

这是一个简单的学习总结,核心逻辑以及V1 V2差异总结,主要从计算的角度总结FlashAttention怎么做到save memory & perf speedup的,讲一些其他文章提的比较少的点,不提供等效计算变换的公式证明。

摸了一下FlashAttention的CUTLASS实现(https//github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_kernel.h)和Triton实现(https//github.com/openai/triton/blob/main/python/triton/ops/flash_attention.py),之前做过FlashAttention V1和V2算法下,这两种框架最终的Kernel Perf benchmark(A100 & H100)所以对V1和V2的差异很好奇,本文是一个简单的学习总结,主要从计算的角度总结FlashAttention怎么做到save memory & perf speedup的,讲一些其他文章提的比较少的点,不提供等效计算变换的公式证明(这部分知乎其他大佬写的非常详细清晰了)。

理解FlashAttention核心逻辑

本节列举从0搞懂FlashAttention的核心步骤

首先需要理解Naive Attention是怎么计算的:

w~深度学习~合集4_深度学习_15

  1. Google Research的工作重点在减少整个过程的memory footprint;FlashAttention重点在减少memory reads/writes次数。可以说FlashAttention主要是从GPU block/thread并行度的视角对访存进行了优化。
  2. Google Research的工作每个block会产出一份中间结果,所有block执行完毕之后,再将他们的中间结果计算获得一个最终结果;FlashAttention则采用类似滑动窗口的方式,第i个block会将累积的中间结果传递给第 i+1 个block,也就是说最后一个block计算完毕后,可以保证整行的Softmax逻辑计算正确性。锐评:我认为这个点并没有什么独创性,Google Research这么考虑的原因也大概率是因为TPU的计算逻辑粒度适合沿sequence length切并行,切的越小越有利于TPU并行,最后再有一个逻辑来处理中间数据很正常。
  3. Google Research的工作在后向backward的时候做了一些冗余计算,FlashAttention把后向的计算简化了,减少了backward阶段的memory traffic。

FlashAttention V1的公式和推导不细嗦了,其他文章讲得非常好。列一下个人觉得从0开始的最佳学习路线:

首先看文章,公式推导和Tiling细节强烈这篇文章:From online softmax to FlashAttention(https//courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf),写的非常好,由浅入深,公式推导这块看完这篇其他都不用看了。然后辅助看一些知乎文章把不明白的地方搞清楚。

理解From online softmax to FlashAttention需要四个步骤

  1. softmax
  2. safe softmax
  3. online softmax Tiling
  4. FlashAttention Tiling

总之:FlashAttention之所以可以省显存(显存开销随Seq length线性增加),是因为解开了softmax以及后面GEMM的行方向依赖,并且通过辅助数组保存的辅助信息re-scale到正确的数值。

其次,了解一些背景信息,这里附一下其他可能便于理解FlashAttention项目发展的背景信息:

  • FlashAttention V1 在NVIDIA apex fmha基础上实现(最早的FlashAttention Alpha Realease(https//github.com/Dao-AILab/flash-attention/blob/1fcbe6f0d088d807ba585ddb850eb3497cd7b65b/csrc/stream_attn/src/fmha_kernel.h)),V2基于CUTLASS 3.0 & CUTE 重构(CUTE真是个好东西)
  • FlashAttention目前最方便的调库途径主要有两个
  • 最新实现在ops里面(https//github.com/openai/triton/blob/main/python/triton/ops/flash_attention.py)
  • 稳定的实现在tutorial里面(https//github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py)
  • 【更新】最近测了一下,H100上tutorial速度也很快,看上去tutorial的kernel算法也经常有小优化
  • pip install flash-attn,官方的库,编译时间稍长,基于CUTLASS有大量的模板,如果想进一步魔改(比如加bias或者加mask,或者稀疏化等)学习和Debug成本比较大
  • 使用Triton的实现,性能实测非常不错

最后,看代码,跑代码,Profile Kernel with Nsight Compute,改代码...

这里我推荐基于Triton FlashAttention(https//github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py)上手,原因如下:

  1. Tri-Dao的FlashAttention基于CUTLASS 3.0重构,新手小白想要编译跑起来首先要搞定几个环境问题;里面一个头文件几千行,想看懂需要首先搞懂CUTLASS的基本用法,想改代码更需要一些模板猿编程debug技巧,想优化性能的话...你得学学CUTE,再学学各种GPU Features...如果你只是想学习一下FlashAttention,或者只是想基于这个Fusion idea定制你自己的kernel,没必要从CUTLASS开始。CUTLASS适合CUDA熟手并且期望拿到Peak Performance的玩家。
  2. Triton语法很容易上手,方便魔改你自己的Attention Kernel,或者你有其他的想法也很容易实践实验。例子:FlagAttention(https//github.com/FlagOpen/FlagAttention),Sparse Flash Attention(https//github.com/epfml/dynamic-sparse-flash-attention) (所以非常适合发paper啦,至少迭代CUDA kernel速度直接飙升,加快idea的反馈。从实验时间成本来说,用CUDA写半个月的Kernel+调一个月性能 >> 用CUTLASS写一周Kenrel+调两天性能 >> 用Triton写3天Kernel+调1天性能)
  3. Triton FlashAttention在Hopper-Triton PR(https//github.com/openai/triton/commit/f1512bded1934e34f104bf1ac8547e97e24b2fe8)之后,目前main分支已经集成了大量的Hopper相关的优化Pass,相比官方库还没有稳定实现Hopper features来说,在某些problem size下可能有优势。
  4. 关于Triton推荐阅读:杨:谈谈对OpenAI Triton的一些理解(https://zhuanlan.zhihu.com/p/613244988),杨:OpenAI Triton Conference参会随感兼谈Triton Hopper(https://zhuanlan.zhihu.com/p/659348024)

OK,进入正题

FlashAttention V2相比V1有哪些改进

w~深度学习~合集4_深度学习_16

FlashAttention V1 前向后向 Kernel示意草图

w~深度学习~合集4_深度学习_17

FlashAttention V2 前向后向 Kernel示意草图

V2主要从两个方面改进:

算法的改进

fwd和bwd都简化了非matmul计算,这里也是对rescale重新优化了一下;其中bwd不需要m了,只需要logsumexp L 即可。

其实FlashAttention不管V1还是V2都有一个缺点,就是为了rescale方便并行,需要把很多计算逻辑顺序排在后面(尤其是浮点数的乘除),这会改变计算的数值精度稳定性,造成在某些使用到Attention结构的网络中收敛不了的问题。

fwd和bwd都根据casual mask的特性尽可能减少冗余计算和访存:

  1. 右侧上三角block无须计算,直接跳过;
  2. 每行只用对最后一个block设定casual mask的逻辑即可。

w~深度学习~合集4_深度学习_18

FlashAttention V1 & V2 forward

w~深度学习~合集4_深度学习_19

FlashAttention V1 & V2 backward

红框的就是这里算法部分的优化和改动(左图多了mask和droupout的逻辑,忽略即可)

这个优化其实不是critical path,所以提升并不大。fwd做2个GEMM,bwd做5个GEMM,整个Kernel fwd & bwd都是memory bound,此时应该优化的是GEMM相关的memory dependency,做multi-stages,更灵活的异步调度(比如warp specialization),最后可能还需要考虑优化data reuse,优化L2 cache等等,当然一切都需要基于Nsight Compute结果分析,不然都是幻觉。

Sequence Length 并行

非常赞同@方佳瑞(https://www.zhihu.com/people/8c89d6f733cb2b81ce36a2daf0a81a82) 方佳瑞:大模型训练加速之FlashAttention系列:爆款工作背后的产品观(https://zhuanlan.zhihu.com/p/664061672#:~:text=我觉得V2最重要的提升点是参考Phil Tillet的Tirton版本,更改了Tiling循环的顺序,也就是笔者本文图1的效果。) 提到的,V2 能够把特定输入下的一个CUDA Kernel提升2X,这只能说明baseline(V1)选的太好了(笑),总之,就是因为改变了Tiling循环的顺序,把Q循环挪到了最外层,所以刚好就可以把Q循环直接给到Thread Block并行维度来计算了,本来这个方向没有依赖就是可以并行的。话说我最开始的也很纳闷,这个idea其实最早就有了,PyTorch的实现以及NV Apex FMHA的实现都有这个版本的kernel。

考虑一下,为什么K/V上的seq length方向不给到Thread Block做并行?答案是,如果可以在Q seq length上拆block并行了,那么一般来说GPU occupancy已经够了,再多拆K/V的话也不是不行,但是会额外带来通信开销;Flash Decoding其实就是在inference阶段,面对Q的seq length=1的情况,在K/V方向做了block并行,来提高GPU Utilization从而加速的。

w~深度学习~合集4_深度学习_20

FlashAttention V1 - Tile and 2D-Loop

Thread Block Level 并行

交换了Q loop顺序到最外层之后,最大的好处是可以把这一维度的并行度从串行的loop改成并行的thread block。

所以,FlashAttention V2的实现中,fwd除了在Batch和Head上分配Thread Block并行,还在seq length上增加了一维并行度(之前是需要M N方向做loop的,现在只在N方向loop了,横着切),注意:bwd没有改变这里的循环,跟V1一样,但是也在seq length上增加了一维并行度(N方向并行,竖着切)。

造成fwd和bwd区别的主要原因是:

  1. fwd的目的是计算QK GEMM之后沿着行方向online softmax,所以需要沿着行方向loop,不然就需要额外的reduce逻辑了。因此fwd kernel选择一行Tile为一个block。如下左图一行同色块为一个block。

w~深度学习~合集4_深度学习_21

细节1:从V1 (KV外循环,QO内循环) 到V2 (Q外循环,KV内循环,O在 smem初始化,最后只写出一次), memory traffic是否降低了? memory traffic of V1:

w~深度学习~合集4_深度学习_22

细节2:现在确定了fwd kernel要在B, H, Q_N_CTX三个维度Launch Kernel了,有两种选择:grid_dim = [Q_N_CTX, B, H], grid_dim = [B, H, Q_N_CTX],哪种更好?

答案是第一种更好,因为Q_N_CTX放ThreadBlock.X维度的话,对于同一个B和H的Q_N_CTX是连续调度的,也就是说算第一行用到的K/V Tile大概率还在L2上,第二行计算可以直接从L2拿到,这样可以显著提高L2 cache hit rate。这个优化在大seq_length的时候优化很明显。

Warp Level 并行

说完了thread block的并行,再来看一个block内的warp怎么并行的

w~深度学习~合集4_深度学习_23

把V横着画,有一种上下对称的美感

首先看fwd,相比V1,V2改进了Warp Partition:4个warp会从smem的K/V tile load同样的数据做mma计算,但是load 不同Q,把V1 sliced-K sliced-V 改成了v2 sliced-Q,V1的做法是需要warp之间产生同步通信的,因为在计算QK结果乘V的时候,如图所示需要跨warp reduction得到O的结果,而且fwd的目的是沿着行方向计算softmax,行方向信息最后要汇总的,这也需要跨warp不同。V2就不需要了,这样可以减少同步开销。

对于bwd来说,如果按照右图做warp partition:1、QK的结果是P,dV=P x dO,计算dV也是需要cross warp sync的;2、dO x V的结果是dP,跟P是对称的,计算dS = P o(点乘) dP的时候不需要cross warp sync;3、计算dQ是dS x K,不需要cross warp sync;4、计算dK = dS x Q,需要cross warp sync;

如果按照左图来做warp partition,那么:1、计算dV不需要cross warp sync;2、计算dS不需要cross warp sync;3、计算dQ,需要cross warp sync;4、计算dK,不需要cross warp sync;

这里有个疑问,对于bwd来说,左图不需要cross warp sync的场景是更多的,如果Br Bc d这三个reduction的维度差不多,按道理来说bwd kernel更应该采用V1的方式做warp partition,原文:

Similarly for the backward pass, we choose to partition the warps to avoid the "split-K" scheme. However, it still requires some synchronization due to the more complicated dependency between all the different inputs and gradients Q, K, V, O, dO, dQ, dK, dV. Nevertheless, avoiding "split-K" reduces shared memory reads/writes and again yields speedup

但作者并没有细说,也没有实验数据证明这里确实提升了bwd kernel性能,感到困惑。以后有空再测一下这里的策略是不是负优化(笑

其他优化

Sequence Parallel

https//github.com/openai/triton/blob/main/python/triton/ops/flash_attention.py

这是一个没有放在FlashAttentionV2 Release的优化点(应该,反正V2 paper没提到)。之前提到的bwd中,对dQ的计算是需要跨Block做全局Atomic Reduction的,如果Block数太多,就会产生Atomic竞争,效率很低(AtomicAdd指令吞吐和延迟都比正常的LDG慢2X,而且如果dQ的数据类型是fp16或者bf16,进行Atomic操作将是性能灾难);所以如果想避免使用Atomic指令做Reduction,有两种方案:

Loop n

最简单的方式,就是bwd不要seq length方向并行度了,直接串行循环就ok;

Sequence Parallel

把dQ开大一维度,N方向有多少列block就开几个buffer,并且N方向纯block并行,最后开另外一个kernel做一个reduce,这个逻辑体现在Triton代码里就是:

https//github.com/openai/triton/blob/f9b2b822dfc0980df4c39286713414dfcd27cf8e/python/triton/ops/flash_attention.py%23L323C1-L361C1

num_block_n = tl.cdiv(N_CTX, BLOCK_N)
    if not SEQUENCE_PARALLEL:
        for start_n in range(0, num_block_n):
            _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO,  #
                                      DQ, DK, DV,  #
                                      L,  #
                                      D,  #
                                      Q_block_ptr, K_block_ptr, V_block_ptr,  #
                                      DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,  #
                                      stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,  #
                                      stride_kz, stride_kh, stride_kn, stride_kk,  #
                                      stride_vz, stride_vh, stride_vn, stride_vk,  #
                                      Z, H, N_CTX,  #
                                      off_h, off_z, off_hz, start_n, num_block_n,  #
                                      BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,  #
                                      BLOCK_N=BLOCK_N,  #
                                      SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,  #
                                      CAUSAL=CAUSAL,  #
                                      MMA_V3=MMA_V3  #
                                      )
    else:
        start_n = tl.program_id(1)
        _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO,  #
                                  DQ, DK, DV,  #
                                  L,  #
                                  D,  #
                                  Q_block_ptr, K_block_ptr, V_block_ptr,  #
                                  DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,  #
                                  stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,  #
                                  stride_kz, stride_kh, stride_kn, stride_kk,  #
                                  stride_vz, stride_vh, stride_vn, stride_vk,  #
                                  Z, H, N_CTX,  #
                                  off_h, off_z, off_hz, start_n, num_block_n,  #
                                  BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,  #
                                  BLOCK_N=BLOCK_N,  #
                                  SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,  #
                                  CAUSAL=CAUSAL,  #
                                  MMA_V3=MMA_V3  #
                                  )

Atomic和Sequence Parallel肯定是两个极端了,所以根据不同的problem size取个折中肯定会有机会拿到更好的性能,根据具体的计算任务和GPU型号,还有些取巧的方式来避免Atomic,比如把即将冲突的Atomic Tile安排在不同的wave上,这样GPU也不会因为Atomic带来性能损耗。

Flash Decoding优化的核心跟这里的Sequence Parallel很类似:K/V Seq Length 方向切并行度做forward softmax并且最后使用一个reduction kernel对output进行累加和rescale。

指令级别的优化

基于FlashAttentionV2继续优化下去(假设把mma相关指令和memory latency掩盖的很好的话),最后大概率会bound在softmax相关的指令上,这个时候细扣这些浮点数计算指令也是有帮助的:

  1. FMUL+MUFU.EX2指令替换为fastmath指令expf
  2. 把scale相关的数据在cpu提前计算好,省一条GPU指令
  3. 把FMUL和FADDs合成FFMA指令

其实还有很多点可以优化FlashAttention的性能,不过都是些没有profile kernel的拍脑袋幻觉,就不讲了,之后有空做了实验可以再写一篇。









#IMMA~~

搬来自斯坦福的研究者提出了 IMMA, 一种利用隐空间多层图 (multiplex latent graphs) 来表征多种独立的交互类型,并使用一种新型的多层图注意力机制 (multiplex attention mechanism) 来描述个体间交互强度的行为及轨迹预测模型。该方法不仅大幅提升了预测的准确度,同时也具有很强的可解释性 (interpretability) 和泛化能力 (zero-shot generalizability)。

  • 论文链接:https://arxiv.org/abs/2208.10660 
  • 代码链接:https://github.com/fanyun-sun/IMMA

一.研究背景

对多智能体系统 (multi-agent systems) 的建模在很多领域和应用中起到重要作用,包括但不限于自动驾驶,移动机器人导航,以及人机协作。由于个体的行为会受到不同类型社会性交互 (social interactions) 的影响,多智能体系统的动力学建模面临着极高的挑战。

过去已有的方法例如 NRI[1]和 EvolveGraph[2]利用图神经网络 (GNN) 来推测每一对智能体之间的关系类型,但是这样并不能显式地对在多智能体系统中出现的不同层级的交互关系进行建模,从而导致模型的预测效果和可解释性下降。下面 Figure 1 介绍了生活中人与人之间交互的一个实例。

该研究提出了 IMMA(Interaction Modeling with Multiplex Attention)方法,通过利用隐空间中的多层关系图结构 (multiplex latent interaction graphs), 对不同层级中不同类型的交互关系进行推理,同时,该研究还设计了一种新的多层图注意力机制 (multiplex attention mechanism) 来学习每种交互关系的强度。另外,该研究还提出了一种逐层训练 (progressive layer training) 的方法来加强不同层的关系图之间的解耦,从而进一步提升了模型的可解释性 (interpretability) 和泛化能力 (zero-shot generalizability)。本文方法在多种不同领域的问题中都取得了最优效果,包括 social navigation, cooperative task achievement 和 team sports.

w~深度学习~合集4_深度学习_24

二.研究方法

问题描述:假设场景中包含 N 个智能体,模型的输入包含这 N 个智能体的轨迹,任务目标是根据过去一段时间内的轨迹观测来预测未来一段时间内的轨迹,同时要对智能体之间的交互关系进行建模和推理。

核心观点:在已有的方法中 ([1][2]),隐变量 z 中的每个元素表示交互关系图中对应的 edge 属于每一种可能的关系的概率,意味着用于 Decoder 的关系图中每一对智能体的关系只能是其中的一种。然而,复杂的交互系统中可能存在某些智能体之间同时存在多种相互独立的关系的情况,并且每种关系的强度也可能有所区别,仅通过一层关系图不能准确描述具有这类性质的多体系统。因此,该研究提出利用多层关系图拓扑结构来进行更精准的建模,不仅可以提升预测效果,同时也能用模型学到的多层关系图提供对模型预测的解释,一定程度上可以分析各智能体行为之间的因果关系。另外,由于多层关系图也可以提升模型在训练数据中没有包含的场景中取得更好的效果,增强泛化能力和训练样本效率(sample efficiency)。

模型介绍:该研究提出的模型 IMMA 由一个 Encoder 和一个 Decoder 组成。Encoder 的输入是一系列轨迹观测,输出是一个在隐空间内的交互关系图。Decoder 通过对推测出的关系图里面的信息进行传递和整合来生成对每个智能体未来轨迹的预测。下面 Figure 2 提供了模型示意图和进一步解释。模型整体基于 Conditional Variational Autoencoder 框架,并通过以下目标函数来训练模型参数:

w~深度学习~合集4_深度学习_25

三.实验结果和分析

本文的实验试图回答以下四个问题:

1. 本文方法 (IMMA w/ PLT) 是否在各种 social multiagent systems 数据集测试中始终优于已有的基准方法?

2.Multiplex attentional latent graph 的使用是否给模型提供了更多可解释性?

3. 模型中的每个模块对模型效果的提升有多大贡献?

4. 本文方法相比于基准方法是否可以提升 sample efficiency?它是否可以很好地泛化到新环境或新场景?

1. 该研究主要用以下三个数据集进行实验:Social Navigation Environment  (基于 ORCA), PHASE dataset 和 NBA dataset,所有数据集上的结果如 Table 1 所示。该研究发现使用 Multiplex attentional latent graph 和渐进层训练 (MG w/ PLT) ,结果在所有三个数据集上都优于已有的最强基线模型。 

w~深度学习~合集4_深度学习_26

对于 Social Navigation dataset 的可视化结果如下图所示。第一排表示预测轨迹。圆圈越小表示离当前时间点越远。最左边是真实的未来轨迹和交互关系图。第二排表示预测的 latent graph。智能体 i 和 j 之间的 relation 由 heatmap 中的第 i 行第 j 列的元素表示。RFM 错误地预测了智能体之间的关系——用箭头突出显示,绿色 agent 被错误地赋予了比蓝色 agent 更高的权重。因此,预测的轨迹偏离了事实。相反,本文方法准确地预测了交互关系和未来轨迹。

w~深度学习~合集4_深度学习_27

本文方法在 PHASE(左)和 NBA 数据集 (右) 的结果可视化如下图所示。在右侧的 NBA 图中,橙色代表篮球,不同轨迹颜色代表不同球队。    

w~深度学习~合集4_深度学习_28

 2. 以上实验证明本文方法可以更准确地预测运动轨迹,之后,该研究进一步探究了关系推理能力和对轨迹预测的影响。首先,本文对于关系推理更加准确 (见 Table 2),这不仅帮助模型预测运动轨迹,也提供了更好的 disentanglement 和可解释性。如下图所示,IMMA 中改变 agent 的 leader 会显著改变预测的轨迹,以新 leader 为目标,同时保持对其他 agent 的预测不变。而 RFM 生成的轨迹包括不切实际的转弯 (如红色 agent) 并且对其他 agent 的轨迹预测变差。

w~深度学习~合集4_深度学习_29

3. Ablation study 结果如 Table 3 所示,实验结果证明 multiplex attention graph 在模型中起到了至关重要的作用,逐层训练进一步提升了轨迹预测和关系推理准确度。

w~深度学习~合集4_深度学习_30

4. Table 4 显示 IMMA 的 zero-shot 泛化能力比基线方法更好。 

w~深度学习~合集4_深度学习_31

另外,下图显示相比最优基线方法 (RFM), IMMA 需要更少的训练数据就可以得到更好的结果。

w~深度学习~合集4_深度学习_32

 

四.结论

由于存在潜在的多层社会交互 (social interactions) 关系,多智能体系统 (multi-agent systems) 的动力学 (dynamics) 通常很复杂。智能体 (agent) 的行为可能会受到与其他每个智能体多种独立关系类型的影响,例如在物理系统中通常不存在的复杂属性 (意向性或合作关系)。本文提出了一种包含交互建模的预测方法 (IMMA),该方法使用 multiplex latent graph 作为隐空间表征 (latent representation) 来建模这种多层交互类型可能产生的行为。本文方法在行为和轨迹建模以及关系推理方面均优于其他最先进的方法,并有很强的可解释性 (interpretability) 和泛化能力 (zero-shot generalizability)。

 








#深度强化学习中SAC算法

数学原理、网络架构及其PyTorch实现

本文详细介绍了深度强化学习中的软演员-评论家算法(SAC),包括其数学原理、网络架构设计以及PyTorch实现。

深度强化学习是人工智能领域最具挑战性的研究方向之一,其设计理念源于生物学习系统从经验中优化决策的机制。在众多深度强化学习算法中,软演员-评论家算法(Soft Actor-Critic, SAC)因其在样本效率、探索效果和训练稳定性等方面的优异表现而备受关注。

传统的深度强化学习算法往往在探索-利用权衡、训练稳定性等方面面临挑战。SAC算法通过引入最大熵强化学习框架,在策略优化过程中自动调节探索程度,有效解决了这些问题。其核心创新在于将熵最大化作为策略优化的额外目标,在保证收敛性的同时维持策略的多样性。

本文将系统阐述SAC算法的技术细节,主要包括:

  1. 基于最大熵框架的SAC算法数学原理
  2. 演员网络与评论家网络的具体架构设计
  3. 基于PyTorch的详细实现方案
  4. 网络训练的关键技术要点

SAC算法采用演员-评论家架构,演员网络负责生成动作策略,评论家网络评估动作价值。通过两个网络的协同优化,实现策略的逐步改进。整个训练过程中,演员网络致力于最大化评论家网络预测的Q值,同时保持适度的策略探索;评论家网络则不断优化其Q值估计的准确性。

接下来,我们将从演员网络的数学原理开始,详细分析SAC算法的各个技术组件:

演员(策略)网络

演员是由参数φ确定的策略网络,表示为:

w~深度学习~合集4_深度学习_33

这是一个基于状态输出动作的随机策略。它使用神经网络估计均值和对数标准差,从而得到给定状态下动作的分布及其对数概率。对数概率用于熵正则化,即目标函数中包含一个用于最大化概率分布广度(熵)的项,以促进智能体的探索行为。关于熵正则化的具体内容将在后文详述。演员网络的架构如图所示:

w~深度学习~合集4_深度学习_34

均值μ(s)和对数σ(s)用于动作采样:

w~深度学习~合集4_深度学习_35

其中N表示正态分布。但这个操作存在梯度不可微的问题,需要通过重参数化技巧来解决。

w~深度学习~合集4_深度学习_36

这里d表示动作空间维度,每个分量ε_i从标准正态分布(均值0,标准差1)中采样。应用重参数化技巧:

w~深度学习~合集4_深度学习_37

这样就解决了梯度截断问题。接下来通过激活函数将x_t转换为标准化动作:

w~深度学习~合集4_深度学习_38

该转换确保动作被限制在[-1,1]区间内。

动作对数概率计算

完成动作计算后,就可以计算奖励和预期回报。演员的损失函数中还包含熵正则化项,用于最大化分布的广度。计算采样动作𝑎_t的对数概率Log(π_ϕ)时,从预tanh变换x_t开始分析更为便利。

由于x_t来自均值μ(s)和标准差σ(s)的高斯分布,其概率密度函数(PDF)为:

w~深度学习~合集4_深度学习_39

其中各独立分量x_t,i的分布为:

w~深度学习~合集4_深度学习_40

对两边取对数可简化PDF:

w~深度学习~合集4_深度学习_41

要将其转换为log(π_ϕ),需要考虑x_t到a_t的tanh变换,这可通过微分链式法则实现:

w~深度学习~合集4_深度学习_42

这个关系的推导基于概率守恒原理:两个变量在给定区间内的概率必须相等:

w~深度学习~合集4_深度学习_43

其中a_i = tanh(x_i)。将区间缩小到无穷小的dx和da:

w~深度学习~合集4_深度学习_44

tanh的导数形式为:

w~深度学习~合集4_深度学习_45

代入得到:

w~深度学习~合集4_深度学习_46

最终可得完整表达式:

w~深度学习~合集4_深度学习_47

至此完成了演员部分的推导,这里有动作又有对数概率,就可以进行损失函数的计算。下面是这些数学表达式的PyTorch实现:

import gymnasium as gym    
 from src.utils.logger import logger    
 from src.models.callback import PolicyGradientLossCallback    
 from pydantic import Field, BaseModel, ConfigDict    
 from typing import Dict, List    
 import numpy as np    
 import os    
 from pathlib import Path    
 import torch    
 import torch.nn as nn    
 import torch.optim as optim    
 import torch.nn.functional as F    
 from torch.distributions import Normal   

 

 '''演员网络:估计均值和对数标准差用于熵正则化计算'''    

 class Actor(nn.Module):    
  def __init__(self,state_dim,action_dim):  
  super(Actor,self).__init__()  

  self.net = nn.Sequential(  
  nn.Linear(state_dim, 100),  
  nn.ReLU(),  
  nn.Linear(100,100),  
  nn.ReLU()  
  )  
  self.mean_linear = nn.Linear(100, action_dim)  
  self.log_std_linear = nn.Linear(100, action_dim)  

  def forward(self, state):  
  x = self.net(state)  
  mean = self.mean_linear(x)  
  log_std =self.log_std_linear(x)  
  log_std = torch.clamp(log_std, min=-20, max=2)  
  return mean, log_std  

  def sample(self, state):  
  mean, log_std = self.forward(state)  
  std = log_std.exp()  
  normal = Normal(mean, std)  
  x_t = normal.rsample() # 重参数化技巧  
  y_t = torch.tanh(x_t)  
  action = y_t  
  log_prob = normal.log_prob(x_t)  
  log_prob -= torch.log(1-y_t.pow(2)+1e-6)  
  log_prob = log_prob.sum(dim=1, keepdim =True)  

  return action, log_prob

在讨论损失函数定义和演员网络的训练过程之前,需要先介绍评论家网络的数学原理。

评论家网络

评论家网络的核心功能是估计状态-动作对的预期回报(Q值)。这些估计值在训练过程中为演员网络提供指导。评论家网络采用双网络结构,分别提供预期回报的两个独立估计,并选取较小值作为最终估计。这种设计可以有效避免过度估计偏差,同时提升训练稳定性。其结构如图所示:

w~深度学习~合集4_深度学习_48

需要说明的是,此时的示意图是简化版本,主要用于理解演员和评论家网络的基本角色,暂不考虑训练稳定性的细节。另外,"智能体"实际上是演员和评论家网络的统称而非独立实体,图中分开表示只是为了清晰展示结构。假设评论家网络暂不需要训练,因为这样可以专注于如何利用评论家网络估计的Q值来训练演员网络。演员网络的损失函数表达式为:

w~深度学习~合集4_深度学习_49

更常见的形式是:

w~深度学习~合集4_深度学习_50

其中ρ_D表示状态分布。损失函数通过对所有动作空间和状态空间的熵项与Q值进行积分得到。但在实际应用中,无法直接获取完整的状态分布,因此ρ_D实际上是基于重放缓冲区样本的经验状态分布,期望其能较好地表征整体状态分布特征。

基于该损失函数可以通过反向传播对演员网络进行训练。以下是评论家网络的PyTorch实现:

'''评论家网络:定义q1和q2'''    
 class Critic(nn.Module):    
  def __init__(self, state_dim, action_dim):  
  super(Critic, self).__init__()  

  # Q1网络架构  
  self.q1_net = nn.Sequential(  
  nn.Linear(state_dim + action_dim, 256),  
  nn.ReLU(),  
  nn.Linear(256, 256),  
  nn.ReLU(),  
  nn.Linear(256, 1),  
  )  

  # Q2网络架构  
  self.q2_net = nn.Sequential(  
  nn.Linear(state_dim + action_dim, 256),  
  nn.ReLU(),  
  nn.Linear(256, 256),  
  nn.ReLU(),  
  nn.Linear(256, 1),  
  )  

  def forward(self, state, action):  
  sa = torch.cat([state, action], dim=1)  
  q1 = self.q1_net(sa)  
  q2 = self.q2_net(sa)  
  return q1, q2

前述内容尚未涉及评论家网络自身的训练机制。从重放缓冲区采样的每个数据点包含[s_t, s_{t+1}, a_t, R]。对于状态-动作对的Q值,我们可以获得两种不同的估计。

第一种方法是直接将a_t和s_t输入评论家网络:

w~深度学习~合集4_深度学习_51

第二种方法是基于贝尔曼方程:

w~深度学习~合集4_深度学习_52

这种方法使用s_t+1、a_t+1以及执行动作a_t获得的奖励来重新估计。这里使用目标网络而非第一种方法中的评论家网络进行估计。采用目标评论家网络的主要目的是解决训练不稳定性问题。如果同一个评论家网络同时用于生成当前状态和下一状态的Q值(用于目标Q值),这种耦合会导致网络更新在目标计算的两端产生不一致的传播,从而引起训练不稳定。因此引入独立的目标网络为下一状态的Q值提供稳定估计。目标网络作为评论家网络的缓慢更新版本,确保目标Q值能够平稳演化。具体结构如图所示:

w~深度学习~合集4_深度学习_53

评论家网络的损失函数定义为:

w~深度学习~合集4_深度学习_54

通过该损失函数可以利用反向传播更新评论家网络,而目标网络则采用软更新机制:

w~深度学习~合集4_深度学习_55

其中ε是一个较小的常数,用于限制目标评论家的更新幅度,从而维持训练稳定性。

完整流程

以上内容完整阐述了SAC智能体的各个组件。下图展示了完整SAC智能体的结构及其计算流程:

w~深度学习~合集4_深度学习_56

下面是一个综合了前述演员网络、评论家网络及其更新机制的完整SAC智能体实现

'''SAC智能体的实现:整合演员网络和评论家网络'''    

 class SACAgent:    
  def __init__(self, state_dim, action_dim, learning_rate, device):  
  self.device = device  

  self.actor = Actor(state_dim, action_dim).to(device)  
  self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)  

  self.critic = Critic(state_dim, action_dim).to(device)  
  self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate)  

  # 目标网络初始化  
  self.critic_target = Critic(state_dim, action_dim).to(device)  
  self.critic_target.load_state_dict(self.critic.state_dict())  

  # 熵温度参数  
  self.target_entropy = -action_dim   
  self.log_alpha = torch.zeros(1, requires_grad=True, device=device)  
  self.alpha_optimizer = optim.Adam([self.log_alpha], lr=learning_rate)  

  def select_action(self, state, evaluate=False):  
  state = torch.FloatTensor(state).to(self.device).unsqueeze(0)  
  if evaluate:  
  with torch.no_grad():  
  mean, _ = self.actor(state)  
  action = torch.tanh(mean)  
  return action.cpu().numpy().flatten()  
  else:  
  with torch.no_grad():  
  action, _ = self.actor.sample(state)  
  return action.cpu().numpy().flatten()  

  def update(self, replay_buffer, batch_size=256, gamma=0.99, tau=0.005):  
  # 从经验回放中采样训练数据  
  batch = replay_buffer.sample_batch(batch_size)  
  state = torch.FloatTensor(batch['state']).to(self.device)  
  action = torch.FloatTensor(batch['action']).to(self.device)  
  reward = torch.FloatTensor(batch['reward']).to(self.device)  
  next_state = torch.FloatTensor(batch['next_state']).to(self.device)  
  done = torch.FloatTensor(batch['done']).to(self.device)  

  # 评论家网络更新  
  with torch.no_grad():  
  next_action, next_log_prob = self.actor.sample(next_state)  
  q1_next, q2_next = self.critic_target(next_state, next_action)  
  q_next = torch.min(q1_next, q2_next) - torch.exp(self.log_alpha) * next_log_prob  
  target_q = reward + (1 - done) * gamma * q_next  

  q1_current, q2_current = self.critic(state, action)  
  critic_loss = F.mse_loss(q1_current, target_q) + F.mse_loss(q2_current, target_q)  

  self.critic_optimizer.zero_grad()  
  critic_loss.backward()  
  self.critic_optimizer.step()  

  # 演员网络更新  
  action_new, log_prob = self.actor.sample(state)  
  q1_new, q2_new = self.critic(state, action_new)  
  q_new = torch.min(q1_new, q2_new)  
  actor_loss = (torch.exp(self.log_alpha) * log_prob - q_new).mean()  

  self.actor_optimizer.zero_grad()  
  actor_loss.backward()  
  self.actor_optimizer.step()  

  # 温度参数更新  
  alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()  

  self.alpha_optimizer.zero_grad()  
  alpha_loss.backward()  
  self.alpha_optimizer.step()  

  # 目标网络软更新  
  for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):  
  target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
总结

本文系统地阐述了SAC算法的数学基础和实现细节。通过对演员网络和评论家网络的深入分析,我们可以看到SAC算法在以下几个方面具有显著优势:

理论框架

  • 基于最大熵强化学习的理论基础保证了算法的收敛性
  • 双Q网络设计有效降低了值函数估计的过度偏差
  • 自适应温度参数实现了探索-利用的动态平衡

实现特点

  • 采用重参数化技巧确保了策略梯度的连续性
  • 软更新机制提升了训练稳定性
  • 基于PyTorch的向量化实现提高了计算效率

实践价值

  • 算法在连续动作空间中表现优异
  • 样本效率高,适合实际应用场景
  • 训练过程稳定,调参难度相对较小

未来研究可以在以下方向继续深化:

  • 探索更高效的策略表达方式
  • 研究多智能体场景下的SAC算法扩展
  • 结合迁移学习提升算法的泛化能力
  • 针对大规模状态空间优化网络架构

强化学习作为人工智能的核心研究方向之一,其理论体系和应用场景都在持续发展。深入理解算法的数学原理和实现细节,将有助于我们在这个快速演进的领域中把握技术本质,开发更有效的解决方案。作者:Najib Sharifi, Ph.D









#xxxx

#xxxx