pytorch 训练过程acc_Pytorch之Softmax多分类任务

dc73925982f5a6979ac84683f399e725.png

上一篇文章中,笔者介绍了什么是Softmax回归及其原理。因此在接下来的这篇文章中,我们就来开始动手实现一下Softmax回归,并且最后要完成利用Softmax模型对Fashion MINIST进行分类的任务。在开始实现Softmax之前,我们先来了解一下Fashion MINIST这一数据集。

1 数据集

1.1 FashionMNIST

数据集FashionMNIST虽然名字里面有'MNIST'这个词,但是其与手写体识别一点关系也没有,仅仅只是因为FashionMNIST数据集在数据集规模、类别数量和图片大小上与MINIST手写体数据集一致。

ed07e2727b740bec1ff00cffa35927bd.png
图 1. Fashion MINIST数据集

如图1所示便为Fashion MNIST数据集的部分可视化结果,其包含有训练集6万张和测试集1万张,每张图片的大小为[28,28]。在Pytorch中,我们可以通过如下代码对其进行载入:

def load_dataset():
    mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',  train=True, download=True,transform=transforms.ToTensor())
    mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',   train=False, download=True,transform=transforms.ToTensor())
    return mnist_train, mnist_test

其中参数root表示指定数据集的保存目录;train表示返回训练集还是测试集;download表示数据集不存在时是否需要下载;transform表示指定一种转换方法,而指定transforms.ToTensor()就是将尺寸为(H x W x C)且数据位于[0,255]的PIL图片或者数据类型为np.unit8的numpy数组转换为尺寸为(C x H x W)且数据类型为torch.float32,范围在的Tensor

同时,我们还可以通过代码image= mnist_train[0][0]label=mnist_train[0][1]来分别访问一张图片和其对应的标签值。

1.2 构造数据集

在模型实际训练过程中,数据读取经常是训练的性能瓶颈。同时,为了能够更好的训练模型我们通常会对数据进行打乱,以及分批(batch)的将数据输入的模型中。在Pytorch中,我们可以通过DataLoader这个类来方便的完成上述功能。

start = time.time()
train_iter = torch.utils.data.DataLoader(mnist_test, batch_size=1024, shuffle=True, num_workers=2)
for x_test, y_test in train_iter:
    print(x_test.shape)
    print('%.2f sec' % (time.time() - start))

#结果
torch.Size([1024, 1, 28, 28])
torch.Size([1024, 1, 28, 28])
torch.Size([1024, 1, 28, 28])
torch.Size([1024, 1, 28, 28])
torch.Size([1024, 1, 28, 28])
torch.Size([1024, 1, 28, 28])
torch.Size([1024, 1, 28, 28])
torch.Size([1024, 1, 28, 28])
torch.Size([1024, 1, 28, 28])
torch.Size([784, 1, 28, 28])
2.60 sec

其中batsh_size表示指定每次返回batsh_size个样本;shuffle=True表示对数据集进行打乱;num_workers=2表示用两个进程来读取数据。

但需要注意的是,这里的数据集mnist_test是Pytorch内置的,那如果是我们自己读入的数据集该怎么使用DataLoader呢?我们只需要首先将自己的数据集转换成tensor,然后再通过TensorDataset这个类来构造一个数据集即可。

def make_dataset():
    x = torch.linspace(0, 100, 100, dtype=torch.float32).reshape(-1, 2)
    y = torch.randn(50 )
    dataset = torch.utils.data.TensorDataset(x, y)
    return dataset

此时返回的dataset数据集也就同样能够通过DataLoader进行读取。

2 Softmax多分类

在实现这个分类模型之前,我们先来介绍一下几个需要用到的函数。

2.1 softmax计算实现

在上一篇文章中我们介绍了softmax的计算公式,其实现可以通过如下代码来完成:

def softmax(x):
    s = torch.exp(x)
    return s / torch.sum(s, dim=1, keepdim=True)# 此处触发了广播机制

a = torch.tensor([[1,2,3.],[2,3,1.]])
print(softmax(a))
#结果:
tensor([[0.0900, 0.2447, 0.6652],
        [0.2447, 0.6652, 0.0900]])

其中torch.exp()为计算每个元素的指数次方;sum(s, dim=1, keepdim=True)表示计算得到每一行的和;最后是按位除操作。需要注意的是传入的x必须是浮点类型的,不然会报错。

2.2 交叉熵计算实现

假设我们现在有两个样本,其预测得到的概率分布为[0.1,0.3,0.6][0.5,0.4,0.1]。同时,正确的标签分布为[0,0,1][0,1,0],则对应的交叉熵为。但是,我们在用代码实现的时候完全不用这么麻烦,只需要通过正确的标签找到预测概率分布中对应的值,再取对数即可。

例如[0,0,1][0,1,0]这两个真实分布对应的标签就应该是2和1(从0开始),因此我们只需要分别取[0.1,0.3,0.6][0.5,0.4,0.1]中第2个元素0.6和第1个原始0.4,再取对数就能实现交叉熵的计算。

上述过程通过如下代码便可完成:

def crossEntropy(logits,y):
    c = -torch.log(logits.gather(1,y.reshape(-1,1)))
    return torch.sum(c)# 注意这里返回的是和

logits = torch.tensor([[0.1, 0.3, 0.6], [0.5, 0.4, 0.1]])
y = torch.LongTensor([2, 1])
c = crossEntropy(logits,y)
print(c)

#结果
tensor(1.4271)

其中.gather()就是根据指定维度和索引,选择对应位置上的元素。同时,需要注意的是logits的每一行为一个样本的概率分布,因此我们需要在行上进行索引操作,故gather()的第一个参数应该是1,这一点一定要注意。

2.3 准确率计算实现

在前面介绍softmax时说到,对于每个样本的预测类别,我们会选择对应概率值最大的类别作为输出结果。因此,在计算预测的准确率时,我们首先需要通过torch.argmax()这个函数来返回预测得到的标签。

y_true = torch.tensor([[2,1]])
logits = torch.tensor([[0.1,0.3,0.6],[0.5,0.4,0.1]])
y_pred = logits.argmax(1)
print(y_pred)

#结果
tensor([2, 0])

最后,我们将预测得到的标签同正确标签进行对比即可求得准确率。

def accuracy(y_true,logits):
    acc = (logits.argmax(1) == y_true).float().mean()
    return acc.item()

print(accuracy(y_true,logits))
#结果
0.5

2.4 评估模型

一般我们训练得到一个模型后都需要对其在测试集上进行评估,也就是在测试集上计算其总的准确率。因此,我们首先需要计算得到所有预测对的样本(而不仅仅只是一个batch),然后再除以总的样本数即可。

def evaluate(data_iter, forward, input_nodes, w, b):
    acc_sum, n = 0.0, 0
    for x, y in data_iter:
        logits = forward(x, input_nodes, w, b)
        acc_sum += (logits.argmax(1) == y).float().sum().item()
        n += len(y)
    return acc_sum / n

2.5 分类模型实现

w = torch.tensor(np.random.normal(0, 0.5, [input_nodes, output_nodes]),
                 dtype=torch.float32, requires_grad=True)
b = torch.tensor(np.random.randn(output_nodes), dtype=torch.float32, requires_grad=True)
for epoch in range(epochs):
    for i, (x, y) in enumerate(train_iter):
        logits = forward(x, input_nodes, w, b)
        l = crossEntropy(y, logits)
        l.backward()
        gradientDescent([w, b], lr)
        acc = accuracy(y, logits)
        if i % 50 == 0:
            print("Epoches[{}/{}]---batch[{}/{}]---acc{:.4}---loss {:.4}".format(
                epoches, epoch, len(mnist_train) // batch_size, i, acc,l))
            acc = evaluate(test_iter, forward, input_nodes, w, b)
            print("Epoches[{}/{}]--acc on test{:.4}".format(epochs, epoch, acc))
# 结果:
Epochs[8000/20]--acc on test0.8323
Epochs[8000/21]---batch[468/0]---acc0.8516---loss 47.13
Epochs[8000/21]---batch[468/50]---acc0.8203---loss 67.22
Epochs[8000/21]---batch[468/100]---acc0.9219---loss 38.74
Epochs[8000/21]---batch[468/150]---acc0.8516---loss 57.39
Epochs[8000/21]---batch[468/200]---acc0.8281---loss 74.76
Epochs[8000/21]---batch[468/250]---acc0.8672---loss 55.32
Epochs[8000/21]---batch[468/300]---acc0.8281---loss 60.19

可以看到,大约20轮迭代后,softmax模型在测试集上的准确率就达到了0.83左右。

3 总结

在这篇文章中,笔者首先介绍了FashionMNIST数据集。然后接着介绍了如何使用Pytorch中的DataLoader来构造训练数据迭代器。最后,介绍了如何通过Pytorch来一步步的实现Softmax分类模型,包括如何实现softmax操作、如何快捷的计算交叉熵、如何计算模型的准确率等等。本次内容就到此结束,感谢您的阅读!

本次内容就到此结束,感谢您的阅读!如果你觉得上述内容对你有所帮助,欢迎关注并传播本公众号!若有任何疑问与建议,请添加笔者微信'nulls8'进行交流。青山不改,绿水长流,我们月来客栈见!

引用

[1]动手深度学习

[2]示例代码:https://github.com/moon-hotel/DeepLearningWithMe

推荐阅读

[1]想明白多分类必须得谈逻辑回归

[2]Pytorch之Linear与MSELoss

[3]Pytorch之拟合正弦函数你会吗?

[4]你告诉我什么是深度学习

[5]《跟我一起深度学习》终于来了

c9a78e0337f802ee2973f701586f69df.png

052b81601840b30d2c92efe58f50d5c8.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/500988.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

进程调度实验_Linux应用编程之进程的PID与PPID

关注、星标公众号,直达精彩内容ID:嵌入式情报局作者:情报小哥1进程PID首先介绍PID的相关知识,为后面介绍fork函数进行铺垫。01PID与PPID PID不是控制理论的PID算法,而是Prcess ID的简写。进程PID是当操作系统运行进程时…

操作Windows文件夹时,弹出文件夹正在使用,操作无法完成【解决】

在windows系统上,有时候在删除系统文件或文件夹时出现弹框,提示操作无法完成。这种情况的出现是因为你要删除的文件或文件夹被打开,或者被系统占用。遇到这种情况要怎么处理呢,本文介绍下具体的操作方法来帮助你解决这个问题。 方…

邀请合作如何表达_适时表达想法 才有利于彼此的合作

丹尼跟珍妮合作主持一个podcast节目,两人对这个节目兴致勃勃,并花很多时间投入,珍妮想邀请自己身边朋友一起参加,认为特别来宾可以增加节目的丰富度;丹尼却觉得现在节目才刚开始起步,要建立好两人的节目定位…

idea代码可以编译但是爆红_推荐一款 IDEA 生成代码神器,写代码再也不用加班了...

作者:HeloWxl链接:https://www.jianshu.com/p/e4192d7c6844Easycode是idea的一个插件,可以直接对数据的表生成entity,controller,service,dao,mapper,无需任何编码,简单而强大。1、安装(EasyCode)我这里的话是已经那装好了。建议大…

html跑马灯_用Excel居然能做“跑马灯”,而且还这么简单!

我的目标:让中国的大学生走出校门的那一刻就已经具备这些office技能,让职场人士能高效使用office为其服务。支持我,也为自己加油!你没看错,上面这个就是用Excel做出来的,不过要用到窗体和控件。步骤如下&am…

c语言双链表排序交换节点_图解:单链表翻转的三种方式!

当我们在聊到链表反转的时候,一定说的都是单链表,双链表本身就具有前驱指针 Prev 和后续指针 next,无需进行翻转。单链表反转,反转后的效果如下:看起来很简单,只需要将单链表所有结点的 next 指向&#xff…

wsdl文档中的soap:address的生成规则_BAT大牛都在使用的数据库文档生成插件,不来看一下?...

一、概述在企业级开发中、我们经常会有编写数据库表结构文档的时间付出,从业以来,待过几家企业,关于数据库表结构文档状态:要么没有、要么有、但都是手写、后期运维开发,需要手动进行维护到文档中,很是繁琐…

修订模式怎么彻底关闭_电脑玩游戏卡顿怎么办?

电脑玩游戏卡怎么办?在玩游戏时电脑卡真的是会气死人的,特别是在打团的时候卡了,想砸电脑有木有?那么电脑玩游戏卡怎么办呢?给大家介绍几个方法,可以尝试改善卡顿。软件方面:1、 开启电源性能模…

datepicker不能选择是为什么_为什么客厅不好看?休闲椅选错了

为什么客厅不好看?休闲椅选错了很多装修完毕的小伙伴们经常有这样一个疑问: 为什么我家的客厅看上去这么凌乱,一点都没有想象中井然有致?这其中的潜在原因有很多,比如沙发墙的装饰设计有误,比如各类家具的款式搭配不对…

如何打开屏幕坏的手机_每天打开手机屏幕20次?打开10次以上的朋友进~

现代社会最很普遍的现象就是不管有没有事,不断地打开手机屏幕看时间或者刷各种信息和段子。 一块小小的屏幕却有着巨大的魅力。明明没有任何事情要干,却还是忍不住诱惑(cant resist temptation [tɛmpˈteʃən])想要打开屏幕,仿佛潘多拉的盒…

mvc 两个控制器session 丢失_用纯 JavaScript 撸一个 MVC 程序

前言我想用 model-view-controller 架构模式在纯 JavaScript 中写一个简单的程序,于是我这样做了。希望它可以帮你理解 MVC,因为当你刚开始接触它时,它是一个难以理解的概念。我做了这个todo应用程序,这是一个简单小巧的浏览器应用…

redis线程阻塞原因排插_每次面试都要被问:为什么采用单线程的Redis也会如此之快?...

众所周知,Redis在内存库数据库领域非常地火热,它极高的性能和丰富的数据结构为我们的开发提供了极大的便利。但我们也听说了,Redis是单线程的,为什么采用单线程的Redis也会如此之快呢?这篇文章我们来分析一下其中的缘由…

审计日志_Oracle审计日志过大?如何清理及关闭审计机制?

概述oracle 11g推出了审计功能,但这个功能会针对很多操作都产生审计文件.aud,日积月累下来这些文件也很多,默认情况下,系统为了节省资源,减少I/0操作,其审计功能是关闭的。这段时间发现审计占了比较多空间&…

servlet如何使用session把用户的手机号修改_SpringBoot源码学习系列之嵌入式Servlet容器...

1、前言简单介绍SpringBoot的自动配置就是SpringBoot的精髓所在;对于SpringBoot项目是不需要配置Tomcat、jetty等等Servlet容器,直接启动application类既可,SpringBoot为什么能做到这么简捷?原因就是使用了内嵌的Servlet容器&…

mybatisplus新增返回主键_第17期:索引设计(主键设计)

表的主键指的针对一张表中的一列或者多列,其结果必须能标识表中每行记录的唯一性。InnoDB 表是索引组织表,主键既是数据也是索引。主键的设计原则1. 对空间占用要小上一篇我们介绍过 InnoDB 主键的存储方式,主键占用空间越小,每个…

mysql 集群与主从_Mysql集群和主从

1、Mysql cluster: share-nothing,分布式节点架构的存储方案,以便于提供容错性和高性能。需要用到mysql cluster安装包,在集群中的每一个机器上安装。有三个关键概念:Sql节点(多个),数据节点(多个),管理节点(一个)&…

redis缓存原理与实现_基于Redis实现范围查询的IP库缓存设计方案

点击上方“码农沉思录” 发现更多精彩我先说下结果。我现在还不敢放线上去测,这是本地测的数据,我4g内存的电脑本地开redis,一次都没写完过全部数据,都是写一半后不是redis挂就是测试程序挂。可以肯定的是总记录数是以千万为单位…

mysql原生库_Mysql数据库的一些简单原生sql语句

原生sql语句查询:select * from 表名 :查找表内所有数据, * 代表所有where 具体条件 :where作位查询sql语句条件,例 select * from 表名 where 字段名指定值order by 升降序:与desc和asc使用,通常以int类型字段进行升…

有向图生成树是如何画的_漫画:什么是最小生成树?

作者 | 小灰来源 | 程序员小灰————— 第二天 —————————————————首先看看第一个例子,有下面这样一个带权图:它的最小生成树是什么样子呢?下图绿色加粗的边可以把所有顶点连接起来,又保证了边的权值之和最小&a…

printf 指针地址_c语言对指针的理解

先来讲一下本人学指针的经历:大一的时候刚接触c语言对指针这东西真的是太迷了,感觉麻烦难懂不想其他语言一样。但是搞懂以后就被指针的魅力吸引甚至喜欢上c语言。不多讲,开始!(文章可能有些长,但放心全是基础的东西&am…