基于ResNet框架的CNN

数据准备

DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

一、训练集和验证集的划分

#spile_data.pyimport os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla] #['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
mkfile('flower_data/train') #生成train文件夹
for cla in flower_class:mkfile('flower_data/train/'+cla) #在train文件夹下生成各个类别的文件夹mkfile('flower_data/val')
for cla in flower_class:mkfile('flower_data/val/'+cla)split_rate = 0.1
for cla in flower_class:cla_path = file + '/' + cla + '/'images = os.listdir(cla_path)num = len(images)eval_index = random.sample(images, k=int(num*split_rate)) #在images中随机获取0.1的图片for index, image in enumerate(images):if image in eval_index:image_path = cla_path + imagenew_path = 'flower_data/val/' + clacopy(image_path, new_path)else:image_path = cla_path + imagenew_path = 'flower_data/train/' + clacopy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")

二、ResNet网络

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import mmd
from attention import ChannelAttention
from attention import SpatialAttention
import torch__all__ = ['ResNet', 'resnet50']model_urls = {'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}def conv3x3(in_planes, out_planes, stride=1,groups=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, bias=False)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)'''
Resnet中BasicBlock结构,ResNet中使用的网络结构。分2步走:3x3; 3x3
'''
class BasicBlock(nn.Module):expansion = 1  # 最后输出的通道数扩充的比例# BN层来加快网络模型的收敛速度/训练速度/解决梯度消失或者梯度爆炸的问题# 对batch中所有的同一个channel的数据元素进行标准化处理。一个batch共享一套参数# 即如果有C个通道,对N*H*W进行标准化处理,一共进行C次。def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)# downsample是用一个1x1的卷积核处理,改变通道数,如果H/W尺度也不一样就设计strideif self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return out'''
Resnet中Bottleneck结构,ResNet中使用的网络结构。目的是为了降低参数量,分三步走:
1数据降维(1x1),2常规卷积核的卷积(3x3),3数据升维(1x1)
结果图片长宽不变,通道数扩大4倍
'''
class Bottleneck(nn.Module):expansion = 4  # 最后输出的通道数扩充的比例def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, norm_layer=None):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)# BN层来加快网络模型的收敛速度/训练速度/解决梯度消失或者梯度爆炸的问题# 对batch中所有的同一个channel的数据元素进行标准化处理。一个batch共享一套参数# 即如果有C个通道,对N*H*W进行标准化处理,一共进行C次。self.bn1 = nn.BatchNorm2d(planes)  # 卷积层后加BatchNorm2d,按照channel进行归一化self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(planes * 4)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.dropout=nn.Dropout()self.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)# downsample是用一个1x1的卷积核处理,改变通道数,如果H/W尺度也不一样就设计strideif self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return out'''
ResNet由以下组成:
1.conv1、norm1、relu(当指定了deep_stem,这三个将被stem代替)
2.maxpool
3.layer1~layer4(定义为ResLayer类,分别由多个BasicBlock或Bottleneck组成)
'''
class ResNet(nn.Module):# 参数block指明残差块是两层或三层,参数layers指明每个卷积层需要的残差块数量,num_classes指明分类数,zero_init_residual是否初始化为0def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, norm_layer=None):super(ResNet, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself.inplanes = 64self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(12, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)# 网络的第一层加入注意力机制# self.ca = ChannelAttention(self.inplanes)# self.sa = SpatialAttention()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)# 网络的卷积层的最后一层加入注意力机制# self.ca1 = ChannelAttention(self.inplanes)# self.sa1 = SpatialAttention()self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应平均池化,指定输出(H,W)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)# 构造ResLayer类,layer1~layer4# block:BasicBlock/Bottleneck; planes:块的输入通道数; blocks:块的数目def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):if norm_layer is None:norm_layer = nn.BatchNorm2ddownsample = None  # downSample的作用于在残差连接时 将输入的图像的通道数变成和卷积操作的尺寸一致if stride != 1 or self.inplanes != planes * block.expansion:# 通道数恢复成一致/长宽恢复一致downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, norm_layer=norm_layer))return nn.Sequential(*layers)'''卷积/池化后的tensor维度为(batchsize,channels,x,y),其中x.size(0)指batchsize的值,通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y)即将(channels,x,y)拉直,然后就可以和fc层连接因为最后avgpool(1,1)指定输出长*宽为1*1,通道为512*4,所以channels*x*y=2048'''def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)# x = self.ca(x) * x# x = self.sa(x) * xx = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# x = self.ca1(x) * x# x = self.sa1(x) * xx = self.avgpool(x)x = x.view(x.size(0), -1)#x=self.fc(x)return xclass DANNet(nn.Module):def __init__(self, num_classes=2):super(DANNet, self).__init__()self.sharedNet = resnet50(False)self.cls_fc = nn.Linear(2048, num_classes)  # channels*x*y=2048*1*1,见上面的备注def forward(self, source, target):loss = 0source = self.sharedNet(source)if self.training == True:target = self.sharedNet(target)# loss += mmd.mmd_rbf_accelerate(source, target)loss += mmd.mmd_rbf_noaccelerate(source, target)source = self.cls_fc(source)#target = self.cls_fc(target)return source, lossdef resnet50(pretrained=False, **kwargs):"""Constructs a ResNet-50 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))return model

三、训练模型

#train.pyimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101
import torchvision.models.resnetdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)#数据增强操作,训练集:随机裁剪(RandomResizedCrop)、随机水平翻转(RandomHorizontalFlip)、转换为张量(ToTensor)以及归一化(Normalize)
#验证集:大小调整(Resize)、中心裁剪(CenterCrop)、转换为张量(ToTensor)以及归一化(Normalize)
data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#来自官网参数"val": transforms.Compose([transforms.Resize(256),#将最小边长缩放到256transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set pathtrain_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"])
train_num = len(train_dataset) #3306flower_list = train_dataset.class_to_idx   #{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
cla_dict = dict((val, key) for key, val in flower_list.items())  #{0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)  #将cla_dict字典对象转换为JSON格式的字符串,并通过indent=4参数指定缩进为4个空格
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)  #364
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)
#net = resnet34()
net = resnet34(num_classes=5)
# load pretrain weights# model_weight_path = "./resnet34-pre.pth"
# missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型参数# for param in net.parameters():
#     param.requires_grad = False
# change fc layer structure# inchannel = net.fc.in_features
# net.fc = nn.Linear(inchannel, 5)net.to(device)  #将神经网络模型net移动到指定的设备上,这样模型就可以在GPU/CPU上计算loss_function = nn.CrossEntropyLoss()  #损失函数
optimizer = optim.Adam(net.parameters(), lr=0.0001)  #优化器best_acc = 0.0
save_path = './resNet34.pth'
#一个epoch表示对整个训练数据集进行一次完整的迭代训练
for epoch in range(3):# trainnet.train()running_loss = 0.0#step表示当前的步数(或者称为批次数),data则表示从train_loader中加载的数据对象for step, data in enumerate(train_loader, start=0):images, labels = data  #images:(16,3,224,224) labels:(16,)optimizer.zero_grad()logits = net(images.to(device)) #logits:(16,5)将输入的图像数据images传入神经网络netloss = loss_function(logits, labels.to(device)) #1.6871 计算模型输出logits进行标准化(softmax),再计算每个样本预测标签和真实标签的交叉熵,对于整个批次的样本,计算平均交叉熵损失loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()  #累加每个批次的损失值# print train processrate = (step+1)/len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")print()# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))  # eval model only have last output layer# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]  #torch.max包含两个维度信息,第一个维度是最大值,第二个维度是最大值对应的索引acc += (predict_y == val_labels.to(device)).sum().item() #每一次的validate_loader中预测正确的个数,.item() 方法转换为标量val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)  #state_dict()方法返回模型的参数字典,save_path保存模型参数的文件路径print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

四、预测

#predict.pyimport torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondata_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load image
img = Image.open("./roses.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/152039.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PHP笔记-->读取JSON数据以及获取读取到的JSON里边的数据

由于我以前是写C#的,现在学一下PHP, 在读取json数据的时候被以前的思维卡住了。 以前用C#读取的时候,是先定义一个数组,将反序列化的json存到数组里面,在从数组里面获取jaon中的“data”数据。 其实PHP的思路也是一样…

ASP.NET 开发几个知识点

1、 皮肤设定&#xff1a; 项目右键&#xff0c;建立皮肤 app_themes 文件夹&#xff0c;右键 建立 web from 皮肤文件&#xff0c; 设定皮肤样式。全局使用皮肤 web.config 增加 <pages styleSheetTheme"Skin1" /> &#xff0c;或在 具体页面 头 增加 sty…

算法leetcode|89. 格雷编码(rust重拳出击)

文章目录 89. 格雷编码&#xff1a;样例 1&#xff1a;样例 2&#xff1a;提示&#xff1a; 分析&#xff1a;题解&#xff1a;rust&#xff1a;go&#xff1a;c&#xff1a;python&#xff1a;java&#xff1a; 89. 格雷编码&#xff1a; n 位格雷码序列 是一个由 2n 个整数组…

11.20 知识总结(choices参数、MVC和MTV的模式、Django与Ajax技术)

一、 choices参数的使用 1.1 作用 针对某个可以列举完全的可能性字段&#xff0c;我们应该如何存储 .只要某个字段的可能性是可以列举完全的&#xff0c;那么一般情况下都会采用choices参数 1.2 应用场景 应用场景&#xff1a; 学历&#xff1a; 小学 初中 高中 本科 硕士…

第八部分:JSP

目录 JSP概述 8.1&#xff1a;什么是JSP&#xff0c;它有什么作用&#xff1f; 8.2&#xff1a;JSP的本质是什么&#xff1f; 8.3&#xff1a;JSP的三种语法 8.3.1&#xff1a;jsp头部的page指令 8.3.2&#xff1a;jsp中的常用脚本 ①声明脚本&#xff08;极少使用&#xf…

打不开github网页解决方法

问题&#xff1a; 1、composer更新包总是失败 2、github打不开&#xff0c;访问不了 解决方法&#xff1a;下载一个Watt Toolkit工具&#xff0c;勾选上&#xff0c;一键加速就可以打开了。 下载步骤&#xff1a; 1、打开网址&#xff1a; Watt Toolkit 2、点击【下载wind…

k8s的高可用集群搭建,详细过程实战版

kubernetes高可用集群的搭建 前面介绍过了k8s单master节点的安装部署 今天介绍一下k8s高可用集群搭建 环境准备&#xff1a; vip &#xff1a;192.168.121.99 keeplive master01&#xff1a;192.168.121.153 centos7 master02&#xff1a;192.168.121.154 centos7 master03&a…

jupyter修改默认打开目录

当我们打开jupyter notebook&#xff08;不管用什么样的方式打开&#xff0c;使用菜单打开或者是命令行打开是一样的&#xff09;会在默认的浏览器中看到这样的界面、 但是每一台不同的电脑打开之后的界面是不同的&#xff0c;仔细观察就会发现&#xff0c;这里面现实的一些文件…

基于Vue+SpringBoot的超市账单管理系统 开源项目

项目编号&#xff1a; S 032 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S032&#xff0c;文末获取源码。} 项目编号&#xff1a;S032&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统设计3.1 总体设计3.2 前端设计3…

【自动化测试】基于Selenium + Python的web自动化框架!

一、什么是Selenium&#xff1f; Selenium是一个基于浏览器的自动化工具&#xff0c;她提供了一种跨平台、跨浏览器的端到端的web自动化解决方案。Selenium主要包括三部分&#xff1a;Selenium IDE、Selenium WebDriver 和Selenium Grid&#xff1a;  1、Selenium IDE&…

11.17 知识总结(事务、常见的字段类型等)

一、 事务 1.1 如何开启事务 前言 事务是MySQL数据库中得一个重要概念 事务的目的&#xff1a;为了保证多个SQL语句执行成功&#xff0c;执行失败&#xff0c;前后保持一致&#xff0c;保证数据安全 ACID属性&#xff1a; A&#xff1a; C&#xff1a; I&#xff1a; D&#x…

10_6 input输入子系统,流程解析

简单分层 应用层 内核层 --------------------------- input handler 数据处理层 driver/input/evdev.c1.和用户空间交互,实现fops2.不知道数据怎么得到的,但是可以把数据上传给用户--------------------------- input core层1.维护上面和下面的两个链表2.为上下两层提供接口--…

window拖拽操作的实现

调用DragAcceptFiles&#xff0c;让控件或者窗体支持文件拖动操作 void DragAcceptFiles(HWND hWnd, //指明目标窗体的句柄BOOL fAccept //为True时 则hWnd所指向的窗体可以接受拖放的文件. );窗口消息过程处理WM_DROPFILES消息。 在WM_DROPFILES消息处理过程中&…

【面试经典150 | 数学】回文数

文章目录 写在前面Tag题目来源题目解读解题思路方法一&#xff1a;反转一半数字 其他语言python3 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法&#xff0c;两到三天更新一篇文章&#xff0c;欢迎催更…… 专栏内容以分析题目为主&#xff0c;并附带一些对于本…

StoneDB顺利通过中科院软件所 2023 开源之夏 结项审核

近日&#xff0c;中科院软件所-开源软件供应链点亮计划-开源之夏2023的结项名单正式出炉&#xff0c;经过三个月的项目开发和一个多月的严格审核&#xff0c;共产生 418个成功结项项目&#xff01;其中&#xff0c;StoneDB 作为本次参与开源社区&#xff0c;社区入选的两个项目…

EXPLAIN命令使用及功能介绍

当你不确定某个select语句执行会不会影响数据库cpu怎么办&#xff1f;&#xff0c;使用EXPLAIN命令&#xff0c;给你分析该不该执行&#xff01;&#xff01; EXPLAIN命令介绍 MySQL的EXPLAIN命令是一个查询优化工具&#xff0c;用于分析和评估SELECT语句的执行计划。它提供了…

制作含有音频、视频的网页

参考代码如下 <!DOCTYPE html> <html> <head><title>视频音乐网页</title> </head> <body><!-- 视频 --><video width"320" height"240" controls><source src"movie.mp4" type"…

python——第十一天

今日目标&#xff1a; 模块和包的基本概念 python中模块的导入问题 main函数的作用和使用 常见内置模块的使用 IO流相关 模块和包的基本概念&#xff1a; 模块&#xff08;module&#xff09;&#xff1a;一个.py文件就是一个模块 包&#xff08;package&#xff09;&#xff1…

单片非晶磁性测量系统磁参量指标

1. 概述 单片非晶磁性测量系统是专用于测量非晶或纳米晶薄片(带)交流磁特性的装置&#xff0c;由精密励磁及测量装置 ( 40 Hz&#xff5e;65 Hz&#xff0c;可定制至400 Hz )、单片磁导计、全自动测量软件组成。使用该装置可在能耗、效率、材料均匀性/一致性、可靠性、整个生命…

CSS的选择器(一篇文章齐全)

目录 Day26&#xff1a;CSS的选择器 1、CSS的引入方式 2、CSS的选择器 2.1 基本选择器​编辑 2.2 组合选择器 2.3 属性选择器 2.4 伪类选择器 2.5 样式继承 2.6 选择器优先级 3、CSS的属性操作 3.1 文本属性 3.2 背景属性 3.3 边框属性 3.4 列表属性 3.5 dispal…