[PyTorch][chapter 2][李宏毅深度学习-Regression]

前言:

     Regression 模型主要用于股票预测,自动驾驶,推荐系统等领域.

这个模型的输出是一个scalar。这里主要以下一个线性模型为基础

它是神经网络的基础模块,

目录:

  1.    总体流程
  2.    常见问题
  3.    Numpy 例子
  4.    PyTorch 例子

一    总体流程

       1 : 建模(model)

             y=w^Tx+b

                 =\sum_{i}w_ix_i+b 

           其中: w: weight权重系数 向量

                       b:bias 偏置系数  scalar

                       x: 输入数据 向量

       2: TrainData

             收集N个样本,其中

               \hat{y}: 标签 target

              x: [n,1]的列向量

               x^i,\hat{y}^{i}: 一组训练数据

            数据集如下:

              \begin{bmatrix} (x^1,\hat{y}^1)\\ (x^2,\hat{y}^2) \\ .... \\ (x^N,\hat{y}^N) \\ \end{bmatrix}

     

     3: 损失函数

                         用来度量Goodness of Function,一般用MSE loss

                        L(w.b)=\sum_{n=1}^{N}(\hat{y}^{n}-(w^Tx^n+b))^2

                        我们训练的目标是找到最优的参数w,b使得loss 最小

                         w^{*},b^{*}=argminL_{w,b}(w,b)

       4: 训练

                训练的时候,主要通过梯度下降方法 Gradient Descent


二  常见问题

          2.1  Gradient Discent  局部极小值问题

    当损失函数是非凸函数,可能有多个局部极小点.如下图

                 

 如上图,当梯度为0 的时候,此刻参数就无法更新了,导致

loss 陷入了局部极小值点,无法更新到全局最小值点.当多维参数的时候

在saddle point 陷入到局部极小值点.

  解决方案:

Adagrad: 自适应学习率,会根据之前的梯度信息自动调整每个参数的学习率。
RMSprop: 自适应学习率,根据历史梯度信息,采用了指数加权移动平均,
stochastic GD: 其每次都只使用一个样本进行参数更新,这样更新次数大大增加也就不容易陷入局部最优
Mini-Batch GD: 每次更新使用一小批样本进行参数更新
每次更新前加入部分上一次的梯度量,这样整个梯度方向就不容易过于随机。
一些常见情况时,如上次梯度过大,导致进入局部最小点时,下一次更新能很容易借助上次的大梯度跳出局部最小点。

    

   2.2 过拟合问题

        如上图, 五个拟合函数。函数5最复杂,在Training 数据集上的loss 最小

但是在Testing 数据集上loss 最大,这种称为过拟合.

         

      解决方案:

             在Training Data 收集更多的数据集

             L2 正规化

         


三  Numpy 代码例子

      

# -*- coding: utf-8 -*-
"""
Created on Tue Nov 28 16:19:48 2023@author: chengxf2
"""# -*- coding: utf-8 -*-
"""
Created on Tue Nov 28 14:47:32 2023@author: chengxf2
"""import numpy as np
import matplotlib.pyplot as pltdef load_trainData():x_data =[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]y_data =[3, 5, 7, 9, 11, 13, 15, 17, 19, 21]#y_data = b+w*xdata return x_data, y_datadef draw(x_data,y_data,b_history, w_history):#plot the figure#x = np.arange(-200, -100, 1.0) #biasx = np.arange(-5, 5, 0.1) #biasy = np.arange(-10, 10, 0.1) #weightZ = np.zeros((len(y),len(x))) #ndarrayX, Y = np.meshgrid(x,y)print("\n-----z ---",Z.shape)for i in range(len(x)):#100for j in range(len(y)):#40w = x[i]b = y[j]#print(j,i)Z[j,i]=0for n in range(len(x_data)):Z[j,i]=Z[j,i]+(y_data[n] -b-w*x_data[n])**2Z[j,i]== Z[j,i]/len(x_data)#绘制等高线,cmp = plt.get_cmap('jet')plt.contourf(x,y,Z, 50, alpha = 0.5, cmap='rainbow')plt.plot([1.0], [2.0], 'x', ms=10, markeredgewidth =30, color ='red')plt.plot(w_history, b_history, 'o-', ms=3, lw =5, color ='green')plt.xlim(-10, 10)plt.ylim(-10, 10)plt.xlabel(r'$b$',fontsize =16)plt.ylabel(r'$w$',fontsize =16)plt.show()class regression():def learn(self,x_data,y_data):N = len(x_data)b_history = []w_history = []for i in range(self.iteration): w_grad = 0.0b_grad = 0.0for n in range(N):w_grad = w_grad +2.0*(self.b+self.w*x_data[n]-y_data[n])*x_data[n]b_grad = b_grad +2.0*(self.b+self.w*x_data[n]-y_data[n])*1.0self.b = self.b -self.lr*b_gradself.w = self.w- self.lr*w_gradb_history.append(self.b)w_history.append(self.w)print("\n 偏置: ",round(self.b,1), "\n 权重系数 ",round(self.w,1))return w_history,b_historydef initPara(self):self.b = -120 #initial bself.w = -4.0 #initial wself.lr = 1e-3 #learning rateself.iteration = 10000 #iteration Numberdef __init__(self):self.b = 0 self.w = 0self.lr = 0self.iteration = 0if __name__ == "__main__":model = regression()x_data,y_data = load_trainData()model.initPara()w_history,b_history = model.learn(x_data, y_data)draw(x_data,y_data,b_history, w_history)

四 PyTorch API 例子

loss 情况

torch.optim.SGD是PyTorch中实现的Stochastic Gradient Descent(SGD)优化器,用于更新神经网络中的参数,以最小化损失函数,从而提高模型的精度。它的一些重要参数如下:

- lr:学习率(learning rate),控制每次参数更新的步长。默认值为0.001。
- momentum:动量(momentum),加速SGD在相关方向上前进,抑制震荡。常常取值为0.9。若设为0,则为经典的SGD算法。
- dampening:阻尼(dampening),用于防止动量的发散。默认值为0。
- weight_decay:权重衰减(weight decay),也称权重衰减(weight regularization),用于防止过拟合。默认值为0。
- nesterov:采用Nesterov加速梯度法(Nesterov accelerated gradient,NAG)。默认值为False
 

  

# -*- coding: utf-8 -*-
"""
Created on Tue Nov 28 16:27:30 2023@author: chengxf2
"""import torch
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary#加载PyTorch数据集
def load_data():x_data =[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]y_data =[3, 5, 7, 9, 11, 13, 15, 17, 19, 21]x = torch.Tensor(x_data)   #将列表a转成tensor类型one = torch.ones_like(x)Y = torch.Tensor(y_data)  X = torch.stack([x,one],dim=1)return  X,Y.view(10,1)#建模
class NN(torch.nn.Module):def __init__(self, xDim, yDim):super(NN,self).__init__()self.predict = torch.nn.Linear(xDim, yDim,bias=False)def forward(self, x):y = self.predict(x)return y#训练
def learn(X,Y):#X 数据集 Y 标签集model = NN(2,1)optimizer = torch.optim.SGD(model.parameters(),lr=0.001)  #优化方法为随机梯度下降loss_f = torch.nn.MSELoss()lossHistory =[]for i in range(10000):predict = model(X)loss = loss_f(predict, Y)#print("\n loss ",loss.item())lossHistory.append(loss.item())optimizer.zero_grad()loss.backward()optimizer.step()#y_predict = y_predict.detach().numpy()N = len(lossHistory)x = np.arange(0, N)plt.plot(x, lossHistory)summary(model, input_size=(1,2), batch_size=-1)for param in model.parameters():print("\n 参数 ",param)if __name__ == "__main__":x,y = load_data() learn(x,y )

参考:

sam+yolov8 


1 [2023] 李宏毅机器学习完整课程 43
      https://www.bilibili.com/video/BV1NX4y1r7nP/?spm_id_from=333.337.search-card.all.click&vd_source=a624c4a1aea4b867c580cc82f03c1745

2 【2022】最新 李宏毅大佬的深度学习与机器学  P90

    https://www.bilibili.com/video/BV1J94y1f7u5/?spm_id_from=333.337.search-card.all.click&vd_source=a624c4a1aea4b867c580cc82f03c1745


3 [2020 ]李宏毅机器学习深度学习(完整版)国语  P119

     https://www.bilibili.com/video/BV1JE411g7XF/?spm_id_from=333.337.search-card.all.click&vd_source=a624c4a1aea4b867c580cc82f03c1745


4 [2017 ]李宏毅机器学习 P40

      https://www.bilibili.com/video/BV13x411v7US/?spm_id_from=333.337.search-card.all.click&vd_source=a624c4a1aea4b867c580cc82f03c1745


5 李宏毅: 强化学习 P11

     https://www.bilibili.com/video/BV1XP4y1d7Bk/?spm_id_from=333.337.search-card.all.click&vd_source=a624c4a1aea4b867c580cc82f03c1745
————————————————
版权声明:本文为CSDN博主「明朝百晓生」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/chengxf2/article/details/134643845

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/181535.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

智慧工地解决方案,Spring Cloud智慧工地项目平台源码

智慧工地一体化信息管理平台源码,微服务架构JavaSpring Cloud UniApp MySql 智慧工地云平台是专为建筑施工领域所打造的一体化信息管理平台。通过大数据、云计算、人工智能、物联网和移动互联网等高科技技术手段,将施工区域各系统数据汇总,建…

15.Docker-Compose的概念理解及安装

1.Docker-Compose是什么? Docker-Compose是实现对Docker容器集群的快速编排的工具软件。它是Docker官方开源的一个工具软件,可以管理多个Docker容器组成一个应用。你需要定义一个YAML格式的配置文件docker-compose.yml.写好多个容器间的调用关系&#x…

京东大数据(京东运营数据采集):2023年10月京东牛奶乳品行业品牌销售排行榜

鲸参谋监测的京东平台10月份牛奶乳品市场销售数据已出炉! 10月份,牛奶乳品整体销售上涨。鲸参谋数据显示,今年10月,京东平台上牛奶乳品的销量将近1700万,同比增长1%;销售额将近17亿,同比增长约5…

C语言常见算法

算法(Algorithm):计算机解题的基本思想方法和步骤。算法的描述:是对要解决一个问题或要完成一项任务所采取的方法和步骤的描述,包括需要什么数据(输入什么数据、输出什么结果)、采用什么结构、使…

在ubuntu系统安装SVN服务端,并通过客户端进行远程访问

文章目录 前言1. Ubuntu安装SVN服务2. 修改配置文件2.1 修改svnserve.conf文件2.2 修改passwd文件2.3 修改authz文件 3. 启动svn服务4. 内网穿透4.1 安装cpolar内网穿透4.2 创建隧道映射本地端口 5. 测试公网访问6. 配置固定公网TCP端口地址6.1 保留一个固定的公网TCP端口地址6…

【C 语言经典100例】C 练习实例7

题目:输出特殊图案,请在c环境中运行,看一看,Very Beautiful! 程序分析:字符共有256个。不同字符,图形不一样。 VC6.0下出现中文乱码(原因解决方法): 176的16进制是B0,219的16进制是DB&#xf…

常见的6种工业主板盘点

无论您涉及哪种类型的工业环境,主板都是所有电子元件的关键部件之一。可靠且高效的主板是任何功能系统的核心和灵魂。 不同的主板旨在满足不同的需求,如果您希望系统发挥最佳性能,则必须了解这些需求。本文提供了有关当今流行的6种工业主板的…

点云凹凸缺陷检测 最高层点云 点云聚类

文章目录 0. 数据说明1. 凹凸缺陷基本内容2. 详细检测思路结果: 0. 数据说明 如上图所示,需要检测的内容为红色框内标出的缺陷部分。简单示例如下红色线条。 但是,由于噪声的影响,点云的平面度并不好,且横梁边缘处存在连接,如下: 基于上述问题,首先需要获取有效点云(最…

P27 C++this 关键字

目录 前言 01 this关键字的引入 02 this关键字 前言 本章的主题是 C 中的 this 关键字。 以前第一次学qt的时候就遇到了this关键字,那时候还不是很会C,所以有点懵,现在我们就来讲解以下C中的this关键字 C 中有一个关键字 this&#xff0…

示波器高压探头的操作说明及使用注意事项

操作说明: 连接探头衰减端的地线(鳄鱼夹)到好的接地点或可靠的接地测试端。连接BNC头到示波器的BNC输入端口。选择示波器要求的量程范围。 注意:请务必在连接测试前把高压电源关闭。 注意事项: 请勿将测试设备的接地线从地面接线柱上移开。…

拒绝随波逐流!设计与实现可控的水下机器人

这个“长着三个触角”的水下机器人看上去是不是很萌?它使用的是一种新型的由三个球形磁耦合矢量推进器组成的推进系统。与传统的水下机器人使用多个固定推进器来实现多自由度(DOF)推进相比,矢量推进器具有多自由度、寄生推力小&am…

数据结构:哈希表讲解

哈希表 1.哈希概念2.通过关键码确定存储位置2.1哈希方法2.2直接定址法2.3除留余数法 3.哈希冲突概念4.解决哈希冲突4.1闭散列4.1.1概念4.1.2哈希表扩容4.1.3存储位置的状态4.1.4关于键值类型4.1.5代码实现 4.2开散列4.2.1概念4.2.2哈希表扩容4.2.3代码实现 4.3开闭散列的对比 1…

界面控件DevExpress WinForms Sunburst组件,轻松可视化分层扁平数据!

DevExpress WinForms Sunburst控件允许用户以紧凑和视觉上吸引人的方式可视化分层和扁平数据。 DevExpress WinForms有180组件和UI库,能为Windows Forms平台创建具有影响力的业务解决方案。同时能完美构建流畅、美观且易于使用的应用程序,无论是Office风…

ChatGPT到底是如何运作?

自从2022年11月30日发布以来,ChatGPT一直占据着科技届的头条位置,随着苹果的创新能力下降,ChatGPT不断给大家带来震撼,2023年11月7日,首届OpenAI开发者大会在洛杉矶举行,业界普遍认为,OpenAI的开…

11.28C++

#include <iostream>using namespace std;int main() {string str;cout << "请输入一个字符串&#xff1a;" << endl;getline(cin,str);int size str.size();int a0,b0,c0,d0,e0;for(int i0; i < size; i){if(str.at(i) > A && str…

Element-ui合并table表格列方法

merageCell({ row, column, rowIndex, columnIndex }) {if (columnIndex 0 || columnIndex 1) {const property columnIndex 0 ? name : firstDeptName;// 获取当前行的property&#xff0c;这里看自己的需要&#xff0c;改成根据哪个去判断const currentPropertyVal row…

Webshell流量分析

Webshell流量分析 常见的一句话木马: asp一句话 <%eval request("pass")%> aspx一句话 <%@ Page Language="Jscript"%><%eval(Request.Item["pass"],"unsafe");%> php一句话 <?php @eval($_POST["pass&…

【华为数通HCIP | 网络工程师】821刷题日记-BFD和VRRP 及重点(1)

个人名片&#xff1a; &#x1f43c;作者简介&#xff1a;一名大三在校生&#xff0c;喜欢AI编程&#x1f38b; &#x1f43b;‍❄️个人主页&#x1f947;&#xff1a;落798. &#x1f43c;个人WeChat&#xff1a;hmmwx53 &#x1f54a;️系列专栏&#xff1a;&#x1f5bc;️…

iconfont 使用彩色图标

1、下载iconfont到本地 2、全局安装 iconfont-tools npm install -g iconfont-tools 3、在iconfont解压目录下执行命令、一直回车 iconfont-tools 4、文件拷贝 执行完上述命令后会生成iconfont-weapp目录&#xff0c;将iconfont-weapp目录下的iconfont-weapp- icon.css文件…

【23真题】比985还难的双非!

今天分享的是23年长春工业大学807的信号与系统试题及解析。 本套试卷难度分析&#xff1a;本套试题难度中等偏上&#xff0c;题量不少&#xff0c;难度不小&#xff01;状态方程考察的淋漓尽致。另外还有电路题。这所双非院校的真题比90%的211难&#xff0c;甚至比一部分985更…