文章目录
- 一、Transforms的使用
- 二、Tensor数据类型
- 三、常见的Transforms
- 总结
一、Transforms的使用
torchvision中的transforms主要是对图片进行一些变换。
tranforms对应 tranforms.py 文件,里面定义了很多类,输入一个图片对象,返回经过处理的图片对象。
transforms.py就像一个工具箱,里面定义的各种类就像各种工具,图片就是输入对象,经过工具处理,输出期望的图片结果。
现在通过 transforms.ToTensor去看两个问题:
- 1、transforms该如何使用(python)
- 2、为什么我们需要 Tensor 数据类型
ToTensor功能是将 PIL Image 类型 或者numpy.ndarray类型的图片对象转换为 tensor类型。
使用Demo:
from torchvision import transforms
from PIL import Image
img_path = "testdata/train/ants_image/6743948_2b8c096dda.jpg"
img = Image.open(img_path)
print(img)
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
print(tensor_img)
所以使用transforms的方法就是 先实例化选中的类,然后用实例化的对象去处理图片就行。
二、Tensor数据类型
将第一节中的代码复制到 python 控制台,回车,可在右侧看到各种变量和对象的具体信息:
tensor 数据类型可以理解为包装了反向神经网络一些理论基础参数。在神经网络中,要将数据先转换为Tensor类型,再进行训练。
测试代码:
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
img_path = "testdata/train/ants_image/6743948_2b8c096dda.jpg"
img = Image.open(img_path)
writer = SummaryWriter("logs")
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
writer.add_image("Tensor_Image",tensor_img)
writer.close()
结果:
三、常见的Transforms
常用的输入图片对象的数据类型
- PIL : Image.open()
- tensor : ToTensor()
- ndarrays: cv.imread()
常用的Transform有:
- ToTensor() :将图片对象类型转为 tensor
- Normalize() :对图像像素进行归一化计算
- Resize():重新设置 PIL Image的大小,返回也是PIL Image格式
- Compose(): 输入为 transforms类型参数的列表,即
Compose([transforms参数1, transforms参数2], ...)
目的是将几个 transforms操作打包成一个,比如要先进行大小调整,然后进行归一化计算,返回tensor类型,则可以将 ToTensor、Normalize、Resize,按操作顺序输入到Compose中。
示例代码:
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import os
root_path = "hymenoptera_data/train/ants"
img_name = "7759525_1363d24e88.jpg"
img_path = os.path.join(root_path,img_name)
img = Image.open(img_path)
writer = SummaryWriter("logs")
# ToTensor
trans_totensor = transforms.ToTensor() # instantiation
img_tensor = trans_totensor(img)
writer.add_image("Tensor", img_tensor)
# Normalize
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize", img_norm)
#Resize
print(img.size)
trans_resize = transforms.Resize((512,512))
img_resize = trans_resize(img) # return type still is PIL image
img_resize = trans_totensor(img_resize)
writer.add_image("Resize", img_resize)
# Compose - resize -2
trans_resize_2 = transforms.Resize(512)
tran_compose = transforms.Compose([trans_resize_2, trans_totensor])
img_resize2 = tran_compose(img)
writer.add_image("Compose", img_resize2)
writer.close()
结果:
总结
- 关注输入和输出类型
- 多看官方文档
- 关注方法需要什么参数