分享人工智能技术干货,专注深度学习与计算机视觉领域!
相较于Tensorflow,Pytorch一开始就是以动态图构建神经网络图的,其获取模型参数的方法也比较容易,既可以根据其内建接口自己写代码获取模型参数情况,也可以借助第三方库来获取模型参数情况,下面,就让我们一起来了解Pytorch获取模型参数情况的这两种方法!
Pytorch依据其内建接口自己写代码获取模型参数情况,我们主要是借助该框架提供的模型parameters()接口并获取对应参数的size来实现的,对于该参数是否属于可训练参数,那么我们可以依据Pytorch提供的requires_grad标志位来进行判断,具体方法如下代码所示:
# 定义总参数量、可训练参数量及非可训练参数量变量
Total_params = 0
Trainable_params = 0
NonTrainable_params = 0# 遍历model.parameters()返回的全局参数列表
for param in model.parameters():mulValue = np.prod(param.size()) # 使用numpy prod接口计算参数数组所有元素之积Total_params += mulValue # 总参数量if param.requires_grad:Trainable_params += mulValue # 可训练参数量else:NonTrainable_params += mulValue # 非可训练参数量print(f'Total params: {Total_params}')
print(f'Trainable params: {Trainable_params}')
print(f'Non-trainable params: {NonTrainable_params}')
如无特殊设定,一般来说,因为我们是直接获取的model网络参数,因此很少有不可训练参数,往往NonTrainable_params输出结果是0。
这里的第三方库是指torchsummary,欲要使用该库,首先我们得安装它,命令如下:
pip install torchsummary
然后,引入该库的summary方法:
from torchsummary import summary
最后,直接调用一条命令即可获取到Pytorch模型参数情况:
summary(model, input_size=(ch, h, w), batch_size=-1)
这里的ch是指输入张量的channel数量,h表示输入张量的高,w表示输入张量的宽。
我们从以上代码可以看到,借助第三方库torchsummary来获取Pytorch的模型参数情况非常之简便,只需确认好输入图像shape即可,那么,torchsummary的输出是如何的呢?
上图是应用torchsummary获得输出结果的一个示例,这与Tensorflow V2.x及其之后的版本的模型summary()输出是差不多的,输出信息里也是有各个类别的参数量情况、每层网络的参数量、额外的层名称及其输出shape大小,此外,torchsummary库还为我们计算了输入大小、模型参数大小及前向/反向传播参数量大小,可谓信息非常细致,这极大地方便了我们查看Pytorch模型的构造情况。
除了上述两种获取Pytorch模型参数情况的方法,我们当然也可以直接使用model.state_dict()接口获取Pytorch网络参数,但是此种方法打印出来的信息结构非常混乱,也没有为我们进行有效的信息整理,因此很不建议该方法。