import torch
from torch import optim, nn

from torch.utils.data import DataLoader
from torchvision import datasets
from transform import preprocess
# from model import initital_model,class_id2name
from models.resnet import *

TRAIN = '/media/dell/Elements/trainset/train'
VALID = '/media/dell/Elements/trainset/val'
train_data = datasets.ImageFolder(root=TRAIN, transform=preprocess)
val_data = datasets.ImageFolder(root=VALID, transform=preprocess)
batch_size = 16
num_workers = 2
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=num_workers, shuffle=False)

def test(model, test_loader):
    model.eval()
def evalute(model, loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    correct = 0
    total = len(loader.dataset)
    print(total)
    for x, y in loader:
        x,y = x.to(device), y.to(device)
        model.cuda()
        with torch.no_grad():
            logits = model(x)
            #argmax返回的是最大数的索引
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred,y).sum().float().item()
    print(correct)
    return correct / total
# model_ft = initital_model(model_name, num_classes, feature_extract=True)

model = resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
            nn.Dropout(0.2),  # 防止过拟合
            nn.Linear(in_features=num_ftrs, out_features=3)
        )
model_path='/data/classify/checkpoint/el_model_355_9941_10000.pth'
# model.load_state_dict({k.replace('fc.1','fc'):v for k,v in torch.load(model_path).items()})
model.load_state_dict(torch.load(model_path))
print(evalute(model,test_loader))


pytorch统计模型精度_过拟合