AlexNet(pytorch)

AlexNet是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Challenge)竞赛的冠军网络,分类准确率由传统的 70%+提升到 80%+

该网络的亮点在于:

(1)首次利用 GPU 进行网络加速训练。

(2)使用了 ReLU 激活函数,而不是传统的 Sigmoid 激活函数以及 Tanh 激活函数。

(3)使用了 LRN 局部响应归一化。

(4)在全连接层的前两层中使用了 Dropout 随机失活神经元操作,以减少过拟合

模型:

模型参数表:

model.py

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]: (55-3+0)/4 + 1=27nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

train.py

import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdmfrom model import AlexNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))#前期的网络还是用的Normalize标准化,之后的网络会用到BN批标准化data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../../"))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)#注意这里的数据加载还是直接用的torchvision.datasets.ImageFolder加载,#并不需要定义数据加载的脚本,可能是数据比较简单吧#定义数据集时候直接定义数据处理方法,之后torch.utils.data.DataLoader加载数据集加载时候直接调用这里定义的数据处理参数的方法#train文件夹下还有五种花的文件夹,这个具体处理看下面的代码,可能是ImageFolder直接加载文件夹里的图片文件train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])#训练集图片的个数train_num = len(train_dataset)#train_dataset.class_to_idx 是一个字典,将类别名称映射到相应的索引。#下行注释就是flower_list具体内容# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}# cla_dict是一个反转字典,将原始字典 flower_list 的键和值进行交换flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# json.dumps() 将 cla_dict 转换为格式化的 JSON 字符串。# 最后,将 JSON 字符串写入名为 class_indices.json 的文件中# indent 参数表示有几类json_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32#这个代码片段的目的是为了确定在并行计算时使用的最大工作进程数,并确保不超过系统的逻辑 CPU 核心数量和其他限制nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()## def imshow(img):#     img = img / 2 + 0.5  # unnormalize#     npimg = img.numpy()#     plt.imshow(np.transpose(npimg, (1, 2, 0)))#     plt.show()## print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))# imshow(utils.make_grid(test_image))net = AlexNet(num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()# pata = list(net.parameters())optimizer = optim.Adam(net.parameters(), lr=0.0002)epochs = 10save_path = './AlexNet.pth'best_acc = 0.0#一个epoch训练多少批次的数据,一批数据32个CWH,即32张图片train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0#这段代码使用了 tqdm 库来创建一个进度条,用于迭代训练数据集 train_loader 中的批次数据#file=sys.stdout 的作用是将进度条的输出定向到标准输出流,即将进度条显示在终端窗口中train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()#更新进度条的描述信息,显示当前训练的轮数、总轮数和损失值#这个loss是批次损失,在进度条上显示出来train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# 验证是训练完一个epoch后进行在验证集上验证,验证准确率net.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:#val_bar 的类型是 tqdm.tqdm,它是 tqdm 库中的一个类。该类提供了迭代器的功能,# 可以用于包装迭代器对象,并在循环中显示进度条和相关信息val_images, val_labels = val_dataoutputs = net(val_images.to(device))   #outputs:[batch_size,num_classes]predict_y = torch.max(outputs, dim=1)[1]  #torch.max  返回的第一个元素是张量数值,第二个是对应的索引acc += torch.eq(predict_y, val_labels.to(device)).sum().item()#验证完后计算验证集里所有的正确个数/总个数val_accurate = acc / val_num#总损失/训练总批次,求得平均每批的损失print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

训练过程:

using cuda:0 device.
Using 8 dataloader workers every process
using 3306 images for training, 364 images for validation.
train epoch[1/10] loss:1.215: 100%|██████████| 104/104 [00:23<00:00,  4.38it/s]
100%|██████████| 91/91 [00:15<00:00,  5.73it/s]
[epoch 1] train_loss: 1.342  val_accuracy: 0.478
train epoch[2/10] loss:1.111: 100%|██████████| 104/104 [00:19<00:00,  5.30it/s]
100%|██████████| 91/91 [00:15<00:00,  5.75it/s]
[epoch 2] train_loss: 1.183  val_accuracy: 0.533
train epoch[3/10] loss:1.252: 100%|██████████| 104/104 [00:19<00:00,  5.30it/s]
100%|██████████| 91/91 [00:15<00:00,  5.75it/s]
[epoch 3] train_loss: 1.097  val_accuracy: 0.604
train epoch[4/10] loss:0.730: 100%|██████████| 104/104 [00:19<00:00,  5.32it/s]
100%|██████████| 91/91 [00:15<00:00,  5.74it/s]
[epoch 4] train_loss: 1.025  val_accuracy: 0.607
train epoch[5/10] loss:0.961: 100%|██████████| 104/104 [00:19<00:00,  5.28it/s]
100%|██████████| 91/91 [00:16<00:00,  5.65it/s]
[epoch 5] train_loss: 0.941  val_accuracy: 0.676
train epoch[6/10] loss:0.853: 100%|██████████| 104/104 [00:19<00:00,  5.31it/s]
100%|██████████| 91/91 [00:15<00:00,  5.82it/s]
[epoch 6] train_loss: 0.915  val_accuracy: 0.659
train epoch[7/10] loss:1.032: 100%|██████████| 104/104 [00:19<00:00,  5.34it/s]
100%|██████████| 91/91 [00:15<00:00,  5.82it/s]
[epoch 7] train_loss: 0.864  val_accuracy: 0.684
train epoch[8/10] loss:0.704: 100%|██████████| 104/104 [00:19<00:00,  5.32it/s]
100%|██████████| 91/91 [00:15<00:00,  5.80it/s]
[epoch 8] train_loss: 0.842  val_accuracy: 0.706
train epoch[9/10] loss:1.279: 100%|██████████| 104/104 [00:19<00:00,  5.30it/s]
100%|██████████| 91/91 [00:15<00:00,  5.83it/s]
[epoch 9] train_loss: 0.825  val_accuracy: 0.714
train epoch[10/10] loss:0.796: 100%|██████████| 104/104 [00:19<00:00,  5.31it/s]
100%|██████████| 91/91 [00:15<00:00,  5.82it/s]
[epoch 10] train_loss: 0.801  val_accuracy: 0.703
Finished TrainingProcess finished with exit code 0

predict.py:

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import AlexNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "./test.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = AlexNet(num_classes=5).to(device)# load model weightsweights_path = "./AlexNet.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)#torch.load() 函数会根据路径加载模型的权重,并返回一个包含模型参数的字典#load_state_dict() 函数将加载的模型参数字典应用到 model 中,从而将预训练模型的参数加载到 model 中model.load_state_dict(torch.load(weights_path))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

预测结果:

我感觉pycharm的plt显示并不是特别明了

class: daisy        prob: 4.2e-06
class: dandelion    prob: 9.61e-07
class: roses        prob: 0.000773
class: sunflowers   prob: 1.28e-05
class: tulips       prob: 0.999

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227319.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Idea中操作Git使用cherry pick

Idea中操作Git使用cherry pick 使用场景使用功能步骤 使用场景 代码开发中,新功能还未开发完,但是master分支需要使用带新功能中的一次提交的代码,就可以使用cherry pack(优选). 使用功能步骤 切换到master分支选中dev分支双击选择需要使用的提交右键,如果有冲突就会弹窗解…

Netty—NIO万字详解

文章目录 NIO基本介绍同步、异步、阻塞、非阻塞IO的分类NIO 和 BIO 的比较NIO 三大核心原理示意图NIO的多路复用说明 核心一&#xff1a;缓存区 (Buffer)Buffer类及其子类Buffer缓冲区的分类MappedByteBuffer类说明&#xff1a; 核心二&#xff1a;通道 (Channel)Channel类及其…

防止反编译,保护你的SpringBoot项目

ClassFinal-maven-plugin插件是一个用于加密Java字节码的工具&#xff0c;它能够保护你的Spring Boot项目中的源代码和配置文件不被非法获取或篡改。下面是如何使用这个插件来加密test.jar包的详细步骤&#xff1a; 安装并设置Maven&#xff1a; 首先确保你已经在你的开发环境中…

windows 10 安装和配置nginx

1 下载nginx 1.1 下载地址&#xff1a;http://nginx.org/en/download.html 1.2 使用解压到安装目录 1.3 更改配置 conf目录下nginx.conf 修改为未被占用的端口&#xff0c;地址改成你的地址 server {listen 9999;server_name localhost;#charset koi8-r;#access_lo…

2 使用postman进行接口测试

上一篇&#xff1a;1 接口测试介绍-CSDN博客 拿到开发提供的接口文档后&#xff0c;结合需求文档开始做接口测试用例设计&#xff0c;下面用最常见也最简单的注册功能介绍整个流程。 说明&#xff1a;以演示接口测试流程为主&#xff0c;不对演示功能做详细的测试&#xff0c;…

【数据结构】双链表的定义和操作

目录 1.双链表的定义 2.双链表的创建和初始化 3.双链表的插入节点操作 4.双链表的删除节点操作 5.双链表的查找节点操作 6.双链表的更新节点操作 7.完整代码 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高兴与大家相识&#xff0c;希望我的博客能对你有所帮助…

WPF-UI HandyControl 控件简单实战

文章目录 前言UserControl简单使用新建项目直接新建项目初始化UserControlGeometry:矢量图形额外Icon导入最优解决方案 按钮Button切换按钮ToggleButton默认按钮图片可切换按钮加载按钮切换按钮 单选按钮和复选按钮没有太大特点&#xff0c;就不展开写了总结 DataGrid数据表格G…

详细了解stm32---按键

提示&#xff1a;永远支持知识文档免费开源&#xff0c;喜欢的朋友们&#xff0c;点个关注吧&#xff01;蟹蟹&#xff01; 目录 一、了解按键 二、stm32f103按键分析 三、按键应用 一、了解按键 同学们&#xff0c;又见面了o(*&#xffe3;▽&#xffe3;*)ブ&#xff0c;最…

C++ Qt开发:Tab与Tree组件实现分页菜单

Qt 是一个跨平台C图形界面开发库&#xff0c;利用Qt可以快速开发跨平台窗体应用程序&#xff0c;在Qt中我们可以通过拖拽的方式将不同组件放到指定的位置&#xff0c;实现图形化开发极大的方便了开发效率&#xff0c;本章将重点介绍tabWidget选择夹组件与TreeWidget树形选择组件…

升华 RabbitMQ:解锁一致性哈希交换机的奥秘【RabbitMQ 十】

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 升华 RabbitMQ&#xff1a;解锁一致性哈希交换机的奥秘【RabbitMQ 十】 前言第一&#xff1a;该插件需求为什么需要一种更智能的消息路由方式&#xff1f;一致性哈希的基本概念&#xff1a; 第二&…

【Linux】MySQL 数据库安装配置教程(Ubuntu 22.04)

前言 MySQL是一个流行的开源关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;广泛用于Web应用程序的后端数据存储&#xff0c;如许多动态网站、电子商务系统和在线出版物等。 MySQL具有高性能、可靠性和易用性的特点&#xff0c;它支持大型数据库&#xff0c;…

【Java】使用递归的方法获取层级关系数据demo

使用递归来完善各种业务数据的层级关系的获取 引言&#xff1a;在Java开发中&#xff0c;我们通常会遇到层层递进的关系型数据的获取问题&#xff0c;有时是树状解构&#xff0c;或金字塔结构&#xff0c;怎么描述都行&#xff0c;错综复杂的关系在程序中还是可以理清的。 这…

uniGUI之上传文件UniFileUploadButton

TUniFileUploadButton主要属性&#xff1a; Filter: 文件类型过滤&#xff0c;有图片image/* audio/* video/*三种过滤 MaxAllowedSize: 设置文件最大上传尺寸&#xff1b; Message&#xff1a;标题以及消息文本&#xff0c;可翻译成中文 TUniFileUploadButton控件 支持多…

云原生之深入解析Linkerd Service Mesh的功能和使用

一、简介 Linkerd 是 Kubernetes 的一个完全开源的服务网格实现&#xff0c;它通过为你提供运行时调试、可观测性、可靠性和安全性&#xff0c;使运行服务更轻松、更安全&#xff0c;所有这些都不需要对代码进行任何更改。Linkerd 通过在每个服务实例旁边安装一组超轻、透明的…

MX6ULL学习笔记(十二)Linux 自带的 LED 灯

前言 前面我们都是自己编写 LED 灯驱动&#xff0c;其实像 LED 灯这样非常基础的设备驱动&#xff0c;Linux 内 核已经集成了。Linux 内核的 LED 灯驱动采用 platform 框架&#xff0c;因此我们只需要按照要求在设备 树文件中添加相应的 LED 节点即可&#xff0c;本章我们就来学…

Python基础05-函数

零、文章目录 Python基础05-函数 1、函数的作用及其使用步骤 &#xff08;1&#xff09;函数的作用 在Python实际开发中&#xff0c;我们使用函数的目的只有一个“让我们的代码可以被重复使用” 函数的作用有两个&#xff1a; ① 代码重用&#xff08;代码重复使用&#xf…

【AI工具】GitHub Copilot IDEA安装与使用

GitHub Copilot是一款AI编程助手&#xff0c;它可以帮助开发者编写代码&#xff0c;提供代码建议和自动完成功能。以下是GitHub Copilot在IDEA中的安装和使用步骤&#xff1a; 安装步骤&#xff1a; 打开IDEA&#xff0c;点击File -> Settings -> Plugins。在搜索框中输…

windows10 php8连接sql server

一、环境安装 文章目录 一、环境安装1.安装php拓展2.在 Windows 上安装PHP驱动程序3.在 Windows 上安装ODBC驱动 二、php连接sqlserver三、注意事项数据库相关设置相关语法sqlsrv_fetch_array 的示例&#xff1a;sqlsrv_fetch 的示例&#xff1a;echo 和 print_r 的不同 所用资…

Webrtc 学习交流

花了几周的时间研究了一下webrtc &#xff0c;并开发了一个小项目&#xff0c;用来点对点私密聊天 交流传输文件等…后续会继续扩展其功能。 体验地址&#xff0c;大狗子的ID,我在线时可以连接测试到我 f3e0d6d0-cfd7-44a4-b333-e82c821cd927 项目特点 除了交换信令与stun 没…

ES6 面试题 | 01.精选 ES6 面试题

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…