霹雳吧啦Wz《pytorch图像分类》-p3VGG网络

《pytorch图像分类》p3VGG网络详解及感受野的计算

  • 一、零碎知识点
    • 1.nn.Sequential
    • 2.**kwargs
  • 二、VGG网络模型详解
    • 1.感受野
    • 2.模型手算
  • 三、代码
    • 1.module.py
    • 2.train.py
    • 3.predict.py

一、零碎知识点

论文连接:VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION
代码链接:霹雳吧啦Wzdeep-learning-for-image-processing

1.nn.Sequential

nn.Sequential是PyTorch中的一个类,用于按顺序组织和堆叠神经网络的层或模块。它提供了一种便捷的方式来构建简单的前向传播网络。

import torch
import torch.nn as nnmodel = nn.Sequential(
in_channels,out_channels,kernel_sizenn.Conv2d(in_channels,out_channels,kernel_size)nn.ReLU(),                                # 添加激活函数nn.Linear(hidden_features, out_features)  # 添加线性层
)

2.**kwargs

**kwargs是一个特殊的参数传递方式,它允许函数接受不定数量的关键字参数(Keyword Arguments)并将它们作为一个字典进行处理。

下面是一个简单的示例说明**kwargs的用法:

def example_func(**kwargs):for key, value in kwargs.items():print(key, value)example_func(name='Maverick', age=22, location='cheng du')

输出结果:

name Maverick
age 22
location cheng du

二、VGG网络模型详解

1.感受野

感受野(receptive field)是指在卷积神经网络(CNN)中的某一层输出特征图上的像素位置所对应的输入图像上的区域大小。
随着卷积核的增多(即网络的加深),感受野会越来越大。
在这里插入图片描述
当我们说一个神经网络层的感受野大小为N时,可以简单解释为:在该层输出特征图上的一个像素点,它所"看到"的输入图像区域大小是N×N。
随着网络的层数增加,感受野也会逐渐增大。最早的卷积层(例如卷积核为3x3)的感受野较小,但后续的层会通过池化或步幅更大的卷积来逐渐增加感受野的大小。

在这里插入图片描述

2.模型手算

VGG网络的常用配置是D,有16个层(包括13个卷积层和3个全连接层)

LRN是一种对神经网络中的特征图进行局部归一化的操作。其目的是增加网络的鲁棒性,防止某些特征具有过大的响应值而抑制其他特征的重要性。
具有鲁棒性的模型能够在输入数据中存在一定程度的扰动、噪声或异常情况下仍然保持良好的性能。
在这里插入图片描述
反复记忆:输出的特征矩阵的深度out_channels和卷积核的个数相同
因为彩色图形有rgb三个通道,所以最开始的特征矩阵深度为3
后面都是根据卷积核个数的不同产生不同的改变。
在这里插入图片描述

三、代码

1.module.py

import torch.nn as nn
import torch# official pretrain weights
model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, num_classes))if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.features(x)# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1)# N x 512*7*7x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_features(cfg: list):layers = []in_channels = 3for v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}def vgg(model_name="vgg16", **kwargs):assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)cfg = cfgs[model_name]model = VGG(make_features(cfg), **kwargs)return model

2.train.py

import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import vggdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 2nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()model_name = "vgg16"net = vgg(model_name=model_name, num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)epochs = 30best_acc = 0.0save_path = './{}Net.pth'.format(model_name)train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

用的是老师的代码,我的gpu内存不够,我已经将批处理大小(batch size)减少到2了,还是运行不起来
CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 2.00 GiB total capacity; 718.01 MiB already allocated; 341.00 MiB free; 740.00 MiB reserved in total by PyTorch)

3.predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import vggdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "../tulip.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = vgg(model_name="vgg16", num_classes=5).to(device)# load model weightsweights_path = "./vgg16Net.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/593674.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

结算时间和可组合性助力Sui上DeFi蓬勃发展

结算时间是基于Sui交易处理模型的度量标准,确保DeFi用户几乎立即看到交易结果。可组合性则是深深融入Sui的编程环境,扩展了其对对象和智能合约的影响。Sui深度的可组合性赋予DeFi构建者引入创新产品的能力,使其在其他区块链上的DeFi应用中独树…

[C#]使用onnxruntime部署Detic检测2万1千种类别的物体

【源码地址】 github地址:https://github.com/facebookresearch/Detic/tree/main 【算法介绍】 Detic论文:https://arxiv.org/abs/2201.02605v3 项目源码:https://github.com/facebookresearch/Detic 在Detic论文中,Detic提到…

2023.12.27 关于 Redis 数据类型 List 常用命令

目录 List 类型基本概念 List 类型特点 List 操作命令 LPUSH LPUSHX RPUSH RPUSHX LRANGE LPOP RPOP LINDEX LINSERT LREM LTRIM LSET 阻塞版本的命令 阻塞版本 和 非阻塞版本的区别 BLPOP & BRPOP List 类型基本概念 Redis 中的列表(list&am…

定时器PWM控制RGB彩灯案例

1.脉冲宽度调制PWM PWM(Pulse Width Modulation)简称脉宽调制,是利用微处理器的数字输出来对模拟电路进行控制的一种非常有效的技术,广泛应用在测量、通信、工控等方面。   PWM的一个优点是从处理器到​​ ​被控系统​​​信号…

x-cmd pkg | bit - 实验性的现代化 git CLI

目录 简介首次用户功能特点竞品和相关作品进一步探索 简介 bit,由 Chris Walz 于 2020 年使用 Go 语言开发,提供直观的命令行补全提示和建立在 git 命令之上的封装命令,旨在建立完全兼容 git 命令的现代化 CLI。 首次用户 使用 x bit 即可自…

test ui-03-cypress 入门介绍

cypress 是什么? 简而言之,Cypress 是一款专为现代Web构建的下一代前端测试工具。我们解决了开发人员和质量保证工程师在测试现代应用程序时面临的关键问题。 我们使以下操作成为可能: 设置测试编写测试运行测试调试测试 Cypress经常与Se…

使用宝塔在Linux面板搭建网站,并实现公网远程访问

文章目录 前言1. 环境安装2. 安装cpolar内网穿透3. 内网穿透4. 固定http地址5. 配置二级子域名6. 创建一个测试页面 前言 宝塔面板作为简单好用的服务器运维管理面板,它支持Linux/Windows系统,我们可用它来一键配置LAMP/LNMP环境、网站、数据库、FTP等&…

如何在anaconda里安装basemap和pyproj库

当直接使用conda命令进行安装basemap和pyproj库时,会出现版本不对应的报错问题(如下图),所以此篇博客用以展示如何安装basemap和pyproj库 题主默认使用的anaconda源已经切换成了清华大学源,但是仍然会出现报错,所以不是源的问题&a…

CMake入门教程【核心篇】函数(function)

😈「CSDN主页」:传送门 😈「Bilibil首页」:传送门 😈「本文的内容」:CMake入门教程 😈「动动你的小手」:点赞👍收藏⭐️评论📝 文章目录 1. 函数的定义与基本…

labview 与三菱FX 小型PLC通信(OPC)

NI OPC服务器与三菱FX3U PLC通讯方法 一、新建通道名称为:MIT 二、选择三菱FX系列 三、确认端口号相关的参数(COM端:7.波特率:9600,数据位:7,校验:奇校验,停止位&#xf…

海外住宅IP代理的工作原理和应用场景分析,新手必看

海外住宅IP代理作为一种技术解决方案,为用户提供了访问全球网络资源和维护隐私安全的方法。本文将介绍海外住宅IP代理的工作原理和应用场景,帮助读者更好地理解和利用这一技术。 一、工作原理 海外住宅IP代理的工作原理基于代理服务器和IP地址的转发。它…

ITSS服务工程师vs ITSS服务经理:哪个职位更适合你?

✨在信息技术服务领域,ITSS服务工程师和ITSS服务经理是两个极具吸引力的职位。但它们各自的特点和要求是什么?哪个更适合你的职业规划和个人兴趣?接下来,我们将为你详细解读这两个职位的区别,帮助你做出明智的选择&…

Win32 基本程序设计原理总结

目录 1. Windows系统 基本原理 2. 需要什么函数库(.LIB) 2.1 C Runtimes: 2.2 Windows API 3. 需要什么头文件(.H) 4. Windows 程序运行的本质 5. 窗口类的注册与窗口的诞生 6.消息 6.1 消息分类:…

咖啡茶饮营销不止「9 块 9」,门店「VACS」需要全面提升

每一座城市 CBD 的写字楼下和热门商圈的街边,都是咖啡茶饮的战场。作为餐饮行业的热门赛道,咖啡茶饮近年来一直保持高速增长。据统计,截至今年 10 月 31 日,陆陆续续又有约 15 万家店铺开门营业…… 白热化竞争下,茶饮…

2023到2024年:前端发展趋势展望

本文探讨了2023年至2024年之间前端领域的发展趋势。我们将关注以下几个方面的变化:无代码/低代码开发的兴起、WebAssembly的广泛应用、跨平台技术的发展、人工智能在前端的应用以及用户体验的不断优化。 随着技术的飞速发展,前端开发在推动互联网与移动应…

Google Gemini接口调用(node版)

一、打开Google AI Studio https://makersuite.google.com/app/apikey 二、在国外服务器上部署一个接口用于真正的请求 const sdAxiosOnAzure async (req, res) > {let {config {url: https://sinkin.ai/api/inference,method: post,data: {},timeout: 30 * 60 * 1000,}…

Python爬虫中的协程

协程 基本概念 协程:当程序执行的某一个任务遇到了IO操作时(处于阻塞状态),不让CPU切换走(就是不让CPU去执行其他程序),而是选择性的切换到其他任务上,让CPU执行新的任务&#xff…

网络安全—认证技术

文章目录 加密认证对称密钥体制公钥密码体制公钥的加密公钥身份认证和加密 鉴别码认证MAC鉴别码 报文摘要认证认证 加密只认证数字签名 通过了解以前前辈们使用的消息认证慢慢渐进到现代的完整的认证体系。所以在学习的时候也很蒙圈,因为前期的很多技术都是有很严重…

这次,数据泄露的目标受害者指向了---救护车服务公司

已停业的救护车服务遭到勒索软件攻击导致近百万人受到威胁! 此次数据泄露的目标受害者是法伦救护车服务公司,该公司是Transformative Healthcare的子公司。ALPHV勒索软件团伙声称对2023年4月下旬对Transformative Healthcare的攻击负责,并导…

SpringBoot 集成支付宝支付

网页操作步骤 1.进入支付宝开发平台—沙箱环境 使用开发者账号登录开放平台控制平台 2.点击沙箱进入沙箱环境 说明:沙箱环境支持的产品,可以在沙箱控制台 沙箱应用 > 产品列表 中查看。 3.进入沙箱,配置接口加签方式 在沙箱进行调试前…