Alnet网络分析与demo实例

参考自 

  • up主的b站链接:霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频
  • 这位大佬的博客 Fun'_机器学习,pytorch图像分类,工具箱-CSDN博客

数据集下载

http://download.tensorflow.org/example_images/flower_photos.tgz

包含 5 中类型的花,每种类型有600~900张图像不等。

训练集与测试集划分

由于此数据集不像 CIFAR10 那样下载时就划分好了训练集和测试集,因此需要自己划分。具体操作可以看b站那个up 的视频,这里不再赘述

AlexNet详解

重点关注它和上一个模型不一样的地方

1.首次用GPU进行加速训练,上图上下两部分是完全一样的,这是因为用了两块GPU加速训练

2.使用了Relu函数

3.用了Dropout随即失活部分神经元 以减少过拟合

具体网络分析:

Conv1

输入:224*224*3

卷积:11*11*3    48个

  • padding = [1, 2] (左上围加半圈0,右下围加2倍的半圈0
  • stride = 4

输出:(224-11+3)/4+1 = 55   55*55*48

Maxpool1

输入:55*55*48

池化层:

  • kernel_size = 3
  • padding = 0
  • stride = 2

输出:(55-3)/2+1 = 27  27*27*48

Conv2

输入:27*27*48

卷积:5*5*48  128个

  • padding = [2, 2]
  • stride = 1

输出:(27-5+4)/1+1 = 27    27*27*128

Maxpool2

输入:27*27*128

  • 池化层:(只改变尺寸,不改变深度channel)
    • kernel_size = 3
    • padding = 0
    • stride = 2

输出:13*13*128

Conv3

输入:13*13*128

  • 卷积层:
    • 3*3 192个
    • padding = [1, 1]
    • stride = 1

输出:13*13*192

Conv4

输入:13*13*192

  • 卷积层:
    • 3*3 192个
    • padding = [1, 1]
    • stride = 1

输出: 13*13**192

Conv5

输入:13*13*192

  • 卷积层:
    • 3*3*128
    • padding = [1, 1]
    • stride = 1

输出:13*13*128

Maxpool3

输入:13*13*128

  • 池化层:
    • kernel_size = 3
    • padding = 0
    • stride = 2

输出:6*6*128

FC1、FC2、FC3

Maxpool3 → (6*6*256) → FC1 → 2048 → FC2 → 2048 → FC3 → 1000

总结

分析可以发现,除 Conv1 外,AlexNet 的其余卷积层都是在改变特征矩阵的深度,而池化层则只改变(减小)其尺寸。

构建模型;

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()# 用nn.Sequential()将网络打包成一个模块,精简代码self.features = nn.Sequential(   # 卷积层提取图像特征nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]nn.ReLU(inplace=True), 									# 直接修改覆盖原值,节省运算内存nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(   # 全连接层对图像分类nn.Dropout(p=0.5),			   # Dropout 随机失活神经元,默认比例为0.5nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()# 前向传播过程def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)	# 展平后再传入全连接层x = self.classifier(x)return x# 网络权重初始化,实际上 pytorch 在构建网络时会自动初始化权重def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):                            # 若是卷积层nn.init.kaiming_normal_(m.weight, mode='fan_out',   # 用(何)kaiming_normal_法初始化权重nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)                    # 初始化偏重为0elif isinstance(m, nn.Linear):            # 若是全连接层nn.init.normal_(m.weight, 0, 0.01)    # 正态分布初始化nn.init.constant_(m.bias, 0)          # 初始化偏重为0

Dropout  :  发现都有具体的api 想具体研究的可以去看看它的函数是怎么写的

数据预处理 - 图像增强

需要注意的是,对训练集的预处理,多了随机裁剪和水平翻转这两个步骤。可以起到扩充数据集的作用,增强模型泛化能力

data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),       # 随机裁剪,再缩放成 224×224transforms.RandomHorizontalFlip(p=0.5),  # 水平方向随机翻转,概率为 0.5, 即一半的概率翻转, 一半的概率不翻转transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

训练:

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json# 预处理
data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("蒲公英.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))# 关闭 Dropout
model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))     # 将输出压缩,即压缩掉 batch 这个维度predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

预测:

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json# 预处理
data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("蒲公英.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))# 关闭 Dropout
model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))     # 将输出压缩,即压缩掉 batch 这个维度predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

打印出预测的标签以及概率值:

dandelion 0.7221569418907166

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/241851.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文降重方法同义词替换的效果对比与评价 快码论文

大家好,今天来聊聊论文降重方法同义词替换的效果对比与评价,希望能给大家提供一点参考。 以下是针对论文重复率高的情况,提供一些修改建议和技巧,可以借助此类工具: 标题:论文降重方法同义词替换的效果对比…

LIMS实验室管理软件多少钱 智慧实验室管理系统价格

随着实验室管理需求的日益复杂,实验室信息管理系统(LIMS)逐渐成为实验室管理的必备工具。然而,对于许多实验室而言,选择合适的LIMS系统是一项挑战,其中价格是一个重要的考虑因素。本文将以白码LIMS系统为例,探讨LIMS实…

LangChain 31 模块复用Prompt templates 提示词模板

LangChain系列文章 LangChain 实现给动物取名字,LangChain 2模块化prompt template并用streamlit生成网站 实现给动物取名字LangChain 3使用Agent访问Wikipedia和llm-math计算狗的平均年龄LangChain 4用向量数据库Faiss存储,读取YouTube的视频文本搜索I…

【模式识别】探秘分类奥秘:最近邻算法解密与实战

​🌈个人主页:Sarapines Programmer🔥 系列专栏:《模式之谜 | 数据奇迹解码》⏰诗赋清音:云生高巅梦远游, 星光点缀碧海愁。 山川深邃情难晤, 剑气凌云志自修。 目录 🌌1 初识模式识…

智能优化算法应用:基于北方苍鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于北方苍鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于北方苍鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.北方苍鹰算法4.实验参数设定5.算法结果6.…

nodejs微信小程序+python+PHP医疗机构药品及耗材信息管理系统-计算机毕业设计推荐

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…

RT-Thread启动过程

RT-Thread启动流程 一般了解一份代码大多从启动部分开始,同样这里也采用这种方式,先寻找启动的源头。 RT-Thread支持多种平台和多种编译器,而rtthread_startup()函数是RT-Thread规定的统一启动入口。 一般执行顺序是:系统先从启…

CSS3:绘制多边形

clip-path&#xff1a;该属性使用裁剪方式创建元素的可显示区域&#xff0c;区域内的显示&#xff0c;区域外的不显示。 构建一个三角形 <div class"mybox"></div><style>.mybox {width: 100px;height: 100px;background-color: yellow;clip-path…

JavaOOP篇----第十四篇

系列文章目录 文章目录 系列文章目录前言一、Hashcode的作用二、Java的四种引用,强弱软虚三、Java创建对象有几种方式?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码…

Android可折叠设备完全指南:展开未来

Android可折叠设备完全指南&#xff1a;展开未来 探索如何使用Android Jetpack组件折叠和展开设备。 近年来&#xff0c;科技界见证了可折叠设备的革命性趋势。这些设备融合了便携性和功能性的创新特点&#xff0c;使用户能够在不同的形态之间无缝切换。在本博客中&#xff0c…

浅析海博深造

文章目录 深造作用 留学种类 选专业 择校 申请流程 申请方式 深造作用 1、个人能力提升&#xff08;学术专业、语言、新文化或新生活方式&#xff09; 2、更好的职业发展&#xff08;起点更高、结交新朋友或扩大社交圈&#xff09; 3、北京上海落户优惠 4、海外居留福…

前端问题记录

jenkins安装vue依赖报错 jenkins 安装依赖&#xff0c;报错cannot find module ‘/root/.jenkins/workspace/项目路径/node_modules/xxxx’&#xff0c;如图上 解决&#xff1a;执行 npm install vue/cli-service --unsafe-perm&#xff0c;再执行npm i

解决 MATLAB 遗传算法中 exitflg=4 的问题

一、优化问题简介 以求解下述优化问题为例&#xff1a; P 1 : min ⁡ p ∑ k 1 K p k s . t . { ∑ k 1 K R k r e q l o g ( 1 α k ∗ p k ) ≤ B b s , ∀ k ∈ K p k ≥ 0 , ∀ k ∈ K \begin{align} {P_1:}&\mathop{\min}_{\bm{p}}{ \sum\limits_{k1}^K p_k } \no…

通过 Nginx 代理实现网页内容替换

突发奇想&#xff0c;用 Nginx 代理一个网站&#xff0c;把网站的一些关键字替换掉&#xff0c;蛮有意思的。 如下图&#xff1a; 一、编译安装 Nginx 一般 Nginx 中不包含 subs_filter 文本替换的模块&#xff0c;需要自己手动编译安装&#xff0c;步骤如下。 克隆 subs_fi…

物联网产品设计,聊聊设备OTA的升级

物联网产品设计部分的OTA设备固件是一个非常重要的部分&#xff0c;能够实现升级用户服务、保障系统安全等功能。 在迅速变化和发展的物联网市场&#xff0c;新的产品需求不断涌现&#xff0c;因此对于智能硬件设备的更新需求就变得空前高涨&#xff0c;设备不再像传统设备一样…

linux循环调度执行

9.2 循环调度执行 9.2.1 简介 cron的概念和crontab是不可分割的。 ​ crontab是一个命令&#xff0c;常见于Unix和Linux的操作系统之中用于设置周期性被执行的指令。 ​ 该命令从标准输入设备读取指令&#xff0c;并将其存放于“crontab”文件中&#xff0c;以供之后读取和执…

移除石子使总数最小(LeetCode日记)

LeetCode-1962-移除石子使总数最小 题目信息: 给你一个整数数组 p i l e s piles piles &#xff0c;数组 下标从 0 0 0 开始 &#xff0c;其中 p i l e s [ i ] piles[i] piles[i] 表示第 i i i 堆石子中的石子数量。另给你一个整数 k k k &#xff0c;请你执行下述操作…

Django开发2

Django开发2 Django开发1.新建项目2.创建app3.设计表结构&#xff08;django&#xff09;4.在MySQL中生成表5.静态文件管理6.部门管理7.模板的继承8.用户管理8.1 初识Form1. views.py2.user_add.html 8.3 ModelForm&#xff08;推荐&#xff09;0. models.py1. views.py2.user_…

机器人创新实验室任务三参考文档

一、JAVA环境配置 需要在Linux里面下载并且安装java。 sudo apt-get install openjdk-17-jre-headless 打开终端并且运行指令&#xff0c;用apt下载安装java。官方用的好像是java11&#xff0c;我安装的是java17。 如果无法定位软件安装包&#xff0c;可以试试更新一下 sudo …

直接插入排序【从0-1学数据结构】

文章目录 &#x1f497; 直接插入排序Java代码C代码JavaScript代码稳定性时间复杂度空间复杂度 我们先来学习 直接插入排序, 直接排序算是所有排序中最简单的了,代码也非常好实现,尽管直接插入排序很简单,但是我们依旧不可以上来就直接写代码,一定要分析之后才开始写,这样可以提…