政安晨:【Keras机器学习实践要点】(三)—— 编写组件与训练数据

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras实战演绎机器学习

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

介绍

通过 Keras,您可以编写自定义层、模型、度量指标、损失和优化器,并在同一代码库中跨 TensorFlow、JAX 和 PyTorch 运行

老规矩,咱们还是先准备环境(参考我本专栏目录中的文章,其中有搭建环境的部分):

政安晨:【TensorFlow与Keras实战演绎机器学习】专栏 —— 目录icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136985399

准备好环境后,咱们开始。

编写组件

让我们先来看看自定义层

{keras.ops 命名空间包含}
1. NumPy API 的实现,例如 keras.ops.stack 或 keras.ops.matmul
2. 一组 NumPy 中没有的神经网络特定操作,如 keras.ops.conv 或 keras.ops.binary_crossentropy

让我们创建一个可与所有后端配合使用的自定义密集层

class MyDense(keras.layers.Layer):def __init__(self, units, activation=None, name=None):super().__init__(name=name)self.units = unitsself.activation = keras.activations.get(activation)def build(self, input_shape):input_dim = input_shape[-1]self.w = self.add_weight(shape=(input_dim, self.units),initializer=keras.initializers.GlorotNormal(),name="kernel",trainable=True,)self.b = self.add_weight(shape=(self.units,),initializer=keras.initializers.Zeros(),name="bias",trainable=True,)def call(self, inputs):# Use Keras ops to create backend-agnostic layers/metrics/etc.x = keras.ops.matmul(inputs, self.w) + self.breturn self.activation(x)

接下来,让我们制作一个依赖于keras.random命名空间的自定义Dropout层

class MyDropout(keras.layers.Layer):def __init__(self, rate, name=None):super().__init__(name=name)self.rate = rate# Use seed_generator for managing RNG state.# It is a state element and its seed variable is# tracked as part of `layer.variables`.self.seed_generator = keras.random.SeedGenerator(1337)def call(self, inputs):# Use `keras.random` for random ops.return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)

接下来,让我们编写一个自定义子类模型,使用我们的两个自定义层:

class MyModel(keras.Model):def __init__(self, num_classes):super().__init__()self.conv_base = keras.Sequential([keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),keras.layers.MaxPooling2D(pool_size=(2, 2)),keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),keras.layers.GlobalAveragePooling2D(),])self.dp = MyDropout(0.5)self.dense = MyDense(num_classes, activation="softmax")def call(self, x):x = self.conv_base(x)x = self.dp(x)return self.dense(x)

让我们编译并适配它:

model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)model.fit(x_train,y_train,batch_size=batch_size,epochs=1,  # For speedvalidation_split=0.15,
)

现在咱们演绎如下

在本地的TensorFlow虚拟环境中,首先导入keras:

from tensorflow import keras

(可以在Jupyter Notebook中运行)

如果在演绎执行中出错,可能是Keras版本问题,使用如下命令升级keras

sudo pip install --upgrade keras

执行结果:

训练模型

在任意数据源上训练模型

所有的Keras模型都可以在各种数据来源上进行训练和评估,与您使用的后端无关。这包括:

NumPy数组 Pandas数据框 TensorFlow tf.data.Dataset对象 PyTorch DataLoader对象 Keras PyDataset对象 无论您使用TensorFlow、JAX还是PyTorch作为Keras后端,它们都可以工作。

让我们尝试使用PyTorch DataLoader:

import torch# Create a TensorDataset
train_torch_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_torch_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)
)# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(train_torch_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(val_torch_dataset, batch_size=batch_size, shuffle=False
)model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)
model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)

现在让我们尝试使用tf.data来完成这个任务

import tensorflow as tftrain_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
)
test_dataset = (tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
)model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)
model.fit(train_dataset, epochs=1, validation_data=test_dataset)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/770661.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

位段详细解释

结构体位段的使用原则 在C语言中,结构体(Struct)是一种复合数据类型,它允许我们将多个不同类型的数据项组合成一个单一的实体。位段(Bit Field)是结构体中的一个特殊成员,它允许我们只取结构体…

常用中间件redis,kafka及其测试方法

常用消息中间件及其测试方法 一、中间件的使用场景引入中间件的目的一般有两个:1、提升性能常用的中间件:1) 高速缓存:redis2) 全文检索:ES3) 存日志:ELK架构4) 流量削峰:kafka 2、提升可用性产品架构中高可…

Spring Cloud 网关Gateway + 配置中心

网关 网络的接口,负责请求的路由、转发、身份校验 路由:告诉请求去哪找 转发:请求找不到直接带请求过去 路由及转发 判断前端请求的规则就这么配 当前情况下只需要访问8080端口 就可以完成对全部微服务的访问 路由属性 登录校验 没必要在每…

sonar+gitlab提交阻断 增量扫描

通过本文,您将可以学习到 sonarqube、git\gitlab、shell、sonar-scanner、sonarlint 一、前言 sonarqube 是一款开源的静态代码扫描工具。 实际生产应用中,sonarqube 如何落地,需要考虑以下四个维度: 1、规则的来源 现在规则的…

java一和零(力扣Leetcode474)

一和零 力扣原题 给定一个二进制字符串数组 strs 和两个整数 m 和 n,请你找出并返回 strs 的最大子集的长度,该子集中最多有 m 个 0 和 n 个 1。 示例 1: 输入:strs [“10”, “0001”, “111001”, “1”, “0”], m 5, n …

【msyql】mysqldump: 未找到命令...

使用mysqldump备份数据库出现错误提示: mysqldump: 未找到命令... 执行的命令如下: mysqldump -uroot -proot --databases db_user > /home/backups/databackup.sql 解决方法 确认mysql是否安装 查看mysql版本 mysql --version 查找mysql安装路…

php反序列化刷题1

[SWPUCTF 2021 新生赛]ez_unserialize 查看源代码想到robots协议 看这个代码比较简单 直接让adminadmin passwdctf就行了 poc <?php class wllm {public $admin;public $passwd; }$p new wllm(); $p->admin "admin"; $p->passwd "ctf"; ec…

极光笔记|极光消息推送服务的云原生实践

摘要 极光始终秉承“以开发者为中心”的战略导向&#xff0c;极光推送&#xff08;JPush&#xff09;是国内领先的消息推送服务。极光推送&#xff08;JPush&#xff09;本质上是一种软件付费应用程序&#xff0c;结合当前主流云厂商基础施设&#xff0c;逐渐演进成了云上SaaS…

Java后端设置服务器允许跨域

文章目录 1、实现2、一些问题关于各项请求头的作用关于预检请求 3、一些补充4、疑问点 1、实现 以下通过servlet的Filter给所有响应的header加了一些跨域相关的数据&#xff0c;以实现允许跨域。 import org.springframework.context.annotation.Configuration; import org.s…

数据可视化基础与应用-04-seaborn库从入门到精通01-02

总结 本系列是数据可视化基础与应用的第04篇seaborn&#xff0c;是seaborn从入门到精通系列第1-2篇。本系列的目的是可以完整的完成seaborn从入门到精通。主要介绍基于seaborn实现数据可视化。 参考 参考:数据可视化-seaborn seaborn从入门到精通01-seaborn介绍与load_datas…

RabbitMQ3.x之二_RabbitMQ所有端口说明及开启后台管理功能

RabbitMQ3.x之二_RabbitMQ所有端口说明及开启后台管理功能 文章目录 RabbitMQ3.x之二_RabbitMQ所有端口说明及开启后台管理功能1. RabbitMQ端口说明2. 开启Rabbitmq后台管理功能1. 查看rabbitmq已安装的插件2. 开启rabbitmq后台管理平台插件3. 开启插件后&#xff0c;再次查看插…

RSTP环路避免实验(华为)

思科设备参考&#xff1a;RSTP环路避免实验&#xff08;思科&#xff09; 一&#xff0c;技术简介 RSTP (Rapid Spanning Tree Protocol) 是从STP发展而来 • RSTP标准版本为IEEE802.1w • RSTP具备STP的所有功能&#xff0c;可以兼容STP运行 • RSTP和STP有所不同 减少了…

Tomcat下载安装以及配置

一、Tomcat介绍 二、Tomcat下载安装 进入tomcat官网&#xff0c;https://tomcat.apache.org/ 1、选择需要下载的版本&#xff0c;点击下载 下载路径一定要记住&#xff0c;并且路径中尽量不要有中文 8、9、10都可以&#xff0c;本博文以8为例 2、将下载后的安装包解压到指定位…

linux-开发板移植MQTT

将源码复制到共享文件夹 链接&#xff1a;https://pan.baidu.com/s/1kvvO-HhDMDXkQ_wlNtyW_A?pwd332i 提取码&#xff1a;332i 以下步骤教程里都写了&#xff0c;我这里边进行&#xff0c;方便大家对照 pc端 1.进入mqtt_lib, 解压open压缩包 2.按照教程复制这一句并运行&…

服务端应用多级缓存架构方案

服务端应用多级缓存架构方案 场景 20w的QPS的场景下&#xff0c;服务端架构应如何设计&#xff1f; 常规解决方案 可使用分布式缓存来抗&#xff0c;比如redis集群&#xff0c;6主6从&#xff0c;主提供读写&#xff0c;从作为备&#xff0c;不提供读写服务。1台平均抗3w并…

【算法专题--双指针算法】leecode-15.三数之和(medium)、leecode-18. 四数之和(medium)

&#x1f341;你好&#xff0c;我是 RO-BERRY &#x1f4d7; 致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 &#x1f384;感谢你的陪伴与支持 &#xff0c;故事既有了开头&#xff0c;就要画上一个完美的句号&#xff0c;让我们一起加油 目录 前言1. 三数之和2. 解法&…

选择最佳图像处理工具OpenCV、JAI、ImageJ、Thumbnailator和Graphics2D

文章目录 1、前言2、 图像处理工具效果对比2.1 Graphics2D实现2.2 Thumbnailator实现2.3 ImageJ实现2.4 JAI&#xff08;Java Advanced Imaging&#xff09;实现2.5 OpenCV实现 3、图像处理工具结果 1、前言 SVD(stable video diffusion)开放了图生视频的API&#xff0c;但是限…

Ubuntu deb文件 安装 MySQL

更新系统软件依赖 sudo apt update && sudo apt upgrade下载安装包 输入命令查看Ubuntu系统版本 lsb_release -a2. 网站下载对应版本的安装包 下载地址. 解压安装 mkdir /home/mysqlcd /home/mysqltar -xvf mysql-server_8.0.36-1ubuntu20.04_amd64.deb-bundle.tar# …

【考研数学】张宇最新全年学习包

考研数学冲高分必备&#xff0c;张宇老师肯定榜上有名&#xff01; 考研数学&#xff0c;其实就像一场没有硝烟的战斗。基础题是常规武器&#xff0c;中难题就是重型火炮&#xff0c;而压轴题呢&#xff0c;那就是核弹级别的存在&#xff01;考研的战场&#xff0c;关键就在那…

使用ChatGPT的场景之gpt写研究报告,如何ChatGPT写研究报告

推荐写研究报告使用智能站&#xff1a; dayfire.cn/ 1. 确定研究主题 明确主题&#xff1a;在开始之前&#xff0c;你需要有一个清晰的研究主题。这将帮助AI更好地理解你的需求…