机器学习-10-基于paddle实现神经网络

文章目录

    • 总结
    • 参考
    • 本门课程的目标
    • 机器学习定义
    • 第一步:数据准备
    • 第二步:定义网络
    • 第三步:训练网络
    • 第四步:测试训练好的网络

总结

本系列是机器学习课程的系列课程,主要介绍基于paddle实现神经网络。

参考

MNIST 训练_副本

本门课程的目标

完成一个特定行业的算法应用全过程:

懂业务+会选择合适的算法+数据处理+算法训练+算法调优+算法融合
+算法评估+持续调优+工程化接口实现

机器学习定义

关于机器学习的定义,Tom Michael Mitchell的这段话被广泛引用:
对于某类任务T性能度量P,如果一个计算机程序在T上其性能P随着经验E而自我完善,那么我们称这个计算机程序从经验E中学习
在这里插入图片描述

使用MNIST数据集训练和测试模型。

第一步:数据准备

MNIST数据集

import paddle
from paddle.vision.datasets import MNIST
from paddle.vision.transforms import ToTensortrain_dataset = MNIST(mode='train', transform=ToTensor())
test_dataset = MNIST(mode='test', transform=ToTensor())

展示数据集图片

import matplotlib.pyplot as plt
import numpy as nptrain_data0, train_label_0 = train_dataset[0][0], train_dataset[0][1]
train_data0 = train_data0.reshape([28, 28])
plt.figure(figsize=(2, 2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 的标签为: ' + str(train_label_0))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))train_data0 的标签为: [5]

第二步:定义网络

import paddle
import paddle.nn.functional as F
from paddle.nn import Conv2D, MaxPool2D, Linearclass MyModel(paddle.nn.Layer):def __init__(self):super(MyModel, self).__init__()self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)self.linear1 = Linear(in_features=16*5*5, out_features=120)self.linear2 = Linear(in_features=120, out_features=84)self.linear3 = Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = F.relu(x)x = self.conv2(x)x = self.max_pool2(x)x = paddle.flatten(x, start_axis=1, stop_axis=-1)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)x = F.relu(x)x = self.linear3(x)return x

模型可视化

import paddle
mnist = MyModel()
paddle.summary(mnist, (1, 1, 28, 28))
---------------------------------------------------------------------------Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          156      MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0       Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416     MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0       Linear-1          [[1, 400]]            [1, 120]           48,120     Linear-2          [[1, 120]]            [1, 84]            10,164     Linear-3          [[1, 84]]             [1, 10]              850      
===========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.30
---------------------------------------------------------------------------{'total_params': 61706, 'trainable_params': 61706}

第三步:训练网络

import paddle
from paddle.metric import Accuracy
from paddle.static import InputSpecinputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')# 用Model封装模型
model = paddle.Model(MyModel(), inputs, labels)# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())
# 训练模型
model.fit(train_dataset, test_dataset, epochs=3, batch_size=64, save_dir='mnist_checkpoint', verbose=1)
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/3
step 938/938 [==============================] - loss: 0.0208 - acc: 0.9456 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/0
Eval begin...
step 157/157 [==============================] - loss: 0.0041 - acc: 0.9777 - 19ms/step          
Eval samples: 10000
Epoch 2/3
step 938/938 [==============================] - loss: 0.0021 - acc: 0.9820 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/1
Eval begin...
step 157/157 [==============================] - loss: 2.1037e-04 - acc: 0.9858 - 19ms/step      
Eval samples: 10000
Epoch 3/3
step 938/938 [==============================] - loss: 0.0126 - acc: 0.9876 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/2
Eval begin...
step 157/157 [==============================] - loss: 4.7168e-04 - acc: 0.9884 - 19ms/step      
Eval samples: 10000
save checkpoint at /home/aistudio/mnist_checkpoint/final

第四步:测试训练好的网络

import paddle
import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracy
from paddle.static import InputSpecinputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')
model = paddle.Model(MyModel(), inputs, labels)
model.load('./mnist_checkpoint/final')
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())# results = model.evaluate(test_dataset, batch_size=64, verbose=1)
# print(results)results = model.predict(test_dataset, batch_size=64)test_data0, test_label_0 = test_dataset[0][0], test_dataset[0][1]
test_data0 = test_data0.reshape([28, 28])
plt.figure(figsize=(2,2))
plt.imshow(test_data0, cmap=plt.cm.binary)print('test_data0 的标签为: ' + str(test_label_0))
print('test_data0 预测的数值为:%d' % np.argsort(results[0][0])[0][-1])
Predict begin...
step 157/157 [==============================] - 27ms/step          
Predict samples: 10000
test_data0 的标签为: [7]
test_data0 预测的数值为:7/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/1685.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Node.js】01 —— fs模块全解析

🔥【Node.js】 fs模块全解析 📢 引言 在Node.js开发中,fs模块犹如一把万能钥匙,解锁着整个文件系统的操作。从读取文件、写入文件、检查状态到目录管理,无所不能。接下来,我们将逐一揭开fs模块中最常用的那…

vue ant form validate如何对数组下的表单校验

问题 使用Ant Design Vue校验表单时&#xff0c;通过validateFields&#xff0c;但是如何一个数组内部的校验呢&#xff1f; 效果图&#xff1a; 实现方式&#xff1a; 通过 v-for 循环渲染:name"[]"实现&#xff0c;我们直接看代码。 <template><a-for…

Spring Boot中JUnit 4与JUnit 5的如何共存

文章目录 前言一、先上答案二、稍微深入了解2.1 maven-surefire-plugin是什么2.2 JUnit4和JUnit5有什么区别2.2.1 不同的注解2.2.2 架构 前言 在maven项目中&#xff0c;生成单测时是否有这样的疑问&#xff1a;该选JUnit4还是JUnit5&#xff1f;在执行 mvn test 命令时有没有…

三、SpringBoot整合MyBatis

本章节主要描述MyBatis的整合&#xff0c;以及使用mybatis-generator-maven-plugin生成代码骨架&#xff0c;源码&#xff1a; jun/learn-springboot - Gitee.com 一、首先建数据库 本示例用的是MySQL8.0.23&#xff0c;建表t_goods、t_orders&#xff0c;略... 二、goods模块…

Java | Leetcode Java题解之第36题有效的数独

题目&#xff1a; 题解&#xff1a; class Solution {public boolean isValidSudoku(char[][] board) {int[][] rows new int[9][9];int[][] columns new int[9][9];int[][][] subboxes new int[3][3][9];for (int i 0; i < 9; i) {for (int j 0; j < 9; j) {char …

随机森林原理及应用

目录 一、随机森林原理、优点、应用场景 1.1基本原理 1.2主要优点 1.3使用场景 二、具体实例 一、随机森林原理、优点、应用场景 随机森林是一种流行且强大的机器学习算法&#xff0c;属于集成学习方法的一部分&#xff0c;主要用于分类和回归任务。它通过组合多个决策树…

SSTV音频转图片

SSTV工具有很多&#xff0c;这里使用RX-SSTV慢扫描工具 下载安装 RX-SSTV解码软件 下载地址&#xff1a;https://www.qsl.net/on6mu/rxsstv.htm 一直点下一步&#xff0c;安装成功如下图: 虚拟声卡e2eSoft 由于SSTV工具是根据音频传递图片信息&#xff0c;正常解法需要一…

在【laravel框架】学习中遇到的常见的问题以及解决方法

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;开发者-曼亿点 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 曼亿点 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a…

Marching Cubes算法

Marching Cubes算法 1. 简介2. 算法原理的理解2.1 如何找到面经过的这些小块(六面体)&#xff1f;2.2 找到后&#xff0c;如何又进一步的找到面与这些小块(六面体)的交点&#xff1b;2.3 这些交点按照怎么的拓扑连接关系连接&#xff0c;是怎么操作的&#xff1f; 3. 总结4. 参…

金融时报:波场亮相哈佛大学并举办TRON Builder Tour活动

近日,波场TRON作为顶级白金赞助商出席哈佛区块链会议并成功举办TRON Builder Tour哈佛站活动,引发海外媒体热议。美联社、金融时报、Cointelegraph等国际主流媒体及加密知名媒体均对此给予了高度评价,认为本次大会对TRON Builder Tour活动具有里程碑意义,彰显了波场TRON致力于促…

mysql基础5——设置主键

业务字段尽量不要用做主键 删除主键&#xff0c;只是主键被删除&#xff0c;字段还存在 alter table demo.membermaster drop primary key; 添加一个字段设置为主键并给主键添加自增约束 alter table demo.membermaster add column id int primary key auto_increment; 自增…

Gitea 简单介绍、用法以及使用注意事项!

Gitea 是一个轻量级的代码托管解决方案&#xff0c;它提供了一个简单而强大的平台&#xff0c;用于托管和协作开发项目。基于 Go 语言编写&#xff0c;与 GitLab 和 GitHub Enterprise 类似&#xff0c;但专为自托管而设计。以下是对 Gitea 的详细介绍&#xff0c;包括常用命令…

anaconda配置的环境对应的地址查看,环境安装位置

打开conda指令窗口 这个和上面的都一样&#xff0c;哪个都行 点开后&#xff0c;输入 conda env list 这里显示的就是自己的每个环境对应的地址了

游戏黑灰产识别和溯源取证

参考&#xff1a;游戏黑灰产识别和溯源取证 1. 游戏中的黑灰产 1. 黑灰产简介 黑色产业&#xff1a;从事具有违法性活动且以此来牟取利润的产业&#xff1b; 灰色产业&#xff1a;不明显触犯法律和违背道德&#xff0c;游走于法律和道德边缘&#xff0c;以打擦边球的方式为“…

巧用断点设置查找bug【debug】

默认设置的断点&#xff0c;当代码运行到断点处MCU就会被挂起&#xff0c;从而停在断点处。 但在某些情况下&#xff0c;如调试FCCU时&#xff0c;如果设置断点&#xff0c;MCU停下后将会导致 FCCU 配置WDG超时。或在调试类似电机控制类的应用时&#xff0c;不适当的断点会导 致…

复合升降机器人教学科研平台——技术方案

一&#xff1a;功能概述 1.1 功能简介 复合升降机器人是一款集成移动底盘、机械臂、末端执行器、边缘计算平台等机构形成的教学科研平台&#xff0c;可实现机器人建图导航、路径规划&#xff0c;机械臂运动学、动力学、轨迹规划、视觉识别等算法功能和应用&#xff0c;提供例如…

Python中列表数据的保存与读取:以txt文件为例

目录 引言 一、列表数据的保存 二、列表数据的读取 三、进阶用法与注意事项 1. 处理嵌套列表 2. 处理大量数据 3. 注意事项 四、总结 引言 在Python编程中&#xff0c;我们经常需要处理各种类型的数据&#xff0c;包括列表。列表是一种非常灵活的数据结构&#xff0c;…

边缘计算的优势

边缘计算的优势 边缘计算是一种在数据生成地点附近处理数据的技术&#xff0c;而非传统的将数据发送到远端数据中心或云进行处理。这种计算模式对于需要快速响应的场景特别有效&#xff0c;以下详述了边缘计算的核心优势。 1. 降低延迟 边缘计算通过在数据源近处处理数据&…

imx6ull设备树驱动--pinctl、ioctl

添加pinctl节点 进入arch/arm/boot/dts目录下dts文件 在iomuxc下添加pinctlled节点 将 GPIO1_IO03 这个 PIN 复用为 GPIO1_IO03&#xff0c;电气属性&#xff08;配置GPIO一些列寄存器&#xff09;值为 0X10B0 添加led设备节点 与上一节一样&#xff0c;在 / 下面添加设备节…

数电期末复习(四)组合逻辑电路

这里写目录标题 4.1 概述4.2 组合逻辑电路的分析方法4.3 组合逻辑电路的设计方法4.4 若干常用组合逻辑电路4.4.1 编码器&#xff08;encoder&#xff09;4.4.2 译码器(decoder)4.4.3 数据选择器 (data selector)4.3.4 加法器&#xff08;Adder&#xff09;4.4.4 数值比较器&…