GoogLeNet(pytorch)

亮点与创新:

1. 引入Inception基础结构

2. 引入PW维度变换卷积,启迪后续参数量的优化

3. 丢弃全连接层,使用平均池化层(大大减少模型参数)

4. 添加两个辅助分类器帮助训练(避免梯度消失,用于向前传导梯度,也有一定的正则化效果,防止过拟合)

 model.py

import torch.nn as nn
import torch
import torch.nn.functional as Fclass GoogLeNet(nn.Module):def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):super(GoogLeNet, self).__init__()self.aux_logits = aux_logitsself.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.conv2 = BasicConv2d(64, 64, kernel_size=1)self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)if self.aux_logits:self.aux1 = InceptionAux(512, num_classes)self.aux2 = InceptionAux(528, num_classes)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.dropout = nn.Dropout(0.4)self.fc = nn.Linear(1024, num_classes)if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.conv1(x)# N x 64 x 112 x 112x = self.maxpool1(x)# N x 64 x 56 x 56x = self.conv2(x)# N x 64 x 56 x 56x = self.conv3(x)# N x 192 x 56 x 56x = self.maxpool2(x)# N x 192 x 28 x 28x = self.inception3a(x)# N x 256 x 28 x 28x = self.inception3b(x)# N x 480 x 28 x 28x = self.maxpool3(x)# N x 480 x 14 x 14x = self.inception4a(x)# N x 512 x 14 x 14if self.training and self.aux_logits:    # eval model lose this layeraux1 = self.aux1(x)x = self.inception4b(x)# N x 512 x 14 x 14x = self.inception4c(x)# N x 512 x 14 x 14x = self.inception4d(x)# N x 528 x 14 x 14if self.training and self.aux_logits:    # eval model lose this layeraux2 = self.aux2(x)x = self.inception4e(x)# N x 832 x 14 x 14x = self.maxpool4(x)# N x 832 x 7 x 7x = self.inception5a(x)# N x 832 x 7 x 7x = self.inception5b(x)# N x 1024 x 7 x 7x = self.avgpool(x)# N x 1024 x 1 x 1x = torch.flatten(x, 1)# N x 1024x = self.dropout(x)x = self.fc(x)# N x 1000 (num_classes)if self.training and self.aux_logits:   # eval model lose this layerreturn x, aux2, aux1return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)class Inception(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):super(Inception, self).__init__()self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)self.branch2 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, kernel_size=1),BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小)self.branch3 = nn.Sequential(BasicConv2d(in_channels, ch5x5red, kernel_size=1),# 在官方的实现中,其实是3x3的kernel并不是5x5,这里我也懒得改了,具体可以参考下面的issue# Please see https://github.com/pytorch/vision/issues/906 for details.BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小)self.branch4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),BasicConv2d(in_channels, pool_proj, kernel_size=1))def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)branch4 = self.branch4(x)outputs = [branch1, branch2, branch3, branch4]return torch.cat(outputs, 1)class InceptionAux(nn.Module):def __init__(self, in_channels, num_classes):super(InceptionAux, self).__init__()self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]self.fc1 = nn.Linear(2048, 1024)self.fc2 = nn.Linear(1024, num_classes)def forward(self, x):# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14x = self.averagePool(x)# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4x = self.conv(x)# N x 128 x 4 x 4x = torch.flatten(x, 1)x = F.dropout(x, 0.5, training=self.training)# N x 2048x = F.relu(self.fc1(x), inplace=True)x = F.dropout(x, 0.5, training=self.training)# N x 1024x = self.fc2(x)# N x num_classesreturn xclass BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, **kwargs):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.relu(x)return x

train.py

模型有三个返回结果,一个预测,两个辅助分类器

import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import GoogLeNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)# 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型# 官方的模型中使用了bn层以及改了一些参数,不能混用# import torchvision# net = torchvision.models.googlenet(num_classes=5)# model_dict = net.state_dict()# # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth# pretrain_model = torch.load("googlenet.pth")# del_list = ["aux1.fc2.weight", "aux1.fc2.bias",#             "aux2.fc2.weight", "aux2.fc2.bias",#             "fc.weight", "fc.bias"]# pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}# model_dict.update(pretrain_dict)# net.load_state_dict(model_dict)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0003)epochs = 30best_acc = 0.0save_path = './googleNet.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits, aux_logits2, aux_logits1 = net(images.to(device))loss0 = loss_function(logits, labels.to(device))loss1 = loss_function(aux_logits1, labels.to(device))loss2 = loss_function(aux_logits2, labels.to(device))loss = loss0 + loss1 * 0.3 + loss2 * 0.3loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))  # eval model only have last output layerpredict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

注意训练代码里的这段注释代码:

    import torchvisionnet = torchvision.models.googlenet(num_classes=5)#如果用的是官方提供的torchvision.models.googlenet模型,下载下来官方训练的权重是在imagenet1000分类上训练到的权重,需要删除一些层的权重来训练,而且训练的模型直接就是官方的模型,得到自己适应训练集的权重,简单总结,官方模型,官方权重,自己数据集,稍微修改官方权重,使用稍微修改的官方权重训练官方模型下自己数据集的权重
#对于怎么冻结权重,解冻权重,由于这个模型学习时还在训练前面的VGG网络,到resnet进行学习#若想用官方权重训练自己的模型得到适应数据集的权重,也应该是修改一下官方权重
#之后resnet细说model_dict = net.state_dict()# 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pthpretrain_model = torch.load("googlenet.pth")del_list = ["aux1.fc2.weight", "aux1.fc2.bias","aux2.fc2.weight", "aux2.fc2.bias","fc.weight", "fc.bias"]pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}model_dict.update(pretrain_dict)net.load_state_dict(model_dict)

predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import GoogLeNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "../tulip.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimension#对单张图片添加维度,以适应模型#如果是批次预测,之后注意批次预测的代码怎么设计img = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = GoogLeNet(num_classes=5, aux_logits=False).to(device)# load model weightsweights_path = "./googleNet.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),strict=False)model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227495.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NXP应用随记(五):eMios功能点阅读随记

目录 1、概念点 2、eMios功能点 2.1、eMIOS - Single Action Input Capture (SAIC) 2.2、eMIOS - Single Action Output Compare (SAOC) 2.3、eMIOS - Double Action Output Compare (DAOC) 2.4、eMIOS - Pulse/Edge Counting (PEC) – Single Shot 2.5、eMIOS - Pulse/E…

rpc和http的区别,使⽤场景

rpc和http的区别,使⽤场景 区别如下:传输协议传输效率性能消耗负载均衡性能表现使用场景: 区别如下: 传输协议 RPC:可以基于TCP协议,也可以基于HTTP协议HTTP:基于HTTP协议 传输效率 RPC&…

贪吃蛇小游戏

目录 头文件代码 函数实现代码 测试时代码 本游戏的实现需要用到链表&#xff0c;结构体&#xff0c;win32API&#xff0c;枚举等相关知识。 头文件代码 #pragma once#include<locale.h> #include<stdlib.h> #include<Windows.h> #include<stdbool.h&…

K8S(五)—命名空间与资源配额

目录 命名空间(Namespace)命令计算资源配额创建命名空间绑定一个ResourceQuota资源将命名空间和资源限制对象进行绑定尝试创建第二个 Pod查看ResourceQuota 绑定第二个ResourceQuota为命名空间配置默认的 CPU 、memory请求和限制&#xff08;1&#xff09;Pod 中所有容器都没有…

[Verilog] 设计方法和设计流程

主页&#xff1a; 元存储博客 文章目录 1. 设计方法2. 设计流程 3 Vivado软件设计流程总结 1. 设计方法 Verilog 的设计多采用自上而下的设计方法&#xff08;top-down&#xff09;。设计流程是指从一个项目开始从项目需求分析&#xff0c;架构设计&#xff0c;功能验证&#…

智能客服的应用——政务领域

#本文来源清华大学数据治理研究中心政务热线数智化发展报告 &#xff0c;如有侵权&#xff0c;请联系删除。 面对地方政务热线发展所面临的挑战&#xff0c;数智化转型已经成为了热线系统突破当前发展瓶颈、实现整体提质增效的关键手段。《意见》中也明确指出&#xff0c;政务…

ChatGPT4 Excel 高级复杂函数案例实践

案例需求: 需求中需要判断多个条件进行操作。 可以让ChatGPT来实现这样的操作。 Prompt:有一个表格B2单元格为入职日期,C2单元格为员工等级(A,B,C),D2单元格为满意度分数(1,2,3,4,5)请给入职一年以上,员工等级为A级并且满意度在3分以上的人发4000元奖金,给入…

SoloLinker第一次使用记录,解决新手拿到板子的无所适从

本文目录 一、简介二、进群获取资料2.1 需要下载资料2.2 SDK 包解压 三、SDK 编译3.1 依赖安装3.2 编译配置3.3 启动编译3.4 编译后的固件目录 四、固件烧录4.1 RV1106 驱动安装4.2 打开烧录工具4.3 进入boot 模式&#xff08;烧录模式&#xff09;4.4 烧录启动固件4.5 烧录升级…

AntDesignBlazor示例——分页查询

本示例是AntDesign Blazor的入门示例&#xff0c;在学习的同时分享出来&#xff0c;以供新手参考。 示例代码仓库&#xff1a;https://gitee.com/known/BlazorDemo 1. 学习目标 分页查询框架天气数据分页功能表格自定义分页 2. 创建分页查询框架 Table组件分页默认为前端分…

1.electron之纯原生js/jquery的桌面应用程序(基础篇)

如果可以实现记得点赞分享&#xff0c;谢谢老铁&#xff5e; Electron是一个使用 JavaScript、HTML 和 CSS 构建桌面应用程序的框架。 Electron 将 Chromium 和 Node.js 嵌入到了一个二进制文件中&#xff0c;因此它允许你仅需一个代码仓库&#xff0c;就可以撰写支持 Windows、…

Mybatis-Plus——01搭建环境、快速入门(新注解、依赖)

搭建环境、快速入门 一、准备数据库二、创建项目三、导入依赖四、配置连接数据库五、编写实体类六、编写mapper接口七、主程序加MapperScan八、测试&#xff0c;输出查询结果————————创作不易&#xff0c;如觉不错&#xff0c;随手点赞&#xff0c;关注&#xff0c;收藏…

《科技风》期刊发表投稿方式、收稿方向

《科技风》杂志是经国家新闻出版总署批准&#xff0c;河北省科学技术协会主管&#xff0c;河北省科技咨询服务中心主办的国内公开发行的大型综合类科技期刊。 该刊集科技性、前瞻性、创新性和专业性于一体&#xff0c;始终以“把脉科技创新 引领发展风尚”为办刊宗旨&#xff…

spark-常用算子

一&#xff0c;Transformation变换/转换算子&#xff1a; 这种变换并不触发提交作业&#xff0c;这种算子是延迟执行的&#xff0c;也就是说从一个RDD转换生成另一个RDD的转换操作不是马上执行&#xff0c;需要等到有Action操作的时候才会真正触发。 1.Value数据类型的Transf…

vue中qrcanvas生成二维码并且下载二维码

vue中qrcanvas生成带logo二维码并且下载二维码 1.引入qrcanvas模块 cnpm install --save qrcanvas //parkage.json 中引入 "qrcanvas": "^3.1.2" import { qrcanvas } from qrcanvas2.前端vue页面展示 <el-buttonsize"mini"type"tex…

FFmpeg转码流程和常见概念

视频格式&#xff1a;mkv&#xff0c;flv&#xff0c;mov&#xff0c;wmv&#xff0c;avi&#xff0c;mp4&#xff0c;m3u8&#xff0c;ts等等 FFmpeg的转码工具&#xff0c;它的处理流程是这样的&#xff1a; 从输入源获得原始的音视频数据&#xff0c;解封装得到压缩封装的音…

企业微信机器人发送文本、图片、文件、markdown、图文信息

import requests import base64 import hashlib import json # 机器人地址的key值 key"811a1652-60e8-4f51-a1d9-231783399ad2" def path2base64(path):"""文件转换为base64:param path: 文件路径:return:"""with open(path, "rb…

设计模式-模板模式

设计模式专栏 模式介绍模式特点应用场景模板模式和工厂模式区别代码示例Java实现模板模式python实现模板模式 模板模式在spring中的应用 模式介绍 模板模式是一种行为型设计模式&#xff0c;它通过将算法的骨架抽象成一个模板方法&#xff0c;将具体的操作留给子类来实现。这种…

iPhone 与三星手机:哪一款最好?

三星比苹果好吗&#xff1f;还是苹果比三星更好&#xff1f; 小米公司如何称霸全球智能手机市场&#xff1f;小米公司&#xff0c;由雷军创立于2010年&#xff0c;是一家领先的电子巨头。以其MIUI系统和互联网服务闻名&#xff0c;小米公司在全球智能手机市场中稳居前列。小米…

网络(七)路由协议以及相关配置

目录 一、路由器的工作原理 二、路由表的形成 2.1 直连网段 2.2 非直连网 2.3 路由表解析 2.3.1 查看路由表 2.3.2 解析 三、静态路由和默认路由 1. 静态路由 1.1 定义 1.2 特点 2. 默认路由 2.1 定义 2.2 特点 四、静态路由和默认路由的配置 1. 静态路由配置…

flutter学习-day12-可滚动组件和监听

&#x1f4da; 目录 简介可滚动组件 SingleChildScrollViewListView separated分割线无限加载列表带标题列表 滚动监听和控制 ScrollController滚动监听NotificationListener滚动监听 AnimatedList动画列表滚动网格布局GridView 横轴子元素为固定数量横轴子元素为固定最大长度…