基于ResNet框架的CNN

数据准备

DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

一、训练集和验证集的划分

#spile_data.pyimport os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla] #['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
mkfile('flower_data/train') #生成train文件夹
for cla in flower_class:mkfile('flower_data/train/'+cla) #在train文件夹下生成各个类别的文件夹mkfile('flower_data/val')
for cla in flower_class:mkfile('flower_data/val/'+cla)split_rate = 0.1
for cla in flower_class:cla_path = file + '/' + cla + '/'images = os.listdir(cla_path)num = len(images)eval_index = random.sample(images, k=int(num*split_rate)) #在images中随机获取0.1的图片for index, image in enumerate(images):if image in eval_index:image_path = cla_path + imagenew_path = 'flower_data/val/' + clacopy(image_path, new_path)else:image_path = cla_path + imagenew_path = 'flower_data/train/' + clacopy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")

二、ResNet网络

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import mmd
from attention import ChannelAttention
from attention import SpatialAttention
import torch__all__ = ['ResNet', 'resnet50']model_urls = {'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}def conv3x3(in_planes, out_planes, stride=1,groups=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, bias=False)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)'''
Resnet中BasicBlock结构,ResNet中使用的网络结构。分2步走:3x3; 3x3
'''
class BasicBlock(nn.Module):expansion = 1  # 最后输出的通道数扩充的比例# BN层来加快网络模型的收敛速度/训练速度/解决梯度消失或者梯度爆炸的问题# 对batch中所有的同一个channel的数据元素进行标准化处理。一个batch共享一套参数# 即如果有C个通道,对N*H*W进行标准化处理,一共进行C次。def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)# downsample是用一个1x1的卷积核处理,改变通道数,如果H/W尺度也不一样就设计strideif self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return out'''
Resnet中Bottleneck结构,ResNet中使用的网络结构。目的是为了降低参数量,分三步走:
1数据降维(1x1),2常规卷积核的卷积(3x3),3数据升维(1x1)
结果图片长宽不变,通道数扩大4倍
'''
class Bottleneck(nn.Module):expansion = 4  # 最后输出的通道数扩充的比例def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, norm_layer=None):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)# BN层来加快网络模型的收敛速度/训练速度/解决梯度消失或者梯度爆炸的问题# 对batch中所有的同一个channel的数据元素进行标准化处理。一个batch共享一套参数# 即如果有C个通道,对N*H*W进行标准化处理,一共进行C次。self.bn1 = nn.BatchNorm2d(planes)  # 卷积层后加BatchNorm2d,按照channel进行归一化self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(planes * 4)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.dropout=nn.Dropout()self.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)# downsample是用一个1x1的卷积核处理,改变通道数,如果H/W尺度也不一样就设计strideif self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return out'''
ResNet由以下组成:
1.conv1、norm1、relu(当指定了deep_stem,这三个将被stem代替)
2.maxpool
3.layer1~layer4(定义为ResLayer类,分别由多个BasicBlock或Bottleneck组成)
'''
class ResNet(nn.Module):# 参数block指明残差块是两层或三层,参数layers指明每个卷积层需要的残差块数量,num_classes指明分类数,zero_init_residual是否初始化为0def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, norm_layer=None):super(ResNet, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself.inplanes = 64self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(12, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)# 网络的第一层加入注意力机制# self.ca = ChannelAttention(self.inplanes)# self.sa = SpatialAttention()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)# 网络的卷积层的最后一层加入注意力机制# self.ca1 = ChannelAttention(self.inplanes)# self.sa1 = SpatialAttention()self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应平均池化,指定输出(H,W)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)# 构造ResLayer类,layer1~layer4# block:BasicBlock/Bottleneck; planes:块的输入通道数; blocks:块的数目def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):if norm_layer is None:norm_layer = nn.BatchNorm2ddownsample = None  # downSample的作用于在残差连接时 将输入的图像的通道数变成和卷积操作的尺寸一致if stride != 1 or self.inplanes != planes * block.expansion:# 通道数恢复成一致/长宽恢复一致downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, norm_layer=norm_layer))return nn.Sequential(*layers)'''卷积/池化后的tensor维度为(batchsize,channels,x,y),其中x.size(0)指batchsize的值,通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y)即将(channels,x,y)拉直,然后就可以和fc层连接因为最后avgpool(1,1)指定输出长*宽为1*1,通道为512*4,所以channels*x*y=2048'''def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)# x = self.ca(x) * x# x = self.sa(x) * xx = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# x = self.ca1(x) * x# x = self.sa1(x) * xx = self.avgpool(x)x = x.view(x.size(0), -1)#x=self.fc(x)return xclass DANNet(nn.Module):def __init__(self, num_classes=2):super(DANNet, self).__init__()self.sharedNet = resnet50(False)self.cls_fc = nn.Linear(2048, num_classes)  # channels*x*y=2048*1*1,见上面的备注def forward(self, source, target):loss = 0source = self.sharedNet(source)if self.training == True:target = self.sharedNet(target)# loss += mmd.mmd_rbf_accelerate(source, target)loss += mmd.mmd_rbf_noaccelerate(source, target)source = self.cls_fc(source)#target = self.cls_fc(target)return source, lossdef resnet50(pretrained=False, **kwargs):"""Constructs a ResNet-50 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))return model

三、训练模型

#train.pyimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101
import torchvision.models.resnetdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)#数据增强操作,训练集:随机裁剪(RandomResizedCrop)、随机水平翻转(RandomHorizontalFlip)、转换为张量(ToTensor)以及归一化(Normalize)
#验证集:大小调整(Resize)、中心裁剪(CenterCrop)、转换为张量(ToTensor)以及归一化(Normalize)
data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#来自官网参数"val": transforms.Compose([transforms.Resize(256),#将最小边长缩放到256transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set pathtrain_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"])
train_num = len(train_dataset) #3306flower_list = train_dataset.class_to_idx   #{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
cla_dict = dict((val, key) for key, val in flower_list.items())  #{0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)  #将cla_dict字典对象转换为JSON格式的字符串,并通过indent=4参数指定缩进为4个空格
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)  #364
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)
#net = resnet34()
net = resnet34(num_classes=5)
# load pretrain weights# model_weight_path = "./resnet34-pre.pth"
# missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型参数# for param in net.parameters():
#     param.requires_grad = False
# change fc layer structure# inchannel = net.fc.in_features
# net.fc = nn.Linear(inchannel, 5)net.to(device)  #将神经网络模型net移动到指定的设备上,这样模型就可以在GPU/CPU上计算loss_function = nn.CrossEntropyLoss()  #损失函数
optimizer = optim.Adam(net.parameters(), lr=0.0001)  #优化器best_acc = 0.0
save_path = './resNet34.pth'
#一个epoch表示对整个训练数据集进行一次完整的迭代训练
for epoch in range(3):# trainnet.train()running_loss = 0.0#step表示当前的步数(或者称为批次数),data则表示从train_loader中加载的数据对象for step, data in enumerate(train_loader, start=0):images, labels = data  #images:(16,3,224,224) labels:(16,)optimizer.zero_grad()logits = net(images.to(device)) #logits:(16,5)将输入的图像数据images传入神经网络netloss = loss_function(logits, labels.to(device)) #1.6871 计算模型输出logits进行标准化(softmax),再计算每个样本预测标签和真实标签的交叉熵,对于整个批次的样本,计算平均交叉熵损失loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()  #累加每个批次的损失值# print train processrate = (step+1)/len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")print()# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))  # eval model only have last output layer# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]  #torch.max包含两个维度信息,第一个维度是最大值,第二个维度是最大值对应的索引acc += (predict_y == val_labels.to(device)).sum().item() #每一次的validate_loader中预测正确的个数,.item() 方法转换为标量val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)  #state_dict()方法返回模型的参数字典,save_path保存模型参数的文件路径print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

四、预测

#predict.pyimport torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondata_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load image
img = Image.open("./roses.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/152039.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

11.20 知识总结(choices参数、MVC和MTV的模式、Django与Ajax技术)

一、 choices参数的使用 1.1 作用 针对某个可以列举完全的可能性字段,我们应该如何存储 .只要某个字段的可能性是可以列举完全的,那么一般情况下都会采用choices参数 1.2 应用场景 应用场景: 学历: 小学 初中 高中 本科 硕士…

第八部分:JSP

目录 JSP概述 8.1:什么是JSP,它有什么作用? 8.2:JSP的本质是什么? 8.3:JSP的三种语法 8.3.1:jsp头部的page指令 8.3.2:jsp中的常用脚本 ①声明脚本(极少使用&#xf…

打不开github网页解决方法

问题: 1、composer更新包总是失败 2、github打不开,访问不了 解决方法:下载一个Watt Toolkit工具,勾选上,一键加速就可以打开了。 下载步骤: 1、打开网址: Watt Toolkit 2、点击【下载wind…

k8s的高可用集群搭建,详细过程实战版

kubernetes高可用集群的搭建 前面介绍过了k8s单master节点的安装部署 今天介绍一下k8s高可用集群搭建 环境准备: vip :192.168.121.99 keeplive master01:192.168.121.153 centos7 master02:192.168.121.154 centos7 master03&a…

jupyter修改默认打开目录

当我们打开jupyter notebook(不管用什么样的方式打开,使用菜单打开或者是命令行打开是一样的)会在默认的浏览器中看到这样的界面、 但是每一台不同的电脑打开之后的界面是不同的,仔细观察就会发现,这里面现实的一些文件…

基于Vue+SpringBoot的超市账单管理系统 开源项目

项目编号: S 032 ,文末获取源码。 \color{red}{项目编号:S032,文末获取源码。} 项目编号:S032,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统设计3.1 总体设计3.2 前端设计3…

【自动化测试】基于Selenium + Python的web自动化框架!

一、什么是Selenium? Selenium是一个基于浏览器的自动化工具,她提供了一种跨平台、跨浏览器的端到端的web自动化解决方案。Selenium主要包括三部分:Selenium IDE、Selenium WebDriver 和Selenium Grid:  1、Selenium IDE&…

10_6 input输入子系统,流程解析

简单分层 应用层 内核层 --------------------------- input handler 数据处理层 driver/input/evdev.c1.和用户空间交互,实现fops2.不知道数据怎么得到的,但是可以把数据上传给用户--------------------------- input core层1.维护上面和下面的两个链表2.为上下两层提供接口--…

【面试经典150 | 数学】回文数

文章目录 写在前面Tag题目来源题目解读解题思路方法一:反转一半数字 其他语言python3 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法,两到三天更新一篇文章,欢迎催更…… 专栏内容以分析题目为主,并附带一些对于本…

StoneDB顺利通过中科院软件所 2023 开源之夏 结项审核

近日,中科院软件所-开源软件供应链点亮计划-开源之夏2023的结项名单正式出炉,经过三个月的项目开发和一个多月的严格审核,共产生 418个成功结项项目!其中,StoneDB 作为本次参与开源社区,社区入选的两个项目…

单片非晶磁性测量系统磁参量指标

1. 概述 单片非晶磁性测量系统是专用于测量非晶或纳米晶薄片(带)交流磁特性的装置,由精密励磁及测量装置 ( 40 Hz~65 Hz,可定制至400 Hz )、单片磁导计、全自动测量软件组成。使用该装置可在能耗、效率、材料均匀性/一致性、可靠性、整个生命…

CSS的选择器(一篇文章齐全)

目录 Day26:CSS的选择器 1、CSS的引入方式 2、CSS的选择器 2.1 基本选择器​编辑 2.2 组合选择器 2.3 属性选择器 2.4 伪类选择器 2.5 样式继承 2.6 选择器优先级 3、CSS的属性操作 3.1 文本属性 3.2 背景属性 3.3 边框属性 3.4 列表属性 3.5 dispal…

【10套模拟】【7】

关键字: 二叉排序树插入一定是叶子、单链表简单选择排序、子串匹配、层次遍历

力扣刷题-二叉树-二叉树的高度与深度

二叉树最大深度 给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:3 递归法 本题可以使用前序(中左…

Sql Server 2017主从配置之:事务日志传送

使用事务日志传送模式搭建Sql Server 2017主从同步,该模式有一定的延迟,是通过3个不同的定时任务,将主库的日志同步到从库进行恢复来实现数据库同步操作。 该模式在同步时候,从库不可以被使用,否则同步就会失败。 环…

现货白银MACD实战分析例子

MACD这个技术指标的全称是平滑异同移动平均线,主要表示经过平滑处理后均线的差异程度,一般用来研判现货白银价格变化的方向、强度和趋势。MT4中的MACD指标,主要是由信号线、(上升/下跌)动能柱、0轴这三部分组成。 MACD…

【开源】基于Vue.js的衣物搭配系统的设计和实现

项目编号: S 016 ,文末获取源码。 \color{red}{项目编号:S016,文末获取源码。} 项目编号:S016,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 衣物档案模块2.2 衣物搭配模块2.3 衣…

MyBatis查询数据库(全是精髓)

1. 什么是MyBatis? 简单说,MyBatis就是一个完成程序与数据库交互的工具,也就是更简单的操作和读取数据库的工具。 2. 怎么学习Mybatis Mybatis学习只分为两部分: 配置MyBatis开发环境使用MyBatis模式和语法操作数据库 3. 第一…

2023年(第六届)电力机器人应用与创新发展论坛-核心PPT资料下载

一、峰会简介 大会以“聚焦电力机器人创新、助力行业数字化转型、促进产业链协同发展”为主题,展示电力机器人产业全景创新技术,探讨数字化战略下电力机器人应用前景和发展趋势。为加快推进电力机器人应用拓新,助力电网数字化转型升级&#…

彻底弄清Python软件包安装流程并解决安装错误

彻底弄清Python软件包安装流程并解决安装错误 前言:写这篇文章的初衷也是因为以前饱受Python环境配置和软件包安装的摧残,所以写下这篇文章希望帮助同样深陷泥潭的小伙伴们,该文会带你理解关于安装软件包的流程。(tips&#xff1…