【Pytorch神经网络理论篇】 34 样本均衡+分类模型常见损失函数

同学你好!本文章于2021年末编写,获得广泛的好评!

故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现,

Pytorch深度学习·理论篇(2023版)目录地址为:

CSDN独家 | 全网首发 | Pytorch深度学习·理论篇(2023版)目录本专栏将通过系统的深度学习实例,从可解释性的角度对深度学习的原理进行讲解与分析,通过将深度学习知识与Pytorch的高效结合,帮助各位新入门的读者理解深度学习各个模板之间的关系,这些均是在Pytorch上实现的,可以有效的结合当前各位研究生的研究方向,设计人工智能的各个领域,是经过一年时间打磨的精品专栏!https://v9999.blog.csdn.net/article/details/127587345欢迎大家订阅(2023版)理论篇

以下为2021版原文~~~~

 

1 样本均衡

当训练样本不均衡时,可以采用过采样、欠采样、数据增强等手段来避免过拟合。

1.1 使用权重采样类

Sampler类中有一个派生的权重采样类WeightedRandomSampler,能够在加载数据时,按照指定的概率进行随机顺序采样。

WeightedRandomSampler(samples_weight, samples_num)

1、weights:对应的是“样本”的权重而不是“类别的权重”。 也就是说:有一千个样本,weight的数值就有一千个,因此有 len(weight)= 样本数。
2、num_sampes:提取的样本数目,待选取的样本数目一般小于全部的样本数
3、replacement:用于指定是否可以重复选取某一个样本,默认为True,即允许在个epoch中重复选取某一个样本。如果设为False,则当某一类的样本被全部选完,但其样本数自仍未达到num_samples时,sampler将不会再从该类中选取本,此时可能导致weights参数失效。

1.2 WeightedRandomSampler图文解释

如下图,weight是一些tensor,代表每个位置的样本对应的权重,WeightedRandomSampler(weights, 6, True) ,表示按照weight给出的权重,生成六个索引,而且是重复取样。

 从输出可以看出,位置 [1] = 10 由于权重较大,被采样的次数较多,位置[0]由于权重为0所以没有被采样到,其余位置权重低所以都仅仅被采样一次。

1.2.1 获得某个数据集的权重手动计算方法

weight = [ ] 里面每一项代表该样本种类占总样本的倒数。

例如: 数据集 animal = [ cat, cat, dog, dog, dog],cat有两个,dog有三个。

解:
        第一步:先计算每种动物的占比,
                cat_count = 2/5 = 0.4
                dog_count = 3/5 = 0.6

        第二步:再计算count的倒数,也就是占比的倒数,这个数值就是weight
                cat_weight = 1/count = 1/0.4 = 2.5
                dog_weight = 1/count = 1/0.6 = 1.67

        第三步:生成权重
                weight 列表就可以写作:weight = [2.5, 2.5, 1.67, 1.67, 1.67]

1.3 WeightedRandomSampler代码实战

1.3.1 把1000条数据,概率相等的采样,采200条数据:

from torch.utils.data import WeightedRandomSamplerweights=[1]*1000bbb=list(WeightedRandomSampler(weights, 200, replacement=True))print(bbb)

1.3.2 dataset类上的实现

weights =aaa=[1]*20000
sampler=WeightedRandomSampler(weights,num_samples=200,replacement=True)_image_size = 32
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]
trans = transforms.Compose([transforms.RandomCrop(_image_size),# transforms.RandomHorizontalFlip(),# transforms.ColorJitter(.3, .3, .3),transforms.ToTensor(),# transforms.Normalize(_mean, _std),
])if __name__ == '__main__':train_ds = DogsCatsDataset(r"D:\data\ocr\wanqu\archive", "train", transform=trans)train_dl = DataLoader(train_ds, batch_size=2,num_workers=1,sampler=sampler)# train_dl = DataLoader(train_ds, batch_size=20,num_workers=1,shuffle=True)for i, (data, target) in enumerate(train_dl):# print(i,target)if len(np.where(target.numpy() == 1)[0])>0:print('find 1')

1.3.3  Tip

在Dataloader类中,使用了采样器Sampler类就不能使用shume参数。

1.4 权重采样的影响

通过采样的方式进行样本均衡,只是一种辅助手段,它也会引入一些新的问题。在条件允许的情况下,还是推荐将所收集的样本尽量趋于均衡。

1.4.1 过采样

重复正比例数据,实际上没有为模型引入更多数据,过分强调正比例数据,会放大正比例噪声对模型的影响。

1.4.2 欠采样

丢弃大量数据,和过采样一样会存在过拟合的问题。

1.5 通过权重损失控制样本均衡

在多标签非互斥的分类任务(一个对象可以被预测出多种分类)中,还可以使用   BCEWithLogitsLoss函数,在计算损失时为每个类别分配不同的权重。

这种方式可以使模型对每个类别的预测能力达到均衡。例如,多分类的个数是6,则可以使用类似的代码指定每个分类的权重:

pos_weight = torch.ones( [6] )#为每个分类指定权重为1
criterion = torch.nn.BCEwithLogitsioss( posweight = pos_weight)

2 分类模型常用的损失函数

2.1 BCELoss

用于单标签二分类或者多标签二分类,即一个样本可以有多个分类,彼此不互斥。输出和目标的维度是(batch,C),batch是样本数量,C是类别数量。每个C值代表属于一类标签的概率。

2.2 BCEWthLogtsLoss

用于单标签二分类或者多标签二分类,它相当于Sigmoid+BCELoss,即对网络输出的结果先做一次Sigmoid将其值域变为[0,1],再对其与标签之间做BCELoss。

当网络最后一层使用nn.Sigmoid时,就用BCELoss。

当网络最后一层不使用nn.Sigmoid时,就用BCEWithLogitsLoss。

2.3 CrossEntropyLoss

用于多类别分类,输出和目标的维度是(batch,C),batch是样本数量,C是类别数量。每一个C之间是互斥的,相互关联的。

对于每一个batch的C个值,一起求每个C的softmax,所以每个batch的所有C个值之和是1,哪个值大,代表其属于哪一类。

若用于二分类,那输出和目标的维度是(batch,2)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469251.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

安卓 camera 调用流程_安卓如何做出微信那样的界面仿微信“我”的界面1/5

本系列目标通过安卓编程仿写微信“我”的界面,让大家也能做出类似微信界面.效果图如下:本文目标做出页面顶部的相机部分(其他部分在后续文章中逐步分享).效果图如下:实现方案通过截图工具或者下载一张照相机照片,放到工程的src/main/res/drawable目录下,命名为camera.png添加一…

【Pytorch神经网络实战案例】26 MaskR-CNN内置模型实现目标检测

1 Pytorch中的目标检测内置模型 在torchvision库下的modelsldetecton目录中,找到__int__.py文件。该文件中存放着可以导出的PyTorch内置的目标检测模型。 2 MaskR-CNN内置模型实现目标检测 2.1 代码逻辑简述 将COCO2017数据集上的预训练模型maskrcnm_resnet50_fp…

【Pytorch神经网络实战案例】27 MaskR-CNN内置模型实现语义分割

1 PyTorch中语义分割的内置模型 在torchvision库下的models\segmentation目录中,找到segmentation.Py文件。该文件中存放着PyTorch内置的语义分割模型。 2 MaskR-CNN内置模型实现语义分割 2.1 代码逻辑简述 将COCO 2017数据集上的预训练模型dceplabv3_resnet101…

怎么查看电脑内存和配置_电脑内存不足处理方法,电脑卡死处理方法。

超过10万人正在关注赶快来关注吧,这里有你想找的热点资讯,这里有你想要的各种资料,还有海量的资源,还在等什么。快来关注,大佬带你开车。电脑系统经常奔溃,软件经常运行不了,开不了机&#xff0…

前端开源项目周报0307

由OpenDigg 出品的前端开源项目周报第十一期来啦。我们的前端开源周报集合了OpenDigg一周来新收录的优质的前端开源项目,方便前端开发人员便捷的找到自己需要的项目工具等。 react-trend 简单优雅的光线 react-progressive-web-app 优化ProgressiveWeb应用开发 pull…

【Pytorch神经网络理论篇】 35 GaitSet模型:步态识别思路+水平金字塔池化+三元损失

同学你好!本文章于2021年末编写,获得广泛的好评! 故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现, Pytorch深度学习理论篇(2023版)目录地址…

win7分区软件_神奇的工作室win7旗舰版重装系统连不上网怎么解决

深度技术win7系统下载有的时刻我们的电脑安装、重装了win10操作系统之后有的小伙伴们就发现了自己的电脑连不上网了。对于这种问题小编以为可能是我们的电脑在安装系统的过程中泛起了一些内部组件的冲突或者是由于网卡驱动没有安装好导致的,可以通过重新安装、重装驱…

【Pytorch神经网络实战案例】28 GitSet模型进行步态与身份识别(CASIA-B数据集)

1 CASIA-B数据集 本例使用的是预处理后的CASIA-B数据集, 数据集下载网址如下。 http://www.cbsr.ia.ac.cn/china/Gait%20Databases%20cH.asp 该数据集是一个大规模的、多视角的步态库。其中包括124个人,每个人有11个视角(0,18&am…

Android Camera调用流程

一个流程图画的非常好的文章 http://blog.csdn.net/lushengchu_luis/article/details/11033095 1、Packages/apps/到framework 打开Camera ./packages/apps/Camera/src/com/android/camera/Camera.java 进来第一个肯定是onCreate(Bundle icicle) { 这里是开始了一个Camera…

【Pytorch神经网络实战案例】29 【代码汇总】GitSet模型进行步态与身份识别(CASIA-B数据集)

1 GaitSet_DataLoader.py import numpy as np # 引入基础库 import os import torch.utils.data as tordata from PIL import Image from tqdm import tqdm import random# 1.1定义函数,加载文件夹的文件名称# load_data函数, 分为3个步骤:…

linq from 多个sum_快手重拳打击劣质电商 7月以来封禁700多个团伙账号

何为劣质电商?炒作演戏?PK售卖劣质商品?私下交易?夸大其词?……在快手电商的定义里,有上述不良行为的,都可以定义为劣质电商。快手电商站内官方号“快手卖货助手”日前发布第 11 期“自售或PK销…

win10怎么更改账户名称_Win10邮件功能如何查看邮件

win10的用户当中,一方面有说系统臃肿的,另外一方面有说功能多了不少,好用。不管是出于前者还是后者,win10功能确实多了不少,尤其是一些比较常用的功能,比如说邮箱功能,一般用户可能会选择登录网…

AttributeError: ‘set‘ object has no attribute ‘items‘

AttributeError: ‘set’ object has no attribute ‘items’ 出现这个问题,原因可能是定义的header有问题 正确如下: header{“key”:“value”} 如果是直接在请求数据中复制,很有可能会忽略键和值的冒号。

使用eclipse以及Juint进行测试

打开eclipse后,点击左上角的File,新建一个project,命名为testJunit,然后在src目录下新建两个包,分别命名为TestScore和Test(这是文件夹里没有文件所以是白色)。 在TestScore中新建一个class,命名为Score.ja…

excel单元格下拉选项怎么设置_单元格下拉效果怎么实现?

单元格右边的下拉菜单怎么做的?感觉逼格略有提升啊上视频单元格下来效果https://www.zhihu.com/video/1249633577441800192

电脑如何测网速_物联网卡的网速到底怎么样呢

最近不少朋友发私信问我,物联网卡网速到底怎么样,和手机卡的网速有没有什么区别?其实关于网速这个问题,我已经重复解释了很多遍。只要是走公网的流量卡,在不限速的情况下,基本是和你手机卡网速是一致的&…

dll可以在linux下使用吗_Linux下安装和使用杀毒软件AntiVir

小白玩转智能数据湖,20分钟开发实时豆瓣评分Top20电影的脚本!>>> 提起计算机病毒来,可谓人人皆知,有些吃过病毒苦头的人更是有点谈虎色变的感觉。其实无论对于企业还是个人,病毒的危害都是不可避免的&#xf…

[转]微信小程序登录逻辑梳理

本文转自:http://www.jianshu.com/p/d9996cafdb31 官方文档 文档相关地址: 用户登录 获取用户数据 用户数据的签名验证和加解密 登录时序图.png微信两个api所拿到的信息:login和getUserInfo 返回的信息.png注册/登录 小程序端: 通过上面wx.login和wx.getUserInfo两个…

转一篇写的比较好的camera文档[Camera 图像处理原理分析]

色彩篇(一) 1 前言 做为拍照手机的核心模块之一,camera sensor效果的调整,涉及到众多的参数,如果对基本的光学原理及sensor软/硬件对图像处理的原理能有深入的理解和把握的话,对我们的工作将会起…