[Hands On ML] 4. 训练模型

本文为《机器学习实战:基于Scikit-Learn和TensorFlow》的读书笔记。
中文翻译参考

1. 线性回归

如何得到模型的参数

1.1 正规方程求解

  • 先生成带噪声的线性数据
import numpy as np
import matplotlib.pyplot as plt
X = 2*np.random.rand(100,1)
y = 4+3*X+np.random.randn(100,1)
plt.plot(X,y,"b.")
plt.axis([0,2,0,15])

在这里插入图片描述

  • 采用矩阵解方程,得到参数
X_b = np.c_[np.ones((100,1)),X]
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
theta_best
array([[4.46927218],[2.71589368]])
  • 预测新的数据
X_new = np.array([[0],[2]])
X_new_b = np.c_[np.ones((2,1)),X_new]
y_pred = X_new_b.dot(theta_best)
y_pred
array([[4.46927218],[9.90105954]])
  • 画出模型回归线
plt.plot(X_new,y_pred,"r-")
plt.plot(X,y,"b.")
plt.axis([0,2,0,15])
plt.show()

在这里插入图片描述

  • 使用sklearn求解
from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X,y)
lin_reg.intercept_, lin_reg.coef_ # (array([4.15725481]), array([[2.97840411]]))
lin_reg.predict(X_new)
array([[ 4.15725481],[10.11406304]])

1.2 时间复杂度

求解过程需要矩阵求逆,矩阵求逆时间复杂度在 O(n2.4)O(n^{2.4})O(n2.4)O(n3)O(n^3)O(n3) 之间,n 为特征数

  • 特征个数很多的时候,这种计算方法将会非常慢

1.3 梯度下降

整体思路:通过的迭代来逐渐调整参数使得损失函数达到最小值

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
由上图右侧可见,一开始的方向跟梯度方向几乎垂直,走了弯路。

当我们使用梯度下降的时候,应该确保所有的特征有着相近的尺度范围

(例如:使用 Scikit Learn 的 StandardScaler类),否则它将需要很长的时间才能够收敛。

  • 参数越多,找到最佳参数的难度也越大

1.4 批量梯度下降

  • 会使用全部的训练数据
  • 在大数据集上会变得很慢
eta = 0.1 # 学习率
n_iter = 1000
m = 100
theta = np.random.randn(2,1)for iter in range(n_iter):gradients = 2/m*X_b.T.dot(X_b.dot(theta)-y)theta = theta - eta*gradients
theta
array([[4.33118102],[2.8597418 ]])
  • 不同的学习率下,学习情况对比
eta = 0.1 # 学习率
n_iter = 1000
m = 100
theta = np.random.randn(2,1)plt.figure(figsize=(8,6))
plt.ion()# 打开交互模式
plt.axis([0,2,0,15])
plt.rcParams["font.sans-serif"] = "SimHei"for iter in range(n_iter):plt.cla() # 清除原图像gradients = 2/m*X_b.T.dot(X_b.dot(theta)-y)theta = theta - eta*gradientsX_new = np.array([[0],[2]])X_new_b = np.c_[np.ones((2,1)),X_new]y_pred = X_new_b.dot(theta)plt.plot(X,y,"b.")plt.plot(X_new,y_pred,"r-")plt.title("学习率:{:.2f}".format(eta))plt.pause(0.1) # 暂停一会display.clear_output(wait=True)# 刷新图像
plt.ioff()# 关闭交互模式    
plt.show()
theta

求解过程动图请参看博文:matplotlib 绘制梯度下降求解过程

  • 实际使用时,设置较大的迭代次数,和容差,当梯度向量变得非常小的时候,小于容差时,认为收敛,结束迭代

1.5 随机梯度下降

每一步梯度计算只随机选取训练集中的一个样本。这使得算法变得非常快。

  • 随机梯度算法可以在大规模训练集上使用
  • 由于随机性,它到达最小值不是平缓下降,损失函数会忽高忽低,大体呈下降趋势
  • 迭代点不会停止在一个值上,会一直在这个值附近摆动,最后的参数还不错,但不是最优值

由于其随机性,它能跳过局部最优解,但同时它却不能达到最小值。

解决办法:逐渐降低学习率

  • 开始时,走大步,快速前进+跳过局部最优解
  • 然后逐步降低学习率,使算法到达全局最小值。 这个过程被称为模拟退火,因为它类似于熔融金属慢慢冷却的冶金学退火过程

决定每次迭代的学习率的函数称为 learning schedule

  • 如果学习速度降得过快,可能陷入局部最小值,或者迭代次数到了半路就停止了
  • 如果学习速度降得太慢,可能在最小值附近震荡,如果过早停止训练,只得到次优解
from sklearn.linear_model import SGDRegressor
# help(SGDRegressor)
sgd_reg = SGDRegressor(max_iter=100, penalty=None, eta0=0.1)
sgd_reg.fit(X,y.ravel())
sgd_reg.intercept_, sgd_reg.coef_
(array([3.71001759]), array([2.99883799]))

1.6 小批量梯度下降

每次迭代的时候,使用一个随机的小型实例集

2. 多项式回归

依然可以使用线性模型来拟合非线性数据

  • 一个简单的方法:对每个特征进行加权后作为新的特征
  • 然后训练一个线性模型基于这个扩展的特征集。 这种方法称为多项式回归。
m = 100
X = 6*np.random.rand(m,1)-3
y = 0.5*X**2 + X + 2 + np.random.randn(m,1)
plt.rcParams["axes.unicode_minus"] = False # 显示负号
plt.plot(X, y, "g.")

在这里插入图片描述

from sklearn.preprocessing import PolynomialFeatures
pf = PolynomialFeatures(degree=2, include_bias=False)
# help(PolynomialFeatures)
X_ploy = pf.fit_transform(X)
print(X[0])
print(X_ploy[0])
  • 对原始特征进行2阶多项式转换后,多出了 X2
[2.43507761]
[2.43507761 5.92960298]
  • 进行线性回归
lin_reg = LinearRegression()
lin_reg.fit(X_ploy, y)
lin_reg.intercept_, lin_reg.coef_
(array([1.95147614]), array([[1.0462516 , 0.48003845]]))
  • 绘出预测线
plt.plot(X, y, "g.")
x = np.linspace(-3.5, 3.5, 500)
print(x.shape)
y_pred = lin_reg.intercept_ + lin_reg.coef_[0][0]*x + lin_reg.coef_[0][1]*x**2
plt.plot(x, y_pred, 'r-')

在这里插入图片描述
注意,阶数变大时,特征的维度会急剧上升,不仅有 ana^nan,还有 an−1b,an−2b2a^{n-1}b,a^{n-2}b^2an1b,an2b2

如何确定选择多少阶:

1、交叉验证

  • 在训练集上表现良好,但泛化能力很差,过拟合
  • 如果这两方面都不好,欠拟合。可知模型是太复杂还是太简单

2、观察学习曲线

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_splitdef plot_learning_curves(model, X, y):X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(X_train)):model.fit(X_train[:m], y_train[:m])y_train_predict = model.predict(X_train[:m])y_val_predict = model.predict(X_val)train_errors.append(mean_squared_error(y_train_predict, y_train[:m]))val_errors.append(mean_squared_error(y_val_predict, y_val))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")lin_reg = LinearRegression()
plot_learning_curves(lin_reg, X, y)

在这里插入图片描述

  • 上图显示训练集和测试集在数据不断增加的情况下,曲线趋于稳定,同时误差都非常大,欠拟合
  • 欠拟合,添加样本是没用的,需要更复杂的模型或更好的特征

模型的泛化误差由三个不同误差的和决定:

  • 偏差:模型假设不贴合,高偏差的模型最容易出现欠拟合
  • 方差:模型对训练数据的微小变化较为敏感,多自由度的模型更容易有高的方差(如高阶多项式),会导致过拟合
  • 不可约误差:数据噪声,可进行数据清洗

3. 线性模型正则化

限制模型的自由度,降低过拟合

  • 岭(Ridge)回归 L2正则
  • Lasso 回归 L1正则
  • 弹性网络(ElasticNet),以上两者的混合,r=0, 就是L2,r=1,就是 L1
    J(θ)=MSE(θ)+rα∑i=1n∣θi∣+1−r2α∑i=1nθi2J(\theta)=M S E(\theta)+r \alpha \sum_{i=1}^{n}\left|\theta_{i}\right|+\frac{1-r}{2} \alpha \sum_{i=1}^{n} \theta_{i}^{2}J(θ)=MSE(θ)+rαi=1nθi+21rαi=1nθi2
from sklearn.linear_model import Ridge
ridge_reg = Ridge(alpha=1, solver="cholesky")
ridge_reg.fit(X, y)
ridge_reg.predict([[1.5]]) # array([[5.04581676]])from sklearn.linear_model import Lasso
lasso_reg = Lasso(alpha=0.1)
lasso_reg.fit(X, y)
lasso_reg.predict([[1.5]])  # array([5.00189893])from sklearn.linear_model import ElasticNet
elastic_net = ElasticNet(alpha=0.1, l1_ratio=0.5)
elastic_net.fit(X, y)
elastic_net.predict([[1.5]]) # array([4.99822842])

4. 早期停止法(Early Stopping)

在这里插入图片描述
验证集 误差达到最小值,并开始上升时(出现过拟合),结束迭代,回滚到之前的最小值处

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/475026.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 895. 最大频率栈(哈希+按频数存储)

文章目录1. 题目2. 解题1. 题目 实现 FreqStack,模拟类似栈的数据结构的操作的一个类。 FreqStack 有两个函数: push(int x),将整数 x 推入栈中。pop(),它移除并返回栈中出现最频繁的元素。 如果最频繁的元素不只一个&#xff…

猪八戒背媳妇用计算机弹出来,猪八戒背媳妇?杭城游泳馆爆笑一幕:浙大学霸果然机智!...

游泳有时候不光比速度,还得比机智。6月19日,2021年“三好杯”游泳比赛在浙大紫金港校区游泳馆举行。这场比赛浙大各院系(学院)共有23支代表队、近220名运动员参赛。游泳项目是浙大的招牌项目之一,前不久的全国大学生阳光组(普通生)游泳比赛中…

html5支持多线程,html5 多线程

html5 多线程版本:HTML5运行者 Worker 接口是Web Workers API 的一部分,代表一个后台任务,它容易被创建并向创建者发回消息。创建一个运行者只要简单的调用Worker()构造函数,指定一个脚本,在工作线程中执行。运行者能够…

LeetCode 269. 火星词典(拓扑排序)

文章目录1. 题目2. 解题1. 题目 现有一种使用字母的全新语言,这门语言的字母顺序与英语顺序不同。 假设,您并不知道其中字母之间的先后顺序。 但是,会收到词典中获得一个 不为空的 单词列表。 因为是从词典中获得的,所以该单词列…

LeetCode 301. 删除无效的括号(回溯)

文章目录1. 题目2. 解题1. 题目 删除最小数量的无效括号,使得输入的字符串有效,返回所有可能的结果。 说明: 输入可能包含了除 ( 和 ) 以外的字符。 示例 1: 输入: "()())()" 输出: ["()()()", "(())()"]示例 2: 输入:…

javascript数组去重方法性能测试比较

昨天参加的一个前端面试,其中有一题数组去重,首先想到的是对象存键值的方法,代码如下 方法一:(简单存键值) Array.prototype.distinct1 function() {var i0,tmp{},thatthis.slice(0)this.length0;for(;i&l…

LeetCode 428. 序列化和反序列化 N 叉树(DFS)

文章目录1. 题目2. 解题1. 题目 序列化是指将一个数据结构转化为位序列的过程,因此可以将其存储在文件中或内存缓冲区中,以便稍后在相同或不同的计算机环境中恢复结构。 设计一个序列化和反序列化 N 叉树的算法。 一个 N 叉树是指每个节点都有不超过 N…

计算机进入休眠状态后,Win7电脑进入休眠状态后又自动重启该怎么处理

在使用win7系统的时候,有的小伙伴遇到了一个莫名其妙的问题:当电脑进入休眠状态后却突然自动重启了,那么这是怎么一回事呢?又该如何解决呢?别着急,接下来,小编就给大家分享一下Win7电脑进入休眠…

LeetCode 325. 和等于 k 的最长子数组长度(哈希表记录第一次出现的状态)

文章目录1. 题目2. 解题1. 题目 给定一个数组 nums 和一个目标值 k,找到和等于 k 的最长子数组长度。 如果不存在任意一个符合要求的子数组,则返回 0。 注意: nums 数组的总和是一定在 32 位有符号整数范围之内的。 示例 1: 输入: nums [1, -1, 5, -…

测试网上哪款软件最好,手机测试软件哪款好用?4款测试软件推荐

手机强不强测试上见真章!不服测个试呗!虽不能代表作手机的品质,但可以直观的反馈出手机硬件性能。通过专业的手机测试软件可以对手机硬件进行评分,了解手机每个硬件性能情况。鲁大师:《鲁大师》是一款支持Android、平板…

LeetCode 218. 天际线问题(multiset优先队列)*

文章目录1. 题目2. 解题1. 题目 城市的天际线是从远处观看该城市中所有建筑物形成的轮廓的外部轮廓。 现在,假设您获得了城市风光照片(图A)上显示的所有建筑物的位置和高度,请编写一个程序以输出由这些建筑物形成的天际线&#x…

LeetCode 277. 搜寻名人(思维题)

文章目录1. 题目2. 解题2.1 暴力解2.2 高效解1. 题目 假设你是一个专业的狗仔,参加了一个 n 人派对,其中每个人被从 0 到 n - 1 标号。 在这个派对人群当中可能存在一位 “名人”。 所谓 “名人” 的定义是:其他所有 n - 1 个人都认识他/她&…

最近很火的计算机歌曲,抖音日活跃用户数超4亿 2019年度最火音乐竟是它

抖音今日发布《2019抖音数据报告》(以下简称报告),报告显示,截至2020年1月5日,抖音日活跃用户数超过4亿。根据报告,抖音上不同年龄段用户最爱拍摄的内容不尽相同,00后喜欢拍摄二次元相关视频,90后用户喜欢拍…

LeetCode 432. 全 O(1) 的数据结构(设计题)*

文章目录1. 题目2. 解题1. 题目 请你实现一个数据结构支持以下操作: Inc(key) - 插入一个新的值为 1 的 key。 或者使一个存在的 key 增加一,保证 key 不为空字符串。Dec(key) - 如果这个 key 的值是 1,那么把他从数据结构中移除掉。 否则使…

vs2010 rdlc 报表及报表控件

有个winfrom项目要使用报表,数据来源于自定义类(model),从网上找了好多教程,都是说如何拖控件,如何设值之类的。没有我想要的效果。 我想要的效果:将rdlc文件放到Debug目录下,以便一…

设置 NSZombieEnabled 定位 EXC_BAD_ACCESS 错误

http://unmi.cc/nszombieenabled-locate-exc_bad_access-error, 来自 隔叶黄莺 Unmi Blog 我们做 iOS 程序开发时经常用遇到 EXC_BAD_ACCESS 错误导致 Crash,出现这种错误时一般 Xcode 不会给我们太多的信息来定位错误来源,只是在应用 Delegate 上留下像…

LeetCode 785. 判断二分图(染色法)

文章目录1. 题目2. 解题1. 题目 给定一个无向图graph,当这个图为二分图时返回true。 如果我们能将一个图的节点集合分割成两个独立的子集A和B,并使图中的每一条边的两个节点一个来自A集合,一个来自B集合,我们就将这个图称为二分…

css检测,CSS检测工具 CSS Lint简介

Nicholas C. Zakas最近发布了CSS Lint,旨在检测CSS代码中存在的各种问题,从而写出更高效的CSS。CSS Lint现有的一些规则:修复解析错误(Parsing errors should be fixed)避免使用多类选择符(Dont use adjoining classes)IE6以及更古老的浏览器…

LeetCode 1506. Find Root of N-Ary Tree(异或)

文章目录1. 题目2. 解题1. 题目 Given all the nodes of an N-ary tree as an array Node[] tree where each node has a unique value. Find and return the root of the N-ary tree. Follow up: Could you solve this problem in constant space complexity with a linea…

客户端与服务器之间的文件传输,客户端与服务器的文件传输

客户端与服务器的文件传输 内容精选换一换使用FTP上传文件时,写入失败,文件传输失败。该文档适用于Windows系统上的FTP服务。FTP服务端在NAT环境下,客户端需使用被动模式连接服务端。在这种情况下,服务端的IP地址无法从路由器外部…