新手小白的pytorch学习第八弹------分类问题模型和简单预测

目录

  • 1 启动损失函数和优化器
  • 2 训练模型
    • 创建训练和测试循环
  • 3 预测和评估模型

这篇是接着新手小白的pytorch学习第七弹------分类问题模型这一篇的,代码也是哟~

1 启动损失函数和优化器

对于我们的二分类问题,我们经常使用 binary cross entropy 作为损失函数
可以使用torch.optim.SGD()torch.optim.Adam() 作为优化器

有两个 binary cross entropy 函数

  1. torch.nn.BCELoss()-在label(target)和features(input)之间进行衡量
  2. torch.nn.BCEWithLogitsLoss()-这个和上面这个一样,不过它有一个sigmoid嵌入层(nn.Sigmoid)[之后我们会看这个方式的]

下面我们会创建损失函数和优化器,优化器我们使用SGD,优化器使用模型的参数,学习率为0.1

import torch.nn as nn
import torch.optim as optim
# 创建一个损失函数
loss_fn = nn.BCEWithLogitsLoss() # 嵌入sigmoid()函数# 创建一个优化器
optimizer = optim.SGD(params=model_0.parameters(),lr=0.1)

我们再引入一个新的东西,评估标准,它也可以像损失函数一样,来衡量你的模型怎么样。毕竟使用多个角度来衡量模型,能够让模型更加的公正客观。Accuracy准确度,可以看出在总样本中正确样本的数量,所以100%是最好的,毕竟我们期望它全部预测对,对吧。

# 创建一个计算准确率的accuracy函数
def accuracy_fn(y_true, y_pred):correct = torch.eq(y_true, y_pred).sum().item() # torch.eq()计算两个相同的张量acc = (correct/len(y_pred))*100return acc

现在我们可以使用这个函数来衡量我们的模型啦。

2 训练模型

这里使用的损失函数是nn.BCEWithLogits(),因此这个损失函数的输入是logits.
什么是logits呢,我的理解就是我们的模型输出的原始值,不经过处理的值,由于这个损失函数是有一个torch.sigmoid()函数的,所以数据的转换有三个步骤:logits -> prediction probability -> prediction labels

# 查看测试数据的前5个输出
with torch.inference_mode():y_logits = model_0(X_test.to(device))[:5]
y_logits

tensor([[0.6003],
[0.6430],
[0.5095],
[0.6260],
[0.5431]], device=‘cuda:0’)

因为我们的模型没有被训练,因此这些输出都是随机的。
并且我们模型的原始输出是logits,这些数字难以解释,我们需要能够和真实数据相比较的数据。
我们可以使用torch.sigmoid()激活函数来将数据转换为我们需要的形式.

# 使用 torch.sigmoid() 激活函数
y_pred_probs = torch.sigmoid(y_logits)
y_pred_probs

tensor([[0.6457],
[0.6554],
[0.6247],
[0.6516],
[0.6325]], device=‘cuda:0’)

y_pred_probs 现在是 prediction probability 预测概率的形式,概率就是有多大的可能,有多大的几率。在我们的情况中,我们理想的输出是0或1,所以这些值可以被看做一个决定的边界。比如说值越靠近零,那模型就将这个样本分类为0, 值越接近1,模型就将这个样本分类为1.

更具体地说:
if y_pred_probs >= 0.5, y=1(class 1)
if y_pred_probs < 0.5, y=0(class 0)

将预测概率转变成预测标签,我们四舍五入torch.sigmoid()函数的输出即可

# 将概率转变为标签
y_preds = torch.round(y_pred_probs)# 将刚才的过程连起来放在一起
y_preds_labels = torch.round(torch.sigmoid(model_0(X_test.to(device))[:5]))# 查看预测值和标签相等
print(torch.eq(y_preds.squeeze(), y_preds_labels.squeeze()))# 去掉额外的维度
y_preds.squeeze()

tensor([True, True, True, True, True], device=‘cuda:0’)
tensor([1., 1., 1., 1., 1.], device=‘cuda:0’)

y_test[:5]

y_test[:5]

创建训练和测试循环

# 设置随机种子,有利于代码的复现
torch.manual_seed(42)epochs = 100# 将数据放到指定的设备上
X_train, y_train = X_train.to(device), y_train.to(device)
X_test, y_test = X_test.to(device), y_test.to(device)# 创建训练和测试循环
for epoch in range(epochs):# 进入训练模式model_0.train()# 预测y_logits = model_0(X_train).squeeze()y_pred_prob = torch.sigmoid(y_logits)y_pred = torch.round(y_pred_prob)# 计算损失函数和准确率loss = loss_fn(y_logits, y_train)acc = accuracy_fn(y_true = y_train, y_pred = y_pred)optimizer.zero_grad()loss.backward()optimizer.step()# 测试model_0.eval()with torch.inference_mode():test_logits = model_0(X_test).squeeze()test_pred = torch.round(torch.sigmoid(test_logits))test_loss = loss_fn(test_logits, y_test)test_acc = accuracy_fn(y_true = y_test,y_pred = test_pred)# 打印出内容if epoch % 10 == 0:print(f"Epoch:{epoch} | Loss:{loss:.5f} | Accuracy:{acc:.2f}% | Test loss:{test_loss:.2f} | Test accuracy:{test_acc:.2f}%")

Epoch:0 | Loss:0.69313 | Accuracy:51.75% | Test loss:0.69 | Test accuracy:48.50%
Epoch:10 | Loss:0.69310 | Accuracy:51.75% | Test loss:0.69 | Test accuracy:48.00%
Epoch:20 | Loss:0.69308 | Accuracy:51.25% | Test loss:0.69 | Test accuracy:49.00%
Epoch:30 | Loss:0.69307 | Accuracy:50.75% | Test loss:0.69 | Test accuracy:48.00%
Epoch:40 | Loss:0.69306 | Accuracy:50.38% | Test loss:0.69 | Test accuracy:48.00%
Epoch:50 | Loss:0.69305 | Accuracy:51.12% | Test loss:0.69 | Test accuracy:47.50%
Epoch:60 | Loss:0.69304 | Accuracy:51.12% | Test loss:0.69 | Test accuracy:48.00%
Epoch:70 | Loss:0.69303 | Accuracy:50.75% | Test loss:0.69 | Test accuracy:47.50%
Epoch:80 | Loss:0.69303 | Accuracy:50.75% | Test loss:0.69 | Test accuracy:47.00%
Epoch:90 | Loss:0.69303 | Accuracy:50.38% | Test loss:0.69 | Test accuracy:46.50%

通过上面的数据,损失函数几乎没变化,精确度50%左右,感觉模型啥也没有学到,这就意味着它分类是随机的

3 预测和评估模型

从上面的数据,感觉我们的模型好像是随机猜测,我们来可视化一下看看究竟是怎么个事儿。

我们接着会写代码下载并导入helper_functions.py script来自Learn PyTorch for Deep Learning repo.

在这里有一个叫做 plot_decision_boundary() 的函数,它来可视化我们模型的分类的不同的点

我们也会导入我们在 01 中自己写的 plot_predictions()

import requests
from pathlib import Path
# 从仓库下载文档
if Path("helper_functions.py").is_file():print("helper_functions.py already exists, skipping download")
else:print("Downloading helper_functions.py")request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py")with open("helper_functions.py", "wb") as f:f.write(request.content)from helper_functions import plot_predictions, plot_decision_boundary

Downloading helper_functions.py

这里可能需要科学上网,我把文件helper_functions.py的代码放到文末了,可以自己创建一个.py文件粘进去。

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model_0, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model_0, X_test, y_test)

在这里插入图片描述
看中间这条白色的线,模型是通过这条直线来区分红色和蓝色的点,所以是50%的准确率,明显是不对的,因为我们的数据明显是圈圈。

从机器学习的方面看,我们的模型欠拟合(underfitting),即没有从数据中学习到数据的模式.

那我们如何改善呢?请听下回分解

终于把今天的学习整理出来了,BB啊,今天中午吃了个超级物美价廉的套餐,里面的土豆炖牛腩和我平常吃的不一样,它这个带汤,尊嘟很好吃,熏过的香干一定要尝尝啊,皮蛋也很八错。还喝了一杯瑞幸的美式,一般般吧,室友说苦,我就喜欢喝这种苦苦的,嘻嘻嘻。

师姐通过一个电话说我喜欢一个男孩子,就说我喜欢他,哈哈哈哈,不知道咋听出来的,乌龙可是闹大了呢,话说,我们见面次数确实不多,但听到她的声音,莫名有点想她了,晚上就是多愁善感啊,别管我!

BB啊,今天就到这吧,不敢想象明天的学习有多开心,终于到了要改善模型啦~

如果文章对您有帮助的话,记得给俺点个呐!

靴靴,谢谢~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/49266.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器视觉系列之【硬件知识】-工业相机(四)

目录 几个高频面试题目 工业彩色相机如何调节白平衡解决偏色问题 算法原理 多光谱成像技术和相机选型 多光谱相机技术 选择多光谱成像相机技术时的主要考虑因素 智慧工厂机器视觉感知与控制 1 智慧工厂与机器视觉检测控制技术 2 智慧工厂机器视觉感知与控制 基于机器视…

详解yolov5和yolov8以及目标检测相关面试

一、与yoloV4相比&#xff0c;yoloV5的改进 输入端&#xff1a;在模型训练阶段&#xff0c;使用了Mosaic数据增强、自适应锚框计算、自适应图片缩放基准网络&#xff1a;使用了FOCUS结构和CSP结构Neck网络&#xff1a;在Backbone和最后的Head输出层之间插入FPN_PAN结构Head输出…

[NOIP2009 提高组] 最优贸易(含代码题解)

[NOIP2009 提高组] 最优贸易 题目描述 C C C 国有 n n n 个大城市和 m m m 条道路&#xff0c;每条道路连接这 n n n 个城市中的某两个城市。任意两个城市之间最多只有一条道路直接相连。这 m m m 条道路中有一部分为单向通行的道路&#xff0c;一部分为双向通行的道路&am…

NLP-使用Word2vec实现文本分类

Word2Vec模型通过学习大量文本数据&#xff0c;将每个单词表示为一个连续的向量&#xff0c;这些向量可以捕捉单词之间的语义和句法关系。本文做文本分类是结合Word2Vec文本内容text&#xff0c;预测其文本标签label。以下使用mock商品数据的代码实现过程过下&#xff1a; 1、…

JMeter的使用方法

软件安装&#xff1a; 参考链接&#xff1a;JMeter 下载安装及环境配置&#xff08;包含jdk1.8安装及配置&#xff09;_jmeter5.2.1需要什么版本的jdk-CSDN博客 前置知识储备&#xff1a; JMeter的第一个案例 增加线程数 线程&#xff08;thread&#xff09;是操作系统能够进…

ROS2入门到精通—— 2-8 ROS2实战:机器人安全通过狭窄区域的方案

0 前言 室内机器人需要具备适应性和灵活性&#xff0c;以便在狭窄的空间中进行安全、高效的导航。本文提供一些让机器人在狭窄区域安全通过的思路&#xff0c;希望帮助读者根据实际开发适当调整和扩展 1 Voronoi图 Voronoi图&#xff1a;根据给定的一组“种子点”&#xff0…

【数据挖掘】词云分析

目录 1. 词云分析 2. Python 中的 WordCloud 库 1. 词云分析 词云&#xff08;Word Cloud&#xff09;是数据可视化的一种形式&#xff0c;主要用于展示文本数据中单词的频率和重要性。它具有以下几种主要用途和意义&#xff1a; 1. 文本分析 • 识别关键主题&#xff1a;通…

AI学习记录 - 图像识别的基础入门

代码实现&#xff0c;图像识别入门其实非常简单&#xff0c;这里使用的是js&#xff0c;其实就是把二维数组进行公式化处理&#xff0c;处理方式如上图&#xff0c;不同的公式代表的不同的意义&#xff0c;这些意义网上其实非常多&#xff0c;这里就不细讲了。 const getSpecif…

JavaScript构造函数小挑战

// 编码挑战 #1 /* 使用构造函数实现一辆汽车。一辆汽车有一个品牌和一个速度属性。speed 属性是汽车当前的速度&#xff0c;单位为 km/h&#xff1b; a. 执行一个 “accelerate ”方法&#xff0c;将汽车的速度提高 10&#xff0c;并将新速度记录到控制台&#xff1b; 3. a.…

VSCode python autopep8 格式化 长度设置

ctrl, 打开设置 > 搜索autopep8 > 找到Autopep8:Args > 添加项--max-line-length150

等保测评练习卷17

等级保护初级测评师试题17 姓名: 成绩: 判断题(101=10分)1. 关于安全区域边界的安全审计,三级系统的要求包括应对审计进程进行保护,防止未经授权的中断。( F ) 是安全计算环境的安全审计 2.…

秋招突击——7/22——复习{堆——前K个高频元素}——新作{回溯——单次搜索、分割回文串。链表——环形链表II,合并两个有序链表}

文章目录 引言复习堆堆——前K个高频元素个人实现复习实现二参考实现 新作单词搜索个人实现参考实现 分割回文串个人实现参考实现 环形链表II个人实现参考实现 两个有序链表个人实现 总结 引言 又是充满挑战性的一天&#xff0c;继续完成我们的任务吧&#xff01;继续往下刷&a…

WebRTC QoS方法十三.2(Jitter延时的计算)

一、背景介绍 一些报文在网络传输中&#xff0c;会存在丢包重传和延时的情况。渲染时需要进行适当缓存&#xff0c;等待丢失被重传的报文或者正在路上传输的报文。 jitter延时计算是确认需要缓存的时间 另外&#xff0c;在检测到帧有重传情况时&#xff0c;也可适当在渲染时…

【目标检测实验系列】EMA高效注意力机制,融合多尺度特征,助力YOLOv5检测模型涨点(文内附源码)

1. 文章主要内容 本篇博客主要涉及多尺度高效注意力机制&#xff0c;融合到YOLOv5s模型中&#xff0c;增加模型提取多尺度特征的能力&#xff0c;助力模型涨点。&#xff08;通读本篇博客需要7分钟左右的时间&#xff09;。 2. 简要概括 论文地址&#xff1a;EMA论文地址 如下…

Blender材质-PBR与纹理材质

1.PBR PBR:Physically Based Rendering 基于物理的渲染 BRDF:Bidirection Reflectance Distribution Function 双向散射分散函数 材质着色操作如下图&#xff1a; 2.纹理材质 左上角&#xff1a;编辑器类型中选择&#xff0c;着色器编辑器 新建着色器 -> 新建纹理 -> 新…

音视频入门基础:H.264专题(17)——FFmpeg源码获取H.264裸流文件信息(视频压缩编码格式、色彩格式、视频分辨率、帧率)的总流程

音视频入门基础&#xff1a;H.264专题系列文章&#xff1a; 音视频入门基础&#xff1a;H.264专题&#xff08;1&#xff09;——H.264官方文档下载 音视频入门基础&#xff1a;H.264专题&#xff08;2&#xff09;——使用FFmpeg命令生成H.264裸流文件 音视频入门基础&…

【开源库编译 | zlib】 zlib库最新版本(zlib-1.3.1)在Ubuntu(Linux)系统下的 编译 、交叉编译(移植)

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

《书生大模型实战营第3期》入门岛 学习笔记与作业:Git 基础知识

文章大纲 Git 是什么&#xff1f;-- 分布式版本控制系统版本控制系统简介Git 基本概念1. 安装 Git1.1 Windows 系统1.2 Linux 系统 2. Git 托管平台3. 常用 Git 操作4. tips4.1 全局设置 vs. 本地设置4.2 如何配置4.3 验证设置4.4 Git 四步曲 5. 常用插件6. 常规开发流程 作业其…

js+css侧边导航菜单 可收缩

jscss侧边导航菜单 可收缩https://www.bootstrapmb.com/item/14774 创建一个可收缩的侧边导航菜单需要使用JavaScript来处理交互&#xff0c;而CSS则用来设置样式和动画效果。以下是一个简单的示例&#xff0c;展示了如何创建一个可收缩的侧边导航菜单。 HTML 结构 html<!…

重修之路1

我也不知道我现在处于个什么状态&#xff0c;我在以前写代码时知道部分方法如何使用&#xff0c;但是也仅限于此我并不了其如何实现&#xff0c;让我感到迷茫我是越来越菜了随着AI的发展它写出的代码简洁高效甚至让我有些看不懂&#xff0c;以至于我开始怀疑自己的JS基本功因此…