18.多分类问题代码实现

在机器学习中,多分类问题是一类常见的问题,它涉及到将输入数据划分为多个类别中的一个。例如,在图像识别中,我们可能需要将图像分为不同的类别,如手写数字识别(MNIST数据集)就是将手写数字图像分类为0-9的十个数字。本文将介绍如何使用PyTorch框架来构建一个简单的神经网络模型来解决多分类问题,并以MNIST数据集为例进行说明。

数据集

MNIST是一个包含手写数字图像的大型数据集,由NIST(美国国家标准与技术研究院)发起整理,包含了60,000个训练样本和10,000个测试样本。每个样本都是一张28x28像素的灰度图像,表示一个0-9之间的手写数字。

构建神经网络模型

首先,我们需要导入必要的库,并定义神经网络模型。这里我们将使用一个简单的全连接神经网络,包含两个隐藏层和一个输出层。

import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms # 定义神经网络模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28 * 28, 500) self.fc2 = nn.Linear(500, 100) self.fc3 = nn.Linear(100, 10) def forward(self, x): x = x.view(-1, 28 * 28) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return torch.log_softmax(x, dim=1) # 实例化模型 model = Net()

数据加载和预处理

接下来,我们需要加载MNIST数据集,并进行必要的预处理。这里我们使用torchvision.datasets.MNIST来加载数据集,并使用torch.utils.data.DataLoader来加载数据。

# 数据预处理:转换为Tensor并归一化  
transform = transforms.Compose([  transforms.ToTensor(),  transforms.Normalize((0.5,), (0.5,))  
])  # 加载训练集和测试集  
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)  
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)  testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)  
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

定义损失函数和优化器

对于多分类问题,我们通常使用交叉熵损失函数(CrossEntropyLoss)。在PyTorch中,nn.CrossEntropyLoss结合了LogSoftmax和NLLLoss,所以我们不需要在模型输出时显式使用LogSoftmax。

对于优化器,我们选择随机梯度下降(SGD)。

# 定义损失函数和优化器  
criterion = nn.CrossEntropyLoss()  
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

训练模型

现在我们可以开始训练模型了。在每个训练周期(epoch)中,我们将遍历整个训练集,计算损失,反向传播梯度,并更新模型参数。

# 训练模型  
num_epochs = 10  
for epoch in range(num_epochs):  for i, (images, labels) in enumerate(trainloader, 0):  # 清零梯度缓存  optimizer.zero_grad()  # 前向传播  outputs = model(images)  loss = criterion(outputs, labels)  # 反向传播和优化  loss.backward()  optimizer.step()  if (i+1) % 1000 == 0:  print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(trainloader)}], Loss: {loss.item()}')  print('Finished Training')

评估模型

训练完成后,我们可以使用测试集来评估模型的性能。这里我们计算了模型在测试集上的准确率。

# 评估模型  
correct = 0  
total = 0  
with torch.no_grad():  # 不需要计算梯度,节省内存和计算资源  for images, labels in testloader:  outputs = model(images)  _, predicted = torch.max(outputs.data, 1)  # 获取预测结果中概率最大的类别索引  total += labels.size(0)  # 总样本数  correct += (predicted == labels).sum().item()  # 正确预测的样本数  print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

总结

本文介绍了如何使用PyTorch框架来构建和训练一个用于多分类问题的神经网络模型。我们以MNIST手写数字数据集为例,展示了从数据加载和预处理、模型定义、损失函数和优化器选择,到模型训练和评估的整个流程。

在实际应用中,我们可以根据具体的问题和数据集来调整模型的结构和参数,以获得更好的性能。此外,还可以使用更高级的技术和策略来优化模型的训练和评估过程,例如数据增强、正则化、学习率调整等。

通过本文的介绍,读者应该能够掌握使用PyTorch进行多分类问题建模的基本流程和关键技术,为后续的深度学习项目打下坚实的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/16580.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

反对加征关税,特斯拉上海厂传减产20% | 百能云芯

特斯拉公司首席执行官马斯克近日在公开场合表达了对美国计划对中国电动车加征关税的反对立场,强调特斯拉不支持任何扭曲市场的举措。据知情人士透露,特斯拉上海工厂正计划在今年第二季度至少削减Model Y车型的产量20%,以应对市场需求的变化。…

Linux配置PyTorch GPU环境

本文是基于系统已经安装了驱动和CUDA的,假如不会安装驱动和CUDA的,可以参考我写的上一篇文章:https://blog.csdn.net/pdc31czy/article/details/136072017?spm1001.2014.3001.5501 并且本文是基于HPC写的笔记,普通电脑跳过步骤1…

C#读取.sql文件并执行文件中的sql脚本

有些时候我们需要在程序中编写读取sql脚本文件并执行这些sql语句,但是我们在有些时候会遇到读出来的sql语句不能执行,其实不能执行并不是你的sql脚本文件有错误,而是去执行sql语句的时候,C#代码里面执行sql语句的代码对sql里面的一…

低代码与人工智能:改变软件开发的未来

引言 在当今快速发展的科技时代,软件开发行业也在不断地创新和演进。其中,低代码开发和人工智能技术是两个备受关注的领域,低代码开发通过简化开发流程和降低编码难度,使得软件开发变得更加高效和便捷,而人工智能技术…

正宇软件:引领数字人大新纪元,开启甘肃人大代表履职新篇章

在数字化强国的主旋律之下,政府工作的数字化、智能化转型已成为提升治理效能、增强人民满意度的关键一环。在这个大背景下,正宇软件技术开发有限公司以其卓越的技术实力和丰富的行业经验,成为了政府信息化建设的杰出代表。甘肃省人大代表履职…

基于 Wireshark 分析 TCP 协议

一、TCP 协议 TCP(Transmission Control Protocol)是一种面向连接的、可靠的传输层协议。它在网络通信中扮演着重要的角色,用于保证数据的可靠传输。 TCP协议的特点如下: 1. 面向连接:在通信前需要先建立连接&#x…

Hunyuan-DiT环境搭建推理测试

引子 最近鹅厂竟然开源了一个多模态的大模型,之前分享福报厂的多模态视觉大模型(Qwen-VL环境搭建&推理测试-CSDN博客)感兴趣的可以移步。鹅厂开源的,我还是头一回部署。好的,那就让我们看看这个多模态视觉大模型有…

强化学习,第 3 部分:蒙特卡罗方法

文章目录 一、介绍二、关于此文章三、无模型方法与基于模型的方法四、V函数估计4.1 基本概念4.2 V-功能 五、Q 函数估计5.1 V函数概念5.2 优势5.3 Q函数 六、勘探与勘探的权衡七、结论 一、介绍 从赌场到人工智能:揭示蒙特卡罗方法在复杂环境中的强大功能    强化…

企微运营SOP:构建高效、规范的运营流程

随着企业微信在企业内部沟通协作中的广泛应用,如何构建一套高效、规范的企微运营流程成为了众多企业关注的焦点。本文将详细探讨企微运营SOP(Standard Operating Procedure,标准操作程序)的重要性、构建方法以及实施效果&#xff…

zstd库数据压缩与解压缩

在 Visual Studio 2019 中使用 C 的 zstd 库进行数据压缩与解压缩 在今天的博客中,我们将探讨如何在 Visual Studio 2019 中使用 zstd 库进行高效的数据压缩和解压缩。zstd(也称为 Zstandard 或 zstd)是由 Facebook 开发的开源压缩库&#x…

动手学深度学习22 池化层

动手学深度学习22 池化层 1. 池化层2. 实现3. QA 课本: https://zh-v2.d2l.ai/chapter_convolutional-neural-networks/pooling.html 视频: https://www.bilibili.com/video/BV1EV411j7nX/?spm_id_fromautoNext&vd_sourceeb04c9a33e87ceba9c9a2e5f0…

CTF_RE周报(五)

这周感觉题目都开始上难度了,很多题都需要很多的基础知识,也是练到哪学到那,所以刷题的速度还是降了一点 angr符号化执行 上上周就已经遇到了,这周一个buu题也是可以用,就开始学学了,目前还差一半 [WUST…

算法刷题笔记 高精度加法(C++实现)

文章目录 题目描述题目思路和代码 题目描述 给定两个正整数(不含前导0),计算它们的和。 输入格式 共两行,每行包含一个整数。 输出格式 共一行,包含所求的和。 题目思路和代码 基本思路:模拟竖式计算…

关于单元测试

关于单元测试的一些总结:

【408真题】2009-17

“接”是针对题目进行必要的分析,比较简略; “化”是对题目中所涉及到的知识点进行详细解释; “发”是对此题型的解题套路总结,并结合历年真题或者典型例题进行运用。 涉及到的知识全部来源于王道各科教材(2025版&…

618值得买的东西有哪些?买什么最划算?超全品类大清单总结

平日里让许多人心动不已的收藏加购好物,是否常常因为价格昂贵而让人望而却步?然而,618活动期间的到来,恰恰为我们提供了一个难得的购物盛宴!相信在第一波活动中,许多消费者已经跃跃欲试,开始享受…

SuperSocket 自定义AppServer、AppSession、CommandBase

1、预期效果如下图。 2、自定义AppServer,代码如下。 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using SuperSocket.SocketBase; using SuperSocket.SocketBase.Config;namespace Co…

做抖音电商,可以没有货源和经验,但不能没有耐心

我是王路飞。 在抖音做电商这件事,不需要怀疑其可行性。 经过四五年的发展,平台和商家已经证明了抖音电商的前景,它就是我们普通人做抖音最适合的一个渠道。 想在抖音做电商的,再给你们一个经验之谈,你可以没有货源…

linux 查看磁盘使用情况

在Linux系统中,你可以使用以下命令来查看磁盘的情况: 1.df命令:用于显示文件系统的磁盘空间使用情况。 df -h该命令会以人类可读的方式显示文件系统的磁盘空间使用情况,包括文件系统、已用空间、可用空间、已用百分比、挂载点等…

hudi0.13版本clean策略

hudi0.13版本clean策略 在 Apache Hudi 0.13 版本中,清理策略对于数据管理和存储优化起着关键作用。为了确保数据湖的有效利用和性能优化,了解和正确配置清理策略至关重要。以下是 Hudi 0.13 版本的清理策略详细说明及注意事项。 清理策略概述 Hudi 提…