[Hands On ML] 4. 训练模型

本文为《机器学习实战:基于Scikit-Learn和TensorFlow》的读书笔记。
中文翻译参考

1. 线性回归

如何得到模型的参数

1.1 正规方程求解

  • 先生成带噪声的线性数据
import numpy as np
import matplotlib.pyplot as plt
X = 2*np.random.rand(100,1)
y = 4+3*X+np.random.randn(100,1)
plt.plot(X,y,"b.")
plt.axis([0,2,0,15])

在这里插入图片描述

  • 采用矩阵解方程,得到参数
X_b = np.c_[np.ones((100,1)),X]
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
theta_best
array([[4.46927218],[2.71589368]])
  • 预测新的数据
X_new = np.array([[0],[2]])
X_new_b = np.c_[np.ones((2,1)),X_new]
y_pred = X_new_b.dot(theta_best)
y_pred
array([[4.46927218],[9.90105954]])
  • 画出模型回归线
plt.plot(X_new,y_pred,"r-")
plt.plot(X,y,"b.")
plt.axis([0,2,0,15])
plt.show()

在这里插入图片描述

  • 使用sklearn求解
from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X,y)
lin_reg.intercept_, lin_reg.coef_ # (array([4.15725481]), array([[2.97840411]]))
lin_reg.predict(X_new)
array([[ 4.15725481],[10.11406304]])

1.2 时间复杂度

求解过程需要矩阵求逆,矩阵求逆时间复杂度在 O(n2.4)O(n^{2.4})O(n2.4)O(n3)O(n^3)O(n3) 之间,n 为特征数

  • 特征个数很多的时候,这种计算方法将会非常慢

1.3 梯度下降

整体思路:通过的迭代来逐渐调整参数使得损失函数达到最小值

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
由上图右侧可见,一开始的方向跟梯度方向几乎垂直,走了弯路。

当我们使用梯度下降的时候,应该确保所有的特征有着相近的尺度范围

(例如:使用 Scikit Learn 的 StandardScaler类),否则它将需要很长的时间才能够收敛。

  • 参数越多,找到最佳参数的难度也越大

1.4 批量梯度下降

  • 会使用全部的训练数据
  • 在大数据集上会变得很慢
eta = 0.1 # 学习率
n_iter = 1000
m = 100
theta = np.random.randn(2,1)for iter in range(n_iter):gradients = 2/m*X_b.T.dot(X_b.dot(theta)-y)theta = theta - eta*gradients
theta
array([[4.33118102],[2.8597418 ]])
  • 不同的学习率下,学习情况对比
eta = 0.1 # 学习率
n_iter = 1000
m = 100
theta = np.random.randn(2,1)plt.figure(figsize=(8,6))
plt.ion()# 打开交互模式
plt.axis([0,2,0,15])
plt.rcParams["font.sans-serif"] = "SimHei"for iter in range(n_iter):plt.cla() # 清除原图像gradients = 2/m*X_b.T.dot(X_b.dot(theta)-y)theta = theta - eta*gradientsX_new = np.array([[0],[2]])X_new_b = np.c_[np.ones((2,1)),X_new]y_pred = X_new_b.dot(theta)plt.plot(X,y,"b.")plt.plot(X_new,y_pred,"r-")plt.title("学习率:{:.2f}".format(eta))plt.pause(0.1) # 暂停一会display.clear_output(wait=True)# 刷新图像
plt.ioff()# 关闭交互模式    
plt.show()
theta

求解过程动图请参看博文:matplotlib 绘制梯度下降求解过程

  • 实际使用时,设置较大的迭代次数,和容差,当梯度向量变得非常小的时候,小于容差时,认为收敛,结束迭代

1.5 随机梯度下降

每一步梯度计算只随机选取训练集中的一个样本。这使得算法变得非常快。

  • 随机梯度算法可以在大规模训练集上使用
  • 由于随机性,它到达最小值不是平缓下降,损失函数会忽高忽低,大体呈下降趋势
  • 迭代点不会停止在一个值上,会一直在这个值附近摆动,最后的参数还不错,但不是最优值

由于其随机性,它能跳过局部最优解,但同时它却不能达到最小值。

解决办法:逐渐降低学习率

  • 开始时,走大步,快速前进+跳过局部最优解
  • 然后逐步降低学习率,使算法到达全局最小值。 这个过程被称为模拟退火,因为它类似于熔融金属慢慢冷却的冶金学退火过程

决定每次迭代的学习率的函数称为 learning schedule

  • 如果学习速度降得过快,可能陷入局部最小值,或者迭代次数到了半路就停止了
  • 如果学习速度降得太慢,可能在最小值附近震荡,如果过早停止训练,只得到次优解
from sklearn.linear_model import SGDRegressor
# help(SGDRegressor)
sgd_reg = SGDRegressor(max_iter=100, penalty=None, eta0=0.1)
sgd_reg.fit(X,y.ravel())
sgd_reg.intercept_, sgd_reg.coef_
(array([3.71001759]), array([2.99883799]))

1.6 小批量梯度下降

每次迭代的时候,使用一个随机的小型实例集

2. 多项式回归

依然可以使用线性模型来拟合非线性数据

  • 一个简单的方法:对每个特征进行加权后作为新的特征
  • 然后训练一个线性模型基于这个扩展的特征集。 这种方法称为多项式回归。
m = 100
X = 6*np.random.rand(m,1)-3
y = 0.5*X**2 + X + 2 + np.random.randn(m,1)
plt.rcParams["axes.unicode_minus"] = False # 显示负号
plt.plot(X, y, "g.")

在这里插入图片描述

from sklearn.preprocessing import PolynomialFeatures
pf = PolynomialFeatures(degree=2, include_bias=False)
# help(PolynomialFeatures)
X_ploy = pf.fit_transform(X)
print(X[0])
print(X_ploy[0])
  • 对原始特征进行2阶多项式转换后,多出了 X2
[2.43507761]
[2.43507761 5.92960298]
  • 进行线性回归
lin_reg = LinearRegression()
lin_reg.fit(X_ploy, y)
lin_reg.intercept_, lin_reg.coef_
(array([1.95147614]), array([[1.0462516 , 0.48003845]]))
  • 绘出预测线
plt.plot(X, y, "g.")
x = np.linspace(-3.5, 3.5, 500)
print(x.shape)
y_pred = lin_reg.intercept_ + lin_reg.coef_[0][0]*x + lin_reg.coef_[0][1]*x**2
plt.plot(x, y_pred, 'r-')

在这里插入图片描述
注意,阶数变大时,特征的维度会急剧上升,不仅有 ana^nan,还有 an−1b,an−2b2a^{n-1}b,a^{n-2}b^2an1b,an2b2

如何确定选择多少阶:

1、交叉验证

  • 在训练集上表现良好,但泛化能力很差,过拟合
  • 如果这两方面都不好,欠拟合。可知模型是太复杂还是太简单

2、观察学习曲线

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_splitdef plot_learning_curves(model, X, y):X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(X_train)):model.fit(X_train[:m], y_train[:m])y_train_predict = model.predict(X_train[:m])y_val_predict = model.predict(X_val)train_errors.append(mean_squared_error(y_train_predict, y_train[:m]))val_errors.append(mean_squared_error(y_val_predict, y_val))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")lin_reg = LinearRegression()
plot_learning_curves(lin_reg, X, y)

在这里插入图片描述

  • 上图显示训练集和测试集在数据不断增加的情况下,曲线趋于稳定,同时误差都非常大,欠拟合
  • 欠拟合,添加样本是没用的,需要更复杂的模型或更好的特征

模型的泛化误差由三个不同误差的和决定:

  • 偏差:模型假设不贴合,高偏差的模型最容易出现欠拟合
  • 方差:模型对训练数据的微小变化较为敏感,多自由度的模型更容易有高的方差(如高阶多项式),会导致过拟合
  • 不可约误差:数据噪声,可进行数据清洗

3. 线性模型正则化

限制模型的自由度,降低过拟合

  • 岭(Ridge)回归 L2正则
  • Lasso 回归 L1正则
  • 弹性网络(ElasticNet),以上两者的混合,r=0, 就是L2,r=1,就是 L1
    J(θ)=MSE(θ)+rα∑i=1n∣θi∣+1−r2α∑i=1nθi2J(\theta)=M S E(\theta)+r \alpha \sum_{i=1}^{n}\left|\theta_{i}\right|+\frac{1-r}{2} \alpha \sum_{i=1}^{n} \theta_{i}^{2}J(θ)=MSE(θ)+rαi=1nθi+21rαi=1nθi2
from sklearn.linear_model import Ridge
ridge_reg = Ridge(alpha=1, solver="cholesky")
ridge_reg.fit(X, y)
ridge_reg.predict([[1.5]]) # array([[5.04581676]])from sklearn.linear_model import Lasso
lasso_reg = Lasso(alpha=0.1)
lasso_reg.fit(X, y)
lasso_reg.predict([[1.5]])  # array([5.00189893])from sklearn.linear_model import ElasticNet
elastic_net = ElasticNet(alpha=0.1, l1_ratio=0.5)
elastic_net.fit(X, y)
elastic_net.predict([[1.5]]) # array([4.99822842])

4. 早期停止法(Early Stopping)

在这里插入图片描述
验证集 误差达到最小值,并开始上升时(出现过拟合),结束迭代,回滚到之前的最小值处

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/475026.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

非常经典的C字符串函数的实现

1. strlen(),计算字符串长度 1 int strlen(const char string) 2 3 { 4 5 int i0; 6 7 while(string[i]) i; 8 9 return i; 10 11 } 2. strcpy(), 字符串拷贝. 1 char *strcpy(char *destination, const char *source) 2 3 { 4 5 while…

LeetCode 895. 最大频率栈(哈希+按频数存储)

文章目录1. 题目2. 解题1. 题目 实现 FreqStack,模拟类似栈的数据结构的操作的一个类。 FreqStack 有两个函数: push(int x),将整数 x 推入栈中。pop(),它移除并返回栈中出现最频繁的元素。 如果最频繁的元素不只一个&#xff…

七年级上册计算机重点知识点,初一上册数学重点知识点

为了方便大家更好的学习和复习初一上册数学课本内容,现将初一上册数学重要内容整理分享出来。有理数重点知识点(1)定义:由整数和分数组成的数。包括:正整数、0、负整数,正分数、负分数。可以写成两个整之比的形式。(2)数轴&#x…

友元关系

友元关系转载于:https://www.cnblogs.com/LoveFishC/archive/2012/08/01/3846663.html

猪八戒背媳妇用计算机弹出来,猪八戒背媳妇?杭城游泳馆爆笑一幕:浙大学霸果然机智!...

游泳有时候不光比速度,还得比机智。6月19日,2021年“三好杯”游泳比赛在浙大紫金港校区游泳馆举行。这场比赛浙大各院系(学院)共有23支代表队、近220名运动员参赛。游泳项目是浙大的招牌项目之一,前不久的全国大学生阳光组(普通生)游泳比赛中…

html5支持多线程,html5 多线程

html5 多线程版本:HTML5运行者 Worker 接口是Web Workers API 的一部分,代表一个后台任务,它容易被创建并向创建者发回消息。创建一个运行者只要简单的调用Worker()构造函数,指定一个脚本,在工作线程中执行。运行者能够…

magento tab(easy tables)标签应用

我介绍的主要是magento 1.7.0.2版本。 因为彼人刚接触magento一星期,了解有限,理解有误的地方 还请多多包含。 easy tables 在1.7.0.2版本中,默认是在app/design/frontend/default/modern/layout/template/catalog.xml; 让我们先找到这个文件…

LeetCode 269. 火星词典(拓扑排序)

文章目录1. 题目2. 解题1. 题目 现有一种使用字母的全新语言,这门语言的字母顺序与英语顺序不同。 假设,您并不知道其中字母之间的先后顺序。 但是,会收到词典中获得一个 不为空的 单词列表。 因为是从词典中获得的,所以该单词列…

南工大计算机学院,江南-欢迎访问湖北工业大学计算机学院官方网站

科研情况介绍(研究方向、研究课题、现正进行的科研项目)研究方向:计算机软件与理论。近3年来主要个人成果、参加学术团体及社会兼职情况:1、机械化定理证明研究综述.第一作者.软件学报. 20192、mJava到Micro-Dalvik虚拟机的编译验证.第一作者.电子学报20…

邻接表的两种实现(链表和数组模拟)

struct node {int v; //边的结束顶点 int w; //边的长度node* next; //指向以同一起点的下一条边的指针 }*first[N]; //first[u]指向以u为起始点的第一条边 void init() {memset(first,NULL,sizeof(first)); } void add(int u, int v, int w)//添加边 {node* p new node;p->…

LeetCode 301. 删除无效的括号(回溯)

文章目录1. 题目2. 解题1. 题目 删除最小数量的无效括号,使得输入的字符串有效,返回所有可能的结果。 说明: 输入可能包含了除 ( 和 ) 以外的字符。 示例 1: 输入: "()())()" 输出: ["()()()", "(())()"]示例 2: 输入:…

计算机程序专利实用新型,涉及计算机程序的实用新型专利保护的思考

随着信息技术的不断发展,与计算机程序相关的计算机技术以及通信技术渗透到各个领域,越来越多的专利涉及了与计算机程序相关的技术。那么,是否包含计算机程序的相关专利申请都不能被授予实用新型专利权呢?本文从一件复审案例出发&a…

javascript数组去重方法性能测试比较

昨天参加的一个前端面试,其中有一题数组去重,首先想到的是对象存键值的方法,代码如下 方法一:(简单存键值) Array.prototype.distinct1 function() {var i0,tmp{},thatthis.slice(0)this.length0;for(;i&l…

LeetCode 428. 序列化和反序列化 N 叉树(DFS)

文章目录1. 题目2. 解题1. 题目 序列化是指将一个数据结构转化为位序列的过程,因此可以将其存储在文件中或内存缓冲区中,以便稍后在相同或不同的计算机环境中恢复结构。 设计一个序列化和反序列化 N 叉树的算法。 一个 N 叉树是指每个节点都有不超过 N…

计算机进入休眠状态后,Win7电脑进入休眠状态后又自动重启该怎么处理

在使用win7系统的时候,有的小伙伴遇到了一个莫名其妙的问题:当电脑进入休眠状态后却突然自动重启了,那么这是怎么一回事呢?又该如何解决呢?别着急,接下来,小编就给大家分享一下Win7电脑进入休眠…

Syslistview32+Systreeview32系统操作动态链接库和实际的商业化

Syslistview32和Systreeview32 是两个极其常用的系统控件,一个是列表控件,一个是树形框,只要能随意操控这两个控件就能够从外部控制住大多应用到这两个控件的软件。 一开始是想要控制VS平台的列表框来操作自动进房间,但是苦于没有…

LeetCode 325. 和等于 k 的最长子数组长度(哈希表记录第一次出现的状态)

文章目录1. 题目2. 解题1. 题目 给定一个数组 nums 和一个目标值 k,找到和等于 k 的最长子数组长度。 如果不存在任意一个符合要求的子数组,则返回 0。 注意: nums 数组的总和是一定在 32 位有符号整数范围之内的。 示例 1: 输入: nums [1, -1, 5, -…

测试网上哪款软件最好,手机测试软件哪款好用?4款测试软件推荐

手机强不强测试上见真章!不服测个试呗!虽不能代表作手机的品质,但可以直观的反馈出手机硬件性能。通过专业的手机测试软件可以对手机硬件进行评分,了解手机每个硬件性能情况。鲁大师:《鲁大师》是一款支持Android、平板…

Android4开发入门经典 之 第七部分:数据存储

数据存储基本知识 Android系统提供了多种数据存储的方式,如下: 1:Shared Preferences:用来存储私有的、原始类型的、简单的数据,通常是Key-value对2:Internal Storage:在设备内部存储器中存储数…

LeetCode 218. 天际线问题(multiset优先队列)*

文章目录1. 题目2. 解题1. 题目 城市的天际线是从远处观看该城市中所有建筑物形成的轮廓的外部轮廓。 现在,假设您获得了城市风光照片(图A)上显示的所有建筑物的位置和高度,请编写一个程序以输出由这些建筑物形成的天际线&#x…