昇思25天学习打卡营第5天 | 网络构建

目录

1.定义模型类

2.模型层

nn.Flatten

nn.Dense

nn.ReLU

nn.SequentialCell

nn.Softmax

3.模型参数

代码实现:

总结


神经网络模型是由神经网络层和Tensor操作构成的,

mindspore.nn提供了常见神经网络层的实现,

在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。

一个神经网络模型表示为一个Cell,它由不同的子Cell构成。

使用这样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理

1.定义模型类

定义神经网络时,可以继承nn.Cell类,在__init__方法中进行子Cell的实例化和状态管理,在construct方法中实现Tensor操作。

construct意为神经网络(计算图)构建

构建完成后,实例化Network对象,并查看其结构:

三个全连接层(Dense)和两个ReLU激活函数的序列模型

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 10, weight_init="normal", bias_init="zeros"))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)

我们构造一个输入数据,直接调用模型,可以获得一个十维的Tensor输出,其包含每个类别的原始预测值。

model.construct()方法不可直接调用。

在此基础上,我们通过一个nn.Softmax层实例来获得预测概率。

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits
pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

2.模型层

我们分解上面构造的神经网络模型中的每一层。首先我们构造一个shape为(3, 28, 28)的随机数据(3个28x28的图像),依次通过每一个神经网络层来观察其效果。

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)

nn.Flatten

实例化nn.Flatten层,将28x28的2D张量转换为784大小的连续数组。

nn.Dense

nn.Dense为全连接层,其使用权重和偏差对输入进行线性变换。

nn.ReLU

nn.ReLU层给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。

nn.SequentialCell

nn.SequentialCell是一个有序的Cell容器。输入Tensor将按照定义的顺序通过所有Cell。我们可以使用SequentialCell来快速组合构造一个神经网络模型。

nn.Softmax

最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

3.模型参数

网络内部神经网络层具有权重参数和偏置参数(如nn.Dense),这些参数会在训练过程中不断进行优化,可通过 model.parameters_and_names() 来获取参数名及对应的参数详情。

代码实现:

总结

构建网络,定义模型类时,要有这个框架,继承类,在他里面进行实例化和状态管理:

  1. class Network(nn.Cell): 定义了一个类,继承自 nn.Cell

  2. def __init__(self):  Network 类的构造函数,初始化类的属性。

  3. super().__init__(): 调用父类 nn.Cell 的构造函数。

  4. def construct(self, x): 定义了 Network 类的 construct 方法,它是MindSpore中定义模型前向传播逻辑的方法。参数 x 表示输入数据。

  5. x = self.flatten(x): 使用 self.flatten 层将输入数据 x 展平。

  6. logits = self.dense_relu_sequential(x): 将展平后的数据 x 通过 self.dense_relu_sequential 序列模型进行前向传播,得到模型的原始输出 logits。在分类任务中,logits 是模型的线性输出

  7. return logits: 返回模型的输出 logits

class Network(nn.Cell):def __init__(self):super().__init__()def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logits
  1. self.flatten = nn.Flatten(): 初始化一个 nn.Flatten 层,这个层用于将多维输入数据展平为一维数据。在处理图像数据时,通常需要将图像的二维数据(例如,28x28像素)展平为一维向量。

  2. self.dense_relu_sequential = nn.SequentialCell(...): 初始化一个序列模型,包含三个全连接层(nn.Dense)和两个ReLU激活函数(nn.ReLU)。这个序列模型的初始化与之前解释的相同。

预测的时候:

1. pred_probab = nn.Softmax(axis=1)(logits): 使用了 nn.Softmax 函数来将模型的输出 logits 转换为概率分布。

Softmax 函数通常用于多类分类问题的输出层,它可以将一个向量的元素转换为一个概率分布,使得所有元素的和为1。

参数 axis=1 表示 Softmax 函数将在第二个维度(通常是特征维度)上应用,即对于每个样本,将其对应的 logits 转换为概率。

2. y_pred = pred_probab.argmax(1): 这行代码使用了 argmax 函数来找到每个样本概率最高的类别索引。argmax 函数返回输入数组中最大元素的索引。在这里,它沿着第二个维度(即每个样本的概率分布)找到最大值的索引,这代表了模型预测的类别。

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

别的也没什么了吧~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/38180.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI智能客服项目拆解(1) 产品大纲

本文作为拆解AI智能客服项目的首篇,以介绍产品大纲为主。后续以某AI智能客服产品为例,拆解相关技术细节。 AI智能客服是一种基于人工智能技术的客户服务解决方案,旨在提高客户满意度和优化企业运营。利用人工智能和自然语言处理技术&#xff…

MySQL之索引失效的情况

什么情况下索引会失效? 违反最左前缀原则范围查询右边的列不能使用索引不要在索引列上进行运算操作字符串不加单引号导致索引失效以%开头的like模糊查询 什么情况下索引会失效? 示例,有user表如下 CREATE TABLE user (id bigint(20) NOT NU…

JAVA期末速成库(11)第十二章

一、习题介绍 第十二章 Check Point:P454 12.1,12.9,12.10,12,12 二、习题及答案 12.1 What is the advantage of using exception handling? 12.1使用异常处理的优势是什么? 答:使用异常处理有以下优势: 1. 提高…

Spark join数据倾斜调优

Spark中常见的两种数据倾斜现象如下 stage部分task执行特别慢 一般情况下是某个task处理的数据量远大于其他task处理的数据量,当然也不排除是程序代码没有冗余,异常数据导致程序运行异常。 作业重试多次某几个task总会失败 常见的退出码143、53、137…

【电路笔记】-放大器类型

放大器类型 文章目录 放大器类型1、概述2、关于偏置的注意事项3、A类(Class A)放大器4、B类(Class B)放大器5、AB类(Class AB)放大器6、C类(Class C)放大器7、总结1、概述 放大器通常根据输出级的结构进行分类。 事实上,功率放大确实发生在该阶段,因此输出信号的质量和…

Java SE入门及基础(59) 线程的实现(上) 线程的创建方式 线程内存模型 线程安全

目录 线程(上) 1. 线程的创建方式 Thread类常用构造方法 Thread类常用成员方法 Thread类常用静态方法 示例 总结 2. 线程内存模型 3.线程安全 案例 代码实现 执行结果 线程(上) 1. 线程的创建方式 An application t…

利用 Docker 简化 Nacos 部署:快速搭建 Nacos 服务

利用 Docker 简化 Nacos 部署:快速搭建 Nacos 服务 引言 在微服务架构中,服务注册与发现是确保服务间通信顺畅的关键组件。Nacos(Dynamic Naming and Configuration Service)作为阿里巴巴开源的一个服务发现和配置管理平台&…

任务调度器——任务切换

一、开启任务调度器 函数原型: void vTaskStartScheduler( void ) 作用:用于启动任务调度器,任务调度器启动后, FreeRTOS 便会开始进行任务调度 内部实现机制(以动态创建为例): &#xff0…

Linux 安装、配置Tomcat 的HTTPS

Linux 安装 、配置Tomcat的HTTPS 安装Tomcat 这里选择的是 tomcat 10.X ,需要Java 11及更高版本 下载页 ->Binary Distributions ->Core->选择 tar.gz包 下载、上传到内网服务器 /opt 目录tar -xzf 解压将解压的根目录改名为 tomat-10 并移动到 /opt 下, 形成个人…

测评推荐:企业管理u盘的软件有哪些?

U盘作为一种便携的存储设备,方便易用,被广泛应用于企业办公、个人学习及日常工作中。然而,U盘的使用也带来了数据泄露、病毒传播等安全隐患。为了解决这些问题,企业管理U盘的软件应运而生。 本文将对市面上流行的几款U盘管理软件…

Hadoop3:Yarn容量调度器配置多队列案例

一、情景描述 需求1: default队列占总内存的40%,最大资源容量占总资源60%,hive队列占总内存的60%,最大资源容量占总资源80%。 二、多队列优点 (1)因为担心员工不小心,写递归死循环代码&#…

电路笔记(电源模块): 基于FT2232HL实现的jtag下载器硬件+jtag的通信引脚说明

JTAG接口说明 JTAG 接口根据需求可以选择20针或14针的配置,具体选择取决于应用场景和需要连接的功能。比如之前的可编程逻辑器件XC9572XL使用JTAG引脚(TCK、TDI、TDO、TMS、VREF、GND)用于与器件进行调试和编程通信。更详细的内容可以阅读11…

51单片机STC8H8K64U通过RA8889/RA8876如何控制彩屏(SPI源码下载)

【硬件部份】 一、硬件连接实物: STC8H系列单片机不需要外部晶振和外部复位,在相同的工作频率下,速度比传统的8051单片机要快12倍,具有高可靠抗干扰的优秀特性,与瑞佑的RA8889/RA8876控制芯片刚好可以完美搭配用于工…

redis实战-缓存雪崩问题及解决方案

定义理解 缓存雪崩是指在同一时间段,大量缓存的key同时失效,或者Redis服务宕机,导致大量请求到达数据库,带来巨大压力 和缓存击穿的区别: 缓存雪崩是由于缓存中的大量数据同时失效或缓存服务器故障引起的&#xff1b…

机器学习周记(第四十五周:Graphformer)2024.6.24~2024.6.30

目录 摘要ABSTRACT1 论文信息1.1 论文标题1.2 论文摘要1.3 论文引言1.4 论文贡献 2 论文模型2.1 问题定义2.2 模型架构2.2.1 自注意下采样模块(Self-attention down-sampling module)2.2.2 稀疏图自注意力机制(Sparse graph self-attention m…

【C++】using namespace std 到底什么意思

📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文作为 JohnKi 的学习笔记,引用了部分大佬的案例 📢未来很长&a…

新手练习项目 7:猜数字游戏

名人说:莫听穿林打叶声,何妨吟啸且徐行。—— 苏轼《定风波莫听穿林打叶声》 Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder) 目录 一、项目描述二、项目实现三、项目步骤四、项目扩展方向 更多项目内容,请关注我、订…

打靶记录——靶机medium_socnet

靶机下载地址 https://www.vulnhub.com/entry/boredhackerblog-social-network,454/ 打靶过程 由于靶机和我的Kali都处于同一个网段,所以使用arpscan二次发现技术来识别目标主机的IP地址 arpscan -l除了192.168.174.133,其他IP都是我VMware虚拟机正…

【Spring Boot】认识 JPA 的接口

认识 JPA 的接口 1.JPA 接口 JpaRepository2.分页排序接口 PagingAndSortingRepository3.数据操作接口 CrudRepository4.分页接口 Pageable 和 Page5.排序类 Sort JPA 提供了操作数据库的接口。在开发过程中继承和使用这些接口,可简化现有的持久化开发工作。可以使 …

springboot学习,如何用redission实现分布式锁

目录 一、springboot框架介绍二、redission是什么三、什么是分布式锁四、如何用redission实现分布式锁 一、springboot框架介绍 Spring Boot是一个开源的Java框架,由Pivotal团队(现为VMware的一部分)于2013年推出。它旨在简化Spring应用程序…