建立数据集
- 前言
- 1.要点总结
- 1. 将训练集、测试集、验证集的图片放在三个文件夹中,尽量保证这三个文件夹不在移动
- 2.将图片切割,生成真正的训练集
- *文件夹需要提前建立,要不会报错*
- 3.测试集、验证集同理
- 4.将切割好的图片数据集建立文件,包含图片位置和标签两个信息。
- 保存txt文件后一定要读一行测试一下文件路径是否正确,我当时忘记了'/'
- 此时数据集建立完成,但教程还没完,因为pytorch读取也很难受
- 5.需要建立一个类,读取数据集txt文件
- 6.重点事项
- !!!注意大坑,返回图片格式为[通道数, 长或宽,长或宽]如果用Image.open方法读取的图片维度为[ 长或宽,长或宽,通道数],是不能*卷积*的
- 返回三个变量的原因
- 图片数据必须为float格式
- plot显示图片数据,需要先imshow再show,直接show无法正常显示。
- pytorch自己会转换独热编码,你不能输入独热编码,不然会报错
- 输入必须为数字,不能是字符‘a’
- 有多少种结果转换多少个数字
前言
新学cnn,尝试建立数据集。这是过了几天的描述,时间再长估计就要忘记了。
参考文章Pytorch学习(三)定义自己的数据集及加载训练 我是建立图片数据集,图片为四位验证码。
图片已经准备好,图片名称为图中的四位验证码。
如
图片名称为0S22.jpg
1.要点总结
1. 将训练集、测试集、验证集的图片放在三个文件夹中,尽量保证这三个文件夹不在移动
比例我忘记了,大概是6:2:2吧,以后查到了我再改
2.将图片切割,生成真正的训练集
由于四位验证码的可能性太大,节点太多(26+10)**4 = 1679616个节点,仅最后一层就这么多节点,所以将图片切割,生成单个字母的photo_cut,这样训练最后一层仅36个节点,大大减少计算量,最后检测四位验证码也就0.2秒(循环检测的,以后有时间学一下进程,直接四个进程检测)
将test_photo的图片保存在test_cut_photo中
文件夹需要提前建立,要不会报错
from PIL import Image
import os
def cut_photo(file_name_from, name, save_name):
"""
输入图片名称,按固定位置将图片分割为四部分,以便更好地识别字母/数字
:param name: 图片中字符名称
:return:
"""
img = Image.open(str(file_name_from)+"/" + str(name) + ".jpg")
if img.mode == "P":
img = img.convert('RGB')
# name = '1D61'
list_x_start = [2, 18, 30, 44]
list_x_end = [16, 32, 44, 58]
for i in range(4):
x_start = list_x_start[i]
x_end = list_x_end[i]
y_start = 4
y_end = 26
box = [x_start, y_start, x_end, y_end]
# print(box)
# box1 = (0, 0, 16, 28)
# box2 = (16, 0, 32, 28)
# box3 = (32, 0, 46, 28)
# box4 = (46, 0, 60, 28)
region = img.crop(box)
# region.show()
region.resize((22, 22))
file_name = str(save_name)+'/' + str(name) + str(i) + '_' + str(name[i]) + '.jpg'
# print(file_name)
region.save(file_name)
def get_photo_name(file_name_from):
"""
获取photo文件下图片的名字
:return:图片名称列表
"""
photo_name_chr = os.listdir('./' + str(file_name_from))
list_photo_name = []
return photo_name_chr
if __name__ == '__main__':
list_photo_name = get_photo_name('test_photo')
for name in list_photo_name:
cut_photo('test_photo', name[0:4], 'test_cut_photo')
直接将切割后的图片保存,不需要返回
切割后的图片名称保存格式[四位验证码]+[切割字符在的位置]+[切割字符的名称]
3.测试集、验证集同理
可以建立列表一次循环解决。
4.将切割好的图片数据集建立文件,包含图片位置和标签两个信息。
一开始学习时pytorch官方文档给的数据集是cv文件,我以为只能识别cv文件,其实什么文件都可以,我就是用的txt文件,excel应该也可以
注意一点,图片位置应该和标签在同一行,中间加个空格或者’,'逗号随意。
D:/python/report_new_2233/photo_cut/00163_6.jpg 6
D:/python/report_new_2233/photo_cut/00170_0.jpg 0
D:/python/report_new_2233/photo_cut/00171_0.jpg 0
由于有三个训练集、测试集、验证集,所以需要三个txt文件。
import os
from PIL import Image
from photo_cut import *
def make_name(file_name_from, img_file_txt):
"""
生成图片、标签文件对应的文件
:return:
"""
photo_name_chr = os.listdir('./' + str(file_name_from))
# print(photo_name_chr)
with open(img_file_txt, 'w') as fp:
for i in photo_name_chr:
root_name = 'D:/python/report_new_2233/' + str(file_name_from) + '/' + str(i)
fp.write(root_name)
fp.write(' ')
fp.write(i[-5])
fp.write('\n')
def show_photo():
"""
查看图片路径是否正确
:return:
"""
with open('name.txt', 'r') as fp2:
a = fp2.readline().split()
# print(a[0])
img = Image.open(a[0])
img.show()
保存txt文件后一定要读一行测试一下文件路径是否正确,我当时忘记了’/’
此时数据集建立完成,但教程还没完,因为pytorch读取也很难受
5.需要建立一个类,读取数据集txt文件
类中包含三中方法(隐藏方法,是叫这个吧)
主要是第三个方法__getitem__(self, index),他负责读取图片地址和标签,返回图片张量和标签张量。
6.重点事项
!!!注意大坑,返回图片格式为[通道数, 长或宽,长或宽]如果用Image.open方法读取的图片维度为[ 长或宽,长或宽,通道数],是不能卷积的
想转换维度,要么直接Tensor转换,要么转置
class Mydataset(Dataset):
"""
TypeError: show() takes 1 positional argument but 2 were given
现在直接读取6张图片,为什么
"""
def __init__(self, txt_name, train=True, transform=None, target_tranform=None, loader=default_loader, identify=None):
super(Mydataset, self).__init__()
# self.img_label =
line_list = []
if train:
file_name = r'D:/python/pytorch_learn/install/name.txt'
else:
file_name = r'D:/python/pytorch_learn/install/test_ph.txt'
if identify:
file_name = r'D:/python/pytorch_learn/install/CESHI.TXT'
print('这是验证集')
with open(file_name, 'r') as fp:
lines = fp.readlines()
for i in lines:
line = i.split()
line_list.append(line)
self.img = line_list
self.transform = transform
self.target_transform = target_tranform
self.loader = loader
def __len__(self):
return len(self.img)
def __getitem__(self, index):
path, label = self.img[index]
img = self.loader(path) # 读取出来维度为3, 22, 14 ?? 无法正常显示图片
# img = Image.open(path)
img = np.array(img, dtype=np.float32)
# # img = np.array(img, dtype=np.float32).reshape(14, 22, 3)
# img = read_image(path) # 读取数据不需要transform转换
# print(img.size()) # torch.Size([3, 22, 14]) ————需要转至
# img = np.array(img) # 924
# print(img.size)
label_one_hot = one_hot(label) # 不是独热编码,只是字符都变成了数字
if self.transform:
# img = torch.from_numpy(img) # 所以numpy转Tensor唯独不变
# print(img.size())
img = self.transform(img) # Totensor维度改变 torch.Size([3, 22, 14])
# print(img.size())
# img = img.squeeze(0) # 只能去除维度=1的维度
# img = img.T # 现在可以正常显示,但是图片完全翻转了
# print(img.size())
# 转置前维度为(3, 22, 14), 转置后torch.Size([14, 22, 3])
# print(label, img)
return img, label_one_hot, path
if __name__ == "__main__":
# mydata = Mydataset('name.txt')
# print(mydata.__getitem__(5))
train_data = Mydataset(
txt_name='name.txt', # 暂时未使用
train=True,
transform=ToTensor()
)
# data_train = DataLoader(train_data, batch_size=64)
img, label, path = train_data[6]
photo = Image.open(path)
plt.title(label)
plt.imshow(photo)
plt.show()
返回三个变量的原因
正常情况下 __getitem__应该只返回图片张量和标签,但是我的图片经过Tensor转换后无法正常显示,因为图片Tensor转换后维度为[通道数, 长或宽,长或宽],正常图片显示维度为[ 长或宽,长或宽,通道数]
2021.10.22补充
写rnn分类图片的时候突然发现的代码,可以直接显示ToTensor转换后的图片
利用dataset的数据集提取可以做到
train_data = torchvision.datasets.FashionMNIST(
root='D\python\pytorch_learn\install\data',
train=True,
transform=torchvision.transforms.ToTensor(),
download=False,
)
plt.imshow(train_data.data[0].numpy(), cmap='gray')
plt.show()
这个代码可以将显示Tensor的维度的代码
torch.Size([1, 28, 28])
# 一般读取图片的格式为[ 28, 28, 1]即通道数在尺寸后面
图片数据必须为float格式
直接读取的图片像素为整数,需要转换为np的同时调整为float再tranform要不然无法正常读取。
这一点最坑的是他不显示问题,就是预测不出来。不管什么图片,预测结果是’RRRR’。就是找不到问题。
plot显示图片数据,需要先imshow再show,直接show无法正常显示。
这点我还未找到原因,有时间再说。
或者也可以使用
plt.ion()
打开自动显示图像模式(我自己还没试)
也可以参考这篇文章链接: matplotlib中plot.show()不显示图片的问题.
pytorch自己会转换独热编码,你不能输入独热编码,不然会报错
输入必须为数字,不能是字符‘a’
但是也不能输入字符,必须将输入转换为数字,因为’a’无法进行矩阵运算,也无法转换为独热编码,需要你自己将字符转换为数字,怎么转化随意,比如我将’a’作为11,这样0-9,a-z转换为0-36。
有多少种结果转换多少个数字
一开始我想讲字符转为ASCii码,毕竟方便,z转换为96,此时生成的独热编码维度为[36, 96],大大浪费了计算资源。
z转换为36,独热编码维度为[36, 36],就很舒服。