自定义神经网络时的注意事项

问题描述

`

通过继承tf.keras.Model自定义神经网络模型时遇到的一系列问题。

代码如下,

class STFT_ConV2D(tf.keras.Model):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)self.pre_layer = tf.keras.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(768, activation='relu')])self.add = tf.keras.layers.Add()self.output_dense = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):x, y = inputsx = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_x)(x)x = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_x)(x)x = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_x)(x)x = self.pre_layer(x)y = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_y)(y)y = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_y)(y)y = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_y)(y)y = self.pre_layer(y)output = self.add([x, y])output = self.output_dense(output)return output

产生的bug为,

  ValueError: Exception encountered when calling layer 'sequential' (type Sequential).Input 0 of layer "dense" is incompatible with the layer: expected axis -1 of input shape to have value 11368, but received input with shape (None, 210680)

x输入和y输入都使用了成员变量pre_layer,共享了pre_layer层,也就共享了pre_layer层的参数矩阵和结构。
由于x先经过三层卷积层后shape由原来的shape=(360, 256, 109, 1)变成了shape=(360, 203, 56, 1)
再经过pre_layer层里的Flatten时,除“ batchsize ”轴(axis=0)外,其余轴被铺平,输出shape=(360,11368)。接着处理y输入,经过三层卷积层后,shape由原来的shape=(360, 511, 513, 1)变成了shape=(360,458, 460, 1),之后执行到y = self.pre_layer(y)时,如果执行成功,则输出shape=(360,21068),此时与x的shape=(360,11368)维度冲突,从而产生异常。

要点归纳:

  1. 通过继承tf.keras.Model写神经网络模型时,每一个神经网络层只能被同一个输入占有。
  2. 所有tf.keras.layers下的层对象不能直接出现在call()方法中,必须以成员变量的形式在构造器中定义,然后在call()方法中通过self.成员变量的方式调用
  3. 卷积层tf.keras.layers.Conv2D()当神经网络第一层时,必须通过参数input_shape指定输入shape,该shape中不能包含“ batchsize ”轴,例如输入x的shape为(a, b, c, d),其中a代表样本数,b代表行像素,c代表列像素,d代表通道数。则应该指定input_shape=x.shape[1:],去除a所在轴,以免卷积层对该轴造成影响。

解决方案:

class STFT_ConV2D(tf.keras.Model):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)self.conV2d_x1 = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_x)self.conV2d_x2 = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_x)self.conV2d_x3 = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_x)self.conV2d_y1 = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_y)self.conV2d_y2 = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_y)self.conV2d_y3 = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_y)self.flatten_x = tf.keras.layers.Flatten()self.flatten_y = tf.keras.layers.Flatten()self.dense_x = tf.keras.layers.Dense(768, activation='relu')self.dense_y = tf.keras.layers.Dense(768, activation='relu')self.add = tf.keras.layers.Add()self.output_dense = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# x.shape = (360, 256, 109, 1) , y.shape = (360, 511, 513, 1)# inputs = (x, y)x, y = inputs  x = self.conV2d_x1(x) # (360, 249, 102, 3)x = self.conV2d_x2(x) # (360, 234, 87, 3)x = self.conV2d_x3(x) # (360, 203, 56, 1)x = self.flatten_x(x) # (360, 11368)x = self.dense_x(x)  # (360, 768)y = self.conV2d_y1(y)y = self.conV2d_y2(y)y = self.conV2d_y3(y)y = self.flatten_y(y)y = self.dense_y(y)output = self.add([x, y]) # (360, 768)output = self.output_dense(output)return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/814510.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Security Oauth2 之 理解OAuth 2.0授权流程

1. Oauth 定义 1.1 角色 OAuth定义了四个角色: 资源所有者 一个能够授权访问受保护资源的实体。当资源所有者是一个人时,它被称为最终用户。 资源服务器 托管受保护资源的服务器能够使用访问令牌接受和响应受保护的资源请求。 客户 代表资源所有…

Linux系统编程---文件IO

一、系统调用 由操作系统实现并提供给外部应用程序的编程接口(Application Programming Interface,API),用户程序可以通过这个特殊接口来获得操作系统内核提供的服务 系统调用和库函数的区别: 系统调用(系统函数) 内核提供的函数 库调用 …

一起学习python——基础篇(19)

今天来说一下python的如何修改文件名称、获取文件大小、读取文中指定的某一行内容。 1、修改文件名称: import os testPath"D:/pythonFile/test.txt" testPath2"D:/pythonFile/test2.txt" #修改文件名称使用rename方法, #第一个参…

TQ15EG开发板教程:在MPSOC上运行ADRV9009(vivado2018.3)

首先需要在github上下载两个文件,本例程用到的文件以及最终文件我都会放在网盘里面, 地址放在最后面。在github搜索hdl选择第一个,如下图所示 GitHub网址:https://github.com/analogdevicesinc/hdl/releases 点击releases选择版…

31省结婚、离婚、再婚等面板数据(1990-2022年)

01、数据介绍 一般来说,经济发达地区的结婚和离婚率相对较高,而经济欠发达地区的结婚和离婚率相对较低。此外,不同省份的文化、习俗、社会观念等因素也会对结婚和离婚情况产生影响。 本数据从1990年至2022年,对各地区的结婚、离…

Vue-router的编程式导航有哪些方法

Vue Router 的编程式导航主要提供了以下方法&#xff1a; push&#xff1a;这个方法会向 history 栈添加一个新的记录&#xff0c;所以当用户点击浏览器后退按钮时&#xff0c;则回到之前的 URL。当你点击 <router-link> 时&#xff0c;这个方法会在内部被调用&#xff…

6-169 删除递增链表两个值之间的元素 - 人邮DS(C 第2版)线性表习题2(8)

设计一个算法,删除递增有序链表中值大于mink且小于maxk的所有元素(mink和maxk是给定的两个参数,其值可以和表中的元素相同,也可以不同 )。 函数接口定义: void DeleteMinMax(LinkList const &L, int mink, int maxk); L - 递增链表的指针 mink - 被删除元素值的最…

【C++】每日一题 392 判断子序列

给定字符串 s 和 t &#xff0c;判断 s 是否为 t 的子序列。 字符串的一个子序列是原始字符串删除一些&#xff08;也可以不删除&#xff09;字符而不改变剩余字符相对位置形成的新字符串。&#xff08;例如&#xff0c;"ace"是"abcde"的一个子序列&#…

014_files_in_MATLAB中的文件读写

MATLAB中的文件读写 这一篇就要简单介绍MATLAB中的典型文件类型和文件操作。 基于字节流的接口 Matlab本身提供的文件操作是比较接近底层的&#xff0c;这一套底层的文件原语&#xff0c;主要是fopen、fclose、fread、fwrite、fseek、ftell、feof、ferror等函数。这些函数的…

Github 2024-04-14 php开源项目日报Top9

根据Github Trendings的统计,今日(2024-04-14统计)共有9个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量PHP项目9TypeScript项目1Laravel: 以优雅语法简化Web开发 创建周期:4028 天开发语言:PHP协议类型:MIT LicenseStar数量:30824 个Fork数量:1…

《青少年成长管理2024》046 “成长目标:你是谁呀?”2/3

《青少年成长管理2024》046 “成长目标&#xff1a;你是谁呀&#xff1f;”2/3 七、机器智能&#xff1f;八、天赋没有对错&#xff08;一&#xff09;天赋的客观性&#xff08;二&#xff09;我笨我没错&#xff08;三&#xff09;我聪明只是我幸运&#xff08;四&#xff09;…

在Linux驱动中,如何确保中断上下文的正确保存和恢复?

大家好&#xff0c;今天给大家介绍在Linux驱动中&#xff0c;如何确保中断上下文的正确保存和恢复&#xff1f;&#xff0c;文章末尾附有分享大家一个资料包&#xff0c;差不多150多G。里面学习内容、面经、项目都比较新也比较全&#xff01;可进群免费领取。 在Linux驱动中&am…

windows系统搭建OCR半自动标注工具PaddleOCR

深度学习 文章目录 深度学习前言一、环境搭建准备方式1&#xff1a;安装Anaconda搭建1. Anaconda下载地址: [点击](https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/?CM&OD)2. 创建新的conda环境 方式2. 直接安装python 二、安装CPU版本1. 安装PaddlePaddle2、安装…

中国省级人口结构数据集(2002-2022年)

01、数据简介 人口结构数据不仅反映了地域特色&#xff0c;更是预测地区未来发展趋势的重要工具。在这些数据中&#xff0c;总抚养比、少年儿童抚养比和老年人口抚养比是三大核心指标。 少儿抚养比0-14周岁人口数/15-64周岁人口数 老年抚养比65周岁及以上人口数/15-64周岁人…

leetcode每日一题(1702. 修改后的最大二进制字符串)

题目描述 题解 这道题贪心的思想&#xff0c;我们只需要尽可能多的把0变成1&#xff0c;而且进行操作1才能使其变大。观察发现以下几点&#xff1a; 不论原字符串有多少个0&#xff0c;最后都会剩余1个0。 假设原字符串只有一个0&#xff0c;不能进行任何操作&#xff0c;显然…

一招将vscode自动补全的双引号改为单引号

打开设置&#xff0c;搜索quote&#xff0c;在结果的HTML选项下找到自动完成&#xff0c;设置默认引号类型即可。 vscode版本&#xff1a;1.88.1&#xff0c; vscode更新日期&#xff1a;2024-4-10

利用Java代码调用Lua脚本改造分布式锁

4.8 利用Java代码调用Lua脚本改造分布式锁 lua脚本本身并不需要大家花费太多时间去研究&#xff0c;只需要知道如何调用&#xff0c;大致是什么意思即可&#xff0c;所以在笔记中并不会详细的去解释这些lua表达式的含义。 我们的RedisTemplate中&#xff0c;可以利用execute方…

共轭梯度法 Conjugate Gradient Method (线性及非线性)

1. 线性共轭梯度法 共轭梯度法&#xff08;英语&#xff1a;Conjugate gradient method&#xff09;&#xff0c;是求解系数矩阵为对称正定矩阵的线性方程组的数值解的方法。 共轭梯度法是一个迭代方法&#xff0c;它适用于 1. 求解线性方程组&#xff0c; 2. 共轭梯度法也可…

学习基于pytorch的VGG图像分类 day5

注&#xff1a;本系列博客在于汇总CSDN的精华帖&#xff0c;类似自用笔记&#xff0c;不做学习交流&#xff0c;方便以后的复习回顾&#xff0c;博文中的引用都注明出处&#xff0c;并点赞收藏原博主. 目录 VGG的数据集处理 1.数据的分类 2.对数据集的处理 VGG的分类标签设置 …

2款Notepad++平替工具(实用、跨平台的文本编辑器)

前言 今天大姚给大家分享2款Notepad平替工具&#xff0c;实用、跨平台&#xff08;支持Window/MacOS/Linux操作系统平台&#xff09;的文本编辑器。 NotepadNext NotepadNext是一个跨平台的 Notepad 的重新实现。开发是使用 QtCreator 和 Microsft Visual C (msvc) 编译器完…