AlexNet——训练花数据集

目录

一、网络结构

二、创新点分析

三、知识点

1. nn.ReLU(inplace) 

2. os.getcwd与os.path.abspath 

3. 使用torchvision下的datasets包 

4. items()与dict()用法 

5. json文件  

6. tqdm

7. net.train()与net.val()

四、代码


AlexNet是由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年ImageNet图像分类竞赛中提出的一种经典的卷积神经网络。AlexNet使用了Dropout层,减少过拟合现象的发生。

一、网络结构

二、数据集 

文件存放:

dataset->flower_data->flower_photos

再使用split_data.py 将数据集根据比例划分成训练集和预测集

详细请查看b站up主霹雳吧啦Wz:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification

三、创新点分析

1. deeper网络结构

    通过增加网络深度,AlexNet可以更好的学习数据集的特征,并提高分类的准确率。

2. 使用ReLU激活函数,克服梯度消失以及求梯度复杂的问题。

3. 使用LRN局部响应归一化

    LRN是在卷积与池化层间添加归一化操作。卷积过程中,每个卷积核都对应一个feature map,LRN对这些feature map进行归一化操作。即,对每个特征图的每个位置,计算该位置周围的像素平方和,然后将当前位置像素值除以这个和。LRN可抑制邻近神经元的响应,在一定程度上能够避免过拟合,提高网络泛化能力。

4. 使用Dropout层

Dropout层:在训练过程中随机删除一定比例的神经元,以减少过拟合。Dropout一般放在全连接层与全连接层之间。

四、知识点

1. nn.ReLU(inplace) 默认参数为:inplace=False

inplace=False:不会修改输入对象的值,而是返回一个新创建的对象,即打印出的对象存储地址不同。(值传递)

inplace=True:会修改输入对象的值,即打印的对象存储地址相同,可以节省申请与释放内存的空间与时间。(地址传递)

import torch
import numpy as np
import torch.nn as nn# id()方法返回对象的内存地址
relu1 = nn.ReLU(inplace=False)
relu2 = nn.ReLU(inplace=True)
data = np.random.randn(2, 4)
input = torch.from_numpy(data)  # 转换成tensor类型
print("input address:", id(input))
output1 = relu1(input)
print("replace=False -- output address:", id(output1))
output2 = relu2(input)
print("replace=True -- output address:", id(output2))
# input address: 1669839583200
# replace=False -- output address: 1669817512352
# replace=True -- output address: 1669839583200

2. os.getcwd与os.path.abspath 

os.getcwd():获取当前工作目录

os.path.abspath('xxx.py'):获取文件当前的完整路径

import osprint(os.getcwd())  # D:\Code
print(os.path.abspath('test.py'))  # D:\Code\test.py

3. 使用torchvision下的datasets包 

train_dataset=datasets.ImageFolder(root=os.path.join(image_path,'train'),transform=data_transform['train'])

可以得出这些信息: 

4. items()与dict()用法 

items():把字典中的每对key和value组成一个元组,并将这些元组放在列表中返回。

obj = {'dog': 0,'cat': 1,'fish': 2
}
print(obj)  # {'dog': 0, 'cat': 1, 'fish': 2}
print(obj.items())  # dict_items([('dog', 0), ('cat', 1), ('fish', 2)])
print(dict((v, k) for k, v in obj.items()))  # {0: 'dog', 1: 'cat', 2: 'fish'}

5. json文件  

(1)json.dumps:将Python对象编码成JSON字符串

(2)json.loads:将已编码的JSON字符串编码为Python对象

import jsondata = [1, 2, 3]
data_json = json.dumps(data)  # <class 'str'>
data = json.loads(data_json)
print(type(data))  # <class 'list'>

6. tqdm

train_bar = tqdm(train_loader, file=sys.stdout)
使用tqdm函数,对train_loader进行迭代,将进度条输出到标准输出流sys.stdout中。可以方便用户查看训练进度。

from tqdm import tqdm
import timefor i in tqdm(range(10)):time.sleep(0.1)

7. net.train()与net.val()

net.train():启用BatchNormalization和Dropout

net.eval|():不启用BatchNormalization和Dropout

五、代码

model.py

import torch
import torch.nn as nnclass AlexNet(nn.Module):def __init__(self, num_classes=1000):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, padding=2, stride=4),  # input[3,224,224] output[96,55,55]nn.ReLU(inplace=True),  # inplace=True 址传递nn.MaxPool2d(kernel_size=3, stride=2),  # output[96,27,27]nn.Conv2d(96, 256, kernel_size=5, padding=2),  # output[256,27,27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),  # output[256,13,13]nn.Conv2d(256, 384, kernel_size=3, padding=1),  # output[384,13,13]nn.ReLU(inplace=True),nn.Conv2d(384, 384, kernel_size=3, padding=1),  # output[384,13,13]nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),  # output[256,13,13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),  # output[256,6,6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(256 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)  # batch这一维度不用,从channel开始x = self.classifier(x)return x

train.py 

import os
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
from model import AlexNet
import torch.optim as optim
from tqdm import tqdmdef main():device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')print("using:{}".format(device))data_transform = {'train': transforms.Compose([# 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为指定大小transforms.RandomResizedCrop(224),# 水平方向随机翻转transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),'val': transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}# get data root pathdata_root = os.path.abspath(os.getcwd())  # D:\Code\AlexNet# get flower data set pathimage_path = os.path.join(data_root, 'data_set', 'flower_data')  # D:\Code\AlexNet\data_set\flower_data# 使用assert断言语句:出现错误条件时,就触发异常assert os.path.exists(image_path), '{} path does not exist!'.format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'), transform=data_transform['train'])val_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'), transform=data_transform['val'])train_num = len(train_dataset)val_num = len(val_dataset)# write class_dict into json fileflower_list = train_dataset.class_to_idxclass_dict = dict((v, k) for k, v in flower_list.items())json_str = json.dumps(class_dict)with open('class_indices.json', 'w') as file:file.write(json_str)batch_size = 32train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)net = AlexNet(num_classes=5)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0002)epochs = 5save_path = './model/AlexNet.pth'best_acc = 0.0train_steps = len(train_loader)  # train_num / batch_sizetrain_bar = tqdm(train_loader)val_bar = tqdm(val_loader)for epoch in range(epochs):# trainnet.train()epoch_loss = 0.0# 加入进度条train_bar = tqdm(train_loader)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()  # update x by optimizer# print statisticsepoch_loss += loss.item()train_bar.desc = 'train eporch[{}/{}] loss:{:.3f}'.format(epoch + 1, epochs, loss)# validatenet.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(val_loader)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]  # [1]取每行最大值的索引acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_acc = acc / val_numprint('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, epoch_loss / train_steps, val_acc))# find best accuracyif val_acc > best_acc:best_acc = val_acctorch.save(net.state_dict(), save_path)print('Train finished!')if __name__ == '__main__':main()

class_indices.json

{"0": "daisy", "1": "dandelion", "2": "roses", "3": "sunflowers", "4": "tulips"}

predict.py 

import os
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import json
from model import AlexNetdef main():device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = './2.jpg'assert os.path.exists(img_path), "file:'{}' does not exist".format(img_path)img = Image.open(img_path)plt.imshow(img)# input [N,C,H,W]img = transform(img)img = torch.unsqueeze(img, dim=0)# read class_indicesjson_path = './class_indices.json'assert os.path.exists(json_path), "file:'{}' does not exist".format(json_path)with open(json_path, 'r') as file:class_dict = json.load(file)  # {'0': 'daisy', '1': 'dandelion', '2': 'roses', '3': 'sunflowers', '4': 'tulips'}# load modelnet = AlexNet(num_classes=5).to(device)# load model weightsweight_path = './model/AlexNet.pth'assert os.path.exists(weight_path), "file:'{}' does not exist".format(weight_path)net.load_state_dict(torch.load(weight_path))# predictnet.eval()with torch.no_grad():output = torch.squeeze(net(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_class = torch.argmax(predict).numpy()print_res = 'class:{} probability:{:.3}'.format(class_dict[str(predict_class)], predict[predict_class].numpy())plt.title(print_res)plt.show()for i in range(len(predict)):print('class:{:10} probability:{:.3}'.format(class_dict[str(i)], predict[i]))if __name__ == '__main__':main()

Result:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/83315.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023最全Java面试题及答案汇总

前言 面试前还是很有必要针对性的刷一些题&#xff0c;很多朋友的实战能力很强&#xff0c;但是理论比较薄弱&#xff0c;面试前不做准备是很吃亏的。这里整理了很多面试常考的一些面试题&#xff0c;希望能帮助到你面试前的复习并且找到一个好的工作&#xff0c;也节省你在网…

NLP技术如何为搜索引擎赋能

目录 1. NLP关键词提取与匹配在搜索引擎中的应用1. 关键词提取例子 2. 关键词匹配例子 Python实现 2. NLP语义搜索在搜索引擎中的应用1. 语义搜索的定义例子 2. 语义搜索的重要性例子 Python/PyTorch实现 3. NLP个性化搜索建议在搜索引擎中的应用1. 个性化搜索建议的定义例子 2…

Java:JSR 310日期时间体系LocalDateTime、OffsetDateTime、ZonedDateTime

JSR 310日期时间体系&#xff1a; LocalDateTime&#xff1a;本地日期时间OffsetDateTime&#xff1a;带偏移量的日期时间ZonedDateTime&#xff1a;带时区的日期时间 目录 构造计算格式化参考文章 日期时间包 import java.time.LocalDateTime; import java.time.OffsetDateT…

Stable Diffusion WebUI 使用

想要正常运行 Stable Diffusion WebUI 需要机器上有 Nvidia 显卡才行, 简单体验可以 RTX 3070 起步, 正常玩需要 RTX 3080 起步, 要训练模型就要 RTX 3090 起步。 修改配置 通常 Stable Diffusion WebUI 的配置信息写在 stable-diffusion-webui/webui-user.sh 文件中: $ cd …

Eclipse如何打开debug变量窗口

今天笔者在使用Eclipse调试的时候&#xff0c;发现没有变量&#xff08;Variables&#xff09;监视窗口&#xff0c;真是头痛得很&#xff0c;最后摸索出一套显示变量窗口的操作如下&#xff1a; 点击other&#xff0c;找到Variables并点击 最后调试代码&#xff0c;调试后如图…

机器学习(17)---支持向量机(SVM)

支持向量机 一、概述1.1 介绍1.2 工作原理1.3 三层理解 二、sklearn.svm.SVC2.1 查看数据集2.2 contour函数2.3 画决策边界&#xff1a;制作网格2.4 建模画图 三、非线性情况推广3.1 查看数据集3.2 线性画图3.3 为非线性数据增加维度并绘制3D图像 四、核函数 一、概述 1.1 介绍…

免杀对抗-Python-混淆算法+反序列化-打包生成器-Pyinstall

Python-MSF/CS生成shellcode-上线 cs上线 1.生成shellcode-c或者python 2.打开pycharm工具&#xff0c;创建一个py文件&#xff0c;将原生态执行代码复制进去 shellcode执行代码&#xff1a; import ctypesfrom django.contrib.gis import ptr#cs#shellcodebytearray(b"生…

Java skill - 服务同时开始https和http端口

Java skill - 服务同时开始https和http端口 添加ssl配置代码开启http端口讲解大坑 添加ssl配置 在配置文件中添加配置 server:# ssl证书配置ssl:# 双向证书配置# 证书文件路径key-store: /opt/ops/cert/xes.p12# 证书密码key-store-password: 123456# 证书类型key-store-type…

IMX6ULL移植篇-Linux内核源码目录分析一

一. Linux内核源码目录 之前文章对 Linux内核源码的文件做了大体的了解&#xff0c;如下&#xff1a; IMX6ULL移植篇-Linux内核源码文件表_凌肖战的博客-CSDN博客 本文具体说明 Linux内核源码的一些重要文件含义。 二. Linux内核源码中重要文件分析 1. arch 目录 这个目录…

用了 TCP 协议,就一定不会丢包吗?

表面上我是个技术博主。 但没想到今天成了个情感博主。 我是没想到有一天&#xff0c;我会通过技术知识&#xff0c;来挽救粉丝即将破碎的感情。 掏心窝子的说。这件事情多少是沾点功德无量了。 事情是这样的。 最近就有个读者加了我的绿皮聊天软件&#xff0c;女生&#xff0c…

01强化学习的数学原理:大纲

01强化学习学习路线大纲 前言强化学习脉络图章节介绍Chapter 1&#xff1a;Basic ConceptsChapter 2&#xff1a;Bellman EquationChapter 3&#xff1a;Bellman Optimality EquationChapter 4&#xff1a;Value Iteration / Policy IterationChapter 5&#xff1a;Monte Carlo…

华为OD机试 - 靠谱的车 - 逻辑分析(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;A卷B卷&#…

JOSEF约瑟 智能电流继电器KWJL-20/L KWLD26 零序孔径45mm 柜内导轨式安装

KWJL-20智能电流继电器 零序互感器&#xff1a; KWLD80 KWLD45 KWLD26 KWJL-20 一、产品概述 KWJL-20系列智能剩余电流继电器&#xff08;以下简称继电器&#xff09;适用于交流电压至660V或更高的TN、TT、和IT系统&#xff0c;频率为50Hz。通过零序电流互感器检测出超过…

qt qml RadioButton如何设置字体颜色,style提示找不到怎么办?

qt QML中设置RadioButton的字体颜色&#xff0c;可以使用RadioButton的label属性来设置文本的样式。下面是一个示例代码&#xff1a; import QtQuick 2.6 import QtQuick.Controls 2.2 import QtQuick.Controls 1.4 as Controls1_4 import QtQuick.Controls.Styles 1.4 import…

如何使用Python和Numpy实现简单的2D FDTD仿真:详细指南与完整代码示例

第一部分&#xff1a;引言及FDTD简介 引言: 计算机模拟在许多科学和工程领域中都得到了广泛应用。在电磁学领域&#xff0c;有许多不同的数值方法用于模拟波的传播和散射。其中最为知名和广泛使用的一种方法是有限差分时域方法(Finite Difference Time Domain, FDTD)。在这篇…

IOTE 2023国际物联网展直击:芯与物发布全新定位芯片,助力多领域智能化发展

IOTE 2023国际物联网展&#xff0c;作为全球物联网领域的盛会&#xff0c;于9月20日在中国深圳拉开帷幕。北斗星通集团应邀参展&#xff0c;旗下专业从事物联网、消费类GNSS芯片研发设计的芯与物公司也随其亮相本届盛会。 展会上&#xff0c;芯与物展示了一系列创新的GNSS定位…

消费盲返模式:一种让消费者和商家都受益的新型消费返利模式

消费盲返是一种新型的消费返利模式&#xff0c;它的核心思想是&#xff1a;消费者在平台购买商品后&#xff0c;可以获得后续一定数量的订单的部分利润作为奖励。这样&#xff0c;消费者不仅可以享受商品的优惠&#xff0c;还有可能赚取更多的钱。 这种模式对于平台和消费者都有…

【教程】AERMOD高斯稳态扩散模型

查看原文>>>基于AERMOD模型在大气环境影响评价中的实践应用 随着我国经济快速发展&#xff0c;我国面临着日益严重的大气污染问题。近年来&#xff0c;严重的大气污染问题已经明显影响国计民生&#xff0c;引起政府、学界和人们越来越多的关注。大气污染是工农业生产…

iOS蓝牙 Connection Parameters 关键参数说明

1. 先贴苹果文档 《 Accessory Design Guidelines for Apple Devices 》 2. 几个关键词 connection Event Interval 事件间隔&#xff0c;为1.25ms的倍数。可以简单理解为,是两个连接着的蓝牙设备发送“心跳包”的时间间隔&#xff1b; 范围是 6 ~ 3200&#xff0c;即 7.5…

Jmeter性能测试吞吐量控制器使用小结

吞吐量控制器(Throughput Controller)场景: 在同一个线程组里, 有10个并发, 7个做A业务, 3个做B业务,要模拟这种场景,可以通过吞吐量模拟器来实现.。 jmeter性能测试&#xff1a;2023最新的大厂jmeter性能测试全过程项目实战详解&#xff0c;悄悄收藏&#xff0c;后面就看不到…