昇思25天学习打卡营第5天 | 网络构建

目录

1.定义模型类

2.模型层

nn.Flatten

nn.Dense

nn.ReLU

nn.SequentialCell

nn.Softmax

3.模型参数

代码实现:

总结


神经网络模型是由神经网络层和Tensor操作构成的,

mindspore.nn提供了常见神经网络层的实现,

在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。

一个神经网络模型表示为一个Cell,它由不同的子Cell构成。

使用这样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理

1.定义模型类

定义神经网络时,可以继承nn.Cell类,在__init__方法中进行子Cell的实例化和状态管理,在construct方法中实现Tensor操作。

construct意为神经网络(计算图)构建

构建完成后,实例化Network对象,并查看其结构:

三个全连接层(Dense)和两个ReLU激活函数的序列模型

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 10, weight_init="normal", bias_init="zeros"))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)

我们构造一个输入数据,直接调用模型,可以获得一个十维的Tensor输出,其包含每个类别的原始预测值。

model.construct()方法不可直接调用。

在此基础上,我们通过一个nn.Softmax层实例来获得预测概率。

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits
pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

2.模型层

我们分解上面构造的神经网络模型中的每一层。首先我们构造一个shape为(3, 28, 28)的随机数据(3个28x28的图像),依次通过每一个神经网络层来观察其效果。

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)

nn.Flatten

实例化nn.Flatten层,将28x28的2D张量转换为784大小的连续数组。

nn.Dense

nn.Dense为全连接层,其使用权重和偏差对输入进行线性变换。

nn.ReLU

nn.ReLU层给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。

nn.SequentialCell

nn.SequentialCell是一个有序的Cell容器。输入Tensor将按照定义的顺序通过所有Cell。我们可以使用SequentialCell来快速组合构造一个神经网络模型。

nn.Softmax

最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

3.模型参数

网络内部神经网络层具有权重参数和偏置参数(如nn.Dense),这些参数会在训练过程中不断进行优化,可通过 model.parameters_and_names() 来获取参数名及对应的参数详情。

代码实现:

总结

构建网络,定义模型类时,要有这个框架,继承类,在他里面进行实例化和状态管理:

  1. class Network(nn.Cell): 定义了一个类,继承自 nn.Cell

  2. def __init__(self):  Network 类的构造函数,初始化类的属性。

  3. super().__init__(): 调用父类 nn.Cell 的构造函数。

  4. def construct(self, x): 定义了 Network 类的 construct 方法,它是MindSpore中定义模型前向传播逻辑的方法。参数 x 表示输入数据。

  5. x = self.flatten(x): 使用 self.flatten 层将输入数据 x 展平。

  6. logits = self.dense_relu_sequential(x): 将展平后的数据 x 通过 self.dense_relu_sequential 序列模型进行前向传播,得到模型的原始输出 logits。在分类任务中,logits 是模型的线性输出

  7. return logits: 返回模型的输出 logits

class Network(nn.Cell):def __init__(self):super().__init__()def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logits
  1. self.flatten = nn.Flatten(): 初始化一个 nn.Flatten 层,这个层用于将多维输入数据展平为一维数据。在处理图像数据时,通常需要将图像的二维数据(例如,28x28像素)展平为一维向量。

  2. self.dense_relu_sequential = nn.SequentialCell(...): 初始化一个序列模型,包含三个全连接层(nn.Dense)和两个ReLU激活函数(nn.ReLU)。这个序列模型的初始化与之前解释的相同。

预测的时候:

1. pred_probab = nn.Softmax(axis=1)(logits): 使用了 nn.Softmax 函数来将模型的输出 logits 转换为概率分布。

Softmax 函数通常用于多类分类问题的输出层,它可以将一个向量的元素转换为一个概率分布,使得所有元素的和为1。

参数 axis=1 表示 Softmax 函数将在第二个维度(通常是特征维度)上应用,即对于每个样本,将其对应的 logits 转换为概率。

2. y_pred = pred_probab.argmax(1): 这行代码使用了 argmax 函数来找到每个样本概率最高的类别索引。argmax 函数返回输入数组中最大元素的索引。在这里,它沿着第二个维度(即每个样本的概率分布)找到最大值的索引,这代表了模型预测的类别。

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

别的也没什么了吧~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/38180.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

启动spring boot项目停止 提示80端口已经被占用

可能的情况: 检查并结束占用进程: 首先,你需要确定哪个进程正在使用80端口。在Windows上,可以通过命令行输入netstat -ano | findstr LISTENING | findstr :80来查看80端口的PID,然后在任务管理器中结束该进程。在

AI智能客服项目拆解(1) 产品大纲

本文作为拆解AI智能客服项目的首篇,以介绍产品大纲为主。后续以某AI智能客服产品为例,拆解相关技术细节。 AI智能客服是一种基于人工智能技术的客户服务解决方案,旨在提高客户满意度和优化企业运营。利用人工智能和自然语言处理技术&#xff…

MySQL之索引失效的情况

什么情况下索引会失效? 违反最左前缀原则范围查询右边的列不能使用索引不要在索引列上进行运算操作字符串不加单引号导致索引失效以%开头的like模糊查询 什么情况下索引会失效? 示例,有user表如下 CREATE TABLE user (id bigint(20) NOT NU…

实验1 多层感知器设计(MLP)

1.实验目的 掌握多层感知器的原理。掌握多层感知器的设计、训练和测试。2.实验要求 设计一个多层感知器,用于对给定的数据进行分类。要求代码格式规范,注释齐全,程序可正常运行。 3.模型设计 实验设计一个多层感知机,三层机构,只含一个隐藏层,输入层,隐藏层,输出层 1…

JAVA期末速成库(11)第十二章

一、习题介绍 第十二章 Check Point:P454 12.1,12.9,12.10,12,12 二、习题及答案 12.1 What is the advantage of using exception handling? 12.1使用异常处理的优势是什么? 答:使用异常处理有以下优势: 1. 提高…

C++ 模板类的示例-数组

类模板可以有非通用类型参数:1)通常是整型(C20标准可以用其它的类型);2)实例化模板时必须用常量表达式;3)模板中不能修改参数的值;4)可以为非通用类型参数提供…

Android中使用performClick触发点击事件

Android中使用performClick触发点击事件 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!今天我们将探讨在Android开发中如何使用 performClick() 方法来触发点击…

数据库-python SQLite3

数据库-python SQLite3 一:sqlite3 简介二: sqlite3 流程1> demo2> sqlite3 流程 三:sqlite3 step1> create table2> insert into3> update4> select1. fetchall()2. fetchone()3. fetchmany() 5> delete6> other step 四&#…

Spark join数据倾斜调优

Spark中常见的两种数据倾斜现象如下 stage部分task执行特别慢 一般情况下是某个task处理的数据量远大于其他task处理的数据量,当然也不排除是程序代码没有冗余,异常数据导致程序运行异常。 作业重试多次某几个task总会失败 常见的退出码143、53、137…

【电路笔记】-放大器类型

放大器类型 文章目录 放大器类型1、概述2、关于偏置的注意事项3、A类(Class A)放大器4、B类(Class B)放大器5、AB类(Class AB)放大器6、C类(Class C)放大器7、总结1、概述 放大器通常根据输出级的结构进行分类。 事实上,功率放大确实发生在该阶段,因此输出信号的质量和…

Arduino (esp ) 下String的内存释放

在个人的开源项目 GitHub - StarCompute/tftziku: 这是一个通过单片机在各种屏幕上显示中文的解决方案 中为了方便快速检索使用了string,于是这个string在esp8266中占了40多k,原本以为当string设置为""的时候这个40k就可以回收,结果发觉不行…

【JS异步编程】async/await——用同步代码写异步

历史小剧场 懂得暴力的人,是强壮的;懂得克制暴力的人,才是强大的。----《明朝那些事儿》 什么是 async/await async: 声明一个异步函数 自动将常规函数转换成Promise,返回值也是一个Promise对象;只有async函数内部的异…

Java SE入门及基础(59) 线程的实现(上) 线程的创建方式 线程内存模型 线程安全

目录 线程(上) 1. 线程的创建方式 Thread类常用构造方法 Thread类常用成员方法 Thread类常用静态方法 示例 总结 2. 线程内存模型 3.线程安全 案例 代码实现 执行结果 线程(上) 1. 线程的创建方式 An application t…

利用 Docker 简化 Nacos 部署:快速搭建 Nacos 服务

利用 Docker 简化 Nacos 部署:快速搭建 Nacos 服务 引言 在微服务架构中,服务注册与发现是确保服务间通信顺畅的关键组件。Nacos(Dynamic Naming and Configuration Service)作为阿里巴巴开源的一个服务发现和配置管理平台&…

任务调度器——任务切换

一、开启任务调度器 函数原型: void vTaskStartScheduler( void ) 作用:用于启动任务调度器,任务调度器启动后, FreeRTOS 便会开始进行任务调度 内部实现机制(以动态创建为例): &#xff0…

Linux 安装、配置Tomcat 的HTTPS

Linux 安装 、配置Tomcat的HTTPS 安装Tomcat 这里选择的是 tomcat 10.X ,需要Java 11及更高版本 下载页 ->Binary Distributions ->Core->选择 tar.gz包 下载、上传到内网服务器 /opt 目录tar -xzf 解压将解压的根目录改名为 tomat-10 并移动到 /opt 下, 形成个人…

测评推荐:企业管理u盘的软件有哪些?

U盘作为一种便携的存储设备,方便易用,被广泛应用于企业办公、个人学习及日常工作中。然而,U盘的使用也带来了数据泄露、病毒传播等安全隐患。为了解决这些问题,企业管理U盘的软件应运而生。 本文将对市面上流行的几款U盘管理软件…

Hadoop3:Yarn容量调度器配置多队列案例

一、情景描述 需求1: default队列占总内存的40%,最大资源容量占总资源60%,hive队列占总内存的60%,最大资源容量占总资源80%。 二、多队列优点 (1)因为担心员工不小心,写递归死循环代码&#…

数据处理:四选一、四关联

今天去面试,面试官们给我一个‘选择’,有四个选项:‘展示你的才华’、‘展示你的美貌’、‘展示你的才华与美貌’、‘都不展示’ {label: “选择”,children: [{label: “展示你的才华”,children: [],isShow: talentModal,click: () > {i…

电路笔记(电源模块): 基于FT2232HL实现的jtag下载器硬件+jtag的通信引脚说明

JTAG接口说明 JTAG 接口根据需求可以选择20针或14针的配置,具体选择取决于应用场景和需要连接的功能。比如之前的可编程逻辑器件XC9572XL使用JTAG引脚(TCK、TDI、TDO、TMS、VREF、GND)用于与器件进行调试和编程通信。更详细的内容可以阅读11…