3.AlexNet--CNN经典网络模型详解(pytorch实现)

        看博客AlexNet--CNN经典网络模型详解(pytorch实现)_alex的cnn-CSDN博客,该博客的作者写的很详细,是一个简单的目标分类的代码,可以通过该代码深入了解目标检测的简单框架。在这里不作详细的赘述,如果想更深入的了解,可以看另一个博客实现pytorch实现MobileNet-v2(CNN经典网络模型详解) - 知乎 (zhihu.com)。

在这里,直接写AlexNet--CNN的代码。

1.首先建立一个model.py文件,用来写神经网络,代码如下:

#model.pyimport torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(  #打包nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后nn.ReLU(inplace=True), #inplace 可以载入更大模型nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),#全链接nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1) #展平   或者view()x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值nn.init.constant_(m.bias, 0)

2.下载数据集

DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

3.下载完后写一个spile_data.py文件,将数据集进行分类

#spile_data.pyimport os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:mkfile('flower_data/train/'+cla)mkfile('flower_data/val')
for cla in flower_class:mkfile('flower_data/val/'+cla)split_rate = 0.1
for cla in flower_class:cla_path = file + '/' + cla + '/'images = os.listdir(cla_path)num = len(images)eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:image_path = cla_path + imagenew_path = 'flower_data/val/' + clacopy(image_path, new_path)else:image_path = cla_path + imagenew_path = 'flower_data/train/' + clacopy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")

之后应该是这样:
在这里插入图片描述

4.再写一个train.py文件,用来训练模型

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)#数据转换data_transform = {#具体是对图像进行各种转换操作,并用函数compose将这些转换操作组合起来#以下操作步骤:# 1.图片随机裁剪为224X224# 2.随机水平旋转,默认为概率0.5# 3.将给定图像转为Tensor# 4.归一化处理"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)# print(train_dataset)
# print(train_dataset[1][0].size())
# print(train_dataset.imgs[0][0])# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)
i=1
# for step, data in enumerate(train_loader, start=0):
#     images, labels = data
#     print(i,"==>",images.shape,labels.shape)
#     i+=1validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=True,num_workers=0)test_data_iter = iter(validate_loader)
test_image, test_label = test_data_iter.__next__()net = AlexNet(num_classes=5, init_weights=True)net.to(device)
#损失函数:这里用交叉熵
loss_function = nn.CrossEntropyLoss()
#优化器 这里用Adam
optimizer = optim.Adam(net.parameters(), lr=0.0002)
#训练参数保存路径
save_path = './AlexNet.pth'
#训练过程中最高准确率
best_acc = 0.0#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):# trainnet.train()    #训练过程中,使用之前定义网络中的dropoutrunning_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = data#print("step:     ",images.shape,labels)optimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)# validatenet.eval()    #测试过程中不需要dropout,使用所有的神经元acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/1783.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用rdtsc和C/C++来测量运行时间(如何使用内联汇编和获取CPU的TSC时钟频率)

本文主要是一个实验和思维扩展,除非你有特殊用途,不然不要使用汇编指令来实现这个功能。扩展阅读就列出了一些不需要内联汇编实现的 写本文是因为为了《Windows上的类似clock_gettime(CLOCK_MONOTONIC)的高精度测量时间函数》这篇文章找资料的时候&…

不同版本vue安装vue-router

vue-router 是vue官网发布的一个插件库,单页面路由。vue 和 vue-router 之间版本也需要对应。 vue2.x版本使用vue-router3.x版本,vue3.x使用vue-router4.x版本,根据自己的需要选择合适的版本 1、可以在安装前查看vue-router版本,…

说一说什么是并发队列,并发队列和并发集合的区别是什么

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:并发队列和并发集合,以及他们的区别 并发队列是一种特殊的队列数据结构,它能够支持多个线程同时对队列进行操作,包括插入和删除操作,而不需要…

陈奂仁联手 The Sandbox 推出“Hamsterz Doodles”人物化身系列

全新人物化身系列结合艺术与实用性 开创元宇宙新篇章 著名亚洲唱作歌手兼香港电影金像奖得主陈奂仁携手 The Sandbox,兴奋地宣布推出新的元宇宙人物化身系列 —— Hamsterz Doodles 仓鼠涂鸦。 陈奂仁在 The Sandbox 推出 Hamsterz Doodles 系列,将艺术与…

波士顿动力抛弃液压机器人Atlas,推出全新电动化机器人,动作超灵活

本周,机器人科技巨头波士顿动力宣布液压Atlas退役,并推出了下一代产品——专为实际应用而设计的全电动Atlas机器人,这也意味着人形机器人迈出了商业化的第一步。 Atlas——人形机器人鼻祖 Atlas(阿特拉斯)这个名字最…

STM32F407,429参考手册(中文)

发布一个适用STM32F405XX、STM32F407XX、STM32F415XX、STM32F417XX、STM32F427XX、STM32F437XX的中文数据手册,具体内容见下图: 点击下载(提取码:spnn) 链接: https://pan.baidu.com/s/1zqjKFdSV8PnHAHWLYPGyUA 提取码…

计算请假时间,只包含工作时间,不包含中午午休和非工作时间及星期六星期天,结束时间不能小于开始时间

1.计算相差小时,没有休息时间 computed: {// 计算相差小时time() {let time 0;if (this.ruleForm.date1 &&this.ruleForm.date2 &&this.ruleForm.date3 &&this.ruleForm.date4) {// 开始时间let date1 this.ruleForm.date1;let y date…

[笔试训练](二)

004 牛牛的快递_牛客题霸_牛客网 (nowcoder.com) 题目&#xff1a; 题解&#xff1a; 使用向上取整函数ceil()&#xff0c;&#xff08;记得添加头文件#include<cmath>&#xff09; #include <iostream> #include <cmath> using namespace std;int main(…

【深度学习实战(15)】使用训练好的语义分割模型进行推理测试

一、语义分割推理测试的一般流程 前处理 &#xff08;1&#xff09;get image &#xff08;2&#xff09;letter_box&#xff1a;o_h&#xff0c;o_w&#xff0c;i_h&#xff0c;i_w&#xff0c;n_h&#xff0c;n_w &#xff08;3&#xff09;1/250&#xff0c;CHW&#xff0c…

Java中ArrayList和顺序表

目录 1.线性表 2.顺序表 3 ArrayList简介 4. ArrayList使用 4.1 ArrayList的构造 4.2 ArrayList常见操作 4.3 ArrayList的遍历 1.线性表 线性表 &#xff08; linear list &#xff09; 是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结 …

Linux-延迟任务and定时任务

一.在系统中设定延迟任务要求如下 在系统中建立easylee用户&#xff0c;设定其密码为easylee 延迟任务由root用户建立 要求在5小时后备份系统中的用户信息文件到/backup中 确保延迟任务是使用非交互模式建立 确保系统中只有root用户和easylee用户可以执行延迟任务的设定 二.在…

ArrayList与顺序表(1)

前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; hellohello~&#xff0c;大家好&#x1f495;&#x1f495;&#xff0c;这里是E绵绵呀✋✋ &#xff0c;如果觉得这篇文章还不错的话还请点赞❤️❤️收藏&#x1f49e; &#x1f49e; 关注&#x1f4a5;&#x…

【苍穹外卖】HttpClient-快速理解入门

目录 HttpClient-快速理解&入门1. 需求2. 如何使用3. 具体示例4. 大致优点5. 大致缺点 HttpClient-快速理解&入门 1. 需求 在平常访问服务器里面的资源的时候&#xff0c;我们通常是通过浏览器输入网址&#xff08;或者在浏览器点击某个连接&#xff09;这种方式&…

测试的分类(2)

目录 按照执行方式分类 静态测试 动态测试 按照测试方法 灰盒测试 按照测试阶段分类 单元测试 集成测试 系统测试 冒烟测试 回归测试 按照执行方式分类 静态测试 所谓静态测试就是不实际运行被测软件,只是静态地检查程序代码, 界面或文档中可能存在错误的过程. 不以…

ffmpeg安装使用(详细)

目录结构 前言ffmpeg下载ffmpeg环境变量配置ffmpeg环境变量配置验证ffmpeg使用举例说明.mp4 转 .wav.mp3 转 .wav.ogg 转 .wav 参考链接 前言 本文主要记录ffmpeg在Windows系统中的安装使用方法。 ffmpeg下载 FFmpeg官网下载 ffmpeg环境变量配置 解压后将“.\ffmpeg\bin”…

vscode 如何断点调试ros1工程

在vscode中断点调试ros1工程主要分为以下几步&#xff1a; 1. 第一步就是修改cmakelist.txt&#xff0c;到调试模式。 将CMAKE_BUILD_TYPE原来对应的代码注释掉&#xff0c;原来的一般都不是调试模式。加上下面一行代码&#xff0c;意思是设置调试模式。 # 断点调试 SET(CMAK…

Python 运行时的目录信息

摘要说明&#xff1a; 在 python 执行过程中&#xff0c;会涉及各种目录信息&#xff0c;了解各种目录的含义和获取方式&#xff0c;可以让我们更好地进行代码控制&#xff0c;并进行相应的处理。 一. 操作场景说明 1. 几个目录和文件 Windows的命令行窗口所在目录 &#xf…

【Linux】文件基本属性

Linux 系统是一种典型的多用户系统&#xff0c;不同的用户处于不同的地位&#xff0c;拥有不同的权限。 为了保护系统的安全性&#xff0c;Linux 系统对不同的用户访问同一文件&#xff08;包括目录文件&#xff09;的权限做了不同的规定。 在 Linux 中我们可以使用 ll 或者 …

python 中pandas安装教程

1.打开cmd 输入代码 conda info --envs #查看环境 conda activate daiyi_Python #进入环境 conda list #查看环境 conda install pandas #下载pandas 输入y即可 conda uninstall pandas #删除pandas 输入y即可

《你想活出怎样的人生》上映,AOC带你打开宫崎骏的动画世界大门!

摘要&#xff1a;宫崎骏式美学&#xff0c;每一帧都是治愈&#xff01; 近日&#xff0c;宫崎骏新作《你想活出怎样的人生》正式公映。苍鹭与少年的冒险、奇幻瑰丽的场景、爱与成长的主题&#xff0c;让观众们收获到满满的爱与感动。宫崎骏总能以细腻的画面、温柔的音乐&#…