深度学习入门:自建数据集完成花鸟二分类任务

自建数据集完成二分类任务(参考文章)

1 图片预处理

1 .1 统一图片格式

找到的图片需要首先做相同尺寸的裁剪,归一化,否则会因为图片大小不同报错

RuntimeError: stack expects each tensor to be equal size,
but got [3, 667, 406] at entry 0 and [3, 600, 400] at entry 1

pytorch的torchvision.transforms模块提供了许多用于图片变换/增强的函数。

1.1.1 把图片不等比例压缩为固定大小
transforms.Resize((600,600)),
1.1.2 裁剪保留核心区

因为主体要识别的图像一般在中心位置,所以使用CenterCrop,这里设置为(400, 400)

transforms.CenterCrop((400,400)),
1.1.3 处理成统一数据类型

这里统一成torch.float64方便神经网络计算,也可以统一成其他比如uint32等类型

transforms.ConvertImageDtype(torch.float64),
1.1.4 归一化进一步缩小图片范围

对于图片来说0~255的范围有点大,并不利于模型梯度计算,我们应该进行归一化。pytorch当中也提供了归一化的函数torchvision.transforms.Normalize(mean,std)

  • 我们可以使用[0.5,0.5,0.5]mean,std来把数据归一化至[-1,1]
  • 也可以手动计算出所有的图片mean,std来归一化至均值为0,标准差为1的正态分布,
  • 一些深度学习代码常常使用mean=[0.485, 0.456, 0.406] ,std=[0.229, 0.224, 0.225]的归一化数据,这是在ImageNet的几百万张图片数据计算得出的结果
  • BN等方法也具有很出色的归一化表现,我们也会使用到

Juliuszh:详解深度学习中的Normalization,BN/LN/WN
Algernon:【基础算法】六问透彻理解BN(Batch Normalization)

我们这里使用简单的[0.5,0.5,0.5]归一化方法,更新cls_dataset,加入transform操作 ,作为图片裁剪的预处理。

transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

关于transforms的操作大体分为裁剪/翻转和旋转/图像变换/transform自身操作,具体见余霆嵩:PyTorch 学习笔记(三):transforms的二十二个方法,这里不进行详细展开。

1.2 数据增强

当数据集较小时,可以通过对已有图片做数据增强,利用之前提到的transforms中的函数 ,也可以混合使用来根据已有数据创造新数据

        self.data_enhancement = transforms.Compose([transforms.RandomHorizontalFlip(p=1),transforms.RandomRotation(30)])

2 创建自制数据集

2.1 以Dataset类接口为模版

class cls_dataset(Dataset):def __init__(self) -> None:# initializationdef __getitem__(self, index):# return data,label in set def __len__(self):# return the length of the dataset

2.2 创建set

2.2.1定义两个空列表data_list和target_list
2.2.2遍历文件夹
2.2.3读取图片对象,对每一个图片对象预处理后,分别将图片对象和对应的标签加入data_list和target_list中
2.2.4将data_list和target_list加入h5df_ile中
import os
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import h5py
from torchvision.io import read_imagetrain_pic_path = 'test-set'
test_pic_path = 'training-set'def create_h5_file(file_name):all_type = ['flower', 'bird']h5df_file = h5py.File(file_name, "w") #file_name指向比如"train.hdf5"这种文件路径,但这句话之前file_name指向路径为空#图片统一化处理transform = transforms.Compose([transforms.Resize((600, 600)),transforms.CenterCrop((400, 400)),transforms.ConvertImageDtype(torch.float64),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])#数据增强data_list = []   #建立一个保存图片张量的空列表target_list = [] #建立一个保存图片标签的空列表#遍历文件夹建立数据集'''文件夹组成| —— train|   | —— flower|   |   | —— 图片1|   | —— bird|   | —— | —— 图片2| —— test|   | —— flower|   | —— bird'''dataset_kind = file_name.split('.')[0]#先判断缺失的文件是训练集还是测试集if dataset_kind == 'train':pic_file_name = train_pic_pathelse:pic_file_name = test_pic_path#再循环遍历文件夹for file_name_dir, _, files in tqdm(os.walk(pic_file_name)):target = file_name_dir.split('/')[-1]if target in all_type:for file in files:pic = read_image(os.path.join(file_name_dir, file))  #以张量形式读取图片对象pic = transform(pic)    #预处理图片pic = np.array(pic).astype(np.float64)data_list.append(pic)   #将pic对象添加到列表里target_list.append(target.encode()) #将target编码后添加到列表里h5df_file.create_dataset("image", data=data_list)h5df_file.create_dataset("target", data=target_list)h5df_file.close()class h5py_dataset(Dataset):def __init__(self, file_name) -> None:super().__init__()self.file_name = file_name    #指向文件的路径名#如果file_name指向的h5文件不存在,就新建一个if not os.path.exists(file_name):create_h5_file(file_name)def __getitem__(self, index):with h5py.File(self.file_name, 'r') as f:if f['target'][index].decode() == 'bird':   #如果在f文件的target列表中查找到index下标对应的标签是birdtarget = torch.tensor(0)else:target = torch.tensor(1)return f['image'][index], targetdef __len__(self):with h5py.File(self.file_name, 'r') as f:return len(f['target'])def h5py_loader():train_file = 'train.hdf5'test_file = 'test.hdf5'train_dataset = h5py_dataset(train_file)test_dataset = h5py_dataset(test_file)train_data_loader = DataLoader(train_dataset, batch_size=4)test_data_loader = DataLoader(test_dataset, batch_size=4)return train_data_loader, test_data_loader

2.3 创建loader

实例化set对象后利用torch.utils.data.DataLoader

3 搭建网络

3.1 网络结构

在这里插入图片描述

3.2 参数计算

卷积后,池化后尺寸计算公式:
(图像尺寸-卷积核尺寸 + 2*填充值)/步长+1
(图像尺寸-池化窗尺寸 + 2*填充值)/步长+1

参考文章

3.3 不成文规定

池化参数一般就是(2, 2)

中间的channel数量都是自己设定的,二的次方就行

kernelsize一般3或者5之类的

4 训练

加深对前面数据集组成理解

    for _, data in enumerate(train_loader):if isinstance(data, list):image = data[0].type(torch.FloatTensor).to(device)target = data[1].to(device)elif isinstance(data, dict):image = data['image'].type(torch.FloatTensor).to(device)target = data['target'].to(device)else:print(type(data))raise TypeError

for 循环中data的组成来源于构建set时,

    h5df_file.create_dataset("image", data=data_list)h5df_file.create_dataset("target", data=target_list)

写入了h5df文件中两个dataset,但在文件中是以嵌套列表形式保存,其中data[0]等价于引用image这个dataset,data[1]等价于引用target这个集合

在这里插入图片描述

5 测试

6 保存模型

改进

投影概率放到网络里面

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/155454.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【OpenCV实现图像:在Python中使用OpenCV进行直线检测】

文章目录 概要霍夫变换举个栗子执行边缘检测进行霍夫变换小结 概要 图像处理作为计算机视觉领域的重要分支,广泛应用于图像识别、模式识别以及计算机视觉任务中。在图像处理的众多算法中,直线检测是一项关键而常见的任务。该任务的核心目标是从图像中提…

bitmap基础介绍+holo实现离线UV计算

bitmap 基础介绍bitmaping 数据结构bitmap计算算子集成二阶段分布式计算:RoaringBitmap构造方案分桶方案建序方案 holo官网 离线UV计算创建用户映射表创建聚合结果表更新用户映射表和聚合结果表更新聚合结果表UV、PV查询 基础介绍 RoaringBitmap主要为了解决UV指标…

第十一章 目标检测中的NMS

精度提升 众所周知,非极大值抑制NMS是目标检测常用的后处理算法,用于剔除冗余检测框,本文将对可以提升精度的各种NMS方法及其变体进行阶段性总结。 总体概要: 对NMS进行分类,大致可分为以下六种,这里是依…

训练跳跃(青蛙跳台阶),剑指offer,力扣

目录 题目地址: 题目: 青蛙跳台阶问题 我们直接看题解吧: 相似题目,斐波那契数列: 解题方法: 难度分析: 审题目事例提示: 解题思路: 代码实现: 小鸡识补充 题…

手机弱网测试工具:Charles

我们在测试app的时候,需要测试弱网情况下的一些场景,那么使用Charles如何设置弱网呢,请看以下步骤: 前提条件: 手机和电脑要在同一局域网内 Charles连接手机抓包 一、打开Charles,点击代理,…

如何搭建测试环境?一文解决你所有疑惑!

什么是测试环境 测试环境,指为了完成软件测试工作所必需的计算机硬件、软件、网络设备、历史数据的总称,简而言之,测试环境硬件软件网络数据准备测试工具。 硬件:指测试必需的服务器、客户端、网络连接等辅助设备。 软件&#…

Java 省考试院自学考试考籍管理系统

1) 项目简介 考籍管理系统是省考试院自学考试管理系统的一部分,包括考生考籍档案管理、考生免考管理、课程顶替、考籍转入转出管理、毕业管理和日志管理等功能模块。该项目的建设方便和加强了省考试院对自学考试考籍的一系列管理操作,社会效应明显。…

React函数组件状态Hook—useState《进阶-对象数组》

React函数组件状态-state 对象 state state 中可以保存任意类型的 JavaScript 值,包括对象。但是,你不应该直接修改存放在 React state 中的对象。相反,当你想要更新⼀个对象时,你需要创建⼀个新的对象(或者将其拷⻉⼀…

股票指标信息(六)

6-指标信息 文章目录 6-指标信息一. 展示股票的K线图数据,用于数据统计二. 展示股票指标数据,使用Java处理,集合形式展示三. 展示股票目前的最新的指标数据信息四. 展示股票指标数据,某一个属性使用Java处理五. 展示股票的指标数据,用于 Echarts 页面数据统计六. 展示股票指标数…

【开题报告】基于uni-app的汽车租赁app的设计与实现

1.项目背景及意义 项目背景: 随着人们生活水平的提高,汽车租赁服务在城市中变得越来越普及。传统的租车方式存在一些问题,比如租车流程繁琐、费用不透明、选择有限等。因此,开发一款基于uni-app的汽车租赁app成为了满足用户需求…

MAX/MSP SDK学习05:A_GIMME方法

今天终于将A_GIMME方法部分的描述看懂了,上周因为太赶时间加上这文档很抽象一直没看懂。也就那么一回事,记录一下。 A_GIMME方法用于接收多个参数: #include "ext.h" // standard Max include, always required #include "…

vue3拖拽排序 使用 vuedraggable

vue.draggable.next vue3 拖拽排序 vue.draggable.next 下载 pnpm add vuedraggablenext使用 <script lang"ts" setup> import { reactive } from "vue"; import draggable from "vuedraggable";const list reactive([{id: 1,name: &…

在Uni-app中实现计时器效果

本文将介绍如何在Uni-app中使用Vue.js的计时器功能实现一个简单的计时器效果。 首先&#xff0c;我们需要创建一个包含计时器的组件。以下是一个基本的计时器组件示例&#xff1a; <template><div class"timer"><p>{{ formatTime }}</p><…

Android 12.0 默认授予应用权限

Android 12.0 默认授予应用权限 最近接到客户需求提到每当首次点开某个应用时都会弹出申请权限的弹窗&#xff0c;操作起来感觉很麻烦&#xff0c;需要将指定的这个应用默认授予权限&#xff0c;具体修改参照如下&#xff1a; frameworks/base/services/core/java/com/androi…

uniapp 微信小程序如何实现多个item列表的分享

以下代码是某个循环里面的item <button class"cu-btn" style"background-color: transparent;padding: 0;"open-type"share" :data-tree"item.treeId" :data-project"item.projectId"v-if"typeId1 && userI…

阿里云3M固定带宽服务器速度快吗?是否够用?

阿里云服务器3M带宽下载速度是384KB/秒&#xff0c;上传速度是1280KB/s&#xff08;折合1.25M/秒&#xff09;&#xff0c;3M固定带宽够用吗&#xff1f;对于一般流量不是太大的个人博客、企业官网、论坛社区、小型电商网站或搭建个人学习环境或测试环境是完全够用的&#xff0…

Spring AOP用法(待完善)

Cglib实现AOP // 切所有方法Testpublic void cglib1() {UserService target new UserService();// 通过cglib实现AOPEnhancer enhancer new Enhancer();enhancer.setSuperclass(UserService.class);// 定义额外逻辑&#xff0c;也就是代理逻辑// o:代理对象; method:被代…

mysql 怎么做定时备份 / mysql 备份 / sql文件导出

在MySQL数据库中&#xff0c;你可以使用不同的方法来定时备份数据库。以下是其中的一种方法&#xff0c;使用Linux系统中的cron任务和mysqldump命令来创建定时备份&#xff1a; 创建备份脚本&#xff1a; 首先&#xff0c;创建一个脚本文件&#xff0c;比如backup_script.sh&am…

【ceph】ceph集群的故障域是怎么快速修改导入导出

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…