【深度学习实验】前馈神经网络(四):自定义逻辑回归模型:前向传播、反向传播算法

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 逻辑回归Logistic类

a. 构造函数__init__

b. __call__(self, x)方法

c. 前向传播forward

d. 反向传播backward

2. 模型训练

3. 代码整合


一、实验介绍

  • 实现逻辑回归模型(Logistic类)
    • 实现前向传播forward
    • 实现反向传播backward

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

0. 导入必要的工具包

import torch

1. 逻辑回归Logistic

a. 构造函数__init__

 def __init__(self):self.inputs = Noneself.outputs = Noneself.params = None

         初始化了类的成员变量self.inputsself.outputsself.params,它们分别用于保存输入、输出和参数。

b. __call__(self, x)方法

    __call__(self, x)方法使得该类的实例可以像函数一样被调用。它调用了forward(x)方法,将输入的x传递给前向传播方法。

 def __call__(self, x):return self.forward(x)

c. 前向传播forward

  def forward(self, inputs):outputs = 1.0 / (1.0 + torch.exp(-inputs))self.outputs = outputsreturn outputs

    forward(self, inputs)方法执行逻辑回归的前向传播。它接受输入inputs作为参数,并通过逻辑回归的公式计算输出值outputs。最后,将计算得到的输出保存在self.outputs中,并返回输出值。

d. 反向传播backward

    def backward(self, outputs_grads=None):if outputs_grads is None:outputs_grads = torch.ones(self.outputs.shape)outputs_grad_inputs = torch.multiply(self.outputs, (1.0 - self.outputs))return torch.multiply(outputs_grads, outputs_grad_inputs)

    backward(self, outputs_grads=None)方法执行逻辑回归的反向传播。

  • 接受一个可选的参数outputs_grads,用于传递输出的梯度。
  • 如果没有提供outputs_grads,则默认为全1的张量,表示对输出的梯度都为1。
  • 根据逻辑回归的导数公式,可以将输出值与(1-输出值)相乘,然后再乘以传入的梯度值,得到输入的梯度。
  • 返回计算得到的输入梯度。

2. 模型训练

act = Logistic()
x = torch.tensor([3,3,4,2])
y = act(x)z = act.backward()
print(z)
  • 创建一个Logistic的实例act;
  • 传入张量x进行前向传播,得到输出张量y;
  • 调用act.backward()进行反向传播,得到输入x的梯度;
  • 将结果打印输出。
tensor([0.0452, 0.0452, 0.0177, 0.1050])

3. 代码整合

# 导入必要的工具包
import torchclass Logistic():def __init__(self):self.inputs = Noneself.outputs = Noneself.params = Nonedef __call__(self, x):return self.forward(x)def forward(self, inputs):outputs = 1.0 / (1.0 + torch.exp(-inputs))self.outputs = outputsreturn outputsdef backward(self, outputs_grads=None):if outputs_grads is None:outputs_grads = torch.ones(self.outputs.shape)outputs_grad_inputs = torch.multiply(self.outputs, (1.0 - self.outputs))return torch.multiply(outputs_grads, outputs_grad_inputs)act = Logistic()
x = torch.tensor([3,3,4,2])
y = act(x)z = act.backward()
print(z)

注意:

        本实验仅实现了逻辑回归的前向传播和反向传播部分,缺少了模型的参数更新和训练部分。完整的逻辑回归,需要进一步编写训练循环、损失函数和优化器等部分,欲知后事如何,请听下回分解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/86756.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux 目录结构介绍

对上面的说明: root 目录 : linux 超级权限 root 的主目录 home 目录 : 系统默认的用户主目录,如果添加用户是不指定用户的主目录,默认在/home 下创建与用户同名的文件夹 bin 目录 : 存放系统所需要的重要命令&am…

uniapp Echart X轴Y轴文字被遮挡怎么办,或未能铺满整个容器

有时候布局太小,使用echarts,x轴y轴文字容易被遮挡,怎么解决这个问题呢,或者是未能铺满整个容器。 方法1: 直接设置 containLabel 字段 options: { grid: { containLabel: true, },} 方法2: 间接设置,但是…

【新版】系统架构设计师 - 案例分析 - 信息安全

个人总结,仅供参考,欢迎加好友一起讨论 文章目录 架构 - 案例分析 - 信息安全安全架构安全模型分类BLP模型Biba模型Chinese Wall模型 信息安全整体架构设计WPDRRC模型各模型安全防范功能 网络安全体系架构设计开放系统互联安全体系结构安全服务与安全机制…

mysql workbench常用操作

1、No database selected Select the default DB to be used by double-clicking its name in the SCHEMAS list in the sidebar 方法一:双击你要使用的库 方法二:USE 数据库名 2、复制表名,字段名 3、保存链接

vue3+ts 实现移动端分页

current 开始页码 pageSize 结束页码 const sizeref<number>(10) //一页显示十条 const eachCurrentPageref<number>(1) //默认是第一页interface ITdata {current: number,pageSize: number,// xxxx 其他参数... } const selectApplyList ref<…

联想电脑打开exe提示要在Microsoft Store中搜索应用

问题&#xff1a; 你需要为此任务安装应用。 是否要在Microsoft Store中搜索一个&#xff1f; 如图&#xff1a; 出现此情况&#xff0c;仅需要做如下操作&#xff0c;在要打开的exe文件上右键&#xff0c;属性&#xff1a; 如图箭头所示&#xff0c;点击“解除锁定”出现对钩&…

<十二>objectARX开发:Arx注册命令类型的含义以及颜色索引对应RGB值

1、注册命令类型 我们经常在acrxEntryPoint.cpp中看到注册命令如下: 那么各个宏定义代表什么意思呢? 主标识:(常用的) ACRX_CMD_MODAL: 在别的命令执行的时候该命令不会在其中执行。ACRX_CMD_TRANSPARENT: 命令可以再其它命令中执行,但在该标志下ads_sssetfirst()不能使…

LeetCode 494.目标和 (动态规划 + 性能优化)二维数组 压缩成 一维数组

494. 目标和 - 力扣&#xff08;LeetCode&#xff09; 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - &#xff0c;然后串联起所有整数&#xff0c;可以构造一个 表达式 &#xff1a; 例如&#xff0c;nums [2, 1] &#xff0c;可以在 2…

用Redis做数据排名

1.背景 用Redis做数据缓存用的比较多&#xff0c;大家都能熟练使用String和Hash结构去存储数据&#xff0c;今天讲下如何使用ZSet来做数据排名。 假设场景是需要按天存储全国城市的得分数据&#xff0c;可以查询前十名的城市排名。 这个case可以使用传统关系型数据库做…

如何修复wmvcore.dll缺失问题,wmvcore.dll下载修复方法分享

近年来&#xff0c;电脑使用的普及率越来越高&#xff0c;人们在日常生活中离不开电脑。然而&#xff0c;有时候我们可能会遇到一些问题&#xff0c;其中之一就是wmvcore.dll缺失的问题。wmvcore.dll是Windows平台上用于支持Windows Media Player的动态链接库文件&#xff0c;如…

SD-MTSP:萤火虫算法(FA)求解单仓库多旅行商问题MATLAB(可更改数据集,旅行商的数量和起点)

一、萤火虫算法&#xff08;FA&#xff09;简介 萤火虫算法(Firefly Algorithm&#xff0c;FA)是Yang等人于2009年提出的一种仿生优化算法。 参考文献&#xff1a;田梦楚, 薄煜明, 陈志敏, et al. 萤火虫算法智能优化粒子滤波[J]. 自动化学报, 2016, 42(001):89-97. 二、单仓…

数量关系(刘文超)

解题技巧 代入排除法 数字特性法 整除特性 比例倍数特性&#xff08;找比例&#xff0c;比例不明显时找等式&#xff09; 看不懂式子时&#xff0c;把所有的信息像表格一样列出来 看不懂式子时&#xff0c;把所有的信息像表格一样列出来

【机器学习】期望最大算法(EM算法)解析:Expectation Maximization Algorithm

【机器学习】期望最大算法&#xff08;EM算法&#xff09;&#xff1a;Expectation Maximization Algorithm 文章目录 【机器学习】期望最大算法&#xff08;EM算法&#xff09;&#xff1a;Expectation Maximization Algorithm1. 介绍2. EM算法数学描述3. EM算法流程4. 两个问…

性能测试 —— Tomcat监控与调优:Jconsole监控

JConsole的图形用户界面是一个符合Java管理扩展(JMX)规范的监测工具&#xff0c;JConsole使用Java虚拟机(Java VM)&#xff0c;提供在Java平台上运行的应用程序的性能和资源消耗的信息。在Java平台&#xff0c;标准版(Java SE平台)6&#xff0c;JConsole的已经更新到目前的外观…

Linux查看哪些进程占用的系统 buffer/cache 较高 (hcache,lsof)命令

1、什么是buffer/cache &#xff1f; buffer/cache 其实是作为服务器系统的文件数据缓存使用的&#xff0c;尤其是针对进程对文件存在 read/write 操作的时候&#xff0c;所以当你的服务进程在对文件进行读写的时候&#xff0c;Linux内核为了提高服务的读写速度&#xff0c;则将…

linux 约束

linux 约束 1、约束的概念1.1什么是约束1.2约束的优劣势 2、约束的作用3、约束的分类4、约束的应用场景5、约束的管理5.1创建5.2查看5.3插入5.4删除 6、总结 1、约束的概念 1.1什么是约束 在关系型数据库中&#xff0c;约束是用于限制表中数据规则的一种机制。它可以确保表中…

数据库数据恢复-ORACLE常见故障有哪些?恢复数据的可能性高吗?

ORACLE数据库常见故障&#xff1a; 1、ORACLE数据库无法启动或无法正常工作。 2、ORACLE数据库ASM存储破坏。 3、ORACLE数据库数据文件丢失。 4、ORACLE数据库数据文件部分损坏。 5、ORACLE数据库DUMP文件损坏。 ORACLE数据库数据恢复可能性分析&#xff1a; 1、ORACLE数据库无…

基于STM32的宠物托运智能控制系统的设计(第十七届研电赛)

一、功能介绍 使用STM32作为主控设备&#xff0c;通过DHT11温湿度传感器、多合一空气质量检测传感器以及压力传感器对宠物的托运环境中的温湿度、二氧化碳浓度和食物与水的重量进行采集&#xff0c;将采集到的信息在本地LCD显示屏上显示&#xff0c;同时&#xff0c;使用4G模块…

小程序如何关联公众号来发送模板消息

有时候我们可能需要通过公众号来发送一些小程序的服务通知&#xff0c;比如订单提醒、活动通知等。那么要如何操作呢&#xff1f; 1. 有一个通过了微信认证的服务号。需要确保小程序和公众号是同一个主体的。也就是说&#xff0c;小程序和公众号应该都是属于同一个企业。如果还…

RestTemplate:简化HTTP请求的强大工具

文章目录 什么是RestTemplateRestTemplate的作用代码示例 RestTemplate与HttpClient 什么是RestTemplate RestTemplate是一个在Java应用程序中发送RESTful HTTP请求的强大工具。本文将介绍RestTemplate的定义、作用以及与HttpClient的对比&#xff0c;以帮助读者更好地理解和使…