深度学习:tf.keras实现模型搭建、模型训练和预测

在sklearn中,模型都是现成的。tf.Keras是一个神经网络库,我们需要根据数据和标签值构建神经网络。神经网络可以发现特征与标签之间的复杂关系。神经网络是一个高度结构化的图,其中包含一个或多个隐藏层。每个隐藏层都包含一个或多个神经元。神经网络有多种类别,该程序使用的是密集型神经网络,也称为全连接神经网络:一个层中的神经元将从上一层中的每个神经元获取输入连接。例如,图 2 显示了一个密集型神经网络,其中包含 1 个输入层、2 个隐藏层以及 1 个输出层,如下图所示:

神经网络

上图 中的模型经过训练并馈送未标记的样本时,它会产生 3 个预测结果:相应鸢尾花属于指定品种的可能性。对于该示例,输出预测结果的总和是 1.0。该预测结果分解如下:山鸢尾为 0.02,变色鸢尾为 0.95,维吉尼亚鸢尾为 0.03。这意味着该模型预测某个无标签鸢尾花样本是变色鸢尾的概率为 95%。

TensorFlow tf.keras API 是创建模型和层的首选方式。通过该 API,您可以轻松地构建模型并进行实验,而将所有部分连接在一起的复杂工作则由 Keras 处理。

tf.keras.Sequential 模型是层的线性堆叠。该模型的构造函数会采用一系列层实例;在本示例中,采用的是 2 个密集层(分别包含 10 个节点)以及 1 个输出层(包含 3 个代表标签预测的节点)。第一个层的 input_shape 参数对应该数据集中的特征数量:

# 利用sequential方式构建模型model = Sequential([# 隐藏层1,激活函数是relu,输入大小有input_shape指定Dense(10, activation="relu", input_shape=(4,)),  # 隐藏层2,激活函数是reluDense(10, activation="relu"),# 输出层Dense(3,activation="softmax")])

通过model.summary可以查看模型的架构:

激活函数可决定层中每个节点的输出形状。这些非线性关系很重要,如果没有它们,模型将等同于单个层。激活函数有很多,但隐藏层通常使用 ReLU。

隐藏层和神经元的理想数量取决于问题和数据集。与机器学习的多个方面一样,选择最佳的神经网络形状需要一定的知识水平和实验基础。一般来说,增加隐藏层和神经元的数量通常会产生更强大的模型,而这需要更多数据才能有效地进行训练。

模型训练和预测

在训练和评估阶段,我们都需要计算模型的损失。这样可以衡量模型的预测结果与预期标签有多大偏差,也就是说,模型的效果有多差。我们希望尽可能减小或优化这个值,所以我们设置优化策略和损失函数,以及模型精度的计算方法:

# 设置模型的相关参数:优化器,损失函数和评价指标mode
l.compile(optimizer='adam', loss='categorical_crossentropy', metrics=["accuracy"])

接下来与在sklearn中相同,分别调用fit和predict方法进行预测即可。

# 模型训练:epochs,训练样本送入到网络中的次数,batch_size:每次训练的送入到网络中的样本个数
model.fit(train_X, train_y_ohe, epochs=10, batch_size=1, verbose=1);

上述代码完成的是:

  1. 迭代每个epoch。通过一次数据集即为一个epoch。

  2. 在一个epoch中,遍历训练 Dataset 中的每个样本,并获取样本的特征 (x) 和标签 (y)。

  3. 根据样本的特征进行预测,并比较预测结果和标签。衡量预测结果的不准确性,并使用所得的值计算模型的损失和梯度。

  4. 使用 optimizer 更新模型的变量。

  5. 对每个epoch重复执行以上步骤,直到模型训练完成。

训练过程展示如下:

Epoch 1/10
75/75 [==============================] - 0s 616us/step - loss: 0.0585 - accuracy: 0.9733
Epoch 2/10
75/75 [==============================] - 0s 535us/step - loss: 0.0541 - accuracy: 0.9867
Epoch 3/10
75/75 [==============================] - 0s 545us/step - loss: 0.0650 - accuracy: 0.9733
Epoch 4/10
75/75 [==============================] - 0s 542us/step - loss: 0.0865 - accuracy: 0.9733
Epoch 5/10
75/75 [==============================] - 0s 510us/step - loss: 0.0607 - accuracy: 0.9733
Epoch 6/10
75/75 [==============================] - 0s 659us/step - loss: 0.0735 - accuracy: 0.9733
Epoch 7/10
75/75 [==============================] - 0s 497us/step - loss: 0.0691 - accuracy: 0.9600
Epoch 8/10
75/75 [==============================] - 0s 497us/step - loss: 0.0724 - accuracy: 0.9733
Epoch 9/10
75/75 [==============================] - 0s 493us/step - loss: 0.0645 - accuracy: 0.9600
Epoch 10/10
75/75 [==============================] - 0s 482us/step - loss: 0.0660 - accuracy: 0.9867

与sklearn中不同,对训练好的模型进行评估时,与sklearn.score方法对应的是tf.keras.evaluate()方法,返回的是损失函数和在compile模型时要求的指标:

# 计算模型的损失和准确率
loss, accuracy = model.evaluate(test_X, test_y_ohe, verbose=1)
print("Accuracy = {:.2f}".format(accuracy))

分类器的准确率为:

3/3 [==============================] - 0s 591us/step - loss: 0.1031 - accuracy: 0.9733
Accuracy = 0.97

到此我们对tf.kears的使用有了一个基本的认知,在接下来的课程中会给大家解释神经网络以及在计算机视觉中的常用的CNN的使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/6156.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【微信小程序】使用iView组件库中的icons资源

要在微信小程序中使用iView组件库中的icons资源,需要先下载并引入iView组件库,并按照iView的文档进行配置和使用。 以下是一般的使用步骤: 下载iView组件库的源码或使用npm安装iView。 在小程序项目的app.json文件中添加iView组件库的引入配…

mac端好用的多功能音频软件 AVTouchBar for mac 3.0.7

AVTouchBar是来自触摸栏的视听播放器,将跳动笔记的内容带到触摸栏,触摸栏可显示有趣的音频内容,拥有更多乐趣,以一种有趣的方式播放音乐,该软件支持多种音频播放软件,可在Mac上自动更改音乐~ 音频选择-与内…

Flask Bootstrap 导航条

(43条消息) Flask 导航栏,模版渲染多页面_U盘失踪了的博客-CSDN博客 (43条消息) 学习记录:Bootstrap 导航条示例_bootstrap导航栏案例_U盘失踪了的博客-CSDN博客 1,引用Bootstrap css样式,导航栏页面跳转 2,页面两列…

实验五 分支限界法

实验五 分支限界法 01背包问题的分治限界法的实现 剪枝函数 限界函数 1.实验目的 1、理解分支限界法的剪枝搜索策略,掌握分支限界法的算法框架 2、设计并实现问题,掌握分支限界算法。 2.实验环境 java 3.问题描述 给定n种物品和一背包。物品i的重…

Cesium态势标绘专题-位置点(标绘+编辑)

标绘专题介绍:态势标绘专题介绍_总要学点什么的博客-CSDN博客 入口文件:Cesium态势标绘专题-入口_总要学点什么的博客-CSDN博客 辅助文件:Cesium态势标绘专题-辅助文件_总要学点什么的博客-CSDN博客 本专题没有废话,只有代码,代码中涉及到的引入文件方法,从上面三个链…

企业微信在ios机型无法吊起打开个人信息页接口(openUserProfile)

wx.qy.openUserProfile({type: 1,//1表示该userid是企业成员,2表示该userid是外部联系人userid: "wmEQlEGwAAHxbWYDOK5u3Af13xlYAAAA", //可以是企业成员,也可以是外部联系人success: function(res) {// 回调} });遇到的问题:调用打…

动态规划入门第1课

1、从计数到选择 ---- 递推与DP(动态规划) 2、从递归到记忆 ---- 子问题与去重复运算 3、动态规划的要点 第1题 网格路1(grid1) 小x住在左下角(0,0)处,小y在右上角(n,n)处。小x需要通过一段网格路才能到小y家。每次,小x可以选…

macOS mysql 8.0 忘记密码

╰─➤ mysql -V mysql Ver 8.0.33 for macos13.3 on arm64 (Homebrew)mysql.server status mysql.server stopskip-grant-tables 启动mysql ─➤ /opt…

云计算和云架构是什么 有什么用途?

云计算是一种基于互联网的计算方式,它通过网络将计算资源(如计算能力、存储、网络带宽等)以服务的形式提供给用户,并允许用户根据需求进行灵活的资源调配和管理。云计算通常分为三个层次,即基础设施即服务(IaaS)、平台即服务(PaaS)和软件即服…

MongoDB常用语句

MongoDB常用语句 使用创建和删除查询条件查询模糊查询分页排序聚合两表连接 插入 使用 展示数据库 show dbs 或 show databases 查看当前在使用的数据库 db展示数据库下所有表 show collections 或 show tables;终端内容过多,用该指令清屏 cls创建和删除 如果…

[SQL挖掘机] - 比较运算符

介绍: 在 sql 中,比较运算符用于比较表达式或值之间的关系,并生成逻辑真(true)或逻辑假(false)的结果。比较运算符在 sql 查询中扮演着重要的角色,具有以下作用和地位: 条件筛选&a…

【Matlab】基于径向基神经网络的数据回归预测(Excel可直接替换数据)

【Matlab】基于径向基神经网络的数据回归预测(Excel可直接替换数据) 1.模型原理2.数学公式3.文件结构4.Excel数据5.分块代码6.完整代码7.运行结果1.模型原理 基于径向基神经网络(Radial Basis Function Neural Network,RBFNN)的数据回归预测是一种基于神经网络的回归模型…

【Matlab】基于遗传算法优化 BP 神经网络的时间序列预测(Excel可直接替换数据)

【Matlab】基于遗传算法优化 BP 神经网络的时间序列预测(Excel可直接替换数据) 1.模型原理2.文件结构3.Excel数据4.分块代码4.1 arithXover.m4.2 delta.m4.3 ga.m4.4 gabpEval.m4.5 initializega.m4.6 maxGenTerm.m4.7 nonUnifMutation.m4.8 normGeomSel…

使用Redis实现双平面部署的最佳实践

引言: 双平面部署是一种常见的系统架构模式,用于提高系统的可靠性和性能。在这种架构中,拥有相同功能的两个平面同时运行,其中一个平面作为主平面处理请求,而另一个平面则作为备份平面。在传统的双平面部署中&#xf…

操作系统笔记、面试八股(三)—— 系统调用与内存管理

文章目录 3. 系统调用3.1 用户态与内核态3.2 系统调用分类3.3 如何从用户态切换到内核态(系统调用举例) 4. 内存管理4.1 内存管理是做什么的4.1.1 为什么需要虚拟地址空间4.1.2 使用虚拟地址访问内存有什么优势 4.2 常见的内存管理机制4.3 分页管理4.3.1…

Android-WebRTC-实现摄像头显示

EglBase是什么? 它提供了一个接口,用于在Android平台上创建和管理EGL(嵌入式系统图形库)上下文,以便在WebRTC中进行图像和视频的处理和渲染。 Capturer, Source, Track, Sink分别是什么? Capturer&#xff…

2023C语言暑假作业day3

1 选择题 1 已知函数的原型是: int fun(char b[10], int a); ,设定义: char c[10];int d; ,正确的调用语句是 A: fun(c,&d); B: fun(c,d); C: fun(&c,&d); D: fun(&c,d); 答案解析: 正确答案&#x…

kettle开发-Day40-AI分流之case/switch

前言: 前面我们讲到了很多关于数据流的AI方面的介绍,包括自定义组件和算力提升这块的,今天我们来学习一个关于kettle数据分流处理非常重要的组件Switch / Case 。当我们的数据来源于类似日志、csv文件等半结构化数据时,我们需要在…

Vmware+CentOS+KGDB内核双机调试

1.准备两台CentOS系统的vmware虚拟机 其中一台作为调试机,另一台则作为被调试机。如下图,CentOS7.9x64为被调试机,CentOS7.9x64-Debugger为调试机 2.配置串口设备 若虚拟机有串口设备(如打印机),需要先删…

黑马 pink h5+css3+移动端前端

网页概念 网页是网站的一页,网页有很多元素组成,包括视频图片文字视频链接等等,以.htm和.html后缀结尾,俗称html文件 HTML 超文本标记语言,描述网页语言,不是编程语言,是标记语言,有标签组成 超文本指的是不光文本,还有图片视频等等标签 常用浏览器 firefox google safari…