使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作

使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作

总共分为四步

  • 构造一个my_dataset类,继承自torch.utils.data.Dataset
  • 重写__getitem____len__ 类函数
  • 建立两个函数find_classeshas_file_allowed_extension,直接从这copy过去
  • 建立my_make_dataset函数用来构造(path,lable)对

一、构造一个my_dataset类,继承自torch.utils.data.Dataset

二、 重写__getitem____len__ 类函数

要构造Dataset的子类,就必须要实现两个方法:

  • getitem_(self, index):根据index来返回数据集中标号为index的元素及其标签。
  • len_(self):返回数据集的长度。
class my_dataset(Dataset):def __init__(self,root_original, root_cdtfed, transform=None):super(my_dataset, self).__init__()self.transform = transformself.root_original = root_originalself.root_cdtfed = root_cdtfedself.original_imgs = []self.cdtfed_imgs = []#add (img_path, label) to listsself.original_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)self.cdtfed_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)# super(my_dataset, self).__init__()def __getitem__(self, index):    #这个方法是必须要有的,用于按照索引读取每个元素的具体内容fn1, label1 = self.original_imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息fn2, label2 = self.cdtfed_imgs[index]img1 = Image.open(fn1).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片img2 = Image.open(fn2).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片if self.transform is not None:img1 = self.transform(img1) #是否进行transformimg2 = self.transform(img2) #是否进行transformimg_list = [img1, img2]label = label1name = fn1return img_list,label,name  #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分return len(self.original_imgs)

三、建立两个函数find_classeshas_file_allowed_extension,直接从这copy过去

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:"""Finds the class folders in a dataset.See :class:`DatasetFolder` for details."""classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:"""Checks if a file is an allowed extension.Args:filename (string): path to a fileextensions (tuple of strings): extensions to consider (lowercase)Returns:bool: True if the filename ends with one of given extensions"""return filename.lower().endswith(extensions)
  • 建立my_make_dataset函数用来构造(path,lable)对
def my_make_dataset(directory: str,class_to_idx: Optional[Dict[str, int]] = None,extensions: Optional[Tuple[str, ...]] = None,is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:"""Generates a list of samples of a form (path_to_sample, class).See :class:`DatasetFolder` for details.Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` functionby default."""directory = os.path.expanduser(directory)if class_to_idx is None:_, class_to_idx = find_classes(directory)elif not class_to_idx:raise ValueError("'class_to_index' must have at least one entry to collect any samples.")both_none = extensions is None and is_valid_file is Noneboth_something = extensions is not None and is_valid_file is not Noneif both_none or both_something:raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")if extensions is not None:def is_valid_file(x: str) -> bool:return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))is_valid_file = cast(Callable[[str], bool], is_valid_file)instances = []available_classes = set()for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):if is_valid_file(fname):path = os.path.join(root, fname)# item = path, [int(cl) for cl in target_class.split('_')]item = path, target_classinstances.append(item)if target_class not in available_classes:available_classes.add(target_class)empty_classes = set(class_to_idx.keys()) - available_classesif empty_classes:msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "if extensions is not None:msg += f"Supported extensions are: {', '.join(extensions)}"raise FileNotFoundError(msg)return instances #instance:[item:(path, int(class_name)), ]

附录:完整代码

我这里传入两个root_dir,因为我要用一个dataset加载两个数据集,分别放在data1和data2里

class my_dataset(Dataset):def __init__(self,root_original, root_cdtfed, transform=None):super(my_dataset, self).__init__()self.transform = transformself.root_original = root_originalself.root_cdtfed = root_cdtfedself.original_imgs = []self.cdtfed_imgs = []#add (img_path, label) to listsself.original_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)self.cdtfed_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)# super(my_dataset, self).__init__()def __getitem__(self, index):    #这个方法是必须要有的,用于按照索引读取每个元素的具体内容fn1, label1 = self.original_imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息fn2, label2 = self.cdtfed_imgs[index]img1 = Image.open(fn1).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片img2 = Image.open(fn2).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片if self.transform is not None:img1 = self.transform(img1) #是否进行transformimg2 = self.transform(img2) #是否进行transformimg_list = [img1, img2]label = label1name = fn1return img_list,label,name  #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分return len(self.original_imgs)def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:"""Finds the class folders in a dataset.See :class:`DatasetFolder` for details."""classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:"""Checks if a file is an allowed extension.Args:filename (string): path to a fileextensions (tuple of strings): extensions to consider (lowercase)Returns:bool: True if the filename ends with one of given extensions"""return filename.lower().endswith(extensions)def my_make_dataset(directory: str,class_to_idx: Optional[Dict[str, int]] = None,extensions: Optional[Tuple[str, ...]] = None,is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:"""Generates a list of samples of a form (path_to_sample, class).See :class:`DatasetFolder` for details.Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` functionby default."""directory = os.path.expanduser(directory)if class_to_idx is None:_, class_to_idx = find_classes(directory)elif not class_to_idx:raise ValueError("'class_to_index' must have at least one entry to collect any samples.")both_none = extensions is None and is_valid_file is Noneboth_something = extensions is not None and is_valid_file is not Noneif both_none or both_something:raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")if extensions is not None:def is_valid_file(x: str) -> bool:return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))is_valid_file = cast(Callable[[str], bool], is_valid_file)instances = []available_classes = set()for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):if is_valid_file(fname):path = os.path.join(root, fname)# item = path, [int(cl) for cl in target_class.split('_')]item = path, target_classinstances.append(item)if target_class not in available_classes:available_classes.add(target_class)empty_classes = set(class_to_idx.keys()) - available_classesif empty_classes:msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "if extensions is not None:msg += f"Supported extensions are: {', '.join(extensions)}"raise FileNotFoundError(msg)return instances #instance:[item:(path, int(class_name)), ]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/507907.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python文件和数据的格式化_Python在文本文件中格式化特定数据

谢谢你们的帮助。作为一个新手,我最终得到的代码不是那么优雅,但它仍然起作用:)。在#open the file and create the CSV after filtering the input file.def openFile(filename, keyword): #defines the function to open the file. User to…

windows功能_这 12 个好用 Windows 软件,让你也能用上 macOS 的独占功能

在离开 macOS 这段时间,每天在家依赖 Windows To Go 为生,感到日常工作流程在四处冒烟。这才发现 macOS 的有些特性就如同空气一样,虽然毫无存在感,却不可缺失。关于「如何在 Windows 中实现 macOS 的 xxx」,随便上网一…

Batch Normalization、Layer Normalization、Group Normalization、Instance Normalization原理、适用场景和实际使用经验

Batch Normalization、Layer Normalization、Group Normalization、Instance Normalization原理、适用场景和使用经验 一、 简单介绍各种Normalization 先放一张来自Group Normalization原论文中的图,个人认为这个图很形象,以此图直观感受一下各种归一…

blockly自定义中文出问题_[BlocklyNukkit入门]#5自定义物品

自定义物品创建一个木棍item blockitem.buildItem(280, 0, 1);设置名字item.setCustomName("棍");设置信息,用分号隔开换行blockitem.setItemLore(item, "第一行;第二行;第三行;第四行");添加有序合成添加有序合成,设置G为橡木原木的键,G就代表原木.参数1…

收发一体超声波测距离传感器模块_芜湖低功耗超声波液位计物位计设备排名

KUS 超声波液位物位计 8种工作状态设置指导 1), 窗口常开模式(模拟量输出产品为正线性工作模式或者距离测量模式)2), 窗口常闭模式(模拟量输出产品为负线性工作模式或者液位测量模式)3), 单点常开模4), 单点常闭模式。5), 单点常开带大滞回区间模式6), 单点常闭带大滞回区间模式…

学术写作科研工具推荐

最近在写论文,然后过程中觉得一些工具可以提升效率,所以就简单总结一下,以后也会逐渐更新 Ccf deadlines:我该赶哪个ddl呢? 一些主要ccf A、B、C类会议的截稿日期都会被统计显示在此哦,看看你想投哪个会吧…

如何加声调口诀_声母韵母口诀顺口溜歌曲(怎么快速记住声母韵母)

家有一年级的小朋友,或者是孩子即将上小学的爸爸妈妈们,孩子的拼音学习进行得怎么样了?拼音是孩子进行语文学习的第一课,也是基础。但对很多小朋友来说真的是一道拦路虎。很多孩子由于一年级拼音基础不牢,到了四五年级…

pytorch model.train() 和model.eval() 对 BN 层的影响

model.train() BN做归一化时,使用的均值和方差是当前这个Batch的如果这时 track_running_statsTrue, 则会更新running_mean 和 running_var但是,running_mean 和 running_var不用在训练阶段 model.eval() BN 做归一化时,使用的…

联想用u盘重装系统步骤_详解联想如何使用u盘重装win10系统

联想是国内知名的品牌之一,很多朋友都购买了联想品牌的电脑,但是在使用的过程中难免会出现些磕磕碰碰的问题。所以今天小编就大家详细的介绍一下联想电脑使用u盘重装win10系统的方法。联系怎么使用u盘重装win10系统呢?最近有不少朋友在询问这…

笔记本电脑键盘切换_真想本小新13pro搭档,笔记本电脑周边好物清单推荐

原标题:真想本小新13pro搭档,笔记本电脑周边好物清单推荐真想本小新13pro搭档,笔记本电脑周边好物清单推荐 2020-10-24 15:21:493点赞4收藏2评论9月28日 - 11月12日,参与#双11购物攻略#征稿活动,赢取苹果全家桶8888元超…

pytorch 训练模型很慢,卡在数据读取,卡I/O的有效解决方案

多线程加载 在 datalaoder中指定num_works > 0,多线程加载数据集,最大可设置为 cpu 核数设置 pin_memory True, 固定内存访问单元,节约内存调度时间示例如下: loader DataLoader(dataset,batch_sizebatch_size * group_size,shuffleTr…

python达梦数据库_python 操作达 梦数据库

python 达梦数据库操作流程连接数据库 dm.connect( ... )获取游标 dm_conn.cursor()编写SQL语句 sql_str执行SQL语句 dm_cursor.execute()获取结果列表 dt_breakpoint dm_cursor.fetchall()关闭游标 dm_cursor.close()关闭数据库连接 dm_conn.close()代码示例import pandas as…

C++求复数的角度_11.初中数学:方程5x2m=4x的解,在2与10之间,怎么求m的取值范围?...

欢迎您来到方老师数学课堂,请点击上方蓝色字体,关注方老师数学课堂。所有的视频内容,全部免费,请大家放心关注,放心订阅。初中数学:方程5x-2m-4-x的解,在2与10之间,怎么求m的取值范围…

python3 beautifulsoup 模块详解_关于beautifulsoup模块的详细介绍

这篇文章主要给大家介绍了python中 Beautiful Soup 模块的搜索方法函数。 方法不同类型的过滤参数能够进行不同的过滤,得到想要的结果。文中介绍的非常详细,对大家具有一定的参考价值,需要的朋友们下面来一起看看吧。前言我们将利用 Beautifu…

python解zuobiaoxi方程_欧式期权定价的python实现

0. pre 在《给你的二叉树期权定价》中就挖了坑要写期权定价的代码,这会有时间来填坑啦。本文将会用python实现欧式期权定价。具体的定价算法分别是基于BS公式的、蒙特卡洛的以及二叉树的。对于二叉树和BS公式还不熟悉的小伙伴可以移步至往期关于二叉树期权定价和BS公…

去除标签_有效去除“狗皮膏药”标签,快学起来吧

去除商品标签向来是比较头疼一件事,有时候在去掉标签后会留下粘性残留物,它会粘上灰尘和其他脏东西,把表面变成脏兮兮的颜色,让人看着太不舒服了。其实去除标签残留粘胶并不难,可能家里就有去除它的工具哦~那今天小编就…

win10很多软件显示模糊_还在使用第三方软件?Win10可以直接显示显卡温度啦

微软刚刚开始向参与快速通道测试的用户推送Windows 10 20H1 Build 18963 版带来部分新功能和优化等。这个版本也是常规优化版本因此带来的新功能很少,但这次更新为任务管理器带来原生的显示显卡温度功能。用户打开任务管理器点击性能选项卡然后找到「独立显卡」即可…

分数怎么化成带分数_小升初数学总复习第三个基础模块:分数的认识

今天我们开始小升初数学总复习第三个基础模块的复习:分数的认识分数的认识一共分为8个知识考点。第一,分数的意义把单位“1”.平均分成若干份,表示这样的一份或者几份的数叫做分数。表示其中一份的数叫做分数单位。第二&#xff0…

active mq topic消费后删除_《我想进大厂》之MQ夺命连环11问

继之前的mysql夺命连环之后,我发现我这个标题被好多套用的,什么夺命zookeeper,夺命多线程一大堆,这一次,开始面试题系列MQ专题,消息队列作为日常常见的使用中间件,面试也是必问的点之一&#xf…

嘀嗒还是滴答_2021年顺风车车主口碑榜!滴滴、滴答、一喂顺风车成TOP3

出行平台烧钱抢用户抢司机,大家都见怪不怪了,只是近期平台为自身利益而牺牲司机的例子层出不穷,在司机刚进入平台补贴多流水多,没多久司机收入都不够交车租的,司机踩坑,全家受罪,很多司机表示自…