【深度学习实验】线性模型(二):使用NumPy实现线性模型:梯度下降法

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 初始化参数

2. 线性模型 linear_model

3. 损失函数loss_function

4. 梯度计算函数compute_gradients

5. 梯度下降函数gradient_descent

6. 调用函数


一、实验介绍

        使用NumPy实现线性模型:梯度下降法

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

         线性模型梯度下降法是一种常用的优化算法,用于求解线性回归模型中的参数。它通过迭代的方式不断更新模型参数,使得模型在训练数据上的损失函数逐渐减小,从而达到优化模型的目的。

        梯度下降法的基本思想是沿着损失函数梯度的反方向更新模型参数。在每次迭代中,根据当前的参数值计算损失函数的梯度,然后乘以一个学习率的因子,得到参数的更新量。学习率决定了参数更新的步长,过大的学习率可能导致错过最优解,而过小的学习率则会导致收敛速度过慢。

具体而言,对于线性回归模型,梯度下降法的步骤如下:

  1. 初始化模型参数,可以随机初始化或者使用一些启发式的方法。

  2. 循环迭代以下步骤,直到满足停止条件(如达到最大迭代次数或损失函数变化小于某个阈值):

    a. 根据当前的参数值计算模型的预测值。

    b. 计算损失函数关于参数的梯度,即对每个参数求偏导数。

    c. 根据梯度和学习率更新参数值。

    d. 计算新的损失函数值,并检查是否满足停止条件。

  3. 返回优化后的模型参数。

       本实验中,gradient_descent函数实现了梯度下降法的具体过程。它通过调用initialize_parameters函数初始化模型参数,然后在每次迭代中计算模型预测值、梯度以及更新参数值。

0. 导入库

import numpy as np

1. 初始化参数

        在梯度下降算法中,需要初始化待优化的参数,即权重 w 和偏置 b。可以使用随机初始化的方式。

def initialize_parameters():w = np.random.randn(5)b = np.random.randn(5)return w, b

2. 线性模型 linear_model

def linear_model(x, w, b):output = np.dot(x, w) + breturn output

3. 损失函数loss_function

         该函数接受目标值y和模型预测值prediction,计算均方误差损失。

def loss_function(y, prediction):loss = (prediction - y) * (prediction - y)return loss

4. 梯度计算函数compute_gradients

        为了使用梯度下降算法,需要计算损失函数关于参数 w 和 b 的梯度。可以使用数值计算的方法来近似计算梯度。

def compute_gradients(x, y, w, b):h = 1e-6  # 微小的数值,用于近似计算梯度grad_w = (loss_function(y, linear_model(x, w + h, b)) - loss_function(y, linear_model(x, w - h, b))) / (2 * h)grad_b = (loss_function(y, linear_model(x, w, b + h)) - loss_function(y, linear_model(x, w, b - h))) / (2 * h)return grad_w, grad_b

5. 梯度下降函数gradient_descent

        根据梯度计算的结果更新参数 w 和 b,从而最小化损失函数。

def gradient_descent(x, y, learning_rate, num_iterations):w, b = initialize_parameters()for i in range(num_iterations):prediction = linear_model(x, w, b)grad_w, grad_b = compute_gradients(x, y, w, b)w -= learning_rate * grad_wb -= learning_rate * grad_bloss = loss_function(y, prediction)print("Iteration", i, "Loss:", loss)return w, b

6. 调用函数

        执行梯度下降优化:调用 gradient_descent 函数并传入数据 x 和 y,设置学习率和迭代次数进行优化。

x = np.random.rand(5)
y = np.array([1, -1, 1, -1, 1]).astype('float')
learning_rate = 0.1
num_iterations = 100
w_optimized, b_optimized = gradient_descent(x, y, learning_rate, num_iterations)

        在上述代码中,每一次迭代都会打印出当前迭代次数和对应的损失值。通过不断更新参数 w 和 b,使得损失函数逐渐减小,达到最小化损失函数的目的。

希望这个详细解析能够帮助你优化代码并使用梯度下降算法最小化损失函数。如果还有其他问题,请随时提问!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/83057.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

只需4步使用Redis缓存优化Node.js应用

介绍 通过API获取数据时,会向服务器发出网络请求,收到响应数据。但是,此过程可能非常耗时,并且可能会导致程序响应时间变慢。 我们使用缓存来解决这个问题,客户端程序首先向API发送请求,将返回的数据存储…

文档丢失怎么找回?学会这3个方法就足够!

场景1:“不是吧!我辛辛苦苦写的文档好像忘记保存就退出了!谁能救救我!帮我找回丢失的文档?” 场景2:“电脑里的文档太多了,每次在清理时都容易误删。有什么方法可以找回我丢失的文档吗&#xff…

成集云 | 用友T+集成聚水潭ERP(用友T+主管库存)| 解决方案

源系统成集云目标系统 方案介绍 用友T是一款由用友畅捷通推出的新型互联网企业管理系统,它主要满足成长型小微企业对其灵活业务流程的管控需求,并重点解决往来业务管理、订单跟踪、资金、库存等管理难题。 聚水潭是一款以SaaS ERP为核心,集…

嵌入式笔试面试刷题(day15)

文章目录 前言一、Linux中的主设备号和次设备号1.查看方法2.主设备号和次设备号的作用 二、软件IIC和硬件IIC的区别三、变量的声明和定义区别四、static在C和C中的区别五、串口总线空闲时候的电平状态总结 前言 本篇文章继续讲解嵌入式笔试面试刷题,希望大家坚持跟…

pgzrun 拼图游戏制作过程详解(10)

10. 拼图游戏继续升级——多关卡拼图 初始化列表Photos用来储存拼图文件名,Photo_ID用来统计当下是第几张拼图,Squares储存当下拼图的24张小拼图的文件名,Gird储存当下窗口上显示的24个小拼图及坐标。 Photos["girl_","boy_…

基于Java+SpringBoot+Vue+Element的OA系统的设计和实现

基于JavaSpringBootVueElement的OA系统的设计和实现 源码传送入口前言主要技术系统设计功能截图数据库设计代码论文目录订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 源码传送入口 前言 在Internet高速发展的今天,我们生活的各个领域都涉及到计算机的…

AI写作工具,智能ai写作工具

在信息化时代,内容创作已经成为了许多行业的核心。从营销广告到新闻报道,从博客文章到学术论文,人们需要不断地产生高质量的文字内容。创作是一项耗时耗力的工作,需要丰富的知识和创造性思维。 AI写作工具,是一类基于人…

基于Spring Boot的医院预约挂号系统设计与实现

前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 👇🏻…

SAP ABAP基础知识 访问外部数据库-开发篇

前言 本文主要介绍通过ABAP语言访问外部数据库的几种方式 一、外部数据库配置 本文示例中的代码访问了两个外部数据库 MTD : 外部oracle数据库,其中示例表 ZTTEMP 字段( ZZTNO,WERKS) S4Q : 外部HANA数据库(开发系统访问测试系统的数据库), 使用表USR02,ZTTEMP 二、ABAP访问…

IDEA(2023)解决运行乱码问题

😇作者介绍:一个有梦想、有理想、有目标的,且渴望能够学有所成的追梦人。 🎆学习格言:不读书的人,思想就会停止。——狄德罗 ⛪️个人主页:进入博主主页 🗼专栏系列:无 &#x1f33c…

【计算机组成原理】读书笔记第三期:内存和磁盘的关系

目录 写在开头 内存与磁盘的关系 基本关系 磁盘缓存 虚拟内存 节约内存的编程方法 通过DLL文件实现函数共有 通过调用_stdcall来降低文件程序的大小 磁盘的物理结构 结尾 写在开头 本文继续阅读总结《程序是怎样跑起来的》这本书(作者:矢泽…

操作系统(5-7分)

内容概述 进程管理 进程的状态 前驱图 同步和互斥 PV操作(难点) PV操作由P操作原语和V操作原语组成(原语是不可中断的过程),对信号量进行操作,具体定义如下: P(S)&#…

【计算机网络】——传输层

//图片取自王道,仅做交流学习 一、传输层提供的服务 物理层、数据链路层、网络层是通信子网。 传输层:它属于面向通信部分的最高层,同时也是用户功能的最低层 为应用层提供通信服务使用网络层的服务 网络层提供主机之间的逻辑通信。 1、传输…

SpringMVC之JSR303与拦截器

目录 一.JSR303 1.什么是JSR303 2.为什么使用JSR303 3.JSR303常用注解 4.快速入门 4.1导入Maven依赖 4.2 配置校验规则 4.3 对服务端数据添加进行校验 4.4 结果测试 二.拦截器 1.什么是拦截器 2.拦截器与过滤器 3.应用场景 4.基本拦截器配置 5 案例演示&#xff0…

区块链实验室(23) - FISCO中PBFT耗时与流量特征

前面的实验(区块链实验室(11) - PBFT耗时与流量特征)用仿真的PBFT观察耗时。现在用真实的Fisco网络再次观察其特征。同样地,用相同的网络,即100个节点构成的无标度网络。在每个节点上发起10次交易,记录每次交易的耗时。结果见下图所示。 前半…

VSCode『SSH』连接服务器『GUI界面』传输

前言 最近需要使用实验室的服务器训练带有 GUI 画面的 AI 算法模型(pygame),但是我是使用 SSH 连接的,不能很好的显示模型训练的效果画面,所以下面将会讲解如何实现 SSH 连接传输 Linux GUI 画面的 注:我们…

Postman —— HTTP请求基础组成部分

一般来说,所有的HTTP Request都有最基础的4个部分组成:URL、 Method、 Headers和body。 (1)Method 要选择Request的Method是很简单的,Postman支持所有的请求方式。 (2)URL 要组装一条Request…

华为HCIA(五)

Vlan id 在802.1Q中 高级ACL不能匹配用户名和源MAC 2.4G频段被分为14个交叠的,错列的20MHz信道,信道编码从1到14,邻近的信道之间存在一定的重叠范围 STA通过Probe获取SSID信息 Snmp报文 网络管理设备异常发生时会发送trap报文 D类地址是…

基于Java网络书店商城设计实现(源码+lw+部署文档+讲解等)

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

亚马逊云科技 Amazon Lightsail :一种在云服务器上运行容器的简单方法

当向开发人员介绍亚马逊云科技云服务时,通常会花一点时间来介绍并演示 Amazon Lightsail 。它是迄今为止开始使用亚马逊云科技的最简单方法。使用它,您在几分钟内即可在自己的虚拟服务器上运行您的应用程序。而后增加了在 Amazon Lightsail 上部署基于容…