霹雳吧啦Wz《pytorch图像分类》-p2AlexNet网络

《pytorch图像分类》p2AlexNet网络基础及代码

  • 一、零碎知识点
    • 1.过拟合
    • 2.使用dropout后的正向传播
    • 3.正则化regularization
    • 4.代码中所用的知识点
  • 二、总体架构分析
    • 1.ReLU激活函数
    • 2.手算
    • 3.模型代码
  • 三、训练花分类课程代码
    • 1.model.py
    • 2.train.py
    • 3.predict.py

一、零碎知识点

1.过拟合

模型假设过于复杂,参数过多,训练数据过少,噪声过多,导致拟合的函数完美的预测训练集,但对新数据的测试集预测结果差。 过度的拟合了训练数据,而没有考虑到泛化能力。
举个栗子:当我们在学习一门新的学科时会做一些例题,我们把这些例题完完整整背下来,但是一旦给出新的题目,还是不会做,这就是学习没有泛化能力,只记住了例题的细节而忽视了更普遍的规律。

2.使用dropout后的正向传播

dropout会在每一层当中随机失活一部分神经元,从而减少了神经元之间的共适应性,防止过拟合。
nn.Dropout(p= )p代表的是随机失活的比例,默认p=0.5

3.正则化regularization

是一种通过添加额外的约束或惩罚项来控制模型的复杂度的技术,其目的也是防止过拟合。
假设我们的模型是一个二次多项式,我们的目标是最小化损失函数(均方误差MSE),让预测的曲线与真实值尽可能接近。
L2正则化将惩罚项加到损失函数中,使用权重的平方和乘以一个正则化系数λ。
带正则化的损失函数为:
L o s s = M S E + λ ∗ ∣ ∣ w ∣ ∣ 2 Loss = MSE + λ * ||w||^2 Loss=MSE+λ∣∣w2
很抽象,以后再深入学习。

4.代码中所用的知识点

  1. 设备设置
    如果有可以使用的gpu设备,默认使用第一个,没有的话即使用cpu。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  1. RandomResizedCrop随机裁剪
    将裁剪后的图像调整为指定的大小(224x224)
transforms.RandomResizedCrop(224)
  1. root=“. ./. .”
    返回上上层目录

  2. CrossEntropyLoss交叉熵损失函数
    它是PyTorch中的一个损失函数,常用于多分类问题的训练中。它结合了Softmax激活函数和负对数似然损失(NLLLoss)
    详情请见我之前写的博客:多分类问题

  3. net.train( )和net.eval( )
    当调用net.train()时,模型将被设置为训练模式。在训练模式下,模型会启用一些特定的操作,如Batch Normalization归一化处理和Dropout随机失活一些神经元,防止过拟合。而调用net.eval()时,模型将被设置为评估(evaluate)模式。

二、总体架构分析

1.ReLU激活函数

ReLU(Rectified Linear Unit)线性整流函数,又称修正线性单元。在PyTorch中,可以使用torch.nn.ReLU类来表示。
其公式为:
f ( x ) = M a x ( 0 , x ) f(x)=Max(0,x) f(x)=Max(0,x)
它将小于零的输入映射为0,大于等于零的输入保持不变,求导简单方便。
我总结了这些天学习的几个激活函数:激活函数总结

2.手算

在这里插入图片描述

3.模型代码

nn.sequential将一系列的层结构打包组合成一个新的结构
classifer包括了三个全连接层
在这里插入图片描述
1.Conv1 第一个卷积层
代码表示为: nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
彩色图像有rgb三个通道,所以输入通道数为3,输出通道数=卷积核的个数kernel_num
为了方便训练,卷积核个数只取一半,padding直接取2,训练效果与原本的影响不大

 nn.Conv2d(3,48,kernel_size=11,stride=4,padding=2)

2.Maxpool1 第一个池化层
代码表示为:nn.MaxPool2d(kernel_size,stride)

nn.MaxPool2d(kernel_size=3,stride=2)

3.后续代码

nn.Conv2d(48,128,kernel_size=5,padding=2)
nn.MaxPool2d(kernel_size=3,stride=2)
nn.Conv2d(128,192,kernel_size=3,padding=1)
nn.Conv2d(192,192,kernel_size=3,padding=1)
nn.Conv2d(192,128,kernel_size=3,padding=1)
nn.MaxPool2d(kernel_size=3,stride=2)

三、训练花分类课程代码

1.model.py

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

2.train.py

import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdmfrom model import AlexNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4, shuffle=True,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()## def imshow(img):#     img = img / 2 + 0.5  # unnormalize#     npimg = img.numpy()#     plt.imshow(np.transpose(npimg, (1, 2, 0)))#     plt.show()## print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))# imshow(utils.make_grid(test_image))net = AlexNet(num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()# pata = list(net.parameters())optimizer = optim.Adam(net.parameters(), lr=0.0002)epochs = 10save_path = './AlexNet.pth'best_acc = 0.0train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

3.predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import AlexNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "1..jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = AlexNet(num_classes=5).to(device)# load model weightsweights_path = "./AlexNet.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

1.jpg是郁金香
在这里插入图片描述
2.jpg是向日葵
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/586530.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java多线程<三>常见的多线程设计模式

多线程的设计模式 两阶段线程终止 park方法 interrupted() 会让他失效。 使用volatile关键字进行改写 单例模式 双锁检测 保护性暂停 实现1: package threadBase.model;/*** author: Zekun Fu* date: 2022/5/29 19:01* Description:* 保护性暂停,* …

打砖块,Android休闲小游戏开发

A. 项目描述 《打砖块》是一款经典的休闲小游戏 ,结合了经典的图形和音效,给玩家带来了轻松愉快的游戏体验。 该游戏操作简单易上手。玩家只需通过触摸屏幕控制底部的“拍子”左右移动,以反弹“小球” 击碎 顶部的砖块。玩家可以根据球的角度…

基于JAVA+SSM+VUE的前后端分离的大学竞赛管理系统

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取项目下载方式🍅 一、项目背景介绍: 随着互联网技术的快速…

青龙面板的安装

一、安装docker 首先,需要在服务器上安装docker。 没有服务器的可以使用虚拟机,或申请一台三丰云的免费云服务器体验一下,独立IP地址,送免备案服务,可以满足基本的使用,三丰云上还有免费虚拟主机等其他免费…

ES6之解构赋值详解

✨ 专栏介绍 在现代Web开发中,JavaScript已经成为了不可或缺的一部分。它不仅可以为网页增加交互性和动态性,还可以在后端开发中使用Node.js构建高效的服务器端应用程序。作为一种灵活且易学的脚本语言,JavaScript具有广泛的应用场景&#x…

Springboot整合MybatisPlus的基本CRUD

目录 前言1. 搭建项目2. 基本的CRUD 前言 发现项目框架是MybatisPlus的,由于个人使用该框架的CRUD比较少 对此学习过程中,从零到有开始搭建学习还是比较重要的,感悟会比较多 关于各个类的使用,可看如下文章: 剖析Ja…

Java—AOP案例-记录操作日志

简介:上一篇文章“JAVA语言—AOP基础”已经详细的介绍了AOP的各个功能接口,已经使用步骤,这篇文章就是基于此来做的一个小案例。案例的功能是记录登录的用户对于数据库表的相关信息进行增、删、查、改的操作记录下来,并且存储到数…

腾讯云轻量应用服务器详细介绍(全网超详细说明)

腾讯云轻量应用服务器开箱即用、运维简单的轻量级云服务器,CPU内存带宽配置高并且价格特别优惠,轻量2核2G3M带宽62元一年、2核2G4M优惠价118元一年,540元三年、2核4G5M带宽218元一年,756元3年、4核8G12M带宽646元15个月等&#xf…

微信小程序开发系列-08自定义组件模版特性

微信小程序开发系列目录 《微信小程序开发系列-01创建一个最小的小程序项目》《微信小程序开发系列-02注册小程序》《微信小程序开发系列-03全局配置中的“window”和“tabBar”》《微信小程序开发系列-04获取用户图像和昵称》《微信小程序开发系列-05登录小程序》《微信小程序…

点成案例 | 如何利用细胞计数仪在单细胞测序中评估细胞

一、概述 单细胞测序技术能够用来表征异常细胞群,分析稀有细胞和细胞图谱网络,发现异质性等。由于单细胞测序巨大的应用潜力,目前此技术正在经历爆炸性增长。然而,单细胞测序需要成本和时间的大量投资。为了确保时间和资源的投资…

正确的认识 字节码文件

上一篇中认识了JVM的基本组成,我们说JVM只认识字节码文件。那么在字节码文件进入JVM之前,我们先认识了解字节码文件长什么样,我们作为工程师不需要去死扣底层的理论知识,但是我们只是需要正确的打开字节码文件 知道里面有哪些部分…

[Angular] 笔记 22:ElementRef

chatgpt: ElementRef 是 Angular 中的一个类,它用于包装对 DOM 元素的引用。它允许开发者直接访问与 Angular 组件关联的宿主 DOM 元素。 当在 Angular 中需要直接操作 DOM 元素时,可以使用 ElementRef。通常情况下,最好避免直接操作 DOM&a…

Prism介绍

Prism介绍 Prism是一个框架,用于在WPF、Xamarin Forms、Uno Platform和WinUI中构建松散耦合、可维护和可测试的XAML应用程序。 设计目标 为了实现下列目的: 创建能够由模块组成的程序,这些模块能够被单独地编写、组装、部署,并且对…

十三:爬虫-Scrapy框架(下)

一:各文件的使用回顾 1.items的使用 items 文件主要用于定义储存爬取到的数据的数据结构,方便在爬虫和 Item Pipeline 之间传递数据。 items.pyimport scrapyclass TencentItem(scrapy.Item):# define the fields for your item here like:title scr…

jmeter函数助手-常用汇总

一.函数助手介绍 1.介绍及作用 介绍: jmeter自带的一个特性,可以通过指定的函数规则创建后进行调用该函数,在后续接口请求参数中进行调用 作用 (1)做参数化。 2.如何使用 jmeter工具栏-->工具-->函数助手…

LabVIEW在大型风电机组状态监测系统开发中的应用

LabVIEW在大型风电机组状态监测系统开发中的应用 风电作为一种清洁能源,近年来在全球范围内得到了广泛研究和开发。特别是大型风力发电机组,由于其常常位于边远地区如近海、戈壁、草原等,面临着恶劣自然环境和复杂设备运维挑战。为了提高风电…

DockerCompose - 容器编排、模板命令、compose命令、Pottainer 可视化界面管理(一文通关)

目录 一、DockerCompose 容器编排 1.1、简介 1.2、Docker-Compose 安装 1.2.1、在线安装 1.2.2、离线安装 1.3、docker-compose.yml 中的模板命令 前置说明 模板命令 1.4、DockerCompse 命令 前置说明 up down exec ps restart rm top pause暂停 和 unpause恢…

linux下的进程布局与ububtu操作系统下的proc文件夹学习笔记一

相关内容我写在公众号,写的挺详细的,欢迎关注我的公众号。请使用鼠标右键,新建标签页打开,直接点击显示参数错误,不知道怎么回事?linux下的进程布局与ububtu操作系统下的proc文件夹学习笔记 (qq.com)https:…

Windows下配置GCC(MinGW)环境

一、下载并安装MinGW 步骤1:下载MinGW安装器 前往MinGW的官方下载源,通过以下链接可以获取到最新版的MinGW安装程序: 网页地址:https://sourceforge.net/projects/mingw/files/ [MinGW 下载地址](https://sourceforge.net/proj…

二级路由的配置以及注意项

二级路由 比如说LayOut组件是父亲,LayOut和ArtComp是儿子,那我们怎么给儿子配路由呢? 1、首先在router下的index.js导入组件,配置规则,详细如下 // 导入路由相关组件 import LayOut from /views/LayOut import UserC…