RESNET的复现pytorch版本

RESNET的复现pytorch版本

使用的数据为Object_102_CaDataset,可以在网上下载,也可以在评论区问。

RESNET模型的亮点

1.提出了残差模块。

2.使用Batch Normalization加速训练

3.残差网络:易于收敛,很好的解决了退化问题,模型可以很深,准确率大大提高了。

残差结构如下所示:

image-20240314180554463

首先,是模型构建部分

class ResBlock(nn.Module):def __init__(self, in_channels, out_channels, stride_1=1, stride_2=1, padding=1, kernel_size=(3, 3), short_cut=None):super(ResBlock, self).__init__()self.short_cut = short_cutself.model = Sequential(# 1.1Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride_1,padding=padding),BatchNorm2d(out_channels),ReLU(),Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride_2,padding=padding),BatchNorm2d(out_channels),ReLU(),)self.short_layer = Sequential(Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1), stride=2, padding=0),BatchNorm2d(out_channels),ReLU(),)self.R = ReLU()def forward(self, x):f1 = xif self.short_cut is not None:f1 = self.short_layer(x)out = self.model(x)out = self.R(f1+out)return out

该部分为模型的残差块,使用了3*3的卷积,然后进行归一化。

对于整个模型的构建部分:

class Resnet_easier(nn.Module):def __init__(self, num_classes):super(Resnet_easier, self).__init__()self.model0 = Sequential(# 0# 输入3通道、输出64通道、卷积核大小、步长、补零、Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=2, padding=3),BatchNorm2d(64),ReLU(),MaxPool2d(kernel_size=(3, 3), stride=2, padding=1),)self.model1 = ResBlock(64, 64)self.model2 = ResBlock(64, 64)self.model3 = ResBlock(64, 128, stride_1=2, stride_2=1, short_cut=True)self.model4 = ResBlock(128, 128)self.model5 = ResBlock(128, 256, stride_1=2, stride_2=1, short_cut=True)self.model6 = ResBlock(256, 256)self.model7 = ResBlock(256, 512, stride_1=2, stride_2=1, short_cut=True)self.model8 = ResBlock(512, 512)# AAP 自适应平均池化self.aap = AdaptiveAvgPool2d((1, 1))# flatten 维度展平self.flatten = Flatten(start_dim=1)# FC 全连接层self.fc = Linear(512, num_classes)def forward(self, x):x = x.to(torch.float32)x = self.model0(x)x = self.model1(x)x = self.model2(x)x = self.model3(x)x = self.model4(x)x = self.model5(x)x = self.model6(x)x = self.model7(x)x = self.model8(x)# 最后3个x = self.aap(x)x = self.flatten(x)x = self.fc(x)return x

接下来是读入数据模块

class Object_102_CaDataset(Dataset):def __init__(self, folder):mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]self.file_list = []label_names = [item for item in os.listdir(folder) if os.path.isdir(os.path.join(folder, item))]  # 获取文件夹下的所有标签label_to_index = dict((label, index) for index, label in enumerate(label_names))  # 将label转为数字self.all_picture_paths = self.get_all_picture(folder)  # 获取所有图片路径self.all_picture_labels = [label_to_index[os.path.split(os.path.dirname(os.path.abspath(path)))[1]] for path inself.file_list]self.mean = np.array(mean).reshape((1, 1, 3))self.std = np.array(std).reshape((1, 1, 3))def __getitem__(self, index):img = cv2.imread(self.all_picture_paths[index])if img is None:print(os.path.join("image", self.all_picture_paths[index]))img = cv2.resize(img, (224, 224))  #统一图片的尺寸img = img / 255img = (img - self.mean) / self.stdimg = np.transpose(img, [2, 0, 1])label = self.all_picture_labels[index]img = torch.tensor(img)label = torch.tensor(label)return img, labeldef __len__(self):return len(self.all_picture_paths)def get_all_picture(self, folder):for filename in os.listdir(folder):file_path = os.path.join(folder, filename)if os.path.isfile(file_path):self.file_list.append(file_path)elif os.path.isdir(file_path):self.file_list = self.get_all_picture(file_path)return self.file_list

使用上述dataloader可以方便的对数据进行读取操作。

接下来就是整个的训练模块

import torch
from torch import nn
from torch.utils.data import DataLoaderfrom ResNet.ResNet18 import Resnet18
from ResNet.ResNet18_easier import Resnet_easier
from ResNet.dataset import Object_102_CaDataset
from ResNet.res_net import ResNet, ResBlockfrom torchsummary import summary
data_dir = 'E:\PostGraduate\Paper_review\computer_view_model/ResNet/data/101_ObjectCategories'
Object_102 = Object_102_CaDataset(data_dir)
train_size = int(len(Object_102) * 0.7)
# print(train_size)
test_size = len(Object_102) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(Object_102, [train_size, test_size])
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
#显示数据,此处的注释内容可以让我们看到读取的图片
# import random
# from matplotlib import pyplot as plt
# import matplotlib
# matplotlib.use('TkAgg')
# def denorm(img):
#     for i in range(img.shape[0]):
#         img[i] = img[i] * std[i] + mean[i]
#     img = torch.clamp(img, 0., 1.)
#     return img
# plt.figure(figsize=(8, 8))
# for i in range(9):
#     img, label = train_dataset[random.randint(0, len(train_dataset))]
#     img = denorm(img)
#     img = img.permute(1, 2, 0)
#     ax = plt.subplot(3, 3, i + 1)
#     ax.imshow(img.numpy()[:, :, ::-1])
#     ax.set_title("label = %d" % label)
#     ax.set_xticks([])
#     ax.set_yticks([])
# plt.show()train_iter = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_iter = DataLoader(train_dataset, batch_size=64)
model = Resnet_easier(102)
# print(summary(model, (3, 224, 224)))
epoch = 50  # 训练轮次
optmizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# optmizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()#.cuda()  # 定义交叉熵损失函数
log_interval = 10train_losses = []
train_counter = []
test_losses = []
test_counter = [i * len(train_iter.dataset) for i in range(epoch + 1)]# test_loop(model,'cpu',test_iter)
def train_loop(n_epochs, optimizer, model, loss_fn, train_loader):for epoch in range(1, n_epochs + 1):model.train()for i, data in enumerate(train_loader):correct = 0(images, label) = dataimages = images#.cuda()label = label#.cuda()# print(len(images))output = model(images)loss = loss_fn(output, label)optimizer.zero_grad()loss.backward()optimizer.step()pred = output.data.max(1, keepdim=True)[1]pred = torch.tensor(pred, dtype=torch.float32)for index in range(0, len(pred)):if pred[index] == label[index]:correct += 1# correct = torch.eq(pred, label).sum()# print(correct)if i % log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\t accuracy:[{}/{} ({:.0f}%)] \tLoss: {:.6f}'.format(epoch, i * len(images), len(train_loader.dataset),100. * i / len(train_loader), correct, len(pred), 100. * correct / len(pred), loss.item()))train_losses.append(loss.item())train_counter.append((i * 64) + ((epoch - 1) * len(train_loader.dataset)))torch.save(model.state_dict(), 'model_paramter/test/model.pth')torch.save(optimizer.state_dict(), 'model_paramter/test/optimizer.pth')# test_loop(model, 'cpu', test_iter)# PATH = 'E:\\PostGraduate\\Paper_review\\computer_view_model\\ResNet/model_paramter/model.pth'
# dictionary = torch.load(PATH)
# model.load_state_dict(dictionary)
train_loop(epoch, optmizer, model, loss_fn, train_iter)# PATH = 'E:\\PostGraduate\\Paper_review\\computer_view_model\\ResNet/model_paramter/model.pth'
# dictionary = torch.load(PATH)
# model.load_state_dict(dictionary)
# test_loop(model, 'cpu', test_iter)

若要测试数据的准确度等内容可以参考之前的博文使用LSTm进行情感分析,对test部分进行修改即可。

也可以参考下面的

PATH = 'E:\\PostGraduate\\Paper_review\\computer_view_model\\ResNet/model_paramter/model.pth'
dictionary = torch.load(PATH)
model.load_state_dict(dictionary)
def test_loop(model, device, test_iter):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_iter:data = data.to(device)target = target.to(device)output = model(data)output = output.data.max(1, keepdim=True)[1]output = torch.tensor(output, dtype=torch.float32)# loss_func = loss_fn(output, target)# test_loss += loss_funcpred = outputfor index in range(0, len(pred)):if pred[index] == target[index]:correct += 1test_loss /= len(test_iter.dataset)test_losses.append(test_loss)print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_iter.dataset),100. * correct / len(test_iter.dataset)))test_loop(model,'cpu',test_iter)

loss /= len(test_iter.dataset)
test_losses.append(test_loss)
print(‘\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n’.format(
test_loss, correct, len(test_iter.dataset),
100. * correct / len(test_iter.dataset)))

test_loop(model,‘cpu’,test_iter)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/760645.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++】狗屁不通文章生成器2.0

【C】狗屁不通文章生成器2.0 1 前言2 改进2.1 字词的前后关系2.2 文章生成系统 3 实现(部分)3.1 class wordpair3.1.1 转化为 json3.1.2 添加后缀词3.1.3 选择后缀词 3.2 class createArticle3.2.1文本分割3.2.2生成文章 4演示4.1 wordpair(3x2), 启动词(春天)4.2 wordpair(2x1…

Vue按需加载:提升应用性能的利器

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

【进程概念】进程控制块task_struct-PCB

文章目录 进程的概念如何描述进程?**为什么要描述一个进程**?进程描述--PCBtask_struct 组织进程查看进程通过系统调用获取进程标示符getpid()以及getppid() 进程的概念 在【百度百科】中,关于进程---- 狭义定义:进程是 正在运行 的程序的实…

若依ruoyi-vue中的文件上传和下载

文章目录 文件上传后端实现前端实现 文件下载后端实现前端实现 在若依(Ruoyi)框架中,结合 Vue 前端框架,文件的上传和下载通常使用以下方法实现: 文件上传 若依现成的功能里面没有文件上传,但是集成了文件…

基于php健身房管理系统flask-django-python

根据现实需要,此系统我们设计出一下功能,主要有以下功能模板。 (1)前台功能:首页、运动器材、教练信息、营业信息、公告栏、在线留言、后台管理、个人中心。 (2)会员功能:首页、个人…

Springboot笔记(web开启)-08

有一些日志什么的后续我会补充 1.使用springboot: 创建SpringBoot应用,选中我们需要的模块;SpringBoot已经默认将这些场景配置好了,只需要在配置文件中指定少量配置就可以运行起来自己编写业务代码; 2.SpringBoot对静态资源的映…

【记录39】html element-ui 加载

环境 html使用element-ui组件、用vue框架搭建 方法一: 方法二(推荐) 将相关资源下载下来,在对应的html文件中相对路径引入。注意:css加载放在js之前

Controller中接收数组参数

1、场景 需要根据用户id集合批量删除用户数据,前端使用post请求,controller中参数接收数组参数并根据用户id删除用户基本信息 2、分析处理: 2.1、前端请求类型contentType:application/json 请求体中为json字符串,后端新建一个U…

javaSwing愤怒的小鸟

一、简介 游戏名称是“愤怒的小鸟”,英文称为“AngryBird”。 “愤怒的小鸟”是著名游戏公司Rovio偶然间开发出来的益智游戏,从2009年12月上市到iOS。,讲述了鸟类和猪因为猪偷鸟蛋反生的一系列故事。游戏的类型版本是横向版本的水平视角&…

怎么在Linux系统下Docker部署Excalidraw白板工具并实现无公网IP远程访问?

文章目录 1. 安装Docker2. 使用Docker拉取Excalidraw镜像3. 创建并启动Excalidraw容器4. 本地连接测试5. 公网远程访问本地Excalidraw5.1 内网穿透工具安装5.2 创建远程连接公网地址5.3 使用固定公网地址远程访问 本文主要介绍如何在Ubuntu系统使用Docker部署开源白板工具Excal…

C++临时变量

本博客将讲述我学习过程中对临时变量的疑惑与理解 为什么写这篇文章? 我在学习C过程中,发现C在发生隐式转换时或者出现未命名的变量如字符串再或者在求值的时候,会出现C临时变量(系统自动生成),而这个临时…

PgSQL根据身份证号查询年龄

1、需求:数据库中有身份证号码,也有年龄字段,但是年龄字段不会自动更新,现在需要返回最新的年龄数据。 2、思路:获取当前年份,截取省份证中的年龄部分数据,再进行相减即可; 3、具体…

MySQL高级学习笔记

1、MySQL架构组成 1.1 高级MySQL介绍 什么是DBA? 数据库管理员,英文是Database Administrator,简称DBA; 百度百科介绍 数据库管理员(简称DBA),是从事管理和维护数据库管理系统(D…

搜索测试题题解(3月21号总结)

目录 1.Shufflem Up 2.Pots 3.Open the Lock 1.Shufflem Up 样例 InputcopyOutputcopy 2 4 AHAH HAHA HHAAAAHH 3 CDE CDE EEDDCC 1 2 2 -1 题意:本题要求将s1和s2合并,再将合并的s分为s1和s2,知道s为我们需要得到的期望s,输…

巨细!Python爬虫详解

爬虫(又称为网页蜘蛛,网络机器人,在 FOAF 社区中间,更经常的称为网页追逐者);它是一种按照一定的规则,自动地抓取网络信息的程序或者脚本。 如果我们把互联网比作一张大的蜘蛛网,那…

北航最新!基于条纹投影的半透明物体3D重建方法

作者:小柠檬 | 来源:3DCV 在公众号「3DCV」后台,回复「原论文」可获取论文pdf 添加微信:dddvision,备注:3D高斯,拉你入群。文末附行业细分群 详细内容请关注3DCV 3D视觉精品课程:…

雷池 WAF 社区版:下一代 Web 应用防火墙的革新

黑客的挑战 智能语义分析算法: 黑客们常利用复杂技术进行攻击,但雷池社区版的智能语义分析算法能深入解析攻击本质,即使是最复杂的攻击手法也难以逃脱。 0day攻击防御: 传统防火墙难以防御未知攻击,但雷池社区版能有效…

01_Kubernetes基础

Kubernetes为什么叫K8S:因为K和S之间有8个字母 为什么需要K8S 对于云计算来说有自己的交互标准 Paas的下一代标准就是容器化,容器的集群化有没有很好的方案?有需求就会有产品,这个产品就叫做资源管理器。 首先是Apache的MESOS&…

LeetCode每日一题【206. 反转链表】

思路:双指针,一前一后,逐个把指向后面的指针指向前面。 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int x) : val(x), ne…

刷题训练之滑动窗口

> 作者简介:დ旧言~,目前大二,现在学习Java,c,c,Python等 > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:熟练掌握滑动窗口算法,并且能把下面的…