一、Dataset和Dataloader
1. 定义
Dataset本质上就是一个抽象类,可以把数据封装成Python可以识别的数据结构。
Dataset类不能实例化,所以在使用Dataset的时候,我们需要定义自己的数据集类,也是Dataset的子类,来继承Dataset类的属性和方法。
Dataset可作为DataLoader的参数传入DataLoader,实现基于张量的数据预处理
Dataset和DataLoader都是用来帮助我们加载数据集的两个重要工具类:Dataset用来构造支持索引的数据集。在训练时需要在全部样本中拿出小批量数据参与每次的训练,因此我们需要使用DataLoader,即DataLoader是用来在Dataset里取出一组数据(mini-batch)供训练时快速使用的。
2. Dataset和Dataloader详解
- _init_():构造函数,初始化。
- _getitem_():_DataLoaderIter()类中有调用:
- _len_():调用len()函数:
# Dataset和Dataloader需要的方法和类,导入一下
from torch.utils.data import Dataset, DataLoader
class Myclass(Dataset) ##继承Dataset类
#首先需要初始化,根据创建实例的时候,需要运行的函数
def __init__(self): #可以传入一些参数,如数据路径,label路径等
def __getitem__(self, idx): # idx相当于一个编号,可以在该函数下对数据集中每条数据进行处理,如返回数据,label等
def __len__(self):# 可以通过len知道数据集长度
- collate_fn(self, batch):
在 PyTorch 中,数据集通常需要经过一些预处理才能用于训练模型。在预处理数据时,常常需要将多个样本组合成一个 mini-batch,并对这些样本进行 padding,以保证它们具有相同的维度。
在这种情况下,我们可以使用 DataLoader 类来加载数据集。DataLoader 类有一个可选参数 collate_fn,用于定义如何组合样本成 mini-batch。
collate_fn 是一个函数,它接受一个样本列表作为输入,将这些样本组合成一个 mini-batch,并返回该 mini-batch。这个函数可以根据数据集的不同而不同。通常情况下,我们需要将样本中的文本序列进行 padding,以使它们具有相同的长度。
对该函数的理解参考:pytorch中DataLoader中的collate_fn什么意思,例子 和pytorch中collate_fn函数的使用&如何向collate_fn函数传参
首先Dataloader会根据batch参数生成一个长度为batch值的列表,列表的值是myDataset()类中__getitem__()的参数,如果shuffle为True ,列表的值就是从0到len(data)中随机抽样索引。然后列表的索引值会依次送入__getitem__()方法,最终返回一个列表(index即idx)的数据,该列表数据会作为collate_fn 函数的默认参数传入,最终得到一个batch的数据。如下__getitem__返回(x,y)两项,则shape为(batch_size(一个batch大小), 2,…)。
loader = Dataloader(dataset, batch_size, shuffle, collate_fn, ...)
'''
dataset:要载入的数据集
batch_size:批大小,每个批中的样本数
shuffle:是否载入数据集时是否要随机选取(打乱顺序),True为打乱顺序,False为不打乱。布尔型,只能取None、True、False
num_workers:进程数。用来实现并行化加载数据
drop_last (bool, optional):默认false,是否舍弃最后不满足批大小的数据
......
'''
class Mydataset(Dataset):
def __init__(self):
super().__init__()
def __getitem__(self, idx):
return x, y
collate_fn 函数可以使用系统默认的也可以自己设计,非常灵活,也可以自定义取出一个batch数据的格式,如列表,字典等。
二、tensorboard的使用
参考博客:Pytorch中使用TensorBoard 首先导入包
import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("Log_dir")
# SummaryWriter函数中有很多参数如(log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=120, filename_suffix='')
# log_dir (str):指定了数据保存的文件夹的位置,如果该文件夹不存在则会创建一个出来。如果没有指定的话,默认的保存的文件夹是./runs/现在的时间_主机名,例如:Feb04_22-42-47_Alienware,因此每次运行之后都会创建一个新的文件夹。
# comment (string):给默认的log_dir添加的后缀,如果我们已经指定了log_dir具体的值,那么这个参数就不会有任何的效果
# purge_step (int):TensorBoard在记录数据的时候有可能会崩溃,例如在某一个epoch中,进行到第T + X T+XT+X个step的时候由于各种原因(内存溢出)导致崩溃,那么当服务重启之后,就会从T TT个step重新开始将数据写入文件,而中间的X XX,即purge_step指定的step内的数据都被被丢弃。
# max_queue (int):在记录数据的时候,在内存中开的队列的长度,当队列慢了之后就会把数据写入磁盘(文件)中。
# flush_secs (int):以秒为单位的写入磁盘的间隔,默认是120秒,即两分钟。
# filename_suffix (string):添加到log_dir中每个文件的后缀。更多文件名称设置要参考
可能会使用到一些方法如:
writer.add_image()
writer.add_scalar()
1. add_scalar()
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
"""Add scalar(标量) data to summary.
Args:
tag (string): Data identifier 对应标签
scalar_value (float or string/blobname): Value to save 对应Y轴
global_step (int): Global step value to record 对应X轴
walltime (float): Optional override default walltime (time.time())
with seconds after epoch of event
#使用该函数得到文件之后可以使用下面命令根据得到的地址在网页查看
tensorboard --logdir=事件文件所在文件夹名(地址)
#如:tensorboard --logdir=/home/work_nfs7/zxzhao/shixi/temp/train
若担心端口冲突,可通过如下命令指定端口:
tensorboard --logdir=文件夹名 --port=端口号
2. add_image()
def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
"""Add image data to summary.
Note that this requires the ``pillow`` package.
Args:
tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname): 图像数据,注意图像数据类型要符合要求
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
Shape:
img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as
corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``.
默认数据为3,H,W即(CHW)。若不符合则可用如dataformats='HWC'
# 想重新显示把tag改一下就好了