模 型 的 评 测 模型的评测 模型的评测
# for major_test
import torch
import major_config
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from major_dataset import LoadDataset
def evaluteTop1(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = (major_config.device), (major_config.device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
# correct += torch.eq(pred, y).sum().item()
return correct / total
def evaluteTop5(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = (major_config.device), (major_config.device)
with torch.no_grad():
logits = model(x)
maxk = max((1, 5))
y_resize = y.view(-1, 1)
_, pred = logits.topk(maxk, 1, True, True)
correct += torch.eq(pred, y_resize).sum().float().item()
return correct / total
if __name__ == "__main__":
# 1.加载测试数据
# 1.1 预处理
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(major_config.norm_mean, major_config.norm_std),
])
# 1.2 数据加载
test_data = LoadDataset(data_dir=major_config.test_image, transform=test_transform)
test_loader = DataLoader(dataset=test_data, batch_size=major_config.batchsize, shuffle=True) # shuffle训练时打乱样本
# 2.加载模型
net = major_config.model # 对应修改模型 net = se_resnet50(num_classes=5,pretrained=True)
path_model_state_dict = major_config.path_test_model
net.load_state_dict(torch.load(path_model_state_dict))
# 3.评测
evaluteTop1(net,test_loader)
evaluteTop5(net,test_loader)