PyTorch学习8:多分类问题

文章目录

  • 前言
  • 一、说明
  • 二、示例
    • 1.步骤
    • 2.示例代码
  • 总结


前言

介绍如何利用PyTorch中Softmax 分类器实现多分类问题。

一、说明

1.多分类问题的输出是一个分布,满足和为1.
2.Softmax 分类器
在这里插入图片描述
3.损失函数:交叉熵损失
torch.nn.CrossEntropyLoss()
在这里插入图片描述

二、示例

1.步骤

1.建立模型
2.定义训练函数
3.定义测试函数
4.主函数:定义训练集和测试集,定义损失函数和优化器,进行训练,存储结果,绘图

2.示例代码

代码如下(示例):

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import pickle
# prepare dataset# batch_size = 64
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])  # 归一化,均值和方差
#
# train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
# test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
# test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)# design model using classclass Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)  # -1其实就是自动获取mini_batchx = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x)  # 最后一层不做激活,不进行非线性变换# model = Net()
#
# # construct loss and optimizer
# criterion = torch.nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)# training cycle forward, backward, updatedef train(epoch):running_loss = 0.0loss_s = 0.0for batch_idx, data in enumerate(train_loader, 0):# 获得一个批次的数据和标签inputs, target = dataoptimizer.zero_grad()# 获得模型预测结果(64, 10)outputs = model(inputs)# 交叉熵代价函数outputs(64,10),target(64)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()loss_s += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))running_loss = 0.0return loss_s / len(train_loader)def test():correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度total += labels.size(0)correct += (predicted == labels).sum().item()  # 张量之间的比较运算print('accuracy on test set: %d %% ' % (100 * correct / total))return 100 * correct / totalif __name__ == '__main__':batch_size = 64transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])  # 归一化,均值和方差train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)model = Net()# construct loss and optimizercriterion = torch.nn.CrossEntropyLoss(reduction='mean')optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)epoch_list = []loss_list = []accuracy_list = []for epoch in range(10):epoch_list.append(epoch)loss_lis = train(epoch)loss_list.append(loss_lis)tes = test()accuracy_list.append(tes)with open('8/epoch_list.pkl', 'wb') as f:pickle.dump(epoch_list, f)with open('8/loss_list.pkl', 'wb') as f:pickle.dump(loss_list, f)with open('8/accuracy_list.pkl', 'wb') as f:pickle.dump(accuracy_list, f)

画图程序如下:

import pickle
import matplotlib.pyplot as pltwith open('8/epoch_list.pkl', 'rb') as f:loaded_epoch_list = pickle.load(f)
with open('8/loss_list.pkl', 'rb') as f:loaded_loss_list = pickle.load(f)
with open('8/accuracy_list.pkl', 'rb') as f:loaded_acc_list = pickle.load(f)plt.subplot(2, 1, 1)  # 创建子图,2行1列,第1个子图
plt.plot(loaded_epoch_list, loaded_loss_list)
plt.xlabel('epoch')
plt.ylabel('loss 1')plt.subplot(2, 1, 2)  # 创建子图,2行1列,第2个子图
plt.plot(loaded_epoch_list, loaded_acc_list,'r')
plt.xlabel('epoch')
plt.ylabel('acc 1')
plt.show()

得到如下结果:
在这里插入图片描述

总结

PyTorch学习8:多分类问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/25889.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

运维开发详解:DevOps 理念下的高效运维实践

目录 前言 1、 运维开发的核心概念 2、 运维开发的技术栈 3、运维开发的实践案例 4、 运维开发的挑战与机遇 5、 运维开发的未来发展趋势 6、运维开发概念 7、运维开发的角色 8、成为一名优秀的运维开发工程师 9、总结 前言 随着互联网业务的快速发展,传…

虚拟化 之一 详解 jailhouse 架构及原理、软硬件要求、源码文件、基本组件

Jailhouse 是一个基于 Linux 实现的针对创建工业级应用程序的小型 Hypervisor,是由西门子公司的 Jan Kiszka 于 2013 年开发的,并得到了官方 Linux 内核的支持,在开源社区中获得了知名度和吸引力。 Jailhouse Jailhouse 是一种轻量级的虚拟化…

微软如何打造数字零售力航母系列科普13 - Prime Focus Technologies在NAB 2024上推出CLEAR®对话人工智能联合试点

Prime Focus Technologies在NAB 2024上推出CLEAR对话人工智能联合试点 彻底改变您与内容的互动方式,从内容的创建到分发 洛杉矶,2024年4月9日/PRNewswire/-媒体和娱乐(M&E)行业人工智能技术解决方案的先驱Prime Focus Techn…

架构师如何评估团队成员的成熟度

评估团队成员的成熟度是一个涉及观察、沟通和反馈的过程。以下是一些方法和步骤,可以帮助你评估团队成员的成熟度,无论是在技术能力、还是职业发展方面: 设定评估标准:首先,明确你希望评估的成熟度方面,比…

人工智能在医学领域的应用及技术实现

欢迎来到 Papicatch的博客 目录 🍉引言 🍉 医学影像分析 🍈技术实现 🍍数据准备 🍍模型构建 🍍模型训练 🍍模型评估 🍍应用部署 🍈示例代码 🍉 基因…

操作系统真象还原:内存管理系统

第8章-内存管理系统 这是一个网站有所有小节的代码实现,同时也包含了Bochs等文件 8.1 Makefile简介 8.1.1 Makefile是什么 8.1.2 makefile基本语法 make 给咱们提供了方法,可以在命令之前加个字符’@’,这样就不会输出命令本身…

微信小程序使用 “云函数“ 获取 “openid“

文章目录 1.前期准备2.具体操作步骤 1.前期准备 必须使用云开发已经配置好云开发 2.具体操作步骤 1.进入小程序开发工具→在云函数目录上右键→选中新建云函数 创建结束,自动上传(必须确认已经上传才生效) 2.进入对应页面的js文件&#…

QT 信号和槽 信号关联到信号示例 信号除了可以绑定槽以外,信号还可以绑定信号

信号除了可以关联到槽函数,还可以关联到类型匹配的信号,实现信号的接力触发。上个示例中因为 clicked 信号没有参数,而 SendMsg 信号有参数,所以不方便直接关联。本小节示范一个信号到信号的关联,将按钮的 clicked 信号…

【优化过往代码】关于vue自定义事件的运用

【优化过往代码】关于vue自定义事件的运用 需求说明过往代码优化思路优化后代码(Vue2)遇到问题记录 Vue2官方自定义指令说明文档 Vue3官方自定义指令说明文档 需求说明 进入某些页面需要加载一些外部资源,并在资源加载完后进行一些处理&…

51单片机数码管显示的计数器,按键按下暂定,再次按下继续。(按键功能使用中断实现)

1、功能描述 数码管显示的计数器,按键按下暂定,再次按下继续。(按键功能使用中断实现) 2、实验原理 按键与中断:使用单片机的外部中断功能来检测按键动作,实现非阻塞的按键检测。 中断服务程序&…

十四、OpenAI之助手API(Asistants API)

助手API允许你在自己的应用系统中构建一个AI助手。助手有指令,能利用模型、工具和文件响应用户的查询。助手API目前支持3种类型的工具:代码交互,文件搜索和函数调用。 你可以使用助手后台探索助手的能力,或通过这个指南的大纲一步…

【栈】2751. 机器人碰撞

本文涉及知识点 栈 LeetCode2751. 机器人碰撞 现有 n 个机器人,编号从 1 开始,每个机器人包含在路线上的位置、健康度和移动方向。 给你下标从 0 开始的两个整数数组 positions、healths 和一个字符串 directions(directions[i] 为 ‘L’ …

MySQL-数据处理函数

026-distinct去重 select job from emp;加个 distinct 就行了 select distinct job from emp;注意:这个去重只是将显示的结果去重,原表数据不会被更改。 select 永远不会改变原数据 select distinct deptno, job from emp order by deptno asc;027-数…

步态控制之足旋转点(Foot Rotation Indicator, FRI)

足旋转点(Foot Rotation Indicator, FRI) 足旋转点是人形机器人步态规划中的一个关键概念,用于描述步态过程中机器人脚部的旋转和稳定性。FRI 可以帮助确定机器人在行走时是否稳定,以及如何调整步态以保持稳定。下面详细介绍FRI的原理,并举例说明其应用。 足旋转点(FRI…

R语言统计分析——图形的简单示例

参考资料:R语言实战【第2版】 1、示例一 # 绑定数据框mtcars attach(mtcars)# 打开一个图形窗口并生成一个散点图plot(wt,mpg)# 添加一条最优拟合曲线abline(lm(mpg~wt))# 添加标题title("Regression of MPG on weight") # 解除数据框绑定 detach(mtcar…

ES8.13 _bulk报错Malformed content, found extra data after parsing: START_OBJECT解决

在使用elaticsearch8.13.0使用批量创建索引时,根据谷粒中说的es7.9方法去批量操作请求: http://127.0.0.1:9200/shop/_doc/_bulk 注意1:设置header为Content-Type:application/x-ndjson,否则请求报错: {"error": &qu…

量化视频2---miniqmt的使用配置

量化视频2---miniqmt的使用配置 量化视频2---miniqmt的使用配置 (qq.com)

机器学习笔记:focal loss

1 介绍 Focal Loss 是一种在类别不平衡的情况下改善模型性能的损失函数最初在 2017 年的论文《Focal Loss for Dense Object Detection》中提出这种损失函数主要用于解决在有挑战性的对象检测任务中,易分类的负样本占据主导地位的问题,从而导致模型难以…

什么是async/await?

async/await 是 JavaScript 中处理异步操作的一种新方式,它使得异步代码能够以同步的方式书写,从而提高了代码的可读性和可维护性。 async async 是一个函数修饰符,用于声明一个函数是异步的。一个 async 函数总是返回一个 Promise 对象。如…

6.10 c语言

7.1 if-else语句 简化形式 if(表达式)语句块 阶梯形式 if(表达式1)语句块1 else if(表达式2&#xff09;语句块2 嵌套形式 if() if() 语句1 else 语句2 else if() 语句3 else 语句4 表达式一般情况下为逻辑表达式或关系表达式 #include <stdio.h>//从小到大排序,输出顺…