终于把tensorflow输入层和输出层搞懂了!fit函数与输入层,输出层,tf.keras.Model输入和输出的关系

结论

fit函数与输入层,输出层,tf.keras.Model输入和输出的关系

  • fit函数使用dataset格式,输入为字典格式,假设tf.keras.Model中输入和输出为字典格式(2.2或2.3),dataset的key必须和2.2或2.3中字典的key一致,否则报错
  • fit函数使用dataset格式,输入为仍然是字典格式,假设tf.keras.Model中输入和输出为list格式(2.1),dataset的key必须和2.1涉及到的**输入层和输出层(1.1和1.2)**的层名一致,否则报错

1. 定义模型输入和输出

1.1 定义模型输入层

continuous_input = {key: tf.keras.layers.Input(shape=(), name=key) for key in continuous_feature}
discrete_input = {key: tf.keras.layers.Input(shape=(), name=key) for key in discrete_feature}  

1.2 定义模型输出层

output_1 = tf.keras.layers.Dense(1, activation='sigmoid', name='is_click')(x)
output_2 = tf.keras.layers.Dense(1, activation='sigmoid', name='is_play')(x)
output_3 = tf.keras.layers.Dense(1, activation='sigmoid', name='is_pay')(x)

2. tf.keras.Model输入和输出

2.1 输入和输出为list格式

model_func = tf.keras.Model(inputs=list(continuous_input.values()) + list(discrete_input.values()),outputs=[output_1, output_2,  output_3])

2.2 输出为dict格式

model_func = tf.keras.Model(inputs=list(continuous_input.values()) + list(discrete_input.values()),outputs={'is_click': output_1, 'is_play': output_2, 'is_pay': output_3})

2.3 输入为dict格式

# 构造输入字典,也可以其他方式构造,此处只是为了说明,continuous_input为字典
continuous_input.update(discrete_input)
model_func = tf.keras.Model(inputs=continuous_input,outputs=[output_1, output_2,  output_3])

3. fit函数中输入和输出-dataset(tfrecord格式)

3.1 dataset定义

def _parse_function(example_proto, feature_description):# Parse the input `tf.Example` proto using the dictionary above.data = tf.io.parse_single_example(example_proto, feature_description)is_click = data.pop('is_click')is_play = data.pop('is_play')is_pay = data.pop('is_pay')return data, {'is_click': is_click, 'is_play': is_play, 'is_pay': is_pay}

3.2 dataset示例-batch_size=1024

dataset为字典格式,请注意!

({'age': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.3448276 , 0.27586207, 0.31034482, ..., 0.37931034, 0.44827586,0.1724138 ], dtype=float32)>, 'first_class_id': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([ 4,  1,  1, ...,  4, 15,  1], dtype=int64)>, 'gender': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([2, 1, 2, ..., 2, 2, 1], dtype=int64)>, 'married': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([2, 2, 2, ..., 1, 1, 1], dtype=int64)>, 'province': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([22, 23, 25, ..., 14, 18, 20], dtype=int64)>, 'second_class_id': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([23, 53, 24, ..., 29, 11, 47], dtype=int64)>, 'tag_id': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([ 5, 58,  6, ..., 17, 76, 49], dtype=int64)>, 'target_item_id': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([ 5, 58,  6, ..., 17, 76, 49], dtype=int64)>, 'type': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([1, 4, 1, ..., 2, 1, 1], dtype=int64)>, 'user_click_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.5       , 0.5833333 , 1.        , ..., 0.41666666, 0.33333334,0.6666667 ], dtype=float32)>, 'user_click_video_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.5       , 0.5833333 , 0.8333333 , ..., 0.41666666, 0.33333334,0.6666667 ], dtype=float32)>, 'user_exp_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.47826087, 0.5797101 , 0.6231884 , ..., 0.3768116 , 0.6086956 ,0.26086956], dtype=float32)>, 'user_exp_video_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.61290324, 0.61290324, 0.58064514, ..., 0.32258064, 0.61290324,0.4516129 ], dtype=float32)>, 'user_id': <tf.Tensor: shape=(1024,), dtype=string, numpy=
array([b'ffb07508-9acc-4253-a1a0-e3e7fc6fad58',b'1ac654df-2b93-47b8-80ba-ca15642b5919',b'69daac99-ad14-4fc8-80f7-8c80cbc221b3', ...,b'97366ccc-b10d-47cb-9ad6-956c535ccf87',b'a7f43278-9e9b-4500-98b7-6d536d680ac1',b'297e2eec-a491-4aab-bc96-24f806751eb1'], dtype=object)>, 'user_name': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([234, 239, 285, ..., 753, 222, 563], dtype=int64)>, 'user_pay_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'user_pay_video_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'user_play_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'user_play_video_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'video_click_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.4722222, 0.5      , 0.6666667, ..., 0.8333333, 0.6666667,0.4722222], dtype=float32)>, 'video_click_user_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.4       , 0.51428574, 0.6857143 , ..., 0.7714286 , 0.71428573,0.4857143 ], dtype=float32)>, 'video_duration': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.8046324 , 0.19939578, 0.52970797, ..., 0.40584087, 0.5800604 ,0.15005036], dtype=float32)>, 'video_exp_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.7175141 , 0.4858757 , 0.7627119 , ..., 0.69491524, 0.4519774 ,0.6384181 ], dtype=float32)>, 'video_exp_user_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.85057473, 0.4827586 , 0.83908045, ..., 0.7011494 , 0.3908046 ,0.54022986], dtype=float32)>, 'video_id': <tf.Tensor: shape=(1024,), dtype=string, numpy=
array([b'HVgLcemGqaFAYgyEemtb', b'YNfPZPQwWggZRBkSsjMG',b'AvTonQbyvahPSCjsLvqN', ..., b'tvcZUdJBXAzJxsOZkXIc',b'HxnekvQEXBAgptCkNpXQ', b'RGumXWzhSqSoikFAZcWH'], dtype=object)>, 'video_name': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([15, 60,  1, ..., 25, 70, 47], dtype=int64)>, 'video_pay_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'video_pay_user_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'video_play_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.6666667 , 0.5555556 , 0.7777778 , ..., 0.6666667 , 0.44444445,0.33333334], dtype=float32)>, 'video_play_user_cnt': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([0.6666667 , 0.5555556 , 0.7777778 , ..., 0.6666667 , 0.44444445,0.33333334], dtype=float32)>, 'work': <tf.Tensor: shape=(1024,), dtype=int64, numpy=array([1, 2, 2, ..., 2, 3, 3], dtype=int64)>}, {'is_click': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'is_play': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, 'is_pay': <tf.Tensor: shape=(1024,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>})

4. fit函数与输入层,输出层,tf.keras.Model输入和输出的关系

  • fit函数使用dataset格式,输入为字典格式,假设tf.keras.Model中输入和输出为字典格式(2.2或2.3),dataset的key必须和2.2或2.3中字典的key一致,否则报错
  • fit函数使用dataset格式,输入为仍然是字典格式,假设tf.keras.Model中输入和输出为list格式(2.1),dataset的key必须和2.1涉及到的**输入层和输出层(1.1和1.2)**的层名一致,否则报错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/25754.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL逻辑备份

目录 一.mysqldump 基本命令&#xff1a; 常用选项&#xff1a; 示例 备份整个数据库 备份多个数据库 备份所有数据库 仅备份数据库结构 仅备份特定表 添加选项以有效处理锁表问题 恢复数据库 从逻辑备份文件恢复 注意事项 二. mysqlpump mysqlpump 特点 基…

BoardLight - hackthebox

简介 靶机名称&#xff1a;BoardLight 难度&#xff1a;简单 靶场地址&#xff1a;https://app.hackthebox.com/machines/603 本地环境 靶机IP &#xff1a;10.10.11.11 ubuntu渗透机IP(ubuntu 22.04)&#xff1a;10.10.16.17 windows渗透机IP&#xff08;windows11&…

在 RISC-V 设计中发现可远程利用的漏洞

在移动CPU领域&#xff0c;主流的CPU构架除了intel 的X86构架&#xff0c;甲骨文的arm 构架&#xff0c;其实还有RISC-V 构架。但是因为国际间竞争关系&#xff0c;现在RISC-V技术路线被国外废止了&#xff0c;目前只有中国在继续开发&#xff08;早期RISC-V是买断过来的&#…

从欧盟弹性法案看软件物料清单(SBOM)

随着网络安全意识的提升和相关法规的推动&#xff0c;SBOM在国际上网络安全实践中的重要性日益凸显。 例如&#xff1a;美国国土安全部&#xff08;DHS&#xff09;的 “软件供应链评估工具包”&#xff08;SCAT&#xff09;就鼓励软件供应商提供SBOM&#xff0c;以帮助买方评…

重新认识Word —— 制作简历

重新认识Word —— 制作简历 PPT的图形减除功能word中的设置调整页边距进行排版表格使用 我们之前把word长排版文本梳理了一遍&#xff0c;其实word还有另外的功能&#xff0c;比如说——制作简历。 在这之前&#xff0c;我们先讲一个小技巧&#xff1a; PPT的图形减除功能 …

【数据结构】栈和队列-->理解和实现(赋源码)

Toc 欢迎光临我的Blog&#xff0c;喜欢就点歌关注吧♥ 前面介绍了顺序表、单链表、双向循环链表&#xff0c;基本上已经结束了链表的讲解&#xff0c;今天谈一下栈、队列。可以简单的说是前面学习的一特殊化实现&#xff0c;但是总体是相似的。 前言 栈是一种特殊的线性表&…

VISIO安装教程+安装包

文章目录 01、什么是VISIO&#xff1f;02、安装教程03、常见安装问题解析 01、什么是VISIO&#xff1f; Visio是由微软开发的流程图和图表绘制软件&#xff0c;它是Microsoft Office套件的一部分。Visio提供了各种模板和工具&#xff0c;使用户能够轻松创建和编辑各种类型的图…

【微信小程序开发(从零到一)】——个人中心页面的实战项目(二)

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;开发者-曼亿点 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 曼亿点 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a…

VS2022+Qt雕刻机单片机马达串口上位机控制系统

程序示例精选 VS2022Qt雕刻机单片机马达串口上位机控制系统 如需安装运行环境或远程调试&#xff0c;见文章底部个人QQ名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对《VS2022Qt雕刻机单片机马达串口上位机控制系统》编写代码&#xff0c;代码整洁&a…

C#面:阐述对DDD的理解

C#是一种面向对象的编程语言&#xff0c;而领域驱动设计&#xff08;Domain-Driven Design&#xff0c;简称DDD&#xff09;是一种软件开发方法论&#xff0c;它强调将业务领域的知识和逻辑直接融入到软件设计和开发中。 在C#中实施DDD的关键是将业务领域划分为不同的领域模型…

PHP“well”运动健身APP-计算机毕业设计源码87702

【摘要】 随着互联网的趋势的到来&#xff0c;各行各业都在考虑利用互联网将自己的信息推广出去&#xff0c;最好方式就是建立自己的平台信息&#xff0c;并对其进行管理&#xff0c;随着现在智能手机的普及&#xff0c;人们对于智能手机里面的应用“well”运动健身app也在不断…

vue中插槽的本质

定义slotCompoent.vue 组件 <template><slot></slot><slot nameslot1></slot><slot name"slot2" msg"hello"></slot> </template>使用组件&#xff1a; <slotComponent><p>默认的</p>…

gcc:coverage:gcda文件没有生成的另一个例子:dlopen

根据gcc的文档&#xff0c; 如果是使用dlopen的方式来打开一个函数&#xff0c;需要记录coverage的数据&#xff0c;就需要使用下面这个链接。 If an executable loads a dynamic shared object via dlopen functionality, ‘-Wl,–dynamic-list-data’ is needed to dump all …

【系统架构】架构演进

系列文章目录 第一章 系统架构的演进 本篇文章目录 系列文章目录前言一、原始分布式二、单体系统时代三、SOA时代烟囱架构微内核架构事件驱动架构 四、微服务架构五、后微服务时代六、无服务时代总结 前言 最近笔者一直在学习系统架构的相关知识&#xff0c;对系统架构的演进…

6.7 作业

搭建一个货币的场景&#xff0c;创建一个名为 RMB 的类&#xff0c;该类具有整型私有成员变量 yuan&#xff08;元&#xff09;、jiao&#xff08;角&#xff09;和 fen&#xff08;分&#xff09;&#xff0c;并且具有以下功能&#xff1a; (1)重载算术运算符 和 -&#xff…

Day34

Day34 三大范式及反范式设计 第一范式&#xff1a; 存在问题&#xff1a; 1.存在非常严重的数据冗余(重复) 2.数据添加存在问题 3.数据删除存在问题 第二范式&#xff1a; 解决了一部分数据冗余&#xff0c;但仍然存在较严重的数据冗余问题&#xff0c;数据添加和删除问题依然…

Java学习-JDBC(五)

JDBC优化及工具类封装 现有问题 ①创建连接池②获取连接③连接回收 ThreadLocal 为解决多线程程序的并发问题提供了一种新的思路&#xff0c;使用这个工具类可以很简洁地编写出优美的多线程程序&#xff0c;通常用在多线程中管理共享数据库连接、Session等ThreadLocal用于保…

leetcode hot100 补充

除了 hot100 外&#xff0c;还有一些常见的题目&#xff0c;也是值得我们复习的。我们新开一个 补充 栏目&#xff0c;进行梳理。 hot 100补充 回溯法 回溯法 medium 组合之和 II dfs

6.全开源源码---小红书卡片-跳转微信-自动回复跳转卡片-商品卡片-发私信-发群聊-安全导流不封号-企业号白号都可以用

现在用我们的方法&#xff0c;可以规避违规风险&#xff0c;又可以丝滑引流&#xff0c;因为会以笔记的形式发给客户&#xff0c;点击之后直接跳微信&#xff0c;我们来看看演示效果吧&#xff08;没有风险提示&#xff09; 无论是引流还是销售产品都会事半功倍。

关于如何设置 TMOD (定时/计数 高低 共 8 位 寄存器)

TMOD 寄存器简介 TMOD 是 8051 单片机的定时器模式寄存器。它是一个 8 位寄存器&#xff0c;用于配置定时器/计数器的工作模式。TMOD 的每一位有特定的含义。 TMOD 的结构如下&#xff1a; GATE | C/T | M1 | M0 | GATE | C/T | M1 | M07 | 6 | 5 | 4 | 3 | 2 | …