keras中的EarlyStopping使用很方便,当但我测试torch的EarlyStopping时掉坑了!!!

torch中pytorchtools工具有早停法,但我测试基本用不了,总是出错

from pytorchtools import EarlyStopping

 torch中的使用方法:

model = yourModel()	# 伪代码
# 指定损失函数,可以是其他损失函数,根据训练要求决定
criterion = nn.CrossEntropyLoss()	# 交叉熵损失函数,注意该损失函数对自动对批量样本的损失取平均
# 指定优化器,可以是其他
optimizer = torch.optim.Adam(model.parameters())
# 初始化 early_stopping 对象
patience = 20	# 当验证集损失在连续20次训练周期中都没有得到降低时,停止模型训练,以防止模型过拟合
early_stopping = EarlyStopping(patience, verbose=True)	# 关于 EarlyStopping 的代码可先看博客后面的内容

batch_size = 64	# 或其他,该参数属于超参,对于如何选择超参,你可以参考下我的上一篇博客
n_epochs = 100	# 可以设置大一些,毕竟你是希望通过 early stopping 来结束模型训练
#----------------------------------------------------------------
# 训练模型,直到 epoch == n_epochs 或者触发 early_stopping 结束训练
for epoch in range(1, n_epochs + 1):

	# 建立训练数据的 DataLoader
	training_dataset = Data.TensorDataset(X_train, y_train)
    # 把dataset放到DataLoader中
    data_loader = Data.DataLoader(
        dataset=training_dataset,
        batch_size=batch_size,	 # 批量大小
        shuffle=True	         # 是否打乱数据顺序
    )
    # ---------------------------------------------------
    model.train()	             # 设置模型为训练模式
    # 按小批量训练
	for batch, (data, target) in enumerate(data_loader):
		optimizer.zero_grad()	# 清楚所有参数的梯度
		output = model(data)	# 输出模型预测值
		loss = criterion(output, target)	# 计算损失
		loss.backward()                 	# 计算损失对于各个参数的梯度
		optimizer.step()	                # 执行单步优化操作:更新参数
	# ----------------------------------------------------
	model.eval()                            # 设置模型为评估/测试模式,关闭dropout,并将模型参数锁定
	# 一般如果验证集不是很大的话,模型验证就不需要按批量进行了,但要注意输入参数的维度不能错
	valid_output = model(X_val)
	valid_loss = criterion(valid_output, y_val)	# 注意这里的输入参数维度要符合要求,我这里为了简单,并未考虑这一点

	early_stopping(valid_loss, model)
	# 若满足 early stopping 要求
	if early_stopping.early_stop:
		print("Early stopping")
		# 结束模型训练
		break
# 获得 early stopping 时的模型参数
model.load_state_dict(torch.load('checkpoint.pt'))

个人自定义的早停,参考:https://zhuanlan.zhihu.com/p/350982073

from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
# from model.cbam_resunet_se_resunet_deep_supervision import CbamResUNet_deep
# from model.resnet50_unet import Resnet_Unet
# from model.cenet_se import CE_Net
from model.unet_model import UNet
import torch.nn as nn
import torch
from torch.autograd import Variable
from torchsummary import summary
from tqdm import tqdm
from dice_loss import DiceLoss
# from pytorchtools import EarlyStopping
from torch.utils.tensorboard import SummaryWriter
import numpy as np
# from dice_loss2 import dice_loss
# 
# https://github.com/milesial/Pytorch-UNet



def train_net(net, device):
    # 加载训练集
    epochs=190
    batch_size=6
    lr=0.0001
    n_classes=1
    data_path = r"data/train/"
    isbi_dataset = ISBI_Loader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size, 
                                               shuffle=True)
    print(len(isbi_dataset))
    print(len(train_loader))
    writer = SummaryWriter()
    global_step = 0
    # 定义RMSprop算法
    # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    optimizer = optim.Adam(net.parameters(),lr=lr,betas=(0.9,0.99))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' , patience=2)
    # early_stopping = EarlyStopping(patience=10, verbose=True)
    # 定义Loss算法
    # if n_classes > 1:
    #     criterion = nn.CrossEntropyLoss()
    # else:
    #     criterion = nn.BCEWithLogitsLoss()

    criterion=DiceLoss() 
    # best_loss统计,初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    
    es=0
    
    for epoch in range(epochs):
        # 训练模式
        net.train()
        train_losses=[]
        # 按照batch_size开始训练
        print("epoch......: ",epoch)
        with tqdm(total=30, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for image, label in train_loader:
                optimizer.zero_grad()
                # 将数据拷贝到device中
            
                image = image.to(device=device, dtype=torch.float32)
                label = label.to(device=device, dtype=torch.float32)

                mask_type = torch.float32 if n_classes == 1 else torch.long
                label = label.to(device=device, dtype=mask_type)
                # image, label = Variable(image, requires_grad=False), Variable(label, requires_grad=False)
                
                # image, label = Variable(image.cuda(), requires_grad=False), Variable(label.cuda(),
                #                                                                                 requires_grad=False)
                # 使用网络参数,输出预测结果
                pred = net(image)
                # 计算loss
                # loss = criterion(pred, label)
                loss = criterion(pred, label)
                train_losses.append(loss.item())
                # loss = loss.requires_grad_()
                # print('Loss/train', loss.item())
                # 保存loss值最小的网络参数
            
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                
                # 更新参数
                loss.backward()
                optimizer.step()

                # writer.add_scalar('Loss/train', loss.item(), global_step)
                # writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
                writer.add_scalar('Loss/train', loss.item(), epoch)
                writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)


                pbar.update(image.shape[0])
                global_step += 1

            train_epoch_loss=np.average(train_losses)
            

            #保存最好的模型
            # https://zhuanlan.zhihu.com/p/350982073
            scheduler.step(train_epoch_loss)
            if float(train_epoch_loss) < best_loss:
                best_loss = train_epoch_loss
                # torch.save(net.state_dict(), 'best_model.pth')
                torch.save(net.state_dict(),r"./checkpoint/" + f'CP_epoch{epoch + 1}_{train_epoch_loss}.pth')
                es=0
            else:
                es+=1
                if es>5:
                    print("Early stopping with best_acc:" )
                    break

            train_losses = [] # 清零

            # early_stopping(train_epoch_loss, model)  #早停
            # if early_stopping.early_stop:
            #     print("Early stopping")
            #     break

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # # 加载网络,图片单通道1,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    net.cuda()

    # net = CbamResUNet_deep(n_channels=1, n_classes=1)
    # net.cuda()

    # net = CE_Net()
    # net.cuda()
    # summary(net, (3, 256, 256))
    # 指定训练集地址,开始训练
    train_net(net, device)

最后推荐一个不错的写法

def train_model(model, batch_size, patience, n_epochs):
    
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in range(1, n_epochs + 1):

        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        for batch, (data, target) in enumerate(train_loader, 1):
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            loss = criterion(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # record training loss
            train_losses.append(loss.item())

        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
        for data, target in valid_loader:
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            loss = criterion(output, target)
            # record validation loss
            valid_losses.append(loss.item())

        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = len(str(n_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    # load the last checkpoint with the best model
    model.load_state_dict(torch.load('checkpoint.pt'))

    return  model, avg_train_losses, avg_valid_losses