这里和大家简单介绍一下批处理以及优化器optimizer
批处理
在处理数据的过程中,为了使得整个网络有着更好的学习效果并且不会有过多的资源的浪费,所以有批处理的概念,具体的原理不多说,直接上代码
1、导包
import torch
import torch.utils.data as Data
我们设置BATCH_SIZE = 5
,在不同的训练任务中可以根据自己的需求或者硬件的需求进行设置,较为常见的为8,16等
BATCH_SIZE = 5
随机生成两组数据,为了直观给画出来
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
# plt.scatter(x,y)
# plt.show()
classtorch.utils.data.Dataset
表示Dataset
的抽象类。[1]
所有其他数据集都应该进行子类化。所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
class torch.utils.data.TensorDataset(data_tensor, target_tensor)
包装数据和目标张量的数据集。
通过沿着第一个维度索引两个张量来恢复每个样本。
参数:
-
data_tensor (Tensor)
- 包含样本数据 -
target_tensor (Tensor)
- 包含样本目标(标签) -
classtorch.utils.data.DataLoader
数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。[2]
参数:
- dataset (Dataset) – 加载数据的数据集。
- batch_size (int, optional) –每个batch加载多少个样本(默认: 1)。
- shuffle (bool, optional) –设置为True时会在每个epoch重新打乱数据(默认: False).
- sampler (Sampler, optional) –定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
- num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
- collate_fn (callable, optional) –
- pin_memory (bool, optional) –
- drop_last (bool, optional) –如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
torch_dataset = Data.TensorDataset(x , y )
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True
)
该版本不需要输入data_tensor,target_tensor,
否则会报错;同时对于Windows平台下不要使用多个子进程加载数据,否则会报错,Windows平台下多线程有点问题,具体原因不好说。
如果此时输出loader
会对应的为loader
加载的数据在物理硬件上的存储地址
通过批处理,来输出每一批的数据,来达到直观的效果
for epoch in range(3):
for step ,(batch_x,batch_y) in enumerate(loader):
print('Epoch:',epoch,'| step:',step,'|batch_x:',batch_x.numpy(),'|batch_y:',batch_y.numpy())
附完整代码
import torch
import torch.utils.data as Data
import matplotlib.pyplot as plt
BATCH_SIZE = 5
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
# plt.scatter(x,y)
# plt.show()
# torch_dataset = Data.TensorDataset(data_tensor=x , target_tensor = y )
torch_dataset = Data.TensorDataset(x , y )
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True
)
# print(loader)
for epoch in range(3):
for step ,(batch_x,batch_y) in enumerate(loader):
print('Epoch:',epoch,'| step:',step,'|batch_x:',batch_x.numpy(),'|batch_y:',batch_y.numpy())