文章目录
- 前言
- 一、Dataset、DataLoader是什么?
- 二、如何定义Dataset?
- 1.定义 Dataset
- 三、如何使用DataLoader?
- 1. 使用Dataloader加载数据集
- 四、可视化源数据
- 五、完整代码
- 参考
前言
深度学习初入门小白,技艺不精,写下笔记记录自己的学习过程。欢迎评论区交流提问,力所能及之问题,定当毫无保留之相授。
一、Dataset、DataLoader是什么?
Dataset:是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中。
Dataloader:通过DataLoader这个函数,我们在加载数据集时候,批次读取数据及多线程并行处理,这样可以加快我们读取数据集的速度。
二、如何定义Dataset?
Dataset类是Pytorch中数据集加载类中应该继承的父类。通常包括这三部分:
1.*def __init__(self)*
2.*def __getitem__(self, index):*
3.*def __len__(self):*
其中父类中的两个私有成员函数,__len__和__getitem__必须被重载!
1.定义 Dataset
#root1和root2分别为训练集,验证集存放图片路径及标签的txt路径
root1 = r"C:\Users\asus\Desktop\mstar_classification\mstar\train.txt"
root2 = r"C:\Users\asus\Desktop\mstar_classification\mstar\val.txt"
# 1、构建数据集类
class Mydata(Dataset):
# __init__
# 该函数可以包含多个参数,如数据的读取路径和对数据的处理设置等一系列设定
# txt:存放着图片数据的路径和标签信息,words[0]为图片的路径,words[1]为图片的标签,如下图所示。(txt需要事先生成,如何生成先挖个坑)
# imgs:按行读取txt,并依次存放到列表中
# transform为:图片数据增强,下文中会讲
def __init__(self, txt, transform=None, target_transform=None):
super(Mydata, self).__init__()
imgs = []
fh = open(txt, 'r')
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1]))) # imgs中包含有图像路径和标签
self.txt = txt
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
# __getitem__
# 接收一个index,然后返回图片路径和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。
# 在本代码中,这个list为imgs[]
# 图片打开方式为Image.open,三通道RGB格式。若数据集图片为单通道,可在transform中添加transforms.Grayscale(1)函数。
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(os.path.join(self.txt[:-4], fn))#self.txt[:-4],下文加载txt时,路径中不需要有后缀,所以去掉.txt四个字符
if self.transform is not None:
img = self.transform(img)
return img, label
#__len__
#返回样本的总数量, 该方法提供了dataset的大小
def __len__(self):
return len(self.imgs)
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ColorJitter(),
transforms.Grayscale(1), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
test_transform = transforms.Compose([transforms.Grayscale(1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_data = Mydata(txt=root1, transform=train_transform)
test_data = Mydata(txt=root2, transform=test_transform)
txt中存放着图片的路径及标签
三、如何使用DataLoader?
该函数的作用是将数据整理成一个batch,即根据batch_size的大小一次性在数据集中取出batch_size个数据。例如数据集中有100条数据,batch_size的值为20,则每次在100条数据中取出20条数据。
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
# dataset: 加载torch.utils.data.Dataset对象数据,即为上文中的train_data和test_data
# batch_size: 每个batch的大小
# shuffle:是否对数据进行打乱
# drop_last:是否对无法整除的最后一个datasize进行丢弃
# um_workers:表示加载的时候子进程数,一般GPU使用
1. 使用Dataloader加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
四、可视化源数据
examples = enumerate(train_loader)
batch_idx, (examples_data, examples_targets) = next(examples)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i+1)
plt.tight_layout()#自动调整子图参数,使之填充满整个图像区域
plt.imshow(examples_data[i][0], interpolation='none')
plt.title("Category:{}".format(examples_targets[i]))
plt.xticks([])
plt.yticks([])
plt.show()
五、完整代码
注意:
1.数据集的路径需要改成自己的
2.前提需要生成相应的txt文件
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os
root1 = r"C:\Users\asus\Desktop\mstar_classification\mstar\train.txt"
root2 = r"C:\Users\asus\Desktop\mstar_classification\mstar\val.txt"
# 1、构建数据集
class Mydata(Dataset):
def __init__(self, txt, transform=None, target_transform=None):
super(Mydata, self).__init__()
self.txt = txt
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1]))) # imgs中包含有图像路径和标签
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(os.path.join(self.txt[:-4], fn))
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
# 2.数据增强、加载数据
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ColorJitter(),
transforms.Grayscale(1), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
test_transform = transforms.Compose(
[transforms.Grayscale(1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
# 是被封装进DataLoader里,实现该方法封装自己的数据和标签
train_data = Mydata(txt=root1, transform=train_transform)
test_data = Mydata(txt=root2, transform=test_transform)
# DataLoader被封装入DataLoader里,实现该方法达到数据的划分
# train_data 和test_data包含多有的训练与测试数据,调用DataLoader批量加载
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
# 3.可视化源数据
examples = enumerate(train_loader)
batch_idx, (examples_data, examples_targets) = next(examples)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout() # 自动调整子图参数,使之填充满整个图像区域
plt.imshow(examples_data[i][0], interpolation='none')
plt.title("Category:{}".format(examples_targets[i]))
plt.xticks([])
plt.yticks([])
plt.show()