机器学习7:逻辑回归

一、说明

        逻辑回归模型是处理分类问题的最常见机器学习模型之一。二项式逻辑回归只是逻辑回归模型的一种类型。它指的是两个变量的分类,其中概率用于确定二元结果,因此“二项式”中的“bi”。结果为真或假 — 0 或 1。

        二项式逻辑回归的一个例子是预测人群中 COVID-19 的可能性。一个人要么感染了COVID-19,要么没有,必须建立一个阈值以尽可能准确地区分这些结果。

二、sigmoid函数

        这些预测不适合一条线,就像线性回归模型一样。相反,逻辑回归模型拟合到右侧所示的 sigmoid 函数。

        对于每个 x,生成的 y 值表示结果为 True 的概率。在 COVID-19 示例中,这表示医生对某人感染病毒的信心。在右图中,阴性结果为蓝色,阳性结果为红色。

图片来源:作者

三、过程

        要进行二项式逻辑回归,我们需要做各种事情:

  1. 创建训练数据集。
  2. 使用 PyTorch 创建我们的模型。
  3. 将我们的数据拟合到模型中。

        逻辑回归问题的第一步是创建训练数据集。首先,我们应该设置一个种子来确保我们的随机数据的可重复性。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seed

我们必须使用 PyTorch 的线性模型,因为我们正在处理一个输入 x 和一个输出 y。因此,我们的模型是线性的。为此,我们将使用 PyTorch 的函数:Linear

model = Linear(in_features=1, out_features=1) # use a linear model

接下来,我们必须生成蓝色 X 和红色 X 数据,确保将它们从行向量重塑为列向量。蓝色的在 0 到 7 之间,红色的在 7 到 10 之间。对于 y 值,蓝点表示 COVID-19 测试阴性,因此它们都将是

  1. 对于红点,它们代表 COVID-19 测试呈阳性,因此它们将为 1。下面是代码及其输出:
blue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

现在,我们的代码应如下所示:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

四、优化

        我们将使用梯度下降过程来优化 S 形函数的损失。损失是根据函数拟合数据的优度计算的,数据由 S 形曲线的斜率和截距控制。我们需要梯度下降来找到最佳斜率和截距。

        我们还将使用二进制交叉熵(BCE)作为我们的损失函数,或对数损失函数。对于一般的逻辑回归,不包含对数的损失函数将不起作用。

        为了实现BCE作为我们的损失函数,我们将它设置为我们的标准,并将随机梯度下降作为我们优化它的手段。由于这是我们将要优化的函数,我们需要传入模型参数和学习率。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descent

        现在,我们准备开始梯度下降以优化我们的损失。我们必须将梯度归零,通过将我们的数据插入 sigmoid 函数来找到 y-hat 值,计算损失,并找到损失函数的梯度。然后,我们必须迈出一步,确保存储我们的新斜率并为下一次迭代进行拦截。

optimizer.zero_grad()
Yhat = torch.sigmoid(model(X)) 
loss = criterion(Yhat,Y)
loss.backward()
optimizer.step() 

五、收尾

        为了找到最佳斜率和截距,我们本质上是在训练我们的模型。我们必须对多次迭代或纪元应用梯度下降。在此示例中,我们将使用 2,000 个纪元进行演示。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()

将所有代码片段放在一起,我们应该得到以下代码:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y valuesepochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()
两千个时期后的最终输出:epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314

两千个时期后的最终输出:

epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314 

六、可视化

        最后,我们可以将数据与 sigmoid 函数一起绘制,以获得以下可视化效果:

x = np.arange(0,10,.1)
y = model.weight.item()*x + model.bias.item()plt.plot(x, 1/(1 + np.exp(-y)), color="green")plt.xlim(0,10)
plt.scatter(blue_x, blue_y, color="blue")
plt.scatter(red_x, red_y, color="red")plt.show()

图片来源:作者

七、局限性

        二元分类的最大问题之一是需要阈值。在逻辑回归的情况下,此阈值应为 x 值,其中 y 为 50%。我们试图回答的问题是将阈值放在哪里?

        在 COVID-19 测试的情况下,原始示例说明了这种困境。如果我们将阈值设置为 x=5,我们可以清楚地看到应该是红色的蓝点和应该是蓝色的红点。

        悬垂的红点称为误报,即模型错误地预测正类的区域。悬垂的蓝点称为假阴性 - 模型错误地预测负类的区域。

 八、结论

        成功的二项式逻辑回归模型将减少假阴性的数量,因为这些假阴性通常会导致最大的危险。患有COVID-19但检测呈阴性对他人的健康和安全构成严重风险。

        通过对可用数据使用二项式逻辑回归,我们可以确定放置阈值的最佳位置,从而有助于减少不确定性并做出更明智的决策。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/93873.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PickerView案例12-info_plist-PCH文件介绍 Objective-C语言】

一、给大家介绍一下我们项目的一些文件: 1.这个呢,是项目的基础文件: 一些类啊: 一些图片啊: 还有加载图片, 最主要,就是这个东西:info.plist:文件 info.plist: 2.那,需要大家了解一点,关于它的历史啊: 我们现在用的时候,都是从xcode6.4开始的, 或者说,直…

Python 数据分析与挖掘(一)

Python 数据分析与挖掘(数据探索) 数据探索 1.1 需要掌握的工具(库) 1.1.1 Nump库 Numpy 提供多维数组对象和各种派生对象(类矩阵),利用应用程序接口可以实现大量且繁琐的数据运算。可以构建…

Linux 5种网络模型

[参考]:《黑马程序员Redis》https://www.bilibili.com/video/BV1cr4y1671t/?p166&share_sourcecopy_web&vd_source9e65300ccca322aeb367bb1eb677b0fc [参考]:《操作系统》 [参考]:《UNIX网络编程》 为了避免用户应用导致冲突甚至内…

基于SSM的奶茶店管理系统

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

WebSocket实战之三遇上PAC

一、前言 前两天销售数据实时刷新功能开发测试完成,开开心心部署到生产环境,然后直接懵逼傻眼了,竟然连接不上WebSocket服务端,浏览器端请求头报 Provisional headers are shown 信息,然后采用一系列操作排查问题。 …

89、Redis 的 value 所支持的数据类型(String、List、Set、Zset、Hash)---->Zset 相关命令

本次讲解要点: ** Set相关命令:是指value中的数据类型** 启动redis服务器: 打开小黑窗: C:\Users\JH>e: E:>cd E:\install\Redis6.0\Redis-x64-6.0.14\bin E:\install\Redis6.0\Redis-x64-6.0.14\bin>redis-server.exe …

创建型设计模式 原型模式 建造者模式 创建者模式对比

创建型设计模式 单例 工厂模式 看这一篇就够了_软工菜鸡的博客-CSDN博客 4.3 原型模式 4.3.1 概述 用一个已经创建的实例作为原型,通过复制该原型对象来创建一个和原型对象相同的新对象。 4.3.2 结构 原型模式包含如下角色: 抽象原型类:规定了…

excel中将一个sheet表根据条件分成多个sheet表

有如下excel表,要求:按月份将每月的情况放在一个sheet中。 目测有6个月,就应该有6个sheet,每个sheet中体现本月的情况。 一、首先增加一个辅助列,月份,使用month函数即可。 填充此列所有。然后复制【月份】…

力扣练习——链表在线OJ

目录 提示: 一、移除链表元素 题目: 解答: 二、反转链表 题目: 解答: 三、找到链表的中间结点 题目: 解答: 四、合并两个有序链表(经典) 题目: 解…

Redis与分布式-分布式锁

接上文 Redis与分布式-集群搭建 1.分布式锁 为了解决上述问题,可以利用分布式锁来实现。 重新复制一份redis,配置文件都是刚下载时候的不用更改,然后启动redis服务和redis客户。 redis存在这样的命令:和set命令差不多&#xff0…

十四天学会C++之第二天(函数和库)

1. 函数的定义和调用 在C中,函数是组织和结构化代码的关键工具之一。它们允许您将一段代码封装成一个可重复使用的模块,这有助于提高代码的可读性和维护性。 为什么使用函数? 函数在编程中的作用不可小觑。它们有以下几个重要用途&#xf…

ASUS华硕飞行堡垒5笔记本FX504GM_FX80GM原装出厂Windows10系统

系统自带所有驱动、出厂主题壁纸、系统属性华硕专属LOGO标志、Office办公软件、MyASUS华硕电脑管家等预装程序 下载链接:https://pan.baidu.com/s/1C8vPvqiwqoUY3PxC915LXg?pwdv079

基于被囊群优化的BP神经网络(分类应用) - 附代码

基于被囊群优化的BP神经网络(分类应用) - 附代码 文章目录 基于被囊群优化的BP神经网络(分类应用) - 附代码1.鸢尾花iris数据介绍2.数据集整理3.被囊群优化BP神经网络3.1 BP神经网络参数设置3.2 被囊群算法应用 4.测试结果&#x…

小白自己​制作一个苹果.ios安卓.apk文件app应用手机下载的代码合并文件一码双端的落地页面详细教程

小白自己制作一个苹果.ios安卓.apk文件app应用手机下载的代码落地页面详细教程 图片取自这里哈 我们在这篇文章中教你如何制作一个手机下载引导落地页。这个落地页将可以自动识别访问者使用的是安卓还是苹果设备,并引导下载相应的应用程序。让我们按照以下步骤一…

Selenium 浏览器坐标转桌面坐标

背景: 做图表自动化项目需要做拖拽操作,但是selenium提供的拖拽API无效,因此借用pyautogui实现拖拽,但是pyautogui的拖拽是基于Windows桌面坐标实现的,另外浏览器中的坐标与windows桌面坐标并不是一比一对应的关系&am…

【计算机组成原理】考研真题攻克与重点知识点剖析 - 第 1 篇:计算机系统概述

前言 本文基础知识部分来自于b站:分享笔记的好人儿的思维导图,感谢大佬的开源精神,习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析,本人技术有限&#xff…

基于SpringBoot+MyBatis实现的个人博客系统(一)

这篇主要讲解一下如何基于SpringBoot和MyBatis技术实现一个简易的博客系统(前端页面主要是利用CSS,HTML进行布局书写),前端的静态页面代码可以直接复制粘贴,后端的接口以及前端发送的Ajax请求需要自己书写. 博客系统需要完成的接口: 注册登录博客列表页展示博客详情页展示发布博…

如何在 Google Earth 中创建轨迹、路线并制作动画

如何创建航迹 https://kurviger.de/en Google 地球飞行教程(天桥动画) 选择合适的点 (可调整视图快照)点击录制,依次点击图标即可

WebSocket实战之六心跳重连机制

一、前言 WebSocket应用部署到生产环境,我们除了会碰到因为经过代理服务器无法连接的问题(注:该问题可以通过搭建WSS来解决,具体配置请看 WebSocket实战之四WSS配置 ),另外一个问题就是外网环境不稳定经常…

基于SSM的餐厅点菜管理系统的设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用Vue技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…