SparkNet: Training Deep Network in Spark
训练深度神经网络是一个非常耗时的过程,比如用卷积神经网络去训练一个目标识别任务需要好几天来训练。因此,充分利用集群的资源,加快训练速度成了一个非常重要的领域。不过,当前非常热门的批处理计算架构(例如:MapReduce 和 Spark)都不是设计用来专门支持异步计算和现有的一些通信密集型的深度学习系统。
SparkNet 是基于Spark的深度神经网络架构,
- 它提供了便捷的接口能够去访问Spark RDDs;
- 同时提供Scala接口去调用caffe;
- 还拥有一个轻量级的tensor 库;
- 使用了一个简单的并行机制来实现SGD的并行化,使得SparkNet能够很好的适应集群的大小并且能够容忍极高的通信延时;
- 它易于部署,并且不需要对参数进行调整;
- 它还能很好的兼容现有的caffe模型;
下面这张图是SparkNet的架构:
从上图可以看出,Master 向每个worker 分发任务之后,各个worker都单独的使用Caffe(利用GPU)来进行训练。每个worker完成任务之后,把参数传回Master。论文用了5个节点的EC2集群,broadcast 和 collect 参数(每个worker几百M),耗时20秒,而一个minibatch的计算时间是2秒。
Implementation
SparkNet 是建立在Apache Spark和Caffe深度学习库的基础之上的。SparkNet 用Java来访问Caffe的数据,用Scala来访问Caffe的参数,用ScalaBuff来使得Caffe网络在运行时保持动态结构。SparkNet能够兼容Caffe的一些模型定义文件,并且支持Caffe模型参数的载入。
下面简单贴一下SparkNet的api和模型定义、模型训练代码。
并行化的SGD
为了让模型能够在带宽受限的环境下也能运行得很好,论文提出了一种SGD的并行化机制使得最大幅度减小通信,这也是全文最大了亮点。这个方法也不是只针对SGD,实际上对Caffe的各种优化求解方法都有效。
在将SparkNet的并行化机制之前,先介绍一种Naive的并行机制。
Naive SGD Parallelization
Spark拥有一个master节点和一些worker节点。数据分散在各个worker中的。
在每一次的迭代中,Spark master节点都会通过broadcast(广播)的方式,把模型参数传到各个worker节点中。
各个worker节点在自己分到的部分数据,在同一个模型上跑一个minibatch 的SGD。
完成之后,各个worker把训练的模型参数再发送回master,master将这些参数进行一个平均操作,作为新的(下一次迭代)的模型参数。
这是很多人都会采用的方法,看上去很对,不过它有一些缺陷。
Naive 并行化的缺陷
这个缺陷就是需要消耗太多的通信带宽,因为每一次minibatch训练都要broadcast 和 collect 一次,而这个过程特别消耗时间(20秒左右)。
令 Na(b) 表示,在batch-size为 b 的情况下,到达准确率 a 所需要的迭代次数。
令 C(b) 表示,在batch-size 为 b 的情况下,SGD训练一个batch的训练时间(约2秒)。
显然,使用SGD达到准确率为a所需要的时间消耗是:
Na(b)C(b)
假设有K个机器,通信(broadcast 和 collect)的时间为 S,那么Naive 并行 SGD
的时间消耗就是:
Na(b)(C(b)/K+S)
SparkNet 的并行化机制
基本上过程和Naive 并行化差不多。唯一的区别在于,各个worker节点每次不再只跑一个迭代,而是在自己分到的一个minibatch数据集上,迭代多次,迭代次数是一个固定值τ。
SparkNet的并行机制是分好几个rounds来跑的。在每一个round中,每个机器都在batch size为b 的数据集上跑 τ
我们用Ma(b,K,τ) 表示达到准确率 a 所需要的 round 次数。
因此,SparkNet需要的时间消耗就是:
Ma(b,K,τ)∗(τC(b)+S)
下面这张图,很直观的对比了Naive 并行机制跟 SparkNet 并行机制的区别:
Naive 并行机制:
SparkNet 并行机制:
论文还做了各种对比实验,包括时间,准确率等。实验模型采用AlexNet,数据集是ImageNet的子集(100类,每类1000张)。
假设S=0,那么τMa(b,K,τ)/Na(b) 就是SparkNet的加速倍数。论文通过改变τ 和 K
上面的表格还是体现了一些趋势的:
(1). 看第一行,当K=1,因为只有一个worker节点,所以异步计算的τ这时并没有起到什么作用,可以看到第一行基本的值基本都是接近1.
(2). 看最右边这列,当τ=1,这其实就相当于是Naive 并行机制,只不过,Naive的batch是b/K,这里是b. 这一列基本上是跟K成正比。
(3). 注意到每一行的值并不是从左到右一直递增的。
当S!=0
可以看到,当S接近与0的时候(带宽高),Naive会比SparkNet速度更快,但是,当S 变大(带宽受限),SparkNet的性能将超过Naive,并且可以看出,Naive受S变化剧烈, 而SparkNet相对平稳。
而作者实验用EC2环境,S大概是20秒,所以,显然,SparkNet会比Naive好很多。
论文还做了一些事情,比如:
令 τ=50,分别测试K=1、3、5、10时,准确率与时间的关系;
当K=5,分别测试τ=20、50、100、150时,准确率与时间的关系。
总结一下,这篇论文其实没有太多复杂的创新(除了SGD并行化时候的一点小改进),不过我很期待后续的工作,同时也希望这个SparkNet能够维护的越来越好。有时间的话,还是很想试试这个SparkNet的