如何查看resnet网络的中间输出特征和卷积核的参数

查看中间层的特征,需要在定义Model时,在forward时,将中间要显示的层输出。

    def forward(self, x):outputs = []x = self.conv1(x)outputs.append(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)outputs.append(x)# x = self.layer2(x)# x = self.layer3(x)# x = self.layer4(x)## if self.include_top:#     x = self.avgpool(x)#     x = torch.flatten(x, 1)#     x = self.fc(x)return outputs

这里在convert1后和layer1后添加到一个列表中,然后输出。后面的就不进行卷积操作了。

 

import torch.nn as nn
import torchclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(out_channel)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(out_channel)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, blocks_num, num_classes=1000, include_top=True):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel, channel))return nn.Sequential(*layers)def forward(self, x):outputs = []x = self.conv1(x)outputs.append(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)outputs.append(x)# x = self.layer2(x)# x = self.layer3(x)# x = self.layer4(x)## if self.include_top:#     x = self.avgpool(x)#     x = torch.flatten(x, 1)#     x = self.fc(x)return outputsdef resnet34(num_classes=1000, include_top=True):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

 然后就可以在预测的时候输出中间层。

import torch
from alexnet_model import AlexNet
from resnet_model import resnet34
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torchvision import transformsdata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# data_transform = transforms.Compose(
#     [transforms.Resize(256),
#      transforms.CenterCrop(224),
#      transforms.ToTensor(),
#      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# create model
model = AlexNet(num_classes=5)
# model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"  # "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
print(model)# load image
img = Image.open("../tulip.jpg")
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# forward
out_put = model(img)
for feature_map in out_put:# [N, C, H, W] -> [C, H, W]im = np.squeeze(feature_map.detach().numpy())# [C, H, W] -> [H, W, C]im = np.transpose(im, [1, 2, 0])# show top 12 feature mapsplt.figure()for i in range(12):ax = plt.subplot(3, 4, i+1)# [H, W, C]plt.imshow(im[:, :, i], cmap='gray')plt.show()

输出卷积核的参数

import torch
from alexnet_model import AlexNet
from resnet_model import resnet34
import matplotlib.pyplot as plt
import numpy as np# create model
model = AlexNet(num_classes=5)
# model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"  # "resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
print(model)weights_keys = model.state_dict().keys()
for key in weights_keys:# remove num_batches_tracked para(in bn)if "num_batches_tracked" in key:continue# [kernel_number, kernel_channel, kernel_height, kernel_width]weight_t = model.state_dict()[key].numpy()# read a kernel information# k = weight_t[0, :, :, :]# calculate mean, std, min, maxweight_mean = weight_t.mean()weight_std = weight_t.std(ddof=1)weight_min = weight_t.min()weight_max = weight_t.max()print("mean is {}, std is {}, min is {}, max is {}".format(weight_mean,weight_std,weight_max,weight_min))# plot hist imageplt.close()weight_vec = np.reshape(weight_t, [-1])plt.hist(weight_vec, bins=50)plt.title(key)plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/734181.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于MapReduce的汽车数据清洗与统计案例

数据简介 ecar168.csv(汽车销售数据表): 字段数据类型字段说明rankingString排名manufacturerString厂商vehicle_typeString车型monthly_sales_volumeString月销量accumulated_this_yearString本年累计last_monthString上月chain_ratioStri…

BC134 蛇形矩阵

一:题目 二:思路分析 2.1 蛇形矩阵含义 首先,这道题我们要根据这个示例,找到蛇形矩阵是怎么移动的 这是,我们可以标记一下每次移动到方向 我们根据上图可以看出,蛇形矩阵一共有两种方向,橙色…

【Pytorch】新手入门:基于sklearn实现鸢尾花数据集的加载

【Pytorch】新手入门:基于sklearn实现鸢尾花数据集的加载 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望…

数据挖掘助力零售业务增长:从数据分析到策略制定的全过程

在数字化时代,数据挖掘已经成为企业获取竞争优势的关键手段之一。通过深入挖掘和分析海量数据,企业能够洞察消费者行为、市场趋势和潜在商机,从而制定更为精准和有效的业务策略。本文将通过一个具体的零售业务案例,分析数据挖掘的应用过程,展示如何从数据中发现价值,并将…

Hadoop运行搭建——系统配置和Hadoop的安装

Hadoop运行搭建 前言: 本文原文发在我自己的博客小站,直接复制文本过来,所以图片不显示(我还是太懒啦!)想看带图版的请移步我的博客小站~ Linux镜像:CentOS7 系统安装:CentOS安装参考教程 系统网卡设置…

C语言——函数指针——函数指针变量详解

函数指针变量 函数指针变量的作用 函数指针变量是指向函数的指针,它可以用来存储函数的地址,并且可以通过该指针调用相应的函数。函数指针变量的作用主要有以下几个方面: 回调函数:函数指针变量可以作为参数传递给其他函数&…

三菱PLC基础指令

LD指令(a触点的逻辑运算开 指令表程序 0000 LD X000 0001 OUT Y000 LDI指令(b触点的逻辑运算开 指令表程序 0000 LDI X000 0001 OUT Y000 3.数据寄存器(D)的位指定*1(仅对应FX3u,FX3uc可编程控制器) 指令表程序 0000 LD D0.3 0001 OUT Y000 4.定时器 0000 LDI X00…

Objects类 --java学习笔记

Objects类 Objects是一个工具类,提供了很多操作对象的静态方法给我们使用 Objects类常用的三个方法 Objects.equals 比直接equals更安全,因为Objects.equals里面做了非空校验 Objects.isNull(A) 等价于 A null Objects.non…

Redisson学习

简介 Redisson 是一个在 Redis 的基础上实现的 Java 驻留内存数据网格(In-Memory Data Grid)。它提供了许多分布式 Java 对象和服务,包括分布式锁、分布式集合、分布式执行服务、分布式调度任务等。 使用 依赖 相关依赖,注意版…

【兔子机器人】修改GO电机id(软件方法、硬件方法)

一、硬件方法 利用上位机直接修改GO电机的id号: 打开调试助手,点击“调试”,查询电机,修改id号,即可。 但先将四个GO电机连接线拔掉,不然会将连接的电机一并修改。 利用24V电源给GO电机供电。 二、软件方…

回溯算法12-全排列II(Java/排列数去重操作)

12.全排列II 题目描述 给定一个可包含重复数字的序列 nums ,按任意顺序 返回所有不重复的全排列。 示例 1: 输入:nums [1,1,2] 输出: [[1,1,2],[1,2,1],[2,1,1]]示例 2: 输入:nums [1,2,3] 输出&…

Spring Boot整合zxing实现二维码登录

zxing是google的一个二维码生成库,使用时需配置依赖: implementation("com.google.zxing:core:3.4.1") implementation("com.google.zxing:javase:3.4.1") zxing的基本使用 我们可以通过MultiFormatWriter().encode()方法获取一个…

AI预测福彩3D第3弹【2024年3月6日预测】

书接上回,经过连续两期使用人工神经网络对福彩3D进行预测,经过不断的调参优化,并及时总结规律,感觉还是有一定的信心提高七码的命中概率。 今天,咱们继续来验证,直接上今天的统计结果,首先&…

手写分布式配置中心(五)整合springboot(不自动刷新的)

springboot中使用配置方式有四种,分别是environment、BeanDefinition、Value、ConfigurationProperties。具体的原理可以看我之前的一篇文章https://blog.csdn.net/cjc000/article/details/132800290。代码在https://gitee.com/summer-cat001/config-center 原理 …

斐波那契算法

斐波那契数列 斐波那契数列(Fibonacci sequence)是一个非常著名的数学序列,它是由意大利数学家莱昂纳多斐波那契(Leonardo Fibonacci)在1202年的著作《计算之书》(Liber Abaci)中首次引入的。这…

SAP 凭证分割功能 - 概述系统配置 - Chapter01/02

目录 1. 概述 2. 系统配置 2.1.为文档拆分给总分类账科目分类 2.2.对凭证拆分的凭证类型进行分类 2.3.定义总账会计核算的凭证分解特征 2.4.定义零余额结算科目 配置步骤: 路径:IMG->财务会计->总帐会计->业务交易->凭证拆分->定义零余额结算科 2.5.定义…

移动端uni-app小程序搜索高亮前端处理,同时可设置相关样式,兼顾性能

在uni-app中我们会遇到搜索高亮显示的需求 如下图: 起初用的是富文本实现 使用replaceAll方法取代搜索字段为一个 标签并设置相应的样式,但是小程序的并没有把 标签渲染出来,所以放弃了,下面原代码: /* 搜索字体变色…

C++进阶:详细讲解继承

现在也是结束了初阶部分的内容,今天开始就进入进阶部分了。一刻也没有为初阶的结束而哀悼,立刻赶来“战场”的是进阶部分里的继承 文章目录 1.继承的概念和定义1.1继承的概念1.2继承的定义1.2.1继承的格式1.2.2再讲访问限定符(详讲protected)1.2.3**继承…

【被c++11弃用的智能指针auto_ptr】

C11标准之前的auto_ptr这个智能指针不被广泛使用的原因就是:在某些应用场景下,拷贝构造函数的意义不明确,同理赋值语句也是这个道理,意义同样不明确,因为C11标准之前并不存在移动赋值和移动构造的概念,还有…

NFT交易市场开发(一)

实现的基本功能 (一) 发行一个符合ERC20标准的测试Token,要求如下: 总量::1亿精度:18名称:Fake USDT in CBI简称:cUSDT (二) 发行一个符合ERC72…