PyTorch系列 (二): pytorch数据读取自制数据集并

PyTorch系列 (二): pytorch数据读取

PyTorch 1: How to use data in pytorch

Posted by WangW on February 1, 2019

参考:

  1. PyTorch documentation
  2. PyTorch 码源

本文首先介绍了有关预处理包的源码,接着介绍了在数据处理中的具体应用;

1 PyTorch数据预处理以及源码分析 (torch.utils.data)

torch.utils.data脚本码源

1.1 Dataset

Dataset

 
1
class torch.utils.data.Dataset

表示Dataset的抽象类。所有其他数据集都应该进行子类化。 所有子类应该override__len____getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

 
1
2
3
4
5
6
7
8
9
10
11
12
class Dataset(object):# 强制所有的子类override getitem和len两个函数,否则就抛出错误;# 输入数据索引,输出为索引指向的数据以及标签;def __getitem__(self, index):raise NotImplementedError# 输出数据的长度def __len__(self):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])

TensorDataset

 
1
class torch.utils.data.TensorDataset(*tensors)

Dataset的子类。包装tensors数据集;输入输出都是元组; 通过沿着第一个维度索引一个张量来回复每个样本。 个人感觉比较适用于数字类型的数据集,比如线性回归等。

 
1
2
3
4
5
6
7
8
9
10
class TensorDataset(Dataset):def __init__(self, *tensor):assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)self.tensors = tensorsdef __getitem__(self, index):return tuple(tensor[index] for tensor in tensorsdef __len__(self):return self.tensors[0].size(0)

ConcatDateset

 
1
class torch.utils.data.ConcatDateset(datasets)

连接多个数据集。 目的:组合不同的数据集,可能是大规模数据集,因为连续操作是随意连接的。 datasets的参数:要连接的数据集列表 datasets的样式:iterable

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class ConcatDataset(Dataset):@staticmethoddef cumsum(sequence):# sequence是一个列表,e.g. [[1,2,3], [a,b], [4,h]]# return 一个数据大小列表,[3, 5, 7], 明显看的出来包含数据多少,第一个代表第一个数据的大小,第二个代表第一个+第二数据的大小,最后代表所有的数据大学;...def __getitem__(self, idx):# 主要是这个函数,通过bisect的类实现了任意索引数据的输出;dataset_idx = bisect.bisect_right(self.cumulative_size, idx)if dataset_idx == 0:sample_idx == idxelse:sample_idx = idx - self.cumulative_sizes[dataset_idx -1]return self.datasets[dataset_idx][sample_idx]...

Subset

 
1
class torch.utils.data.Subset(dataset, indices)

选取特殊索引下的数据子集; dataset:数据集; indices:想要选取的数据的索引;

random_split

 
1
class torch.utils.data.random_split(dataset, lengths):

随机不重复分割数据集; dataset:要被分割的数据集 lengths:长度列表,e.g. [7, 3], 保证7+3=len(dataset)

1.2 DataLoader

DataLoader

 
1
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

数据加载器。 组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。 参数:

  • dataset (Dataset) - 从中加载数据的数据集。
  • batch_size (int, optional) - 批训练的数据个数。
  • shuffle (bool, optional) - 是否打乱数据集(一般打乱较好)。
  • sampler (Sampler, optional) - 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • batch_sampler (Sample, optional) - 和sampler类似,返回批中的索引。
  • num_workers (int, optional) - 用于数据加载的子进程数。
  • collate_fn (callable, optional) - 合并样本列表以形成小批量。
  • pin_memory (bool, optional) - 如果为True,数据加载器在返回去将张量复制到CUDA固定内存中。
  • drop_last (bool, optional) - 如果数据集大小不能被batch_size整除, 设置为True可以删除最后一个不完整的批处理。
  • timeout (numeric, optional) - 正数,收集数据的超时值。
  • worker_init_fn (callabel, optional) - If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

特别重要:DataLoader中是不断调用DataLoaderIter

DataLoaderIter

 
1
class _DataLoaderIter(loader)

从DataLoader’s数据中迭代一次。其上面DataLoader功能都在这里; 插个眼,有空在分析这个

1.3 sampler

Sampler

 
1
class torch.utils.data.sampler.Sampler(data_source)

所有采样器的基础类; 每个采样器子类必须提供一个__iter__方法,提供一种迭代数据集元素的索引的方法,以及返回迭代器长度__len__方法。

class Sampler(object):def __init__(self, data_source):passdef __iter__(self):raise NotImplementedErrordef __len__(self):raise NotImplementedError

SequentialSampler

 
1
class torch.utils.data.SequentialSampler(data_source)

样本元素顺序排列,始终以相同的顺序。 参数:-data_source (Dataset) - 采样的数据

RandomSampler

 
1
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)

样本随机排列,如果没有Replacement,将会从打乱的数据采样,否则,。。 参数:

  • data_source (Dataset) - 采样数据
  • num_samples (int) - 采样数据大小,默认是全部。
  • replacement (bool) - 是否放回

SubsetRandomSampler

 
1
class torch.utils.data.SubsetRandomSampler(indices)

从给出的索引中随机采样,without replacement。 参数:

  • indices (sequence) - 索引序列。

BatchSampler

 
1
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

将采样封装到批处理索引。 参数:

  • sampler (sampler) - 基本采样
  • batch_size (int) - 批大小
  • drop_last (bool) - 是否删掉最后的批次

weightedRandomSampler

 
1
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)

样本元素来自[0,…,len(weights)-1], 给定概率(权重)。 参数:

  • weights (list) - 权重列表。不需要加起来为1
  • num_samplers (int) - 要采样数目
  • replacement (bool) -

1.4 Distributed

DistributedSampler

 
1
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None)

????没读呢

1.5 其它链接

  1. PyTorch源码解读之torch.utils.data.DataLoader

2 torchvision

计算机视觉用到的库,文档以及码源如下:

  1. torchvision documentation
  2. torchvision 其库主要包含一下内容:
  • torchvision.datasets
    • MNIST
    • Fashion-MNIST
    • EMNIST
    • COCO
    • LSUN
    • ImageFolder
    • DatasetFolder
    • Imagenet-12
    • CIFAR
    • STL10
    • SVHN
    • Photo Tour
    • SBU
    • Flickr
    • VOC
  • torchvision.models
    • Alexnet
    • VGG
    • ResNet
    • SqueezeNet
    • DenseNet
    • Inception v3
  • torchvision.transforms
    • Transforms on PIL Image
    • Transfroms on torch.* Tensor
    • Conversion Transforms
    • Generic Transforms
    • Functional Transforms
  • torchvision.utils

3 应用

3.1 init

具有一下图像数据如下表示:

  • train
    • normal
      • 1.png
      • 2.png
      • 8000.png
    • tumor
      • 1.png
      • 2.png
      • 8000.png
  • validation
    • normal
      • 1.png
    • tumor
      • 1.png

希望能够训练模型,使得能够识别tumor, normal两类,将tumor–>1, normal–>0。

3.2 数据读取

在PyTorch中数据的读取借口需要经过,Dataset和DatasetLoader (DatasetloaderIter)。下面就此分别介绍。

Dataset

首先导入必要的包。

import osimport numpy as np
from torch.utils.data import Dataset
from PIL import Imagenp.random.seed(0)

其次定义MyDataset类,为了代码整洁精简,将不必要的操作全删,e.g. 图像剪切等。

class MyDataset(Dataset):def __init__(self, root, size=229, ):"""Initialize the data producer"""self._root = rootself._size = sizeself._num_image = len(os.listdir(root))self._img_name = os.listdir(root)def __len__(self):return self._num_imagedef __getitem__(self, index):img = Image.open(os.path.join(self._root, self._img_name[index]))# PIF image: H × W × C# torch image: C × H × Wimg = np.array(img, dtype-np.float32).transpose((2, 0, 1))return img

DataLoader

将MyDataset封装到loader器中。

from torch.utils.data import DataLoader# 实例化MyData
dataset_tumor_train = MyDataset(root=/img/train/tumor/)
dataset_normal_train = MyDataset(root=/img/train/normal/)
dataset_tumor_validation = MyDataset(root=/img/validation/tumor/)
dataset_normal_validation = MyDataset(root=/img/validation/normal/)# 封装到loader
dataloader_tumor_train = DataLoader(dataset_tumor_train, batch_size=10)
dataloader_normal_train = DataLoader(dataset_normal_train, batch_size=10)
dataloader_tumor_validation = DataLoader(dataset_tumor_validation, batch_size=10)
dataloader_normal_validation = DataLoader(dataset_normal_validation, batch_size=10)

3.3 train_epoch

简单将数据流接口与训练连接起来

def train_epoch(model, loss_fn, optimizer, dataloader_tumor, dataloader_normal):model.train()# 由于tumor图像和normal图像一样多,所以将tumor,normal连接起来,steps=len(tumor_loader)=len(normal_loader)steps = len(dataloader_tumor)batch_size = dataloader_tumor.batch_sizedataiter_tumor = iter(dataloader_tumor)dataiter_normal = iter(dataloader_normal)for step in range(steps):data_tumor = next(dataiter_tumor)target_tumor = [1, 1,..,1] # 和data_tumor长度相同的tensordata_tumor = Variable(data_tumor.cuda(async=True))target_tumor = Variable(target_tumor.cuda(async=True))data_normal = next(dataiter_normal)target_normal = [0, 0,..,0] # data_normal = Variable(data_normal.cuda(async=True))target_normal = Variable(target_normal.cuda(async=True))idx_rand = Variable(torch.randperm(batch_size*2).cuda(async=True))data = torch.cat([data_tumor, data_normal])[idx_rand]target = torch.cat([target_tumor, target_normal])[idx_rand]output = model(data)loss = loss_fn(output, target)optimizer.zero_grad()loss.backward()optimizer.step()

任何程序错误,以及技术疑问或需要解答的,请添加

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/547057.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nodejs+nginx获取真实ip

nodejs nginx获取真实ip分为两部分&#xff1a; 第一、配置nginx&#xff1b;第二、通过nodejs代码获取&#xff1b; 其他语言也是一样的&#xff0c;都是配置nginx之后&#xff0c;在http头里面获取“x-forwarded-for”. 第一、配置nginx location / {proxy_set_header Ho…

【OSChina-MoPaaS应用开发大赛】豪美创新后台业务管理系统

2019独角兽企业重金招聘Python工程师标准>>> 应用名称&#xff1a;豪美创新后台业务管理系统 应用URL地址&#xff1a;http://tyz.sturgeon.mopaas.com/admin/index.html 登录&#xff1a;admin/admin 投票地址&#xff1a;http://www.oschina.net/mopaas-app-co…

QT5更改应用程序图标

1.准备好.ico的图片放在工程目录下&#xff0c;并添加到项目的资源文件中 2.在项目配置.pro文件中添加一下内容 RC_ICONS AppIcon.icoAppIcon为你的ico图片名字 3.在可视化设计文件.ui中选择主窗口&#xff0c;将其属性中的windowIcon一项右侧下三角单击&#xff0c;从“选择…

python List中元素两两组合

python List中元素两两组合 import itertools aa [a, b, c] bb list(itertools.permutations(aa, 2)) print(bb) print("######################") cc list(itertools.combinations(aa, 2)) print(cc) 话不多说&#xff0c;运行结果解释一些 任何程序错误&…

xcode编译报错unknown error -1=ffffffffffffffff Command /bin/sh failed with exit code 1

升级完xcode9.1之后&#xff0c;编译项目出现如下错误&#xff1a; CI今日构建时报出如下错误&#xff1a; /Users/xxx/Library/Developer/Xcode/DerivedData/Snowball-ebllohyukujrncbaldsfojfjxwep/Build/Intermediates.noindex/ArchiveIntermediates/ProjectName/Installat…

android button的selector

实现按钮的selector <?xml version"1.0" encoding"utf-8"?><selector xmlns:android"http://schemas.android.com/apk/res/android"><item android:drawable"drawable/common_topbar_route_bus_pressed" android:sta…

Windows Qt5下用QAxObject快速读写Excel指南

Qt Windows 下快速读写Excel指南 很多人搜如何读写excel都会看到用QAxObject来进行操作&#xff0c;很多人试了之后都会发现一个问题&#xff0c;就是慢&#xff0c;非常缓慢&#xff01;因此很多人得出结论是QAxObject读写excel方法不可取&#xff0c;效率低。 后来我曾试过…

python opencv过滤红色

OpenCV简易视频处理框架OpenCV主要色彩空间OpenCV的位操作方法 找出视频中红色物体 import cv2 import numpy as npdef filter_out_black(src_frame):if src_frame is not None:hsv cv2.cvtColor(src_frame, cv2.COLOR_BGR2HSV)lower_red np.array([0, 0, 0])upper_red np.…

#39;boost/iterator/iterator_adaptor.hpp#39; file not found之xcode生成时报错的解决方案

xcode生成rn&#xff08;0.49.3&#xff09;项目的时候出现“boost/iterator/iterator_adaptor.hpp file not found之xcode”报错。 原因&#xff1a; /Users/xxx/.rncache 中 boost_1_63_0.tar.gz&#xff0c; double-conversion-1.1.5.tar.gz&#xff0c; folly-2016.09.26.…

经典面试题SALES TAXES思路分析和源码分享

题目&#xff1a; SALES TAXESBasic sales tax is applicable at a rate of 10% on all goods, except books, food, and medical products that are exempt. Import duty is an additional sales tax 除书籍 食品 药品外其他商品基本税为10%。进口税附加5%&#xff0c;不免税。…

Snipaste在Window运行后遇到提示计算机中丢失 api-ms-win-crt-runtime-l1-1-0.dll 错误

故障排除 以下为运行 Snipaste 可能遇到的错误及其解决方案。 Windows 运行后遇到提示计算机中丢失 api-ms-win-crt-runtime-l1-1-0.dll 错误 请根据你操作系统的版本&#xff08;32位/64位&#xff09;&#xff0c;下载并安装相应的微软 Visual C 2015 可再发行组件包: 32…

Windows10安装ubuntu18.04双系统教程

Windows10安装ubuntu18.04双系统教程 写在前面&#xff1a;本教程为windows10安装ubuntu18.04&#xff08;64位&#xff09;双系统教程&#xff0c;是我多次安装双系统的经验总结&#xff0c;安装方法同样适用于ubuntu16.04&#xff08;64位&#xff09;。为了直观和易于理解&…

ffmpeg h264+ts +udp传输

http://bbs.csdn.net/topics/370246456 http://1229363.blog.163.com/blog/static/19743427201001244711137/ ffmpeg windows 下编译 http://www.360doc.com/content/13/0913/15/13084517_314201133.shtml h264帧边界识别 http://fs-linux.org/forum.php?modviewthread&ti…

ReactNative实现图集功能

需求描述&#xff1a;  图片缩放、拖动、长按保存等基础图片查看的功能&#xff1b; 展示每张图片文本描述&#xff1b; 实现效果&#xff0c;如图&#xff1a; 实现步骤 使用第三方插件&#xff1a;react-native-image-zoom-viewer 插件GitHub地址&#xff1a;https://git…

C++或C 实现AES ECB模式加密解密,支持官方验证

本文主要介绍 AES 算法的加解密方法。本文使用的语言为 C&#xff0c;调用的 AES 库为&#xff1a;cryptopp。 1 概述 AES 加密算法的介绍如下&#xff08;摘自 WikiPedia&#xff09;&#xff1a; 高级加密标准&#xff08;英语&#xff1a;Advanced Encryption Standard&am…

Kali Linux 2019.4用U盘安装以及解决Kali Linux 2019.4中文乱码问题

一、利用Win32 Disk Imager 实现U盘刻录ISO 1.Kali Linux官网下载 2.Win32 Disk Imager官网下载地址 3.打开Win32 Disk Imager软件&#xff0c;添加下载的镜像文件&#xff0c;选择制作镜像的U盘&#xff0c;点击“”“写入”&#xff0c;等待写入成功完成&#xff01; 二、…

Javascript实现AES加密解密(ECB/CBC)

环境配置 js文件https://code.google.com/archive/p/crypto-js/downloads在线AES加密解密地址在线AES加密解密、AES在线加密解密、AES encryption and decryption--查错网下载完成后在页面中引入 rollups/aes.jscomponents/mode-ecb.jscomponents/pad-nopadding.js引入后页面 …

在PHP中利用wsdl创建标准webservice

参照整理&#xff1a; http://bbs.php100.com/read-htm-tid-95228.htmlhttp://www.ieliwb.com/wsdl-create-soapdiscovery/ 说明&#xff1a; 非标准的webservice&#xff0c;可能只能PHP才能访问 标准的webservice&#xff0c;就必须要使用wsdl在这里我只介绍标准的webserv…

Kali-Linux2019.04设置中文输入法

1.打开超级终端&#xff0c;输入 apt-get install fcitx 首先安装输入法框架 2.输入apt-get install fcitx-googlepinyin 然后安装google输入法 3.如下图&#xff0c;打开fcitx输入法配置 4.通过左下角的“”“”添加&#xff0c;选择刚才安装的google中文输入法&#xff0c…