用PyTorch轻松实现二分类:逻辑回归入门

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦什么是逻辑回归?
  • 🥦分类问题
  • 🥦交叉熵
  • 🥦代码实现
  • 🥦总结

🥦引言

当谈到机器学习和深度学习时,逻辑回归是一个非常重要的算法,它通常用于二分类问题。在这篇博客中,我们将使用PyTorch来实现逻辑回归。PyTorch是一个流行的深度学习框架,它提供了强大的工具来构建和训练神经网络,适用于各种机器学习任务。

在机器学习中已经使用了sklearn库介绍过逻辑回归,这里重点使用pytorch这个深度学习框架

🥦什么是逻辑回归?

我们首先来回顾一下什么是逻辑回归?

逻辑回归是一种用于二分类问题的监督学习算法。它的主要思想是通过一个S形曲线(通常是Sigmoid函数)将输入特征映射到0和1之间的概率值,然后根据这些概率值进行分类决策。在逻辑回归中,我们使用一个线性模型和一个激活函数来实现这个映射。

🥦分类问题

这里以MINIST Dataset手写数字集为例
在这里插入图片描述

这个数据集中包含了6w个训练集1w个测试集,类别10个
这里我们不再向之前线性回归那样,根据属于判断具体的数值大小;而是根据输入的值判断从0-9每个数字的概率大小记为p(0)、p(1)…而且十个概率值和为1,我们的目标就是根据输入得到这十个分类对于输入的每一个的概率值,哪个大就是我们需要的。

这里介绍一下与torch相关联的库—torchvision
torchvision:

  • “torchvision” 是一个PyTorch的附加库,专门用于处理图像和视觉任务。
    它包含了一系列用于数据加载、数据增强、计算机视觉任务(如图像分类、目标检测等)的工具和数据集。
  • “torchvision” 提供了许多预训练的视觉模型(例如,ResNet、VGG、AlexNet等),可以用于迁移学习或作为基准模型。
    此外,它还包括了用于图像预处理、转换和可视化的函数。

上图已经清楚的显示了,这个库包含了一些自带的数据集,但是并不是我们安装完这个库就有了,而且需要进行调用的,类似在线下载,root指定下载的路径,train表示你需要训练集还是测试集,通常情况下就是两个一个训练,一个测试,download就是判断你下没下载,下载了就是摆设,没下载就给你下载了

我们再来看一个数据集(CIFAR-10)
在这里插入图片描述
包含了5w训练样本,1w测试样本,10类。调用方式与上一个类似。

接下来我们从一张图更加直观的查看分类和回归
在这里插入图片描述

左边的是回归,右边的是分类


在这里插入图片描述

过去我们使用回归例如 y ^ \hat{y} y^=wx+b∈R,这是属于一个实数的;但是在分类问题, y ^ \hat{y} y^∈[0,1]
这说明我们需要寻找一个函数,将原本实数的值经过函数的映射转化为[0,1]之间。这里我们引入Logistic函数,使用极限很清楚的得出x趋向于正无穷的时候函数为1,x趋向于负无穷的时候,函数为0,x=0的时候,函数为0.5,当我们计算的时候将 y ^ \hat{y} y^带入这样就会出现一个0到1的概率了。

下图展示一些其他的Sigmoid函数
在这里插入图片描述

🥦交叉熵

过去我们所使用的损失函数普遍都是MSE,这里引入一个新的损失函数—交叉熵

==交叉熵(Cross-Entropy)==是一种用于衡量两个概率分布之间差异的数学方法,常用于机器学习和深度学习中,特别是在分类问题中。它是一个非常重要的损失函数,用于衡量模型的预测与真实标签之间的差异,从而帮助优化模型参数。

在交叉熵的上下文中,通常有两个概率分布:

  • 真实分布(True Distribution): 这是指问题的实际概率分布,表示样本的真实标签分布。通常用 p ( x ) p(x) p(x)表示,其中 x x x表示样本或类别。

  • 预测分布(Predicted Distribution): 这是指模型的预测概率分布,表示模型对每个类别的预测概率。通常用 q ( x ) q(x) q(x)表示,其中 x x x表示样本或类别。

交叉熵的一般定义如下:
在这里插入图片描述其中, H ( p , q ) H(p, q) H(p,q) 表示真实分布 p p p 和预测分布 q q q 之间的交叉熵。

交叉熵的主要特点和用途包括:

  • 度量差异性: 交叉熵度量了真实分布和预测分布之间的差异。当两个分布相似时,交叉熵较小;当它们之间的差异增大时,交叉熵增大。

  • 损失函数: 在机器学习中,交叉熵通常用作损失函数,用于衡量模型的预测与真实标签之间的差异。在分类任务中,通常使用交叉熵作为模型的损失函数,帮助模型优化参数以提高分类性能。

  • 反向传播: 交叉熵在训练神经网络时非常有用。通过计算交叉熵的梯度,可以使用反向传播算法来调整神经网络的权重,从而使模型的预测更接近真实标签。

在分类问题中,常见的交叉熵损失函数包括二元交叉熵(Binary Cross-Entropy)和多元交叉熵(Categorical Cross-Entropy)。二元交叉熵用于二分类问题,多元交叉熵用于多类别分类问题。

刘二大人的PPT中也介绍了
在这里插入图片描述
右边的表格中每组y与 y ^ \hat{y} y^对应的BCE,BCE越高说明越可能,最后将其求均值

🥦代码实现

在这里插入图片描述

根据上图可知,线性回归和逻辑回归的流程与函数只区别于Sigmoid函数
在这里插入图片描述
这里就是BCEloss的调用,里面的参数代表求不求均值

完整代码如下

import torch.nn.functional as F
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred
model = LogisticRegressionModel() 
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()

最后绘制一下

import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))  # 相当于reshape
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r') 
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

运行结果如下
在这里插入图片描述

🥦总结

这就是使用PyTorch实现逻辑回归的基本步骤。逻辑回归是一个简单但非常有用的算法,可用于各种分类问题。希望这篇博客能帮助你开始使用PyTorch构建自己的逻辑回归模型。如果你想进一步扩展你的知识,可以尝试在更大的数据集上训练模型或探索其他深度学习算法。祝你好运!

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/99104.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《golang设计模式》第二部分·结构型模式-07-代理模式(Proxy)

文章目录 1. 概述1.1 角色1.2 模式类图 2. 代码示例2.1 设计2.2 代码2.3 示例类图 1. 概述 代理(Proxy)是用于控制客户端访问目标对象的占位对象。 需求:在调用接口实现真是主题之前需要一些提前处理。 解决:写一个代理&#xff…

ViewPager、RecycleView实现轮播图

1.ViewPager实现轮播图形效果。 1&#xff09;layout中&#xff0c;PageIndicatorView轮播的View <RelativeLayoutandroid:layout_width"match_parent"android:layout_height"200dp"android:orientation"vertical"><androidx.viewpager…

http.header.Set()与Add()区别;

在Go语言中进行HTTP请求时&#xff0c;http.Header对象表示HTTP请求或响应的头部信息。http.Header是一个map[string][]string类型的结构&#xff0c;用于存储键值对&#xff0c;其中键表示HTTP头字段的名称&#xff0c;值是一个字符串切片&#xff0c;可以存储多个相同名称的头…

centos openssh升级

注意&#xff1a; openssh升级异常会造成服务失联&#xff0c;如果在允许的情况下可以安装talent服务&#xff0c;使用talent升级&#xff1b; 如果不能安装talent服务&#xff0c;可以打开多个终端&#xff0c;启动ping命令&#xff0c;防止升级终端失败后&#xff0c;作为备用…

巧用excel实现试卷向表格的转换

MID($E$10,FIND(D14,$E$10,1),FIND(D15,$E$10,1)-FIND(D14,$E$10,1)) MID($E$10,FIND(D15,$E$10,1),FIND(D16,$E$10,1)-FIND(D15,$E$10,1)) 中华人民共和国司法部

HarmonyOS/OpenHarmony原生应用-ArkTS万能卡片组件Span

作为Text组件的子组件&#xff0c;用于显示行内文本的组件。无子组件 一、接口 Span(value: string | Resource) 从API version 9开始&#xff0c;该接口支持在ArkTS卡片中使用。 参数&#xff1a; 参数名 参数类型 必填 参数描述 value string | Resource 是 文本内…

c++视觉---中值滤波处理

中值滤波&#xff08;Median Filter&#xff09;是一种常用的非线性平滑滤波方法&#xff0c;用于去除图像中的噪声。它不像线性滤波&#xff08;如均值滤波或高斯滤波&#xff09;那样使用权重来计算平均值或加权平均值&#xff0c;而是选择滤波窗口内的像素值中的中间值作为输…

docker搭建jenkins

1.拉取镜像 docker pull jenkinsci/blueocean 2.启动容器 docker run -d -u root -p 8666:8080 -p 50000:50000 -v /var/jenkins_home:/var/jenkins_home -v /etc/localtime:/etc/localtime --name MyJenkins jenkinsci/blueocean 3.访问ip:port,就能访问了 4.docker logs 容器…

主从复制的实现方案

读写分离技术架构图 实现读写分离的技术架构选型如上;需要自己去实践主从复制;为了节省资源&#xff0c;当然系统并发量并没有那么大,选择一主一丛;强制读主库,为了解决主从同步延迟带来的影响&#xff1b;对于实时性要求高的强制读主库&#xff1b;GTID 主要是一种事务标识技术…

linux centos运行C语言程序

1.安装gcc。 yum install gcc [rootlinux ~]# yum install gcc 已加载插件&#xff1a;fastestmirror, langpacks Repository updates is listed more than once in the configuration Repository extras is listed more than once in the configuration Loading mirror spee…

Ubuntu 22.04 铭瑄 MS-终结者 B760M D4 WIFI 驱动安装

wifi芯片为intel ax101ngw 直接装最新稳定版本linux 6.5.6 源码地址 https://cdn.kernel.org/pub/linux/kernel/v6.x/linux-6.5.6.tar.xz .config参考 /boot/config-6.2.0-33-generic修改&#xff0c;完整内容如下 # # Automatically generated file; DO NOT EDIT. # Linux/…

OpenCV Python – 使用SIFT算法实现两张图片的特征匹配

OpenCV Python – 使用SIFT算法实现两张图片的特征匹配 1.要实现在大图中找到任意旋转、缩放等情况下的小图位置&#xff0c;可以使用特征匹配算法&#xff0c;如 SIFT (尺度不变特征变换) 或 SURF (加速稳健特征)。这些算法可以在不同尺度和旋转情况下寻找匹配的特征点 impo…

docker搭建nginx

1.docker pull nginx 2.docker run --name nginx-test -p 8082:80 -d nginx 3.访问ip:8082

Linux 系统性能瓶颈分析(超详细)

Author&#xff1a;rab 目录 前言一、性能指标1.1 进程1.1.1 进程定义1.1.2 进程状态1.1.3 进程优先级1.1.4 进程与程序间的关系1.1.5 进程与进程间的关系1.1.6 进程与线程的关系 1.2 内存1.2.1 物理内存与虚拟内存1.2.2 页高速缓存与页写回机制1.2.3 Swap Space 1.3 文件系统1…

在PicGo上使用github图床解决typora上传csdn图片不显示问题(保姆级教程)

文章目录 在PicGo上使用github图床解决typora上传csdn图片不显示问题&#xff08;保姆级教程&#xff09;1、typora上传csdn图片不显示&#xff08;外链图片转存失败&#xff09;2、PicGo2.1、PicGo下载2.2、PicGo使用2.2.1、对PicGo完成基本的配置2.2.2、配置github图床2.2.3、…

R实现地图相关图形绘制

大家好&#xff0c;我是带我去滑雪&#xff01; 地图相关图形绘制具有许多优点&#xff0c;这些优点使其在各种领域和应用中非常有用。例如&#xff1a;地图相关图形提供了一种直观的方式来可视化数据&#xff0c;使数据更容易理解和分析。通过地图&#xff0c;可以看到数据的空…

【C++笔记】C++三大特性之多态的概念、定义及使用

1.多态的概念 多态即多种形态&#xff0c;对于C程序设计中指的是在类的实例化对象中&#xff0c;当不同的对象去完成某个行为时会出现不同的状态。 2.多态的分类 静态的多态&#xff1a;函数重载&#xff0c;看起来调用同一个函数有不同行为。静态&#xff1a;原理是编译时实…

UE4 Unlua 初使用小记

function M:Construct()print(Hello World)print(self.Va)local mySubsystem UE4.UHMSGameInstanceSubsystemUE4.UKismetSystemLibrary.PrintString(self,"Get Click Msg From UnLua ")end unlua中tick不能调用的问题&#xff1a; 把该类的Event Tick为灰色显示的删…

【数据库审计】2023年数据库审计厂家汇总

我们大家都知道数据库审计的重要意义&#xff0c;不仅可以满足等保合规&#xff0c;还能进行风险告警&#xff0c;保障数据安全。那你知道目前市面上数据库审计厂家有哪些吗&#xff1f;这里小编就给大家汇总一下。 2023年数据库审计厂家汇总 1、行云管家 2、安恒信息 3、…

MongoDB-介绍与安装部署

介绍与安装部署 1.MongoDB简介a) 体系结构b) 数据模型c) MongoDB的特点c.1) 高性能c.2) 高性可用性c.3) 高拓展性c.4) 丰富的查询支持 2.单机部署a) Windows系统中的安装启动b) Shell连接(mongo命令)c) Linux系统中的安装启动和连接 1.MongoDB简介 MongoDB是一个开源、高性能、…