吴恩达机器学习-可选实验:梯度下降逻辑回归(Gradient Descent for Logistic Regression)

文章目录

    • 目标
    • 数据集
    • Logistic梯度下降
    • 梯度下降实现
      • 计算梯度,代码描述
    • 另一个数据集

目标

在本实验中,你将:

  • 更新逻辑回归的梯度下降
  • 在一个熟悉的数据集上探索梯度下降
  • 使用梯度下降给逻辑回归更新参数
import copy, math
import numpy as np
%matplotlib widget
import matplotlib.pyplot as plt
from lab_utils_common import  dlc, plot_data, plt_tumor_data, sigmoid, compute_cost_logistic
from plt_quad_logistic import plt_quad_logistic, plt_prob
plt.style.use('./deeplearning.mplstyle')

数据集

让我们从决策边界实验室中使用的相同的两个特征数据集开始。

X_train = np.array([[0.5, 1.5], [1,1], [1.5, 0.5], [3, 0.5], [2, 2], [1, 2.5]])
y_train = np.array([0, 0, 0, 1, 1, 1])

和前面一样,我们将使用一个辅助函数来绘制这些数据。标签y = 1的数据点显示为红色叉,而标签y = 0的数据点显示为蓝色圆。

fig,ax = plt.subplots(1,1,figsize=(4,4))
plot_data(X_train, y_train, ax)ax.axis([0, 4, 0, 3.5])
ax.set_ylabel('$x_1$', fontsize=12)
ax.set_xlabel('$x_0$', fontsize=12)
plt.show()

在这里插入图片描述

Logistic梯度下降

在这里插入图片描述
回想一下梯度下降算法利用了梯度计算:
在这里插入图片描述
在这里插入图片描述

梯度下降实现

梯度下降算法的实现有两个部分:

  • 实现上述公式(1)的循环。这是下面的gradient_descent,通常在可选和实践实验室中提供给您。
  • 当前梯度的计算,如式(2、3)所示。这是下面的compute_gradient_logistic。你将被要求完成本周的实践实验。

计算梯度,代码描述

对所有w和b实现上述式(2)、(3)。实现方法有很多,如下所示:

  • 初始化变量,累加dj_dw和dj_db
  • 对于每个例子
    • 计算该示例的误差g(w·x^i +b)-y ^ i
    • 对于本例中的每个输入值Xj^i,
      • 将误差乘以输入的Xj^i,并加上dj_dw的相应元素。(上式2)
    • 将错误添加到dj_db(上面的公式3)
  • 用dj_db和dj_dw除以样本总数(m)
  • 请注意,在numpy中X[i,:]或X[i]中的x^i和Xj ^i是X[i, j]
def compute_gradient_logistic(X, y, w, b): """Computes the gradient for linear regression Args:X (ndarray (m,n): Data, m examples with n featuresy (ndarray (m,)): target valuesw (ndarray (n,)): model parameters  b (scalar)      : model parameterReturnsdj_dw (ndarray (n,)): The gradient of the cost w.r.t. the parameters w. dj_db (scalar)      : The gradient of the cost w.r.t. the parameter b. """m,n = X.shapedj_dw = np.zeros((n,))                           #(n,)dj_db = 0.for i in range(m):f_wb_i = sigmoid(np.dot(X[i],w) + b)          #(n,)(n,)=scalarerr_i  = f_wb_i  - y[i]                       #scalarfor j in range(n):dj_dw[j] = dj_dw[j] + err_i * X[i,j]      #scalardj_db = dj_db + err_idj_dw = dj_dw/m                                   #(n,)dj_db = dj_db/m                                   #scalarreturn dj_db, dj_dw  

使用下面的单元格检查梯度函数的实现。

X_tmp = np.array([[0.5, 1.5], [1,1], [1.5, 0.5], [3, 0.5], [2, 2], [1, 2.5]])
y_tmp = np.array([0, 0, 0, 1, 1, 1])
w_tmp = np.array([2.,3.])
b_tmp = 1.
dj_db_tmp, dj_dw_tmp = compute_gradient_logistic(X_tmp, y_tmp, w_tmp, b_tmp)
print(f"dj_db: {dj_db_tmp}" )
print(f"dj_dw: {dj_dw_tmp.tolist()}" )

在这里插入图片描述
梯度下降代码
实现上述方程(1)的代码如下所示。花点时间定位和比较例程中的函数与上面的方程。

def gradient_descent(X, y, w_in, b_in, alpha, num_iters): """Performs batch gradient descentArgs:X (ndarray (m,n)   : Data, m examples with n featuresy (ndarray (m,))   : target valuesw_in (ndarray (n,)): Initial values of model parameters  b_in (scalar)      : Initial values of model parameteralpha (float)      : Learning ratenum_iters (scalar) : number of iterations to run gradient descentReturns:w (ndarray (n,))   : Updated values of parametersb (scalar)         : Updated value of parameter """# An array to store cost J and w's at each iteration primarily for graphing laterJ_history = []w = copy.deepcopy(w_in)  #avoid modifying global w within functionb = b_infor i in range(num_iters):# Calculate the gradient and update the parametersdj_db, dj_dw = compute_gradient_logistic(X, y, w, b)   # Update Parameters using w, b, alpha and gradientw = w - alpha * dj_dw               b = b - alpha * dj_db               # Save cost J at each iterationif i<100000:      # prevent resource exhaustion J_history.append( compute_cost_logistic(X, y, w, b) )# Print cost every at intervals 10 times or as many iterations if < 10if i% math.ceil(num_iters / 10) == 0:print(f"Iteration {i:4d}: Cost {J_history[-1]}   ")return w, b, J_history         #return final w,b and J history for graphing

让我们对数据集运行梯度下降。

w_tmp  = np.zeros_like(X_train[0])
b_tmp  = 0.
alph = 0.1
iters = 10000w_out, b_out, J = gradient_descent(X_train, y_train, w_tmp, b_tmp, alph, iters) 
print(f"\nupdated parameters: w:{w_out}, b:{b_out}")

在这里插入图片描述
我们来绘制梯度下降的结果:

fig,ax = plt.subplots(1,1,figsize=(5,4))
# plot the probability 
plt_prob(ax, w_out, b_out)# Plot the original data
ax.set_ylabel(r'$x_1$')
ax.set_xlabel(r'$x_0$')   
ax.axis([0, 4, 0, 3.5])
plot_data(X_train,y_train,ax)# Plot the decision boundary
x0 = -b_out/w_out[1]
x1 = -b_out/w_out[0]
ax.plot([0,x0],[x1,0], c=dlc["dlblue"], lw=1)
plt.show()

这段代码看起来是用于在图表上绘制决策边界的部分。让我来解释每一行:

  1. x0 = -b_out/w_out[1]
    • 这行代码计算了决策边界上的两个点的 x 坐标值。假设 b_out 是模型的偏置项,w_out 是模型的权重参数。
  2. x1 = -b_out/w_out[0]
    • 这行代码计算了决策边界上的两个点的 y 坐标值。
  3. ax.plot([0,x0],[x1,0], c=dlc["dlblue"], lw=1)
    • 这行代码使用 Matplotlib 库中的 plot 函数,在图表上绘制了决策边界。它连接了两个点 (0, x0) 和 (x1, 0),即通过前面计算得到的两个点,从而画出了决策边界。参数 c 指定了线的颜色,lw 则指定了线的宽度

在这里插入图片描述
在上图中:

  • 阴影反映了概率y=1(决策边界之前的结果)
  • 决策边界是概率=0.5处的那条线

另一个数据集

让我们回到单变量数据集。只需要两个参数,w, b,就可以绘制出成本函数使用等高线图来更好地了解梯度下降是什么。

x_train = np.array([0., 1, 2, 3, 4, 5])
y_train = np.array([0,  0, 0, 1, 1, 1])

和前面一样,我们将使用一个辅助函数来绘制这些数据。标签y = 1的数据点显示为红色叉,而标签y = 0的数据点显示为黑色圆。

fig,ax = plt.subplots(1,1,figsize=(4,3))
plt_tumor_data(x_train, y_train, ax)
plt.show()

在这里插入图片描述
在这里插入图片描述

w_range = np.array([-1, 7])
b_range = np.array([1, -14])
quad = plt_quad_logistic( x_train, y_train, w_range, b_range )

(1)在这里插入图片描述
(2)
在这里插入图片描述
(3)在这里插入图片描述
(4)点击运行之后在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/739854.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Go微服务: 基于GRPC结合Consul实现微服务调用

基于GRPC结合Consul实现微服务调用 1 &#xff09;环境准备 基于 go workspace 准备3个包: protos&#xff0c;server, client新建 demo目录&#xff0c;其内部结构如下├── protos │ ├── go.mod │ └── users │ ├── users.proto │ …

怎么判断你的模型是好是坏?模型性能评估指标大全!

模型性能评估指标&#xff0c;大家一定不陌生&#xff01;很多小伙伴们都说难&#xff0c;但是它真的很重要很重要很重要&#xff01;它会对我们的模型有很多的指导&#xff0c;也会给我们真正做模型的时候提供一些指导性的思想&#xff0c;不然我们看到别人的东西只能跟着人家…

centos7.9升级ssh和openssl

一、环境 [roottmp179 package]# ssh -V OpenSSH_7.4p1, OpenSSL 1.0.2k-fips 26 Jan 2017 [roottmp179 package]# cat /etc/redhat-release CentOS Linux release 7.9.2009 (Core) 二、 升级前准备 mkdir /opt/package cd /opt/package wget https://www.openssl.org/source…

【linux】冯诺依曼体系与操作系统的理解

本篇文章是进程的预备知识&#xff0c;但也不仅仅是进程的预备知识&#xff0c; 也可以更好地帮助我们理解整个计算机体系。 目录 冯诺依曼体系结构&#xff1a;进一步理解操作系统&#xff1a; 冯诺依曼体系结构&#xff1a; 关于这张图先进行一下必要的解释&#xff1a; 输…

怎样通过IT服务台来增强IT项目管理?

当下&#xff0c;越来越多的企业和组织重视IT项目管理的重要性。而如何通过IT服务台来增强和提升IT项目管理效率&#xff0c;成为了许多企业领导和IT专业人员共同关注的话题。如何充分利用IT服务台&#xff0c;以促进IT项目管理水平的提升和项目成功率的增加变得至关重要。 1…

Codeforces Round 924 (Div. 2)---->B. Equalize

总思路&#xff1a;首先我们做这题的时候有两个点一定要知道&#xff1a; 1.当数组中有重复元素的时候&#xff0c;只有其中的一个才能贡献一个相同元素&#xff0c;其他的都不行&#xff08;因为是排列&#xff0c;一个数只出现一次&#xff09;&#xff0c;所以我们可以用使…

怎么免费下载无水印视频素材?赶快收藏这六个网站。

今天来教大家怎么下载无水印视频素材&#xff0c;其中一些是免费的&#xff0c;并且可以在商业项目中使用&#xff0c;这些网站都是无水印视频素材&#xff0c;可以放心使用。 蛙学网&#xff1a; 网站的内容非常丰富多彩&#xff0c;包括风景&#xff0c;夜景&#xff0c;食物…

论文阅读:Editing Large Language Models: Problems, Methods, and Opportunities

Editing Large Language Models: Problems, Methods, and Opportunities 论文链接 代码链接 摘要 由于大语言模型&#xff08;LLM&#xff09;中可能存在一些过时的、不适当的和错误的信息&#xff0c;所以有必要纠正模型中的相关信息。如何高效地修改模型中的相关信息而不影…

【JS】自动下拉网页刷新,当出现指定关键字,就打印出来

批量检查域名是否可以注册 1、有的网站数据是通过下拉发生请求&#xff0c;间隔x毫秒自动下拉 2、查找某个关键字&#xff0c;找到就打印出来 3、打印数据自动去重 4、当连续n次下拉&#xff0c;没有新div元素出来&#xff0c;就停止该循环 var map {}; var count 0; var l…

qt如何将QHash中的数据有序地放入到QList中

在qt中&#xff0c;要将QHash中的数据有序地放入到QList中&#xff0c;首先要明白&#xff1a; 我们可以遍历QHash中的键值对&#xff0c;并将其按照键的顺序或值的大小插入到QList中&#xff0c;直接用for循环即可。 #include <QCoreApplication> #include <QHas…

java学习(Arrays类和System类)

目录 目录 一.Arrays类 二.System常见方法 三、Biglnteger和BigDecimal&#xff08;高精度&#xff09; 1.Biglnter的常用方法 2.BigDecimal常见方法 3.日期类 1)第一代日期类 2&#xff09;第二代日期类 3)第三代日期类 一.Arrays类 Arrays包含了一系 列静态方法&am…

11、Linux-安装和配置Redis

目录 第一步&#xff0c;传输文件和解压 第二步&#xff0c;安装gcc编译器 第三步&#xff0c;编译Redis 第四步&#xff0c;安装Redis服务 第五步&#xff0c;配置Redis ①开启后台启动 ②关闭保护模式&#xff08;关闭之后才可以远程连接Redis&#xff09; ③设置远程…

12双体系Java学习之局部变量和作用域

局部变量 局部变量的作用域 参数变量

理解记忆相关

foreach循环 在 Java 中&#xff0c;foreach 循环&#xff08;也称为增强型 for 循环&#xff09;是一种简洁的语法&#xff0c;用于遍历数组或集合&#xff08;如 List、Set、Map 等&#xff09;。以下是 foreach 循环的基本用法&#xff1a; 遍历数组&#xff1a; String[] …

在 Python 中从键盘读取用户输入

文章目录 如何在 Python 中从键盘读取用户输入input 函数使用input读取键盘输入使用input读取特定类型的数据处理错误从用户输入中读取多个值 getpass 模块使用 PyInputPlus 自动执行用户输入评估总结 如何在 Python 中从键盘读取用户输入 原文《How to Read User Input From t…

Rust:为 Trait 定义默认的方法

当你提到“指定 trait 的实现”并使用 :: 符号时&#xff0c;你可能是指在某些情况下&#xff0c;你想直接通过 trait 而不是具体的类型来调用方法。这在 trait 提供了默认方法实现时尤其有用&#xff0c;因为你可以不依赖任何具体的类型实现来调用这些方法。 然而&#xff0c…

AI写真变现项目丨超级训练营SOP手册

出品方&#xff1a; 吴东子团队 x AI破局俱乐部 以下只是该SOP手册的部分介绍&#xff0c;AI写真变现项目上手到变现全流程&#xff0c;需要完整手册的可以dd我。 AI写真 首先什么是AI写真&#xff0c;顾名思义的话可以说成是用AI生成写真照&#xff0c;我们先暂且这么理解&am…

PostgreSQL教程(三十二):服务器管理(十四)之监控磁盘使用

本章讨论如何监控PostgreSQL数据库系统的磁盘使用情况。 一、判断磁盘用量 每个表都有一个主要的堆磁盘文件&#xff0c;大多数数据都存储在其中。如果一个表有着可能会很宽&#xff08;尺寸大&#xff09;的列&#xff0c; 则另外还有一个TOAST文件与这个表相关联&#xff0…

Java详解:单列 | 双列集合 | Collections类

○ 前言&#xff1a; 在开发实践中&#xff0c;我们需要一些能够动态增长长度的容器来保存我们的数据&#xff0c;java中为了解决数据存储单一的情况&#xff0c;java中就提供了不同结构的集合类&#xff0c;可以让我们根据不同的场景进行数据存储的选择&#xff0c;如Java中提…