Resnet50的细节讲解

残差神经网络 (ResNet)也是需要掌握的模型,需要自己手动实现理解细节。本文就是对代码的细节讲解,话不多说,开始了。

pytorch如何调用SSIM loss pytorch如何调用resnet_深度学习

首先你需要了解它的结构,本文以resnet50围绕讲解,网络的输入照片大小是224x224的经过conv1,conv2,conv3,conv4,conv5最后在平均池化,全连接层。由于中间有重复利用的模块,所以我们需要将它们写成一个类,用来重复调用即可。

Resnet之所以能够训练那么深的原因就是它的结构,在不断向后训练的过程中依旧保留浅层特征,我个人的理解,正常情况下经过一次又一次的卷积,浅层的特征逐渐消失,然而Resnet在向后训练的过程中不断的加上前面浅层的特征,这样更加丰富特征的全局性

pytorch如何调用SSIM loss pytorch如何调用resnet_2d_02


这张图就是我们需要写模块的根据,它们的区别是卷积层数目的不同,我们本文讲解resnet50,所以是以右边的版块为例。先看代码

class Block(nn.Module):
    def __init__(self, in_channels, filters, stride=1, is_1x1conv=False):
        super(Block, self).__init__()
        filter1, filter2, filter3 = filters
        self.is_1x1conv = is_1x1conv
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride,bias=False),
            nn.BatchNorm2d(filter1),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(filter1, filter2, kernel_size=3, stride=1, padding=1,  bias=False),
            nn.BatchNorm2d(filter2),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(filter2, filter3, kernel_size=1, stride=1,  bias=False),
            nn.BatchNorm2d(filter3),
        )
        if is_1x1conv:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride,  bias=False),
                nn.BatchNorm2d(filter3)
            )
    def forward(self, x):
        x_shortcut = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.is_1x1conv:
            x_shortcut = self.shortcut(x_shortcut)
        x = x + x_shortcut
        x = self.relu(x)
        return x

我们将其写成一个类,这样多次使用方便。

def __init__(self, in_channels, filters, stride=1, is_1x1conv=False):

in_channels是通道数,filter1, filter2, filter3 = filters这样写的原因是我们对应板块来写,一个板块进行了3次卷积。is_1x1conv=False 这个就是直接将浅层的特征图仅仅经历一次卷积的捷径,正常情况下应该是三次卷积。这里有一点细节就是我在写self.conv3 这个卷积的时候没有加上Relu()函数,主要是这里需要判断这个板块是否激活了self.shortcut,只有加上这个之后才能一起Relu。

self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride),
                nn.BatchNorm2d(filter3)

这段代码就是特征图捷径,浅层特征图就经历一次卷积直接与进行三次卷积之后的特征图相加

def _make_layer(self, in_channels, filters, blocks, stride=1):
        layers = []
        block_1 = Block(in_channels, filters, stride=stride, is_1x1conv=True)
        layers.append(block_1)
        for i in range(1, blocks):
            layers.append(Block(filters[2], filters, stride=1, is_1x1conv=False))

        return nn.Sequential(*layers)

写这个函数就是利用for循环多次使用重复的板块,这里也有一些细节

block_1 = Block(in_channels, filters, stride=stride, is_1x1conv=True)
 layers.append(Block(filters[2], filters, stride=1, is_1x1conv=False))

每个板块的第一次卷积和后面的卷积stride的设置是不同的,主要是板块从conv3开始第一次就是进行stride=2的设定,这样经过这层卷积,特征图的大小变为原来的二分之一。

self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = self._make_layer(64, (64, 64, 256), Layers[0])
        self.conv3 = self._make_layer(256, (128, 128, 512), Layers[1], 2)
        self.conv4 = self._make_layer(512, (256, 256, 1024), Layers[2], 2)
        self.conv5 = self._make_layer(1024, (512, 512, 2048), Layers[3], 2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(2048, 1000)
        )

这便是整个resnet50的网络设置了,我在上一篇pytorch实现inception模型有讲到如何计算特征图经过卷积之后的大小。举个例子,输入的图像为224x224的,进过conv1 (224+2x3+1-7)/2 ,公式是(n+2xpadding+1-kernel)/stride.整个的特征图变化如下

pytorch如何调用SSIM loss pytorch如何调用resnet_2d_03

self.conv2 = self._make_layer(64, (64, 64, 256), Layers[0])

Layer[0]是我之前提前设置好需要重复的次数。Layers = [3, 4, 6, 3],整个的代码流程是这样的

import torch
import torch.nn as nn

Layers = [3, 4, 6, 3]
class Block(nn.Module):
    def __init__(self, in_channels, filters, stride=1, is_1x1conv=False):
        super(Block, self).__init__()
        filter1, filter2, filter3 = filters
        self.is_1x1conv = is_1x1conv
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride,bias=False),
            nn.BatchNorm2d(filter1),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(filter1, filter2, kernel_size=3, stride=1, padding=1,  bias=False),
            nn.BatchNorm2d(filter2),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(filter2, filter3, kernel_size=1, stride=1,  bias=False),
            nn.BatchNorm2d(filter3),
        )
        if is_1x1conv:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride,  bias=False),
                nn.BatchNorm2d(filter3)
            )
    def forward(self, x):
        x_shortcut = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.is_1x1conv:
            x_shortcut = self.shortcut(x_shortcut)
        x = x + x_shortcut
        x = self.relu(x)
        return x


class Resnet50(nn.Module):

    def __init__(self):
        super(Resnet50,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = self._make_layer(64, (64, 64, 256), Layers[0])
        self.conv3 = self._make_layer(256, (128, 128, 512), Layers[1], 2)
        self.conv4 = self._make_layer(512, (256, 256, 1024), Layers[2], 2)
        self.conv5 = self._make_layer(1024, (512, 512, 2048), Layers[3], 2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(2048, 1000)
        )
    def forward(self, input):
        x = self.conv1(input)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    def _make_layer(self, in_channels, filters, blocks, stride=1):
        layers = []
        block_1 = Block(in_channels, filters, stride=stride, is_1x1conv=True)
        layers.append(block_1)
        for i in range(1, blocks):
            layers.append(Block(filters[2], filters, stride=1, is_1x1conv=False))

        return nn.Sequential(*layers)

net = Resnet50()
x = torch.rand((10, 3, 224, 224))
for name,layer in net.named_children():
    if name != "fc":
        x = layer(x)
        print(name, 'output shaoe:', x.shape)
    else:
        x = x.view(x.size(0), -1)
        x = layer(x)
        print(name, 'output shaoe:', x.shape)

训练效果

pytorch如何调用SSIM loss pytorch如何调用resnet_ide_04


可以看出明显过拟合了,因为数据集的数量太小,resnet50的框架比较大,这里为了方便训练,所以还是使用的cifar 10分类数据集

训练代码 可直接运行resnet50网络

import torch
import torch.nn as nn

Layers = [3, 4, 6, 3]
class Block(nn.Module):
    def __init__(self, in_channels, filters, stride=1, is_1x1conv=False):
        super(Block, self).__init__()
        filter1, filter2, filter3 = filters
        self.is_1x1conv = is_1x1conv
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride,bias=False),
            nn.BatchNorm2d(filter1),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(filter1, filter2, kernel_size=3, stride=1, padding=1,  bias=False),
            nn.BatchNorm2d(filter2),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(filter2, filter3, kernel_size=1, stride=1,  bias=False),
            nn.BatchNorm2d(filter3),
        )
        if is_1x1conv:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride,  bias=False),
                nn.BatchNorm2d(filter3)
            )
    def forward(self, x):
        x_shortcut = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.is_1x1conv:
            x_shortcut = self.shortcut(x_shortcut)
        x = x + x_shortcut
        x = self.relu(x)
        return x


class Resnet50(nn.Module):

    def __init__(self ,num_classes):
        super(Resnet50,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = self._make_layer(64, (64, 64, 256), Layers[0])
        self.conv3 = self._make_layer(256, (128, 128, 512), Layers[1], 2)
        self.conv4 = self._make_layer(512, (256, 256, 1024), Layers[2], 2)
        self.conv5 = self._make_layer(1024, (512, 512, 2048), Layers[3], 2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(2048, num_classes)
        )
    def forward(self, input):
        x = self.conv1(input)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    def _make_layer(self, in_channels, filters, blocks, stride=1):
        layers = []
        block_1 = Block(in_channels, filters, stride=stride, is_1x1conv=True)
        layers.append(block_1)
        for i in range(1, blocks):
            layers.append(Block(filters[2], filters, stride=1, is_1x1conv=False))

        return nn.Sequential(*layers)





import time
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


def load_dataset(batch_size):
    train_set = torchvision.datasets.CIFAR10(
        root="data/cifar-10", train=True,
        download=True, transform=transforms.ToTensor()
    )
    test_set = torchvision.datasets.CIFAR10(
        root="data/cifar-10", train=False,
        download=True, transform=transforms.ToTensor()
    )
    train_iter = torch.utils.data.DataLoader(
        train_set, batch_size=batch_size, shuffle=True, num_workers=4
    )
    test_iter = torch.utils.data.DataLoader(
        test_set, batch_size=batch_size, shuffle=True, num_workers=4
    )
    return train_iter, test_iter


def train(net, train_iter, criterion, optimizer, num_epochs, device, num_print, lr_scheduler=None, test_iter=None):
    net.train()
    record_train = list()
    record_test = list()

    for epoch in range(num_epochs):
        print("========== epoch: [{}/{}] ==========".format(epoch + 1, num_epochs))
        total, correct, train_loss = 0, 0, 0
        start = time.time()

        for i, (X, y) in enumerate(train_iter):
            X, y = X.to(device), y.to(device)
            output = net(X)
            loss = criterion(output, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            total += y.size(0)
            correct += (output.argmax(dim=1) == y).sum().item()
            train_acc = 100.0 * correct / total

            if (i + 1) % num_print == 0:
                print("step: [{}/{}], train_loss: {:.3f} | train_acc: {:6.3f}% | lr: {:.6f}" \
                    .format(i + 1, len(train_iter), train_loss / (i + 1), \
                            train_acc, get_cur_lr(optimizer)))


        if lr_scheduler is not None:
            lr_scheduler.step()

        print("--- cost time: {:.4f}s ---".format(time.time() - start))

        if test_iter is not None:
            record_test.append(test(net, test_iter, criterion, device))
        record_train.append(train_acc)

    return record_train, record_test


def test(net, test_iter, criterion, device):
    total, correct = 0, 0
    net.eval()

    with torch.no_grad():
        print("*************** test ***************")
        for X, y in test_iter:
            X, y = X.to(device), y.to(device)

            output = net(X)
            loss = criterion(output, y)

            total += y.size(0)
            correct += (output.argmax(dim=1) == y).sum().item()

    test_acc = 100.0 * correct / total

    print("test_loss: {:.3f} | test_acc: {:6.3f}%"\
          .format(loss.item(), test_acc))
    print("************************************\n")
    net.train()

    return test_acc


def get_cur_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def learning_curve(record_train, record_test=None):
    plt.style.use("ggplot")

    plt.plot(range(1, len(record_train) + 1), record_train, label="train acc")
    if record_test is not None:
        plt.plot(range(1, len(record_test) + 1), record_test, label="test acc")

    plt.legend(loc=4)
    plt.title("learning curve")
    plt.xticks(range(0, len(record_train) + 1, 5))
    plt.yticks(range(0, 101, 5))
    plt.xlabel("epoch")
    plt.ylabel("accuracy")

    plt.show()


import torch.optim as optim


BATCH_SIZE = 128
NUM_EPOCHS = 20
NUM_CLASSES = 10
LEARNING_RATE = 0.02
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
NUM_PRINT = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def main():
    net = Resnet50(num_classes=NUM_CLASSES)
    net = net.to(DEVICE)

    train_iter, test_iter = load_dataset(BATCH_SIZE)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        net.parameters(),
        lr=LEARNING_RATE,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY,
        nesterov=True
    )
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    record_train, record_test = train(net, train_iter, criterion, optimizer, \
          NUM_EPOCHS, DEVICE, NUM_PRINT, lr_scheduler, test_iter)

    learning_curve(record_train, record_test)


if __name__ == '__main__':
    main()

分布式训练代码

终端运行指令
python -m torch.distributed.launch --nproc_per_node=2 test2.py

import torch
import torch.nn as nn
import argparse
import torch.distributed as dist  #1. DDP相关包
import torch.utils.data.distributed




Layers = [3, 4, 6, 3]
class Block(nn.Module):
    def __init__(self, in_channels, filters, stride=1, is_1x1conv=False):
        super(Block, self).__init__()
        filter1, filter2, filter3 = filters
        self.is_1x1conv = is_1x1conv
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride,bias=False),
            nn.BatchNorm2d(filter1),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(filter1, filter2, kernel_size=3, stride=1, padding=1,  bias=False),
            nn.BatchNorm2d(filter2),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(filter2, filter3, kernel_size=1, stride=1,  bias=False),
            nn.BatchNorm2d(filter3),
        )
        if is_1x1conv:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride,  bias=False),
                nn.BatchNorm2d(filter3)
            )
    def forward(self, x):
        x_shortcut = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.is_1x1conv:
            x_shortcut = self.shortcut(x_shortcut)
        x = x + x_shortcut
        x = self.relu(x)
        return x


class Resnet50(nn.Module):

    def __init__(self ,num_classes):
        super(Resnet50,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = self._make_layer(64, (64, 64, 256), Layers[0])
        self.conv3 = self._make_layer(256, (128, 128, 512), Layers[1], 2)
        self.conv4 = self._make_layer(512, (256, 256, 1024), Layers[2], 2)
        self.conv5 = self._make_layer(1024, (512, 512, 2048), Layers[3], 2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(2048, num_classes)
        )
    def forward(self, input):
        x = self.conv1(input)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    def _make_layer(self, in_channels, filters, blocks, stride=1):
        layers = []
        block_1 = Block(in_channels, filters, stride=stride, is_1x1conv=True)
        layers.append(block_1)
        for i in range(1, blocks):
            layers.append(Block(filters[2], filters, stride=1, is_1x1conv=False))

        return nn.Sequential(*layers)





import time
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader


os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

def load_dataset(batch_size,opt):
    train_set = torchvision.datasets.CIFAR10(
        root="data/cifar-10", train=True,
        download=True, transform=transforms.ToTensor()
    )
    
    train_sampler   = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True,)
    train_iter = torch.utils.data.DataLoader(
        train_set, batch_size=batch_size, num_workers=4, sampler=train_sampler
    )
    test_set = torchvision.datasets.CIFAR10(
        root="data/cifar-10", train=False,
        download=True, transform=transforms.ToTensor()
    )
    test_iter = torch.utils.data.DataLoader(
        test_set, batch_size=batch_size, num_workers=4
    )
    return train_iter, test_iter


def train(net, train_iter, criterion, optimizer, num_epochs, device, num_print, lr_scheduler=None, test_iter=None, opt=None):
    net.train()
    record_train = list()
    record_test = list()

    for epoch in range(num_epochs):
        if opt.local_rank != -1:
            train_iter.sampler.set_epoch(epoch)
        print("========== epoch: [{}/{}] ==========".format(epoch + 1, num_epochs))
        total, correct, train_loss = 0, 0, 0
        start = time.time()

        for i, (X, y) in enumerate(train_iter):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            output = net(X)
            loss = criterion(output, y)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            total += y.size(0)
            correct += (output.argmax(dim=1) == y).sum().item()
            train_acc = 100.0 * correct / total

            if opt.local_rank in [-1, 0] and (i + 1) % num_print == 0:
                print("step: [{}/{}], train_loss: {:.3f} | train_acc: {:6.3f}% | lr: {:.6f}" \
                    .format(i + 1, len(train_iter), train_loss / (i + 1), \
                            train_acc, get_cur_lr(optimizer)))


        if lr_scheduler is not None:
            lr_scheduler.step()

        print("--- cost time: {:.4f}s ---".format(time.time() - start))

        if opt.local_rank in [-1, 0] and test_iter is not None:
            record_test.append(test(net, test_iter, device, criterion))
        record_train.append(train_acc)

    return record_train, record_test


def test(net, test_iter, device,criterion):
    total, correct = 0, 0
    net.eval()

    with torch.no_grad():
        print("*************** test ***************")
        for X, y in test_iter:
            X, y = X.to(device), y.to(device)

            output = net(X)
            loss = criterion(output, y)

            total += y.size(0)
            correct += (output.argmax(dim=1) == y).sum().item()

    test_acc = 100.0 * correct / total

    print("test_loss: {:.3f} | test_acc: {:6.3f}%"\
          .format(loss.item(), test_acc))
    print("************************************\n")
    net.train()

    return test_acc


def get_cur_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def learning_curve(record_train, record_test=None):
    plt.style.use("ggplot")

    plt.plot(range(1, len(record_train) + 1), record_train, label="train acc")
    if record_test is not None:
        plt.plot(range(1, len(record_test) + 1), record_test, label="test acc")

    plt.legend(loc=4)
    plt.title("learning curve")
    plt.xticks(range(0, len(record_train) + 1, 5))
    plt.yticks(range(0, 101, 5))
    plt.xlabel("epoch")
    plt.ylabel("accuracy")

    plt.show()


import torch.optim as optim


BATCH_SIZE = 256
NUM_EPOCHS = 20
NUM_CLASSES = 10
LEARNING_RATE = 0.02
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
NUM_PRINT = 40


def main(opt):
    dist.init_process_group(backend='nccl', init_method=opt.init_method)
    device = torch.device('cuda', opt.local_rank if torch.cuda.is_available() else 'cpu')
    print("Using device:{}\n".format(device))


    train_iter, test_iter= load_dataset(BATCH_SIZE,opt)

    net = Resnet50(num_classes=NUM_CLASSES)
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[opt.local_rank], output_device=opt.local_rank) #7. 将模型包装成分布式

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        net.parameters(),
        lr=LEARNING_RATE,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY,
        nesterov=True
    )
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    record_train, record_test = train(net, train_iter, criterion, optimizer, \
          NUM_EPOCHS, device, NUM_PRINT, lr_scheduler, test_iter,opt)

    learning_curve(record_train, record_test)



if __name__ == '__main__':
    
    parser = argparse.ArgumentParser('DDP training script.')
    parser.add_argument('--local_rank', type=int, default=-1, help='local_rank of current process') #2. 指定local_rank,这个参数必须要有
    parser.add_argument('--init_method', default='env://') #3.指定初始化方式,这里用的是环境变量的初始化方式
    opt = parser.parse_args()
    if opt.local_rank in [-1, 0]:
        print("opt:", opt)

    main(opt)