GoogleNet网络训练集和测试集搭建

测试集和训练集都是在之前搭建好的基础上进行修改的,重点记录与之前不同的代码。

还是使用的花分类的数据集进行训练和测试的。

一、训练集

1、搭建网络

设置参数:使用辅助分类器,采用权重初始化

net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)

2、参数输出

之前的模型只有 1 个输出,但由于GoogleNet使用了两个辅助分类器,所以会有 3 个输出。

定义三个输出,分别计算主分类器、辅助分类器1、辅助分类器2的损失函数并相加,最后将损失函数反向传播,使用优化器更新参数模型。 

不单独放代码了,不知道哪里是改动的。图片中红色框中是改动的

整个训练集的代码

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import GoogleNet
import os
import json
import timedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,shuffle=False, num_workers=0)# test_data_iter = iter(validate_loader)
# test_image, test_label = next(test_data_iter)
#
# # 查看图片
# def imshow(img):
#     img = img / 2 + 0.5
#     nping = img.numpy()
#     plt.imshow(np.transpose(nping, (1, 2, 0)))
#     plt.show()
# # print labels
# print(' '.join('%5s' % str(cla_dict[test_label[j].item()]) for j in range(4)))
# # show images
# imshow(utils.make_grid(test_image))net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0003)best_acc = 0.0
save_path = './GoogleNet.pth'
# best_acc = 0.0
for epoch in range(2):# trainnet.train()running_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()logits, aux_logits2, aux_logits1 = net(images.to(device))loss0 = loss_function(logits, labels.to(device))loss1 = loss_function(aux_logits1, labels.to(device))loss2 = loss_function(aux_logits2, labels.to(device))loss = loss0 + loss1 * 0.3 + loss2 * 0.3loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()rate = (step+1) / len(train_loader)a = "*" * int(rate*50)b = "." *int((1-rate)*50)print("\rtrain loss: (:3.0f)%[()->:.3f)".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)net.eval()acc = 0.0with torch.no_grad():for data_test in validate_loader:test_images, test_labels = data_testoutputs = net(test_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == test_labels.to(device)).sum().item()accurate_test = acc / val_numif accurate_test > best_acc:best_acc = accurate_testtorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, acc / val_num))
print("Finished Training")

训练完成 

 中间有几次报错,不过在看懂报错后很快改过来了。

二、测试集

载入模型

在创建模型的时候,aux_logits不会构建辅助分类器,但是之前训练的参数会保存。

所以,在载入模型的时候,要设置参数strict=False, 它可以精准匹配当前模型与所需要载入的权重模型的结构。

辅助分类器中的参数全部存放在unexpecte_keys中。

测试集全部代码

 可以自己找图片进行预测看准确率。

import torch
import matplotlib.pyplot as plt
import json
from model import GoogleNet
from PIL import Image
from torchvision import transformsdata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("8.jpeg")
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)# read class_indent
try:json_file = open('./class_indices,json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = GoogleNet(num_classes=5, aux_logits=False)
model_weight_path = "./GoogleNet.pth"
missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
model.eval()
with torch.no_grad():output = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

准确率好低,可能是模型训练的还不够吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/823779.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Midjourney常见玩法及prompt关键词技巧

今天系统给大家讲讲Midjourney的常见玩法和prompt关键词的一些注意事项,带大家入门~(多图预警,建议收藏~) 一、入门及常见玩法 1、注册并添加服务器(会的童鞋可跳过~) …

实在智能携AI Agent智能体亮相2024年度QCon全球软件开发大会

4月11日-13日,以“全面进化”作为2024年度主题的「QCon全球软件开发大会暨智能软件开发生态展」在北京隆重举行。作为AI准独角兽和超自动化头部企业,实在智能应邀出席发表《面向办公自动化领域的AI Agent建设思考与分享》演讲及圆桌交流,展示…

反转二叉树(力扣226)

解题思路:用队列进行前序遍历的同时把节点的左节点和右节点交换 具体代码如下: class Solution { public:TreeNode* invertTree(TreeNode* root) {if (root NULL) return root;swap(root->left, root->right); // 中invertTree(root->left)…

OpenHarmony南向开发案例【智慧中控面板(基于 Bearpi-Micro)】

1 开发环境搭建 【从0开始搭建开发环境】【快速搭建开发环境】 参考鸿蒙开发指导文档:gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或复制转到。 【注意】:快速上手教程第六步出拉取代码时需要修改代码仓库地址 在MobaXterm中输入…

【模拟】Leetcode 提莫攻击

题目讲解 495. 提莫攻击 算法讲解 前后的两个数字之间的关系要么是相减之差 > 中毒时间 &#xff0c;要么反之 那即可通过示例&#xff0c;进行算法的模拟&#xff0c;得出上图的计算公式 class Solution { public:int findPoisonedDuration(vector<int>& time…

redis热key问题如何解决?

今天跟大家分享一个干货——redis热key问题 那么什么是redis热key呢&#xff1f; 在Redis中&#xff0c;热键问题是指那些经常被访问的键&#xff0c;它们会消耗大量的Redis服务器资源&#xff0c;当大量的请求集中在一个key上&#xff0c;会导致这个节点的cpu&#xff0c;内…

至少需要[XXXXMB]内存才能安装(宝塔导入数据库提示)

①我的2g内存腾讯云服务器想安装mysql8.0 ②宝塔提示“至少需要[3700MB]内存才能安装” 将数据库部署到宝塔上的时候提示-----》至少需要[XXXXMB]内存才能安装&#xff0c;解决的方法其实也很简单。 首先&#xff0c;进入文件夹/www/server/panel/class&#xff0c;找到找到…

使用Python插入100万条数据到MySQL数据库并将数据逐步写出到多个Excel

Python插入100万条数据到MySQL数据库 步骤一&#xff1a;导入所需模块和库 首先&#xff0c;我们需要导入 MySQL 连接器模块和 Faker 模块。MySQL 连接器模块用于连接到 MySQL 数据库&#xff0c;而 Faker 模块用于生成虚假数据。 import mysql.connector # 导入 MySQL 连接…

PostgreSQL的学习心得和知识总结(一百三十八)|深入理解PostgreSQL数据库之Protocol message构造和解析逻辑

目录结构 注&#xff1a;提前言明 本文借鉴了以下博主、书籍或网站的内容&#xff0c;其列表如下&#xff1a; 1、参考书籍&#xff1a;《PostgreSQL数据库内核分析》 2、参考书籍&#xff1a;《数据库事务处理的艺术&#xff1a;事务管理与并发控制》 3、PostgreSQL数据库仓库…

【Java开发指南 | 第十一篇】Java运算符

读者可订阅专栏&#xff1a;Java开发指南 |【CSDN秋说】 文章目录 算术运算符关系运算符位运算符逻辑运算符赋值运算符条件运算符&#xff08;?:&#xff09;instanceof 运算符Java运算符优先级 Java运算符包括&#xff1a;算术运算符、关系运算符、位运算符、逻辑运算符、赋值…

关于ContentProvider这一遍就够了

ContentProvider是什么&#xff1f; ContentProvider是Android四大组件之一&#xff0c;主要用于不同应用程序之间或者同一个应用程序的不同部分之间共享数据。它是Android系统中用于存储和检索数据的抽象层&#xff0c;允许不同的应用程序通过统一的接口访问数据&#xff0c;…

java实现请求缓冲合并

业务背景&#xff1a; 我们对外提供了一个rest接口给第三方业务进行调用&#xff0c;但是由于第三方框架限制&#xff0c;导致会发送大量相似无效请求&#xff0c;例如&#xff1a;接口入参json包含两个字段&#xff0c;createBy和receiverList&#xff0c;完整的入参json示例…

Postman之版本信息查看

Postman之版本信息查看 一、为何需要查看版本信息&#xff1f;二、查看Postman的版本信息的步骤 一、为何需要查看版本信息&#xff1f; 不同的版本之间可能存在功能和界面的差异。 二、查看Postman的版本信息的步骤 1、打开 Postman 2、打开设置项 点击页面右上角的 “Set…

Java中的容器

Java中的容器主要包括以下几类&#xff1a; Collection接口及其子接口/实现类&#xff1a; List 接口及其实现类&#xff1a; ArrayList&#xff1a;基于动态数组实现的列表&#xff0c;支持随机访问&#xff0c;插入和删除元素可能导致大量元素移动。LinkedList&#xff1a;基…

只用键盘的技巧

技巧一&#xff1a;将常用软件固定在任务栏使用winnum/winT(shift)打开 技巧二&#xff1a;winX快捷键&#xff08;显示快捷键的快捷键&#xff09; ALT F4    关闭当前应用程序 技巧三&#xff1a;使用好Chrome快键键 ctrl h&#xff1b;历史纪录。 ctrl shift esc&am…

充电桩--OCPP 充电通讯协议介绍

一、OCPP协议介绍 OCPP的全称是 Open Charge Point Protocol 即开放充电点协议&#xff0c; 它是免费开放的协议&#xff0c;该协议由位于荷兰的组织 OCA&#xff08;开放充电联盟&#xff09;进行制定。Open Charge Point Protocol (OCPP) 开放充电点协议用于充电站(CS)和任何…

大模型开发轻松入门——(1)从搭建自己的环境开始

pip install openai import openai import osfrom dotenv import load_dotenv, find_dotenv _ load_dotenv(find_dotenv())openai.api_key os.getenv(OPENAI_API_KEY)

mac IDEA激活 亲测有效

1、官网下载mac版本IDEA并安装 2、打开激活页面 3、下载脚本文件 链接: https://pan.baidu.com/s/1I2BqdfxSJv1A96422rflnA?pwdm494 提取码: m494 4、命令行到该界面&#xff0c;执行 sudo bash idea.sh 可能出现的问题&#xff1a; 查看sh文件&#xff0c;targetFilePath…

vue快速入门(十六)事件修饰符

注释很详细&#xff0c;直接上代码 上一篇 新增内容 事件修饰符之阻止冒泡事件修饰符之阻止默认行为 源码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdev…

重定向原理和缓冲区

文章目录 重定向缓冲区 正文开始前给大家推荐个网站&#xff0c;前些天发现了一个巨牛的 人工智能学习网站&#xff0c; 通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。 点击跳转到网站。 重定向 内核中为了管理被打开的文件&#xff0c;一定会存在描述一…