注意区分打印网络参数的个数和打印网络参数(权重和偏置)的个数
在TensorFlow 1.0 中,可以通过使用tf.trainable_variables()获取模型的所有可训练参数(即权重和偏置),并使用sess.run()在会话中运行这些变量来打印它们的值。
打印网络参数(权重和偏置)
import tensorflow as tf# 构建模型# 创建会话
with tf.Session() as sess:# 初始化所有变量sess.run(tf.global_variables_initializer())# 获取所有可训练的变量trainable_vars = tf.trainable_variables()# 打印每个变量的名称和值for var in trainable_vars:print(var.name)print(sess.run(var))
打印出网络参数的个数,需要获取每个可训练参数的形状,然后计算它们的乘积来得到每个参数的元素个数。最后,将所有参数的元素个数相加即可得到网络参数的总个数。
import tensorflow as tf
import numpy as np# 构建模型# 创建会话
with tf.Session() as sess:# 初始化所有变量sess.run(tf.global_variables_initializer())# 获取所有可训练的变量trainable_vars = tf.trainable_variables()# 计算所有参数的总个数total_parameters = 0for variable in trainable_vars:# 获取变量的形状,例如[5, 5, 1, 32]表示一个5x5的32通道卷积核shape = variable.get_shape()# 计算当前变量的参数个数,为形状的各维大小的乘积variable_parametes = 1for dim in shape:variable_parametes *= dim.value# 将当前变量的参数个数加到总个数上total_parameters += variable_parametesprint("Total number of parameters in the network: {}".format(total_parameters))