DataParallel & DistributedDataParallel分布式训练

参考博客 《DataParallel & DistributedDataParallel分布式训练》:

细节参考博客(推荐)

###DDP

# 引入包
import argparse
import torch.distributed as dist

# 设置可选参数
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=0, type=int,
help='node rank for distributed training')
args = parser.parse_args()
# print(args.local_rank)
dist.init_process_group(backend='nccl')


# 1.上面讲到的初始化进程组
dist.init_process_group(backend='nccl')
torch.cuda.set_device(args.local_rank)

# 2.使用DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle = (train_sampler is None), sampler=train_sampler, pin_memory=False)

# 3.创建DDP模型进行分布式训练
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)

# 4.命令行开始训练 --nproc_per_node参数指定为当前主机创建的进程数(比如我当前可用但卡数是2 那就为这个主机创建两个进程,每个进程独立执行训练脚本)
# 我是单机多卡, 所以nnode=1, 就是一台主机, 一台主机上--nproc_per_node个进程
python -m torch.distributed.run --nnodes=1 --nproc_per_node=2 --node_rank=0 --master_port=6005 train.py

使用DP或者DDP在保存和使用模型时需要注意的地方

《使用DP或者DDP在保存和使用模型时需要注意的地方》:

在保存模型的时候建议用net.module.state_dict(),这是因为如果裁剪了DP或者DDP,网络结构变为nn.Sequential()这种数据类型了,而完美常用的保存方式是:

net = torch.nn.Linear(10,1)
# 先构造一个网络
net = torch.nn.DataParallel(net, device_ids=[0,3])
torch.save(net.module.state_dict(), './tmp.pth')

有了上述的知识基础,在加载模型的时候建议先用

def get_bare_model(net):    
if isinstance(net, (nn.DataParallel, nn.parallel.DistributedDataParalleled)):
net = net.module
return net


清澈的爱,只为中国