【深度学习】线性回归

Linear Regression

  • 一个例子
  • 线性回归
  • 机器学习中的表达
  • 评价函数好坏的度量:损失(Loss)
  • 损失函数(Loss function)
    • 哪个数据集的均方误差 (MSE) 高
  • 如何找出最优b和w?
  • 寻找最优b和w
  • 如何降低损失 (Reducing Loss)
  • 梯度下降法
    • 梯度
    • 计算梯度f^'^(x)i
    • 梯度下降法(gradient descent)
    • 学习速率(Learning rate)
    • 金发女孩原则(Goldilocks principle)
  • 如何训练模型
  • 习题

一个例子

相比凉爽的天气,蟋蟀在较为炎热的天气里鸣叫更频繁。
现有数据如下图,请预测鸣叫声与温度的关系。
在这里插入图片描述

线性回归

如图这种由点确定线的过程叫回归,既找出因变量和自变量之间的关系

y = mx + b

y指的是温度(摄氏度),即我们试图预测的值
m指的是直线的斜率
x指的是每分钟的鸣叫声次数,即输入特征的值
b指的是y轴截距

机器学习中的表达

y=𝑤𝑥+𝑏

x 特征(feature)
y 预测值(target)
w 权重(weight)
b 偏差(bias)

在这里插入图片描述
样本(samples) (x,y)

在这里插入图片描述

如何找出最优的函数 , 即找出一组最佳的w和b?
首先需要对函数的好坏进行评价

评价函数好坏的度量:损失(Loss)

反应模型预测(prediction)的准确程度。

如果模型(model)的预测完全准确,则损失为零
训练模型的目标是从所有样本中找到一组平均损失 “较小” 的权重和偏差。

在这里插入图片描述

显然,相较于左侧曲线图中的蓝线,右侧曲线图中的蓝线代表的是预测效果更好的模型

损失函数(Loss function)

损失函数是指汇总各样本损失的数学函数。

· 平方损失(又称为 L2 损失)
· 单个样本的平方损失= ( observation - prediction(x) )2 = ( y - y’ )2
· 均方误差 (mean-square error, MSE) 指的是样本的平均平方损失

在这里插入图片描述

哪个数据集的均方误差 (MSE) 高

以下两幅图显示的两个数据集,哪个数据集的均方误差 (MSE) 较高?

在这里插入图片描述

B图线上的 8 个样本产生的总损失为 0。不过,尽管只有两个点在线外,但这两个点的离线距离依然是左图中离群点的 2 倍。平方损失进一步加大差异,因此两个点的偏移量产生的损失是一个点的 4 倍。
· 根据损失函数的定义
· 对于A:MSE = (02 + 12 + 02 + 12 + 02 + 12 + 02 + 12 + 02 + 02)/10 = 0.4
· 对于B:MSE = (02 + 02 + 02 + 22 + 02 + 02 + 02 + 22 + 02 + 02)/10 = 0.8

→因此B的MSE较高。

如何找出最优b和w?

定义了损失函数,我们就可以评价任一函数的好坏,下一步如何找出最优b和w?
靠猜~
在这里插入图片描述

像猜价格游戏,参与者给出初始价钱,通过“高了”或“低了”的提示,逐渐接近正确的价格。
在这里插入图片描述

寻找最优b和w

1. 首先随机给出一组参数b=1,w=1 
2. 评价这组参数的好坏,例如用MSE 
3. 改变w和b的值
4. 转到步骤2,直到总体损失不再变化或变化极其缓慢为止,该模型已收敛

在这里插入图片描述

最后一个问题:如何改变w、b ???

如何降低损失 (Reducing Loss)

简化问题,以只有一个参数w为例,所产生的损失Loss与 w 的图形是凸形(convex)。如下所示:

在这里插入图片描述

只有一个最低点 → 即只存在一个斜率正好为 0 的位置。
损失函数取到最小值的地方。

但是,如何找到这一点呢?

  1. 为w选择一个起点↓ 这里选择了一个稍大于 0 的起点
    在这里插入图片描述

梯度下降法

梯度

· 梯度是偏导数的矢量;有方向和大小。

· 梯度即是某一点最大的方向导数,沿梯度方向函数有最大的变化率
(沿梯度方向函数增加,负梯度方向函数减少)

· 损失Loss相对于单个权重的梯度大小就等于导数f(x)i

二元函数的梯度:
在这里插入图片描述

计算梯度f(x)i

求导数,即切线的斜率,在这个例子中,是负的,因此负梯度是w的正方向。

在这里插入图片描述

梯度下降法(gradient descent)

w=w-lr*w_grad

lr (learning rate,学习速率)

在这里插入图片描述

然后,重复此过程,逐渐接近最低点。
在这里插入图片描述
⭐负梯度指向损失函数下降的方向

学习速率(Learning rate)

通常梯度下降法用梯度乘以学习速率(步长),以确定下一个点的位置:

w=w-lr*w_grad

例如,如果学习速率为 0.01,梯度大小为 2.5,则:
w=w-0.01*2.5

学习速率是机器学习算法中人为引入的,是用于调整机器学习算法的旋钮,这种称为超参数。
在这里插入图片描述

金发女孩原则(Goldilocks principle)

每个回归问题都存在一个“Goldilocks principle”学习速率,该值与损失函数的平坦程度相关。 例如,如果损失函数的梯度较小,则可以采用更大的学习速率,以补偿较小的梯度并获得更大的步长。

(西方有一个儿童故事叫 “ The Three Bears(金发女孩与三只小熊)”,迷路了的金发姑娘未经允许就进入了熊的房子,她尝了三只碗里的粥,试了三把椅子,又在三张床上躺了躺。最后发现不烫不冷的粥最可口,不大不小的椅子坐着最舒服,不高不矮的床上躺着最惬意。道理很简单,刚刚好就是最适合的,just the right amount,这样做选择的原则被称为 Goldilocks principle(金发女孩原则)。)

采取恰当学习速率,可以高效的到达最低点,如下图所示:

在这里插入图片描述
降低损失:优化学习速率-模拟动图

如何训练模型

  1. 定义一个函数的集合
  2. 给出评价函数好坏的方法
  3. 利用梯度下降法找到最佳函数

在这里插入图片描述

习题

在这里插入图片描述

import numpy as np
import matplotlib.pyplot as pltdef gradient_descent(x, y, theta, learning_rate, epochs):ws = []bs = []for i in range(epochs):y_pred = x.dot(theta)diff = y_pred - yloss = 0.5 * np.mean(diff ** 2)g = x.T.dot(diff)theta -= learning_rate * gws.append(theta[0][0])bs.append(theta[1][0])learning_rate = learning_rate_shedule(i + 1)print(f'第{i + 1}次梯度下降后,损失为{round(loss, 5)},w为{round(theta[0][0], 5)},b为{round(theta[1][0], 5)}')if loss < 0.001:breakreturn ws, bsdef learning_rate_shedule(t):return t0 / (t + t1)if __name__ == '__main__':xdata = np.array([8, 3, 9, 7, 16, 5, 3, 10, 4, 6])ydata = np.array([30, 21, 35, 27, 42, 24, 10, 38, 22, 25])x_data = np.array(xdata).reshape(-1, 1)y_data = np.array(ydata).reshape(-1, 1)X = np.concatenate([x_data, np.full_like(x_data, fill_value=1)], axis=1)theta = np.random.randn(2, 1)t0 = 1.5t1 = 1000ws, bs = gradient_descent(X, y_data, theta, learning_rate_shedule(0), 10000)ax = plt.subplot(3, 2, 1)bx = plt.subplot(3, 2, 2)cx = plt.subplot(3, 1, (2, 3))# 散点+预测线ax.scatter(x_data, y_data)x = np.linspace(x_data.min() - 1, x_data.max() + 1, 100)y = ws[-1] * x + bs[-1]ax.plot(x, y, color='red')ax.set_xlabel('x')ax.set_ylabel('y')#w,b变化x = np.arange(1, len(ws) + 1)bx.plot(x, ws, label='w')bx.plot(x, bs, label='b')bx.set_xlabel('epoch')bx.set_ylabel('change')bx.legend()#等高线def get_loss(w, b):return 0.5 * np.mean((y_data - (w * x_data + b)) ** 2)b_range = np.linspace(-100, 100, 100)w_range = np.linspace(-10, 10, 100)losses = np.zeros((len(b_range), len(w_range)))for i in range(len(b_range)):for j in range (len(w_range)):losses[i, j] = get_loss(w_range[j], b_range[i])cx.contour(b_range, w_range, losses, cmap='summer')cx.contourf(b_range, w_range, losses)cx.set_xlabel('b')cx.set_ylabel('w')plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/737345.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

我们都想要一个好的前景

大家好&#xff0c;我是记得诚。 有个读者向我咨询了一下他的就业问题。 问题&#xff1a; 大佬好&#xff0c;我咨询一下就业方向问题。我本身是大专毕业的&#xff0c;专业是应用电子技术&#xff0c;学了一部分硬件知识和软件。 毕业后第一份工作是去一家比较小的医疗机…

天猫魔盒解码报错

最近有个天猫魔盒&#xff08;Tmall,MagicBox_M17,MagicBox_M17&#xff09;有报错&#xff0c;报错信息如下&#xff1a; generic_decoder.cc, (line 98): Too many frames backed up in the decoder, dropping frame with timestamp 4219980314https://chromium.googlesourc…

工人安全绳穿戴识别系统---豌豆云

工人安全绳穿戴识别系统采用视频图像自动识别的形式&#xff0c;豌豆云工人安全绳穿戴识别系统通过安装在作业区域的监控摄像头。 一旦发现工人高空作业未佩戴安全带的情况&#xff0c;系统可以立即发出告警&#xff0c;相关人员可以迅速采取措施&#xff0c;防止事故的发生。…

记一次多线程写入文件出现IOException:Stream Closed的问题

背景 网关在解析1000个05文件&#xff08;txt&#xff09;写入到SFTP文件时&#xff0c;是每次读取1000 * 5条数据&#xff0c;然后每1000笔数据创建一个线程逐条数据进行字段数值映射转换&#xff0c;一共创建5个线程扔到线程池进行处理。每条数据解析完都会将数据写入到SFTP的…

绪论——算法设计原则【数据科学与工程算法基础】

一、题记 最近情绪不太稳定&#xff0c;些许烦躁&#xff0c;也就一直没践行前边说的“学习记录”的想法。现在开始做了&#xff0c;春华易逝&#xff0c;正当时&#xff0c;有想法就去做&#xff0c;踌躇懊悔是这个年纪最不该做的事。 二、前言 之前说了分块做这个系列&#x…

101. Go单测系列1---使用monkey打桩

本文将介绍如何在单元测试中使用monkey进行打桩。 monkey支持为任意函数及方法进行打桩。 monkey介绍 monkey是一个Go单元测试中十分常用的打桩工具&#xff0c;它在运行时通过汇编语言重写可执行文件&#xff0c;将目标函数或方法的实现跳转到桩实现&#xff0c;其原理类似…

我用 Python 做了个小仙女代码蹦迪视频

前言 最近在B站上看到一个漂亮的仙女姐姐跳舞视频&#xff0c;循环看了亿遍又亿遍&#xff0c;久久不能离开&#xff01; 看着仙紫小姐姐的蹦迪视频&#xff0c;除了一键三连还能做什么&#xff1f;突发奇想&#xff0c;能不能把小仙女的蹦迪视频转成代码舞呢&#xff1f; 说…

uniapp引入jQuery

安装 npm install jquery --saveoryarn add jquery引入 import Vue from vue import jquery from "jquery"; Vue.prototype.$ jquery;<template><view>abc</view> </template><script>export default {data() {return {}}} </scr…

Vue3全家桶 - VueRouter - 【1】快速使用(创建路由模块 + 规定路由模式 + 使用路由规则 + RouterView-RouterLink)

VueRouter Vue-Router官网&#xff1b;vue-router 是 vue.js 官方给出的路由解决方案&#xff0c;能够轻松的管理 SPA 项目中组件的切换&#xff1b;安装&#xff1a;yarn add vue-router4&#xff1b; 快速使用 1.1 创建路由模块 在项目中的 src 文件夹中创建一个 router …

【智慧公寓】东胜物联嵌入式硬件解决方案,为智慧公寓解决方案商降本增效,更快实现产品规模化生产

方案背景 东胜物联本次服务的客户是一家专注于提供智慧公寓解决方案的欧洲企业&#xff0c;该公司旨在为用户提供智能&#xff0c;便捷&#xff0c;安全的生活体验。其解决方案涵盖智慧公寓控制、自动化、能源管理和智能建筑&#xff0c;它的使命是通过复杂的控制系统使用户能…

【OpenGL实现 03】纹理贴图原理和实现

目录 一、说明二、纹理贴图原理2.1 纹理融合原理2.2 UV坐标原理 三、生成纹理对象3.1 需要在VAO上绑定纹理坐标3.2 纹理传递3.3 纹理buffer生成 四、代码实现&#xff1a;五、着色器4.1 片段4.2 顶点 五、后记 一、说明 本篇叙述在画出图元的时候&#xff0c;如何贴图纹理图片…

局域网管理工具

每个组织的业务运营方法都是独一无二的&#xff0c;其网络基础设施也是如此&#xff0c;由于随着超融合基础设施等新计算技术的发展&#xff0c;局域网变得越来越复杂&#xff0c;因此局域网管理也应该如此&#xff0c;组织需要量身定制的局域网管理解决方案&#xff0c;这些解…

php 面试题目

当涉及到PHP排序的面试题目时&#xff0c;面试官通常会希望了解你对PHP内置排序函数的理解&#xff0c;以及你如何处理复杂的排序需求。以下是一些可能的PHP排序面试题目&#xff1a; 解释PHP中sort(), rsort(), asort(), arsort(), ksort(), 和 krsort()等函数的区别和用途。…

probiller怎么订阅

很多小伙伴想订阅probiller&#xff0c;但是不知道怎么订阅&#xff0c;这里我使用的是556150的卡订阅的&#xff0c;亲测~~ 所以有想订阅的小伙伴可以点击获取5561卡片&#xff0c;此卡0年费、0月费 下面请看订阅记录 开卡步骤请看图 卡信息在卡中心&#xff0c;cvc安全码 …

(五)关系数据库标准语言SQL

注&#xff1a;课堂讲义使用的数据库 5.1利用SQL语言建立数据库 5.1.1 create Database 5.1.2 create schema...authorization... 创建数据库和创建模式的区别&#xff1a; 数据库是架构的集合&#xff0c;架构是表的集合。但在MySQL中&#xff0c;他们使用的方式是相同的。 …

网络模型的保存和读取

1. 网络保存 import torch import torchvision from torch import nnvgg16 torchvision.models.vgg16(pretrainedFalse)#保存方式1 不仅保存了网络模型结构也保存了参数 torch.save(vgg16,vgg16_method1.pth)#保存方式2 获取模型状态&#xff08;参数&#xff09;并且保存…

深入探索HAProxy:高性能负载均衡器的奥秘

目录 引言 一、HAProxy基础知识 &#xff08;一&#xff09;HAProxy概述 &#xff08;二&#xff09;核心特性 &#xff08;三&#xff09;支持调度算法 二、安装haproxy &#xff08;一&#xff09;下载源码包 &#xff08;二&#xff09;解决依赖环境 &#xff08;三…

Linux系统安装APITable智能表格并结合内网穿透实现公网访问本地服务

文章目录 前言1. 部署APITable2. cpolar的安装和注册3. 配置APITable公网访问地址4. 固定APITable公网地址 前言 vika维格表作为新一代数据生产力平台&#xff0c;是一款面向 API 的智能多维表格。它将复杂的可视化数据库、电子表格、实时在线协同、低代码开发技术四合为一&am…

TextView实现打印机效果 ,字符串逐字显示

public class FadeInTextView extends TextView { private Rect textRect new Rect(); private StringBuffer stringBuffer new StringBuffer(); private String[] arr; private int textCount; private int currentIndex -1; /** * 每个字出现的时间 */ priv…

力扣:118. 杨辉三角

力扣&#xff1a;118. 杨辉三角 描述 给定一个非负整数 numRows&#xff0c;生成「杨辉三角」的前 numRows 行。 在「杨辉三角」中&#xff0c;每个数是它左上方和右上方的数的和。 示例 1: 输入: numRows 5 输出: [[1],[1,1],[1,2,1],[1,3,3,1],[1,4,6,4,1]] 示例 2: 输…