昇思MindSpore学习笔记6--网络构建

摘要:

记录了昇思MindSpore神经网络模型的定义构建、模型分层,以及获取模型参数的步骤方法。

一、神经网络模型概念

神经网络模型是由神经网络层和Tensor操作构成

mindspore.nn实现了神经网络层

mindspore.nn.Cell是构建神经网络的基类

支持子Cell嵌套形成复杂的神经网络。

二、环境准备

安装minspore模块

!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1

导入minspore、nn、ops等相关模块

import mindspore
from mindspore import nn, ops

三、定义模型类

继承nn.Cell类

__init__ 初始化方法实现子Cell实例化和状态管理

construct 构建方法实现Tensor操作

示例:Mnist数据集分类神经网络模型。

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 10, weight_init="normal", bias_init="zeros"))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)

输出:

Network<(flatten): Flatten<>(dense_relu_sequential): SequentialCell<(0): Dense<input_channels=784, output_channels=512, has_bias=True>(1): ReLU<>(2): Dense<input_channels=512, output_channels=512, has_bias=True>(3): ReLU<>(4): Dense<input_channels=512, output_channels=10, has_bias=True>>>

构造输入数据,直接调用模型,获得一个十维的Tensor输出,包含每个类别的原始预测值。

【为什么?model.construct()方法不可直接调用。

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits

输出:

Tensor(shape=[1, 10], dtype=Float32, value=
[[-6.12840615e-03, -9.06550791e-03,  7.44015072e-03 ... -1.21108280e-03,  8.75189435e-05,  1.14824511e-02]])

nn.Softmax()获得预测概率

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

输出:

Predicted class: [9]

四、模型层

神经网络模型层分解

构造3个28x28的随机图像,shape为(3, 28, 28)

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)

输出:

(3, 28, 28)

1. nn.Flatten

实例化nn.Flatten层,将28x28的2D张量转换为784大小的连续数组。

类似理解为将2维矩阵转换为1维数组。

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.shape)

输出:

(3, 784)

2. nn.Dense

nn.Dense为全连接层,使用权重和偏差对输入进行线性变换。

【猜测理解】排除噪声数据

layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
print(hidden1.shape)

输出:

(3, 20)

3. nn.ReLU

nn.ReLU层给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。

【猜测理解】提取特征模型

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")

输出:

Before ReLU: 
[[ 0.22167806 -0.6745967  -0.27886766  0.22455142 -0.23248096 -0.133361430.10336601  0.55169076  1.5191591   0.73823357  0.14714427  0.70888543-0.38694382 -0.49788377  0.06992912  0.1532075  -0.5910665   0.69121903-0.3803944   1.0794045 ][ 0.22167806 -0.6745967  -0.27886766  0.22455142 -0.23248096 -0.133361430.10336601  0.55169076  1.5191591   0.73823357  0.14714427  0.70888543-0.38694382 -0.49788377  0.06992912  0.1532075  -0.5910665   0.69121903-0.3803944   1.0794045 ][ 0.22167806 -0.6745967  -0.27886766  0.22455142 -0.23248096 -0.133361430.10336601  0.55169076  1.5191591   0.73823357  0.14714427  0.70888543-0.38694382 -0.49788377  0.06992912  0.1532075  -0.5910665   0.69121903-0.3803944   1.0794045 ]]After ReLU: 
[[0.22167806 0.         0.         0.22455142 0.         0.0.10336601 0.55169076 1.5191591  0.73823357 0.14714427 0.708885430.         0.         0.06992912 0.1532075  0.         0.691219030.         1.0794045 ][0.22167806 0.         0.         0.22455142 0.         0.0.10336601 0.55169076 1.5191591  0.73823357 0.14714427 0.708885430.         0.         0.06992912 0.1532075  0.         0.691219030.         1.0794045 ][0.22167806 0.         0.         0.22455142 0.         0.0.10336601 0.55169076 1.5191591  0.73823357 0.14714427 0.708885430.         0.         0.06992912 0.1532075  0.         0.691219030.         1.0794045 ]]

4. nn.SequentialCell

nn.SequentialCell 顺序Cell容器

输入张量Tensor按照定义的顺序通过所有Cell

seq_modules = nn.SequentialCell(flatten,layer1,nn.ReLU(),nn.Dense(20, 10)
)logits = seq_modules(input_image)
print(logits.shape)

输出:

(3, 10)

5. nn.Softmax

最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

【猜测理解】nn.Softmax根据神经网络返回的特征解推出预测概率

softmax = nn.Softmax(axis=1)
pred_probab = softmax(logits)

五、模型参数

包括网络内部神经网络层具有权重参数和偏置参数nn.Dense等。

model.parameters_and_names() 获取参数名及对应的参数

print(f"Model structure: {model}\n\n")
​
for name, param in model.parameters_and_names():print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")

输出:

Model structure: Network<(flatten): Flatten<>(dense_relu_sequential): SequentialCell<(0): Dense<input_channels=784, output_channels=512, has_bias=True>(1): ReLU<>(2): Dense<input_channels=512, output_channels=512, has_bias=True>(3): ReLU<>(4): Dense<input_channels=512, output_channels=10, has_bias=True>>>Layer: dense_relu_sequential.0.weight
Size: (512, 784)
Values : 
[[-0.00240309  0.01673641  0.01800174 ... -0.00494339  0.00156037-0.01077523][ 0.01024065  0.00036838  0.01390939 ... -0.03533866  0.015644050.00675983]] Layer: dense_relu_sequential.0.bias
Size: (512,)
Values : [0. 0.] Layer: dense_relu_sequential.2.weight
Size: (512, 512)
Values : 
[[ 0.01293653 -0.00968055  0.00156187 ... -0.00389436  0.00231833-0.00164853][ 0.01165598  0.00445346 -0.00812163 ...  0.00903445 -0.001936590.00366774]] Layer: dense_relu_sequential.2.bias
Size: (512,)
Values : [0. 0.] Layer: dense_relu_sequential.4.weight
Size: (10, 512)
Values : 
[[ 5.1974057e-05 -1.5710767e-02 -1.3622168e-02 ...  8.3433613e-03-2.0038823e-02  2.4895865e-02][-1.8089322e-02 -1.6186651e-02  1.5664777e-02 ... -8.6768856e-034.5306431e-03  8.7257465e-03]] Layer: dense_relu_sequential.4.bias
Size: (10,)
Values : [0. 0.] 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/37944.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《学透 Spring》学习笔记 | 笔记1 | Spring 家族

Spring 家族 概述 2003 年&#xff0c;Rod Johnson 对 J2EE&#xff08;现在的 Java EE&#xff09; 存在的各种问题进行深入的剖析&#xff0c;并提出了一套解决方案 —— Spring Framework&#xff0c;这就是最早期的 Spring。Spring Framework 对 J2EE 进行了一系列补充&a…

在线一起学习平台设计

设计一个在线一起学习平台&#xff0c;旨在促进远程协作学习&#xff0c;提升学习效率和体验。以下是设计的基本框架和关键功能&#xff1a; 1. 用户管理模块 注册与登录&#xff1a;用户可以通过邮箱、手机号或社交媒体账号注册和登录。个人资料&#xff1a;用户可以完善个人…

OpenCV 调用自定义训练的 YOLO-V8 Onnx 模型

一、YOLO-V8 转 Onnx 在本专栏的前面几篇文章中&#xff0c;我们使用 ultralytics 公司开源发布的 YOLO-V8 模型&#xff0c;分别 Fine-Tuning 实验了 目标检测、关键点检测、分类 任务&#xff0c;实验后发现效果都非常的不错&#xff0c;但是前面的演示都是基于 ultralytics…

【贪心】【哈希表】个人练习-Leetcode-846. Hand of Straights

题目链接&#xff1a;https://leetcode.cn/problems/hand-of-straights/ 题目大意&#xff1a;给出一数列&#xff0c;求是否能刚好将它们分成若干组&#xff0c;每组的元素数量为groupSize&#xff0c;并且元素连续。 思路&#xff1a;因为题目的限制很死&#xff0c;如果能…

C语言分支和循环(下)

C语言分支和循环&#xff08;下&#xff09; 1. 随机数生成1.1 rand1.2 srand1.3 time1.4 设置随机数的范围 2. 猜数字游戏实现 掌握了前面学习的这些知识&#xff0c;我们就可以写⼀些稍微有趣的代码了&#xff0c;比如&#xff1a; 写⼀个猜数字游戏 游戏要求&#xff1a; 电…

第6章 复制

文章目录 前言1.配置1.1建立复制1.2断开复制1.3 安全性1.4 只读1.5 传输延迟 2. 拓扑2.1.一主一从结构2.2.一主多从结构2.3.树状主从结构 3.原理3.1复制过程3.2数据同步3.3全量复制 前言 复制功能&#xff0c;实现了相同数据的多个Redis副本。复制功能是高可用Redis的基础&…

智能交通(2)——IntelliLight智能交通灯

论文分享&#xff1a;IntelliLight | Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mininghttps://dl.acm.org/doi/10.1145/3219819.3220096摘要 智能交通灯控制对于高效的交通系统至关重要。目前现有的交通信号灯大多由手…

【Python系列】列表推导式:简洁而强大的数据操作工具

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

hnust 1949: 顺序表实现(第二部分)

hnust 1949: 顺序表实现(第二部分) 题目描述 拷贝下面的代码&#xff0c;然后将顺序表实现第一部分的工作填入&#xff0c;再完成顺序表的GetElem&#xff0c;LocateElem和ListDelete操作&#xff0c;其他地方不得改动。 #include #include #include using namespace std; #…

已成功与服务器建立连接,但是在登录过程中发生错误。(provider: SSL提供程序,error:0-证书链是由不受信任的颁发机构颁发的。)

已成功与服务器建立连接&#xff0c;但是在登录过程中发生错误。(provider: SSL提供程序,error:0-证书链是由不受信任的颁发机构颁发的。) 在连接SQL Server2008R2数据库时发生错误。 连接字符串&#xff1a;server127.0.0.1;uidsa;pwd1;databasedb; 解决办法&#xff1a; 方…

PySide(PyQt)在图像上画线

1、按鼠标左键任意画线 import sys from PySide6.QtWidgets import QApplication, QLabel, QVBoxLayout, QWidget from PySide6.QtGui import QPainter, QPixmap, QMouseEvent, QColor, QPen from PySide6.QtCore import Qt, QPointclass PaintLabel(QLabel):def __init__(self…

如何使用FlowUs打造爆款自媒体内容?内容资产管理沉淀的先进工具选息流

FlowUs 是一款流行的在线协作工具&#xff0c;它以其灵活的块编辑器、看板视图、数据库管理等功能受到众多个人和团队的喜爱。将其应用于内容资产管理&#xff0c;尤其是对于追求打造爆款自媒体的创作者而言&#xff0c;可以极大地提升内容创作、组织、分发及分析的效率。 内容…

无刷直流电机(BLDCM)仿真建模

无刷直流电机&#xff0c;即BLDCM在各个行业应用非常广泛。在汽车电子领域&#xff0c;BLDCM被广泛用于电动汽车、混合动力汽车、电动自行车等车辆的驱动系统中。由于BLDCM具有高效率、高力矩密度和快速响应的优势&#xff0c;它可以提供可靠的动力输出&#xff0c;并且可以通过…

JavaScript 动态网页实例 —— 背景效果

页面背景是网页设计中必不可少的重要内容之一,其背景的好坏直接影响网页浏览者的浏览兴趣。网页背景分为背景图和背景色两种,对于普通的背景图和背景色,完全可以通过HTML实现,而要实现复杂的背景效果,则需要借助于JavaScript。本章介绍页面背景的一些实现效果。首先是一个…

idea常用配置 | 快捷注释

idea快速注释 一、类上快速注释 &#xff08;本方法是IDEA环境自带的&#xff0c;设置特别方便简单易使用&#xff09; 1、偏好设置->编辑器->文件和代码模版 | File-Settings-Editor-File and Code Templates 2、右下方的“描述”中有相对应的自动注注释配置格式 贴…

力扣 单词规律

所用数据结构 哈希表 核心方法 判断字符串pattern 和字符串s 是否存在一对一的映射关系&#xff0c;按照题意&#xff0c;双向连接的对应规律。 思路以及实现步骤 1.字符串s带有空格&#xff0c;因此需要转换成字符数组进行更方便的操作&#xff0c;将字符串s拆分成单词列表…

Java单体架构项目_云霄外卖-特殊点

项目介绍&#xff1a; 定位&#xff1a; 专门为餐饮企业&#xff08;餐厅、饭店&#xff09;定制的一款软件商品 分为&#xff1a; 管理端&#xff1a;外卖商家使用 用户端&#xff08;微信小程序&#xff09;&#xff1a;点餐用户使用。 功能架构&#xff1a; &#xff08…

Python学习笔记20:进阶篇(九)常见标准库使用之sys模块和re模块

前言 本文是根据python官方教程中标准库模块的介绍&#xff0c;自己查询资料并整理&#xff0c;编写代码示例做出的学习笔记。 根据模块知识&#xff0c;一次讲解单个或者多个模块的内容。 教程链接&#xff1a;https://docs.python.org/zh-cn/3/tutorial/index.html 错误输出…

电商平台数据爬取经验分享

一、引言 在电商领域&#xff0c;数据的重要性不言而喻。无论是市场趋势分析、竞争对手研究&#xff0c;还是用户行为洞察&#xff0c;都离不开数据的支持。而数据爬虫作为获取这些数据的重要工具&#xff0c;其技术的掌握和运用对于电商平台来说至关重要。本文将结合个人实际…

AI绘画 Stable Diffusion【实战进阶】:图片的创成式填充,竖图秒变横屏壁纸!想怎么扩就怎么扩!

大家好&#xff0c;我是向阳。 所谓图片的创成式填充&#xff0c;就是基于原有图片进行扩展或延展&#xff0c;在保证图片合理性的同时实现与原图片的高度契合。是目前图像处理中常见应用之一。之前大部分都是通过PS工具来处理的。今天我们来看看在AI绘画工具 Stable Diffusio…