利用梯度下降法求解一元线性回归和多元线性回归

文章目录

  • 原理以及公式
    • 【1】一元线性回归问题
    • 【2】多元线性回归问题
    • 【3】学习率
    • 【4】流程分析(一元线性回归)
    • 【5】流程分析(多元线性回归)
      • 归一化原理以及每种归一化适用的场合
  • 一元线性回归代码以及可视化结果
  • 多元线性回归代码以及可视化结果
  • 总结


原理以及公式

【1】一元线性回归问题

原函数是一元函数(关于x),它的损失函数是二元函数(关于w和b)

这里介绍两种损失函数:平方损失函数和均方差损失函数
在这里插入图片描述

【2】多元线性回归问题

X和W都是m+1维的向量,损失函数是高维空间中的凸函数

【3】学习率

学习率属于超参数(超参数:在开始学习之前设置,不是通过训练得到的)
可以选择在迭代次数增加时减少学习率大小.
下图是学习率正常或较小、稍大、过大的迭代图。

【4】流程分析(一元线性回归)

过程分析:

1、加载样本数据x,y
2、设置超参数学习率,迭代次数
3、设置模型参数初值w0, b0
4、训练模型w, b
5、结果可视化

														流程图:

在这里插入图片描述

【5】流程分析(多元线性回归)

归一化原理以及每种归一化适用的场合

在这里插入图片描述

线性归一化:适用于样本分布均匀且集中的情况,如果最大值(或者最小值)不稳定,和绝大数样本数据相差较大,使用这种方法得到的结果也不稳定.为了抑制这个问题,在实际问题中可以用经验值来代替最大值和最小值
标准差归一化适用于样本近似正态分布,或者最大最小值未知的情况,有时当最大最小值处于孤立点时也可以使用标准差归一化
非线性映射归一化,通常用于数据分化较大的情况(有的很大有的很小)
总结:样本属性归一化需要根据属性样本分布规律定制

过程分析:

加载样本数据area,room,price
数据处理归一化,X,Y
设置超参数学习率,迭代次数
设置模型参数初值W0(w0,w1,w2)
训练模型W
结果可视化

在这里插入图片描述

一元线性回归代码以及可视化结果

#解析法实现一元线性回归 
# #Realization of one variable linear regression by analytic method
#导入库
import numpy as np
import matplotlib.pyplot as plt 
#设置字体
plt.rcParams['font.sans-serif'] =['SimHei']
#加载样本数据
x=np.array([137.97,104.50,100.00,124.32,79.20,99.00,124.00,114.00,106.69,138.05,53.75,46.91,68.00,63.02,81.26,86.21])
y=np.array([145.00,110.00,93.00,116.00,65.32,104.00,118.00,91.00,62.00,133.00,51.00,45.00,78.50,69.65,75.69,95.30])
#设置超参数,学习率
learn_rate=0.00001
#迭代次数
iter=100
#每10次迭代显示一下效果
display_step=10
#设置模型参数初值
np.random.seed(612)
w=np.random.randn()
b=np.random.randn()
#训练模型
#存放每次迭代的损失值
mse=[]
for i in range(0,iter+1):#求偏导dL_dw=np.mean(x*(w*x+b-y))dL_db=np.mean(w*x+b-y)#更新模型参数w=w-learn_rate*dL_dwb=b-learn_rate*dL_db#得到估计值pred=w*x+b#计算损失(均方误差)Loss=np.mean(np.square(y-pred))/2mse.append(Loss)#显示模型#plt.plot(x,pred)if i%display_step==0:print("i:%i,Loss:%f,w:%f,b:%f"%(i,mse[i],w,b))
#模型和数据可视化
plt.figure(figsize=(20,4))
plt.subplot(1,3,1)
#绘制散点图
#张量和数组都可以作为散点函数的输入提供点坐标
plt.scatter(x,y,color="red",label="销售记录")
plt.scatter(x,pred,color="blue",label="梯度下降法")
plt.plot(x,pred,color="blue")#设置坐标轴的标签文字和字号
plt.xlabel("面积(平方米)",fontsize=14)
plt.xlabel("价格(万元)",fontsize=14)#在左上方显示图例
plt.legend(loc="upper left")#损失变化可视化
plt.subplot(1,3,2)
plt.plot(mse)
plt.xlabel("迭代次数",fontsize=14)
plt.ylabel("损失值",fontsize=14)
#估计值与标签值比较可视化
plt.subplot(1,3,3)
plt.plot(y,color="red",marker="o",label="销售记录")
plt.plot(pred,color="blue",marker="o",label="梯度下降法")
plt.legend()
plt.xlabel("sample",fontsize=14)
plt.ylabel("price",fontsize=14)
#显示整个绘图
plt.show()

在这里插入图片描述

多元线性回归代码以及可视化结果

#解析法实现多元线性回归
#Realization of multiple linear regression by analytic method
#导入库与模块
import numpy as np
import matplotlib.pyplot as plt 
from mpl_toolkits.mplot3d import Axes3D
#=======================【1】加载样本数据===============================================
area=np.array([137.97,104.50,100.00,124.32,79.20,99.00,124.00,114.00,106.69,138.05,53.75,46.91,68.00,63.02,81.26,86.21])
room=np.array([3,2,2,3,1,2,3,2,2,3,1,1,1,1,2,2])
price=np.array([145.00,110.00,93.00,116.00,65.32,104.00,118.00,91.00,62.00,133.00,51.00,45.00,78.50,69.65,75.69,95.30])
num=len(area) #样本数量
#=======================【2】数据处理===============================================
x0=np.ones(num)
#归一化处理,这里使用线性归一化
x1=(area-area.min())/(area.max()-area.min())
x2=(room-room.min())/(room.max()-room.min())
#堆叠属性数组,构造属性矩阵
#从(16,)到(16,3),因为新出现的轴是第二个轴所以axis为1
X=np.stack((x0,x1,x2),axis=1)
print(X)
#得到形状为一列的数组
Y=price.reshape(-1,1)
print(Y)
#=======================【3】设置超参数===============================================
learn_rate=0.001
#迭代次数
iter=500
#每10次迭代显示一下效果
display_step=50
#=======================【4】设置模型参数初始值===============================================
np.random.seed(612)
W=np.random.randn(3,1)
#=======================【4】训练模型=============================================
mse=[]
for i in range(0,iter+1):#求偏导dL_dW=np.matmul(np.transpose(X),np.matmul(X,W)-Y)   #XT(XW-Y)#更新模型参数W=W-learn_rate*dL_dW#得到估计值PRED=np.matmul(X,W)#计算损失(均方误差)Loss=np.mean(np.square(Y-PRED))/2mse.append(Loss)#显示模型#plt.plot(x,pred)if i % display_step==0:print("i:%i,Loss:%f"%(i,mse[i]))
#=======================【5】结果可视化============================================
plt.rcParams['font.sans-serif'] =['SimHei']
plt.figure(figsize=(12,4))
#损失变化可视化
plt.subplot(1,2,1)
plt.plot(mse)
plt.xlabel("迭代次数",fontsize=14)
plt.ylabel("损失值",fontsize=14)
#估计值与标签值比较可视化
plt.subplot(1,2,2)
PRED=PRED.reshape(-1)
plt.plot(price,color="red",marker="o",label="销售记录")
plt.plot(PRED,color="blue",marker="o",label="预测房价")
plt.xlabel("sample",fontsize=14)
plt.ylabel("price",fontsize=14)
plt.legend()
plt.show()

在这里插入图片描述

总结

注意点:选择归一化方式


喜欢的话点个赞和关注呗!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/378356.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux x64 asm 参数传递,NASM汇编学习系列(1)——系统调用和参数传递

0. 说明本学习系列代码几乎完全摘自:asmtutor.com,如果英文可以的(也可以用谷歌浏览器翻译看),可以直接看asmtutor.com上的教程系统环境搭建:(我用的是ubuntu18.04.4 server,安装gcc、g)sudo apt install nasmsudo apt…

Javascript之创建对象(原型模式)

我们创建的每个函数都有一个prototype(原型)属性,这个属性是一个指针,指向一个对象,它的用途是包含可以有特定类型的所有实例共享的属性和方法。 prototype就是通过构造函数而创建的那个对象的原型对象。使用原型的好处就是可以让所有对象实例…

treeset java_Java TreeSet pollLast()方法与示例

treeset javaTreeSet类pollLast()方法 (TreeSet Class pollLast() method) pollLast() method is available in java.util package. pollLast()方法在java.util包中可用。 pollLast() method is used to return the last highest element and then remove the element from thi…

第五章 条件、循环及其他语句

第五章 条件、循环及其他语句 再谈print和import print现在实际上是一个函数 1,打印多个参数 用逗号分隔,打印多个表达式 sep自定义分隔符,默认空格 end自定义结束字符串,默认换行 print("beyond",yanyu,23)#结果为…

两种方法将Android NDK samples中hello-neon改成C++

一、第一种方法:1.修改helloneon.c 中代码 a.将 char* str; 改为 char str[512] {0}; b.将 asprintf(&str, "FIR Filter benchmark:\nC version : %g ms\n", time_c); 改为 sprintf(str, "FIR Filter benchmark:\nC ve…

【视觉项目】【day6】8.26关于matchTemplate()以及NCC的思考整理

NCC与matchTemplate()函数中match_method TM_CCOEFF_NORMED是否一样? 先看公式: TM_CCOEFF_NORMED NCCTM_CCOEFF_NORMED:归一化的相关性系数匹配方法 NCC:normalized cross correlation:归一化互相关系数 公式是一样的。 参考: 模板匹配的几…

linux待机流程,Linux睡眠喚醒機制--Kernel態

一、對於休眠(suspend)的簡單介紹 在Linux中,休眠主要分三個主要的步驟: 1) 凍結用戶態進程和內核態任務2) 調用注冊的設備的suspend的回調函數, 順序是按照注冊順序3) 休眠核心設備和使CPU進入休眠態, 凍結進程是內核把進程列表中所有的進程的狀態都設置為停止,並且保存下…

strictmath_Java StrictMath log1p()方法与示例

strictmathStrictMath类log1p()方法 (StrictMath Class log1p() method) log1p() method is available in java.lang package. log1p()方法在java.lang包中可用。 log1p() method is used to return (the logarithm of the sum of the given argument and 1 like log(1d) in th…

第六章 抽象

第六章 抽象 自定义函数 要判断某个对象是否可调用,可使用内置函数callable import math x 1 y math.sqrt callable(x)#结果为:False callable(y)#结果为:True使用def(表示定义函数)语句,来定义函数 …

HTTP 状态代码

如果向您的服务器发出了某项请求要求显示您网站上的某个网页(例如,当用户通过浏览器访问您的网页或在 Googlebot 抓取该网页时),那么,您的服务器会返回 HTTP 状态代码以响应该请求。 此状态代码提供了有关请求状态的信…

TensorFlow的可训练变量和自动求导机制

文章目录一些概念、函数、用法TensorFlow实现一元线性回归TensorFlow实现多元线性回归一些概念、函数、用法 对象Variable 创建对象Variable: tf.Variable(initial_value,dtype)利用这个方法,默认整数为int32,浮点数为float32,…

linux samba安装失败,用aptitude安装samba失败

版本:You are using Ubuntu 10.04 LTS- the Lucid Lynx - released in April 2010 and supported until April 2013.root下执行aptitude install sambaReading package lists... DoneBuilding dependency treeReading state information... DoneReading extended st…

django第二个项目--使用模板做一个站点访问计数器

上一节讲述了django和第一个项目HelloWorld,这节我们讲述如何使用模板,并做一个简单的站点访问计数器。 1、建立模板 在myblog模块文件夹(即包含__init__.py的文件夹)下面新建一个文件夹templates,用于存放HTML模板,在…

strictmath_Java StrictMath log10()方法与示例

strictmathStrictMath类log10()方法 (StrictMath Class log10() method) log10() method is available in java.lang package. log10()方法在java.lang包中可用。 log10() method is used to return the logarithm of the given (base 10) of the given argument in the method…

30、深入理解计算机系统笔记,并发编程(concurrent)(2)

1、共享变量 1)线程存储模型 线程由内核自动调度,每个线程都有它自己的线程上下文(thread context),包括一个惟一的整数线程ID(Thread ID,TID),栈,栈指针,程序…

PostgreSQL在何处处理 sql查询之十三

继续: /*--------------------* grouping_planner* Perform planning steps related to grouping, aggregation, etc.* This primarily means adding top-level processing to the basic* query plan produced by query_planner.** tuple_fraction i…

【视觉项目】基于梯度的NCC模板匹配代码以及效果

文章目录流程分析工程代码【1】NCC代码【Ⅰ】sttPxGrdnt结构体【Ⅱ】sttTemplateModel模板结构体【Ⅲ】calcAccNCC计算ncc系数函数【Ⅳ】searchNcc NCC模板匹配函数【Ⅴ】searchSecondNcc 二级搜索:在某一特定点周围再以步进为1搜索【2】测试图转外轮廓【Ⅰ】孔洞填…

第七章 再谈抽象

第七章 再谈抽象 对象魔法 多态:可对不同类型的对象执行相同的操作,而这些操作就像“被施了魔法”一样能够正常运行。(即:无需知道对象的内部细节就可使用它)(无需知道对象所属的类(对象的类型)就能调用其…

c语言math乘法,JavaScript用Math.imul()方法进行整数相乘

1. 基本概念Math.imul()方法用于计算两个32位整数的乘积,它的结果也是32位的整数。JavaScript的Number类型同时包含了整数和浮点数,它没有专门的整型和浮点型。因此,Math.imul()方法能提供类似C语言的整数相乘的功能。我们将Math.imul()方法的…

java scanner_Java Scanner nextLong()方法与示例

java scanner扫描器类的nextLong()方法 (Scanner Class nextLong() method) Syntax: 句法: public long nextLong();public long nextLong(int rad);nextLong() method is available in java.util package. nextLong()方法在java.util包中可用。 nextLong() method…