使用Pytorch完成多分类问题

多分类问题在最后的输出层采用的Softmax Layer,其具有两个特点:1.每个输出的值都是在(0,1);2.所有值加起来和为1.

假设是最后线性层的输出,则对应的Softmax function为:

     


输出经过sigmoid运算即可是西安输出的分类概率都大于0且总和为1。


上图的交叉熵损失就包含了softmax计算和右边的标签输入计算(即框起来的部分)


所以在使用交叉熵损失的时候,神经网络的最后一层是不要做激活的,因为把它做成分布的激活是包含在交叉熵损失里面的,最后一层不要做非线性变换,直接交给交叉熵损失。

如上图,做交叉熵损失时要求y是一个长整型的张量,构造时直接用

criterion = torch.nn.CrossEntropyLoss()


3个类别,分别是2,0,1
Y_pred1 ,Y_pred2还是线性输出,没经过softmax,还不是概率分布,比如Y_pred1,0.9最大,表示对应为第3个的概率最大,和2吻合,1.1最大,表示对应为第1个的概率最大,和0吻合,2.1最大,表示对应为第2个的概率最大,和1吻合,那么Y_pred1 的损失会比较小
对于Y_pred2,0.8最大,表示对应为第1个的概率最大,和0不吻合,0.5最大,表示对应为第3个的概率最大,和2不吻合,0.5最大,表示对应为第3个的概率最大,和2不吻合,那么Y_pred2 的损失会比较大

 

Exercise 9-1: CrossEntropyLoss vs NLLLoss
What are the differences?
• Reading the document:
• https://pytorch.org/docs/stable/nn.html#crossentropyloss
• https://pytorch.org/docs/stable/nn.html#nllloss
• Try to know why:
• CrossEntropyLoss <==> LogSoftmax + NLLLoss

为什么要用transform

transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])

PyTorch读图像用的是python的imageLibrary,就是PIL,现在用的都是pillow,pillow读进来的图像用神经网络处理的时候,神经网络有一个特点就是希望输入的数值比较小,最好是在-1到+1之间,最好是输入遵从正态分布,这样的输入对神经网络训练是最有帮助的

原始图像是28*28的像素值在0到255之间,我们把它转变成图像张量,像素值是0到1

在视觉里面,灰度图就是一个矩阵,但实际上并不是一个矩阵,我们把它叫做单通道图像,彩色图像是3通道,通道有宽度和高度,一般我们读进来的图像张量是WHC(宽高通道)
在PyTorch里面我们需要转化成CWH,把通道放在前面是为了在PyTorch里面进行更高效的图像处理,卷积运算。所以拿到图像之后,我们就把它先转化成pytorch里面的一个Tensor,把0到255的值变成0到1的浮点数,然后把维度由2828变成128*28的张量,由单通道变成多通道,

这个过程可以用transforms的ToTensor这个函数实现


归一化


transforms.Normalize((0.1307, ), (0.3081, ))

这里的0.1307,0.3081是对Mnist数据集所有的像素求均值方差得到的
也就是说,将来拿到了图像,先变成张量,然后Normalize,切换到0,1分布,然后供神经网络训练
如上图,定义好transform变换之后,直接把它放到数据集里面,为什么要放在数据集里面呢,是为了在读取第i个数据的时候,直接用transform处理

 

模型

输入是一组图像,激活层改用Relu
全连接神经网络要求输入是一个矩阵
所以需要把输入的张量变成一阶的,这里的N表示有N个图片

 view函数可以改变张量的形状,-1表示将来自动去算它的值是多少,比如输入是n128*28
将来会自动把n算出来,输入了张量就知道形状,就知道有多少个数值

最后输出是(N,10)因为是有0-9这10个标签嘛,10表示该图像属于某一个标签的概率,现在还是线性值,我们再用softmax把它变成概率

 #沿着第一个维度找最大值的下标,返回值有两个,因为是10列嘛,返回值一个是每一行的最大值,另一个是最大值的下标(每一个样本就是一行,每一行有10个量)(行是第0个维度,列是第1个维度)


 MNIST数据集训练代码

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim# prepare datasetbatch_size = 64transform = transforms.Compose([transforms.ToTensor(), #先将图像变换成一个张量tensor。transforms.Normalize((0.1307,), (0.3081,))#其中的0.1307是MNIST数据集的均值,0.3081是MNIST数据集的标准差。
])  # 归一化,均值和方差train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True,download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False,download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)# design model using class
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x):# 28 * 28 = 784# 784 = 28 * 28,即将N *1*28*28转化成 N *1*784x = x.view(-1, 784)  # -1其实就是自动获取mini_batchx = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x)  # 最后一层不做激活,不进行非线性变换model = Net()#CrossEntropyLoss <==> LogSoftmax + NLLLoss。
#也就是说使用CrossEntropyLoss最后一层(线性层)是不需要做其他变化的;
#使用NLLLoss之前,需要对最后一层(线性层)先进行SoftMax处理,再进行log操作。# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
#momentum 是带有优化的一个训练过程参数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)# training cycle forward, backward, updatedef train(epoch):running_loss = 0.0#enumerate()函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,#同时列出数据和数据下标,一般用在 for 循环当中。#enumerate(sequence, [start=0])for batch_idx, data in enumerate(train_loader, 0):# 获得一个批次的数据和标签inputs, target = dataoptimizer.zero_grad()#forward + backward + update# 获得模型预测结果(64, 10)outputs = model(inputs)# 交叉熵代价函数outputs(64,10),target(64)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))running_loss = 0.0def test():correct = 0total = 0with torch.no_grad():#不需要计算梯度。for data in test_loader:images, labels = dataoutputs = model(images)#orch.max的返回值有两个,第一个是每一行的最大值是多少,第二个是每一行最大值的下标(索引)是多少。_, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度total += labels.size(0)correct += (predicted == labels).sum().item()  # 张量之间的比较运算print('accuracy on test set: %d %% ' % (100 * correct / total))if __name__ == '__main__':for epoch in range(10):train(epoch)test()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469793.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyTorch的nn.Linear()详解

1. nn.Linear() nn.Linear()&#xff1a;用于设置网络中的全连接层&#xff0c;需要注意的是全连接层的输入与输出都是二维张量 一般形状为[batch_size, size]&#xff0c;不同于卷积层要求输入输出是四维张量。其用法与形参说明如下&#xff1a; in_features指的是输入的二维…

罗彻斯特大学计算机科学系专业排名,罗切斯特大学计算机科学专业

罗切斯特大学(University of Rochester&#xff0c;U of R)建立于1850年&#xff0c;是一所美国著名的私立研究型大学&#xff0c;“新常春藤”联盟之一&#xff0c;北美大学协会(AAU)成员、世界大学联盟成员。360老师介绍&#xff0c;学校的7位学者是美国国家科学院院士&#…

系统权限管理设计 (转)

权限设计&#xff08;初稿&#xff09; 1. 前言&#xff1a; 权限管理往往是一个极其复杂的问题&#xff0c;但也可简单表述为这样的逻辑表达式&#xff1a;判断“Who对What(Which)进行How的操作”的逻辑表达式是否为真。针对不同的应用&#xff0c;需要根据项目的实…

卷积神经网络(基础篇)

说明 0、前一部分叫做Feature Extraction&#xff0c;后一部分叫做classification 1、每一个卷积核它的通道数量要求和输入通道是一样的。这种卷积核的总数有多少个和你输出通道的数量是一样的。 2、卷积(convolution)后&#xff0c;C(Channels)变&#xff0c;W(width)和H(Heig…

Inception(Pytorch实现)

论文在此: Going deeper with convolutions 论文下载: https://arxiv.org/pdf/1409.4842.pdf 网络结构图: import torch import torch.nn as nn import torch.nn.functional as Fclass Inception3(nn.Module):def __init__(self, num_classes1000, aux_logitsTrue, transform…

SecureCRT 用来当串口工具的设置

今天从淘宝网上买的USB转串口线终于到了&#xff0c;从网上下载了驱动&#xff0c;关于USB转串口驱动在我上传的资源里面有&#xff0c;关于SecureCRT这个串口调试工具我也上传了&#xff0c;是个绿色免安装版本。 刚开始的时候一步一步的设置串口&#xff0c;连接串口也可以连…

Brainstorm-the walkthrough example: Image Classification

(1) 运行create data&#xff0c;其中包括下载cifar10&#xff0c;并转换为hdf5格式&#xff08;详见百度百科&#xff1a;http://baike.baidu.com/view/771949.htm#4_2&#xff09;: cifar10的数据简介见&#xff1a;http://www.cs.toronto.edu/~kriz/cifar.html cd data pyth…

卷积神经网络(高级篇) Inception Moudel

Inception Moudel 1、卷积核超参数选择困难&#xff0c;自动找到卷积的最佳组合。 2、1x1卷积核&#xff0c;不同通道的信息融合。使用1x1卷积核虽然参数量增加了&#xff0c;但是能够显著的降低计算量(operations) 3、Inception Moudel由4个分支组成&#xff0c;要分清哪些…

计算机谈音乐薛之谦,明星浮世绘之薛之谦:分析了50多首音乐作品,为其总结了五个特点...

原标题&#xff1a;明星浮世绘之薛之谦&#xff1a;分析了50多首音乐作品&#xff0c;为其总结了五个特点薛之谦&#xff0c;才华横溢思维敏捷&#xff0c;性格搞怪却又忧郁。我曾经用四个字来形容他&#xff0c;沙雕其外&#xff0c;金玉其中。记得老薛曾经发布了一个动态&…

linux内核下载 编译

linux内核下载网址 今天去看了一场电影“疯狂的原始人”----回来的车上看到一个老奶奶传教士,我想对自己多,加油,加油学习,深思深思 我们现在用的安霸系统,每搞一次我都会进行一次备份,一个系统加上GUI一起都有差不多一G多,而今天下载了最新的linux内核版本,才不80M左…

Deep learning

论文&#xff1a;doi:10.1038/nature14539 论文意义和主要内容 三巨头从机器学习谈起&#xff0c;指出传统机器学习的不足&#xff0c;总览深度学习理论、模型&#xff0c;给出了深度学习的发展历史&#xff0c;以及DL中最重要的算法和理论。 概念&#xff1a; 原理&#xff…

第一周:深度学习引言(Introduction to Deep Learning)

1.1 欢迎(Welcome) 深度学习改变了传统互联网业务&#xff0c;例如如网络搜索和广告。但是深度学习同时也使得许多新产品和企业以很多方式帮助人们&#xff0c;从获得更好的健康关注。 深度学习做的非常好的一个方面就是读取X光图像&#xff0c;到生活中的个性化教育&#xf…

无忧计算机二级试题题库,全国计算机二级MS Office试题

考无忧小编为各位考生搜集整理了的二级MS Office试题&#xff0c;希望可以为各位的备考锦上添花&#xff0c;雪中送炭&#xff01;记得刷计算机等级考试题库哟&#xff01;1、被选中要筛选的数据单元格的下拉箭头中有哪几种筛选方式( ABD)A、全部B、前十个C、后十个D、自定义2、…

第二周:神经网络的编程基础之Python与向量化

本节课我们将来探讨Python和向量化的相关知识。 1. Vectorization 深度学习算法中&#xff0c;数据量很大&#xff0c;在程序中应该尽量减少使用循环语句&#xff0c;而可以使用向量运算来提高程序运行速度。 向量化&#xff08;Vectorization&#xff09;就是利用矩阵运算的…

U-boot移槙

1、我是照着这里去移植的 http://blog.chinaunix.net/uid-26306203-id-3716785.html 2、然后make 出现问题&#xff0c;到这里去有解决办法&#xff1a;http://blog.csdn.net/zjt289198457/article/details/6854177 : http://blog.csdn.net/zjt289198457/article/details/68…

第三周:浅层神经网络

1. 神经网络综述 首先&#xff0c;我们从整体结构上来大致看一下神经网络模型。 前面的课程中&#xff0c;我们已经使用计算图的方式介绍了逻辑回归梯度下降算法的正向传播和反向传播两个过程。如下图所示。神经网络的结构与逻辑回归类似&#xff0c;只是神经网络的层数比逻辑…

智慧交通day00-项目简介

汽车的日益普及在给人们带来极大便利的同时&#xff0c;也导致了拥堵的交通路况&#xff0c;以及更为频发的交通事故。智能交通技术已成为推动现代技术交通技术发展的重要力量&#xff0c;智能交通不仅能够提供实时的交通路况信息&#xff0c;帮助交通管理者规划管理策略&#…

智慧交通day01-算法库01:numba

1 numba介绍 numba是一个用于编译Python数组和数值计算函数的编译器&#xff0c;这个编译器能够大幅提高直接使用Python编写的函数的运算速度。 numba使用LLVM编译器架构将纯Python代码生成优化过的机器码&#xff0c;通过一些添加简单的注解&#xff0c;将面向数组和使用大量…

计算机语言恢复,win10系统找回消失不见语言栏的恢复方法

win10系统使用久了&#xff0c;好多网友反馈说关于对win10系统找回消失不见语言栏设置的方法&#xff0c;在使用win10系统的过程中经常不知道如何去对win10系统找回消失不见语言栏进行设置&#xff0c;有什么好的办法去设置win10系统找回消失不见语言栏呢&#xff1f;在这里小编…

智慧交通day01-算法库02:imutils

1.imutils功能简介 imutils是在OPenCV基础上的一个封装&#xff0c;达到更为简结的调用OPenCV接口的目的&#xff0c;它可以轻松的实现图像的平移&#xff0c;旋转&#xff0c;缩放&#xff0c;骨架化等一系列的操作。 安装方法&#xff1a; pip install imutils在安装前应确…