使用TensorFlow训练深度学习模型实战(下)

大家好,本文接TensorFlow训练深度学习模型的上半部分继续进行讲述,下面将介绍有关定义深度学习模型、训练模型和评估模型的内容。

定义深度学习模型

数据准备完成后,下一步是使用TensorFlow搭建神经网络模型,搭建模型有两个选项:

可以使用各种层,包括Dense、Conv2D和LSTM,从头开始搭建模型。这些层定义了模型的架构及数据流经过它的方式,可基于TensorFlow Hub提供的预训练模型搭建模型。这些模型已经在大型数据集上进行了训练,并可以在特定数据集上进行微调,以达到在较短的训练时间内达到较高的准确度。

可以根据TensorFlow Hub中的预训练模型来建立模型。这些模型已经在大型数据集上进行了训练,并且可以在你的特定数据集上进行微调,以达到较少的训练时间,达到较高的准确性。

  • 从头开始定义深度学习模型

TensorFlow中的tf.keras.Sequential函数允许我们逐层定义神经网络模型,我们可以选择各种层,如Dense、Conv2D和LSTM,来搭建定制的模型架构。以下是示例: 

# 定义模型架构
model = tf.keras.Sequential([tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(10)
])

在这个示例中,我们定义了一个模型,包含以下六个层(4个隐藏层):

  1. Conv2D层,具有32个过滤器,3x3的内核大小和ReLU激活。此层以形状为(28,28,1)的输入图像作为输入。

  2. MaxPooling2D层,具有默认的2x2池大小。此层对从上一层获得的特征映射进行下采样。

  3. Flatten层,将2D特征映射展平为1D向量。

  4. Dense层,具有128个神经元和ReLU激活。此层对展平的特征映射执行完全连接操作。

  5. Dropout层,在训练期间随机丢弃50%的连接以防止过拟合。

  6. Dense层,具有十个神经元,无激活函数。此层表示模型的输出层,神经元的数量对应于分类任务中的类别数目。

这个模型遵循典型的卷积神经网络架构,包括多个卷积层和池化层,以及一个或多个全连接层。

  • 从预训练模型定义深度学习模型 

利用TensorFlow Hub提供的预训练模型可能是一个不错的选择,因为它们已经在大量的数据集上进行了训练,可以帮助在减少训练时间的同时实现高准确度。在实现任何这些模型之前,让我们先了解一些TensorFlow Hub提供的常见预训练模型。

  1. VGG:The Visual Geometry Group(VGG)模型是由牛津大学开发的。这些模型广泛用于图像分类任务,并在各种基准数据集上取得了最先进的结果。

  2. ResNet:The Residual Network(ResNet)模型是由微软研究院开发的。这些模型具有独特的架构,可以训练非常深的神经网络(高达1000层)。

  3. Inception:Inception模型是由Google开发的。这些模型具有独特的架构,使用不同尺度的多个并行卷积,Inception模型广泛用于目标检测和图像分类任务。

  4. MobileNet:MobileNet模型是由Google开发的。这些模型具有针对移动设备和嵌入式设备进行优化的独特架构,MobileNet模型广泛用于移动设备上的图像分类和目标检测任务。

可以通过向预训练模型添加额外层并在特定数据集上训练模型来应用迁移学习。与从头开始训练模型相比,这种技术可以节省大量时间和计算资源。但是,在选择预训练模型并将数据集转换为该格式以确保兼容之前,了解预训练模型所需的输入格式非常重要。

在这个示例中,MobileNet模型被作为基本模型使用。在使用基本模型之前,检查模型所需的格式非常重要, 在本示例中,格式为(224,224,3)。然而,MNIST数据集是一个灰度图像,大小为(28,28,1),其中单个值表示像素的亮度。图像大小也比所需的格式要小得多。因此,需要重新调整数据集。以下是调整大小的主要思路:

使用image.resize函数将图像调整为所需的大小。该函数使用双线性插值来保留原始图像中的信息,同时将其调整为新大小。因此,此步骤可以将原始形状(28,28,1)调整为(224,224,1)的形状。

使用image.grayscale_to_rgb函数将图像转换为新的RGB图像,通过将单个灰度通道复制到新的RGB图像的所有三个通道中,从而将原始形状(224,224,1)调整为(224,224,3)的形状。

# 调整输入图像的大小为224x224,并将其转换为三通道的RGB图像
X_train = tf.image.grayscale_to_rgb(tf.image.resize(X_train, [224, 224]))
X_test = tf.image.grayscale_to_rgb(tf.image.resize(X_test, [224, 224]))

 现在让我们基于MobileNet模型定义我们的模型:

# 加载MobileNet模型,不包括顶层
base_model = MobileNet(include_top=False, input_shape=(224, 224, 3))# 添加一个全局平均池化层和一个全连接输出层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
x = Dense(10, activation='softmax')(x)# 将基础模型和新层结合起来,创建完整的模型
model = tf.keras.models.Model(inputs=base_model.input, outputs=x)# 冻结基础模型中的各层
for layer in base_model.layers:layer.trainable = False

在上面的示例中,我们定义了一个模型,如下所示:

  1. 使用MobileNet()定义基本模型

  2. GlobalAveragePooling2D层,使用基本模型的最后一个卷积层的输出,计算每个特征映射的平均值,从而得到一个固定长度的向量,总结了特征映射中的空间信息。

  3. Dropout层,在训练期间随机丢弃50%的连接以防止过拟合。

  4. Dense层,使用十个单元的完全连接层和softmax激活。它接收来自上一层的输出并生成覆盖十个可能类别的概率分布。

编译和训练模型 

在创建模型之后,必须通过指定在训练期间使用的损失函数、优化器和指标来编译它。以下是一个编译模型的示例代码:

# 编译该模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

由于这是一个多分类问题,因此此示例代码使用了稀疏交叉熵损失函数,我们使用的是Adam优化器和准确率指标。

在训练模型之后可以在测试集上评估它,以查看它在未见过的数据上的表现如何,以下是一个评估模型的示例代码:

# 在测试数据上评估该模型
test_loss, test_acc = model.evaluate(X_test, y_test)
print('Test loss: ', test_loss)
print('Test Acc: ', test_acc)

 在此示例代码中,我们在测试集上评估模型,并输出测试损失和准确率。

进行预测

一旦训练和评估了模型,就可以使用它来预测新数据。以下是一个进行预测的示例代码:

# 对新数据进行预测
y_pred = model.predict(X_test)
y_pred_labels = np.argmax(y_pred, axis=1)
print(y_pred_labels)

在此示例代码中,我们在模型上使用predict()方法对整个测试集进行预测。

如果我们想要预测单个图像并返回预测标签与真实标签,那么就需要对Keras模型的predict()方法进行更改。因为Keras模型的predict()方法期望输入数据形式为一批图像,而我们想要传递单个图像给predict()方法,所以需要将其重新调整为批次大小为1。

def predict_and_compare(model, X_test, y_test, index):# 从X_test中获取给定索引的例子example = X_test[index]# 将例子重塑为预期的输入形状example = np.reshape(example, (1, 28, 28, 1))# 预测这个例子的标签y_pred = model.predict(example)# 将预测的概率转换为类别标签y_pred_label = np.argmax(y_pred, axis=1)[0]# 使用索引从y_test获取真实标签y_test_array = y_test.values# Get the label for the first example in the test set y_true = y_test_array[index]# 输出预测的和真实的标签print("Predicted label:", y_pred_label)print("True label:", y_true)# 返回预测的和真实的标签return y_pred_label, y_true# 预测并比较测试集中第一个例子的标签
y_pred_label, y_true = predict_and_compare(model, X_test, y_test, 0)

在上面的示例中,我们通过添加一个额外的维度来代表批次大小,从而将输入图像从(28,28,1)调整为(1,28,28,1)。这样,我们就可以传递单个图像给predict()方法,并获得该图像的预测结果。当我们调用上面的函数时,可以自定义要预测的图像:

 这就是在TensorFlow中实现深度学习的步骤。当然,这只是一个基本示例。你可以搭建具有更多层、不同类型的层和不同超参数的更复杂的模型,以便在数据集上获得更好的性能。

综上,本文我们演示了如何对数据进行预处理、搭建和训练模型、在单独的测试集上评估其性能以及使用简单的卷积神经网络(CNN)进行图像分类的预测,通过学习可以获得如何在TensorFlow中构建深度学习模型以及如何将这些概念应用于真实世界数据集的理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/13315.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

x86架构ubuntu22用docker部署zsnes

0. 环境 x86 ubuntu22 1. 安装docker $ sudo apt remove docker docker-engine docker $ sudo apt update $ sudo apt install -y apt-transport-https ca-certificates curl software-properties-common$ curl -fsSL http://mirrors.aliyun.com/docker-ce/linux/ubuntu/gpg …

JDK17 中的新特性初步了解

1. Switch 语句的增强 jdk12 ,switch语句不用写break了,直接写箭头和对应的值。 jdk 17中, 加了一个逗号,用于匹配多对一。 如果要在每个case里写逻辑,可以写在花括号里。 在返回值的前面加上yield的关键字。 也可以对…

C++ 编程入门(一)—— Hello World

C 是什么环境搭建第一个 C 程序本篇结语 C 是什么 C 是一种面向对象的计算机程序设计语言,由美国 AT&T 贝尔实验室的 Bjarne Stroustrup 在 20 世纪 80 年代初期发明并实现(最初这种语言被称作 “C with Classes” 带类的 C 语言)。它是一…

多线程案例 | 单例模式、阻塞队列、定时器、线程池

多线程案例 1、案例一:线程安全的单例模式 单例模式 单例模式是设计模式的一种 什么是设计模式? 设计模式好比象棋中的 “棋谱”,红方当头炮,黑方马来跳,针对红方的一些走法,黑方应招的时候有一些固定的…

ScheduledThreadPoolExecutor 及 ThreadPoolExecutor的基本使用及说明

关于作者:CSDN内容合伙人、技术专家, 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 ,擅长java后端、移动开发、人工智能等,希望大家多多支持。 目录 一、导读二、概览2.1 为什么不推荐使用Executors去创建线程池 三、…

C++ list底层实现原理

文章目录 一、list底层实现二、类构成三、构造函数四、迭代器五、获取第一个元素六、获取最后一个元素七、插入元素 一句话:list底层实现一个双向循环链表 一、list底层实现 一个双向循环链表 二、类构成 class list : protected_List_base_list_base.lsit_impl…

python进阶书籍的推荐 知乎,python入门后如何进阶

本篇文章给大家谈谈python进阶书籍的推荐 知乎,以及python入门后如何进阶,希望对各位有所帮助,不要忘了收藏本站喔。 1、Python应该怎么学_python应该怎么学 想要学习Python,需要掌握的内容还是比较多的,对于自学的同…

【MySQL】索引特性

​🌠 作者:阿亮joy. 🎆专栏:《零基础入门MySQL》 🎇 座右铭:每个优秀的人都有一段沉默的时光,那段时光是付出了很多努力却得不到结果的日子,我们把它叫做扎根 目录 👉没…

重大更新|Sui主网即将上线流动性质押,助力资产再流通

Sui社区一直提议官方上线流动质押功能,现在通过SIP过程,已经升级该协议以实现这一功能。 Sui使用委托权益证明机制(DPoS)来选择和奖励负责运营网络的验证节点。为了保障网络安全,验证节点通过质押SUI token获得质押奖…

抖音短视频矩阵系统源码:SEO优化开发解析

抖音短视频矩阵系统源码是一个基于抖音短视频平台的应用程序。它允许用户上传和观看短视频,以及与其他用户交互。SEO优化开发解析是指对该系统进行搜索引擎优化的开发解析。 一、 在进行SEO优化开发解析时,可以考虑以下几点: 关键词优化&…

Java Stream流

Java 8 版本新增的Stream,配合同版本出现的Lambda ,给我们操作集合(Collection)提供了极大的便利。Stream流是JDK8新增的成员,允许以声明性方式处理数据集合,可以把Stream流看作是遍历数据集合的一个高级迭…

linux 系统编程

C标准函数与系统函数的区别 什么是系统调用 由操作系统实现并提供给外部应用程序的编程接口。(Application Programming Interface,API)。是应用程序同系统之间数据交互的桥梁。 一个helloworld如何打印到屏幕。 每一个FILE文件流(标准C库函数&#xff…

前端调用合约如何避免出现transaction fail

前言: 作为开发,你一定经历过调用合约的时候发现 gas fee 超出限制,但是不知道报了什么错。这个时候一般都是触发了require错误合约校验。对于用户来说他不理解为什么一笔交易会花费如此大的gas,那我们作为开发如何尽量避免这种情…

Jvm的一些技巧

反编译字节码文件 找到对应的class文件所在的目录,使用javap -v -p 命令 查询运行中某个Java进程的Jvm参数 【案例】查询 MethodAreaDemo 这个类运行过程中,初始的元空间大小 MetaspaceSize jps 查询 Java 进程的进程ID ![在这里插入图片描述](https…

新零售行业如何做会员管理和会员营销

蚓链数字化营销系统全渠道会员管理解决方案,线上线下统一管理,打造私域流量,微信、门店会员全渠道管理,打通私域流量池,实现裂变营销: 开启新零售之路,必然要摒弃原有的管理模式,大…

C# NDArray System.IO.FileLoadException报错原因分析

C# NDArray System.IO.FileLoadException 报错原因分析: 1.NuGet程序包版本有冲突 2.统一项目版本 1.打开解决方案NuGet程序包设置 2.查看是否有版本冲突 3.统一版本冲突

C++终止cin输入while循环时多读取^Z或^D的问题

原代码&#xff1a; istream& operator>>(istream& is, map<string, int>&mm) {string ss"";int ii0;is >> ss>>ii;mm[ss]ii;return is; }int main() {map<string,int>msi;while(cin>>msi);return 0; } 问题&…

【探讨】Java POI 处理 Excel 中的名称管理器

前言 最近遇到了一些导表的问题。原本的导表工具导不了使用名称管理器的Excel。 首先我们有两个Sheet。B1用的是名称管理器中的AAA, 而B2用的对应的公式。 第二个sheet&#xff0c;名为Test2: 这是一段简化的代码&#xff1a; public class Main {public static void mai…

【Python】将M4A\AAC录音文件转换为MP3文件

文章目录 m4aaac 基础环境&#xff1a; sudo apt-get install ffmpegm4a 要将M4A文件转换为MP3文件&#xff0c;你可以使用Python中的第三方库pydub。pydub使得音频处理变得非常简单。在开始之前&#xff0c;请确保你已经安装了pydub库&#xff0c;如果没有&#xff0c;可以通…

7.25 Qt

制作一个登陆界面 login.pro文件 QT core guigreaterThan(QT_MAJOR_VERSION, 4): QT widgetsCONFIG c11# The following define makes your compiler emit warnings if you use # any Qt feature that has been marked deprecated (the exact warnings # depend on …