深度学习入门:自建数据集完成花鸟二分类任务

自建数据集完成二分类任务(参考文章)

1 图片预处理

1 .1 统一图片格式

找到的图片需要首先做相同尺寸的裁剪,归一化,否则会因为图片大小不同报错

RuntimeError: stack expects each tensor to be equal size,
but got [3, 667, 406] at entry 0 and [3, 600, 400] at entry 1

pytorch的torchvision.transforms模块提供了许多用于图片变换/增强的函数。

1.1.1 把图片不等比例压缩为固定大小
transforms.Resize((600,600)),
1.1.2 裁剪保留核心区

因为主体要识别的图像一般在中心位置,所以使用CenterCrop,这里设置为(400, 400)

transforms.CenterCrop((400,400)),
1.1.3 处理成统一数据类型

这里统一成torch.float64方便神经网络计算,也可以统一成其他比如uint32等类型

transforms.ConvertImageDtype(torch.float64),
1.1.4 归一化进一步缩小图片范围

对于图片来说0~255的范围有点大,并不利于模型梯度计算,我们应该进行归一化。pytorch当中也提供了归一化的函数torchvision.transforms.Normalize(mean,std)

  • 我们可以使用[0.5,0.5,0.5]mean,std来把数据归一化至[-1,1]
  • 也可以手动计算出所有的图片mean,std来归一化至均值为0,标准差为1的正态分布,
  • 一些深度学习代码常常使用mean=[0.485, 0.456, 0.406] ,std=[0.229, 0.224, 0.225]的归一化数据,这是在ImageNet的几百万张图片数据计算得出的结果
  • BN等方法也具有很出色的归一化表现,我们也会使用到

Juliuszh:详解深度学习中的Normalization,BN/LN/WN
Algernon:【基础算法】六问透彻理解BN(Batch Normalization)

我们这里使用简单的[0.5,0.5,0.5]归一化方法,更新cls_dataset,加入transform操作 ,作为图片裁剪的预处理。

transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

关于transforms的操作大体分为裁剪/翻转和旋转/图像变换/transform自身操作,具体见余霆嵩:PyTorch 学习笔记(三):transforms的二十二个方法,这里不进行详细展开。

1.2 数据增强

当数据集较小时,可以通过对已有图片做数据增强,利用之前提到的transforms中的函数 ,也可以混合使用来根据已有数据创造新数据

        self.data_enhancement = transforms.Compose([transforms.RandomHorizontalFlip(p=1),transforms.RandomRotation(30)])

2 创建自制数据集

2.1 以Dataset类接口为模版

class cls_dataset(Dataset):def __init__(self) -> None:# initializationdef __getitem__(self, index):# return data,label in set def __len__(self):# return the length of the dataset

2.2 创建set

2.2.1定义两个空列表data_list和target_list
2.2.2遍历文件夹
2.2.3读取图片对象,对每一个图片对象预处理后,分别将图片对象和对应的标签加入data_list和target_list中
2.2.4将data_list和target_list加入h5df_ile中
import os
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import h5py
from torchvision.io import read_imagetrain_pic_path = 'test-set'
test_pic_path = 'training-set'def create_h5_file(file_name):all_type = ['flower', 'bird']h5df_file = h5py.File(file_name, "w") #file_name指向比如"train.hdf5"这种文件路径,但这句话之前file_name指向路径为空#图片统一化处理transform = transforms.Compose([transforms.Resize((600, 600)),transforms.CenterCrop((400, 400)),transforms.ConvertImageDtype(torch.float64),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])#数据增强data_list = []   #建立一个保存图片张量的空列表target_list = [] #建立一个保存图片标签的空列表#遍历文件夹建立数据集'''文件夹组成| —— train|   | —— flower|   |   | —— 图片1|   | —— bird|   | —— | —— 图片2| —— test|   | —— flower|   | —— bird'''dataset_kind = file_name.split('.')[0]#先判断缺失的文件是训练集还是测试集if dataset_kind == 'train':pic_file_name = train_pic_pathelse:pic_file_name = test_pic_path#再循环遍历文件夹for file_name_dir, _, files in tqdm(os.walk(pic_file_name)):target = file_name_dir.split('/')[-1]if target in all_type:for file in files:pic = read_image(os.path.join(file_name_dir, file))  #以张量形式读取图片对象pic = transform(pic)    #预处理图片pic = np.array(pic).astype(np.float64)data_list.append(pic)   #将pic对象添加到列表里target_list.append(target.encode()) #将target编码后添加到列表里h5df_file.create_dataset("image", data=data_list)h5df_file.create_dataset("target", data=target_list)h5df_file.close()class h5py_dataset(Dataset):def __init__(self, file_name) -> None:super().__init__()self.file_name = file_name    #指向文件的路径名#如果file_name指向的h5文件不存在,就新建一个if not os.path.exists(file_name):create_h5_file(file_name)def __getitem__(self, index):with h5py.File(self.file_name, 'r') as f:if f['target'][index].decode() == 'bird':   #如果在f文件的target列表中查找到index下标对应的标签是birdtarget = torch.tensor(0)else:target = torch.tensor(1)return f['image'][index], targetdef __len__(self):with h5py.File(self.file_name, 'r') as f:return len(f['target'])def h5py_loader():train_file = 'train.hdf5'test_file = 'test.hdf5'train_dataset = h5py_dataset(train_file)test_dataset = h5py_dataset(test_file)train_data_loader = DataLoader(train_dataset, batch_size=4)test_data_loader = DataLoader(test_dataset, batch_size=4)return train_data_loader, test_data_loader

2.3 创建loader

实例化set对象后利用torch.utils.data.DataLoader

3 搭建网络

3.1 网络结构

在这里插入图片描述

3.2 参数计算

卷积后,池化后尺寸计算公式:
(图像尺寸-卷积核尺寸 + 2*填充值)/步长+1
(图像尺寸-池化窗尺寸 + 2*填充值)/步长+1

参考文章

3.3 不成文规定

池化参数一般就是(2, 2)

中间的channel数量都是自己设定的,二的次方就行

kernelsize一般3或者5之类的

4 训练

加深对前面数据集组成理解

    for _, data in enumerate(train_loader):if isinstance(data, list):image = data[0].type(torch.FloatTensor).to(device)target = data[1].to(device)elif isinstance(data, dict):image = data['image'].type(torch.FloatTensor).to(device)target = data['target'].to(device)else:print(type(data))raise TypeError

for 循环中data的组成来源于构建set时,

    h5df_file.create_dataset("image", data=data_list)h5df_file.create_dataset("target", data=target_list)

写入了h5df文件中两个dataset,但在文件中是以嵌套列表形式保存,其中data[0]等价于引用image这个dataset,data[1]等价于引用target这个集合

在这里插入图片描述

5 测试

6 保存模型

改进

投影概率放到网络里面

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/155454.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【OpenCV实现图像:在Python中使用OpenCV进行直线检测】

文章目录 概要霍夫变换举个栗子执行边缘检测进行霍夫变换小结 概要 图像处理作为计算机视觉领域的重要分支,广泛应用于图像识别、模式识别以及计算机视觉任务中。在图像处理的众多算法中,直线检测是一项关键而常见的任务。该任务的核心目标是从图像中提…

bitmap基础介绍+holo实现离线UV计算

bitmap 基础介绍bitmaping 数据结构bitmap计算算子集成二阶段分布式计算:RoaringBitmap构造方案分桶方案建序方案 holo官网 离线UV计算创建用户映射表创建聚合结果表更新用户映射表和聚合结果表更新聚合结果表UV、PV查询 基础介绍 RoaringBitmap主要为了解决UV指标…

第十一章 目标检测中的NMS

精度提升 众所周知,非极大值抑制NMS是目标检测常用的后处理算法,用于剔除冗余检测框,本文将对可以提升精度的各种NMS方法及其变体进行阶段性总结。 总体概要: 对NMS进行分类,大致可分为以下六种,这里是依…

手机弱网测试工具:Charles

我们在测试app的时候,需要测试弱网情况下的一些场景,那么使用Charles如何设置弱网呢,请看以下步骤: 前提条件: 手机和电脑要在同一局域网内 Charles连接手机抓包 一、打开Charles,点击代理,…

如何搭建测试环境?一文解决你所有疑惑!

什么是测试环境 测试环境,指为了完成软件测试工作所必需的计算机硬件、软件、网络设备、历史数据的总称,简而言之,测试环境硬件软件网络数据准备测试工具。 硬件:指测试必需的服务器、客户端、网络连接等辅助设备。 软件&#…

Java 省考试院自学考试考籍管理系统

1) 项目简介 考籍管理系统是省考试院自学考试管理系统的一部分,包括考生考籍档案管理、考生免考管理、课程顶替、考籍转入转出管理、毕业管理和日志管理等功能模块。该项目的建设方便和加强了省考试院对自学考试考籍的一系列管理操作,社会效应明显。…

React函数组件状态Hook—useState《进阶-对象数组》

React函数组件状态-state 对象 state state 中可以保存任意类型的 JavaScript 值,包括对象。但是,你不应该直接修改存放在 React state 中的对象。相反,当你想要更新⼀个对象时,你需要创建⼀个新的对象(或者将其拷⻉⼀…

股票指标信息(六)

6-指标信息 文章目录 6-指标信息一. 展示股票的K线图数据,用于数据统计二. 展示股票指标数据,使用Java处理,集合形式展示三. 展示股票目前的最新的指标数据信息四. 展示股票指标数据,某一个属性使用Java处理五. 展示股票的指标数据,用于 Echarts 页面数据统计六. 展示股票指标数…

MAX/MSP SDK学习05:A_GIMME方法

今天终于将A_GIMME方法部分的描述看懂了,上周因为太赶时间加上这文档很抽象一直没看懂。也就那么一回事,记录一下。 A_GIMME方法用于接收多个参数: #include "ext.h" // standard Max include, always required #include "…

RedisConnectionFactory is required已解决!!!!

1.起因🤶🤶🤶🤶 redis搭建完成后,准备启动主程序,异常兴奋,结果报错了!!!! 2.究竟是何原因 😭😭😭&#x1f…

关于在3dsmax中制作的模型导入UE后尺寸大小不对的问题

现象 在3dsmax中的基本单位为毫米 在UE中基本单位是厘米 我在3dsmax中创建一个长宽高均为1000mm的方块 然后导入到UE中的世界坐标原点 方块向X轴正方向移动100个单位100cm1000mm,按理来说,新方块的此时应该和旧方块是贴着的,但是现象确是两者…

力扣 2. 两数相加

Problem: 2. 两数相加 思路与算法 Code /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}* ListNode(int val) { this.val val; }* ListNode(int val, ListNode next) { this.val val; this…

NSSCTF web刷题记录6

文章目录 [HZNUCTF 2023 final]eznode[MoeCTF 2021]地狱通讯-改[红明谷CTF 2022] Smarty Calculator方法一 CVE-2021-26120方法二 CVE-2021-29454方法三 写马蚁剑连接 [HZNUCTF 2023 final]eznode 考点:vm2沙箱逃逸、原型链污染 打开题目,提示找找源码 …

QT打包圆心识别

圆心点识别QT界面封装 最近在练习QT相关内容,找了个相关功能集成了下,主要是为了熟悉各个组件,功能主要是进行圆心识别。 主要涉及的QT功能点: 1.日志可视化 2.按钮及各类参数添加组件 3.水印添加及图片可视化 4.许可添加 5.主线…

OpenLayers实战,WebGL图层根据Feature要素的变量动态渲染多种颜色的三角形,适用于大量三角形渲染不同颜色

专栏目录: OpenLayers实战进阶专栏目录 前言 本章使用OpenLayers根据Feature要素的变量动态渲染不同颜色的三角形。 通过一个WebGL图层生成四种不同颜色的图形要素,适用于WebGL图层需要根据大量点要素区分颜色显示的需求。 更多的WebGL图层使用运算符动态生成样式的内容将会…

测试用例的8大设计原则

我们看到的大部分关于测试用例设计的文章,都在讲等价类、因果图、流程法等内容,这是关于测试用例的具体设计方法层面。本文想讨论的重点是,测试用例设计该遵循什么原则,有哪些思维和观点有助于产出更好的测试设计,这些…

CNP实现应用CD部署

上一篇整体介绍了cnp的功能,这篇重点介绍下CNP产品应用开发的功能。 简介 CNP的应用开发,主要是指的应用CD部署的配置管理。 应用列表,用来创建一个应用,一般与项目对应,也可以多个应用对应到一个项目。具体很灵活。…

结合两个Python小游戏,带你复习while循环、if判断、函数等知识点

💐作者:insist-- 💐个人主页:insist-- 的个人主页 理想主义的花,最终会盛开在浪漫主义的土壤里,我们的热情永远不会熄灭,在现实平凡中,我们终将上岸,阳光万里 ❤️欢迎点…

Ubuntu18.04安装LeGO-LOAM保姆级教程

系统环境:Ubuntu18.04.6 LTS 1.LeGO-LOAM的安装前要求: 1.1 ROS安装:参考我的另一篇博客Ubuntu18.04安装ROS-melodic保姆级教程_灬杨三岁灬的博客-CSDN博客文章浏览阅读168次。Ubuntu18.04安装ROS-melodic保姆级教程https://blog.csdn.net/…