PyTorch的distributed训练已经逐渐代替nn.Dataparallel的方式,因为官方对其有更好的支持,并且训练速度更快。大家可能知道一种启动方法,就是用torch.launch启动。但有没有被冗长的代码段惹的不开心呢。今天阿杰为大家带来一种更简单的启动方式,那就是torch.multiprocessing
Note: torch.multiprocessing的启动和用torch.launch本质是一样的,就是单纯的代码量少。
使用方法
使用头文件
import torch.multiprocessing as mp
import torch.distributed as dist
大家需要知道的是,分布式多进程是每一个进程都要定义模型,定义任何训练需要的东西。所以在主进程,只完成超参数定义,以及启动多进程所需要的设置即可。
if __name__ == '__main__':
# import configuration file
# load json or yaml, argsparse
args = xxxxx
# 接下来是设置多进程启动的代码
# 1.首先设置端口,采用随机的办法,被占用的概率几乎很低.
port_id = 10000 + np.random.randint(0, 1000)
args.dist_url = 'tcp://127.0.0.1:' + str(port_id)
# 2. 然后统计能使用的GPU,决定我们要开几个进程,也被称为world size
args.num_gpus = torch.cuda.device_count()
if args.num_gpus == 1:
main_worker(rank=0, args=args)
else:
# 3. 多进程的启动
torch.multiprocessing.set_start_method('spawn')
mp.spawn(main_worker, nprocs=args.num_gpus, args=(args,))
我必须要详细介绍一下mp.spawn第一个参数是一个函数,这个函数将执行训练的所有步骤。从这一步开始,python将建立多个进程,每个进程都会执行main_worker函数。 第二个参数是开启的进程数目。第三个参数是main_worker的函数实参。
然后看main_worker的定义,特别注意一下。我们送入的两个实参,但实际形参有两个。没错,第一个形参是进程id号(必须要多加一个形参,且放到第一个位置上)。id号是从0到(总进程数目-1)的。id为0的进程我们就叫做主进程。之所以需要区分进程,因为我们一般打印日志和存权重文件,总不会希望每个进程都做一次相同的事情吧。我们只在主进程完成这个事情就行了(用if 判断一下)。
def main_worker(gpu, args): # gpu参数控制进程号
接下来就是常规操作
需要完成Pytorch分布式必须写的几段代码。
args.rank = gpu # 用rank记录进程id号
dist.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.num_gpus,
rank=args.rank)
torch.cuda.set_device(gpu) # 设置默认GPU 最好方法哦init之后,这样你使用.cuda(),数据就是去指定的gpu上了
# 定义模型, 转同步BN
model = xxx
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu],
find_unused_parameters=True )
#定义数据集
train_dataset = xxxx
# 注意这一步,和单卡的唯一区别。这个sample能保证多个进程不会取重复的数据。shuffle必须设置为False(默认)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size,
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
接下来就和单卡的操作一样了。执行的之后,直接用普通命令就可以执行多进程分布式训练了。
python xxx.py
而不需要使用torch.launch.xxxxx 一大串东西了。
最好友情提示一下,在打印日志或者保存模型,需要在主进程下执行哦
if args.rank == 0 : # 主进程
torch.save(xx)
print(log)