Pytorch深度学习实践(5)逻辑回归

逻辑回归

逻辑回归主要是解决分类问题

  • 回归任务:结果是一个连续的实数
  • 分类任务:结果是一个离散的值

分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 09),因为各个类别之间没有大小之差。

因此,对于分类问题,我们最终的输出是个概率,即属于某个类别的概率是多少,然后从概率集合里找最大值,作为当前预测的结果

下载MNIST数据集

import torchvision
train_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=True, download=True)
test_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=False, dowload=True)
  • 通过train参数来指定训练集和测试集

逻辑回归

将之前的学习时长—考试分数转化为二分类任务,即学习时长—是否通过考试

x(hours)y(pass/fail)
10(fail)
20(fail)
31(pass)
4?

其中, P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat y = 1) + P(\hat y = 0) = 1 P(y^=1)+P(y^=0)=1

当输出的概率在 0.5 0.5 0.5附近时,即模型不确定,因此通常会输出一个不确定的值

对于二分类任务,逻辑回归会先使用回归,生成一个得分值,即落在实数集区间内,然后再使用 s i g m o i d sigmoid sigmoid函数,将得分值映射到 [ 0 , 1 ] [0, 1] [0,1]区间内,得到预测概率

s i g m o i d sigmoid sigmoid函数
σ ( x ) = 1 1 + e − x \sigma (x) = \frac{1}{1+e^{-x}} σ(x)=1+ex1
在这里插入图片描述

S i g m o i d Sigmoid Sigmoid常用来做二分类任务,常具备三个特征:

  • 饱和函数
  • 单调递增
  • 有极限

当我们使用线性回归来得到逻辑回归的得分值时,逻辑回归模型的函数定义就如下所示:
y ^ = σ ( x ∗ ω + b ) \hat y = \sigma (x*\omega + b) y^=σ(xω+b)

损失函数

线性回归使用的损失函数是计算预测值和真实值之差

而对于逻辑回归,由于我们得到的是概率,是一个 0 − 1 0-1 01分布,因此需要修改损失函数
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss = -(ylog\hat y + (1-y)log(1-\hat y)) loss=(ylogy^+(1y)log(1y^))
即我们比较的是分布之间的差异

交叉熵 c r o s s − e n t r o p y cross-entropy crossentropy

存在两个分布 P D 1 ( x ) P_{D1}(x) PD1(x) P D 2 ( x ) P_{D2}(x) PD2(x)

两个分布的差异程度使用公式: ∑ i = 1 n P D 1 ( x i ) l n P D 2 ( x i ) \sum_{i=1}^{n}P_{D1}(x_i)lnP_{D2}(x_i) i=1nPD1(xi)lnPD2(xi) 来衡量

上述公式越大时,两个分布的差异越小

模型的改变

模型构造的改变
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred

需要先将输入写入到linear()线性模型中,再使用Sigmoid()函数

模型损失函数的改变

使用交叉熵函数BCELoss

criterion = torch.nn.BCELoss(size_average=False)

整体代码

import torch
import matplotlib.pyplot as plt
import numpy as np########## 数据集准备 ##########
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])########## 模型定义 ##########
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = torch.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel()########## 损失函数和优化器的设置 ##########
criterion = torch.nn.BCELoss(size_average=False) # BCELoss -- 交叉熵函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)########## 模型训练 ##########
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()########## 模型测试 ##########
x = np.linspace(0, 10, 200)
x_test = torch.Tensor(x).view((200, 1)) # view()相当于reshape
y_test = model(x_test)
y = y_test.data.numpy()  # 转化为np类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], 'r--')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/50505.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

了解Java虚拟机(JVM)

前言👀~ 上一章我们介绍网络原理相关的知识点,今天我们浅浅来了解一下java虚拟机JVM JVM( Java Virtual Machine ) JVM内存区域划分 方法区/元数据区(线程共享) 堆(线程共享) 虚…

iOS object-C 解答算法:找到所有数组中消失的数字(leetCode-448)

找到所有数组中消失的数字(leetCode-448) 题目如下图:(也可以到leetCode上看完整题目,题号448) 光看题看可能有点难以理解,我们结合示例1来理解一下这道题. 有8个整数的数组 nums [4,3,2,7,8,2,3,1], 求在闭区间[1,8]范围内(即1,2,3,4,5,6,7,8)的数字,哪几个没有出现在数组 …

Spring Boot的Web开发

目录 Spring Boot的Web开发 1.静态资源映射规则 第一种静态资源映射规则 2.enjoy模板引擎 3.springMVC 3.1请求处理 RequestMapping DeleteMapping 删除 PutMapping 修改 GetMapping 查询 PostMapping 新增 3.2参数绑定 一.支持数据类型: 3.3常用注解 一.Request…

网闸(Network Gatekeeper或Security Gateway)

本心、输入输出、结果 文章目录 网闸(Network Gatekeeper或Security Gateway)前言网闸主要功能网闸工作原理网闸使用场景网闸网闸(Network Gatekeeper或Security Gateway) 编辑 | 简简单单 Online zuozuo 地址 | https://blog.csdn.net/qq_15071263 如果觉得本文对你有帮助…

c++如何理解多态与虚函数

目录 **前言****1. 何为多态**1.1 **编译时多态**1.1.1 函数重载1.1.2 模板 **1.2 运行时多态****1.2.1 虚函数****1.2.2 为什么要用父类指针去调用子类函数** **2. 注意****2.1 基类的析构函数应写为虚函数****2.2 构造函数不能设为虚函数** **本文参考** 前言 在学习 c 的虚…

C++ | Leetcode C++题解之第275题H指数II

题目&#xff1a; 题解&#xff1a; class Solution { public:int hIndex(vector<int>& citations) {int n citations.size();int left 0, right n - 1;while (left < right) {int mid left (right - left) / 2;if (citations[mid] > n - mid) {right m…

全球“微软蓝屏”事件:IT基础设施韧性与安全性的考验

人不走空 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌赋&#xff1a;斯是陋室&#xff0c;惟吾德馨 目录 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌…

git配置环境变量

一.找到git安装目录 打开此git安装目录下的bin文件&#xff0c;复制此文件路径 二.配置环境变量 2.1 右键点击此电脑的属性栏 2.2 点击高级系统配置 2.3 点击环境变量 2.4 按图中步骤进行配置 三.配置完成 win r 输入cmd打开终端 终端页面中输入 git --version 如图所示…

20240725java的Controller、DAO、DO、Mapper、Service层、反射、AOP注解等内容的学习

在Java开发中&#xff0c;‌controller、‌dao、‌do、‌mapper等概念通常与MVC&#xff08;‌Model-View-Controller&#xff09;‌架构和分层设计相关。‌这些概念各自承担着不同的职责&#xff0c;‌共同协作以构建和运行一个应用程序。‌以下是这些概念的解释&#xff1a;‌…

Redis的两种持久化方式---RDB、AOF

rdb其实就是一种快照持久化的方式&#xff0c;它会将Redis在某个时间点的所有的数据状态以二进制的方式保存到硬盘上的文件当中&#xff0c;它相对于aof文件会小很多&#xff0c;因为知识某个时间点的数据&#xff0c;当然&#xff0c;这就会导致它的实时性不够高&#xff0c;如…

【游戏制作】使用Python创建一个美观的贪吃蛇游戏,附完整代码

目录 前言 项目运行结果 项目简介 环境配置 代码实现 主体结构 主要功能详解 界面和菜单 控制蛇的移动 食物生成和碰撞检测 游戏结束 运行游戏 总结 前言 贪吃蛇游戏是一款经典的电脑游戏&#xff0c;许多人都曾经玩过。今天我们将使用Python和ttkbootstrap库来实…

Mysql注意事项(一)

Mysql注意事项&#xff08;一&#xff09; 最近回顾了一下MySQL&#xff0c;发现了一些MySQL需要注意的事项&#xff0c;同时也作为学习笔记&#xff0c;记录下来。–2020年05月13日 1、通配符* 检索所有的列。 不建议使用 通常&#xff0c;除非你确定需要表中的每个列&am…

51单片机-第四节-定时器

一、定时器&#xff1a; 1.介绍&#xff1a; 单片机内部实现的计时系统。 作用&#xff1a;代替长时间Daley&#xff0c;提高cpu效率。 数量&#xff1a;至少2个&#xff0c;T0&#xff0c;T1&#xff0c;T2等。其中T0&#xff0c;T1为所有51单片机共有&#xff0c;T2等为不…

爬虫提速!用Python实现多线程下载器!

✨ 内容&#xff1a; 在网络应用中&#xff0c;下载速度往往是用户体验的关键。多线程下载可以显著提升下载速度&#xff0c;通过将一个文件分成多个部分并行下载&#xff0c;可以更高效地利用带宽资源。今天&#xff0c;我们将通过一个实际案例&#xff0c;学习如何用Python实…

typecho仿某度响应式主题Xaink

新闻类型博客主题&#xff0c;简洁好看&#xff0c;适合资讯类、快讯类、新闻类博客建站&#xff0c;响应式设计&#xff0c;支持明亮和黑暗模式 直接下载 zip 源码->解压后移动到 Typecho 主题目录->改名为xaink->启用。 演示图&#xff1a; 下载链接&#xff1a; t…

【proteus经典项目实战】51单片机用计数器中断实现100以内的按键计数并播放音乐

一、简介 一个基于8051微控制器的计数器系统&#xff0c;该系统能够通过按键输入递增计数&#xff0c;并且能够在达到100时归零。该系统将使用计数器中断和外部中断来实现其功能。 51单片机因其简单易用和成本效益高&#xff0c;成为电子爱好者和学生的首选平台。通过编程单片…

最新风车IM即时聊天源码及完整视频教程2024年7月版

堡塔面板 试验性Centos/Ubuntu/Debian安装命令 独立运行环境&#xff08;py3.7&#xff09; 可能存在少量兼容性问题 不断优化中 curl -sSO http://io.bt.sy/install/install_panel.sh && bash install_panel.sh 1.宝塔环境如下: Nginx 1.20 Tomcat 8 MySQL 8.0 R…

构造+有序集合,CF 1023D - Array Restoration

一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 1023D - Array Restoration 二、解题报告 1、思路分析 先考虑合法性检查&#xff1a; 对于数字x&#xff0c;其最左位置和最右位置 之间如果存在数字比x小&#xff0c;则非法 由于q次操作&#xff0c;第q…

GPT-4o mini:AI技术的平民化革命

目录 引言一、GPT-4o mini简介二、性能表现三、技术特点四、价格与市场定位五、应用场景六、安全性与可靠性七、未来展望八、代码示例结语 引言 在人工智能的浪潮中&#xff0c;大模型技术一直是研究和应用的热点。然而&#xff0c;高昂的成本和复杂的部署常常让许多企业和开发…

基于DMASM镜像的DMDSC共享存储集群部署

DMv8镜像模式共享存储集群部署 环境说明 操作系统&#xff1a;centos7.6 服务器&#xff1a;2台虚拟机 达梦数据库版本&#xff1a;达梦V8 安装前准备工作 参考文档《DM8共享存储集群》-第11、12章节 参考文档《DM8_Linux服务脚本使用手册》 1、系统环境(all nodes) 1…