用PyTorch轻松实现二分类:逻辑回归入门

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦什么是逻辑回归?
  • 🥦分类问题
  • 🥦交叉熵
  • 🥦代码实现
  • 🥦总结

🥦引言

当谈到机器学习和深度学习时,逻辑回归是一个非常重要的算法,它通常用于二分类问题。在这篇博客中,我们将使用PyTorch来实现逻辑回归。PyTorch是一个流行的深度学习框架,它提供了强大的工具来构建和训练神经网络,适用于各种机器学习任务。

在机器学习中已经使用了sklearn库介绍过逻辑回归,这里重点使用pytorch这个深度学习框架

🥦什么是逻辑回归?

我们首先来回顾一下什么是逻辑回归?

逻辑回归是一种用于二分类问题的监督学习算法。它的主要思想是通过一个S形曲线(通常是Sigmoid函数)将输入特征映射到0和1之间的概率值,然后根据这些概率值进行分类决策。在逻辑回归中,我们使用一个线性模型和一个激活函数来实现这个映射。

🥦分类问题

这里以MINIST Dataset手写数字集为例
在这里插入图片描述

这个数据集中包含了6w个训练集1w个测试集,类别10个
这里我们不再向之前线性回归那样,根据属于判断具体的数值大小;而是根据输入的值判断从0-9每个数字的概率大小记为p(0)、p(1)…而且十个概率值和为1,我们的目标就是根据输入得到这十个分类对于输入的每一个的概率值,哪个大就是我们需要的。

这里介绍一下与torch相关联的库—torchvision
torchvision:

  • “torchvision” 是一个PyTorch的附加库,专门用于处理图像和视觉任务。
    它包含了一系列用于数据加载、数据增强、计算机视觉任务(如图像分类、目标检测等)的工具和数据集。
  • “torchvision” 提供了许多预训练的视觉模型(例如,ResNet、VGG、AlexNet等),可以用于迁移学习或作为基准模型。
    此外,它还包括了用于图像预处理、转换和可视化的函数。

上图已经清楚的显示了,这个库包含了一些自带的数据集,但是并不是我们安装完这个库就有了,而且需要进行调用的,类似在线下载,root指定下载的路径,train表示你需要训练集还是测试集,通常情况下就是两个一个训练,一个测试,download就是判断你下没下载,下载了就是摆设,没下载就给你下载了

我们再来看一个数据集(CIFAR-10)
在这里插入图片描述
包含了5w训练样本,1w测试样本,10类。调用方式与上一个类似。

接下来我们从一张图更加直观的查看分类和回归
在这里插入图片描述

左边的是回归,右边的是分类


在这里插入图片描述

过去我们使用回归例如 y ^ \hat{y} y^=wx+b∈R,这是属于一个实数的;但是在分类问题, y ^ \hat{y} y^∈[0,1]
这说明我们需要寻找一个函数,将原本实数的值经过函数的映射转化为[0,1]之间。这里我们引入Logistic函数,使用极限很清楚的得出x趋向于正无穷的时候函数为1,x趋向于负无穷的时候,函数为0,x=0的时候,函数为0.5,当我们计算的时候将 y ^ \hat{y} y^带入这样就会出现一个0到1的概率了。

下图展示一些其他的Sigmoid函数
在这里插入图片描述

🥦交叉熵

过去我们所使用的损失函数普遍都是MSE,这里引入一个新的损失函数—交叉熵

==交叉熵(Cross-Entropy)==是一种用于衡量两个概率分布之间差异的数学方法,常用于机器学习和深度学习中,特别是在分类问题中。它是一个非常重要的损失函数,用于衡量模型的预测与真实标签之间的差异,从而帮助优化模型参数。

在交叉熵的上下文中,通常有两个概率分布:

  • 真实分布(True Distribution): 这是指问题的实际概率分布,表示样本的真实标签分布。通常用 p ( x ) p(x) p(x)表示,其中 x x x表示样本或类别。

  • 预测分布(Predicted Distribution): 这是指模型的预测概率分布,表示模型对每个类别的预测概率。通常用 q ( x ) q(x) q(x)表示,其中 x x x表示样本或类别。

交叉熵的一般定义如下:
在这里插入图片描述其中, H ( p , q ) H(p, q) H(p,q) 表示真实分布 p p p 和预测分布 q q q 之间的交叉熵。

交叉熵的主要特点和用途包括:

  • 度量差异性: 交叉熵度量了真实分布和预测分布之间的差异。当两个分布相似时,交叉熵较小;当它们之间的差异增大时,交叉熵增大。

  • 损失函数: 在机器学习中,交叉熵通常用作损失函数,用于衡量模型的预测与真实标签之间的差异。在分类任务中,通常使用交叉熵作为模型的损失函数,帮助模型优化参数以提高分类性能。

  • 反向传播: 交叉熵在训练神经网络时非常有用。通过计算交叉熵的梯度,可以使用反向传播算法来调整神经网络的权重,从而使模型的预测更接近真实标签。

在分类问题中,常见的交叉熵损失函数包括二元交叉熵(Binary Cross-Entropy)和多元交叉熵(Categorical Cross-Entropy)。二元交叉熵用于二分类问题,多元交叉熵用于多类别分类问题。

刘二大人的PPT中也介绍了
在这里插入图片描述
右边的表格中每组y与 y ^ \hat{y} y^对应的BCE,BCE越高说明越可能,最后将其求均值

🥦代码实现

在这里插入图片描述

根据上图可知,线性回归和逻辑回归的流程与函数只区别于Sigmoid函数
在这里插入图片描述
这里就是BCEloss的调用,里面的参数代表求不求均值

完整代码如下

import torch.nn.functional as F
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred
model = LogisticRegressionModel() 
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()

最后绘制一下

import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))  # 相当于reshape
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r') 
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

运行结果如下
在这里插入图片描述

🥦总结

这就是使用PyTorch实现逻辑回归的基本步骤。逻辑回归是一个简单但非常有用的算法,可用于各种分类问题。希望这篇博客能帮助你开始使用PyTorch构建自己的逻辑回归模型。如果你想进一步扩展你的知识,可以尝试在更大的数据集上训练模型或探索其他深度学习算法。祝你好运!

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/99104.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ViewPager、RecycleView实现轮播图

1.ViewPager实现轮播图形效果。 1&#xff09;layout中&#xff0c;PageIndicatorView轮播的View <RelativeLayoutandroid:layout_width"match_parent"android:layout_height"200dp"android:orientation"vertical"><androidx.viewpager…

巧用excel实现试卷向表格的转换

MID($E$10,FIND(D14,$E$10,1),FIND(D15,$E$10,1)-FIND(D14,$E$10,1)) MID($E$10,FIND(D15,$E$10,1),FIND(D16,$E$10,1)-FIND(D15,$E$10,1)) 中华人民共和国司法部

HarmonyOS/OpenHarmony原生应用-ArkTS万能卡片组件Span

作为Text组件的子组件&#xff0c;用于显示行内文本的组件。无子组件 一、接口 Span(value: string | Resource) 从API version 9开始&#xff0c;该接口支持在ArkTS卡片中使用。 参数&#xff1a; 参数名 参数类型 必填 参数描述 value string | Resource 是 文本内…

c++视觉---中值滤波处理

中值滤波&#xff08;Median Filter&#xff09;是一种常用的非线性平滑滤波方法&#xff0c;用于去除图像中的噪声。它不像线性滤波&#xff08;如均值滤波或高斯滤波&#xff09;那样使用权重来计算平均值或加权平均值&#xff0c;而是选择滤波窗口内的像素值中的中间值作为输…

docker搭建jenkins

1.拉取镜像 docker pull jenkinsci/blueocean 2.启动容器 docker run -d -u root -p 8666:8080 -p 50000:50000 -v /var/jenkins_home:/var/jenkins_home -v /etc/localtime:/etc/localtime --name MyJenkins jenkinsci/blueocean 3.访问ip:port,就能访问了 4.docker logs 容器…

主从复制的实现方案

读写分离技术架构图 实现读写分离的技术架构选型如上;需要自己去实践主从复制;为了节省资源&#xff0c;当然系统并发量并没有那么大,选择一主一丛;强制读主库,为了解决主从同步延迟带来的影响&#xff1b;对于实时性要求高的强制读主库&#xff1b;GTID 主要是一种事务标识技术…

docker搭建nginx

1.docker pull nginx 2.docker run --name nginx-test -p 8082:80 -d nginx 3.访问ip:8082

Linux 系统性能瓶颈分析(超详细)

Author&#xff1a;rab 目录 前言一、性能指标1.1 进程1.1.1 进程定义1.1.2 进程状态1.1.3 进程优先级1.1.4 进程与程序间的关系1.1.5 进程与进程间的关系1.1.6 进程与线程的关系 1.2 内存1.2.1 物理内存与虚拟内存1.2.2 页高速缓存与页写回机制1.2.3 Swap Space 1.3 文件系统1…

在PicGo上使用github图床解决typora上传csdn图片不显示问题(保姆级教程)

文章目录 在PicGo上使用github图床解决typora上传csdn图片不显示问题&#xff08;保姆级教程&#xff09;1、typora上传csdn图片不显示&#xff08;外链图片转存失败&#xff09;2、PicGo2.1、PicGo下载2.2、PicGo使用2.2.1、对PicGo完成基本的配置2.2.2、配置github图床2.2.3、…

R实现地图相关图形绘制

大家好&#xff0c;我是带我去滑雪&#xff01; 地图相关图形绘制具有许多优点&#xff0c;这些优点使其在各种领域和应用中非常有用。例如&#xff1a;地图相关图形提供了一种直观的方式来可视化数据&#xff0c;使数据更容易理解和分析。通过地图&#xff0c;可以看到数据的空…

UE4 Unlua 初使用小记

function M:Construct()print(Hello World)print(self.Va)local mySubsystem UE4.UHMSGameInstanceSubsystemUE4.UKismetSystemLibrary.PrintString(self,"Get Click Msg From UnLua ")end unlua中tick不能调用的问题&#xff1a; 把该类的Event Tick为灰色显示的删…

【数据库审计】2023年数据库审计厂家汇总

我们大家都知道数据库审计的重要意义&#xff0c;不仅可以满足等保合规&#xff0c;还能进行风险告警&#xff0c;保障数据安全。那你知道目前市面上数据库审计厂家有哪些吗&#xff1f;这里小编就给大家汇总一下。 2023年数据库审计厂家汇总 1、行云管家 2、安恒信息 3、…

MongoDB-介绍与安装部署

介绍与安装部署 1.MongoDB简介a) 体系结构b) 数据模型c) MongoDB的特点c.1) 高性能c.2) 高性可用性c.3) 高拓展性c.4) 丰富的查询支持 2.单机部署a) Windows系统中的安装启动b) Shell连接(mongo命令)c) Linux系统中的安装启动和连接 1.MongoDB简介 MongoDB是一个开源、高性能、…

多头注意力机制

1、什么是多头注意力机制 从多头注意力的结构图中&#xff0c;貌似这个所谓的多个头就是指多组线性变换&#xff0c;但是并不是&#xff0c;只使用了一组线性变换层&#xff0c;即三个变换张量对 Q、K、V 分别进行线性变换&#xff0c;这些变化不会改变原有张量的尺寸&#xf…

MyBatisPlus(十六)逻辑删除

说明 实际生产中的数据&#xff0c;一般不采用物理删除&#xff0c;而采用逻辑删除&#xff0c;也就是将一条记录的状态改为已删除。 逻辑删除&#xff0c;本质上是更新操作。 MyBatis Plus 框架&#xff0c;提供了逻辑删除功能。在配置了逻辑删除后&#xff0c;增删改查和统…

在Remix中编写你的第一份智能合约

智能合约简单来讲就是&#xff1a;部署在去中心化区块链上的一个合约或者一组指令&#xff0c;当这个合约或者这组指令被部署以后&#xff0c;它就不能被改变了&#xff0c;并会自动执行&#xff0c;每个人都可以看到合约里面的条款。更深层次的理解就是&#xff1a;这些代码会…

亚马逊电子产品日本站PSE认证,TELEC认证如何办理?

日本市场准入认证——PSE认证&#xff0c;TELEC认证 日本作为第三大经济体国家&#xff0c;是中国商品对外出口的最多的国家之一&#xff0c;无论是在日本亚马逊销售还是在日本当地销售&#xff0c;都需要符合日本市场准入许可。需要注意的是日本的电气安全标准都是自主特色的…

R语言 一种功能强大的数据分析、统计建模 可视化 免费、开源且跨平台 的编程语言

R语言是一种广泛应用于数据分析、统计建模和可视化的编程语言。它由新西兰奥克兰大学的罗斯伊哈卡和罗伯特杰特曼开发&#xff0c;并于1993年首次发布。R语言是一个免费、开源且跨平台的语言&#xff0c;它在统计学和数据科学领域得到了广泛的应用。 R语言具有丰富的数据处理、…

基于docker+Keepalived+Haproxy高可用前后的分离技术

基于dockerKeepalivedHaproxy高可用前后端分离技术 架构图 服务名docker-ip地址docker-keepalived-vip-iphaproxy-01docker-ip自动分配 未指定ip192.168.31.252haproxy-02docker-ip自动分配 未指定ip192.168.31.253 安装haproxy 宿主机ip 192.168.31.254 宿主机keepalived虚…

《DevOps 精要:业务视角》- 读书笔记(三)

DevOps 精要:业务视角&#xff08;三&#xff09; 第3章 原则3.1 价值流3.2 部署流水线3.3 一切都应存储在版本控制系统中3.4 自动化配置管理3.5 完成的定义3.6 小结 第3章 原则 将原则从实践中分离出来&#xff0c;这是一种很有用的做法。当然了&#xff0c;这两个词分别有着…