使用pytorch框架。模型包含13层卷积层、2层池化层、15层全连接层。为什么叠这么多层?就是玩。
FashionMNIST数据集包含训练集6w张图片,测试集1w张图片,每张图片是单通道、大小28×28。
import argparse
import torch
import torch.nn as nn # 指定torch.nn别名nn
import torch.optim as optim
import torchvision # 一些加载数据的函数及常用的数据集接口
import torchvision.transforms as transforms
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # out 14
nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # out 7
)
self.linear = nn.Sequential(
nn.Linear(128 * 7 * 7, 128), # 全连接层
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 128),
nn.BatchNorm1d(num_features=128),
nn.ReLU(inplace=True),
nn.Linear(128, 10),
)
def forward(self, inputs):
out = self.conv(inputs)
out = out.view(out.size(0), -1)
logits = self.linear(out)
return logits
def train(model, device, train_loader, criterion, optimizer):
model.train()
total_loss, total_num = 0, 0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
logits = model(data)
loss = criterion(logits, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_num += 1
return total_loss / total_num
def test(model, device, test_loader):
model.eval()
t_correct, t_sum = 0, 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
logits = model(data)
pred = logits.argmax(dim=1)
t_correct += torch.eq(pred, target).float().sum().item()
t_sum += data.size(0)
acc = t_correct / t_sum
return acc
def main():
fashionmnist_train = torchvision.datasets.FashionMNIST(
root='F:/dataset/minist', train=True, download=True,
transform=transforms.ToTensor(),
)
fashionmnist_test = torchvision.datasets.FashionMNIST(
root='F:/dataset/minist', train=False, download=True,
transform=transforms.ToTensor(),
)
print('训练集和测试集大小:', len(fashionmnist_train), len(fashionmnist_test))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_iter = torch.utils.data.DataLoader(fashionmnist_train, batch_size=args.batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(fashionmnist_test, batch_size=args.test_batch_size)
model = CNN().to(device) # 类转换到cuda上
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.001) # model.parameters() 网络里面的参数
lr_sch = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9)
for epoch in range(args.epochs):
train_loss = train(model, device, train_iter, criterion, optimizer)
acc = test(model, device, test_iter)
print(f'epoch {epoch:2d}: loss = {train_loss:.6f}; acc={acc:.4f}; lr:{lr_sch.get_last_lr()[0]:.8f}')
lr_sch.step() # 执行一次学习率衰减 lr = lr * 衰减率
if args.save_model:
torch.save(model.state_dict(), "cnn.pt")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='test')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=20, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
# 解析参数
args = parser.parse_args()
main()
运行结果:
训练集和测试集大小: 60000 10000
epoch 0: loss = 0.680793; acc=0.8512; lr:0.00100000
epoch 1: loss = 0.413065; acc=0.8831; lr:0.00090000
epoch 2: loss = 0.375292; acc=0.8931; lr:0.00081000
epoch 3: loss = 0.352440; acc=0.8852; lr:0.00072900
epoch 4: loss = 0.342234; acc=0.8936; lr:0.00065610
epoch 5: loss = 0.316583; acc=0.8845; lr:0.00059049
epoch 6: loss = 0.304564; acc=0.8570; lr:0.00053144
epoch 7: loss = 0.282753; acc=0.9003; lr:0.00047830
epoch 8: loss = 0.259815; acc=0.9098; lr:0.00043047
epoch 9: loss = 0.237412; acc=0.9196; lr:0.00038742
epoch 10: loss = 0.218256; acc=0.9124; lr:0.00034868
epoch 11: loss = 0.205545; acc=0.9228; lr:0.00031381
epoch 12: loss = 0.187289; acc=0.9243; lr:0.00028243
epoch 13: loss = 0.173118; acc=0.9268; lr:0.00025419
epoch 14: loss = 0.156642; acc=0.9301; lr:0.00022877
epoch 15: loss = 0.140609; acc=0.9210; lr:0.00020589
epoch 16: loss = 0.128453; acc=0.9296; lr:0.00018530
epoch 17: loss = 0.116559; acc=0.9321; lr:0.00016677
epoch 18: loss = 0.104093; acc=0.9324; lr:0.00015009
epoch 19: loss = 0.089764; acc=0.9308; lr:0.00013509