3.AlexNet--CNN经典网络模型详解(pytorch实现)

        看博客AlexNet--CNN经典网络模型详解(pytorch实现)_alex的cnn-CSDN博客,该博客的作者写的很详细,是一个简单的目标分类的代码,可以通过该代码深入了解目标检测的简单框架。在这里不作详细的赘述,如果想更深入的了解,可以看另一个博客实现pytorch实现MobileNet-v2(CNN经典网络模型详解) - 知乎 (zhihu.com)。

在这里,直接写AlexNet--CNN的代码。

1.首先建立一个model.py文件,用来写神经网络,代码如下:

#model.pyimport torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(  #打包nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后nn.ReLU(inplace=True), #inplace 可以载入更大模型nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),#全链接nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1) #展平   或者view()x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值nn.init.constant_(m.bias, 0)

2.下载数据集

DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

3.下载完后写一个spile_data.py文件,将数据集进行分类

#spile_data.pyimport os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:mkfile('flower_data/train/'+cla)mkfile('flower_data/val')
for cla in flower_class:mkfile('flower_data/val/'+cla)split_rate = 0.1
for cla in flower_class:cla_path = file + '/' + cla + '/'images = os.listdir(cla_path)num = len(images)eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:image_path = cla_path + imagenew_path = 'flower_data/val/' + clacopy(image_path, new_path)else:image_path = cla_path + imagenew_path = 'flower_data/train/' + clacopy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")

之后应该是这样:
在这里插入图片描述

4.再写一个train.py文件,用来训练模型

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)#数据转换data_transform = {#具体是对图像进行各种转换操作,并用函数compose将这些转换操作组合起来#以下操作步骤:# 1.图片随机裁剪为224X224# 2.随机水平旋转,默认为概率0.5# 3.将给定图像转为Tensor# 4.归一化处理"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)# print(train_dataset)
# print(train_dataset[1][0].size())
# print(train_dataset.imgs[0][0])# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)
i=1
# for step, data in enumerate(train_loader, start=0):
#     images, labels = data
#     print(i,"==>",images.shape,labels.shape)
#     i+=1validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=True,num_workers=0)test_data_iter = iter(validate_loader)
test_image, test_label = test_data_iter.__next__()net = AlexNet(num_classes=5, init_weights=True)net.to(device)
#损失函数:这里用交叉熵
loss_function = nn.CrossEntropyLoss()
#优化器 这里用Adam
optimizer = optim.Adam(net.parameters(), lr=0.0002)
#训练参数保存路径
save_path = './AlexNet.pth'
#训练过程中最高准确率
best_acc = 0.0#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):# trainnet.train()    #训练过程中,使用之前定义网络中的dropoutrunning_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = data#print("step:     ",images.shape,labels)optimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)# validatenet.eval()    #测试过程中不需要dropout,使用所有的神经元acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/1783.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用rdtsc和C/C++来测量运行时间(如何使用内联汇编和获取CPU的TSC时钟频率)

本文主要是一个实验和思维扩展,除非你有特殊用途,不然不要使用汇编指令来实现这个功能。扩展阅读就列出了一些不需要内联汇编实现的 写本文是因为为了《Windows上的类似clock_gettime(CLOCK_MONOTONIC)的高精度测量时间函数》这篇文章找资料的时候&…

不同版本vue安装vue-router

vue-router 是vue官网发布的一个插件库,单页面路由。vue 和 vue-router 之间版本也需要对应。 vue2.x版本使用vue-router3.x版本,vue3.x使用vue-router4.x版本,根据自己的需要选择合适的版本 1、可以在安装前查看vue-router版本,…

陈奂仁联手 The Sandbox 推出“Hamsterz Doodles”人物化身系列

全新人物化身系列结合艺术与实用性 开创元宇宙新篇章 著名亚洲唱作歌手兼香港电影金像奖得主陈奂仁携手 The Sandbox,兴奋地宣布推出新的元宇宙人物化身系列 —— Hamsterz Doodles 仓鼠涂鸦。 陈奂仁在 The Sandbox 推出 Hamsterz Doodles 系列,将艺术与…

波士顿动力抛弃液压机器人Atlas,推出全新电动化机器人,动作超灵活

本周,机器人科技巨头波士顿动力宣布液压Atlas退役,并推出了下一代产品——专为实际应用而设计的全电动Atlas机器人,这也意味着人形机器人迈出了商业化的第一步。 Atlas——人形机器人鼻祖 Atlas(阿特拉斯)这个名字最…

STM32F407,429参考手册(中文)

发布一个适用STM32F405XX、STM32F407XX、STM32F415XX、STM32F417XX、STM32F427XX、STM32F437XX的中文数据手册,具体内容见下图: 点击下载(提取码:spnn) 链接: https://pan.baidu.com/s/1zqjKFdSV8PnHAHWLYPGyUA 提取码…

计算请假时间,只包含工作时间,不包含中午午休和非工作时间及星期六星期天,结束时间不能小于开始时间

1.计算相差小时,没有休息时间 computed: {// 计算相差小时time() {let time 0;if (this.ruleForm.date1 &&this.ruleForm.date2 &&this.ruleForm.date3 &&this.ruleForm.date4) {// 开始时间let date1 this.ruleForm.date1;let y date…

[笔试训练](二)

004 牛牛的快递_牛客题霸_牛客网 (nowcoder.com) 题目&#xff1a; 题解&#xff1a; 使用向上取整函数ceil()&#xff0c;&#xff08;记得添加头文件#include<cmath>&#xff09; #include <iostream> #include <cmath> using namespace std;int main(…

ArrayList与顺序表(1)

前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; hellohello~&#xff0c;大家好&#x1f495;&#x1f495;&#xff0c;这里是E绵绵呀✋✋ &#xff0c;如果觉得这篇文章还不错的话还请点赞❤️❤️收藏&#x1f49e; &#x1f49e; 关注&#x1f4a5;&#x…

【苍穹外卖】HttpClient-快速理解入门

目录 HttpClient-快速理解&入门1. 需求2. 如何使用3. 具体示例4. 大致优点5. 大致缺点 HttpClient-快速理解&入门 1. 需求 在平常访问服务器里面的资源的时候&#xff0c;我们通常是通过浏览器输入网址&#xff08;或者在浏览器点击某个连接&#xff09;这种方式&…

测试的分类(2)

目录 按照执行方式分类 静态测试 动态测试 按照测试方法 灰盒测试 按照测试阶段分类 单元测试 集成测试 系统测试 冒烟测试 回归测试 按照执行方式分类 静态测试 所谓静态测试就是不实际运行被测软件,只是静态地检查程序代码, 界面或文档中可能存在错误的过程. 不以…

ffmpeg安装使用(详细)

目录结构 前言ffmpeg下载ffmpeg环境变量配置ffmpeg环境变量配置验证ffmpeg使用举例说明.mp4 转 .wav.mp3 转 .wav.ogg 转 .wav 参考链接 前言 本文主要记录ffmpeg在Windows系统中的安装使用方法。 ffmpeg下载 FFmpeg官网下载 ffmpeg环境变量配置 解压后将“.\ffmpeg\bin”…

vscode 如何断点调试ros1工程

在vscode中断点调试ros1工程主要分为以下几步&#xff1a; 1. 第一步就是修改cmakelist.txt&#xff0c;到调试模式。 将CMAKE_BUILD_TYPE原来对应的代码注释掉&#xff0c;原来的一般都不是调试模式。加上下面一行代码&#xff0c;意思是设置调试模式。 # 断点调试 SET(CMAK…

【Linux】文件基本属性

Linux 系统是一种典型的多用户系统&#xff0c;不同的用户处于不同的地位&#xff0c;拥有不同的权限。 为了保护系统的安全性&#xff0c;Linux 系统对不同的用户访问同一文件&#xff08;包括目录文件&#xff09;的权限做了不同的规定。 在 Linux 中我们可以使用 ll 或者 …

《你想活出怎样的人生》上映,AOC带你打开宫崎骏的动画世界大门!

摘要&#xff1a;宫崎骏式美学&#xff0c;每一帧都是治愈&#xff01; 近日&#xff0c;宫崎骏新作《你想活出怎样的人生》正式公映。苍鹭与少年的冒险、奇幻瑰丽的场景、爱与成长的主题&#xff0c;让观众们收获到满满的爱与感动。宫崎骏总能以细腻的画面、温柔的音乐&#…

设计模式——模板方法

1)模板方法模式(Template Method Pattem)&#xff0c;又叫模板模式(Template Patern)&#xff0c;在一个抽象类公开定义了执行它的方法的模板。它的子类可以按需要重写方法实现&#xff0c;但调用将以抽象类中定义的方式进行。 2)简单说&#xff0c;模板方法模式 定义一个操作中…

ETLCloud中多并行分支运行的设计技巧

在大数据处理领域&#xff0c;ETL&#xff08;Extract, Transform, Load&#xff09;流程是至关重要的一环&#xff0c;它涉及数据的提取、转换和加载&#xff0c;以确保数据的质量和可用性。而在ETL流程中&#xff0c;多并行分支的运行设计是一项关键技巧&#xff0c;可以有效…

华为ensp中MSTP多网段传输协议(原理及配置命令)

作者主页&#xff1a;点击&#xff01; ENSP专栏&#xff1a;点击&#xff01; 创作时间&#xff1a;2024年4月22日15点29分 在华为ENSP中&#xff0c;MSTP&#xff08;多段传输协议&#xff09;是重要的生成树协议&#xff0c;它扩展了STP&#xff08;生成树协议&#xff09…

猴子摘桃问题(C语言)

一、N-S流程图&#xff1b; 二、运行结果&#xff1b; 三、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h>int main() {//初始化变量值&#xff1b;int sum 1;int i 0;//运算&#xff1b;for (i 1; i < 10; i){//运算&#xff1b;sum …

typecho博客的相对地址实现

typecho其中的博客地址,必须写上绝对地址,否则在迁移网址的时候会出现问题,例如页面记载异常 修改其中的 typecho\var\Widget\Options\General.php 中的165行左右, /** 站点地址 */if (!defined(__TYPECHO_SITE_URL__)) {$siteUrl new Form\Element\Text(siteUrl,null,$this-…

怎么把3d模型旋转加复制---模大狮模型网

在3D设计中&#xff0c;旋转和复制模型是常见且重要的操作&#xff0c;它们可以帮助设计师创建复杂的场景并节省时间。本文将介绍如何在3D建模软件中旋转并复制模型&#xff0c;以及一些技巧和注意事项&#xff0c;帮助您轻松实现这些操作。 旋转3D模型&#xff1a; 旋转3D模型…