Pytorch深度学习实践(5)逻辑回归

逻辑回归

逻辑回归主要是解决分类问题

  • 回归任务:结果是一个连续的实数
  • 分类任务:结果是一个离散的值

分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 09),因为各个类别之间没有大小之差。

因此,对于分类问题,我们最终的输出是个概率,即属于某个类别的概率是多少,然后从概率集合里找最大值,作为当前预测的结果

下载MNIST数据集

import torchvision
train_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=True, download=True)
test_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=False, dowload=True)
  • 通过train参数来指定训练集和测试集

逻辑回归

将之前的学习时长—考试分数转化为二分类任务,即学习时长—是否通过考试

x(hours)y(pass/fail)
10(fail)
20(fail)
31(pass)
4?

其中, P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat y = 1) + P(\hat y = 0) = 1 P(y^=1)+P(y^=0)=1

当输出的概率在 0.5 0.5 0.5附近时,即模型不确定,因此通常会输出一个不确定的值

对于二分类任务,逻辑回归会先使用回归,生成一个得分值,即落在实数集区间内,然后再使用 s i g m o i d sigmoid sigmoid函数,将得分值映射到 [ 0 , 1 ] [0, 1] [0,1]区间内,得到预测概率

s i g m o i d sigmoid sigmoid函数
σ ( x ) = 1 1 + e − x \sigma (x) = \frac{1}{1+e^{-x}} σ(x)=1+ex1
在这里插入图片描述

S i g m o i d Sigmoid Sigmoid常用来做二分类任务,常具备三个特征:

  • 饱和函数
  • 单调递增
  • 有极限

当我们使用线性回归来得到逻辑回归的得分值时,逻辑回归模型的函数定义就如下所示:
y ^ = σ ( x ∗ ω + b ) \hat y = \sigma (x*\omega + b) y^=σ(xω+b)

损失函数

线性回归使用的损失函数是计算预测值和真实值之差

而对于逻辑回归,由于我们得到的是概率,是一个 0 − 1 0-1 01分布,因此需要修改损失函数
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss = -(ylog\hat y + (1-y)log(1-\hat y)) loss=(ylogy^+(1y)log(1y^))
即我们比较的是分布之间的差异

交叉熵 c r o s s − e n t r o p y cross-entropy crossentropy

存在两个分布 P D 1 ( x ) P_{D1}(x) PD1(x) P D 2 ( x ) P_{D2}(x) PD2(x)

两个分布的差异程度使用公式: ∑ i = 1 n P D 1 ( x i ) l n P D 2 ( x i ) \sum_{i=1}^{n}P_{D1}(x_i)lnP_{D2}(x_i) i=1nPD1(xi)lnPD2(xi) 来衡量

上述公式越大时,两个分布的差异越小

模型的改变

模型构造的改变
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred

需要先将输入写入到linear()线性模型中,再使用Sigmoid()函数

模型损失函数的改变

使用交叉熵函数BCELoss

criterion = torch.nn.BCELoss(size_average=False)

整体代码

import torch
import matplotlib.pyplot as plt
import numpy as np########## 数据集准备 ##########
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])########## 模型定义 ##########
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = torch.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel()########## 损失函数和优化器的设置 ##########
criterion = torch.nn.BCELoss(size_average=False) # BCELoss -- 交叉熵函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)########## 模型训练 ##########
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()########## 模型测试 ##########
x = np.linspace(0, 10, 200)
x_test = torch.Tensor(x).view((200, 1)) # view()相当于reshape
y_test = model(x_test)
y = y_test.data.numpy()  # 转化为np类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], 'r--')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/50505.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

了解Java虚拟机(JVM)

前言👀~ 上一章我们介绍网络原理相关的知识点,今天我们浅浅来了解一下java虚拟机JVM JVM( Java Virtual Machine ) JVM内存区域划分 方法区/元数据区(线程共享) 堆(线程共享) 虚…

iOS object-C 解答算法:找到所有数组中消失的数字(leetCode-448)

找到所有数组中消失的数字(leetCode-448) 题目如下图:(也可以到leetCode上看完整题目,题号448) 光看题看可能有点难以理解,我们结合示例1来理解一下这道题. 有8个整数的数组 nums [4,3,2,7,8,2,3,1], 求在闭区间[1,8]范围内(即1,2,3,4,5,6,7,8)的数字,哪几个没有出现在数组 …

Spring Boot的Web开发

目录 Spring Boot的Web开发 1.静态资源映射规则 第一种静态资源映射规则 2.enjoy模板引擎 3.springMVC 3.1请求处理 RequestMapping DeleteMapping 删除 PutMapping 修改 GetMapping 查询 PostMapping 新增 3.2参数绑定 一.支持数据类型: 3.3常用注解 一.Request…

网闸(Network Gatekeeper或Security Gateway)

本心、输入输出、结果 文章目录 网闸(Network Gatekeeper或Security Gateway)前言网闸主要功能网闸工作原理网闸使用场景网闸网闸(Network Gatekeeper或Security Gateway) 编辑 | 简简单单 Online zuozuo 地址 | https://blog.csdn.net/qq_15071263 如果觉得本文对你有帮助…

c++如何理解多态与虚函数

目录 **前言****1. 何为多态**1.1 **编译时多态**1.1.1 函数重载1.1.2 模板 **1.2 运行时多态****1.2.1 虚函数****1.2.2 为什么要用父类指针去调用子类函数** **2. 注意****2.1 基类的析构函数应写为虚函数****2.2 构造函数不能设为虚函数** **本文参考** 前言 在学习 c 的虚…

CentOS7停止维护,更换国内镜像源

众所周知,截至6月30号,CentOS7就停止维护了,由此带来了一系列影响,最直接的影响就是通过yum安装软件时,链接404了,这就很不友好了。。。 不过好在咱们国内的各大厂很早就做了国内的镜像源,比如阿…

Vue中如何获取dom元素?vue方法

在Vue中,通常我们不建议直接操作DOM元素,因为Vue的设计哲学是声明式渲染,即你通过Vue的模板语法(如插值表达式、指令等)来描述你的界面,Vue会负责将你的数据渲染成DOM,并在数据变化时自动更新DO…

C++ | Leetcode C++题解之第275题H指数II

题目&#xff1a; 题解&#xff1a; class Solution { public:int hIndex(vector<int>& citations) {int n citations.size();int left 0, right n - 1;while (left < right) {int mid left (right - left) / 2;if (citations[mid] > n - mid) {right m…

全球“微软蓝屏”事件:IT基础设施韧性与安全性的考验

人不走空 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌赋&#xff1a;斯是陋室&#xff0c;惟吾德馨 目录 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌…

git配置环境变量

一.找到git安装目录 打开此git安装目录下的bin文件&#xff0c;复制此文件路径 二.配置环境变量 2.1 右键点击此电脑的属性栏 2.2 点击高级系统配置 2.3 点击环境变量 2.4 按图中步骤进行配置 三.配置完成 win r 输入cmd打开终端 终端页面中输入 git --version 如图所示…

记一次老旧项目的整体技术升级

最近给公司采购的老旧的 node8 vue2.6 webpack3 npm 项目做构建优化 背景&#xff1a;整个项目 build 一次 20 min &#xff0c;本地冷启动和热更新也忒慢&#xff0c;依赖 npm i 一下也得装个 20 min 众所周知&#xff0c;Node 版本&#xff0c;依赖包管理工具 和 构建工…

20240725java的Controller、DAO、DO、Mapper、Service层、反射、AOP注解等内容的学习

在Java开发中&#xff0c;‌controller、‌dao、‌do、‌mapper等概念通常与MVC&#xff08;‌Model-View-Controller&#xff09;‌架构和分层设计相关。‌这些概念各自承担着不同的职责&#xff0c;‌共同协作以构建和运行一个应用程序。‌以下是这些概念的解释&#xff1a;‌…

小团队应该如何选择 DevOps 工具链?

极狐GitLab&#xff08;GitLab 中国发行版&#xff09;正式推出面向个人、小团队的极狐GitLab SaaS 团队版&#xff0c;仅需 499/人/年。即可享受 AI DevOps平台 专业技术服务支持&#xff01;&#xff01;&#xff01;详情查看 https://dl.gitlab.cn/50lnrrrq 回答这个问题…

Redis的两种持久化方式---RDB、AOF

rdb其实就是一种快照持久化的方式&#xff0c;它会将Redis在某个时间点的所有的数据状态以二进制的方式保存到硬盘上的文件当中&#xff0c;它相对于aof文件会小很多&#xff0c;因为知识某个时间点的数据&#xff0c;当然&#xff0c;这就会导致它的实时性不够高&#xff0c;如…

关于RabbitMQ你了解多少?

关于RabbitMQ你了解多少&#xff1f; 文章目录 关于RabbitMQ你了解多少&#xff1f;基础篇使用消息队列RabbitMQRabbitMQ 7种用法类型Exchange 交换机 进阶篇目录客户端整合SpringBoot消息可靠性投递消费端限流消息超时死信和死信队列延迟队列事务消息惰性队列优先级队列 集群篇…

【游戏制作】使用Python创建一个美观的贪吃蛇游戏,附完整代码

目录 前言 项目运行结果 项目简介 环境配置 代码实现 主体结构 主要功能详解 界面和菜单 控制蛇的移动 食物生成和碰撞检测 游戏结束 运行游戏 总结 前言 贪吃蛇游戏是一款经典的电脑游戏&#xff0c;许多人都曾经玩过。今天我们将使用Python和ttkbootstrap库来实…

Mysql注意事项(一)

Mysql注意事项&#xff08;一&#xff09; 最近回顾了一下MySQL&#xff0c;发现了一些MySQL需要注意的事项&#xff0c;同时也作为学习笔记&#xff0c;记录下来。–2020年05月13日 1、通配符* 检索所有的列。 不建议使用 通常&#xff0c;除非你确定需要表中的每个列&am…

51单片机-第四节-定时器

一、定时器&#xff1a; 1.介绍&#xff1a; 单片机内部实现的计时系统。 作用&#xff1a;代替长时间Daley&#xff0c;提高cpu效率。 数量&#xff1a;至少2个&#xff0c;T0&#xff0c;T1&#xff0c;T2等。其中T0&#xff0c;T1为所有51单片机共有&#xff0c;T2等为不…

爬虫提速!用Python实现多线程下载器!

✨ 内容&#xff1a; 在网络应用中&#xff0c;下载速度往往是用户体验的关键。多线程下载可以显著提升下载速度&#xff0c;通过将一个文件分成多个部分并行下载&#xff0c;可以更高效地利用带宽资源。今天&#xff0c;我们将通过一个实际案例&#xff0c;学习如何用Python实…

typecho仿某度响应式主题Xaink

新闻类型博客主题&#xff0c;简洁好看&#xff0c;适合资讯类、快讯类、新闻类博客建站&#xff0c;响应式设计&#xff0c;支持明亮和黑暗模式 直接下载 zip 源码->解压后移动到 Typecho 主题目录->改名为xaink->启用。 演示图&#xff1a; 下载链接&#xff1a; t…