Pytorch深度学习实践(8)多分类任务

多分类问题

多分类问题主要是利用了Softmax分类器,数据集采用MNIST手写数据集

设计方法:

  • 把每一个类别看成一个二分类的问题,分别输出10个概率

    在这里插入图片描述

    但是这种方法存在一种问题:不存在抑制问题,即按照常规来讲,当 y ^ 1 \hat y_1 y^1较大时,其他的值应该较小,换句话说,就是这10个概率相加不为1,即最终的输出结果不满足离散化分布的要求

  • 引入Softmax算子,即在输出层使用softmax
    在这里插入图片描述

Softmax

softmax函数定义:
P ( y = i ) = e z i ∑ j = 0 K − 1 e z j P(y = i) = \frac{e^{z_i}}{\sum _{j=0}^{K-1}e^{z_j}} P(y=i)=j=0K1ezjezi

  • 首先,使用指数 e z i e^{z_i} ezi,保证所有的输出都大于0
  • 其次,分母中对所有的输出进行指数计算后求和,相当于对结果归一化,保证了最后10个类别的概率相加为1

损失函数

NLL损失

L o s s ( Y ^ , Y ) = − Y l o g Y ^ Loss(\hat Y, Y) = -Ylog \hat Y Loss(Y^,Y)=YlogY^

在这里插入图片描述

y_pred = np.exp(z) / np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum

也可以直接将标签为1的类对应的 l o g Y ^ log\hat Y logY^拿出来直接做运算

交叉熵损失

交叉熵损失 = LogSoftmax + NLLLoss
在这里插入图片描述

注意:

  • 交叉熵损失中,神经网络的最后一层不做激活,激活函数包括在交叉熵中了
  • y需要使用长整型的张量 y = torch.LongTensor([0])
  • 直接调用函数使用,criterion = torch.nn.CrossEntropyLoss()

代码实现

import torch
from torchvision import transforms  # 对图像进行处理
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F  # 非线性函数的包
import torch.optim as optim  # 优化

处理数据集

要把图像数据转化为Tensor形式,因此要定义transform对象

########### 数据集处理 ##########
batch_size = 64
## 将图像转化为张量
transform = transforms.Compose([transforms.ToTensor(),  # 使用ToTensor()方法将图像转化为张量transforms.Normalize((0.1307, ), (0.3081, ))  # 归一化 均值-标准差 切换到01分布
])
tran_dataset = datasets.MNIST(root='./dataset/mnist/',train=True,download=True,transform=transform)  # 将transform放到数据集里 直接处理train_loader = DataLoader(tran_dataset,shuffle=True,batch_size=batch_size)test_dataset = datasets.MNIST(root='./dataset/mnist/',train=False,download=True,transform=transform)test_loader = DataLoader(test_dataset,shuffle=True,batch_size=batch_size)

模型定义

由于线性模型的输入必须是一个向量,而图像为 28 × 28 × 1 28×28×1 28×28×1的矩阵,因此要先把图像拉成成一个 1 × 784 1×784 1×784的向量

这种做法会导致图像中的一些局部信息丢失

########## 模型定义 ##########
## 1×28×28 --view-> 1×784
## 784 --> 512 --> 256 --> 128 --> 128 --> 64 -->10
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)  # 图像矩阵拉伸成向量x = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))x = self.l5(x)  # 最后一层不需要激活 sigmoid包含在cross entropy里return xmodel = Net()

注意最后一层不需要添加softmax,因为损失函数使用的cross entorpy已经包含softmax函数

损失函数和优化器设置

由于模型比较复杂,所以使用动量梯度下降方法,参数设置为0.5

########## 损失函数和优化器设置 ##########
criterion = torch.nn.CrossEntropyLoss()
## momentum指带冲量的梯度下降
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.5)

动量梯度下降(Gradient descent with momentum)

动量梯度下降方法是对小批量mini-batch的一种优化,可以有效地减少收敛过程中摆动幅度的大小,提高效率

动量梯度下降的过程类似于一个带有质量的小球在函数曲线上向下滚落,当小求滚落在最低点后,由于惯性还会继续上升一段距离,然后再滚落回来,再经最低点上去…最终小球停留在最低点处。而且由于小球具有惯性这一特点,当函数曲面比较复杂陡峭时,它便可以越过这些而尽快达到最低点

动量梯度下降每次更新参数时,对各个mini-batch求得的梯度 ∇ W \nabla W W ∇ b \nabla b b使用指数加权平均得到 V ∇ w V_{\nabla w} Vw V ∇ b V_{\nabla b} Vb,即通过历史数据来更新当前参数,从而得到更新公式:
V ∇ W n + 1 = β V ∇ W n + ( 1 − β ) ∇ W n V_{\nabla W_{n+1}} = \beta V_{\nabla W_n} + (1 - \beta)\nabla W_n VWn+1=βVWn+(1β)Wn

V ∇ b n + 1 = β V ∇ b n + ( 1 − β ) ∇ b n V_{\nabla b_{n+1}} = \beta V_{\nabla b_n} + (1 - \beta)\nabla b_n Vbn+1=βVbn+(1β)bn

模型训练

将训练阶段封装成一个函数

########## 模型训练 ##########
def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()  # 优化器清零## forwardoutputs = model(inputs)loss = criterion(outputs, labels)## backwardloss.backward()## updataoptimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:  # 每300轮 输出一次损失print('[%d %5d] loss: %.3f' % (epoch + 1, batch_idx+1, running_loss / 300))running_loss = 0.0

模型测试

将模型的测试封装成一个函数,且每训练一轮,测试一次

def test():correct = 0total = 0with torch.no_grad():  # 以下代码不会计算梯度for data in test_loader:images, labels = dataoutputs = model(images)  # 得到十个概率## 返回 最大值和最大值下标 我们只需要最大值下标即可_, predicted = torch.max(outputs.data, dim=1)  # 取最大值total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %%' % (100 * correct / total))

主函数

########### main ##########
if __name__ == '__main__':accuracy_history = []epoch_history = []for epoch in range(50):train(epoch)accuracy = test()accuracy_history.append(accuracy)epoch_history.append(epoch)plt.plot(epoch_history, accuracy_history)plt.xlabel('epoch')plt.ylabel('accuracy')plt.show()

完整代码

import torch
import matplotlib.pyplot as plt
from torchvision import transforms  # 对图像进行处理
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F  # 非线性函数的包
import torch.optim as optim  # 优化########### 数据集处理 ##########
batch_size = 64
## 将图像转化为张量
transform = transforms.Compose([transforms.ToTensor(),  # 使用ToTensor()方法将图像转化为张量transforms.Normalize((0.1307, ), (0.3081, ))  # 归一化 均值-标准差 切换到01分布
])
tran_dataset = datasets.MNIST(root='./dataset/mnist/',train=True,download=False,transform=transform)  # 将transform放到数据集里 直接处理train_loader = DataLoader(tran_dataset,shuffle=True,batch_size=batch_size)test_dataset = datasets.MNIST(root='./dataset/mnist/',train=False,download=False,transform=transform)test_loader = DataLoader(test_dataset,shuffle=True,batch_size=batch_size)########## 模型定义 ##########
## 1×28×28 --view-> 1×784
## 784 --> 512 --> 256 --> 128 --> 128 --> 64 -->10
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)  # 图像矩阵拉伸成向量x = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))x = self.l5(x)  # 最后一层不需要激活 sigmoid包含在cross entropy里return xmodel = Net()########## 损失函数和优化器设置 ##########
criterion = torch.nn.CrossEntropyLoss()
## momentum指带冲量的梯度下降
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.5)########## 模型训练 ##########
def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()  # 优化器清零## forwardoutputs = model(inputs)loss = criterion(outputs, labels)## backwardloss.backward()## updataoptimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:  # 每300轮 输出一次损失print('[%d %5d] loss: %.3f' % (epoch + 1, batch_idx+1, running_loss / 300))running_loss = 0.0
########## 模型测试 ##########
def test():correct = 0total = 0with torch.no_grad():  # 以下代码不会计算梯度for data in test_loader:images, labels = dataoutputs = model(images)  # 得到十个概率## 返回 最大值和最大值下标 我们只需要最大值下标即可_, predicted = torch.max(outputs.data, dim=1)  # 取最大值total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %%' % (100 * correct / total))return 100*correct / total########### main ##########
if __name__ == '__main__':accuracy_history = []epoch_history = []for epoch in range(50):train(epoch)accuracy = test()accuracy_history.append(accuracy)epoch_history.append(epoch)plt.plot(epoch_history, accuracy_history)plt.xlabel('epoch')plt.ylabel('accuracy(%)')plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/876491.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

stm32h7串口发送寄存器空中断

关于stm32串口的发送完成中断UART_IT_TC网上资料挺多的,但是使用发送寄存器空中断UART_IT_TXE的不太多 UART_IT_TC 和 UART_IT_TXE区别 UART_IT_TC 和 UART_IT_TXE 是两种不同的 UART 中断源,用于表示不同的发送状态。它们的主要区别如下: …

raise JSONDecodeError(“Expecting value”, s, err.value) from None

raise JSONDecodeError(“Expecting value”, s, err.value) from None 目录 raise JSONDecodeError(“Expecting value”, s, err.value) from None 【常见模块错误】 【解决方案】 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是…

数字图像处理笔记(三) ---- 傅里叶变换的基本原理

系列文章目录 数字图像处理笔记(一)---- 图像数字化与显示 数字图像处理笔记(二)---- 像素加图像统计特征 数字图像处理笔记(三) ---- 傅里叶变换的基本原理 文章目录 系列文章目录前言一、傅里叶变换二、离散傅里叶变…

ChatTTS(文本转语音) 一键本地安装爆火语音模型

想不想让你喜欢的文章,有着一个动听的配音,没错,他就可以实现。 ChatTTS 是一款专为对话场景设计的文本转语音模型,例如 LLM 助手对话任务。它支持英语和中文两种语言。 当下爆火模型,在Git收获23.5k的Star&#xff…

【Pod 详解】Pod 的概念、使用方法、容器类型

《Pod 详解》系列,共包含以下几篇文章: Pod 的概念、使用方法、容器类型Pod 的生命周期(一):Pod 阶段与状况、容器的状态与重启策略Pod 的生命周期(二):Pod 的健康检查之容器探针Po…

C++入门基础:C++中的常用操作符练习

开头介绍下C语言先,C是一种广泛使用的计算机程序设计语言,起源于20世纪80年代,由比雅尼斯特劳斯特鲁普在贝尔实验室开发。它是C语言的扩展,增加了面向对象编程的特性。C的应用场景广泛,包括系统软件、游戏开发、嵌入式…

智慧医院临床检验管理系统源码(LIS),全套LIS系统源码交付,商业源码,自主版权,支持二次开发

实验室信息系统是集申请、采样、核收、计费、检验、审核、发布、质控、查询、耗材控制等检验科工作为一体的网络管理系统。它的开发和应用将加快检验科管理的统一化、网络化、标准化的进程。一体化设计,与其他系统无缝连接,全程化条码管理。支持危机值管…

DataX(二):DataX安装与入门

1. 官方地址 下载地址:http://datax-opensource.oss-cn-hangzhou.aliyuncs.com/datax.tar.gz 源码地址:GitHub - alibaba/DataX: DataX是阿里云DataWorks数据集成的开源版本。 2. 前置要求 Linux JDK(1.8 以上,推荐 1.8) Python(推荐 Pyt…

一文总结代理:代理模式、代理服务器

概述 代理在计算机编程领域,是一个很通用的概念,包括:代理设计模式,代理服务器等。 代理类持有具体实现类的实例,将在代理类上的操作转化为实例上方法的调用。为某个对象提供一个代理,以控制对这个对象的…

测试分类篇

按测试对象划分 这里可以分为界面测试, 可靠性测试, 容错率测试, 文档测试, 兼容性测试, 安装卸载测试, 安全测试, 性能测试, 内存泄露测试. 界面测试 界面测试(简称UI测试),指按照界面的需求(一般是UI设计稿)和界面的设计规则…

DOS攻击实验

实验背景 Dos 攻击是指故意的攻击网络协议实现的缺陷或直接通过野蛮手段,残忍地耗尽被攻击对象的资源,目的是让目标计算机或网络无法提供正常的服务或资源访问,使目标系统服务系统停止响应甚至崩溃。 实验设备 一个网络 net:cloud0 一台模…

基于微信小程序+SpringBoot+Vue的儿童预防接种预约系统(带1w+文档)

基于微信小程序SpringBootVue的儿童预防接种预约系统(带1w文档) 基于微信小程序SpringBootVue的儿童预防接种预约系统(带1w文档) 开发合适的儿童预防接种预约微信小程序,可以方便管理人员对儿童预防接种预约微信小程序的管理,提高信息管理工作效率及查询…

24暑假算法刷题 | Day22 | LeetCode 77. 组合,216. 组合总和 III,17. 电话号码的字母组合

目录 77. 组合题目描述题解 216. 组合总和 III题目描述题解 17. 电话号码的字母组合题目描述题解 77. 组合 点此跳转题目链接 题目描述 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 示例 1: 输…

移动UI:排行榜单页面如何设计,从这五点入手,附示例。

移动UI的排行榜单页面设计需要考虑以下几个方面: 1. 页面布局: 排行榜单页面的布局应该清晰明了,可以采用列表的形式展示排行榜内容,同时考虑到移动设备的屏幕大小,应该设计合理的滚动和分页机制,确保用户…

贪心算法.

哈夫曼树 哈夫曼树(Huffman Tree),又称为霍夫曼树或最优二叉树,是一种带权路径长度最短的二叉树,常用于数据压缩。 定义:给定N个权值作为N个叶子结点,构造一棵二叉树,若该树…

普乐蛙VR航天航空体验馆知识走廊VR体验带你登陆月球

VR航天航空设备是近年来随着虚拟现实(VR)技术的快速发展而兴起的一种新型设备,它结合了航天航空领域的专业知识与VR技术的沉浸式体验,为用户提供了前所未有的航天航空体验。以下是对VR航天航空设备的详细介绍: 一、设备…

UGUI优化篇--UGUI合批

UGUI合批 UGUI合批规则概述UGUI性能查看工具合批部分的特殊例子一个白色image、蓝色image覆盖了Text,白色image和Text哪个先渲染 Mask合批Mask为什么会产生两个drawcallMask为什么不能合批Mask注意要点 RectMask2D为什么RecMask2D比Mask性能更好主要代码RectMask2D注…

Golang | Leetcode Golang题解之第295题数据流的中位数

题目: 题解: type MedianFinder struct {nums *redblacktree.Treetotal intleft, right iterator }func Constructor() MedianFinder {return MedianFinder{nums: redblacktree.NewWithIntComparator()} }func (mf *MedianFinder) AddNum(…

MySQL中多表查询之外连接

首先先来介绍一下我做的两个表,然后再用他们两个举例说明。 -- 创建教师表 create table teachers( id_t int primary key auto_increment, -- 老师编号 name_t varchar(5) -- 姓名 ); -- 创建学生表 create table students( id_s int primary key auto_increment,…

数据结构——单链表OJ题(下)

目录 一、链表的回文结构 思路一:数组法 (1)注意 (2)解题 思路二:反转链表法 (1) 注意 (2)解题 二、相交链表 (1)思路&#…