神经网络基础-神经网络搭建和参数计算

文章目录

    • 1.构建神经网络
    • 2. 神经网络的优缺点

1.构建神经网络

在 pytorch 中定义深度神经网络其实就是层堆叠的过程,继承自nn.Module,实现两个方法:

  • __init__方法中定义网络中的层结构,主要是全连接层,并进行初始化。
  • forward方法,在实例化模型的时候,底层会自动调用该函数。该函数中可以定义学习率,为初始化定义的layer传入数据等。

我们来构建如下图所示的神经网络模型:
在这里插入图片描述

编码设计如下:

  1. 第1个隐藏层:权重初始化采用标准化的xavier初始化 激活函数使用sigmoid。
  2. 第2个隐藏层:权重初始化采用标准化的He初始化 激活函数采用relu。
  3. out输出层线性层 假若二分类,采用softmax做数据归一化。
# 创建神经网络
import torch
import torch.nn as nn
# pip install torchsummary
from torchsummary import summary # 计算模型参数,查看模型结构 pip install torchsummary
# 创建神经网络模型类
class Model(nn.Module):# 初始化属性值def __init__(self):# 调用父类的初始化属性值super(Model, self).__init__()# 创建第一个隐藏层模型,3个输入特征,3个输出特征self.linear1 = nn.Linear(3, 3)# 初始化权重 xavier 均匀分布初始化nn.init.xavier_uniform_(self.linear1.weight)# 创建第二个隐藏层,3个输入特征(上一层的输出特征),2个输出特征self.linear2 = nn.Linear(3, 2)# 初始化权重 kaiming 正太分布初始化nn.init.kaiming_normal_(self.linear2.weight)# 创建输出层模型self.out = nn.Linear(2, 2)# 创建向前传播方法,自动执行 forward()方法def forward(self, x):# 数据经过第一个线性层x = self.linear1(x)# 使用 sigmoid 激活函数x = torch.sigmoid(x)# 数据经过第二个线性层x = self.linear2(x)# 使用 relu 激活函数x = torch.relu(x)# 数据经过输出层x = self.out(x)# 使用 softmax 激活函数# dim=-1:每一维度行数据相机为1x = torch.softmax(x, dim=-1)return xif __name__ == '__main__':# 实例化model对象model = Model()# 随机产生数据data = torch.randn(5,3)print('data.shape',data.shape)# 数据经过神经网络模型训练out = model(data)print('out.shape',out.shape)# 计算模型参数# 计算每层每个神经元的 w 和 b 个数总和summary(model,input_size=(3,),batch_size=5)# 查看模型参数print("======查看模型参数w和b======")for name, param in model.named_parameters():print(name, param)
  • 神经网络的输入数据是为[batch_size, in_features]的张量经过网络处理后获取了[batch_size, out_features]的输出张量。

  • 在上述例子中,batch_size=5, in_features=3,out_features=2,结果如下所示:

    data.shape torch.Size([5, 3])
    out.shape torch.Size([5, 2])
    

    模型参数输出:

    ----------------------------------------------------------------Layer (type)               Output Shape         Param #
    ================================================================Linear-1                     [5, 3]              12Linear-2                     [5, 2]               8Linear-3                     [5, 2]               6
    ================================================================
    Total params: 26
    Trainable params: 26
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.00
    Forward/backward pass size (MB): 0.00
    Params size (MB): 0.00
    Estimated Total Size (MB): 0.00
    ----------------------------------------------------------------
    ======查看模型参数w和b======
    linear1.weight Parameter containing:
    tensor([[ 0.3857,  0.4809, -0.0346],[ 0.3645,  0.2803, -0.6291],[ 0.1999, -0.6617,  0.7724]], requires_grad=True)
    linear1.bias Parameter containing:
    tensor([0.3084, 0.5636, 0.4501], requires_grad=True)
    linear2.weight Parameter containing:
    tensor([[ 0.1063,  0.7494,  0.4311],[-1.4152,  0.3396, -0.8590]], requires_grad=True)
    linear2.bias Parameter containing:
    tensor([-0.3771,  0.2937], requires_grad=True)
    out.weight Parameter containing:
    tensor([[-0.6012,  0.4727],[-0.2953, -0.5854]], requires_grad=True)
    out.bias Parameter containing:
    tensor([-0.3271,  0.4940], requires_grad=True)
    

模型参数的计算:

  1. 以第一个隐层为例:该隐层有3个神经元,每个神经元的参数为:4个(w1,w2,w3,b1),所以一共用3x4=12个参数。
  2. 输入数据和网络权重是两个不同的事儿!对于初学者理解这一点十分重要,要分得清。
    在这里插入图片描述

2. 神经网络的优缺点

  1. 优点
    ➢ 精度高,性能优于其他的机器学习算法,甚至在某些领域超过了人类。
    ➢ 可以近似任意的非线性函数。
    ➢ 近年来在学界和业界受到了热捧,有大量的框架和库可供调。
  2. 缺点
    ➢ 黑箱,很难解释模型是怎么工作的。
    ➢ 训练时间长,需要大量的计算资源。
    ➢ 网络结构复杂,需要调整超参数。
    ➢ 部分数据集上表现不佳,容易发生过拟合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/64492.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Dcoker Redis哨兵模式集群介绍与搭建 故障转移 分布式 Java客户端连接

介绍 Redis 哨兵模式(Sentinel)是 Redis 集群的高可用解决方案,它主要用于监控 Redis 主从复制架构中的主节点和从节点的状态,并提供故障转移和通知功能。通过 Redis 哨兵模式,可以保证 Redis 服务的高可用性和自动故…

机器学习之交叉熵

交叉熵(Cross-Entropy)是机器学习中用于衡量预测分布与真实分布之间差异的一种损失函数,特别是在分类任务中非常常见。它源于信息论,反映了两个概率分布之间的距离。 交叉熵的数学定义 对于分类任务,假设我们有&#…

Scala的泛型界限

泛型界限 上限 泛型的上限,下限。对类型的更加具体的约束! 如果给某个泛型设置了上界:这里的类型必须是上界 如果给某个泛型设置了下界:这里的类型必须是下界

vscode中同时运行两个python文件(不用安装插件)

如何在vscode中同时运行两个python文件呢?今天在工作中遇到了这个问题。 查了网上的方法是安装coder runner插件,后来发现自身就有这个功能。所以记录一下,方便后续查找: 这是我的第一个文件,点击右上角的运行旁边的小箭头,有一…

python rabbitmq实现简单/持久/广播/组播/topic/rpc消息异步发送可配置Django

windows首先安装rabbitmq 点击参考安装 1、环境介绍 Python 3.10.16 其他通过pip安装的版本(Django、pika、celery这几个必须要有最好版本一致) amqp 5.3.1 asgiref 3.8.1 async-timeout 5.0.1 billiard 4.2.1 celery 5.4.0 …

Numpy基本介绍

目录 1、Numpy的优势 1.1、ndarray介绍 1.2、ndarray与Python原生list运算效率对比 1.3、ndarray的优势 1.3.1、内存块风格 1.3.2、ndarray支持并行化运算(向量化运算) 1.3.3、效率远高于纯Python代码 2、N维数组-ndarray 2.1、ndarray的属性 2.2、ndarray的形状 2…

XXE练习

pikachu-XXE靶场 1.POC:攻击测试 <?xml version"1.0"?> <!DOCTYPE foo [ <!ENTITY xxe "a">]> <foo>&xxe;</foo> 2.EXP:查看文件 <?xml version"1.0"?> <!DOCTYPE foo [ <!ENTITY xxe SY…

正则表达式在线校验(RegExp) - 加菲工具

正则表达式在线校验 - 加菲工具 打开网站 加菲工具 选择“正则表达式在线校验” 或者直接打开https://www.orcc.online/tools/regexp 输入待校验的源文本与正则表达式&#xff0c;点击“校验”按钮 需要注意检验后的内容可能存在多空格&#xff0c;可以拉下去看看~

java后端环境配置

因为现在升学了&#xff0c;以前本来想毕业干java的&#xff0c;很多java的环境配置早就忘掉了&#xff08;比如mysql maven jdk idea&#xff09;&#xff0c;想写个博客记录下来&#xff0c;以后方便自己快速搭建环境 JAVA后端开发配置 环境配置jdkideamavenMySQLnavicate17…

51c嵌入式~合集3

我自己的原文哦~ https://blog.51cto.com/whaosoft/12847563 一、UCIe 2.0 日前&#xff0c;通用芯粒互连&#xff08;UCIe&#xff09;产业联盟最新公布了 UCIe 2.0 规范&#xff0c;支持可管理性标准化系统架构&#xff0c;并全面解决了系统级封装&#xff08;SiP&#x…

解决电脑网速慢问题:硬件检查与软件设置指南

电脑网速慢是许多用户在使用过程中常见的问题&#xff0c;它不仅会降低工作效率&#xff0c;还可能影响娱乐体验。导致电脑网速慢的原因多种多样&#xff0c;包括硬件问题、软件设置和网络环境等。本文将从不同角度分析这些原因&#xff0c;并提供提高电脑网速的方法。 一、检查…

6、AI测试辅助-测试报告编写(生成Bug分析柱状图)

AI测试辅助-测试报告编写&#xff08;生成Bug分析柱状图&#xff09; 一、测试报告1. 创建测试报告2. 报告补充优化2.1 Bug图表分析 3. 风险评估 总结 一、测试报告 测试报告内容应该包含&#xff1a; 1、测试结论 2、测试执行情况 3、测试bug结果分析 4、风险评估 5、改进措施…

【C++ 】for 循环系统深入解析与实现法比较

博客主页&#xff1a; [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 &#x1f4af;前言&#x1f4af;for 循环的基本语法格式语法格式&#xff1a;格式一&#xff1a;单行语句的 for 循环格式二&#xff1a;多行语句的 for 循环循环流程图实例代码 for 循环中变量初始化的作…

Vue3安装配置、开发环境搭建(组件安装卸载)(图文详细)

Vue3安装配置、开发环境搭建(组件安装卸载)&#xff08;图文详细&#xff09; 本文目录&#xff1a; 一、vue的主要安装使用方式 二、node.js安装和配置 1、支持运行 Node.js的平台 2、Node.js 版本开发发布时间表&#xff08;日期可能会有变化&#xff09; 3、下载安装n…

Oracle 适配 OpenGauss 数据库差异语法汇总

背景 国产化进程中&#xff0c;需要将某项目的数据库从 Oracle 转为 OpenGauss &#xff0c;项目初期也是规划了适配不同数据库的&#xff0c;MyBatis 配置加载路径设计的是根据数据库类型加载指定文件夹的 xml 文件。 后面由于固定了数据库类型为 Oracle 后&#xff0c;只写…

Vue进阶之状态管理,解锁项目开发超能力

一、概念 状态管理是指对应用程序中状态的管理。在软件领域&#xff0c;状态是指在某个特定时刻&#xff0c;应用程序的数据和行为表现。 以一个简单的购物网站为例&#xff0c;购物车中的商品列表、用户的登录状态等都是状态。状态管理主要涉及这些状态如何被存储、更新和在…

操作系统(16)I/O软件

前言 操作系统I/O软件是负责管理和控制计算机系统与外围设备&#xff08;例如键盘、鼠标、打印机、存储设备等&#xff09;之间交互的软件。 一、I/O软件的定义与功能 定义&#xff1a;I/O软件&#xff0c;也称为输入/输出软件&#xff0c;是计算机系统中用于管理和控制设备与主…

游戏AI实现-寻路算法(Dijkstra)

戴克斯特拉算法&#xff08;英语&#xff1a;Dijkstras algorithm&#xff09;&#xff0c;又称迪杰斯特拉算法、Dijkstra算法&#xff0c;是由荷兰计算机科学家艾兹赫尔戴克斯特拉在1956年发现的算法。 算法过程&#xff1a; 1.首先设置开始节点的成本值为0&#xff0c;并将…

CTFshow-文件上传(Web151-170)

CTFshow-文件上传(Web151-170) 参考了CTF show 文件上传篇&#xff08;web151-170&#xff0c;看这一篇就够啦&#xff09;-CSDN博客 Web151 要求png&#xff0c;然后上传带有一句话木马的a.png&#xff0c;burp抓包后改后缀为a.php&#xff0c;然后蚁剑连接&#xff0c;找fl…

Unity超优质动态天气插件(含一年四季各种天气变化,可用于单机局域网VR)

效果展示&#xff1a;https://www.bilibili.com/video/BV1CkkcYHENf/?spm_id_from333.1387.homepage.video_card.click 在你的项目中设置enviro真的很容易&#xff01;导入包裹并按照以下步骤操作开始的步骤&#xff01; 1. 拖拽“EnviroSky”预制件&#xff08;“environme…