迁移学习就是下载别人训练好的神经网络结构作为自己的预训练网络,并用到自己感兴趣的地方去。在计算机视觉领域用的十分广泛。

本篇文章实现的是使用matlab自带的数据文件对AlexNet进行迁移学习(文中路径相关代码可能需要修改才能成功运行),并使用迁移学习后的网络对图片进行识别。

AlexNet是一个深达8层的卷积神经网络。您可以从ImageNet数据库加载经过一百万个图像训练的网络的预训练版本。预先训练的网络可以将图像分类为1000个对象类别,例如键盘,鼠标,铅笔和许多动物。结果,网络已经学会了针对各种图像的丰富的特征表示。网络的图像输入大小为227×227。

1首先下载alexnet网络

net = alexnet;

2导入数据集,并将其分为训练集和验证数据集,最后随机显示16张示例图像

path='E:\matlab\installpath\bin\MerchData';
imds = imageDatastore(path, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames'); 
% imds = imageDatastore('E:\matlab\installpath\bin\MerchData',...
%     'IncludeSubfolders',true, ...
%     'LabelSource','foldernames');
%解压缩新图像并将其加载为图像数据存储。imageDatastore根据文件夹名称自动标记图像,
% 并将数据存储为ImageDatastore对象。图像数据存储库使您可以存储大的图像数据(包括不适合内存的数据),
% 并在训练卷积神经网络期间有效地读取一批图像。
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);
% 将数据分为训练和验证数据集。使用70%的图像进行训练,并使用30%的图像进行验证。
% splitEachLabel将images数据存储分为两个新的数据存储。
numTrainImages = numel(imdsTrain.Labels)
%获取文件总数
idx = randperm(numTrainImages,16);
%产生一个随机数idx(i)
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsTrain,idx(i));
    imshow(I)
end
%for函数是为了显示i个图片

3加载预训练网络:选取网络中除了最后三层的所有层(对于新的分类问题最后三层必须微调),最后三层替换为:完全连接的层,softmax层和分类输出层。根据新数据指定新的完全连接层的选项,将完全连接的层设置为与新数据中的类数相同的大小。numClasses为类别数。

net = alexnet;
analyzeNetwork(net);
% 使用analyzeNetwork以显示网络架构的交互式可视化和有关网络层的详细信息。
inputSize = net.Layers(1).InputSize;
% 第一层是图像输入层,要求输入图像的尺寸为227×227×3,其中3是颜色通道的数量。
layerTransfer = net.Layers(1:end-3);
%替换最终层,预训练网络的最后三层net配置为1000个班级。
% 对于新的分类问题,必须对这三层进行微调。从预训练的网络中提取除最后三层以外的所有层。
numClasses = numel(categories(imdsTrain.Labels))
% 通过将最后三层替换为完全连接的层,softmax层和分类输出层,
% 将层转移到新的分类任务。根据新数据指定新的完全连接层的选项。
% 将完全连接的层设置为与新数据中的类数相同的大小。
% 要在新层中学习比在传输层中学习更快,
% 请增加完全连接层的WeightLearnRateFactor和BiasLearnRateFactor值。
layers = [
    layerTransfer
    fullyConnectedLayer(numClasses,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20)
    softmaxLayer
    classificationLayer];
%layers是经过微调的网络

4.训练网络:首先对数据集进行增强,设置培训选项,然后在被微调过的网络上进行训练。

%training net
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);
% 网络要求输入图像的大小为227×227×3,但是图像数据存储区中的图像具有不同的大小。
% 使用增强图像数据存储区来自动调整训练图像的大小。指定要在训练图像上执行的其他增强操作:
% 沿垂直轴随机翻转训练图像,并在水平和垂直方向上将它们随机平移最多30个像素。数据增强有助于防止网络过度拟合和记忆训练图像的确切细节。

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
% %augmentedImageSource()用于数据增强,
% 第一个入口参数表示增强后数据的输出尺寸,第二个参数表示需要增强的数据,在此是对训练集进行增强。

options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',3, ...
    'Verbose',false, ...
    'Plots','training-progress');
% 指定培训选项。对于转移学习,请保留预培训网络的早期层(转移层重量)的功能。
% 要在转移层中减慢学习速度,将初始学习速率设置为小值。
% 在前一步中,您增加了完全连接层的学习速率因素,以加快新最终层的学习速度。
% 这种学习速率设置的组合仅在新层中产生快速学习,而在其他层中导致学习速度较慢。
% 在进行转学学习时,您不需要为那么多的时代进行培训。
% 一个时代是整个培训数据集的完整培训周期。指定微型批次大小和验证数据。
% 该软件在培训期间对网络进行每次迭代验证。ValidationFrequency
netTransfer = trainNetwork(augimdsTrain,layers,options);
%迁移学习后的网络是netTransfer

5分类验证图像:首先对增强后的数据集进行classify()产生标签和可能性,然后随机显示四个被测试样本进行验证,最后用mean函数计算验证集分类精度。

% 将进行增强后的数据进行分类,产生标签和可能性大小
[YPred,scores] = classify(netTransfer,augimdsValidation);

% 显示四个样本验证图像及其预测标签。
idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label));
end

% 计算验证集的分类精度。准确性是网络正确预测的标签的分数
YValidation = imdsValidation.Labels;
accuracy = mean(YPred == YValidation)

6运行结果:随机显示验证集的几个数据及其标签和验证集的正确率:

matlab保存训练好的生成对抗网络 matlab训练的网络保存在哪里_神经网络

matlab保存训练好的生成对抗网络 matlab训练的网络保存在哪里_matlab保存训练好的生成对抗网络_02

7.使用训练前后网络对图片进行识别:

%识别单张图片,并展示结果
I = imread('C:\Users\admin\Desktop\download.jpg');
I = imresize(I,inputSize(1:2));
label1 = classify(net,I)
label2 = classify(netTransfer,I)
figure('numbertitle','off','name','net','color','white')
imshow(I)
title(label1)
figure('numbertitle','off','name','netTransfer','color','white')
imshow(I)
title(label2)

 

 

matlab保存训练好的生成对抗网络 matlab训练的网络保存在哪里_神经网络_03