TensorFlow神经网络中间层的可视化

TensorFlow神经网络中间层的可视化

  • TensorFlow神经网络中间层的可视化
    • 1. 训练网络并保存为.h5文件
    • 2. 通过.h5文件导入网络
    • 3. 可视化网络中间层结果
      • (1)索引取层可视化
      • (2)通过名字取层可视化

TensorFlow神经网络中间层的可视化

1. 训练网络并保存为.h5文件

我们使用AlexNet为例,任务是手写数字识别,训练集使用手写数字集(mnist)。

网络的结构(我们使用的是28x28的黑白图):
在这里插入图片描述

网络搭建和训练的代码

# 最终版
import os.path
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2# 画出训练过程的准确率和损失值的图像
def plotTrainHistory(history, train, val):plt.plot(history[train])plt.plot(history[val])plt.title('Train History')plt.xlabel('Epoch')plt.ylabel(train)plt.legend(['train', 'validation'], loc = 'upper left')plt.show()(xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()xTrain = tf.expand_dims(xTrain, axis = 3)
xTest = tf.expand_dims(xTest, axis = 3)
print(f"训练集数据大小:{xTrain.shape}")
print(f"训练集标签大小:{yTrain.shape}")
print(f"测试集数据大小:{xTest.shape}")
print(f"测试集标签大小:{yTest.shape}")# 归一化
xTrainNormalize = tf.cast(xTrain, tf.float32) / 255
xTestNormalize = tf.cast(xTest, tf.float32) / 255
# 数据独热编码
yTrainOneHot = tf.keras.utils.to_categorical(yTrain)
yTestOneHot = tf.keras.utils.to_categorical(yTest)model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(filters = 96, kernel_size = 11, strides = 4, input_shape = (28, 28, 1),padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.BatchNormalization(),tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'SAME'),tf.keras.layers.Conv2D(filters = 256, kernel_size = 5, strides = 1,padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.BatchNormalization(),tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'SAME'),tf.keras.layers.Conv2D(filters = 384, kernel_size = 3, strides = 1,padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.Conv2D(filters = 384, kernel_size = 3, strides = 1,padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.Conv2D(filters = 256, kernel_size = 3, strides = 1,padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'SAME'),tf.keras.layers.Flatten(),tf.keras.layers.Dense(4096, activation = tf.keras.activations.relu),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(4096, activation = tf.keras.activations.relu),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(10, activation = tf.keras.activations.softmax)
])weightsPath = './AlexNetModel/'callback = tf.keras.callbacks.ModelCheckpoint(filepath = weightsPath,save_best_only = True,save_weights_only = True,verbose = 1
)model.compile(loss = tf.keras.losses.CategoricalCrossentropy(),optimizer = tf.keras.optimizers.Adam(),metrics = ['accuracy']
)model.summary()# 不存在就训练模型
print('参数文件不存在,即将训练模型')
modelTrain = model.fit(xTrainNormalize, yTrainOneHot, validation_split = 0.2,epochs = 20, batch_size = 300, verbose = 1, callbacks = [callback]
)
model.save("./model.h5")
plotTrainHistory(modelTrain.history, 'loss', 'val_loss')
plotTrainHistory(modelTrain.history, 'accuracy', 'val_accuracy')

2. 通过.h5文件导入网络

把刚才训练得到的模型重新读取,并且重新加载数据集

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npdef plot_images(images, number, path, title, gray = False):plt.figure()plt.title(title)order = 1for i in range(0, number):plt.subplot(3, 3, order)if gray:plt.imshow(images[:, :, 0, i], cmap = 'gray')else:plt.imshow(images[:, :, 0, i])plt.colorbar()order = order + 1plt.savefig("./{}.png".format(path))plt.show()if __name__ == '__main__':weightsPath = './AlexNetModel/'(xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()xTrain = tf.expand_dims(xTrain, axis = 3)xTest = tf.expand_dims(xTest, axis = 3)# print(f"训练集数据大小:{xTrain.shape}")# print(f"训练集标签大小:{yTrain.shape}")# print(f"测试集数据大小:{xTest.shape}")# print(f"测试集标签大小:{yTest.shape}")# 归一化xTrainNormalize = tf.cast(xTrain, tf.float32) / 255xTestNormalize = tf.cast(xTest, tf.float32) / 255# 数据独热编码yTrainOneHot = tf.keras.utils.to_categorical(yTrain)yTestOneHot = tf.keras.utils.to_categorical(yTest)model = tf.keras.models.load_model("model.h5")model.summary()print('Layer Number', len(model.layers))sample = xTrainNormalize[0]plt.imshow(sample)plt.colorbar()plt.savefig('./train.png')

3. 可视化网络中间层结果

测试的数字,5

在这里插入图片描述

(1)索引取层可视化

model.layers中存放着这个神经网络的全部层,它是一个list类型变量

AlexNet一共16层(卷积层、全连接层、池化层等都算入),全部存储在里面

model = tf.keras.models.load_model("model.h5")
print('Layer Number', len(model.layers))

可视化的时候我们取出一部分层,然后来预测,预测结果就是取出来这部分层的结果,因此就看到了中间层的结果

output = tf.keras.models.Sequential([tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),model.layers[0],model.layers[1],model.layers[2],
]).predict(sample)
print('output.shape', output.shape)
plot_images(output, 9, '5_Conv2D_BN_MP_1', str(output.shape))

查看三层的结果,即Conv2D+BN+MaxPool,结果是 (28, 4, 1, 96),这里画出前9个

在这里插入图片描述
把这96个叠加在一起的结果

t = output[:, :, 0, 0]
for i in range(1, output.shape[3]):t = t + output[:, :, 0, i]
plt.imshow(t)
plt.colorbar()
plt.savefig('./5_Conv2D_BN_MP_1_All.png')

在这里插入图片描述
下面的代码是画出神经网络三个中间层的结果

在这里插入图片描述

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npdef plot_images(images, number, path, title, gray = False):plt.figure()plt.title(title)order = 1for i in range(0, number):plt.subplot(3, 3, order)if gray:plt.imshow(images[:, :, 0, i], cmap = 'gray')else:plt.imshow(images[:, :, 0, i])plt.colorbar()order = order + 1plt.savefig("./{}.png".format(path))plt.show()if __name__ == '__main__':weightsPath = './AlexNetModel/'(xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()xTrain = tf.expand_dims(xTrain, axis = 3)xTest = tf.expand_dims(xTest, axis = 3)# print(f"训练集数据大小:{xTrain.shape}")# print(f"训练集标签大小:{yTrain.shape}")# print(f"测试集数据大小:{xTest.shape}")# print(f"测试集标签大小:{yTest.shape}")# 归一化xTrainNormalize = tf.cast(xTrain, tf.float32) / 255xTestNormalize = tf.cast(xTest, tf.float32) / 255# 数据独热编码yTrainOneHot = tf.keras.utils.to_categorical(yTrain)yTestOneHot = tf.keras.utils.to_categorical(yTest)model = tf.keras.models.load_model("model.h5")model.summary()print('Layer Number', len(model.layers))sample = xTrainNormalize[0]plt.imshow(sample)plt.colorbar()plt.savefig('./train.png')output = tf.keras.models.Sequential([tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),model.layers[0],model.layers[1],model.layers[2],]).predict(sample)print('output.shape', output.shape)plot_images(output, 9, '5_Conv2D_BN_MP_1', str(output.shape))t = output[:, :, 0, 0]for i in range(1, output.shape[3]):t = t + output[:, :, 0, i]plt.imshow(t)plt.colorbar()plt.savefig('./5_Conv2D_BN_MP_1_All.png')output = tf.keras.models.Sequential([tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),model.layers[0],model.layers[1],model.layers[2],model.layers[3],model.layers[4],model.layers[5],]).predict(sample)print('output.shape', output.shape)plot_images(output, 9, '5_Conv2D_BN_MP_2', str(output.shape))t = output[:, :, 0, 0]for i in range(1, output.shape[3]):t = t + output[:, :, 0, i]plt.imshow(t)plt.colorbar()plt.savefig('./5_Conv2D_BN_MP_2_All.png')output = tf.keras.models.Sequential([tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),model.layers[0],model.layers[1],model.layers[2],model.layers[3],model.layers[4],model.layers[5],model.layers[6],model.layers[7],model.layers[8],model.layers[9],]).predict(sample)print('output.shape', output.shape)plot_images(output, 9, '5_Conv2D_3_MP', str(output.shape))t = output[:, :, 0, 0]for i in range(1, output.shape[3]):t = t + output[:, :, 0, i]plt.imshow(t)plt.colorbar()plt.savefig('./5_Conv2D_3_MP_All.png')

0和5的结果

在这里插入图片描述

(2)通过名字取层可视化

模型的**summary()**成员函数可以查看网络每一层名字和参数情况

model.summary()

博客中使用的AlexNet每一层名字和参数情况
在这里插入图片描述
通过名字来取中间层,并且预测得到中间层可视化结果

在这里插入图片描述
如果我们要看这个池化层的结果,这样写代码

model = tf.keras.models.load_model("../model.h5")
model.summary()sample = xTrainNormalize[0]
plt.imshow(sample)
plt.colorbar()
plt.savefig('./train.png')output = tf.keras.models.Model(inputs=model.get_layer('conv2d').input,outputs=model.get_layer('max_pooling2d').output
).predict(sample)

通过get_layer获取指定名字的层

inputs指定输入层,outputs指定输出层

每一层的名字可以在创建的时候使用name参数指定

...
tf.keras.layers.Conv2D(filters = 96, kernel_size = 11, strides = 4, input_shape = (28, 28, 1),padding = 'SAME', activation = tf.keras.activations.relu, name = 'Conv2D_1'),
...

每一层的名字红色框框出

在这里插入图片描述

下面是例子:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npdef plot_images(images, number, path, title, gray = False):plt.figure()plt.title(title)order = 1for i in range(0, number):plt.subplot(3, 3, order)if gray:plt.imshow(images[:, :, 0, i], cmap = 'gray')else:plt.imshow(images[:, :, 0, i])plt.colorbar()order = order + 1plt.savefig("./{}.png".format(path))plt.show()if __name__ == '__main__':(xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()xTrain = tf.expand_dims(xTrain, axis = 3)xTest = tf.expand_dims(xTest, axis = 3)# 归一化xTrainNormalize = tf.cast(xTrain, tf.float32) / 255xTestNormalize = tf.cast(xTest, tf.float32) / 255# 数据独热编码yTrainOneHot = tf.keras.utils.to_categorical(yTrain)yTestOneHot = tf.keras.utils.to_categorical(yTest)model = tf.keras.models.load_model("../model.h5")model.summary()sample = xTrainNormalize[0]plt.imshow(sample)plt.colorbar()plt.savefig('./train.png')output = tf.keras.models.Model(inputs=model.get_layer('conv2d').input,outputs=model.get_layer('max_pooling2d').output).predict(sample)# output = tf.keras.models.Sequential([#     tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),#     model.layers[0],#     model.layers[1],#     model.layers[2],# ]).predict(sample)print('output.shape', output.shape)# plot_images(output, 9, '5_Conv2D_BN_MP_1', str(output.shape))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何在PHP中实现文件下载?

在PHP中实现文件下载通常涉及以下几个步骤: 确保文件存在并可供下载: 首先,您需要确保要下载的文件存在,并且具有合适的文件权限。您可以使用file_exists函数来检查文件是否存在。 设置HTTP响应头: 在向客户端发送文件…

Redis系列之简单实现watchDog自动续期机制

在分布锁的实际使用中,可能会遇到一种情况,一个业务执行时间很长,已经超过redis加锁的时间,也就是锁已经释放了,但是业务还没执行完成,这时候其它线程还是可以获取锁,那就没保证线程安全 项目环…

完美解决labelimg xml转可视化中文乱码问题,不用matplotlib

背景简述 我们有一批标注项目要转可视化,因为之前没有做过,然后网上随意找了一段代码测试完美(并没有)搞定,开始疯狂标注,当真正要转的时候傻眼了,因为测试的时候用的是英文标签,实…

基于linux系统的Tomcat+Mysql+Jdk环境搭建(三)centos7 安装Tomcat

Tomcat下载官网: Apache Tomcat - Which Version Do I Want? JDK下载官网: Java Downloads | Oracle 中国 如果不知道Tomcat的哪个版本应该对应哪个版本的JDK可以打开官网,点击Whitch Version 下滑,有低版本的,如…

Flutter实现Android拖动到垃圾桶删除效果-Draggable和DragTarget的详细讲解

文章目录 Draggable介绍构造函数参数说明使用示例 DragTarget 介绍构造函数参数说明使用示例 DragTarget 如何接收Draggable传递过来的数据? Draggable介绍 Draggable是Flutter框架中的一个小部件,用于支持用户通过手势拖动一个子部件。它是基于手势的一…

知识付费小程序开发:技术实践示例

随着知识付费小程序的兴起,让我们一起来看一个简单的示例,使用Node.js和Express框架搭建一个基础的知识付费小程序后端。 首先,确保你已经安装了Node.js和npm。接下来,创建一个新的项目文件夹,然后通过以下步骤创建你…

大文件分块上传的代码,C++转delphi,由delphi实现。

在 Delphi 中,我们通常使用 IdHTTP 或 TNetHTTPClient 等组件来处理 HTTP 请求 原文章链接: 掌握分片上传:优化大文件传输的关键策略 【C】【WinHttp】【curl】-CSDN博客 改造思路: 文件分块处理:使用 TFileStream 来…

适用于 Windows 和 Mac 的 10 款最佳照片恢复软件(免费和付费)

丢失照片很容易。这里点击错误,那里贴错标签的 SD 卡,然后噗的一声,一切都消失了。值得庆幸的是,在技术领域,你可以纠正一些错误。其中包括删除数据或格式化错误的存储设备。 那么,让我们看看可用于从 SD …

从memcpy()函数中学习函数的设计思想

memcpy()函数:可以理解为内存拷贝。 他的函数定义如下的 my_memcpy()函数相同。 下面这个函数是我的模拟实现,现在让我们一起来学习一下这个函数的设计思想: void * my_memcpy(void * des, const void* src, size_t size) {void * p des;…

Before an Exam

题目名字 Before an Exam 题目链接 题意 给定天数和目标,然后接下来输入每天的最少时间和最多时间,先判断在每天的范围内能否完成目标,如果不能输出no,如果能就输出每天在给定范围内完成的时间 思路 先用maxsum来将每天的最大时…

[c++]—vector类___提升版(带你了解vector底层的运用)

我写我 不论主谓宾 可以反复错 🌈vector的介绍 1.vector是表示可变大小数组的序列容器2.就像数组一样,vector也采用的连续存储空间来存储元素,也就是意味着可以采用下标对vector的元素进行访问,和数组一样高效。但是又不像数组&…

工业性能CCD图像处理+

目录 硬件部分 ​编辑 软件部分 CCD新相机的调试处理(更换相机处理,都要点执行检测来查看图像变化) 问题:新相机拍摄出现黑屏,图像拍摄不清晰,(可以点击图像,向下转动鼠标的滚轮&#xff08…

华为OD机试 - 符号运算(Java JS Python C)

题目描述 给定一个表达式,求其分数计算结果。 表达式的限制如下: 所有的输入数字皆为正整数(包括0)仅支持四则运算(+-*/)和括号结果为整数或分数,分数必须化为最简格式(比如6,3/4,7/8,90/7)除数可能为0,如果遇到这种情况,直接输出"ERROR"输入和最终计…

基于linux系统的Tomcat+Mysql+Jdk环境搭建(一)vmare centos7 设置静态ip和连接MobaXterm

特别注意,Windows10以上版本操作系统需要下载安装VMware Workstation Pro16及以上版本,安装方式此处略。 (可忽略 my*** 记录设置的vamare centos7 账号root/aaa 密码:Aa123456 ) 1、命令行和图形界面切换 如果使用的是VMware虚拟机&…

金智融门户(统一身份认证)同步数据至钉钉通讯录

前言:因全面使用金智融门户和数据资产平台,二十几个信息系统已实现统一身份认证和数据同步,目前单位使用的钉钉尚未同步组织机构和用户信息,职工入职、离职、调岗时都需要手工在钉钉后台操作,一是操作繁琐,二是钉钉通讯录更新不及时或经常遗漏,带来管理问题。通过金智融…

编写第一个Selenium脚本

目录 安装Selenium类库 请求对应的程序语言 Pip 下载 在项目中使用 编写第一个Selenium脚本 八个基本组成部分 1. 使用驱动实例开启会话 本地驱动 驱动自动管理 驱动选项 浏览器选项 Capabilities Timeouts 2. 在浏览器上执行操作 3. 请求 浏览器信息 4. 建立等…

CAD 审图意见的导出

看图的时候喜欢在图上直接标注意见,但是如果还要再把意见一行一行的导出到word里面就很麻烦,在网上看了一个审图软件,报价要980,而且那个审图意见做的太复杂了。 我的需求就是把图上标的单行文字和多行文字直接导出来就行&#x…

debug点f8step over会进入class文件

File->settings->Bulid.Executiong.Deployment->Debugger->Stepping 取消如图对钩即可

二十七、读写文件

二十七、读写文件 27.1 文件类QFile #include <QCoreApplication>#include<QFile> #include<QDebug>int main(int argc, char *argv[]) {QCoreApplication a(argc, argv);QFile file("D:/main.txt");if(!file.open(QIODevice::WriteOnly | QIODe…

three.js模拟太阳系

地球的旋转轨迹目前设置为了圆形&#xff0c;效果&#xff1a; <template><div><el-container><el-main><div class"box-card-left"><div id"threejs" style"border: 1px solid red"></div><div c…