#所需要的包:
import torch
from torch import nn, optim # nn:神经网络模块,optim:优化器
from torch.optim.lr_scheduler import CosineAnnealingLR # 学习率衰减:导入余弦退火学习率调度器
from torchinfo import summary#打印网络模型各层信息的库
import timm#timm是一个用于图像分类的PyTorch库
import torchvision#torchvision是一个用于处理图像和视频的库
import torchvision.transforms as transforms  #用于图像预处理的库
from torch.utils.data import DataLoader  #用于加载数据的库
from matplotlib import pyplot as plt #用于画图的库
import numpy as np #用于数学计算的库
from tqdm import tqdm#进度条的库
import Ranger #Ranger优化器

#--------------对训练函数进行封装
#net(神经网络模型)、loss(损失函数)、train_dataloader和valid_dataloader(训练集和验证集数据加载器)、device(训练设备)、batch_size(批量大小)、num_epoch(训练轮数)、lr(初始学习率)、lr_min(最小学习率)、optim(优化器类型)、init(是否进行权重初始化)和scheduler_type(学习率调度器类型)
def train(net, loss, train_dataloader, valid_dataloader, device, batch_size, num_epoch, lr, lr_min, optim='sgd', init=True, scheduler_type='Cosine'):#训练函数
    def init_xavier(m): #参数初始化
        #if type(m) == nn.Linear or type(m) == nn.Conv2d:#如果是全连接层或者卷积层
        if type(m) == nn.Linear:#如果是全连接层
            nn.init.xavier_normal_(m.weight) #权重初始化

    if init:  #是否进行权重初始化
        net.apply(init_xavier)#应用权重初始化

    print('training on:', device)   #打印训练设备
    net.to(device)  #将网络模型放到训练设备上
    #优化器选择
    if optim == 'sgd':  #sgd优化器
        optimizer = torch.optim.SGD((param for param in net.parameters() if param.requires_grad), lr=lr,    #param.requires_grad:是否需要梯度
                                    weight_decay=0) #weight_decay:权重衰减
    elif optim == 'adam': #adam优化器
        optimizer = torch.optim.Adam((param for param in net.parameters() if param.requires_grad), lr=lr, #param.requires_grad:是否需要梯度
                                     weight_decay=0)
    elif optim == 'adamW': #adamW优化器
        optimizer = torch.optim.AdamW((param for param in net.parameters() if param.requires_grad), lr=lr, #param.requires_grad:是否需要梯度
                                      weight_decay=0)
    elif optim == 'ranger': #ranger优化器
        optimizer = Ranger((param for param in net.parameters() if param.requires_grad), lr=lr,#param.requires_grad:是否需要梯度
                           weight_decay=0)
    if scheduler_type == 'Cosine':#余弦退火学习率调度器
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epoch, eta_min=lr_min)
    #用来保存每个epoch的Loss和acc以便最后画图
    train_losses = [] #训练集Loss
    train_acces = [] #训练集acc
    eval_acces = [] #验证集acc
    best_acc = 0.0 #最好的acc
    #训练
    for epoch in range(num_epoch):#训练轮数

        print("——————第 {} 轮训练开始——————".format(epoch + 1))  #打印训练轮数

        # 训练开始
        net.train()#训练模式
        train_acc = 0 #训练集acc
        for batch in tqdm(train_dataloader, desc='训练'): #tqdm:进度条
            imgs, targets = batch #imgs:图片,targets:标签
            imgs = imgs.to(device) #将图片放到训练设备上
            targets = targets.to(device) #将标签放到训练设备上
            output = net(imgs) #输出

            Loss = loss(output, targets) #损失函数
          
            optimizer.zero_grad() #梯度清零
            Loss.backward() #反向传播
            optimizer.step() #优化器更新参数

            _, pred = output.max(1) #预测
            num_correct = (pred == targets).sum().item() #预测正确的数量
            acc = num_correct / (batch_size) #acc
            train_acc += acc #训练集acc
        scheduler.step() #学习率调度器更新学习率
        print("epoch: {}, Loss: {}, Acc: {}".format(epoch, Loss.item(), train_acc / len(train_dataloader))) #打印训练轮数、Loss和acc
        train_acces.append(train_acc / len(train_dataloader)) #保存训练集acc
        train_losses.append(Loss.item()) #保存训练集Loss

        # 测试步骤开始
        net.eval() #测试模式
        eval_loss = 0 #验证集Loss
        eval_acc = 0 #验证集acc
        with torch.no_grad():  #不进行梯度计算
            for imgs, targets in valid_dataloader:  #验证集
                imgs = imgs.to(device) #将图片放到训练设备上
                targets = targets.to(device) #将标签放到训练设备上
                output = net(imgs) #输出
                Loss = loss(output, targets) #损失函数
                _, pred = output.max(1) #预测
                num_correct = (pred == targets).sum().item() #预测正确的数量
                eval_loss += Loss  #累加验证集Loss
                acc = num_correct / imgs.shape[0] #acc
                eval_acc += acc #累加验证集acc

            eval_losses = eval_loss / (len(valid_dataloader)) #整体验证集上的Loss
            eval_acc = eval_acc / (len(valid_dataloader)) #整体验证集上的acc
            if eval_acc > best_acc: #如果acc大于最好的acc
                best_acc = eval_acc #更新最好的acc
                torch.save(net.state_dict(),'best_acc.pth') #保存最好的acc
            eval_acces.append(eval_acc) #保存验证集acc
            print("整体验证集上的Loss: {}".format(eval_losses)) #打印整体验证集上的Loss
            print("整体验证集上的正确率: {}".format(eval_acc)) #打印整体验证集上的acc
    return train_losses, train_acces, eval_acces #返回训练集Loss、训练集acc、验证集acc
#-----------------------对整体进行封装(以CIFAR-10数据集为例):

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchinfo import summary
import timm
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import Ranger

def get_dataloader(batch_size): #获取数据集
    data_transform = { #数据预处理
        "train": transforms.Compose([transforms.RandomResizedCrop(224),#随机裁剪
                                     transforms.RandomHorizontalFlip(), #随机水平翻转
                                     transforms.ToTensor(), #转换为张量
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), #标准化
        "val": transforms.Compose([transforms.Resize(256), #调整大小
                                   transforms.CenterCrop(224), #中心裁剪
                                   transforms.ToTensor(), #转换为张量
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} #标准化
    train_dataset = torchvision.datasets.CIFAR10('./p10_dataset', train=True, transform=data_transform["train"], download=True) #训练集
    test_dataset = torchvision.datasets.CIFAR10('./p10_dataset', train=False, transform=data_transform["val"], download=True) #测试集
    print('训练数据集长度: {}'.format(len(train_dataset))) #打印训练集长度
    print('测试数据集长度: {}'.format(len(test_dataset))) #打印测试集长度
    # DataLoader创建数据集
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) #训练集
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) #测试集
    return train_dataloader,test_dataloader  #返回训练集和测试集

def show_pic(dataloader):#展示dataloader里的6张图片
    examples = enumerate(dataloader)  # 组合成一个索引序列
    batch_idx, (example_data, example_targets) = next(examples)  # 从examples中取出一个batch
    classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') #类别
    fig = plt.figure() #创建一个窗口
    for i in range(6): #展示6张图片
        plt.subplot(2, 3, i + 1) #2行3列
        # plt.tight_layout()
        img = example_data[i]  # 取出其中一张图片
        print('pic shape:',img.shape)  #打印图片形状
        img = img.swapaxes(0, 1) #交换维度
        img = img.swapaxes(1, 2) #交换维度
        plt.imshow(img, interpolation='none') #展示图片
        plt.title(classes[example_targets[i].item()]) #图片标题
        plt.xticks([]) #x轴刻度
        plt.yticks([]) #y轴刻度
    plt.show()#展示图片

def get_net(): #获得预训练模型并冻住前面层的参数
    net = timm.create_model('resnet50', pretrained=True, num_classes=10)#获取预训练模型
    print(summary(net, input_size=(128, 3, 224, 224))) #打印模型结构
    '''Freeze all layers except the last layer(fc or classifier)''' #冻住前面层的参数
    for param in net.parameters(): #冻住前面层的参数
        param.requires_grad = False #冻住前面层的参数
    # nn.init.xavier_normal_(model.fc.weight) #初始化最后一层的参数
    # nn.init.zeros_(model.fc.bias) #初始化最后一层的参数
    net.fc.weight.requires_grad = True #最后一层的权重参数需要梯度
    net.fc.bias.requires_grad = True #最后一层的偏置参数需要梯度
    return net #返回模型

def train(net, loss, train_dataloader, valid_dataloader, device, batch_size, num_epoch, lr, lr_min, optim='sgd', init=True, scheduler_type='Cosine'):
    def init_xavier(m):#初始化
        #if type(m) == nn.Linear or type(m) == nn.Conv2d: #初始化全连接层和卷积层
        if type(m) == nn.Linear: #初始化全连接层
            nn.init.xavier_normal_(m.weight) #初始化权重

    if init: #初始化
        net.apply(init_xavier) #初始化

    print('training on:', device) #打印训练设备
    net.to(device)  # 将网络移动到device上

    if optim == 'sgd':#优化器
        optimizer = torch.optim.SGD((param for param in net.parameters() if param.requires_grad), lr=lr, #优化器
                                    weight_decay=0)
    elif optim == 'adam':#优化器
        optimizer = torch.optim.Adam((param for param in net.parameters() if param.requires_grad), lr=lr, #优化器
                                     weight_decay=0)
    elif optim == 'adamW':#优化器
        optimizer = torch.optim.AdamW((param for param in net.parameters() if param.requires_grad), lr=lr,
                                      weight_decay=0)
    elif optim == 'ranger':
        optimizer = Ranger((param for param in net.parameters() if param.requires_grad), lr=lr,
                           weight_decay=0)
    if scheduler_type == 'Cosine': #学习率衰减
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epoch, eta_min=lr_min)#学习率衰减

    train_losses = []#训练损失
    train_acces = [] #训练准确率
    eval_acces = [] #测试准确率
    best_acc = 0.0 #最好的准确率
    for epoch in range(num_epoch): #训练num_epoch轮

        print("——————第 {} 轮训练开始——————".format(epoch + 1)) #打印训练开始

        # 训练开始
        net.train() #训练模式
        train_acc = 0 #训练准确率
        for batch in tqdm(train_dataloader, desc='训练'): #训练集
            imgs, targets = batch #取出一个batch的数据
            imgs = imgs.to(device) #将数据移动到device上
            targets = targets.to(device) #将数据移动到device上
            output = net(imgs) #前向传播

            Loss = loss(output, targets) #计算损失
        
            optimizer.zero_grad() #梯度清零
            Loss.backward() #反向传播
            optimizer.step() #更新参数

            _, pred = output.max(1) #取出最大值
            num_correct = (pred == targets).sum().item() #计算正确个数
            acc = num_correct / (batch_size) #计算准确率
            train_acc += acc #计算训练准确率
        scheduler.step() #学习率衰减
        print("epoch: {}, Loss: {}, Acc: {}".format(epoch, Loss.item(), train_acc / len(train_dataloader))) #打印训练损失和准确率
        train_acces.append(train_acc / len(train_dataloader)) #保存训练准确率
        train_losses.append(Loss.item()) #保存训练损失

        # 测试步骤开始
        net.eval() #测试模式
        eval_loss = 0 #测试损失
        eval_acc = 0 #测试准确率
        with torch.no_grad(): #不计算梯度
            for imgs, targets in valid_dataloader:#测试集
                imgs = imgs.to(device) #将数据移动到device上
                targets = targets.to(device) #将数据移动到device上
                output = net(imgs) #前向传播
                Loss = loss(output, targets) #计算损失
                _, pred = output.max(1) #取出最大值
                num_correct = (pred == targets).sum().item() #计算正确个数
                eval_loss += Loss #计算测试损失
                acc = num_correct / imgs.shape[0] #计算准确率
                eval_acc += acc #计算测试准确率

            eval_losses = eval_loss / (len(valid_dataloader)) #计算测试损失
            eval_acc = eval_acc / (len(valid_dataloader)) #计算测试准确率
            if eval_acc > best_acc: #保存最好的模型
                best_acc = eval_acc #保存最好的准确率
                torch.save(net.state_dict(),'best_acc.pth') #保存模型
            eval_acces.append(eval_acc) #保存测试准确率
            print("整体验证集上的Loss: {}".format(eval_losses)) #打印测试损失
            print("整体验证集上的正确率: {}".format(eval_acc))  # 打印测试准确率
    return train_losses, train_acces, eval_acces #返回训练损失,训练准确率,测试准确率

def show_acces(train_losses, train_acces, valid_acces, num_epoch):#对准确率和loss画图显得直观
    plt.plot(1 + np.arange(len(train_losses)), train_losses, linewidth=1.5, linestyle='dashed', label='train_losses') #画图
    plt.plot(1 + np.arange(len(train_acces)), train_acces, linewidth=1.5, linestyle='dashed', label='train_acces') #画图
    plt.plot(1 + np.arange(len(valid_acces)), valid_acces, linewidth=1.5, linestyle='dashed', label='valid_acces') #画图
    plt.grid()#网格
    plt.xlabel('epoch') #x轴标签
    plt.xticks(range(1, 1 + num_epoch, 1)) #x轴刻度
    plt.legend() #图例
    plt.show() #展示

if __name__ == '__main__':
    train_dataloader, test_dataloader = get_dataloader(batch_size=64) #获取数据集
    show_pic(train_dataloader) #展示数据集
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #判断是否有GPU
    net = get_net() #获取模型
    loss = nn.CrossEntropyLoss() #定义损失函数
    train_losses, train_acces, eval_acces = train(net, loss, train_dataloader, test_dataloader, device, batch_size=64, num_epoch=20, lr=0.1, lr_min=1e-4, optim='sgd', init=False) #训练模型
    show_acces(train_losses, train_acces, eval_acces, num_epoch=20) #展示准确率和loss