import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision import datasets
from transform import preprocess
# from model import initital_model,class_id2name
from models.resnet import *
TRAIN = '/media/dell/Elements/trainset/train'
VALID = '/media/dell/Elements/trainset/val'
train_data = datasets.ImageFolder(root=TRAIN, transform=preprocess)
val_data = datasets.ImageFolder(root=VALID, transform=preprocess)
batch_size = 16
num_workers = 2
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=num_workers, shuffle=False)
def test(model, test_loader):
model.eval()
def evalute(model, loader):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
correct = 0
total = len(loader.dataset)
print(total)
for x, y in loader:
x,y = x.to(device), y.to(device)
model.cuda()
with torch.no_grad():
logits = model(x)
#argmax返回的是最大数的索引
pred = logits.argmax(dim=1)
correct += torch.eq(pred,y).sum().float().item()
print(correct)
return correct / total
# model_ft = initital_model(model_name, num_classes, feature_extract=True)
model = resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.2), # 防止过拟合
nn.Linear(in_features=num_ftrs, out_features=3)
)
model_path='/data/classify/checkpoint/el_model_355_9941_10000.pth'
# model.load_state_dict({k.replace('fc.1','fc'):v for k,v in torch.load(model_path).items()})
model.load_state_dict(torch.load(model_path))
print(evalute(model,test_loader))