TensorFlow神经网络中间层的可视化

TensorFlow神经网络中间层的可视化

  • TensorFlow神经网络中间层的可视化
    • 1. 训练网络并保存为.h5文件
    • 2. 通过.h5文件导入网络
    • 3. 可视化网络中间层结果
      • (1)索引取层可视化
      • (2)通过名字取层可视化

TensorFlow神经网络中间层的可视化

1. 训练网络并保存为.h5文件

我们使用AlexNet为例,任务是手写数字识别,训练集使用手写数字集(mnist)。

网络的结构(我们使用的是28x28的黑白图):
在这里插入图片描述

网络搭建和训练的代码

# 最终版
import os.path
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2# 画出训练过程的准确率和损失值的图像
def plotTrainHistory(history, train, val):plt.plot(history[train])plt.plot(history[val])plt.title('Train History')plt.xlabel('Epoch')plt.ylabel(train)plt.legend(['train', 'validation'], loc = 'upper left')plt.show()(xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()xTrain = tf.expand_dims(xTrain, axis = 3)
xTest = tf.expand_dims(xTest, axis = 3)
print(f"训练集数据大小:{xTrain.shape}")
print(f"训练集标签大小:{yTrain.shape}")
print(f"测试集数据大小:{xTest.shape}")
print(f"测试集标签大小:{yTest.shape}")# 归一化
xTrainNormalize = tf.cast(xTrain, tf.float32) / 255
xTestNormalize = tf.cast(xTest, tf.float32) / 255
# 数据独热编码
yTrainOneHot = tf.keras.utils.to_categorical(yTrain)
yTestOneHot = tf.keras.utils.to_categorical(yTest)model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(filters = 96, kernel_size = 11, strides = 4, input_shape = (28, 28, 1),padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.BatchNormalization(),tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'SAME'),tf.keras.layers.Conv2D(filters = 256, kernel_size = 5, strides = 1,padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.BatchNormalization(),tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'SAME'),tf.keras.layers.Conv2D(filters = 384, kernel_size = 3, strides = 1,padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.Conv2D(filters = 384, kernel_size = 3, strides = 1,padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.Conv2D(filters = 256, kernel_size = 3, strides = 1,padding = 'SAME', activation = tf.keras.activations.relu),tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'SAME'),tf.keras.layers.Flatten(),tf.keras.layers.Dense(4096, activation = tf.keras.activations.relu),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(4096, activation = tf.keras.activations.relu),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(10, activation = tf.keras.activations.softmax)
])weightsPath = './AlexNetModel/'callback = tf.keras.callbacks.ModelCheckpoint(filepath = weightsPath,save_best_only = True,save_weights_only = True,verbose = 1
)model.compile(loss = tf.keras.losses.CategoricalCrossentropy(),optimizer = tf.keras.optimizers.Adam(),metrics = ['accuracy']
)model.summary()# 不存在就训练模型
print('参数文件不存在,即将训练模型')
modelTrain = model.fit(xTrainNormalize, yTrainOneHot, validation_split = 0.2,epochs = 20, batch_size = 300, verbose = 1, callbacks = [callback]
)
model.save("./model.h5")
plotTrainHistory(modelTrain.history, 'loss', 'val_loss')
plotTrainHistory(modelTrain.history, 'accuracy', 'val_accuracy')

2. 通过.h5文件导入网络

把刚才训练得到的模型重新读取,并且重新加载数据集

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npdef plot_images(images, number, path, title, gray = False):plt.figure()plt.title(title)order = 1for i in range(0, number):plt.subplot(3, 3, order)if gray:plt.imshow(images[:, :, 0, i], cmap = 'gray')else:plt.imshow(images[:, :, 0, i])plt.colorbar()order = order + 1plt.savefig("./{}.png".format(path))plt.show()if __name__ == '__main__':weightsPath = './AlexNetModel/'(xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()xTrain = tf.expand_dims(xTrain, axis = 3)xTest = tf.expand_dims(xTest, axis = 3)# print(f"训练集数据大小:{xTrain.shape}")# print(f"训练集标签大小:{yTrain.shape}")# print(f"测试集数据大小:{xTest.shape}")# print(f"测试集标签大小:{yTest.shape}")# 归一化xTrainNormalize = tf.cast(xTrain, tf.float32) / 255xTestNormalize = tf.cast(xTest, tf.float32) / 255# 数据独热编码yTrainOneHot = tf.keras.utils.to_categorical(yTrain)yTestOneHot = tf.keras.utils.to_categorical(yTest)model = tf.keras.models.load_model("model.h5")model.summary()print('Layer Number', len(model.layers))sample = xTrainNormalize[0]plt.imshow(sample)plt.colorbar()plt.savefig('./train.png')

3. 可视化网络中间层结果

测试的数字,5

在这里插入图片描述

(1)索引取层可视化

model.layers中存放着这个神经网络的全部层,它是一个list类型变量

AlexNet一共16层(卷积层、全连接层、池化层等都算入),全部存储在里面

model = tf.keras.models.load_model("model.h5")
print('Layer Number', len(model.layers))

可视化的时候我们取出一部分层,然后来预测,预测结果就是取出来这部分层的结果,因此就看到了中间层的结果

output = tf.keras.models.Sequential([tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),model.layers[0],model.layers[1],model.layers[2],
]).predict(sample)
print('output.shape', output.shape)
plot_images(output, 9, '5_Conv2D_BN_MP_1', str(output.shape))

查看三层的结果,即Conv2D+BN+MaxPool,结果是 (28, 4, 1, 96),这里画出前9个

在这里插入图片描述
把这96个叠加在一起的结果

t = output[:, :, 0, 0]
for i in range(1, output.shape[3]):t = t + output[:, :, 0, i]
plt.imshow(t)
plt.colorbar()
plt.savefig('./5_Conv2D_BN_MP_1_All.png')

在这里插入图片描述
下面的代码是画出神经网络三个中间层的结果

在这里插入图片描述

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npdef plot_images(images, number, path, title, gray = False):plt.figure()plt.title(title)order = 1for i in range(0, number):plt.subplot(3, 3, order)if gray:plt.imshow(images[:, :, 0, i], cmap = 'gray')else:plt.imshow(images[:, :, 0, i])plt.colorbar()order = order + 1plt.savefig("./{}.png".format(path))plt.show()if __name__ == '__main__':weightsPath = './AlexNetModel/'(xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()xTrain = tf.expand_dims(xTrain, axis = 3)xTest = tf.expand_dims(xTest, axis = 3)# print(f"训练集数据大小:{xTrain.shape}")# print(f"训练集标签大小:{yTrain.shape}")# print(f"测试集数据大小:{xTest.shape}")# print(f"测试集标签大小:{yTest.shape}")# 归一化xTrainNormalize = tf.cast(xTrain, tf.float32) / 255xTestNormalize = tf.cast(xTest, tf.float32) / 255# 数据独热编码yTrainOneHot = tf.keras.utils.to_categorical(yTrain)yTestOneHot = tf.keras.utils.to_categorical(yTest)model = tf.keras.models.load_model("model.h5")model.summary()print('Layer Number', len(model.layers))sample = xTrainNormalize[0]plt.imshow(sample)plt.colorbar()plt.savefig('./train.png')output = tf.keras.models.Sequential([tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),model.layers[0],model.layers[1],model.layers[2],]).predict(sample)print('output.shape', output.shape)plot_images(output, 9, '5_Conv2D_BN_MP_1', str(output.shape))t = output[:, :, 0, 0]for i in range(1, output.shape[3]):t = t + output[:, :, 0, i]plt.imshow(t)plt.colorbar()plt.savefig('./5_Conv2D_BN_MP_1_All.png')output = tf.keras.models.Sequential([tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),model.layers[0],model.layers[1],model.layers[2],model.layers[3],model.layers[4],model.layers[5],]).predict(sample)print('output.shape', output.shape)plot_images(output, 9, '5_Conv2D_BN_MP_2', str(output.shape))t = output[:, :, 0, 0]for i in range(1, output.shape[3]):t = t + output[:, :, 0, i]plt.imshow(t)plt.colorbar()plt.savefig('./5_Conv2D_BN_MP_2_All.png')output = tf.keras.models.Sequential([tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),model.layers[0],model.layers[1],model.layers[2],model.layers[3],model.layers[4],model.layers[5],model.layers[6],model.layers[7],model.layers[8],model.layers[9],]).predict(sample)print('output.shape', output.shape)plot_images(output, 9, '5_Conv2D_3_MP', str(output.shape))t = output[:, :, 0, 0]for i in range(1, output.shape[3]):t = t + output[:, :, 0, i]plt.imshow(t)plt.colorbar()plt.savefig('./5_Conv2D_3_MP_All.png')

0和5的结果

在这里插入图片描述

(2)通过名字取层可视化

模型的**summary()**成员函数可以查看网络每一层名字和参数情况

model.summary()

博客中使用的AlexNet每一层名字和参数情况
在这里插入图片描述
通过名字来取中间层,并且预测得到中间层可视化结果

在这里插入图片描述
如果我们要看这个池化层的结果,这样写代码

model = tf.keras.models.load_model("../model.h5")
model.summary()sample = xTrainNormalize[0]
plt.imshow(sample)
plt.colorbar()
plt.savefig('./train.png')output = tf.keras.models.Model(inputs=model.get_layer('conv2d').input,outputs=model.get_layer('max_pooling2d').output
).predict(sample)

通过get_layer获取指定名字的层

inputs指定输入层,outputs指定输出层

每一层的名字可以在创建的时候使用name参数指定

...
tf.keras.layers.Conv2D(filters = 96, kernel_size = 11, strides = 4, input_shape = (28, 28, 1),padding = 'SAME', activation = tf.keras.activations.relu, name = 'Conv2D_1'),
...

每一层的名字红色框框出

在这里插入图片描述

下面是例子:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npdef plot_images(images, number, path, title, gray = False):plt.figure()plt.title(title)order = 1for i in range(0, number):plt.subplot(3, 3, order)if gray:plt.imshow(images[:, :, 0, i], cmap = 'gray')else:plt.imshow(images[:, :, 0, i])plt.colorbar()order = order + 1plt.savefig("./{}.png".format(path))plt.show()if __name__ == '__main__':(xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()xTrain = tf.expand_dims(xTrain, axis = 3)xTest = tf.expand_dims(xTest, axis = 3)# 归一化xTrainNormalize = tf.cast(xTrain, tf.float32) / 255xTestNormalize = tf.cast(xTest, tf.float32) / 255# 数据独热编码yTrainOneHot = tf.keras.utils.to_categorical(yTrain)yTestOneHot = tf.keras.utils.to_categorical(yTest)model = tf.keras.models.load_model("../model.h5")model.summary()sample = xTrainNormalize[0]plt.imshow(sample)plt.colorbar()plt.savefig('./train.png')output = tf.keras.models.Model(inputs=model.get_layer('conv2d').input,outputs=model.get_layer('max_pooling2d').output).predict(sample)# output = tf.keras.models.Sequential([#     tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),#     model.layers[0],#     model.layers[1],#     model.layers[2],# ]).predict(sample)print('output.shape', output.shape)# plot_images(output, 9, '5_Conv2D_BN_MP_1', str(output.shape))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis系列之简单实现watchDog自动续期机制

在分布锁的实际使用中,可能会遇到一种情况,一个业务执行时间很长,已经超过redis加锁的时间,也就是锁已经释放了,但是业务还没执行完成,这时候其它线程还是可以获取锁,那就没保证线程安全 项目环…

完美解决labelimg xml转可视化中文乱码问题,不用matplotlib

背景简述 我们有一批标注项目要转可视化,因为之前没有做过,然后网上随意找了一段代码测试完美(并没有)搞定,开始疯狂标注,当真正要转的时候傻眼了,因为测试的时候用的是英文标签,实…

基于linux系统的Tomcat+Mysql+Jdk环境搭建(三)centos7 安装Tomcat

Tomcat下载官网: Apache Tomcat - Which Version Do I Want? JDK下载官网: Java Downloads | Oracle 中国 如果不知道Tomcat的哪个版本应该对应哪个版本的JDK可以打开官网,点击Whitch Version 下滑,有低版本的,如…

Flutter实现Android拖动到垃圾桶删除效果-Draggable和DragTarget的详细讲解

文章目录 Draggable介绍构造函数参数说明使用示例 DragTarget 介绍构造函数参数说明使用示例 DragTarget 如何接收Draggable传递过来的数据? Draggable介绍 Draggable是Flutter框架中的一个小部件,用于支持用户通过手势拖动一个子部件。它是基于手势的一…

知识付费小程序开发:技术实践示例

随着知识付费小程序的兴起,让我们一起来看一个简单的示例,使用Node.js和Express框架搭建一个基础的知识付费小程序后端。 首先,确保你已经安装了Node.js和npm。接下来,创建一个新的项目文件夹,然后通过以下步骤创建你…

适用于 Windows 和 Mac 的 10 款最佳照片恢复软件(免费和付费)

丢失照片很容易。这里点击错误,那里贴错标签的 SD 卡,然后噗的一声,一切都消失了。值得庆幸的是,在技术领域,你可以纠正一些错误。其中包括删除数据或格式化错误的存储设备。 那么,让我们看看可用于从 SD …

[c++]—vector类___提升版(带你了解vector底层的运用)

我写我 不论主谓宾 可以反复错 🌈vector的介绍 1.vector是表示可变大小数组的序列容器2.就像数组一样,vector也采用的连续存储空间来存储元素,也就是意味着可以采用下标对vector的元素进行访问,和数组一样高效。但是又不像数组&…

工业性能CCD图像处理+

目录 硬件部分 ​编辑 软件部分 CCD新相机的调试处理(更换相机处理,都要点执行检测来查看图像变化) 问题:新相机拍摄出现黑屏,图像拍摄不清晰,(可以点击图像,向下转动鼠标的滚轮&#xff08…

基于linux系统的Tomcat+Mysql+Jdk环境搭建(一)vmare centos7 设置静态ip和连接MobaXterm

特别注意,Windows10以上版本操作系统需要下载安装VMware Workstation Pro16及以上版本,安装方式此处略。 (可忽略 my*** 记录设置的vamare centos7 账号root/aaa 密码:Aa123456 ) 1、命令行和图形界面切换 如果使用的是VMware虚拟机&…

金智融门户(统一身份认证)同步数据至钉钉通讯录

前言:因全面使用金智融门户和数据资产平台,二十几个信息系统已实现统一身份认证和数据同步,目前单位使用的钉钉尚未同步组织机构和用户信息,职工入职、离职、调岗时都需要手工在钉钉后台操作,一是操作繁琐,二是钉钉通讯录更新不及时或经常遗漏,带来管理问题。通过金智融…

CAD 审图意见的导出

看图的时候喜欢在图上直接标注意见,但是如果还要再把意见一行一行的导出到word里面就很麻烦,在网上看了一个审图软件,报价要980,而且那个审图意见做的太复杂了。 我的需求就是把图上标的单行文字和多行文字直接导出来就行&#x…

debug点f8step over会进入class文件

File->settings->Bulid.Executiong.Deployment->Debugger->Stepping 取消如图对钩即可

二十七、读写文件

二十七、读写文件 27.1 文件类QFile #include <QCoreApplication>#include<QFile> #include<QDebug>int main(int argc, char *argv[]) {QCoreApplication a(argc, argv);QFile file("D:/main.txt");if(!file.open(QIODevice::WriteOnly | QIODe…

three.js模拟太阳系

地球的旋转轨迹目前设置为了圆形&#xff0c;效果&#xff1a; <template><div><el-container><el-main><div class"box-card-left"><div id"threejs" style"border: 1px solid red"></div><div c…

idea第一次提交到git(码云)

1.先创建一个仓库 2.将idea和仓库地址绑定 2.将idea和仓库地址绑定

CentOS 7系统加固详细方案SSH FTP MYSQL加固

一、删除后门账户 修改强口令 1、修改改密码长度需要编译login.defs文件 vi /etc/login.defs PASS_MIN_LEN 82、注释掉不需要的用户和用户组 或者 检查是否存在除root之外UID为0的用户 使用如下代码&#xff0c;对passwd文件进行检索&#xff1a; awk -F : ($30){print $1) …

『K8S 入门』二:深入 Pod

『K8S 入门』二&#xff1a;深入 Pod 一、基础命令 获取所有 Pod kubectl get pods2. 获取 deploy kubectl get deploy3. 删除 deploy&#xff0c;这时候相应的 pod 就没了 kubectl delete deploy nginx4. 虽然删掉了 Pod&#xff0c;但是这是时候还有 service&#xff0c…

轻松搭建FPGA开发环境:第三课——Vivado 库编译与设置说明

工欲善其事必先利其器&#xff0c;很多人想从事FPGA的开发&#xff0c;但是不知道如何下手。既要装这个软件&#xff0c;又要装那个软件&#xff0c;还要编译仿真库&#xff0c;网上的教程一大堆&#xff0c;不知道到底应该听谁的。所以很多人还没开始就被繁琐的开发环境搭建吓…

电子学会C/C++编程等级考试2021年06月(六级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:逆波兰表达式 逆波兰表达式是一种把运算符前置的算术表达式,例如普通的表达式2 + 3的逆波兰表示法为+ 2 3。逆波兰表达式的优点是运算符之间不必有优先级关系,也不必用括号改变运算次序,例如(2 + 3) * 4的逆波兰表示法为* +…

智能优化算法应用:基于动物迁徙算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于动物迁徙算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于动物迁徙算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.动物迁徙算法4.实验参数设定5.算法结果6.…