DataParallel & DistributedDataParallel分布式训练
参考博客 《DataParallel & DistributedDataParallel分布式训练》:
细节参考博客(推荐)
###DDP
# 引入包
import argparse
import torch.distributed as dist
# 设置可选参数
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=0, type=int,
help='node rank for distributed training')
args = parser.parse_args()
# print(args.local_rank)
dist.init_process_group(backend='nccl')
# 1.上面讲到的初始化进程组
dist.init_process_group(backend='nccl')
torch.cuda.set_device(args.local_rank)
# 2.使用DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle = (train_sampler is None), sampler=train_sampler, pin_memory=False)
# 3.创建DDP模型进行分布式训练
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
# 4.命令行开始训练 --nproc_per_node参数指定为当前主机创建的进程数(比如我当前可用但卡数是2 那就为这个主机创建两个进程,每个进程独立执行训练脚本)
# 我是单机多卡, 所以nnode=1, 就是一台主机, 一台主机上--nproc_per_node个进程
python -m torch.distributed.run --nnodes=1 --nproc_per_node=2 --node_rank=0 --master_port=6005 train.py
使用DP或者DDP在保存和使用模型时需要注意的地方
《使用DP或者DDP在保存和使用模型时需要注意的地方》:
在保存模型的时候建议用net.module.state_dict(),这是因为如果裁剪了DP或者DDP,网络结构变为nn.Sequential()这种数据类型了,而完美常用的保存方式是:
net = torch.nn.Linear(10,1)
# 先构造一个网络
net = torch.nn.DataParallel(net, device_ids=[0,3])
torch.save(net.module.state_dict(), './tmp.pth')
有了上述的知识基础,在加载模型的时候建议先用
def get_bare_model(net):
if isinstance(net, (nn.DataParallel, nn.parallel.DistributedDataParalleled)):
net = net.module
return net
清澈的爱,只为中国