从numpy里加载_PyTorch强化:01.PyTorch 数据加载和处理

PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性。

1.下载安装包

  • scikit-image:用于图像的IO和变换
  • pandas:用于更容易地进行csv解析
from __future__ import print_function, divisionimport osimport torchimport pandas as pd #用于更容易地进行csv解析from skimage import io, transform #用于图像的IO和变换import numpy as npimport matplotlib.pyplot as pltfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transforms, utils# 忽略警告import warningswarnings.filterwarnings("ignore")plt.ion() # interactive mode

2.下载数据集

从此处下载数据集, 数据存于“data / faces /”的目录中。这个数据集实际上是imagenet数据集标注为face的图片当中在 dlib 面部检测 (dlib’s pose estimation) 表现良好的图片。我们要处理的是一个面部姿态的数据集。也就是按如下方式标注的人脸:

464cfebbaf6cb0df7018f45bf094d343.png

2.1 数据集注释

数据集是按如下规则打包成的csv文件:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y0805personali01.jpg,27,83,27,98, ... 84,1341084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

3.读取数据集

将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量。读取数据代码如下:

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')n = 65img_name = landmarks_frame.iloc[n, 0]landmarks = landmarks_frame.iloc[n, 1:].as_matrix()landmarks = landmarks.astype('float').reshape(-1, 2)print('Image name: {}'.format(img_name))print('Landmarks shape: {}'.format(landmarks.shape))print('First 4 Landmarks: {}'.format(landmarks[:4]))

3.1 数据结果

输出:

Image name: person-7.jpgLandmarks shape: (68, 2)First 4 Landmarks: [[32. 65.][33. 76.][34. 86.][34. 97.]]

4 编写函数

写一个简单的函数来展示一张图片和它对应的标注点作为例子。

def show_landmarks(image, landmarks):"""显示带有地标的图片""" plt.imshow(image) plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r') plt.pause(0.001) # pause a bit so that plots are updatedplt.figure()show_landmarks(io.imread(os.path.join('data/faces/', img_name)), landmarks)plt.show()

函数展示结果如下图所示:

01588fa69375bb6cfc19d3591a4fd701.png

5.数据集类

torch.utils.data.Dataset是表示数据集的抽象类,因此自定义数据集应继承Dataset并覆盖以下方法

  • __len__ 实现 len(dataset) 返还数据集的尺寸。
  • __getitem__用来获取一些索引数据,例如 dataset[i] 中的(i)。

5.1 建立数据集类

为面部数据集创建一个数据集类。我们将在 __init__ 中读取csv的文件内容,在 __getitem__中读取图片。这么做是为了节省内存 空间。只有在需要用到图片的时候才读取它而不是一开始就把图片全部存进内存里。

我们的数据样本将按这样一个字典{'image': image, 'landmarks': landmarks}组织。 我们的数据集类将添加一个可选参数transform 以方便对样本进行预处理。下一节我们会看到什么时候需要用到transform参数。 __init__方法如下图所示:

class FaceLandmarksDataset(Dataset):"""面部标记数据集."""def __init__(self, csv_file, root_dir, transform=None):""" csv_file(string):带注释的csv文件的路径。 root_dir(string):包含所有图像的目录。 transform(callable, optional):一个样本上的可用的可选变换 """ self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx): img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0]) image = io.imread(img_name) landmarks = self.landmarks_frame.iloc[idx, 1:] landmarks = np.array([landmarks]) landmarks = landmarks.astype('float').reshape(-1, 2) sample = {'image': image, 'landmarks': landmarks}if self.transform: sample = self.transform(sample)return sample

6.数据可视化

实例化这个类并遍历数据样本。我们将会打印出前四个例子的尺寸并展示标注的特征点。 代码如下图所示:

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/')fig = plt.figure()for i in range(len(face_dataset)): sample = face_dataset[i]print(i, sample['image'].shape, sample['landmarks'].shape) ax = plt.subplot(1, 4, i + 1) plt.tight_layout() ax.set_title('Sample #{}'.format(i)) ax.axis('off')show_landmarks(**sample)if i == 3: plt.show() break

数据结果:

6.1 图形展示结果

69050814e6b75cfaa13076b11a97d57d.png

6.2 控制台输出结果:

0 (324, 215, 3) (68, 2)1 (500, 333, 3) (68, 2)2 (250, 258, 3) (68, 2)3 (434, 290, 3) (68, 2)

7.数据变换

通过上面的例子我们会发现图片并不是同样的尺寸。绝大多数神经网络都假定图片的尺寸相同。因此我们需要做一些预处理。让我们创建三个转换:

  • Rescale:缩放图片
  • RandomCrop:对图片进行随机裁剪。这是一种数据增强操作
  • ToTensor:把numpy格式图片转为torch格式图片 (我们需要交换坐标轴).

我们会把它们写成可调用的类的形式而不是简单的函数,这样就不需要每次调用时传递一遍参数。我们只需要实现__call__方法,必 要的时候实现 __init__ 方法。我们可以这样调用这些转换:

tsfm = Transform(params)transformed_sample = tsfm(sample)

观察下面这些转换是如何应用在图像和标签上的。

class Rescale(object):"""将样本中的图像重新缩放到给定大小。. Args: output_size(tuple或int):所需的输出大小。 如果是元组,则输出为 与output_size匹配。 如果是int,则匹配较小的图像边缘到output_size保持纵横比相同。 """def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) self.output_size = output_sizedef __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2]if isinstance(self.output_size, int):if h > w: new_h, new_w = self.output_size * h / w, self.output_sizeelse: new_h, new_w = self.output_size, self.output_size * w / helse: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) img = transform.resize(image, (new_h, new_w)) # h and w are swapped for landmarks because for images, # x and y axes are axis 1 and 0 respectively landmarks = landmarks * [new_w / w, new_h / h]return {'image': img, 'landmarks': landmarks}class RandomCrop(object):"""随机裁剪样本中的图像. Args: output_size(tuple或int):所需的输出大小。 如果是int,方形裁剪是。 """def __init__(self, output_size): assert isinstance(output_size, (int, tuple))if isinstance(output_size, int): self.output_size = (output_size, output_size)else: assert len(output_size) == 2 self.output_size = output_sizedef __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2] new_h, new_w = self.output_size top = np.random.randint(0, h - new_h) left = np.random.randint(0, w - new_w) image = image[top: top + new_h, left: left + new_w] landmarks = landmarks - [left, top]return {'image': image, 'landmarks': landmarks}class ToTensor(object):"""将样本中的ndarrays转换为Tensors."""def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] # 交换颜色轴因为 # numpy包的图片是: H * W * C # torch包的图片是: C * H * W image = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}

8.组合转换

接下来我们把这些转换应用到一个例子上。

我们想要把图像的短边调整为256,然后随机裁剪(randomcrop)为224大小的正方形。也就是说,我们打算组合一个Rescale和 RandomCrop的变换。 我们可以调用一个简单的类 torchvision.transforms.Compose来实现这一操作。具体实现如下图:

scale = Rescale(256)crop = RandomCrop(128)composed = transforms.Compose([Rescale(256),RandomCrop(224)])# 在样本上应用上述的每个变换。fig = plt.figure()sample = face_dataset[65]for i, tsfrm in enumerate([scale, crop, composed]): transformed_sample = tsfrm(sample) ax = plt.subplot(1, 3, i + 1) plt.tight_layout() ax.set_title(type(tsfrm).__name__)show_landmarks(**transformed_sample)plt.show()
  • 输出效果:
d6f882daa3e571d2ab32ef3eef2837c2.png

9.迭代数据集

让我们把这些整合起来以创建一个带组合转换的数据集。总结一下,每次这个数据集被采样时:

  • 及时地从文件中读取图片
  • 对读取的图片应用转换
  • 由于其中一步操作是随机的 (randomcrop) , 数据被增强了

我们可以像之前那样使用for i in range循环来对所有创建的数据集执行同样的操作。

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/', transform=transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()]))for i in range(len(transformed_dataset)): sample = transformed_dataset[i]print(i, sample['image'].size(), sample['landmarks'].size())if i == 3: break
  • 输出结果:
0 torch.Size([3, 224, 224]) torch.Size([68, 2])1 torch.Size([3, 224, 224]) torch.Size([68, 2])2 torch.Size([3, 224, 224]) torch.Size([68, 2])3 torch.Size([3, 224, 224]) torch.Size([68, 2])

但是,对所有数据集简单的使用for循环牺牲了许多功能,尤其是:

  • 批量处理数据
  • 打乱数据
  • 使用多线程multiprocessingworker 并行加载数据。

torch.utils.data.DataLoader是一个提供上述所有这些功能的迭代器。下面使用的参数必须是清楚的。一个值得关注的参数是collate_fn, 可以通过它来决定如何对数据进行批处理。但是绝大多数情况下默认值就能运行良好。

dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)# 辅助功能:显示批次def show_landmarks_batch(sample_batched):"""Show image with landmarks for a batch of samples.""" images_batch, landmarks_batch =  sample_batched['image'], sample_batched['landmarks'] batch_size = len(images_batch) im_size = images_batch.size(2) grid_border_size = 2 grid = utils.make_grid(images_batch) plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size): plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size, landmarks_batch[i, :, 1].numpy() + grid_border_size, s=10, marker='.', c='r') plt.title('Batch from dataloader')for i_batch, sample_batched in enumerate(dataloader):print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size()) # 观察第4批次并停止。if i_batch == 3: plt.figure()show_landmarks_batch(sample_batched) plt.axis('off') plt.ioff() plt.show() break
edfe3ac0ef4efbd43c93453edf3344e0.png
  • 输出
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

10.后记:torchvision

在这篇教程中我们学习了如何构造和使用数据集类(datasets),转换(transforms)和数据加载器(dataloader)。torchvision包提供了 常用的数据集类(datasets)和转换(transforms)。你可能不需要自己构造这些类。torchvision中还有一个更常用的数据集类ImageFolder。 它假定了数据集是以如下方式构造的:

root/ants/xxx.pngroot/ants/xxy.jpegroot/ants/xxz.png...root/bees/123.jpgroot/bees/nsdf3.pngroot/bees/asd932_.png

其中’ants’,bees’等是分类标签。在PIL.Image中你也可以使用类似的转换(transforms)例如RandomHorizontalFlip,Scale。利 用这些你可以按如下的方式创建一个数据加载器(dataloader) :

import torchfrom torchvision import transforms, datasetsdata_transform = transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train', transform=data_transform)dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset, batch_size=4, shuffle=True, num_workers=4

2020未来杯AI挑战赛-图像赛道-语音赛道同时开启,30万大奖等你来挑战!

https://ai.futurelab.tv/tournament/6

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/557244.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redmine使用手册

一、Redmine简介 Redmine是基于ROR框架开发的一套跨平台项目管理系统,是项目管理系统的后起之秀,据说是源于Basecamp的ror版而来,支持多种数据库,除了和DotProject的功能大致相当外,还有不少自己独特的功能&#xff0…

swagger2maven依赖_Maven + SpringMVC项目集成Swagger

Swagger 是一个规范和完整的框架,用于生成、描述、调用和可视化 RESTful 风格的 Web 服务。总体目标是使客户端和文件系统作为服务器以同样的速度来更新。文件的方法,参数和模型紧密集成到服务器端的代码,允许API来始终保持同步。作用&#x…

IDEA2019版最新配置SVN及上传教程-超详细图文详解

IDEA2019版配置SVN图文详解 1. 查看svn仓库 调出svn视图: 连接svn服务器: 连接后效果如下: 补充:如果输入正确的连接地址后出现错误—系统找不到指定的文件 请到设置中检查(File | Settings | Version Control | Subversion)SVC客户端路径…

dubbo:reference、dubbo:service和@Service、@Reference使用情况

以前在同一模块中Spring依赖注入&#xff0c;可以通过Service和Autowired Dubbo是远程服务调用&#xff0c;消费方需要注入提供方定义的接口实例&#xff0c;可以通过xml配置 dubbo:reference、dubbo:service <dubbo:service interface"fei.CustomerServices" …

SSM+Maven+Dubbo+Zookeeper简单项目实战以及易错注意点

最近为了熟悉Dubbo远程过程调用架构的使用&#xff0c;并结合SSMMaven整合了简单的一套项目实战 直接看项目结构图 各模块介绍 dubbo-common&#xff1a;存放项目需要的公众类&#xff0c;像查询模型、数据库实体模型等 dubbo-config&#xff1a;存放项目所需的公众配置文件&…

c++二叉树的层序遍历_leetcode 103. 二叉树的锯齿形层序遍历

按层次遍历&#xff0c;记录下对应节点的val和所在层&#xff0c;然后经过一定变换得到输出。python代码如下&#xff1a;# Definition for a binary tree node.# class TreeNode(object):# def __init__(self, x):# self.val x# self.left None# …

TCP和UDP的区别(Socket)

TCP和UDP区别 TCP和UDP编程区别 TCP编程的服务器端一般步骤是&#xff1a;   1、创建一个socket&#xff0c;用函数socket()&#xff1b;   2、设置socket属性&#xff0c;用函数setsockopt(); * 可选   3、绑定IP地址、端口等信息到socket上&#xff0c;用函数bind(); …

mysql open table_MySQL open table

背景&#xff1a;MySQL经常会遇到Too many open files&#xff0c;MySQL上的open_files_limit和OS层面上设置的open file limit有什么关系&#xff1f;源码中也会看到不同的数据结构&#xff0c;TABLE, TABLE_SHARE&#xff0c;跟表是什么关系&#xff1f;MySQL flush tables又…

JUC详解

JUC 前言&#xff1a; 在Java中&#xff0c;线程部分是一个重点&#xff0c;本篇文章说的JUC也是关于线程的。JUC就是java.util .concurrent工具包的简称。这是一个处理线程的工具包&#xff0c;JDK 1.5开始出现的。下面一起来看看它怎么使用。 一、volatile关键字与内存可见…

抓包工具,知道手机app上面使用的接口是哪个

fiddler。大家可以百度上面好多选择一个安装。这里随便扔一个 在电脑上安装以后。你再配置手机上的一些设置。 首先保证手机和电脑在同一个局域网上&#xff0c;连得wifi域名前面一样的&#xff0c;在电脑的cmd输入ipconfig 然后打开手机的设置。wifi页面点开查看你连的wifi的…

munin mysql_munin 监控 mysql 2种方法

munin自带的有mysql监控功能&#xff0c;但是没有启用。试了二种方法&#xff0c;都可以监控mysql。一&#xff0c;安装munin mysql的perl扩展# yum install perl-Cache-Cache perl-IPC-ShareLite perl-DBD-MySQL二&#xff0c;为监控创建mysql用户mysql> CREATE USER munin…

使用fiddler实现手机抓包

使用fiddler实现手机抓包 手机上无法直接查看网络请求数据&#xff0c;需要使用抓包工具。Fiddler是一个免费的web调试代理&#xff0c;可以用它实现记录、查看和调试手机终端和远程服务器之间的http/https通信。 一、PC端fiddler配置 1. 安装HTTPS证书 手机上的应用很多涉及…

小米手机上安装https证书(例如pem证书,crt证书)详解

小米手机上安装https证书&#xff08;例如pem证书&#xff0c;crt证书&#xff09;关键三步&#xff1a; 1.使用第三方浏览器下载.pem 格式的文件 &#xff08;我使用的是QQ浏览器&#xff09; 2.将这个文件放入小米的 DownLoad 文件夹下 (这步也可以不做&#xff0c;只要在4…

python django图书管理系统_Python框架:Django写图书管理系统(LMS)

Django模版文件配置文件路径 test_site -- test_site -- settings.pyTEMPLATES [ { BACKEND: django.template.backends.django.DjangoTemplates, DIRS: [os.path.join(BASE_DIR, "template")], # template文件夹位置 APP_DIRS: True, OPTIONS: { context_processor…

springsecurity中session失效后怎样处理_结合Spring Security进行web应用会话安全管理

结合Spring Security进行web应用会话安全管理在本文中&#xff0c;将为大家说明如何结合Spring Security 管理web应用的会话。如果您阅读后觉得本文对您有帮助&#xff0c;期待您能关注、转发&#xff01;您的支持是我不竭的创作动力&#xff01;一、Spring Security创建使用se…

如何把数据库从sql变成mysql_如何将数据库从SQL Server迁移到MySQL

一、迁移Database Schema。首先使用Sybase Powerdesigner的逆向工程功能&#xff0c;逆向出SQL Server数据库的物理模型。具体操作是在Powerdesigner中选择“File”&#xff0c;“Reverse Engine”再选择Database&#xff0c;将DBMS选择为SQL Server&#xff0c;如图&#xff1…

linux转mysql_[转] linux下安装mysql服务器

[转自&#xff1a;http://www.extmail.org/forum/archive/2/0510/563.html]安装MySQL服务器你可以根据服务器的CPU类型&#xff0c;下载适合你所用CPU和操作系统的MySQL发行包。从下面的URL下载MySQL 4.1.16以tar.gz形式发布的二进制发行包&#xff1a;http://www.mysql.com增加…

HTTP 学习,程序员不懂网络怎么行,一篇HTTP入门 不收藏都可惜

文章目录&#x1f4e2;前言HTTP 必备干货学习&#xff0c;程序员不懂网络怎么行HTTP 协议五个特点&#xff1a;网络结构图解HTTP概述&#x1f3f3;️‍&#x1f308;基于 HTTP 的系统的组件客户端&#xff1a;用户代理网络服务器代理HTTP 的基本方面HTTP 很简单HTTP 是可扩展的…

Java面试——Redis系列总结

文章目录&#xff1a; 1.什么是Redis&#xff1f; 2.为什么要用 Redis / 为什么要用缓存&#xff1f; 3.Redis为什么这么快&#xff1f; 4.Redis都有哪些数据类型&#xff1f; 5.什么是Redis持久化&#xff1f;Redis 的持久化有哪些实现方式&#xff1f; 6.什么是Redis事…

java运行环境_Windows系统java运行环境配置 | 吴文辉博客

在进行java开发之前&#xff0c;我们最重要的步骤就是如何获取JDK版本及正确的安装、配置java环境。只有正确的安装了java运行环境&#xff0c;才能继续java的学习和实践。一、下载JDK安装1、我系统是win7 64位&#xff0c;所以我下载了jdk-8u74-windows-x64&#xff1b;下载地…