基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码

 第一步:准备数据

5种鸟类行为数据:self.class_indict =

 ["bowing_status", "grooming", "headdown", "vigilance_status", "walking"]

,总共有23790张图片,每个文件夹单独放一种数据

第二步:搭建模型

简介:

DenseNet(Dense Convolutional Network)稠密卷积网络
CVPR2017的优秀文章
从feature入手,通过对feature的极致利用达到更好的效果和更少的参数。


优点:

减轻了vanishing-gradient(梯度消失)
加强了feature的传递
更有效地利用了feature
一定程度上较少了参数数量


在深度学习网络中,随着网络深度的加深,梯度消失问题会愈加明显,解决方法是创建浅层与深层之间的短路径。在DenseNet中,在保证网络中层与层之间最大程度的信息传输的前提下,直接将所有层连接起来。

在传统卷积神经网络中,如果你有L层,那么就会有L个连接,但是在DenseNet中,会有(L+1)/2个连接。简单来说,就是每一层的输入来自前面所有层的输出。如下图是dense block的结构图,x是数据,H是网络层。

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import densenet121, load_state_dict
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "cd", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = densenet121(num_classes=args.num_classes).to(device)if args.weights != "":if os.path.exists(args.weights):load_state_dict(model, args.weights)else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除最后的全连接层外,其他权重全部冻结if "classifier" not in name:para.requires_grad_(False)pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4, nesterov=True)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# trainmean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateacc = evaluate(model=model,data_loader=val_loader,device=device)print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=16)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.1)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"E:\20240717\data")# densenet121 官方权重下载地址# https://download.pytorch.org/models/densenet121-a639ec97.pthparser.add_argument('--weights', type=str, default='densenet121.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

视频演示地址:基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码_哔哩哔哩_bilibili

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码见:基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/50813.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从零搭建pytorch模型教程(八)实践部分(二)目标检测数据集格式转换

前言 图像目标检测领域有一个非常著名的数据集叫做COCO,基本上现在在目标检测领域发论文,COCO是不可能绕过的Benchmark。因此许多的开源目标检测算法框架都会支持解析COCO数据集格式。通过将其他数据集格式转换成COCO格式可以无痛的使用这些开源框架来训…

【计算机网络】静态路由实验

一:实验目的 1:掌握通过静态路由方法实现网络的连通性。 二:实验仪器设备及软件 硬件:RCMS-C服务器、网线、Windows 2019/2003操作系统的计算机等。 软件:记事本、WireShark、Chrome浏览器等。 三:实验方…

《分析模式:可重用对象模型》学习笔记之四:企业财务分析中的观察和测量02

这个模型基本解决问题,可以方便定义层次,以及反映了三个不同的维数元素,也反映了企业部门单元和维数元素的关系,但是很快可以看到,在这里,维数被局限在三个:也就是说,如果维数需要改…

静止轨道卫星大气校正(Atmospheric Correction)和BRDF校正

文章内容仅用于自己知识学习和分享,如有侵权,还请联系并删除 :) 目的: TOA reflectance 转为 surface refletance。 主要包含两步: 1)大气校正; 2)BRDF校正 进度&#x…

抖音矩阵管理系统开发:全面解析与推荐

在数字时代,短视频平台如抖音已经成为人们生活中不可或缺的一部分。随着内容创作者数量的激增,如何高效地管理多个抖音账号,实现内容矩阵化运营,成为了众多创作者关注的焦点。今天,我们就来全面解析抖音矩阵管理系统的…

Java_如何在IDEA中使用Git

注意:进行操作前首先要确保已经下载git,在IDEA中可以下载git,但是速度很慢,可以挂梯子下载。 导入git仓库代码 第一次导入: 首先得到要加载的git仓库的url: 在git仓库中点击 “克隆/下载” 按钮&#xf…

SpringBoot教程(十七) | SpringBoot集成swagger

SpringBoot教程(十七) | SpringBoot集成swagger 一、Swagger的简述二、SpringBoot集成swagger21. 引入依赖2. 新建SwaggerConfig配置类当 SpringBoot为2.6.x及以上时 需要注意 3.配置Swagger开关4. 给Controller 添加注解(正式使用&#xff0…

PCIe 以太网芯片 RTL8125B 的 spec 和 Linux driver 分析备忘

1,下载 RTL8125B driver 下载页: https://www.realtek.com/Download/List?cate_id584 2,RTL8125B datasheet下载 下载页: https://file.elecfans.com/web2/M00/44/D8/poYBAGKHVriAHnfWADAT6T6hjVk715.pdf3, 编译driver 解压: $ tar xj…

鸿蒙OpenHarmony Native API【drawing_color.h与drawing_font_collection.h】 头文件

drawing_color.h Overview Related Modules: [Drawing] Description: 文件中定义了与颜色相关的功能函数 Since: 8 Version: 1.0 Summary Functions FunctionDescription[OH_Drawing_ColorSetArgb] (uint32_t alpha, uint32_t red, uint32_t green, uint32_t blue)u…

机器学习第四十九周周报 GT

文章目录 week49 GY摘要Abstract1. 题目2. Abstract3. 网络结构3.1 graphon3.2 框架概览 4. 文献解读4.1 Introduction4.2 创新点4.3 实验过程4.3.1 有效性4.3.2 可转移性4.3.3 消融研究4.3.4 运行时间 5. 结论6.代码复现小结参考文献 week49 GY 摘要 本周阅读了题为Fine-tun…

几个小创新模型,Transformer与SVM、LSTM、BiLSTM、Adaboost的结合,MATLAB分类全家桶再更新!...

截止到本期MATLAB机器学习分类全家桶,一共发了5篇,参考文章如下: 1.机器学习分类全家桶,模式识别,故障诊断的看这一篇绝对够了!MATLAB代码 2. 再更新,机器学习分类全家桶,模式识别&a…

【四】jdk8基于m2芯片arm架构Ubuntu24虚拟机下载与安装

文章目录 1. 安装版本2. 开始安装3. 集群安装 1. 安装版本 如无特别说明,本文均在root权限下安装。进入oracle官网:https://www.oracle.com/java/technologies/downloads/找到最下面Java SE 看到java 8,下载使用 ARM64 Compressed Archive版…

vue3+vite纯前端实现自动触发浏览器刷新更新版本内容,并在打包时生成版本号文件

前言 在前端项目中,有时候为了实现自动触发浏览器刷新并更新版本内容,可以采取一系列巧妙的措施。我的项目中是需要在打包时候生成一个version.js文件,用当前打包时间作为版本的唯一标识,然后打包发版 ,从实现对版本更…

五大设备制造商的 200 多种机型的安全启动功能完全失效

2012 年,一个由硬件和软件制造商组成的行业联盟采用了安全启动技术,以防范长期存在的安全威胁。这种威胁是恶意软件的幽灵,它可以感染 BIOS,即每次计算机启动时加载操作系统的固件。从那里,它可以保持不受检测和删除&a…

从零开始学Java(超详细韩顺平老师笔记梳理)08——面向对象编程中级(上)IDEA常用快捷键、包、封装、继承

文章目录 前言一、IDEA使用常用快捷键模板/自定义模板 二、包package1. 基本介绍2. 包的命名规范3. 常用的包和如何引入4. 注意事项和细节 三、访问修饰符(四类)四、封装Encapsulation(重点)1. 封装介绍2. 封装步骤3. 快速入门4. …

SpringCloud Nacos的配置与使用

Spring Cloud Nacos的配置与使用 文章目录 Spring Cloud Nacos的配置与使用1. 简单介绍2. 环境搭建3. 服务注册/服务发现4. Nacos 负载均衡4.1 服务下线4.2 权重配置4.3 同集群优先访问 5. Nacos 健康检查5.1 两种健康检查机制5.2 服务实例类型 6.Nacos 环境隔离6.1 创建namesp…

【MySQL进阶之路 | 高级篇】表级锁之S锁,X锁,意向锁

1. 从数据操作的粒度划分:表级锁,页级锁,行锁 为了尽可能提高数据库的并发度,每次锁定的数据范围越小越好,理论上每次只锁定当前操作的数据的方案会得到最大的并发度,但是管理锁是很耗资源的事情&#xff…

驾驭代码的无形疆界:动态内存管理揭秘

目录 1.:为什么要有动态内存分配 2.malloc和free 2.1:malloc 2.2:free 3.calloc和realloc 3.1:calloc 3.1.1:代码1(malloc) 3.1.2:代码2(calloc) 3.2:realloc 3.2.1:原地扩容 3.2.2:异地扩容 3.2.3:代码1(原地扩容) 3.2.3:代码2(异地扩容) 4:常见的动态内存的错误…

vite + xlsx + xlsx-style 导出 Excel

如下 npm i 依赖 npm i xlsxnpm i xlsx-style-vite1、简单的使用:.vue文件中使用 const dataSource ref([]) // 数据源const columns [{title: 用户名,key: userName,width: 120,},{title: 用户组,key: userGroup,width: 120,},{title: 状态,key: enable,width: …

鸿蒙(HarmonyOS)下拉选择控件

一、操作环境 操作系统: Windows 11 专业版、IDE:DevEco Studio 3.1.1 Release、SDK:HarmonyOS 3.1.0(API 9) 二、效果图 三、代码 SelectPVComponent.ets Component export default struct SelectPVComponent {Link selection: SelectOption[]priva…