使用pytorch导入自己的数据有两种方法:

第一种:使用torchvision工具包中的datasets.ImageFolder(该方法较为简单)
第二种:使用torch.utils.data.Dataset,自定义导入数据的方式(需要根据不同情况编写代码)

第一种:torchvision.datasets.ImageFolder

要求:专门对于分类问题,将不同标签的图片分别放在不同的文件夹下,如图(将猫狗的图片分别放在两个不同的文件夹下),cat和dog文件夹放在data文件夹下。

pytorch导入本地数据集 pytorch导入自己的数据集_pytorch导入本地数据集

dataset = torchvision.datasets.ImageFolder('path')  # path:data文件夹的路径

第二种:自定义读取方式

要求:没有要求,可以是分类问题,也可以是回归问题(例如输入和输出同为图片)

需要自定义一个Dataset

from PIL import Image
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):

    def __init__(self, data_dir, transform=None):
        self.imgs = self.get_imgs(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        img_path, label = self.imgs[index]
        img = Image.open(img_path)
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)
    
    def get_images(data_dir):
    	imgs = []
    	for root, dirs, _ in os.walk(data_dir):     # dirs 为各类名
    		for sub_dirs in dirs:
    			img_names = os.listdir(os.path.join(root, sub_dir))  # 图片路径
    			for i in range(len(img_names)):
    				img_name = img_names[i]    # 图片名
    				path_img = os.path.join(root, sub_dir, img_name)
    				imgs.append((path_img, int(dirs)))
trainset = MyDataset(train_dir,transforms)
trainloader = DataLoader(trainset, batch_size=1)

整个代码分三步:

  1. 需要自己先定义一个类,继承torch.utils.data.Dataset,并初始化参数:主要为设置图片的路径和预处理方法
class MyDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.imgs = self.get_imgs(data_dir)
        self.transform = transform

data_dir:图片保存的位置
transform:图像预处理方法(可以看博主的博客transforms了解)

  1. 自定义读取文件的路径
    创建一个空的list,将输入图片的路径和输出图片的路径以tuple的形式逐个存入。
    在本例中,图片以输入1,标签1,输入2,标签2,…的形式保存的。
def get_images(data_dir):
    	imgs = []      # 创建一个空的list
    	for root, dirs, _ in os.walk(data_dir):     # 得到data_dir文件夹下所有的文件名(得到的dirs 为各类名)
    		for sub_dirs in dirs: 
    			img_names = os.listdir(os.path.join(root, sub_dir))  # 获得文件夹下所有图片路径
    			for i in range(len(img_names)//2):
    				img_input_name = img_names[i]    # 提取一个input图片名
    				img_label_name = img_name[i+1]  # 提取一个label图片名
    				path_img_1 = os.path.join(root, sub_dir, img_name) # 获得图片路径
    				path_img_2 = os.path.join(root, sub_dir, img_name) # 获得图片路径
    				imgs.append((path_img_1, path_img_2))

3.定义getitem,逐个读入图片
getitem为父类torch.utils.data.Dataset已经定义好的,它会逐个进行index=0,1,2,…。
只需要打开图片,进行图片预处理后,return即可。
定义len,返回样本数。

def __getitem__(self, index):
        img_path, label = self.imgs[index]
        img = Image.open(img_path)    # 打开图片
        if self.transform is not None:
            img = self.transform(img)    # 图片预处理
        return img, label

    def __len__(self):
        return len(self.imgs)

补充知识点:

DataLoader

torch.utils.data.DataLoader:构建可迭代的数据装载器

DataLoader(dataset, batch_size=1, shuffle=False, num_works=0)

dataset:Dataset类,决定数据从哪儿读取及如何读取
batch_size:批大小
shuffle:每个epoch是否乱序
num_works:是否多进程读取数据

Dataset

torch.utils.data.Dataset:所有自定义的Dataset需要继承它,并且复写

class Dataset(object):
	def __init__(self):
		pass
	def __getitem__(self, index):
		pass
	def __len__(self, other):
		pass

len:返回数据集的大小
getitem:接受一个样本,返回一个索引