机器学习-11-卷积神经网络-基于paddle实现神经网络

文章目录

    • 总结
    • 参考
    • 本门课程的目标
    • 机器学习定义
    • 第一步:数据准备
    • 第二步:定义网络
    • 第三步:训练网络
    • 第四步:测试训练好的网络

总结

本系列是机器学习课程的系列课程,主要介绍基于paddle实现神经网络。

参考

MNIST 训练_副本

本门课程的目标

完成一个特定行业的算法应用全过程:

懂业务+会选择合适的算法+数据处理+算法训练+算法调优+算法融合
+算法评估+持续调优+工程化接口实现

机器学习定义

关于机器学习的定义,Tom Michael Mitchell的这段话被广泛引用:
对于某类任务T性能度量P,如果一个计算机程序在T上其性能P随着经验E而自我完善,那么我们称这个计算机程序从经验E中学习
在这里插入图片描述

使用MNIST数据集训练和测试模型。

第一步:数据准备

MNIST数据集

import paddle
from paddle.vision.datasets import MNIST
from paddle.vision.transforms import ToTensortrain_dataset = MNIST(mode='train', transform=ToTensor())
test_dataset = MNIST(mode='test', transform=ToTensor())

展示数据集图片

import matplotlib.pyplot as plt
import numpy as nptrain_data0, train_label_0 = train_dataset[0][0], train_dataset[0][1]
train_data0 = train_data0.reshape([28, 28])
plt.figure(figsize=(2, 2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 的标签为: ' + str(train_label_0))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))train_data0 的标签为: [5]

第二步:定义网络

import paddle
import paddle.nn.functional as F
from paddle.nn import Conv2D, MaxPool2D, Linearclass MyModel(paddle.nn.Layer):def __init__(self):super(MyModel, self).__init__()self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)self.linear1 = Linear(in_features=16*5*5, out_features=120)self.linear2 = Linear(in_features=120, out_features=84)self.linear3 = Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = F.relu(x)x = self.conv2(x)x = self.max_pool2(x)x = paddle.flatten(x, start_axis=1, stop_axis=-1)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)x = F.relu(x)x = self.linear3(x)return x

模型可视化

import paddle
mnist = MyModel()
paddle.summary(mnist, (1, 1, 28, 28))
---------------------------------------------------------------------------Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          156      MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0       Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416     MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0       Linear-1          [[1, 400]]            [1, 120]           48,120     Linear-2          [[1, 120]]            [1, 84]            10,164     Linear-3          [[1, 84]]             [1, 10]              850      
===========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.30
---------------------------------------------------------------------------{'total_params': 61706, 'trainable_params': 61706}

第三步:训练网络

import paddle
from paddle.metric import Accuracy
from paddle.static import InputSpecinputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')# 用Model封装模型
model = paddle.Model(MyModel(), inputs, labels)# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())
# 训练模型
model.fit(train_dataset, test_dataset, epochs=3, batch_size=64, save_dir='mnist_checkpoint', verbose=1)
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/3
step 938/938 [==============================] - loss: 0.0208 - acc: 0.9456 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/0
Eval begin...
step 157/157 [==============================] - loss: 0.0041 - acc: 0.9777 - 19ms/step          
Eval samples: 10000
Epoch 2/3
step 938/938 [==============================] - loss: 0.0021 - acc: 0.9820 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/1
Eval begin...
step 157/157 [==============================] - loss: 2.1037e-04 - acc: 0.9858 - 19ms/step      
Eval samples: 10000
Epoch 3/3
step 938/938 [==============================] - loss: 0.0126 - acc: 0.9876 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/2
Eval begin...
step 157/157 [==============================] - loss: 4.7168e-04 - acc: 0.9884 - 19ms/step      
Eval samples: 10000
save checkpoint at /home/aistudio/mnist_checkpoint/final

第四步:测试训练好的网络

import paddle
import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracy
from paddle.static import InputSpecinputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')
model = paddle.Model(MyModel(), inputs, labels)
model.load('./mnist_checkpoint/final')
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())# results = model.evaluate(test_dataset, batch_size=64, verbose=1)
# print(results)results = model.predict(test_dataset, batch_size=64)test_data0, test_label_0 = test_dataset[0][0], test_dataset[0][1]
test_data0 = test_data0.reshape([28, 28])
plt.figure(figsize=(2,2))
plt.imshow(test_data0, cmap=plt.cm.binary)print('test_data0 的标签为: ' + str(test_label_0))
print('test_data0 预测的数值为:%d' % np.argsort(results[0][0])[0][-1])
Predict begin...
step 157/157 [==============================] - 27ms/step          
Predict samples: 10000
test_data0 的标签为: [7]
test_data0 预测的数值为:7/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/4893.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

编译Qt6.5.3LTS版本(Mac/Windows)的mysql驱动(附带编译后的全部文件)

文章目录 0 背景1 编译过程2 福利参考 0 背景 因为项目要用到对MYSQL数据库操作,所以需要连接到MYSQL数据库。但是连接需要MYSQL驱动,但是Qt本身不自带MYSQL驱动,需要自行编译。网上有很多qt之前版本的mysql驱动,但是没有找到qt6…

使用Ollama和OpenWebUI在CPU上玩转Meta Llama3-8B

2024年4月18日,meta开源了Llama 3大模型[1],虽然只有8B[2]和70B[3]两个版本,但Llama 3表现出来的强大能力还是让AI大模型界为之震撼了一番,本人亲测Llama3-70B版本的推理能力十分接近于OpenAI的GPT-4[4],何况还有一个4…

降薪、调岗、裁员,库迪咖啡没有“钱景”

文 | 螳螂观察 作者 | 青玥 2024年,连锁咖啡行业依然激战正酣。 卷价格、卷联名、卷代言人,部分咖啡小店已经在这样内卷的氛围中,率先被淘汰,咖门的统计数据显示,2023年前九个月,注销的咖啡企业有9825家…

Unity 数字字符串逗号千分位

使用InputField时处理输入的数字型字符串千分位自动添加逗号,且自动保留两位有效数字 输入:123 输出:123.00 输入:12345 输出:12,345.00 代码非常简单 using UnityEngine; using TMPro;public class …

前端高并发的出现场景及解决方法——技能提升——p-limit的使用

最近在写后台管理系统的时候,遇到一个场景,就是打印的页面需要根据传入的多个id,分别去请求详情接口。 比如id有10个,则需要调用10次详情接口获取到数据,最后对所有的数据进行整合后页面渲染。 相信大家或多或少都遇到…

ASP.NET实验室预约系统的设计

摘 要 实验室预约系统的设计主要是基于B/S模型,在Windows系统下,运用ASP.NET平台和SQLServer2000数据库实现实验室预约功能。该设计主要实现了实验室的预约和管理功能。预约功能包括老师对实验室信息、实验项目和实验预约情况的查询以及对实验室的预约…

Linux--05---相对路径与绝对路径、终端的认识

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1. 相对路径与绝对路径切换到用户家目录:cd ~当前目录:./ 2. 对终端的认识3. 文件的类型颜色表示的文件类型:文件类型和权限的表…

基于深度学习的实时人脸检测与情绪分类

情绪分类 实时人脸检测与情绪分类 Kaggle Competion 数据集 fer2013 中的测试准确率为 66%CK数据集的检验准确率为99.87%情绪分类器模型预测从网络摄像头捕获的实时视频中的平均成本时间为 4~ 10ms 关键技术要点: 实时人脸检测:系统采用了前沿的人脸检…

明日周刊-第8期

现在露营的人越来越多了,都是带着帐篷或者遮阳篷聚在一起喝喝茶聊聊天,这是一种很好的放松方式。最近我养了一只金毛,目前两个月大,非常可爱。 文章目录 一周热点资源分享言论歌曲推荐 一周热点 一、人工智能领域 本周&#xff…

迈威通信首秀成都工博会,迸发品牌强势能

4月26日,在繁花似锦的成都,一场工业界的盛会刚刚落下帷幕。而在这场盛会的璀璨星辰中,迈威通信以其首秀之姿,迸发出耀眼的光芒,强势展现了品牌的实力与魅力。 此次展会,迈威通信以“工业互联赋能新数字化智…

C语言求 MD5 值

MD5值常被用于验证数据的完整性,嵌入式开发时经常用到。md5sum命令可以求MD5码,下面介绍如何用C语言实现MD5功能。 一、求字符串MD5值 1、md5sum命令 $ echo -n "12345678" | md5sum //获取"12345678"字符串的md5值 结果&…

react 学习笔记二:ref、状态、继承

基础知识 1、ref 创建变量时,需要运用到username React.createRef(),并将其绑定到对应的节点。在使用时需要获取当前的节点; 注意:vue直接使用里面的值,不需要再用this。 2、状态 组件描述某种显示情况的数据&#…

VCSA6.7重置root密码

VCSA6.7重置root密码 1、登录VCSA所运行的ESXI主机 2、打开VCSA虚拟机Web控制台,先拍摄一个快照,然后重启虚拟机,在如下界面按"e" 3、找到linux开头的段落,在末尾追加rw init/bin/bash; 4、输入完成后,按&…

LPDDR5和LPDDR5X区别

发布时间 LPDDR5和LPDDR5X的发布时间如下: LPDDR5的具体发布时间没有直接提及,但它在市场上的应用早于LPDDR5X。LPDDR5作为LPDDR4(X)的继任者,其规范发布和商用化大致发生在2019年至2020年间,具体技术细节和产品商用则依据各制造…

【ruoyi-vue】关于密码重置

文章目录 前言解决问题 前言 在qq群里经常看到问ruoyi的账号密码是多少?有源代码忘记了登录密码怎么办? 解决问题 在 ruoyi-admin 模块内 SysUserController找到新增用户或修改用户密码的相关接口在里面就可以找相关创建密码的方法ruoyi里的创建密码的…

MySQL从入门到高级 --- 3.DML基本操作

文章目录 第三章:3.基本操作 - DML3.1 数据插入3.2 数据修改3.3 数据删除3.4 练习 第三章: 3.基本操作 - DML DML:数据操作语言,用来对数据中表的数据记录进行更新 关键字: insert 插入 delete 删除 update 更新 …

OceanBase V4.3 发布—— 迈向实时分析 AP 的重要里程

OceanBase在2023年初,发布了4.x架构的第一个重要版本,V4.1。该版本采用了单机分布式一体化架构,并在该架构的基础上,将代表数据库可靠性的RTO降低至 8 秒以内,从而确保在意外故障发生后,系统能够在极短时间…

碳化硅片有哪些比较重要的参数?

知识星球(星球名:芯片制造与封测社区)里的学员问:请问碳化硅衬底片到客户端验证主要测试什么项目,比较重要的参数有哪些? Lattice Parameters:晶格参数。确保衬底的晶格常数与将要生长的外延层…

面对网络安全,做好风险评估对企业会带来哪些帮助

随着信息技术的飞速发展,网络安全问题日益凸显,成为企业不容忽视的重要议题。企业作为社会经济活动的主要参与者,其网络安全不仅关系到自身的生存与发展,更与国家的经济安全、社会稳定息息相关。因此,企业必须高度重视…

盲人手机导航:科技之光引领无障碍出行新纪元

在这个日新月异的数字时代,科技不仅改变了我们获取信息的方式,更在无声中拓宽了视障人士的生活半径。盲人手机导航这一创新技术,正逐步成为他们探索世界、实现独立出行的重要伙伴。 对于大多数人而言,日常出行或许只是一次…