AlexNet——训练花数据集

目录

一、网络结构

二、创新点分析

三、知识点

1. nn.ReLU(inplace) 

2. os.getcwd与os.path.abspath 

3. 使用torchvision下的datasets包 

4. items()与dict()用法 

5. json文件  

6. tqdm

7. net.train()与net.val()

四、代码


AlexNet是由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年ImageNet图像分类竞赛中提出的一种经典的卷积神经网络。AlexNet使用了Dropout层,减少过拟合现象的发生。

一、网络结构

二、数据集 

文件存放:

dataset->flower_data->flower_photos

再使用split_data.py 将数据集根据比例划分成训练集和预测集

详细请查看b站up主霹雳吧啦Wz:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification

三、创新点分析

1. deeper网络结构

    通过增加网络深度,AlexNet可以更好的学习数据集的特征,并提高分类的准确率。

2. 使用ReLU激活函数,克服梯度消失以及求梯度复杂的问题。

3. 使用LRN局部响应归一化

    LRN是在卷积与池化层间添加归一化操作。卷积过程中,每个卷积核都对应一个feature map,LRN对这些feature map进行归一化操作。即,对每个特征图的每个位置,计算该位置周围的像素平方和,然后将当前位置像素值除以这个和。LRN可抑制邻近神经元的响应,在一定程度上能够避免过拟合,提高网络泛化能力。

4. 使用Dropout层

Dropout层:在训练过程中随机删除一定比例的神经元,以减少过拟合。Dropout一般放在全连接层与全连接层之间。

四、知识点

1. nn.ReLU(inplace) 默认参数为:inplace=False

inplace=False:不会修改输入对象的值,而是返回一个新创建的对象,即打印出的对象存储地址不同。(值传递)

inplace=True:会修改输入对象的值,即打印的对象存储地址相同,可以节省申请与释放内存的空间与时间。(地址传递)

import torch
import numpy as np
import torch.nn as nn# id()方法返回对象的内存地址
relu1 = nn.ReLU(inplace=False)
relu2 = nn.ReLU(inplace=True)
data = np.random.randn(2, 4)
input = torch.from_numpy(data)  # 转换成tensor类型
print("input address:", id(input))
output1 = relu1(input)
print("replace=False -- output address:", id(output1))
output2 = relu2(input)
print("replace=True -- output address:", id(output2))
# input address: 1669839583200
# replace=False -- output address: 1669817512352
# replace=True -- output address: 1669839583200

2. os.getcwd与os.path.abspath 

os.getcwd():获取当前工作目录

os.path.abspath('xxx.py'):获取文件当前的完整路径

import osprint(os.getcwd())  # D:\Code
print(os.path.abspath('test.py'))  # D:\Code\test.py

3. 使用torchvision下的datasets包 

train_dataset=datasets.ImageFolder(root=os.path.join(image_path,'train'),transform=data_transform['train'])

可以得出这些信息: 

4. items()与dict()用法 

items():把字典中的每对key和value组成一个元组,并将这些元组放在列表中返回。

obj = {'dog': 0,'cat': 1,'fish': 2
}
print(obj)  # {'dog': 0, 'cat': 1, 'fish': 2}
print(obj.items())  # dict_items([('dog', 0), ('cat', 1), ('fish', 2)])
print(dict((v, k) for k, v in obj.items()))  # {0: 'dog', 1: 'cat', 2: 'fish'}

5. json文件  

(1)json.dumps:将Python对象编码成JSON字符串

(2)json.loads:将已编码的JSON字符串编码为Python对象

import jsondata = [1, 2, 3]
data_json = json.dumps(data)  # <class 'str'>
data = json.loads(data_json)
print(type(data))  # <class 'list'>

6. tqdm

train_bar = tqdm(train_loader, file=sys.stdout)
使用tqdm函数,对train_loader进行迭代,将进度条输出到标准输出流sys.stdout中。可以方便用户查看训练进度。

from tqdm import tqdm
import timefor i in tqdm(range(10)):time.sleep(0.1)

7. net.train()与net.val()

net.train():启用BatchNormalization和Dropout

net.eval|():不启用BatchNormalization和Dropout

五、代码

model.py

import torch
import torch.nn as nnclass AlexNet(nn.Module):def __init__(self, num_classes=1000):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, padding=2, stride=4),  # input[3,224,224] output[96,55,55]nn.ReLU(inplace=True),  # inplace=True 址传递nn.MaxPool2d(kernel_size=3, stride=2),  # output[96,27,27]nn.Conv2d(96, 256, kernel_size=5, padding=2),  # output[256,27,27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),  # output[256,13,13]nn.Conv2d(256, 384, kernel_size=3, padding=1),  # output[384,13,13]nn.ReLU(inplace=True),nn.Conv2d(384, 384, kernel_size=3, padding=1),  # output[384,13,13]nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),  # output[256,13,13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),  # output[256,6,6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(256 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)  # batch这一维度不用,从channel开始x = self.classifier(x)return x

train.py 

import os
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
from model import AlexNet
import torch.optim as optim
from tqdm import tqdmdef main():device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')print("using:{}".format(device))data_transform = {'train': transforms.Compose([# 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为指定大小transforms.RandomResizedCrop(224),# 水平方向随机翻转transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),'val': transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}# get data root pathdata_root = os.path.abspath(os.getcwd())  # D:\Code\AlexNet# get flower data set pathimage_path = os.path.join(data_root, 'data_set', 'flower_data')  # D:\Code\AlexNet\data_set\flower_data# 使用assert断言语句:出现错误条件时,就触发异常assert os.path.exists(image_path), '{} path does not exist!'.format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'), transform=data_transform['train'])val_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'), transform=data_transform['val'])train_num = len(train_dataset)val_num = len(val_dataset)# write class_dict into json fileflower_list = train_dataset.class_to_idxclass_dict = dict((v, k) for k, v in flower_list.items())json_str = json.dumps(class_dict)with open('class_indices.json', 'w') as file:file.write(json_str)batch_size = 32train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)net = AlexNet(num_classes=5)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0002)epochs = 5save_path = './model/AlexNet.pth'best_acc = 0.0train_steps = len(train_loader)  # train_num / batch_sizetrain_bar = tqdm(train_loader)val_bar = tqdm(val_loader)for epoch in range(epochs):# trainnet.train()epoch_loss = 0.0# 加入进度条train_bar = tqdm(train_loader)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()  # update x by optimizer# print statisticsepoch_loss += loss.item()train_bar.desc = 'train eporch[{}/{}] loss:{:.3f}'.format(epoch + 1, epochs, loss)# validatenet.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(val_loader)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]  # [1]取每行最大值的索引acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_acc = acc / val_numprint('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, epoch_loss / train_steps, val_acc))# find best accuracyif val_acc > best_acc:best_acc = val_acctorch.save(net.state_dict(), save_path)print('Train finished!')if __name__ == '__main__':main()

class_indices.json

{"0": "daisy", "1": "dandelion", "2": "roses", "3": "sunflowers", "4": "tulips"}

predict.py 

import os
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import json
from model import AlexNetdef main():device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = './2.jpg'assert os.path.exists(img_path), "file:'{}' does not exist".format(img_path)img = Image.open(img_path)plt.imshow(img)# input [N,C,H,W]img = transform(img)img = torch.unsqueeze(img, dim=0)# read class_indicesjson_path = './class_indices.json'assert os.path.exists(json_path), "file:'{}' does not exist".format(json_path)with open(json_path, 'r') as file:class_dict = json.load(file)  # {'0': 'daisy', '1': 'dandelion', '2': 'roses', '3': 'sunflowers', '4': 'tulips'}# load modelnet = AlexNet(num_classes=5).to(device)# load model weightsweight_path = './model/AlexNet.pth'assert os.path.exists(weight_path), "file:'{}' does not exist".format(weight_path)net.load_state_dict(torch.load(weight_path))# predictnet.eval()with torch.no_grad():output = torch.squeeze(net(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_class = torch.argmax(predict).numpy()print_res = 'class:{} probability:{:.3}'.format(class_dict[str(predict_class)], predict[predict_class].numpy())plt.title(print_res)plt.show()for i in range(len(predict)):print('class:{:10} probability:{:.3}'.format(class_dict[str(i)], predict[i]))if __name__ == '__main__':main()

Result:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/83315.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NLP技术如何为搜索引擎赋能

目录 1. NLP关键词提取与匹配在搜索引擎中的应用1. 关键词提取例子 2. 关键词匹配例子 Python实现 2. NLP语义搜索在搜索引擎中的应用1. 语义搜索的定义例子 2. 语义搜索的重要性例子 Python/PyTorch实现 3. NLP个性化搜索建议在搜索引擎中的应用1. 个性化搜索建议的定义例子 2…

Java:JSR 310日期时间体系LocalDateTime、OffsetDateTime、ZonedDateTime

JSR 310日期时间体系&#xff1a; LocalDateTime&#xff1a;本地日期时间OffsetDateTime&#xff1a;带偏移量的日期时间ZonedDateTime&#xff1a;带时区的日期时间 目录 构造计算格式化参考文章 日期时间包 import java.time.LocalDateTime; import java.time.OffsetDateT…

Eclipse如何打开debug变量窗口

今天笔者在使用Eclipse调试的时候&#xff0c;发现没有变量&#xff08;Variables&#xff09;监视窗口&#xff0c;真是头痛得很&#xff0c;最后摸索出一套显示变量窗口的操作如下&#xff1a; 点击other&#xff0c;找到Variables并点击 最后调试代码&#xff0c;调试后如图…

机器学习(17)---支持向量机(SVM)

支持向量机 一、概述1.1 介绍1.2 工作原理1.3 三层理解 二、sklearn.svm.SVC2.1 查看数据集2.2 contour函数2.3 画决策边界&#xff1a;制作网格2.4 建模画图 三、非线性情况推广3.1 查看数据集3.2 线性画图3.3 为非线性数据增加维度并绘制3D图像 四、核函数 一、概述 1.1 介绍…

免杀对抗-Python-混淆算法+反序列化-打包生成器-Pyinstall

Python-MSF/CS生成shellcode-上线 cs上线 1.生成shellcode-c或者python 2.打开pycharm工具&#xff0c;创建一个py文件&#xff0c;将原生态执行代码复制进去 shellcode执行代码&#xff1a; import ctypesfrom django.contrib.gis import ptr#cs#shellcodebytearray(b"生…

IMX6ULL移植篇-Linux内核源码目录分析一

一. Linux内核源码目录 之前文章对 Linux内核源码的文件做了大体的了解&#xff0c;如下&#xff1a; IMX6ULL移植篇-Linux内核源码文件表_凌肖战的博客-CSDN博客 本文具体说明 Linux内核源码的一些重要文件含义。 二. Linux内核源码中重要文件分析 1. arch 目录 这个目录…

用了 TCP 协议,就一定不会丢包吗?

表面上我是个技术博主。 但没想到今天成了个情感博主。 我是没想到有一天&#xff0c;我会通过技术知识&#xff0c;来挽救粉丝即将破碎的感情。 掏心窝子的说。这件事情多少是沾点功德无量了。 事情是这样的。 最近就有个读者加了我的绿皮聊天软件&#xff0c;女生&#xff0c…

01强化学习的数学原理:大纲

01强化学习学习路线大纲 前言强化学习脉络图章节介绍Chapter 1&#xff1a;Basic ConceptsChapter 2&#xff1a;Bellman EquationChapter 3&#xff1a;Bellman Optimality EquationChapter 4&#xff1a;Value Iteration / Policy IterationChapter 5&#xff1a;Monte Carlo…

华为OD机试 - 靠谱的车 - 逻辑分析(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;A卷B卷&#…

JOSEF约瑟 智能电流继电器KWJL-20/L KWLD26 零序孔径45mm 柜内导轨式安装

KWJL-20智能电流继电器 零序互感器&#xff1a; KWLD80 KWLD45 KWLD26 KWJL-20 一、产品概述 KWJL-20系列智能剩余电流继电器&#xff08;以下简称继电器&#xff09;适用于交流电压至660V或更高的TN、TT、和IT系统&#xff0c;频率为50Hz。通过零序电流互感器检测出超过…

IOTE 2023国际物联网展直击:芯与物发布全新定位芯片,助力多领域智能化发展

IOTE 2023国际物联网展&#xff0c;作为全球物联网领域的盛会&#xff0c;于9月20日在中国深圳拉开帷幕。北斗星通集团应邀参展&#xff0c;旗下专业从事物联网、消费类GNSS芯片研发设计的芯与物公司也随其亮相本届盛会。 展会上&#xff0c;芯与物展示了一系列创新的GNSS定位…

消费盲返模式:一种让消费者和商家都受益的新型消费返利模式

消费盲返是一种新型的消费返利模式&#xff0c;它的核心思想是&#xff1a;消费者在平台购买商品后&#xff0c;可以获得后续一定数量的订单的部分利润作为奖励。这样&#xff0c;消费者不仅可以享受商品的优惠&#xff0c;还有可能赚取更多的钱。 这种模式对于平台和消费者都有…

iOS蓝牙 Connection Parameters 关键参数说明

1. 先贴苹果文档 《 Accessory Design Guidelines for Apple Devices 》 2. 几个关键词 connection Event Interval 事件间隔&#xff0c;为1.25ms的倍数。可以简单理解为,是两个连接着的蓝牙设备发送“心跳包”的时间间隔&#xff1b; 范围是 6 ~ 3200&#xff0c;即 7.5…

Jmeter性能测试吞吐量控制器使用小结

吞吐量控制器(Throughput Controller)场景: 在同一个线程组里, 有10个并发, 7个做A业务, 3个做B业务,要模拟这种场景,可以通过吞吐量模拟器来实现.。 jmeter性能测试&#xff1a;2023最新的大厂jmeter性能测试全过程项目实战详解&#xff0c;悄悄收藏&#xff0c;后面就看不到…

Pytorch史上最全torch全版本离线文件下载地址大全(9月最新)

以下为pytorch官网的全版本torch文件离线下载地址 torch全版本whl文件离线下载大全https://download.pytorch.org/whl/torch/其中的文件版本信息如下所示&#xff08;部分版本信息&#xff0c;根据需要仔细寻找进行下载&#xff09;&#xff1a;

Web(1) 搭建漏洞环境(metasploitable2靶场/DVWA靶场)

简述渗透测试的步骤&#xff1b; 前期交互阶段→情报搜集阶段→威胁建模阶段→漏洞分析阶段→渗透攻击阶段→后渗透攻击阶段→报告阶段 (2)配置好metasploitable2靶场&#xff0c;截图 下载metasploitable2&#xff0c;VMware打开.vmx文件&#xff0c;登录&#xff0c;登陆用…

React 全栈体系(五)

第三章&#xff1a;React 应用(基于 React 脚手架) 一、使用 create-react-app 创建 react 应用 1. react 脚手架 xxx 脚手架: 用来帮助程序员快速创建一个基于 xxx 库的模板项目 包含了所有需要的配置&#xff08;语法检查、jsx 编译、devServer…&#xff09;下载好了所有…

一、8086

1、三大总线&#xff1a; &#xff08;1&#xff09;基础&#xff1a; 地址总线、数据总线、控制总线 &#xff08;2&#xff09;例题&#xff1a; 2、8086CPU &#xff08;1&#xff09;通用寄存器&#xff1a; 数据寄存器&#xff1a; 指针寄存器和变址寄存器&#xff1a…

国内首个潮玩行业沉浸式IP主题乐园,泡泡玛特城市乐园即将开园

近年来&#xff0c;泡泡玛特以潮玩IP为核心&#xff0c;不断拓展业务版图&#xff0c;推进国际化布局同时实现集团化运营&#xff0c;而泡泡玛特首个城市乐园将于9月下旬开业。据了解&#xff0c;泡泡玛特城市乐园是由泡泡玛特精心打造的沉浸式IP主题乐园&#xff0c;占地约4万…

linux新版本io框架 io_uring

从别的博主那copy过来&#xff1a; 1 io_uring是Linux内核的一个新型I/O事件通知机制&#xff0c;具有以下特点&#xff1a; 高性能&#xff1a;相比传统的select/poll/epoll等I/O多路复用机制&#xff0c;io_uring采用了更高效的ring buffer实现方式&#xff0c;可以在处理大量…