T
r
a
i
n
Train
Train
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from datetime import datetime
from major_dataset import LoadDataset
from major_evalution import eval_semantic_segmentation
import major_config
# ****************************************step1 数据处理**********************************************#
Load_train = LoadDataset([major_config.train_image, major_config.train_label], major_config.crop_size)
Load_val = LoadDataset([major_config.val_image, major_config.val_label], major_config.crop_size)
train_data = DataLoader(Load_train, batch_size=major_config.batchsize, shuffle=True, num_workers=1)
val_data = DataLoader(Load_val, batch_size=major_config.batchsize, shuffle=True, num_workers=1)
# *****************************************step2 模型*********************************************#
net = major_config.model
net = net.to(major_config.device)
# ******************************************step3 损失函数********************************************#
criterion = nn.NLLLoss().to(major_config.device) # NLLLoss有利于最后激活层的替换
# ******************************************step4 优化器********************************************#
optimizer = optim.Adam(net.parameters(), lr=1e-4)
# ******************************************step5 训练********************************************#
def train(model):
best = [0] # 存储最优指标,用于Early Stopping
net = model.train() # 指定模型为训练模式,即可以进行参数更新
# 训练轮次
for epoch in range(major_config.num_epoch):
print('Epoch is [{}/{}]'.format(epoch + 1, major_config.num_epoch))
# 每20次epoch,lr学习率降一半
if epoch % 20 == 0 and epoch != 0:
for group in optimizer.param_groups:
group['lr'] *= 0.5
# 指标初始化
train_loss = 0
train_pa = 0
train_mpa = 0
train_miou = 0
train_fwiou = 0
# 训练批次
for i, sample in enumerate(train_data):
# 载入数据
img_data = sample['img'].to(major_config.device)
img_label = sample['label'].to(major_config.device)
# 训练
out = net(img_data)
out = F.log_softmax(out, dim=1)
loss = criterion(out, img_label) # loss计算
optimizer.zero_grad() # 需要梯度清零,再反向传播
loss.backward() # 反向传播
optimizer.step() # 参数更新
train_loss += loss.item() # loss累加
# 评估
# 预测值
pre_label = out.max(dim=1)[1].data.cpu().numpy() # [1]:表示返回索引
pre_label = [i for i in pre_label]
# 真实值
true_label = img_label.data.cpu().numpy()
true_label = [i for i in true_label]
# 计算所有的评价指标
eval_metrix = eval_semantic_segmentation(pre_label, true_label)
# 各评价指标计算
train_pa += eval_metrix['pa']
train_mpa += eval_metrix['mpa']
train_miou += eval_metrix['miou']
train_fwiou += eval_metrix['fwiou']
# 打印损失
print('|batch[{}/{}]|batch_loss {: .8f}|'.format(i + 1, len(train_data), loss.item()))
# 评价指标打印格式定义
metric_description = '|Train PA|: {:.5f}|\n|Train MPA|: {:.5f}|\n|Train MIou|: {:.5f}|\n|Train FWIou|: {:.5f}|'.format(
train_pa / len(train_data),
train_mpa / len(train_data),
train_miou / len(train_data),
train_fwiou / len(train_data),
)
# 打印评价指标
print(metric_description)
# 根据train_miou,保存最优模型
if max(best) <= train_miou / len(train_data):
best.append(train_miou / len(train_data))
torch.save(net.state_dict(), major_config.path_saved_model)
# ******************************************step6 评价********************************************#
def evaluate(model):
net = model.eval()
eval_loss = 0
eval_acc = 0
eval_miou = 0
eval_class_acc = 0
prec_time = datetime.now()
for j, sample in enumerate(val_data):
valImg = sample['img'].to(major_config.device)
valLabel = sample['label'].long().to(major_config.device)
out = net(valImg)
out = F.log_softmax(out, dim=1)
loss = criterion(out, valLabel)
eval_loss = loss.item() + eval_loss
pre_label = out.max(dim=1)[1].data.cpu().numpy()
pre_label = [i for i in pre_label]
true_label = valLabel.data.cpu().numpy()
true_label = [i for i in true_label]
eval_metrics = eval_semantic_segmentation(pre_label, true_label)
eval_acc = eval_metrics['mean_class_accuracy'] + eval_acc
eval_miou = eval_metrics['miou'] + eval_miou
cur_time = datetime.now()
h, remainder = divmod((cur_time - prec_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = 'Time: {:.0f}:{:.0f}:{:.0f}'.format(h, m, s)
val_str = ('|Valid Loss|: {:.5f} \n|Valid Acc|: {:.5f} \n|Valid Mean IU|: {:.5f} \n|Valid Class Acc|:{:}'.format(
eval_loss / len(train_data),
eval_acc / len(val_data),
eval_miou / len(val_data),
eval_class_acc / len(val_data)))
print(val_str)
print(time_str)
if __name__ == "__main__":
train(net)
# evaluate(net) 验证可以自己设置每训练多少次,验证一次,所以,evaluate()函数可以放到train()里面