基于小土堆学习
如何把数据集和Transform结合袭来
https://pytorch.org/
上述网址是pytorch的官网
这里会有详细的使用介绍
下述是对图像处理的专门文档
单击后可查看详细介绍
选择CIFAR10数据集
CIFAR10 数据集是一个广泛使用的计算机视觉数据集,包含了60000张32x32的彩色图像,这些图像分为10个类别,每个类别6000张图像。这些数据集被分为50000张训练图像和10000张测试图像。
参数解释如下:
- -root(str或pathlib.Path):数据集的根目录,其中应存在cifar-10-batches-py目录,或者如果设置download为True,则会在此目录下下载并保存数据集。
- -train(bool,可选):如果为True,则从训练集创建数据集;否则,从测试集创建数据集。
- -transform(callable,可选):一个函数/变换,它接受一个PIL图像并返回变换后的版本。例如,transforms.RandomCrop。
- -target_transform(callable,可选):一个函数/变换,它接受目标(标签)并对其进行变换。
- -download(bool,可选):如果为True,则从互联网下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=False,download=True)
#下载训练集和测试机print(test_set[0])#获取数据类型
print("test_set.classes",test_set.classes)#获取分类目标img,target = test_set[0]
print("img:",img)
print("target:",target)
#输出结果target: 3,对应类别0,1,2,3;也就是当前类别是猫cat
print("test_set.classesp[target]当前类型为",test_set.classes[target])
img.show()
运行结果为
C:\Anaconda3\envs\pytorch_test\python.exe H:\Python\Test\P10_dataset_transforms.py
Files already downloaded and verified
Files already downloaded and verified
(<PIL.Image.Image image mode=RGB size=32x32 at 0x21F676692D0>, 3)
test_set.classes ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img: <PIL.Image.Image image mode=RGB size=32x32 at 0x21F6A68E560>
target: 3
test_set.classesp[target]当前类型为 cat进程已结束,退出代码0
数据集全部转换为tensor数据类型
import torchvisiondataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor() ])train_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=False,transform=dataset_transform,download=True)
#transform=dataset_transform,将数据集中的每个数据都转换为Tensor格式
#下载训练集和测试机print(test_set[0])#获取数据类型
print("test_set.classes",test_set.classes)#获取分类目标img,target = test_set[0]
print("img:",img)
print("target:",target)
#输出结果target: 3,对应类别0,1,2,3;也就是当前类别是猫cat
print("test_set.classesp[target]当前类型为",test_set.classes[target])
输出结果为:
C:\Anaconda3\envs\pytorch_test\python.exe H:\Python\Test\P10_dataset_transforms.py
Files already downloaded and verified
Files already downloaded and verified
(tensor([[[0.6196, 0.6235, 0.6471, ..., 0.5373, 0.4941, 0.4549],[0.5961, 0.5922, 0.6235, ..., 0.5333, 0.4902, 0.4667],[0.5922, 0.5922, 0.6196, ..., 0.5451, 0.5098, 0.4706],...,[0.2667, 0.1647, 0.1216, ..., 0.1490, 0.0510, 0.1569],[0.2392, 0.1922, 0.1373, ..., 0.1020, 0.1137, 0.0784],[0.2118, 0.2196, 0.1765, ..., 0.0941, 0.1333, 0.0824]],[[0.4392, 0.4353, 0.4549, ..., 0.3725, 0.3569, 0.3333],[0.4392, 0.4314, 0.4471, ..., 0.3725, 0.3569, 0.3451],[0.4314, 0.4275, 0.4353, ..., 0.3843, 0.3725, 0.3490],...,[0.4863, 0.3922, 0.3451, ..., 0.3804, 0.2510, 0.3333],[0.4549, 0.4000, 0.3333, ..., 0.3216, 0.3216, 0.2510],[0.4196, 0.4118, 0.3490, ..., 0.3020, 0.3294, 0.2627]],[[0.1922, 0.1843, 0.2000, ..., 0.1412, 0.1412, 0.1294],[0.2000, 0.1569, 0.1765, ..., 0.1216, 0.1255, 0.1333],[0.1843, 0.1294, 0.1412, ..., 0.1333, 0.1333, 0.1294],...,[0.6941, 0.5804, 0.5373, ..., 0.5725, 0.4235, 0.4980],[0.6588, 0.5804, 0.5176, ..., 0.5098, 0.4941, 0.4196],[0.6275, 0.5843, 0.5176, ..., 0.4863, 0.5059, 0.4314]]]), 3)
test_set.classes ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img: tensor([[[0.6196, 0.6235, 0.6471, ..., 0.5373, 0.4941, 0.4549],[0.5961, 0.5922, 0.6235, ..., 0.5333, 0.4902, 0.4667],[0.5922, 0.5922, 0.6196, ..., 0.5451, 0.5098, 0.4706],...,[0.2667, 0.1647, 0.1216, ..., 0.1490, 0.0510, 0.1569],[0.2392, 0.1922, 0.1373, ..., 0.1020, 0.1137, 0.0784],[0.2118, 0.2196, 0.1765, ..., 0.0941, 0.1333, 0.0824]],[[0.4392, 0.4353, 0.4549, ..., 0.3725, 0.3569, 0.3333],[0.4392, 0.4314, 0.4471, ..., 0.3725, 0.3569, 0.3451],[0.4314, 0.4275, 0.4353, ..., 0.3843, 0.3725, 0.3490],...,[0.4863, 0.3922, 0.3451, ..., 0.3804, 0.2510, 0.3333],[0.4549, 0.4000, 0.3333, ..., 0.3216, 0.3216, 0.2510],[0.4196, 0.4118, 0.3490, ..., 0.3020, 0.3294, 0.2627]],[[0.1922, 0.1843, 0.2000, ..., 0.1412, 0.1412, 0.1294],[0.2000, 0.1569, 0.1765, ..., 0.1216, 0.1255, 0.1333],[0.1843, 0.1294, 0.1412, ..., 0.1333, 0.1333, 0.1294],...,[0.6941, 0.5804, 0.5373, ..., 0.5725, 0.4235, 0.4980],[0.6588, 0.5804, 0.5176, ..., 0.5098, 0.4941, 0.4196],[0.6275, 0.5843, 0.5176, ..., 0.4863, 0.5059, 0.4314]]])
target: 3
test_set.classesp[target]当前类型为 cat进程已结束,退出代码0
继续用Tensorboard进行图片的显示:显示前20张图片
import torchvision
from torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor() ])train_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=False,transform=dataset_transform,download=True)
#transform=dataset_transform,将数据集中的每个数据都转换为Tensor格式
#下载训练集和测试机# print(test_set[0])#获取数据类型
# print("test_set.classes",test_set.classes)#获取分类目标
#
# img,target = test_set[0]
# print("img:",img)
# print("target:",target)
# #输出结果target: 3,对应类别0,1,2,3;也就是当前类别是猫cat
# print("test_set.classesp[target]当前类型为",test_set.classes[target])
write = SummaryWriter("logs")
for i in range(20):img, target = test_set[i]write.add_image("img", img, i)
write.close()
结果为:
C:\Anaconda3\envs\pytorch_test\python.exe H:\Python\Test\P10_dataset_transforms.py
Files already downloaded and verified
Files already downloaded and verified进程已结束,退出代码0
local的结果
**(pytorch_test) PS H:\Python\Test> tensorboard --logdir logs --port=6007
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.17.1 at http://localhost:6007/ (Press CTRL+C to quit)
**
拖动可以查看20张图片