机器学习7:pytorch的逻辑回归

一、说明

        逻辑回归模型是处理分类问题的最常见机器学习模型之一。二项式逻辑回归只是逻辑回归模型的一种类型。它指的是两个变量的分类,其中概率用于确定二元结果,因此“二项式”中的“bi”。结果为真或假 — 0 或 1。

        二项式逻辑回归的一个例子是预测人群中 COVID-19 的可能性。一个人要么感染了COVID-19,要么没有,必须建立一个阈值以尽可能准确地区分这些结果。

二、sigmoid函数

        这些预测不适合一条线,就像线性回归模型一样。相反,逻辑回归模型拟合到右侧所示的 sigmoid 函数。

        对于每个 x,生成的 y 值表示结果为 True 的概率。在 COVID-19 示例中,这表示医生对某人感染病毒的信心。在右图中,阴性结果为蓝色,阳性结果为红色。

图片来源:作者

三、过程

        要进行二项式逻辑回归,我们需要做各种事情:

  1. 创建训练数据集。
  2. 使用 PyTorch 创建我们的模型。
  3. 将我们的数据拟合到模型中。

        逻辑回归问题的第一步是创建训练数据集。首先,我们应该设置一个种子来确保我们的随机数据的可重复性。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seed

我们必须使用 PyTorch 的线性模型,因为我们正在处理一个输入 x 和一个输出 y。因此,我们的模型是线性的。为此,我们将使用 PyTorch 的函数:Linear

model = Linear(in_features=1, out_features=1) # use a linear model

接下来,我们必须生成蓝色 X 和红色 X 数据,确保将它们从行向量重塑为列向量。蓝色的在 0 到 7 之间,红色的在 7 到 10 之间。对于 y 值,蓝点表示 COVID-19 测试阴性,因此它们都将是

  1. 对于红点,它们代表 COVID-19 测试呈阳性,因此它们将为 1。下面是代码及其输出:
blue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

现在,我们的代码应如下所示:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

四、优化

        我们将使用梯度下降过程来优化 S 形函数的损失。损失是根据函数拟合数据的优度计算的,数据由 S 形曲线的斜率和截距控制。我们需要梯度下降来找到最佳斜率和截距。

        我们还将使用二进制交叉熵(BCE)作为我们的损失函数,或对数损失函数。对于一般的逻辑回归,不包含对数的损失函数将不起作用。

        为了实现BCE作为我们的损失函数,我们将它设置为我们的标准,并将随机梯度下降作为我们优化它的手段。由于这是我们将要优化的函数,我们需要传入模型参数和学习率。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descent

        现在,我们准备开始梯度下降以优化我们的损失。我们必须将梯度归零,通过将我们的数据插入 sigmoid 函数来找到 y-hat 值,计算损失,并找到损失函数的梯度。然后,我们必须迈出一步,确保存储我们的新斜率并为下一次迭代进行拦截。

optimizer.zero_grad()
Yhat = torch.sigmoid(model(X)) 
loss = criterion(Yhat,Y)
loss.backward()
optimizer.step() 

五、收尾

        为了找到最佳斜率和截距,我们本质上是在训练我们的模型。我们必须对多次迭代或纪元应用梯度下降。在此示例中,我们将使用 2,000 个纪元进行演示。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()

将所有代码片段放在一起,我们应该得到以下代码:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y valuesepochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()
两千个时期后的最终输出:epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314

两千个时期后的最终输出:

epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314 

六、可视化

        最后,我们可以将数据与 sigmoid 函数一起绘制,以获得以下可视化效果:

x = np.arange(0,10,.1)
y = model.weight.item()*x + model.bias.item()plt.plot(x, 1/(1 + np.exp(-y)), color="green")plt.xlim(0,10)
plt.scatter(blue_x, blue_y, color="blue")
plt.scatter(red_x, red_y, color="red")plt.show()

图片来源:作者

七、局限性

        二元分类的最大问题之一是需要阈值。在逻辑回归的情况下,此阈值应为 x 值,其中 y 为 50%。我们试图回答的问题是将阈值放在哪里?

        在 COVID-19 测试的情况下,原始示例说明了这种困境。如果我们将阈值设置为 x=5,我们可以清楚地看到应该是红色的蓝点和应该是蓝色的红点。

        悬垂的红点称为误报,即模型错误地预测正类的区域。悬垂的蓝点称为假阴性 - 模型错误地预测负类的区域。

 八、结论

        成功的二项式逻辑回归模型将减少假阴性的数量,因为这些假阴性通常会导致最大的危险。患有COVID-19但检测呈阴性对他人的健康和安全构成严重风险。

        通过对可用数据使用二项式逻辑回归,我们可以确定放置阈值的最佳位置,从而有助于减少不确定性并做出更明智的决策。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/97198.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS学习路之方舟开发框架—学习ArkTS语言(状态管理 八)

其他状态管理概述 除了前面章节提到的组件状态管理和应用状态管理,ArkTS还提供了Watch和$$来为开发者提供更多功能: Watch用于监听状态变量的变化。$$运算符:给内置组件提供TS变量的引用,使得TS变量和内置组件的内部状态保持同步…

Python环境安装

1、下载python安装包 (1)可以从官网下载需要的版本:Python Releases for Windows | Python.org (2)或者从我的百度网盘下载3.11.1版本: 链接:https://pan.baidu.com/s/1qNH3KU0iHIi-tS9wYBVrtQ …

freertos信号量之二值信号量

freertos信号量之二值信号量 简介例程 简介 FreeRTOS的二值信号量(Binary Semaphore)是用于实现进程间同步和临界资源保护的重要工具。以下是一些二值信号量的常用函数及其说明: 1)xSemaphoreCreateBinary() 创建一个二值信号量…

【论文阅读】通过3D和2D网络的交叉示教实现稀疏标注的3D医学图像分割(CVPR2023)

目录 前言方法标注3D-2D Cross Teaching伪标签选择Hard-Soft Confidence Threshold Consistent Prediction Fusion 结论 论文:3D Medical Image Segmentation with Sparse Annotation via Cross-Teaching between 3D and 2D Networks 代码:https://githu…

95、Spring Data Redis 之使用RedisTemplate 实现自定义查询 及 Spring Data Redis 的样本查询

Spring Data Redis 之使用RedisTemplate 实现自定义查询 Book实体类 原本的接口,再继承我们自定义的接口 自定义查询接口----CustomBookDao 实现类:CustomBookDaoImpl 1、自定义添加hash对象的方法 2、自定义查询价格高于某个点的Book对象 测试&a…

【JavaEE】线程安全的集合类

文章目录 前言多线程环境使用 ArrayList多线程环境使用队列多线程环境使用哈希表1. HashTable2. ConcurrentHashMap 前言 前面我们学习了很多的Java集合类,像什么ArrayList、Queue、HashTable、HashMap等等一些常用的集合类,之前使用这些都是在单线程中…

Amber中的信息传递——章节1.1-第二部分

Amber中的信息传递在实操中共分为预备程序、模拟程序和分析程序三个部分,具体相关文件如下: 1. 预备程序 **LEaP:**是在 Amber 中创建新系统或修改现有系统的主要程序。 它有命令行程序 tleap 和图形用户界面 xleap 两种形式。它结合了 Ambe…

【ARM CoreLink 系列 4 -- NIC-400 控制器详细介绍】

文章目录 1.1 ARM NIC-400(Network interconnect)1.1.1 NIC-400 系统框图1.1.2 NIC-400 Network Interconnect1.2 NIC-400 特点1.2.1 QoS-400 Advanced Quality of Service1.2.2 QVN-400 QoS Virtual Networks1.2.3 TLX-400 Thin Links1.3 NIC-400 Top1.4 NIC-400 Terminology1…

RabbitMQ之Fanout(扇形) Exchange解读

目录 基本介绍 适用场景 springboot代码演示 演示架构 工程概述 RabbitConfig配置类:创建队列及交换机并进行绑定 MessageService业务类:发送消息及接收消息 主启动类RabbitMq01Application:实现ApplicationRunner接口 基本介绍 Fa…

使用华为eNSP组网试验⑸-访问控制

今天练习使用华为sNSP模拟网络设备上的访问控制,这样的操作我经常在华为的S7706、S5720、S5735或者H3C的S5500、S5130、S7706上进行,在网络设备上根据情况应用访问控制的策略是一个网管必须熟练的操作,只是在真机上操作一般比较谨慎&#xff…

『力扣每日一题14』:消失的数字

昨天忙过头,等想起来已经 12 点多了,于是乎断更了。在这里先跟广大读者说声抱歉,并且稍后我会再更一篇。 一、题目 数组nums包含从0到n的所有整数,但其中缺了一个。请编写代码找出那个缺失的整数。你有办法在O(n)时间内完成吗&…

微服务技术栈-Gateway服务网关

文章目录 前言一、为什么需要网关二、Spring Cloud Gateway三、断言工厂和过滤器1.断言工厂2.过滤器3.全局过滤器4.过滤器执行顺序 四、跨域问题总结 前言 在之前的文章中我们已经介绍了微服务技术中eureka、nacos、ribbon、Feign这几个组件,接下来将介绍另外一个组…

using 语句 - 确保正确使用可释放对象

using语句块的几种用法。 1、using 语句可确保正确使用 IDisposable 实例&#xff1a; var numbers new List<int>(); using (StreamReader reader File.OpenText("numbers.txt")) {string line;while ((line reader.ReadLine()) is not null){if (int.Try…

Android源码下载

文章目录 一、Android源码下载 一、Android源码下载 AOSP 是 Android Open Source Project 的缩写。 git 常用命令总结 git 远程仓库相关的操作 # 查看 remote.origin.url 配置项的值 git config --list Android9.0之前代码在线查看地址&#xff1a;http://androidxref.com/ …

【LeetCode高频SQL50题-基础版】打卡第2天:第11-15题

文章目录 【LeetCode高频SQL50题-基础版】打卡第2天&#xff1a;第11-15题⛅前言 员工奖金&#x1f512;题目&#x1f511;题解 学生们参加各科测试的次数&#x1f512;题目&#x1f511;题解 至少有5名直接下属的经理&#x1f512;题目&#x1f511;题解 确认率&#x1f512;题…

使用python利用merge+sort函数对excel进行连接并排序

好久没更新了&#xff0c;天天玩短视频了。现在发现找点学习资料真的好难。 10.1期间偶然拿到一本书 本书是2022年出版的&#xff0c;看了一下不错&#xff0c;根据上面的案例结合&#xff0c;公司经营整合案例&#xff0c;分享一下。 数据内容来源于书中内容&#xff0c;仅供…

docker部署Vaultwarden密码共享管理系统

Vaultwarden是一个开源的密码管理器&#xff0c;它是Bitwarden密码管理器的自托管版本。它提供了类似于Bitwarden的功能&#xff0c;允许用户安全地存储和管理密码、敏感数据和身份信息。 Vaultwarden的主要特点包括&#xff1a; 1. 安全的数据存储&#xff1a;Vaultwarden使…

手机投屏电脑软件AirServer5.6.3.0最新免费版本下载

随着智能手机的普及&#xff0c;越来越多的人喜欢用手机观看视频、玩游戏、办公等。但是&#xff0c;有时候手机屏幕太小&#xff0c;不够清晰&#xff0c;也不方便操作。这时候&#xff0c;如果能把手机屏幕投射到电脑上&#xff0c;就可以享受更大的视野&#xff0c;更流畅的…

【javaweb】学习日记Day11 - tlias智能管理系统 - 文件上传 新增 修改员工 配置文件

目录 一、员工管理功能开发 1、新增员工 postman报错500的原因 &#xff08;1&#xff09;Controller类 &#xff08;2&#xff09;Service类 &#xff08;3&#xff09;Mapper类 2、根据ID查询 &#xff08;1&#xff09;Controller类 &#xff08;2&#xff09;Serv…

第11章 Redis(一)

11.1 谈谈你对Redis的理解 难度:★★★ 重点:★★ 白话解析 对Redis的理解无非从三个方面去说一说:背景,是什么,特性。 背景:数据直接存磁盘太慢了,虽然MySQL用到了BufferPool等缓存,但是为了保证数据不丢失,MySQL采用的RedoLog依然要直接写磁盘。所以,数据的存储就…