神经网络框架的基本设计

一、神经网络框架设计的基本流程

确定网络结构、激活函数、损失函数、优化算法,模型的训练与验证,模型的评估和优化,模型的部署。

二、网络结构与激活函数

1、网络架构

这里我们使用的是多层感知机模型MLP(multilayer prrceptron):

MLP一般分为三层:输入层、隐藏层和输出层。
输入层:接收输入数据。
隐藏层:负责处理数据,可以有一个或多个隐藏层,每个隐藏层包含若干个神经元。
输出层:输出最终结果,通常是一个softmax层,用于多分类任务,或者是一个sigmoid层,用于二分类任务。
MLP的核心思想是通过增加神经元的数量和层次,提高模型的表达能力。由于有多个层,参数需要在这些层之间传递。

每个隐藏层神经元中计算过程如下:
将输入数据传递给第一个隐藏层的神经元。
对于每个神经元,计算其加权和,即将输入与对应的权重相乘并求和,再加上偏置项。
将加权和输入到激活函数中,得到激活值,作为该神经元的输出。
将每个神经元的输出传递到下一层的神经元,直至输出层。
在这个过程中,数据和权重是前向传播的主要传播内容。

class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28,312),nn.ReLU(),nn.Linear(312, 256),nn.ReLU(),nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logits

2、激活函数:

用于引入非线性因素,使得神经网络具有更强的表达能力。常见的激活函数有Sigmoid、ReLU、Tanh等。这里我们暂时略过

3、损失函数:

损失函数是用来量化模型预测值与真实值差距的方式(这么说可能过于抽象,我们初中数学应该学过标准差和方差的概念,方差和标准差是量化数组各元素与平均值的差距的。)而能反映二者间差距的值有很多,常见的有均方损失、交叉熵损失、铰链损失、绝对值损失。

具体什么含义可以自行学习。

损失函数的作用是用来评估模型的准确度的,

在上次实验中,我们使用的是均方损失,这一次我们使用交叉熵损失函数:

loss_fu = torch.nn.CrossEntropyLoss()
loss = loss_fu(pred, label_batch)

torch.nn模块:是构建神经网络的基石,提供了各种类型的层,包括卷积层、池化层、激活函数、循环层和全连接层

4、优化算法

本次还是采用:自适应学习率的优化算法,adam优化器。

三、模型的训练

1、数据集调整

此次的数据集依然是mnist数据集,下载与处理可以看Hello World!-CSDN博客。不过对于idx3-udyte文件,如果我们每次都需要这么读一下,那多少有些麻烦,可以将idx3-udyte文件转换成npy文件以便长期存储。

npy文件是NumPy数组的一种存储格式,用于将Python中的NumPy数组保存到磁盘上,并可以在以后需要时加载回来。它不仅存储了数组的数据,还存储了数组的形状、数据类型等信息。这使得.npy文件非常紧凑且占用空间小,并且可以快速地读写。由于它是二进制格式,所以不能直接用文本编辑器打开查看内容。非常适合在需要频繁保存和加载数组数据的场景下使用,例如机器学习中的模型参数保存、数据集的存储等。

我们回到上期的读取文件的部分

​
def read_data3(self, roadurl):with open(roadurl, 'rb') as f:content = f.read()# print(content)fmt_header = '>iiii'  # 网络字节序offset = 0magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, content, offset)print('幻数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))# 定义一张图片需要的数据个数(每个像素一个字节,共需要行*列个字节的数据)img_size = num_rows * num_cols# struct.calcsize(fmt)用来计算fmt格式所描述的结构的大小offset += struct.calcsize(fmt_header)# '>784B'是指用大端法读取784个unsigned bytefmt_image = '>' + str(img_size) + 'B'# 定义了一个三维数组,这个数组共有num_images个 num_rows*num_cols尺寸的矩阵。images = np.empty((num_images, num_rows, num_cols))for i in range(num_images):images[i] = np.array(struct.unpack_from(fmt_image, content, offset)).reshape((num_rows, num_cols))offset += struct.calcsize(fmt_image)return images​

我们已经接收到了images这个numpy数组,接下来只需要增加一点代码

# 保存到
np.save('文件路径/文件名.npy', array)
# 读取文件
data1 = np.load('文件路径/文件名.npy')

2、模型搭建

import osos.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 指定GPU编
import torch
import numpy as np
from tqdm import tqdmbatch_size = 320  # 设定每次训练的批次数
epochs = 1024  # 设定训练次数# device = "cpu"                         # Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"  # 在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式# 设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28 * 28, 312),torch.nn.ReLU(),torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()
model = model.to(device)  # 将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
# model = torch.compile(model)            # Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()       # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)  # 设定优化函数# 载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")train_num = len(x_train) // batch_size# 开始计算
for epoch in range(20):train_loss = 0for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizetrain_batch = torch.tensor(x_train[start:end]).to(device)label_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(train_batch)loss = loss_fu(pred, label_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:", epoch, "train_loss:", round(train_loss, 2), "accuracy:", round(accuracy, 2))
torch.save(model, './model.pth')
print("模型已更新")

3、训练完成

四、可视化操作

1、代码

import torch
import torch.nn as nnif __name__ == '__main__':model = NeuralNetwork()#print(model)params = list(model.parameters())k = 0for i in params:l = 1print("该层的结构:" + str(list(i.size())))for j in i.size():l *= jprint("该层参数和:" + str(l))k = k + lprint("总参数数量和:" + str(k))

2、使用netron软件进行可视化(推荐)

在github上下载:https://github.com/lutzroeder/netron

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/598677.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何委婉地告诉老板,BI连接金蝶ERP,对决策更有利?

网友:新入职一家企业,发现这家企业依旧是从金蝶ERP中导出数据做分析,这样数据量一大、科目变动多就很难保证数据分析的及时性、灵活性,说真的这对决策来说并不是什么好事。但老板似乎并不觉得这有什么不对。我该如何委婉地告诉老板…

Java 新手如何使用Spring MVC 中的查询字符串和查询参数

目录 前言 什么是查询字符串和查询参数? Spring MVC中的查询参数 处理可选参数 处理多个值 处理查询参数的默认值 处理查询字符串 示例:创建一个RESTful服务 总结 作者简介: 懒大王敲代码,计算机专业应届生 今天给大家…

基于SSM的班级事务管理系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用Vue技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

《手把手教你》系列练习篇之6-python+ selenium自动化测试(详细教程)

1. 简介 今天我们还是继续练习练习基本功,各位小伙伴要耐住住性子,要耐得住寂寞啊,不要急躁,后面你会感谢你在前边的不断练习的。到后面也是检验你前边的学习成果的一次很好实践。 本文介绍如何通过link text、partial link text…

PHP运行环境之宝塔软件安装及Web站点部署流程

PHP运行环境之宝塔软件安装及Web站点部署流程 1.1安装宝塔软件 官网:https://www.bt.cn/new/index.html 自行注册账号,稍后有用 下载安装页面:宝塔面板下载,免费全能的服务器运维软件 1.1.1Linux 安装 如图所示,宝…

Image - 体积最小的 base64 encode 1*1透明图片,透明背景图片base64编码

背景 前端开发时&#xff0c;有些<img>标签的src属性的值来源于接口&#xff0c;在接口获取结果之前&#xff0c;这个src应该设置为什么呢&#xff1f; 误区&#xff1a;设置为# 有人把src设置为<img src"#" />。 这是有问题的&#xff0c;浏览器解析…

2024.1.5每日一题

LeetCode每日一题 1944.队列中可以看到的人数 1944. 队列中可以看到的人数 - 力扣&#xff08;LeetCode&#xff09; 题目描述 有 n 个人排成一个队列&#xff0c;从左到右 编号为 0 到 n - 1 。给你以一个整数数组 heights &#xff0c;每个整数 互不相同&#xff0c;heig…

Windows 下用 C++ 调用 Python

文章目录 Part.I IntroductionChap.I InformationChap.II 预备知识 Part.II 语法Chap.I PyRun_SimpleStringChap.II C / Python 变量之间的相互转换 Part.III 实例Chap.I 文件内容Chap.II 基于 Visual Studio IDEChap.III 基于 cmakeChap.IV 运行结果 Part.IV 可能出现的问题Ch…

硬盘检测软件 SMART Utility mac功能特色

SMART Utility for mac是一款苹果电脑上磁盘诊断工具&#xff0c;能够自动检测磁盘的状态和错误情况&#xff0c;分析并提供错误报告,以直观的界面让用户可明确地知道自己的磁盘状况。SMART Utility 支持普通硬盘HDD和固态硬盘SSD&#xff0c;能够显示出详细的磁盘信息&#xf…

C# OpenCvSharp DNN Gaze Estimation

目录 介绍 效果 模型信息 项目 代码 frmMain.cs GazeEstimation.cs 下载 C# OpenCvSharp DNN Gaze Estimation 介绍 训练源码地址&#xff1a;https://github.com/deepinsight/insightface/tree/master/reconstruction/gaze 效果 模型信息 Inputs ----------------…

Java重修第一天—学习数组

1. 认识数组 建议1.5倍速学习&#xff0c;并且关闭弹幕。 数组的定义&#xff1a;数组是一个容器&#xff0c;用来存储一批同种类型的数据。 下述图&#xff1a;是生成数字数组和字符串数组。 为什么有了变量还需要定义数组呢&#xff1f;为了解决在某些场景下&#xff0c;变…

ZigBee协议栈 -- 协议栈版本与IAR版本适配说明(Zstack2.5.1a + IAR10.30.1)

文章目录 协议栈安装工程适配 在讲到ZigBee协议栈的文章中所用的协议栈版本是Zstack2.5.1a&#xff0c;对于Zstack2.5.1a运行在IAR8.10中是可以完全适配进行编译开发的&#xff0c;现在较新版本的IAR都是10的版本以上了&#xff0c;有部分开发者习惯使用最新版本来获得更好的开…

正则表达式的语法

如果要想灵活的运用正则表达式&#xff0c;必须了解其中各种元素字符的功能&#xff0c;元字符从功能上大致分为&#xff1a; 限定符 选择匹配符 分组组合和反向引用符 特殊字符 字符匹配符 定位符 我们先说一下元字符的转义号 元字符(Metacharacter)-转义号 \\ \\ 符号…

关于系统设计的一些思考

0.前言 当我们站在系统设计的起点&#xff0c;面对一个新的需求&#xff0c;我们该如何开始呢&#xff1f;这是许多处于系统分析与设计领域的新手常常思考的问题。有些人可能会误以为&#xff0c;只要掌握了诸如面向对象、统一建模语言、设计模式、微服务、Serverless、Servic…

《操作系统导论》笔记

操作系统三个关键&#xff1a;虚拟化( virtualization) 并发(concurrency) 持久性&#xff08;persistence&#xff09; 1 CPU虚拟化 1.1 进程 虚拟化CPU&#xff1a;许多任务共享物理CPU&#xff0c;让它们看起来像是同时运行。 时分共享&#xff1a;运行一个进程一段时间…

前端常用的几种算法的特征、复杂度、分类及用法示例演示

算法&#xff08;Algorithm&#xff09;可以理解为有基本运算及规定的运算顺序所构成的完整的解题步骤&#xff0c;或者看成按照要求设计好的有限的确切的计算序列&#xff0c;并且这样的步骤和序列可以解决一类问题。算法代表着用系统的方法描述解决问题的策略机制&#xff0c…

嵌入式Linux之MX6ULL裸机开发学习笔记(IMX启动方式-启动设备的选择)

一,硬件启动方式选择 1.启动方式的选择 6ull支持多种启动方式。 比如可以从 SD/EMMC、 NAND Flash、 QSPI Flash等启动。 6ull是怎么支持多种外置flash启动程序的。 1.启动方式选择&#xff1a; BOOT_MODE0 and BOOT_MODE1&#xff0c;这两个是两个IO来控制的&#xff0c;…

C#上位机与欧姆龙PLC的通信09----开发专用的通讯工具软件(Winform版)

1、介绍 上节文章已经完成了通讯库的开发&#xff0c;可以看到库还是蛮厉害的&#xff0c;在项目中就可以直接拿来应用&#xff0c;这节要做的就是做一个工具软件&#xff0c;形成自己专业的通讯工具&#xff0c;也是对通讯库的直接利用&#xff0c;本节要写的工具软件是一个w…

【HBase】——优化

1 RowKey设计 重要&#xff1a;一条数据的唯一标识就是 rowkey&#xff0c;那么这条数据存储于哪个分区&#xff0c;取决于 rowkey 处于 哪个一个预分区的区间内&#xff0c;设计 rowkey的主要目的 &#xff0c;就是让数据均匀的分布于所有的 region 中&#xff0c;在一定程度…

如何查找iPhone中所有的应用程序

​ ​ Apple 的 App Store 共有约 200 万个适用于 iPhone 和 iPad 的应用程序。如果您像我们一样&#xff0c;您的 iOS 或 iPadOS 设备上可能有数十个应用程序&#xff0c;但没有机会将它们全部整理好。您很容易忘记主屏幕上应用程序图标的位置。 幸运的是&#xff0c;iPhone…