TensorFlow 03(Keras)

一、tf.keras

tf.keras是TensorFlow 2.0的高阶API接口,为TensorFlow的代码提供了新的风格和设计模式,大大提升了TF代码的简洁性和复用性,官方也推荐使用tf.keras来进行模型设计和开发。

1.1 tf.keras中常用模块

如下表所示:

1.2 常用方法

深度学习实现的主要流程:

1.数据获取,

2 数据处理,

3 模型创建与训练,

4 模型测试与评估,

5.模型预测

导入tf.keras

使用 tf.keras,首先需要在代码开始时导入tf.keras

import tensorflow as tf
from tensorflow import keras

数据输入

 对于小的数据集,可以直接使用numpy格式的数据进行训练、评估模型,对于大型数据集或者要进行跨设备训练时使用tf.data.datasets来进行数据输入。

模型构建

  • 简单模型使用Sequential进行构建
  • 复杂模型使用函数式编程来构建
  • 自定义layers

训练与评估

  • 配置训练过程
# 配置优化方法,损失函数和评价指标
model.compile(optimizer=tf.train.AdamOptimizer(0.001),loss='categorical_crossentropy',metrics=['accuracy'])
  • 模型训练
# 指明训练数据集,训练epoch,批次大小和验证集数据
model.fit/fit_generator(dataset, epochs=10, batch_size=3,validation_data=val_dataset,)
  • 模型评估
# 指明评估数据集和批次大小
model.evaluate(x, y, batch_size=32)
  • 模型预测
# 对新的样本进行预测
model.predict(x, batch_size=32)

回调函数(callbacks)

回调函数用在模型训练过程中,来控制模型训练行为,可以自定义回调函数,也可使用tf.keras.callbacks 内置的 callback :

ModelCheckpoint:定期保存 checkpoints。 LearningRateScheduler:动态改变学习速率。 EarlyStopping:当验证集上的性能不再提高时,终止训练。 TensorBoard:使用 TensorBoard 监测模型的状态。

模型的保存和恢复

  • 只保存参数
# 只保存模型的权重
model.save_weights('./my_model')
# 加载模型的权重
model.load_weights('my_model')
  • 保存整个模型
# 保存模型架构与权重在h5文件中
model.save('my_model.h5')
# 加载模型:包括架构和对应的权重
model = keras.models.load_model('my_model.h5')

二、keras构建模型

 

2.1 相关的库的导入

在这里使用sklearn和tf.keras完成鸢尾花分类,导入相关的工具包:

# 绘图
import seaborn as sns
# 数值计算
import numpy as np
# sklearn中的相关工具
# 划分训练集和测试集
from sklearn.model_selection import train_test_split
# 逻辑回归
from sklearn.linear_model import LogisticRegressionCV
# tf.keras中使用的相关工具
# 用于模型搭建
from tensorflow.keras.models import Sequential
# 构建模型的层和激活方法
from tensorflow.keras.layers import Dense, Activation
# 数据处理的辅助工具
from tensorflow.keras import utils

 

2.2 数据展示和划分

利用seborn导入相关的数据,iris数据以dataFrame的方式在seaborn进行存储,我们读取后并进行展示;

将数据划分为训练集和测试集:从iris dataframe中提取原始数据,将花瓣和萼片数据保存在数组X中,标签保存在相应的数组y中;

# 读取数据
iris = sns.load_dataset("iris")
# 展示数据的前五行
iris.head()# 花瓣和花萼的数据
X = iris.values[:, :4]
# 标签值
y = iris.values[:, 4]# 将数据集划分为训练集和测试集
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.5, test_size=0.5, random_state=0)

另外,利用seaborn中pairplot函数探索数据特征间的关系:

# 将数据之间的关系进行可视化
sns.pairplot(iris, hue='species')

2.3 sklearn实现

利用逻辑回归的分类器,并使用交叉验证的方法来选择最优的超参数,实例化LogisticRegressionCV分类器,并使用fit方法进行训练:

# 实例化分类器
lr = LogisticRegressionCV()
# 训练
lr.fit(train_X, train_y)# 计算准确率并进行打印
print("Accuracy = {:.2f}".format(lr.score(test_X, test_y)))Accuracy = 0.93

2.4 tf.keras实现

数据准备

在sklearn中我们只要实例化分类器并利用fit方法进行训练,最后衡量它的性能就可以了,那在tf.keras中与在sklearn非常相似,不同的是:

  • 构建分类器时需要进行模型搭建
  • 数据采集时,sklearn可以接收字符串型的标签,如:“setosa”,但是在tf.keras中需要对标签值进行热编码,如下所示:

有很多方法可以实现热编码,比如pandas中的get_dummies(),在这里我们使用tf.keras中的方法进行热编码:

# 进行热编码
def one_hot_encode_object_array(arr):# 去重获取全部的类别uniques, ids = np.unique(arr, return_inverse=True)# 返回热编码的结果return utils.to_categorical(ids, len(uniques))#对标签值进行热编码 
# 训练集热编码
train_y_ohe = one_hot_encode_object_array(train_y)
# 测试集热编码
test_y_ohe = one_hot_encode_object_array(test_y)

 

模型搭建

在sklearn中,模型都是现成的。tf.Keras是一个神经网络库,我们需要根据数据和标签值构建神经网络。

神经网络可以发现特征与标签之间的复杂关系。

神经网络是一个高度结构化的图,其中包含一个或多个隐藏层。

每个隐藏层都包含一个或多个神经元。

神经网络有多种类别,该程序使用的是密集型神经网络,也称为全连接神经网络:一个层中的神经元将从上一层中的每个神经元获取输入连接。例如,图 2 显示了一个密集型神经网络,其中包含 1 个输入层、2 个隐藏层以及 1 个输出层,如下图所示:

上图 中的模型经过训练并馈送未标记的样本时,它会产生 3 个预测结果:相应鸢尾花属于指定品种的可能性。对于该示例,输出预测结果的总和是 1.0。该预测结果分解如下:山鸢尾为 0.02,变色鸢尾为 0.95,维吉尼亚鸢尾为 0.03。这意味着该模型预测某个无标签鸢尾花样本是变色鸢尾的概率为 95%。

TensorFlow tf.keras API 是创建模型和层的首选方式。通过该 API,您可以轻松地构建模型并进行实验,而将所有部分连接在一起的复杂工作则由 Keras 处理。

tf.keras.Sequential 模型是层的线性堆叠。该模型的构造函数会采用一系列层实例;在本示例中,采用的是 2 个密集层(分别包含 10 个节点)以及 1 个输出层(包含 3 个代表标签预测的节点)。第一个层的 input_shape 参数对应该数据集中的特征数量:

# 利用sequential方式构建模型
model = Sequential([# 隐藏层1,激活函数是relu,输入大小有input_shape指定Dense(10, activation="relu", input_shape=(4,)),  # 隐藏层2,激活函数是reluDense(10, activation="relu"),# 输出层Dense(3,activation="softmax")
])

通过model.summary可以查看模型的架构:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 10)                50        
_________________________________________________________________
dense_1 (Dense)              (None, 10)                110       
_________________________________________________________________
dense_2 (Dense)              (None, 3)                 33        
=================================================================
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________             

激活函数可决定层中每个节点的输出形状。这些非线性关系很重要,如果没有它们,模型将等同于单个层。激活函数有很多,但隐藏层通常使用 ReLU

隐藏层和神经元的理想数量取决于问题和数据集。与机器学习的多个方面一样,选择最佳的神经网络形状需要一定的知识水平和实验基础。一般来说,增加隐藏层和神经元的数量通常会产生更强大的模型,而这需要更多数据才能有效地进行训练。

模型训练和预测

在训练和评估阶段,我们都需要计算模型的损失。这样可以衡量模型的预测结果与预期标签有多大偏差,也就是说,模型的效果有多差。我们希望尽可能减小或优化这个值,所以我们设置优化策略和损失函数,以及模型精度的计算方法:

# 设置模型的相关参数:优化器,损失函数和评价指标
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=["accuracy"])

接下来与在sklearn中相同,分别调用fit和predict方法进行预测即可。

# 模型训练:epochs,训练样本送入到网络中的次数,batch_size:每次训练的送入到网络中的样本个数
model.fit(train_X, train_y_ohe, epochs=10, batch_size=1, verbose=1);
  1. 迭代每个epoch。通过一次数据集即为一个epoch。
  2. 在一个epoch中,遍历训练 Dataset 中的每个样本,并获取样本的特征 (x) 和标签 (y)。
  3. 根据样本的特征进行预测,并比较预测结果和标签。衡量预测结果的不准确性,并使用所得的值计算模型的损失和梯度。
  4. 使用 optimizer 更新模型的变量。
  5. 对每个epoch重复执行以上步骤,直到模型训练完成。

与sklearn中不同,对训练好的模型进行评估时,与sklearn.score方法对应的是tf.keras.evaluate()方法,返回的是损失函数和在compile模型时要求的指标: 

# 计算模型的损失和准确率
loss, accuracy = model.evaluate(test_X, test_y_ohe, verbose=1)
print("Accuracy = {:.2f}".format(accuracy))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/80302.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TCP/IP网络江湖——数据链路层的协议与传承(数据链路层中篇:数据链路层的协议与帧)

0、引言 网络江湖,宛如千年武林,承载着代代传承的文化和传统。在这个广袤的江湖之中,数据链路层犹如武林门派,代代传承着网络通信的精华。这一部分将带领我们深入探讨数据链路层的协议与传承,揭示其在网络江湖中的精彩故事。 就如同江湖中的武者需要基本功夫作为修行的基础…

mysql如何实现根据经纬度判断某一个坐标是否在一个多边形区域范围内

要根据经纬度判断一个坐标是否在一个多边形区域内,MySQL提供了几种函数来处理地理空间数据,其中包括用于处理多边形区域的函数。 1.创建一个包含多边形区域的表: 首先,创建一个表来存储多边形区域。可以使用ST_GeomFromText函数将…

kuiper 规则sql写法

创建规则对接收到的报文数据进行业务过滤,报文有各种结构的,下面对各种结构报文sql过滤使用进行说明 下面sql规则统一对temperature大于20的数据进行过滤 1:单层结构报文 {"temperature": 35,"humidity": 66 } sql写…

【Leetcode Sheet】Weekly Practice 6

Leetcode Test 2605 从两个数字数组里生成最小数字(9.5) 给你两个只包含 1 到 9 之间数字的数组 nums1 和 nums2 &#xff0c;每个数组中的元素 互不相同 &#xff0c;请你返回 最小 的数字&#xff0c;两个数组都 至少 包含这个数字的某个数位。 提示&#xff1a; 1 < …

MySQL8--my.cnf配置文件的设置

原文网址&#xff1a;MySQL8--my.cfg配置文件的设置_IT利刃出鞘的博客-CSDN博客 简介 本文介绍MySQL8的my.cnf的配置。 典型配置 [client] default-character-setutf8mb4[mysql] default-character-setutf8mb4[mysqld] #服务端口号 默认3306 port3306datadir /work/docker…

一个FlutterCocoapods项目打包问题集锦

一个Flutter&Cocoapods项目打包问题集锦 问题1 github加速问题 cocoapods项目需要访问https://github.com/CocoaPods/Specs.git&#xff0c;众所周知&#xff0c;github经常被墙&#xff0c;导致经常需要借助加速来下载和访问&#xff0c;这里可以使用油猴脚本或者Fastgi…

kibana报错内存溢出问题解决

一、背景&#xff1a; kibana内存溢出&#xff0c;进程被kill掉&#xff0c;导致前端页面访问不到。 报错内容 二、报错原因&#xff1a; 发现是前端 js 报的内存 oom 异常&#xff0c;通过网上资料发现node.js 的默认内存大小为1.4G Node 中通过 JavaScript 使用内存时只能…

Promethues(五)查询-PromQL 语言-保证易懂好学

文章目录 一、介绍二、PromQL 数据类型三、常量1 字符串2 浮点 四、时间序列选择器 Time series Selectors1 即时矢量&#xff08;Instant vector&#xff09;选择器2 范围矢量选择器2.1 时间长度2.2 偏移修饰符2.3 修饰符 3 避免慢速查询和过载 五、子查询六、操作符 Operato…

【C++】深拷贝和浅拷贝 ② ( 默认拷贝构造函数是浅拷贝 | 代码示例 - 浅拷贝造成的问题 )

文章目录 一、默认拷贝构造函数是浅拷贝1、默认拷贝构造函数2、默认拷贝构造函数是浅拷贝机制 二、代码示例 - 浅拷贝造成的问题 一、默认拷贝构造函数是浅拷贝 1、默认拷贝构造函数 如果 C 类中 没有定义拷贝构造函数 , C 编译器会自动为该类提供一个 " 默认的拷贝构造函…

连接MySQL时报错:Public Key Retrieval is not allowed的解决方法

问题描述&#xff1a; DBeaver 连接 mysql 时报错&#xff1a;Public Key Retrieval is not allowed&#xff08;不允许公钥检索&#xff09; 解决方法&#xff1a; 连接设置 -> 驱动属性 -> allowPublicKeyRetrievalfalse&#xff08;这里的运输公钥检索是默认关闭的&a…

如何在RK3568开发板上实现USBNET?——飞凌嵌入式/USB Gadget/USB-NET/网络

本文将借助飞凌嵌入式OK3568-C开发板为大家介绍实现USBNET模式的方法&#xff0c;在这之前需要先知道什么是USB Gadget——USB Gadget是指所开发的电子设备以USB从设备的模式通过USB连接到主机。举个例子&#xff1a;将手机通过USB线插入PC后&#xff0c;手机就是USB Gadget。同…

pt24django教程

静态文件访问 不能与服务器端做动态交互的文件都是静态文件&#xff0c;如: 图片,css,js,音频,视频,html文件(部分) 静态文件配置 在 settings.py 中配置一下两项内容: STATIC_URL 静态文件的访问路径&#xff0c;通过哪个url地址找静态文件 &#xff0c;STATIC_URL ‘/s…

[Linux入门]---搭建Linux环境

1.Linux环境的搭建方式 使用Linux操作系统的三种途径&#xff1a; 1.直接安装在物理机上&#xff0c;但是由于 Linux 桌面使用起来非常不友好&#xff0c;不推荐。 2.使用虚拟机软件&#xff0c;将 Linux 搭建在虚拟机上&#xff0c;但是由于当前的虚拟机软件(如 VMWare 之类的…

多线程案例(3) - 定时器,线程池

一&#xff0c;定时器 定时器作用&#xff1a;约定一个时间间隔&#xff0c;时间到达后&#xff0c;执行某段代码逻辑。实际上就是一个 "闹钟" 。 1.1使用标准库中的定时器 标准库中提供了一个 Timer 类. Timer 类的核心方法为 schedule .Timer 类中含有一个扫描线…

element-ui文件下载(单个)

1. 单个附件下载 <el-buttontype"text"size"small"click.native.prevent"download(scope.row)" >下载</el-button>export default {data() {return {downloadUrl: http://127.0.0.1:8881/XX/XX, // 下载接口}},methods: {download(…

国庆中秋特辑(一)浪漫祝福方式 用循环神经网络(RNN)或长短时记忆网络(LSTM)生成祝福诗词

目录 一、使用深度学习中的循环神经网络&#xff08;RNN&#xff09;或长短时记忆网络&#xff08;LSTM&#xff09;生成诗词二、优化&#xff1a;使用双向 LSTM 或 GRU 单元来更好地捕捉上下文信息三、优化&#xff1a;使用生成对抗网络&#xff08;GAN&#xff09;或其他技术…

YOLOV7改进-添加基于注意力机制的目标检测头(DYHEAD)

DYHEAD 复制到这&#xff1a; 1、models下新建文件 2、yolo.py中import一下 3、改IDetect这里 4、论文中说6的效果最好&#xff0c;但参数量不少&#xff0c;做一下工作量 5、在进入IDetect之前&#xff0c;会对RepConv做卷积 5、因为DYHEAD需要三个层输入的特征层一致&am…

Jetpack Compose 介绍和快速上手

Compose版本发展 19年&#xff0c;Compose在Google IO大会横空出世&#xff0c;大家都议论纷纷&#xff0c;为其前途堪忧。 21年7月Compose 1.0的正式发布&#xff0c;却让大家看到了Google在推广Compose上的坚决&#xff0c;这也注定Compose会成为UI开发的新风向。 23年1月…

can‘t sync to target.

飞翔仿真器 无法 与S12单片机 建立联系&#xff0c;仿真时显示 cant sync to target. 但是使用仿真器与其他板子连接仿真是没问题的。 首先怀疑硬件问题&#xff1a;没发现问题&#xff1b; 然后&#xff0c;勇敢的点击菜单中 设置速度&#xff0c;根据自己晶振和建议设置如…

套接字通信之 端口

端口 端口的本质? 无符号短整型数-> unsigned short端口取值范围? 可以有多少个端口? 2的16次方取值范围:0 - 65535 端口的作用? 定位某台主机上运行的某个进程 在电脑上运行了微信和QQ&#xff0c;小明给我的的微信发消息&#xff0c;电脑上的微信就收到了消息&#…