GoogLeNet(pytorch)

亮点与创新:

1. 引入Inception基础结构

2. 引入PW维度变换卷积,启迪后续参数量的优化

3. 丢弃全连接层,使用平均池化层(大大减少模型参数)

4. 添加两个辅助分类器帮助训练(避免梯度消失,用于向前传导梯度,也有一定的正则化效果,防止过拟合)

 model.py

import torch.nn as nn
import torch
import torch.nn.functional as Fclass GoogLeNet(nn.Module):def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):super(GoogLeNet, self).__init__()self.aux_logits = aux_logitsself.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.conv2 = BasicConv2d(64, 64, kernel_size=1)self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)if self.aux_logits:self.aux1 = InceptionAux(512, num_classes)self.aux2 = InceptionAux(528, num_classes)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.dropout = nn.Dropout(0.4)self.fc = nn.Linear(1024, num_classes)if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.conv1(x)# N x 64 x 112 x 112x = self.maxpool1(x)# N x 64 x 56 x 56x = self.conv2(x)# N x 64 x 56 x 56x = self.conv3(x)# N x 192 x 56 x 56x = self.maxpool2(x)# N x 192 x 28 x 28x = self.inception3a(x)# N x 256 x 28 x 28x = self.inception3b(x)# N x 480 x 28 x 28x = self.maxpool3(x)# N x 480 x 14 x 14x = self.inception4a(x)# N x 512 x 14 x 14if self.training and self.aux_logits:    # eval model lose this layeraux1 = self.aux1(x)x = self.inception4b(x)# N x 512 x 14 x 14x = self.inception4c(x)# N x 512 x 14 x 14x = self.inception4d(x)# N x 528 x 14 x 14if self.training and self.aux_logits:    # eval model lose this layeraux2 = self.aux2(x)x = self.inception4e(x)# N x 832 x 14 x 14x = self.maxpool4(x)# N x 832 x 7 x 7x = self.inception5a(x)# N x 832 x 7 x 7x = self.inception5b(x)# N x 1024 x 7 x 7x = self.avgpool(x)# N x 1024 x 1 x 1x = torch.flatten(x, 1)# N x 1024x = self.dropout(x)x = self.fc(x)# N x 1000 (num_classes)if self.training and self.aux_logits:   # eval model lose this layerreturn x, aux2, aux1return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)class Inception(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):super(Inception, self).__init__()self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)self.branch2 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, kernel_size=1),BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小)self.branch3 = nn.Sequential(BasicConv2d(in_channels, ch5x5red, kernel_size=1),# 在官方的实现中,其实是3x3的kernel并不是5x5,这里我也懒得改了,具体可以参考下面的issue# Please see https://github.com/pytorch/vision/issues/906 for details.BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小)self.branch4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),BasicConv2d(in_channels, pool_proj, kernel_size=1))def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)branch4 = self.branch4(x)outputs = [branch1, branch2, branch3, branch4]return torch.cat(outputs, 1)class InceptionAux(nn.Module):def __init__(self, in_channels, num_classes):super(InceptionAux, self).__init__()self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]self.fc1 = nn.Linear(2048, 1024)self.fc2 = nn.Linear(1024, num_classes)def forward(self, x):# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14x = self.averagePool(x)# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4x = self.conv(x)# N x 128 x 4 x 4x = torch.flatten(x, 1)x = F.dropout(x, 0.5, training=self.training)# N x 2048x = F.relu(self.fc1(x), inplace=True)x = F.dropout(x, 0.5, training=self.training)# N x 1024x = self.fc2(x)# N x num_classesreturn xclass BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, **kwargs):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.relu(x)return x

train.py

模型有三个返回结果,一个预测,两个辅助分类器

import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import GoogLeNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)# 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型# 官方的模型中使用了bn层以及改了一些参数,不能混用# import torchvision# net = torchvision.models.googlenet(num_classes=5)# model_dict = net.state_dict()# # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth# pretrain_model = torch.load("googlenet.pth")# del_list = ["aux1.fc2.weight", "aux1.fc2.bias",#             "aux2.fc2.weight", "aux2.fc2.bias",#             "fc.weight", "fc.bias"]# pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}# model_dict.update(pretrain_dict)# net.load_state_dict(model_dict)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0003)epochs = 30best_acc = 0.0save_path = './googleNet.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits, aux_logits2, aux_logits1 = net(images.to(device))loss0 = loss_function(logits, labels.to(device))loss1 = loss_function(aux_logits1, labels.to(device))loss2 = loss_function(aux_logits2, labels.to(device))loss = loss0 + loss1 * 0.3 + loss2 * 0.3loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))  # eval model only have last output layerpredict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

注意训练代码里的这段注释代码:

    import torchvisionnet = torchvision.models.googlenet(num_classes=5)#如果用的是官方提供的torchvision.models.googlenet模型,下载下来官方训练的权重是在imagenet1000分类上训练到的权重,需要删除一些层的权重来训练,而且训练的模型直接就是官方的模型,得到自己适应训练集的权重,简单总结,官方模型,官方权重,自己数据集,稍微修改官方权重,使用稍微修改的官方权重训练官方模型下自己数据集的权重
#对于怎么冻结权重,解冻权重,由于这个模型学习时还在训练前面的VGG网络,到resnet进行学习#若想用官方权重训练自己的模型得到适应数据集的权重,也应该是修改一下官方权重
#之后resnet细说model_dict = net.state_dict()# 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pthpretrain_model = torch.load("googlenet.pth")del_list = ["aux1.fc2.weight", "aux1.fc2.bias","aux2.fc2.weight", "aux2.fc2.bias","fc.weight", "fc.bias"]pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}model_dict.update(pretrain_dict)net.load_state_dict(model_dict)

predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import GoogLeNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "../tulip.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimension#对单张图片添加维度,以适应模型#如果是批次预测,之后注意批次预测的代码怎么设计img = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = GoogLeNet(num_classes=5, aux_logits=False).to(device)# load model weightsweights_path = "./googleNet.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),strict=False)model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227495.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NXP应用随记(五):eMios功能点阅读随记

目录 1、概念点 2、eMios功能点 2.1、eMIOS - Single Action Input Capture (SAIC) 2.2、eMIOS - Single Action Output Compare (SAOC) 2.3、eMIOS - Double Action Output Compare (DAOC) 2.4、eMIOS - Pulse/Edge Counting (PEC) – Single Shot 2.5、eMIOS - Pulse/E…

K8S(五)—命名空间与资源配额

目录 命名空间(Namespace)命令计算资源配额创建命名空间绑定一个ResourceQuota资源将命名空间和资源限制对象进行绑定尝试创建第二个 Pod查看ResourceQuota 绑定第二个ResourceQuota为命名空间配置默认的 CPU 、memory请求和限制(1)Pod 中所有容器都没有…

[Verilog] 设计方法和设计流程

主页: 元存储博客 文章目录 1. 设计方法2. 设计流程 3 Vivado软件设计流程总结 1. 设计方法 Verilog 的设计多采用自上而下的设计方法(top-down)。设计流程是指从一个项目开始从项目需求分析,架构设计,功能验证&#…

智能客服的应用——政务领域

#本文来源清华大学数据治理研究中心政务热线数智化发展报告 ,如有侵权,请联系删除。 面对地方政务热线发展所面临的挑战,数智化转型已经成为了热线系统突破当前发展瓶颈、实现整体提质增效的关键手段。《意见》中也明确指出,政务…

ChatGPT4 Excel 高级复杂函数案例实践

案例需求: 需求中需要判断多个条件进行操作。 可以让ChatGPT来实现这样的操作。 Prompt:有一个表格B2单元格为入职日期,C2单元格为员工等级(A,B,C),D2单元格为满意度分数(1,2,3,4,5)请给入职一年以上,员工等级为A级并且满意度在3分以上的人发4000元奖金,给入…

SoloLinker第一次使用记录,解决新手拿到板子的无所适从

本文目录 一、简介二、进群获取资料2.1 需要下载资料2.2 SDK 包解压 三、SDK 编译3.1 依赖安装3.2 编译配置3.3 启动编译3.4 编译后的固件目录 四、固件烧录4.1 RV1106 驱动安装4.2 打开烧录工具4.3 进入boot 模式(烧录模式)4.4 烧录启动固件4.5 烧录升级…

AntDesignBlazor示例——分页查询

本示例是AntDesign Blazor的入门示例,在学习的同时分享出来,以供新手参考。 示例代码仓库:https://gitee.com/known/BlazorDemo 1. 学习目标 分页查询框架天气数据分页功能表格自定义分页 2. 创建分页查询框架 Table组件分页默认为前端分…

1.electron之纯原生js/jquery的桌面应用程序(基础篇)

如果可以实现记得点赞分享,谢谢老铁~ Electron是一个使用 JavaScript、HTML 和 CSS 构建桌面应用程序的框架。 Electron 将 Chromium 和 Node.js 嵌入到了一个二进制文件中,因此它允许你仅需一个代码仓库,就可以撰写支持 Windows、…

Mybatis-Plus——01搭建环境、快速入门(新注解、依赖)

搭建环境、快速入门 一、准备数据库二、创建项目三、导入依赖四、配置连接数据库五、编写实体类六、编写mapper接口七、主程序加MapperScan八、测试,输出查询结果————————创作不易,如觉不错,随手点赞,关注,收藏…

《科技风》期刊发表投稿方式、收稿方向

《科技风》杂志是经国家新闻出版总署批准,河北省科学技术协会主管,河北省科技咨询服务中心主办的国内公开发行的大型综合类科技期刊。 该刊集科技性、前瞻性、创新性和专业性于一体,始终以“把脉科技创新 引领发展风尚”为办刊宗旨&#xff…

设计模式-模板模式

设计模式专栏 模式介绍模式特点应用场景模板模式和工厂模式区别代码示例Java实现模板模式python实现模板模式 模板模式在spring中的应用 模式介绍 模板模式是一种行为型设计模式,它通过将算法的骨架抽象成一个模板方法,将具体的操作留给子类来实现。这种…

iPhone 与三星手机:哪一款最好?

三星比苹果好吗?还是苹果比三星更好? 小米公司如何称霸全球智能手机市场?小米公司,由雷军创立于2010年,是一家领先的电子巨头。以其MIUI系统和互联网服务闻名,小米公司在全球智能手机市场中稳居前列。小米…

网络(七)路由协议以及相关配置

目录 一、路由器的工作原理 二、路由表的形成 2.1 直连网段 2.2 非直连网 2.3 路由表解析 2.3.1 查看路由表 2.3.2 解析 三、静态路由和默认路由 1. 静态路由 1.1 定义 1.2 特点 2. 默认路由 2.1 定义 2.2 特点 四、静态路由和默认路由的配置 1. 静态路由配置…

[wp]第四届江西省赣网杯网络安全大赛-web 部分wp

第四届江西省赣网杯网络安全大赛(gwb)线上预选赛 因为学业繁忙 只玩了1小时,后续看看补一下这些 2023gwb-web1 九宫格拼图 2023gwb-web2 $filexxx;extract($_GET);if(isset($fun)){$contenttrim(file_get_contents($file));if($fun!&…

MLX:苹果 专为统一内存架构(UMA) 设计的机器学习框架

“晨兴理荒秽,带月荷锄归” 夜深闻讯,有点兴奋~ 苹果为 UMA 设计的深度学习框架真的来了 统一内存架构 得益于 CPU 与 GPU 内存的共享,同时与 MacOS 和 M 芯片 交相辉映,在效率上,实现对其他框架的降维打…

uniapp之屏幕右侧出现滚动条去掉、隐藏、删除【好用!】

目录 问题解决大佬地址最后 问题 解决 在最外层view上加上class“content”;输入以下样式。注意:两个都必须存在在生效。 .content {/* 跟屏幕高度一样高,不管view中有没有内容,都撑开屏幕高的高度 */height: 100vh; overflow: auto; } .content::-webkit-scrollb…

【JavaWeb】往浏览器打印一个hello world

上集:建一个web项目 第一步:建好Servlet类的文件 右键src,建一个class 就行 第二步:编代码 可以直接复制粘贴 用来测试的类 import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; //↓是注解&#xff0…

Java 基础学习(十一)File类与I/O操作

1 File类 1.1 File类概述 1.1.1 什么是File类 File是java.io包下作为文件和目录的类。File类定义了一些与平台无关的方法来操作文件,通过调用File类中的方法可以得到文件和目录的描述信息,包括名称、所在路径、读写性和长度等,还可以对文件…

Graphics Profiler 使用教程

GraphicsProfiler 使用教程 1.工具简介:2.Navigation介绍2.1.打开安装好的Graphics Profiler。2.2.将手机连接到计算机,软件会在手机中安装一个GraphicsProfiler应用(该应用是无界面的)。2.3.Show files list2.4.Record new trace2.4.1.Appli…

听GPT 讲Rust源代码--src/tools(13)

File: rust/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/incoherent_impl.rs 在Rust源代码中,路径为rust/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/incoherent_impl.rs的文件是为了处理Rust代码中的不一致实现问题而存在的。…