PyTorch学习8:多分类问题

文章目录

  • 前言
  • 一、说明
  • 二、示例
    • 1.步骤
    • 2.示例代码
  • 总结


前言

介绍如何利用PyTorch中Softmax 分类器实现多分类问题。

一、说明

1.多分类问题的输出是一个分布,满足和为1.
2.Softmax 分类器
在这里插入图片描述
3.损失函数:交叉熵损失
torch.nn.CrossEntropyLoss()
在这里插入图片描述

二、示例

1.步骤

1.建立模型
2.定义训练函数
3.定义测试函数
4.主函数:定义训练集和测试集,定义损失函数和优化器,进行训练,存储结果,绘图

2.示例代码

代码如下(示例):

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import pickle
# prepare dataset# batch_size = 64
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])  # 归一化,均值和方差
#
# train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
# test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
# test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)# design model using classclass Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)  # -1其实就是自动获取mini_batchx = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x)  # 最后一层不做激活,不进行非线性变换# model = Net()
#
# # construct loss and optimizer
# criterion = torch.nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)# training cycle forward, backward, updatedef train(epoch):running_loss = 0.0loss_s = 0.0for batch_idx, data in enumerate(train_loader, 0):# 获得一个批次的数据和标签inputs, target = dataoptimizer.zero_grad()# 获得模型预测结果(64, 10)outputs = model(inputs)# 交叉熵代价函数outputs(64,10),target(64)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()loss_s += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))running_loss = 0.0return loss_s / len(train_loader)def test():correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度total += labels.size(0)correct += (predicted == labels).sum().item()  # 张量之间的比较运算print('accuracy on test set: %d %% ' % (100 * correct / total))return 100 * correct / totalif __name__ == '__main__':batch_size = 64transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])  # 归一化,均值和方差train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)model = Net()# construct loss and optimizercriterion = torch.nn.CrossEntropyLoss(reduction='mean')optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)epoch_list = []loss_list = []accuracy_list = []for epoch in range(10):epoch_list.append(epoch)loss_lis = train(epoch)loss_list.append(loss_lis)tes = test()accuracy_list.append(tes)with open('8/epoch_list.pkl', 'wb') as f:pickle.dump(epoch_list, f)with open('8/loss_list.pkl', 'wb') as f:pickle.dump(loss_list, f)with open('8/accuracy_list.pkl', 'wb') as f:pickle.dump(accuracy_list, f)

画图程序如下:

import pickle
import matplotlib.pyplot as pltwith open('8/epoch_list.pkl', 'rb') as f:loaded_epoch_list = pickle.load(f)
with open('8/loss_list.pkl', 'rb') as f:loaded_loss_list = pickle.load(f)
with open('8/accuracy_list.pkl', 'rb') as f:loaded_acc_list = pickle.load(f)plt.subplot(2, 1, 1)  # 创建子图,2行1列,第1个子图
plt.plot(loaded_epoch_list, loaded_loss_list)
plt.xlabel('epoch')
plt.ylabel('loss 1')plt.subplot(2, 1, 2)  # 创建子图,2行1列,第2个子图
plt.plot(loaded_epoch_list, loaded_acc_list,'r')
plt.xlabel('epoch')
plt.ylabel('acc 1')
plt.show()

得到如下结果:
在这里插入图片描述

总结

PyTorch学习8:多分类问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/25889.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

虚拟化 之一 详解 jailhouse 架构及原理、软硬件要求、源码文件、基本组件

Jailhouse 是一个基于 Linux 实现的针对创建工业级应用程序的小型 Hypervisor,是由西门子公司的 Jan Kiszka 于 2013 年开发的,并得到了官方 Linux 内核的支持,在开源社区中获得了知名度和吸引力。 Jailhouse Jailhouse 是一种轻量级的虚拟化…

微软如何打造数字零售力航母系列科普13 - Prime Focus Technologies在NAB 2024上推出CLEAR®对话人工智能联合试点

Prime Focus Technologies在NAB 2024上推出CLEAR对话人工智能联合试点 彻底改变您与内容的互动方式,从内容的创建到分发 洛杉矶,2024年4月9日/PRNewswire/-媒体和娱乐(M&E)行业人工智能技术解决方案的先驱Prime Focus Techn…

人工智能在医学领域的应用及技术实现

欢迎来到 Papicatch的博客 目录 🍉引言 🍉 医学影像分析 🍈技术实现 🍍数据准备 🍍模型构建 🍍模型训练 🍍模型评估 🍍应用部署 🍈示例代码 🍉 基因…

操作系统真象还原:内存管理系统

第8章-内存管理系统 这是一个网站有所有小节的代码实现,同时也包含了Bochs等文件 8.1 Makefile简介 8.1.1 Makefile是什么 8.1.2 makefile基本语法 make 给咱们提供了方法,可以在命令之前加个字符’@’,这样就不会输出命令本身…

微信小程序使用 “云函数“ 获取 “openid“

文章目录 1.前期准备2.具体操作步骤 1.前期准备 必须使用云开发已经配置好云开发 2.具体操作步骤 1.进入小程序开发工具→在云函数目录上右键→选中新建云函数 创建结束,自动上传(必须确认已经上传才生效) 2.进入对应页面的js文件&#…

QT 信号和槽 信号关联到信号示例 信号除了可以绑定槽以外,信号还可以绑定信号

信号除了可以关联到槽函数,还可以关联到类型匹配的信号,实现信号的接力触发。上个示例中因为 clicked 信号没有参数,而 SendMsg 信号有参数,所以不方便直接关联。本小节示范一个信号到信号的关联,将按钮的 clicked 信号…

【优化过往代码】关于vue自定义事件的运用

【优化过往代码】关于vue自定义事件的运用 需求说明过往代码优化思路优化后代码(Vue2)遇到问题记录 Vue2官方自定义指令说明文档 Vue3官方自定义指令说明文档 需求说明 进入某些页面需要加载一些外部资源,并在资源加载完后进行一些处理&…

【栈】2751. 机器人碰撞

本文涉及知识点 栈 LeetCode2751. 机器人碰撞 现有 n 个机器人,编号从 1 开始,每个机器人包含在路线上的位置、健康度和移动方向。 给你下标从 0 开始的两个整数数组 positions、healths 和一个字符串 directions(directions[i] 为 ‘L’ …

MySQL-数据处理函数

026-distinct去重 select job from emp;加个 distinct 就行了 select distinct job from emp;注意:这个去重只是将显示的结果去重,原表数据不会被更改。 select 永远不会改变原数据 select distinct deptno, job from emp order by deptno asc;027-数…

步态控制之足旋转点(Foot Rotation Indicator, FRI)

足旋转点(Foot Rotation Indicator, FRI) 足旋转点是人形机器人步态规划中的一个关键概念,用于描述步态过程中机器人脚部的旋转和稳定性。FRI 可以帮助确定机器人在行走时是否稳定,以及如何调整步态以保持稳定。下面详细介绍FRI的原理,并举例说明其应用。 足旋转点(FRI…

R语言统计分析——图形的简单示例

参考资料:R语言实战【第2版】 1、示例一 # 绑定数据框mtcars attach(mtcars)# 打开一个图形窗口并生成一个散点图plot(wt,mpg)# 添加一条最优拟合曲线abline(lm(mpg~wt))# 添加标题title("Regression of MPG on weight") # 解除数据框绑定 detach(mtcar…

ES8.13 _bulk报错Malformed content, found extra data after parsing: START_OBJECT解决

在使用elaticsearch8.13.0使用批量创建索引时,根据谷粒中说的es7.9方法去批量操作请求: http://127.0.0.1:9200/shop/_doc/_bulk 注意1:设置header为Content-Type:application/x-ndjson,否则请求报错: {"error": &qu…

机器学习笔记:focal loss

1 介绍 Focal Loss 是一种在类别不平衡的情况下改善模型性能的损失函数最初在 2017 年的论文《Focal Loss for Dense Object Detection》中提出这种损失函数主要用于解决在有挑战性的对象检测任务中,易分类的负样本占据主导地位的问题,从而导致模型难以…

【recast-navigation-js】使用three.js辅助绘制Agent寻路路径

目录 说在前面setAgentTarget绘制寻路路径结果问题其他 说在前面 操作系统:windows 11浏览器:edge版本 124.0.2478.97recast-navigation-js版本:0.29.0golang版本:1.21.5上一篇:【recast-navigation-js】使用three.js辅…

linux:centos7升级libstdc++版本到3.4.26

下载,解压 wget http://www.vuln.cn/wp-content/uploads/2019/08/libstdc.so_.6.0.26.zip unzip libstdc.so_.6.0.26.zip 复制到【/usr/lib64】: cp libstdc.so.6.0.26 /usr/lib64创建软链接 cd /usr/lib64 sln libstdc.so.6.0.26 libstdc.so.6查看一…

Python | Leetcode Python题解之第144题二叉树的前序遍历

题目: 题解: class Solution:def preorderTraversal(self, root: TreeNode) -> List[int]:res list()if not root:return resp1 rootwhile p1:p2 p1.leftif p2:while p2.right and p2.right ! p1:p2 p2.rightif not p2.right:res.append(p1.val)…

机器学习笔记 - LoRA:大型语言模型的低秩适应

一、简述 1、模型微调 随着大型语言模型 (LLM) 的规模增加到数千亿,对这些模型进行微调成为一项挑战。传统上,要微调模型,我们需要更新所有模型参数。这也称为完全微调 (FFT) 。下图详细概述了此方法的工作原理。 完全微调FFT 的计算成本和资源需求很大,因为更新每…

Vmess协议是什么意思? VLESS与VMess有什么区别?

VMess 是一个基于 TCP 的加密传输协议,所有数据使用 TCP 传输,是由 V2Ray 原创并使用于 V2Ray 的加密传输协议,它分为入站和出站两部分,其作用是帮助客户端跟服务器之间建立通信。在 V2Ray 上客户端与服务器的通信主要是通过 VMes…

【InternLM实战营第二期笔记】06:Lagent AgentLego 智能体应用搭建

文章目录 讲解为什么要有智能体什么是 Agent智能体的组成智能体框架AutoGPTReWooReAct Lagent & Agent LegoAgentLego 实操Lagent Web Demo自定义工具 AgentLego:组装智能体“乐高”直接使用作为智能体,WebUI文生图测试 Agent 工具能力微调 讲解 为…

idea如何使用git reset进行回退以及如何使用git stash将暂存区文件储藏,打包后重新恢复暂存区文件

最近遇到一个棘手的问题,本来按照计划表开发,但是项目经理突然让你改一个小bug,改完需要马上部署到线上,但是你手上的活做到一半还没做完,提交上去那肯定是不可行的。这时就可以使用git stash命令先把当前进度&#xf…