【机器学习】单变量线性回归

文章目录

  • 线性回归模型(linear regression model)
  • 损失/代价函数(cost function)——均方误差(mean squared error)
  • 梯度下降算法(gradient descent algorithm)
  • 参数(parameter)和超参数(hyperparameter)
  • 代码实现样例
  • 运行结果

线性回归模型(linear regression model)

  • 线性回归模型:

f w , b ( x ) = w x + b f_{w,b}(x) = wx + b fw,b(x)=wx+b

其中, w w w 为权重(weight), b b b 为偏置(bias)

  • 预测值(通常加一个帽子符号):

y ^ ( i ) = f w , b ( x ( i ) ) = w x ( i ) + b \hat{y}^{(i)} = f_{w,b}(x^{(i)}) = wx^{(i)} + b y^(i)=fw,b(x(i))=wx(i)+b

损失/代价函数(cost function)——均方误差(mean squared error)

  • 一个训练样本: ( x ( i ) , y ( i ) ) (x^{(i)}, y^{(i)}) (x(i),y(i))
  • 训练样本总数 = m m m
  • 损失/代价函数是一个二次函数,在图像上是一个开口向上的抛物线的形状。

J ( w , b ) = 1 2 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] 2 = 1 2 m ∑ i = 1 m [ w x ( i ) + b − y ( i ) ] 2 \begin{aligned} J(w, b) &= \frac{1}{2m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}]^2 \\ &= \frac{1}{2m} \sum^{m}_{i=1} [wx^{(i)} + b - y^{(i)}]^2 \end{aligned} J(w,b)=2m1i=1m[fw,b(x(i))y(i)]2=2m1i=1m[wx(i)+by(i)]2

  • 为什么需要乘以 1/2?因为对平方项求偏导后会出现系数 2,是为了约去这个系数。

梯度下降算法(gradient descent algorithm)

  • α \alpha α:学习率(learning rate),用于控制梯度下降时的步长,以抵达损失函数的最小值处。若 α \alpha α 太小,梯度下降太慢;若 α \alpha α 太大,下降过程可能无法收敛。
  • 梯度下降算法:

r e p e a t { t m p _ w = w − α ∂ J ( w , b ) w t m p _ b = b − α ∂ J ( w , b ) b w = t m p _ w b = t m p _ b } u n t i l c o n v e r g e \begin{aligned} repeat \{ \\ & tmp\_w = w - \alpha \frac{\partial J(w, b)}{w} \\ & tmp\_b = b - \alpha \frac{\partial J(w, b)}{b} \\ & w = tmp\_w \\ & b = tmp\_b \\ \} until \ & converge \end{aligned} repeat{}until tmp_w=wαwJ(w,b)tmp_b=bαbJ(w,b)w=tmp_wb=tmp_bconverge

其中,偏导数为

∂ J ( w , b ) w = 1 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] x ( i ) ∂ J ( w , b ) b = 1 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] \begin{aligned} & \frac{\partial J(w, b)}{w} = \frac{1}{m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}] x^{(i)} \\ & \frac{\partial J(w, b)}{b} = \frac{1}{m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}] \end{aligned} wJ(w,b)=m1i=1m[fw,b(x(i))y(i)]x(i)bJ(w,b)=m1i=1m[fw,b(x(i))y(i)]

参数(parameter)和超参数(hyperparameter)

  • 超参数(hyperparameter):训练之前人为设置的任何数量都是超参数,例如学习率 α \alpha α
  • 参数(parameter):模型在训练过程中创建或修改的任何数量都是参数,例如 w , b w, b w,b

代码实现样例

import numpy as np
import matplotlib.pyplot as plt# 计算误差均方函数 J(w,b)
def cost_function(x, y, w, b):m = x.shape[0] # 训练集的数据样本数cost_sum = 0.0for i in range(m):f_wb = w * x[i] + bcost = (f_wb - y[i]) ** 2cost_sum += costreturn cost_sum / (2 * m)# 计算梯度值 dJ/dw, dJ/db
def compute_gradient(x, y, w, b):m = x.shape[0] # 训练集的数据样本数d_w = 0.0d_b = 0.0for i in range(m):f_wb = w * x[i] + bd_wi = (f_wb - y[i]) * x[i]d_bi = (f_wb - y[i])d_w += d_wid_b += d_bidj_dw = d_w / mdj_db = d_b / mreturn dj_dw, dj_db# 梯度下降算法
def linear_regression(x, y, w, b, learning_rate=0.01, epochs=1000):J_history = [] # 记录每次迭代产生的误差值for epoch in range(epochs):dj_dw, dj_db = compute_gradient(x, y, w, b)# w 和 b 需同步更新w = w - learning_rate * dj_dwb = b - learning_rate * dj_dbJ_history.append(cost_function(x, y, w, b)) # 记录每次迭代产生的误差值return w, b, J_history# 绘制线性方程的图像
def draw_line(w, b, xmin, xmax, title):x = np.linspace(xmin, xmax)y = w * x + b# plt.axis([0, 10, 0, 50]) # xmin, xmax, ymin, ymaxplt.xlabel("X-axis", size=15)plt.ylabel("Y-axis", size=15)plt.title(title, size=20)plt.plot(x, y)# 绘制散点图
def draw_scatter(x, y, title):plt.xlabel("X-axis", size=15)plt.ylabel("Y-axis", size=15)plt.title(title, size=20)plt.scatter(x, y)# 从这里开始执行
if __name__ == '__main__':# 训练集样本x_train = np.array([1, 2, 3, 5, 6, 7])y_train = np.array([15.5, 19.7, 24.4, 35.6, 40.7, 44.8])w = 0.0 # 权重b = 0.0 # 偏置epochs = 10000 # 迭代次数learning_rate = 0.01 # 学习率J_history = [] # # 记录每次迭代产生的误差值w, b, J_history = linear_regression(x_train, y_train, w, b, learning_rate, epochs)print(f"result: w = {w:0.4f}, b = {b:0.4f}") # 打印结果# 绘制迭代计算得到的线性回归方程plt.figure(1)draw_line(w, b, 0, 10, "Linear Regression")plt.scatter(x_train, y_train) # 将训练数据集也表示在图中plt.show()# 绘制误差值的散点图plt.figure(2)x_axis = list(range(0, 10000))draw_scatter(x_axis, J_history, "Cost Function in Every Epoch")plt.show()

运行结果

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/672950.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nginx限流设置

1.反向代理(建议先看正向代理,反向代理则是同样你要与对方服务器建立连接,但是,代理服务器和目标服务器在一个LAN下,所以我们需要与代理服务器先建交,再由他获取与目标服务器的交互,好比一个带刀侍卫守护着目标服务器) 屏蔽目标服务器的真实地址,相对安全性较好&am…

ubuntu下修改hosts读写权限

ubuntu下修改hosts文件的操作: 由于需要在hosts文件下添加ip地址信息,但是初始情况下系统该文件为只读权限无法修改,具体操作如下所示; 1.cd到系统etc目录下,执行如下命令,此时会提示输入密码,直接输入回…

PgSQL技术内幕 - case when表达式实现机制

PgSQL技术内幕 - case when表达式实现机制 CASE表达式如同 C语言中的if/else语句一样,为SQL添加了条件逻辑处理能力,可以根据不同条件返回不同结果。PgSQL支持两种语法:简单表达式和搜索表达式。 1、搜索表达式 语法如下: CASE WH…

掼蛋牌桌上的默契-牌语解读篇

掼蛋不仅仅是个人战斗,也是和队友之间的默契与配合的战斗。长时间合作的玩家间往往能够通过一些特定的出牌方式传递信息,这些“暗号”或“牌语”成为了他们都顺利夺取胜利的秘密武器。 这些技巧都需要在日常实践中留心捕捉,用心理解和领悟&am…

1978-2022年各省家庭恩格尔系数(分城镇、农村)

1978-2022年各省家庭恩格尔系数(分城镇、农村) 1、时间:1978-2022年 2、指标:城镇家庭恩格尔系数、农村家庭恩格尔系数 3、来源:统计年鉴、省统计公报 4、范围:31省 5、指标解释:恩格尔系数…

springboot整合rabbitmq,及各类型交换机详解

RabbitMQ交换机: 一.交换机的作用 如果直接发送信息给一条队列,而这一消息需要多个队列的的多个消费者共同执行,可此时只会有一个队列的一个消费者接收该消息并处理,其他队列的消费者无法获取消息并执行。所以此时就需要交换机接…

如何使用phpStudy搭建网站并结合内网穿透远程访问本地站点

文章目录 [toc]使用工具1. 本地搭建web网站1.1 下载phpstudy后解压并安装1.2 打开默认站点,测试1.3 下载静态演示站点1.4 打开站点根目录1.5 复制演示站点到站网根目录1.6 在浏览器中,查看演示效果。 2. 将本地web网站发布到公网2.1 安装cpolar内网穿透2…

飞马座卫星

1960年代马歇尔太空飞行中心的历史显然与建造土星五号月球火箭有关。然而,鲜为人知的是该中心在设计科学有效载荷方面的早期工作。 Fairchild 技术人员正在检查扩展的 Pegasus 流星体探测表面。Pegasus 由马里兰州黑格斯敦的 Fairchild Stratos Corporation 通过马歇…

Verilog刷题笔记22

题目: Build a priority encoder for 8-bit inputs. Given an 8-bit vector, the output should report the first (least significant) bit in the vector that is 1. Report zero if the input vector has no bits that are high. For example, the input 8’b100…

springboot164党员教育和管理系统

简介 【毕设源码推荐 javaweb 项目】基于springbootvue 的 适用于计算机类毕业设计,课程设计参考与学习用途。仅供学习参考, 不得用于商业或者非法用途,否则,一切后果请用户自负。 看运行截图看 第五章 第四章 获取资料方式 **项…

gh0st远程控制——客户端界面编写(四)

本节任务点 ◉ 为所有菜单项添加测试响应函数 ◉ 添加删除列表指定条目的功能 为所有菜单项添加测试响应函数: 添加菜单响应函数: void CPCRemoteDlg::OnOnlineCmd() {if (TEST_MODE) {MessageBox("终端管理界面");} }void CPCRemoteDlg:…

Linux 研究文件描述符fd的分配规则

目标:研究fd的分配规则 方式:做实验 我们写一段代码,需要实现的功能如下:利用系统调用接口实现读入字符,并且把读入的字符打印在屏幕上。 实验1 我们需要用到read()函数。 read是系统调用接口,头文件和…

矩阵的正定(positive definite)性质的作用

1. 定义 注意,本文中正定和半正定矩阵不要求是对称或Hermite的。 2. 性质 3. 作用 (1)Axb直接法求解 cholesky实对称正定矩阵求解复共轭对称正定矩阵求解LDL实对称非正定矩阵求解复共轭对称非正定矩阵求解复对称矩阵求解LU实非对称矩阵求解…

假期作业5

TCP和UDP区别 TCP ----稳定 1、提供面向连接的,可靠的数据传输服务; 2、传输过程中,数据无误、数据无丢失、数据无失序、数据无重复; 3、数据传输效率低,耗费资源多; 4、数据收发是不同步的; U…

APIfox自动化编排场景(二)

测试流程控制条件 你可以在测试场景中新增流程控制条件(循环、判断、等待、分组)等。进一步满足了更复杂的测试场景/流程配置的使用,最终借助自动化测试功能解决复杂场景的测试工作。 分组​ 当测试流程中多个步骤存在相关联关系时&#xf…

相机图像质量研究(4)常见问题总结:光学结构对成像的影响--焦距

系列文章目录 相机图像质量研究(1)Camera成像流程介绍 相机图像质量研究(2)ISP专用平台调优介绍 相机图像质量研究(3)图像质量测试介绍 相机图像质量研究(4)常见问题总结:光学结构对成像的影响--焦距 相机图像质量研究(5)常见问题总结:光学结构对成…

作业2.7

一、填空题 1、在下列程序的空格处填上适当的字句&#xff0c;使输出为&#xff1a;0&#xff0c;2&#xff0c;10。 #include <iostream> #include <math.h> class Magic {double x; public: Magic(double d0.00):x(fabs(d)) {} Magic operator(__const Magic&…

2024年2月11日(星期天)骑行金色螳川

2024年2月11日 (星期天&#xff0c;年初二) 骑行金色螳川(赏油菜花&#xff09;&#xff0c;早8:30到9:00&#xff0c; 大观公园门囗集合&#xff0c;9:30准时出发【因迟到者&#xff0c;骑行速度快者&#xff0c;可自行追赶偶遇。】 偶遇地点:大观公园门囗集合 &#xff0c;家…

Redis 命令大全

文章目录 启动与连接Key&#xff08;键&#xff09;相关命令String&#xff08;字符串&#xff09;Hash&#xff08;哈希&#xff09;List&#xff08;列表&#xff09;Set&#xff08;集合&#xff09;Sorted Set&#xff08;有序集合&#xff09;其他常见命令HyperLogLog&…

【Leetcode】LCP 30. 魔塔游戏

文章目录 题目思路代码结果 题目 题目链接 小扣当前位于魔塔游戏第一层&#xff0c;共有 N 个房间&#xff0c;编号为 0 ~ N-1。每个房间的补血道具/怪物对于血量影响记于数组 nums&#xff0c;其中正数表示道具补血数值&#xff0c;即血量增加对应数值&#xff1b;负数表示怪…