霹雳吧啦Wz《pytorch图像分类》-p3VGG网络

《pytorch图像分类》p3VGG网络详解及感受野的计算

  • 一、零碎知识点
    • 1.nn.Sequential
    • 2.**kwargs
  • 二、VGG网络模型详解
    • 1.感受野
    • 2.模型手算
  • 三、代码
    • 1.module.py
    • 2.train.py
    • 3.predict.py

一、零碎知识点

论文连接:VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION
代码链接:霹雳吧啦Wzdeep-learning-for-image-processing

1.nn.Sequential

nn.Sequential是PyTorch中的一个类,用于按顺序组织和堆叠神经网络的层或模块。它提供了一种便捷的方式来构建简单的前向传播网络。

import torch
import torch.nn as nnmodel = nn.Sequential(
in_channels,out_channels,kernel_sizenn.Conv2d(in_channels,out_channels,kernel_size)nn.ReLU(),                                # 添加激活函数nn.Linear(hidden_features, out_features)  # 添加线性层
)

2.**kwargs

**kwargs是一个特殊的参数传递方式,它允许函数接受不定数量的关键字参数(Keyword Arguments)并将它们作为一个字典进行处理。

下面是一个简单的示例说明**kwargs的用法:

def example_func(**kwargs):for key, value in kwargs.items():print(key, value)example_func(name='Maverick', age=22, location='cheng du')

输出结果:

name Maverick
age 22
location cheng du

二、VGG网络模型详解

1.感受野

感受野(receptive field)是指在卷积神经网络(CNN)中的某一层输出特征图上的像素位置所对应的输入图像上的区域大小。
随着卷积核的增多(即网络的加深),感受野会越来越大。
在这里插入图片描述
当我们说一个神经网络层的感受野大小为N时,可以简单解释为:在该层输出特征图上的一个像素点,它所"看到"的输入图像区域大小是N×N。
随着网络的层数增加,感受野也会逐渐增大。最早的卷积层(例如卷积核为3x3)的感受野较小,但后续的层会通过池化或步幅更大的卷积来逐渐增加感受野的大小。

在这里插入图片描述

2.模型手算

VGG网络的常用配置是D,有16个层(包括13个卷积层和3个全连接层)

LRN是一种对神经网络中的特征图进行局部归一化的操作。其目的是增加网络的鲁棒性,防止某些特征具有过大的响应值而抑制其他特征的重要性。
具有鲁棒性的模型能够在输入数据中存在一定程度的扰动、噪声或异常情况下仍然保持良好的性能。
在这里插入图片描述
反复记忆:输出的特征矩阵的深度out_channels和卷积核的个数相同
因为彩色图形有rgb三个通道,所以最开始的特征矩阵深度为3
后面都是根据卷积核个数的不同产生不同的改变。
在这里插入图片描述

三、代码

1.module.py

import torch.nn as nn
import torch# official pretrain weights
model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, num_classes))if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.features(x)# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1)# N x 512*7*7x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_features(cfg: list):layers = []in_channels = 3for v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}def vgg(model_name="vgg16", **kwargs):assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)cfg = cfgs[model_name]model = VGG(make_features(cfg), **kwargs)return model

2.train.py

import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import vggdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 2nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()model_name = "vgg16"net = vgg(model_name=model_name, num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)epochs = 30best_acc = 0.0save_path = './{}Net.pth'.format(model_name)train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

用的是老师的代码,我的gpu内存不够,我已经将批处理大小(batch size)减少到2了,还是运行不起来
CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 2.00 GiB total capacity; 718.01 MiB already allocated; 341.00 MiB free; 740.00 MiB reserved in total by PyTorch)

3.predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import vggdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "../tulip.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = vgg(model_name="vgg16", num_classes=5).to(device)# load model weightsweights_path = "./vgg16Net.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/593674.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

结算时间和可组合性助力Sui上DeFi蓬勃发展

结算时间是基于Sui交易处理模型的度量标准,确保DeFi用户几乎立即看到交易结果。可组合性则是深深融入Sui的编程环境,扩展了其对对象和智能合约的影响。Sui深度的可组合性赋予DeFi构建者引入创新产品的能力,使其在其他区块链上的DeFi应用中独树…

基于Java校园招待所管理系统

基于Java校园招待所管理系统 功能需求 1、客房管理:系统需要管理招待所的客房信息,包括房间类型、数量、价格、状态等,并能够实时更新客房状态。 2、客人管理:系统需要记录客人的基本信息,包括姓名、性别、年龄、联…

Crypto的简单应用-前后端加密传输

最近遇到一个数据脱敏处理的需求,想要用一种轻量级的技术实现,必须足够简单并且适用于所有场合如前后端加密传输、路由加密、数据脱敏等。抽时间研究了一下Crypto加密库的一些API,发现完全符合上述需求,扩展也比较容易。 1、前端加…

[C#]使用onnxruntime部署Detic检测2万1千种类别的物体

【源码地址】 github地址:https://github.com/facebookresearch/Detic/tree/main 【算法介绍】 Detic论文:https://arxiv.org/abs/2201.02605v3 项目源码:https://github.com/facebookresearch/Detic 在Detic论文中,Detic提到…

2023.12.27 关于 Redis 数据类型 List 常用命令

目录 List 类型基本概念 List 类型特点 List 操作命令 LPUSH LPUSHX RPUSH RPUSHX LRANGE LPOP RPOP LINDEX LINSERT LREM LTRIM LSET 阻塞版本的命令 阻塞版本 和 非阻塞版本的区别 BLPOP & BRPOP List 类型基本概念 Redis 中的列表(list&am…

定时器PWM控制RGB彩灯案例

1.脉冲宽度调制PWM PWM(Pulse Width Modulation)简称脉宽调制,是利用微处理器的数字输出来对模拟电路进行控制的一种非常有效的技术,广泛应用在测量、通信、工控等方面。   PWM的一个优点是从处理器到​​ ​被控系统​​​信号…

x-cmd pkg | bit - 实验性的现代化 git CLI

目录 简介首次用户功能特点竞品和相关作品进一步探索 简介 bit,由 Chris Walz 于 2020 年使用 Go 语言开发,提供直观的命令行补全提示和建立在 git 命令之上的封装命令,旨在建立完全兼容 git 命令的现代化 CLI。 首次用户 使用 x bit 即可自…

flutter获取本地图片高度、宽度

/*获取本地图片宽度* */getLocalImageWidth(String path){int width;Completer<int> completer new Completer<int>();Image image Image.file(File.fromUri(Uri.parse(path)));// 预先获取图片信息image.image.resolve(new ImageConfiguration()).addListener(n…

test ui-03-cypress 入门介绍

cypress 是什么&#xff1f; 简而言之&#xff0c;Cypress 是一款专为现代Web构建的下一代前端测试工具。我们解决了开发人员和质量保证工程师在测试现代应用程序时面临的关键问题。 我们使以下操作成为可能&#xff1a; 设置测试编写测试运行测试调试测试 Cypress经常与Se…

使用宝塔在Linux面板搭建网站,并实现公网远程访问

文章目录 前言1. 环境安装2. 安装cpolar内网穿透3. 内网穿透4. 固定http地址5. 配置二级子域名6. 创建一个测试页面 前言 宝塔面板作为简单好用的服务器运维管理面板&#xff0c;它支持Linux/Windows系统&#xff0c;我们可用它来一键配置LAMP/LNMP环境、网站、数据库、FTP等&…

基于多反应堆的高并发服务器【C/C++/Reactor】(中)处理任务队列中的任务

一、处理任务队列中的任务 &#xff08;1&#xff09;EventLoop启动 EventLoop初始化和启动 // 启动反应堆模型 int eventLoopRun(struct EventLoop* evLoop) {assert(evLoop ! NULL);// 取出事件分发和检测模型struct Dispatcher* dispatcher evLoop->dispatcher;// 比较…

2024阿里云Alibaba Cloud Linux 3镜像版本大全说明

Alibaba Cloud Linux阿里云打造的Linux服务器操作系统发行版&#xff0c;Alibaba Cloud Linux完全兼容完全兼容CentOS/RHEL生态和操作方式&#xff0c;目前已经推出Alibaba Cloud Linux 3&#xff0c;阿里云百科aliyunbaike.com分享Alibaba Cloud Linux 3版本特性说明&#xff…

面试算法83:没有重复元素集合的全排列

题目 给定一个没有重复数字的集合&#xff0c;请找出它的所有全排列。例如&#xff0c;集合[1&#xff0c;2&#xff0c;3]有6个全排列&#xff0c;分别是[1&#xff0c;2&#xff0c;3]、[1&#xff0c;3&#xff0c;2]、[2&#xff0c;1&#xff0c;3]、[2&#xff0c;3&…

如何在anaconda里安装basemap和pyproj库

当直接使用conda命令进行安装basemap和pyproj库时&#xff0c;会出现版本不对应的报错问题(如下图)&#xff0c;所以此篇博客用以展示如何安装basemap和pyproj库 题主默认使用的anaconda源已经切换成了清华大学源&#xff0c;但是仍然会出现报错&#xff0c;所以不是源的问题&a…

haproxy笔记

文章目录 场景haproxy配置文档地址 场景 还得先从场景说起。 生产环境redis检查&#xff0c;发现配置的redis地址不对。 redis有3个节点。 192.168.0.1 192.168.0.2 192.168.0.3 但是配置的是 192.168.0.9 端口是16379。 好奇怪有没有&#xff0c;是不是配错了? 问了下部署大…

CMake入门教程【核心篇】函数(function)

&#x1f608;「CSDN主页」&#xff1a;传送门 &#x1f608;「Bilibil首页」&#xff1a;传送门 &#x1f608;「本文的内容」&#xff1a;CMake入门教程 &#x1f608;「动动你的小手」&#xff1a;点赞&#x1f44d;收藏⭐️评论&#x1f4dd; 文章目录 1. 函数的定义与基本…

vue3对比vue2是怎样的

一、前言 Vue 3通过引入Composition API、升级响应式系统、优化性能等一系列的改进和升级,提供了更好的开发体验和更好的性能,使得开发者能够更方便地开发出高质量的Web应用。它在Vue.js 2的基础上进行了一系列的改进和升级,以提供更好的性能、更好的开发体验和更好的扩展性…

labview 与三菱FX 小型PLC通信(OPC)

NI OPC服务器与三菱FX3U PLC通讯方法 一、新建通道名称为&#xff1a;MIT 二、选择三菱FX系列 三、确认端口号相关的参数&#xff08;COM端&#xff1a;7.波特率&#xff1a;9600&#xff0c;数据位&#xff1a;7&#xff0c;校验&#xff1a;奇校验&#xff0c;停止位&#xf…

海外住宅IP代理的工作原理和应用场景分析,新手必看

海外住宅IP代理作为一种技术解决方案&#xff0c;为用户提供了访问全球网络资源和维护隐私安全的方法。本文将介绍海外住宅IP代理的工作原理和应用场景&#xff0c;帮助读者更好地理解和利用这一技术。 一、工作原理 海外住宅IP代理的工作原理基于代理服务器和IP地址的转发。它…

ITSS服务工程师vs ITSS服务经理:哪个职位更适合你?

✨在信息技术服务领域&#xff0c;ITSS服务工程师和ITSS服务经理是两个极具吸引力的职位。但它们各自的特点和要求是什么&#xff1f;哪个更适合你的职业规划和个人兴趣&#xff1f;接下来&#xff0c;我们将为你详细解读这两个职位的区别&#xff0c;帮助你做出明智的选择&…