使用PyTorch解决多分类问题:构建、训练和评估深度学习模型

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🍋引言
  • 🍋什么是多分类问题?
  • 🍋处理步骤
  • 🍋多分类问题
  • 🍋MNIST dataset的实现
  • 🍋NLLLoss 和 CrossEntropyLoss

🍋引言

当处理多分类问题时,PyTorch是一种非常有用的深度学习框架。在这篇博客中,我们将讨论如何使用PyTorch来解决多分类问题。我们将介绍多分类问题的基本概念,构建一个简单的多分类神经网络模型,并演示如何准备数据、训练模型和评估结果。

🍋什么是多分类问题?

多分类问题是一种机器学习任务,其中目标是将输入数据分为多个不同的类别或标签。与二分类问题不同,多分类问题涉及到三个或更多类别的分类任务。例如,图像分类问题可以将图像分为不同的类别,如猫、狗、鸟等。

🍋处理步骤

  • 准备数据
    收集和准备数据集,确保每个样本都有相应的标签,以指明其所属类别。
    划分数据集为训练集、验证集和测试集,以便进行模型训练、调优和性能评估。

  • 数据预处理
    对数据进行预处理,例如归一化、标准化、缺失值处理或数据增强,以确保模型训练的稳定性和性能。

  • 选择模型架构
    选择适当的深度学习模型架构,通常包括卷积神经网络(CNN)、循环神经网络(RNN)、Transformer等,具体取决于问题的性质。

  • 定义损失函数
    为多分类问题选择适当的损失函数,通常是交叉熵损失(Cross-Entropy Loss)。

  • 选择优化器
    选择合适的优化算法,如随机梯度下降(SGD)、Adam、RMSprop等,以训练模型并调整权重。

  • 训练模型
    使用训练数据集来训练模型。在每个训练迭代中,通过前向传播和反向传播来更新模型参数,以减小损失函数的值。

  • 评估模型
    使用验证集来评估模型性能。常见的性能指标包括准确性、精确度、召回率、F1分数等。

  • 调优模型
    根据验证集的性能,对模型进行调优,可以尝试不同的超参数设置、模型架构变化或数据增强策略。

  • 测试模型
    最终,在独立的测试数据集上评估模型的性能,以获得最终性能评估。

  • 部署模型
    将训练好的模型部署到实际应用中,用于实时或批处理多分类任务。

🍋多分类问题

之前我们讨论的问题都是二分类居多,对于二分类问题,我们若求得p(0),南无p(1)=1-p(0),还是比较容易的,但是本节我们将引入多分类,那么我们所求得就转化为p(i)(i=1,2,3,4…),同时我们需要满足以上概率中每一个都大于0;且总和为1

处理多分类问题,这里我们新引入了一个称为Softmax Layer
在这里插入图片描述

接下来我们一起讨论一下Softmax Layer层
在这里插入图片描述
首先我们计算指数计算e的zi次幂,原因很简单e的指数函数恒大于0;分母就是e的z1次幂+e的z2次幂+e的z3次幂…求和,这样所有的概率和就为1了。


下图形象的展示了Softmax,Exponent这里指指数,和上面我们说的一样,先求指数,这样有了分子,再将所有指数求和,最后一一divide,得到了每一个概率。
在这里插入图片描述


接下来我们一起来看看损失函数
在这里插入图片描述
如果使用numpy进行实现,根据刘二大人的代码,可以进行如下的实现

import numpy as np
y = np.array([1,0,0])
z = np.array([0.2,0.1,-0.1])
y_pred = np.exp(z)/np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum()
print(loss)

运行结果如下
在这里插入图片描述
注意:神经网络的最后一层不需要激活


在pytorch中

import torch
y = torch.LongTensor([0])  # 长整型
z = torch.Tensor([[0.2, 0.1, -0.1]])
criterion = torch.nn.CrossEntropyLoss() 
loss = criterion(z, y)
print(loss)

运行结果如下

在这里插入图片描述
下面根据一个例子进行演示

criterion = torch.nn.CrossEntropyLoss()
Y = torch.LongTensor([2,0,1])
Y_pred1 = torch.Tensor([[0.1, 0.2, 0.9], [1.1, 0.1, 0.2], [0.2, 2.1, 0.1]]) 
Y_pred2 = torch.Tensor([[0.8, 0.2, 0.3], [0.2, 0.3, 0.5], [0.2, 0.2, 0.5]])
l1 = criterion(Y_pred1, Y)
l2 = criterion(Y_pred2, Y)
print("Batch Loss1 = ", l1.data, "\nBatch Loss2=", l2.data)

运行结果如下
在这里插入图片描述

根据上面的代码可以看出第一个损失比第二个损失要小。原因很简单,想对于Y_pred1每一个预测的分类与Y是一致的,而Y_pred2则相差了一下,所以损失自然就大了些

🍋MNIST dataset的实现

首先第一步还是导包

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader 
import torch.nn.functional as F 
import torch.optim as optim

之后是数据的准备

batch_size = 64
# transform可以将其转化为0-1,形状的转换从28×28转换为,1×28×28
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307, ), (0.3081, ))   # 均值mean和标准差std
])
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True,download=True,transform=transform)  
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False,download=True,transform=transform)
test_loader = DataLoader(test_dataset,shuffle=False,batch_size=batch_size)

在这里插入图片描述
接下来我们构建网络

class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512) self.l2 = torch.nn.Linear(512, 256) self.l3 = torch.nn.Linear(256, 128) self.l4 = torch.nn.Linear(128, 64) self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)x = F.relu(self.l1(x)) x = F.relu(self.l2(x)) x = F.relu(self.l3(x)) x = F.relu(self.l4(x)) return self.l5(x)  # 注意最后一层不做激活
model = Net()

之后定义损失和优化器

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

接下来就进行训练了

def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0): inputs, target = dataoptimizer.zero_grad()# forward + backward + updateoutputs = model(inputs)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300)) running_loss = 0.0
def test():correct = 0total = 0with torch.no_grad(): # 这里可以防止内嵌代码不会执行梯度for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %%' % (100 * correct / total))

最后调用执行

if __name__ == '__main__': for epoch in range(10): train(epoch)test()

🍋NLLLoss 和 CrossEntropyLoss

NLLLoss 和 CrossEntropyLoss(也称为交叉熵损失)是深度学习中常用的两种损失函数,用于测量模型的输出与真实标签之间的差距,通常用于分类任务。它们有一些相似之处,但也有一些不同之处。

相同点:

用途:两者都用于分类任务,评估模型的输出和真实标签之间的差异,以便进行模型的训练和优化。
数学基础:NLLLoss 和 CrossEntropyLoss 本质上都是交叉熵损失的不同变种,它们都以信息论的概念为基础,衡量两个概率分布之间的相似度。
输入格式:它们通常期望模型的输出是一个概率分布,表示各个类别的预测概率,以及真实的标签。

不同点:

输入格式:NLLLoss 通常期望输入是对数概率(log probabilities),而 CrossEntropyLoss 通常期望输入是未经对数化的概率。在实际应用中,CrossEntropyLoss 通常与softmax操作结合使用,将原始模型输出转化为概率分布,而NLLLoss可以直接使用对数概率。
对数化:NLLLoss 要求将模型输出的概率经过对数化(取对数)以获得对数概率,然后与真实标签的离散概率分布进行比较。CrossEntropyLoss 通常在 softmax 操作之后直接使用未对数化的概率值与真实标签比较。
输出维度:NLLLoss 更通用,可以用于多种情况,包括多类别分类和序列生成等任务,因此需要更多的灵活性。CrossEntropyLoss 通常用于多类别分类任务。

总之,NLLLoss 和 CrossEntropyLoss 都用于分类任务,但它们在输入格式和使用上存在一些差异。通常,选择哪个损失函数取决于你的模型输出的格式以及任务的性质。如果你的模型输出已经是对数概率形式,通常使用NLLLoss,否则通常使用CrossEntropyLoss

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/107291.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue配置全局变量config.js

Vue配置全局变量config.js 若config.js在public目录下 在index.html中引入 这样配置是为了防止路由前缀&#xff0c;如果直接“/config.js”&#xff0c;若路由没有前缀还好&#xff0c;要是有就需要配置为“<% BASE_URL %>config.js”

Xilinx IP 10G Ethernet PCS/PMA IP Core

Vivado 10G Ethernet PCS/PMA介绍 1介绍 完整的10G以太网接口如下图,分为10G PHY和10G MAC两部分。 这篇文章重点讲 10G Ethernet PCS/PMA。 2 IP的基本介绍 10G以太网物理编码子层/物理介质连接(PCS/PMA)核心在Xilinx 10G以太网介质访问控制器(MAC)核心和具有10Gb/s…

Linux网络编程系列之UDP组播

Linux网络编程系列 &#xff08;够吃&#xff0c;管饱&#xff09; 1、Linux网络编程系列之网络编程基础 2、Linux网络编程系列之TCP协议编程 3、Linux网络编程系列之UDP协议编程 4、Linux网络编程系列之UDP广播 5、Linux网络编程系列之UDP组播 6、Linux网络编程系列之服务器编…

尚硅谷Flink(一)

目录 ☄️前置工作 fenfa脚本 &#x1f30b;概述 ☄️Flink是什么 ☄️特点&#xff08;多nb&#xff09; ☄️应用场景&#xff08;不用看&#xff09; ☄️分层API &#x1f30b;配环境 ☄️wordcount ☄️WcDemoUnboundStreaming &#x1f30b;集群部署 ☄️集…

xml的语法

<!-- 1、每一个xml,有且只有一个根标签&#xff0c;所有xml标签必须写在根标签中 2、标签必须要有合闭 3、xml格式是否正确&#xff0c;可以通过浏览器打开xml。来校验xml格式是否正确 4、xml是区别大小写的 5、xml书写标签名时&#xff0c;不要出现空格等特…

微信小程序的框架

目录 一、视图层 1. WXML 数据绑定 列表渲染 条件渲染 模板 2. WXSS 尺寸单位 样式导入 内联样式 选择器 3. WXS事件 二、逻辑层 1. 页面生命周期 2.跳转 1. 一级跳一级 2. 一级跳二级 3. 二级跳二级 4. 二级跳一级 总结 带给我们的收获 一、视图层 1. …

idea 启动项目报错 Command line is too long

1.idea 启动报错 Command line is too long&#xff0c;启动报错信息&#xff1a;Error running ‘Application‘: Command line is too long. 2.如何解决&#xff1f; 1&#xff09;idea打开一个项目。 2.打开项目下的*.idea* 文件夹下的 workspace.xml 文件。 3.在<co…

springboot 导出word模板

一、安装依赖 <dependency><groupId>com.deepoove</groupId><artifactId>poi-tl</artifactId><version>1.12.1</version></dependency>二、定义工具类 package com.example.springbootmp.utils;import com.deepoove.poi.XWP…

5.Python-使用XMLHttpRequest对象来发送Ajax请求

题记 使用XMLHttpRequest对象来发送Ajax请求&#xff0c;以下是一个简单的实例和操作过程。 安装flask模块 pip install flask 安装mysql.connector模块 pip install mysql-connector-python 编写app.py文件 app.py文件如下&#xff1a; from flask import Flask, reque…

好用的跨平台同步笔记工具,手机和电脑可同步的笔记工具

在这个快节奏的工作环境中&#xff0c;每个人都在寻找一种方便又高效的方式来记录工作笔记。记录工作笔记可以帮助大家统计工作进展&#xff0c;了解工作进程&#xff0c;而如果工作中常在一个地方办公&#xff0c;直接选择电脑或者手机中笔记工具来记录即可&#xff0c;但是对…

[部署网站]01安装宝塔面板搭建WordPress

宝塔面板安装WordPress&#xff08;超详细&#xff09;_Wordpress主题网 参考教程 宝塔面板 - 简单好用的Linux/Windows服务器运维管理面板 官网 1.首先你需要一个服务器或者主机 &#xff08;Windows系统或者Linux系统都可以&#xff09; 推荐Linux系统更稳定&#xff0c;…

Axure RP医疗在线挂号问诊原型图医院APP原形模板

医疗在线挂号问诊Axure RP原型图医院APP原形模板&#xff0c;是一款原创的医疗类APP&#xff0c;设计尺寸采用iPhone13&#xff08;375*812px&#xff09;&#xff0c;原型图上加入了仿真手机壳&#xff0c;使得预览效果更加逼真。 本套原型图主要功能有医疗常识科普、医院挂号…

海思平台SS528V100编译 linux kernel tun.c eth_get_headlen 编译 出错的问题

osdrv目录下 make kernel 会去opensource目录下解压linux内核压缩包 同时打上很多补丁 如上图 查看Makefile 如下 有打补丁的命令 然后 然后我们的应用程序用到一个特性 需要打开tun/tab这两个属性 打开之后编译内核出错 查到最后发现 没打补丁之前的文件 没问题 …

Qt拖拽文件到窗口、快捷方式打开

大部分客户端都支持拖拽文件的功能&#xff0c;本篇博客介绍Qt如何实现文件拖拽到窗口、快捷方式打开&#xff0c;以我的开源视频播放器项目为例&#xff0c;介绍拖拽视频到播放器窗口打开。   需要注意的是&#xff0c;Qt拖拽文件的功能&#xff0c;不支持以管理员权限启动的…

maven 新建模块 导入后 按Ctrl 点不进新建模块pom定义

新建的ruoyi-common-mybatisplus 模块,导入一直不正常 画出的模块一直导入不进来 这是提示信息 这是正常的提示信息 加上 <version>3.6.3</version> 后,才一切正常

攻防世界题目练习——Web引导模式(三)(持续更新)

题目目录 1. mfw2.3.4.5. 1. mfw 进去看到网页和页面内容如下&#xff1a; 看到url的参数 ?pageabout &#xff0c;我以为是文件包含什么的&#xff0c;反复试了几次&#xff0c;想用 …/…/…/…/etc/passwd &#xff0c;但是发现.似乎被过滤了&#xff0c;实在不知道怎么做…

uniapp(uncloud) 使用生态开发接口详情2(使用 schema创建数据, schema2code创建页面, iconfont 引入项目)

上一篇介绍如何创建项目,接下来该是如何使用 在项目中pages 目录下,新建界面 项目运行,浏览器中用账号密码登录, 新建一级和二级页面 2.1 系统管理 > 菜单管理 (新增一级界面) 2.2 找到刚刚创建的菜单, 操作行有 子菜单(点击) 用DB Schema创建页面, 3.1 在uniCloud > d…

开路、断路和短路区别

文章目录 开路和断路击穿电源短路、用电器短路、对地短路和对电源短路 开路和断路 开路和断路是电路中两种用于描述电流流动情况的状态。 两者易混淆&#xff0c;常被混淆使用&#xff0c;但是它们还是有所不同。 开路表示电路中存在一个断链&#xff0c;电流无法从一个点流到…

git强制删除本地分支 git branch -D

git强制删除本地分支 git branch -D git删除本地分支_zhangphil的博客-CSDN博客git branch -d <分支名>可以通过: git branch 查看所有本地分支及其名字&#xff0c;然后删除特定分支。https://blog.csdn.net/zhangphil/article/details/82255002 使用git branch -d删除…

退税政策线上VR互动科普展厅为税收工作带来了强大活力

缴税纳税是每个公民应尽的义务和责任&#xff0c;由于很多人缺乏专业的缴税纳税操作专业知识和经验&#xff0c;因此为了提高大家的缴税纳税办事效率和好感度&#xff0c;越来越多地区税务局开始引进VR虚拟现实、web3d开发和多媒体等技术手段&#xff0c;基于线上为广大公民提供…