基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码

 第一步:准备数据

5种鸟类行为数据:self.class_indict =

 ["bowing_status", "grooming", "headdown", "vigilance_status", "walking"]

,总共有23790张图片,每个文件夹单独放一种数据

第二步:搭建模型

简介:

DenseNet(Dense Convolutional Network)稠密卷积网络
CVPR2017的优秀文章
从feature入手,通过对feature的极致利用达到更好的效果和更少的参数。


优点:

减轻了vanishing-gradient(梯度消失)
加强了feature的传递
更有效地利用了feature
一定程度上较少了参数数量


在深度学习网络中,随着网络深度的加深,梯度消失问题会愈加明显,解决方法是创建浅层与深层之间的短路径。在DenseNet中,在保证网络中层与层之间最大程度的信息传输的前提下,直接将所有层连接起来。

在传统卷积神经网络中,如果你有L层,那么就会有L个连接,但是在DenseNet中,会有(L+1)/2个连接。简单来说,就是每一层的输入来自前面所有层的输出。如下图是dense block的结构图,x是数据,H是网络层。

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import densenet121, load_state_dict
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "cd", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = densenet121(num_classes=args.num_classes).to(device)if args.weights != "":if os.path.exists(args.weights):load_state_dict(model, args.weights)else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除最后的全连接层外,其他权重全部冻结if "classifier" not in name:para.requires_grad_(False)pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4, nesterov=True)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# trainmean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateacc = evaluate(model=model,data_loader=val_loader,device=device)print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=16)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.1)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"E:\20240717\data")# densenet121 官方权重下载地址# https://download.pytorch.org/models/densenet121-a639ec97.pthparser.add_argument('--weights', type=str, default='densenet121.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

视频演示地址:基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码_哔哩哔哩_bilibili

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码见:基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/50813.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode 2824. 统计和小于目标的下标对数目

2824. 统计和小于目标的下标对数目 2824. 统计和小于目标的下标对数目 一、题目描述二、我的想法 一、题目描述 给你一个下标从 0 开始长度为 n 的整数数组 nums 和一个整数 target &#xff0c;请你返回满足 0 < i < j < n 且 nums[i] nums[j] < target 的下标对…

从零搭建pytorch模型教程(八)实践部分(二)目标检测数据集格式转换

前言 图像目标检测领域有一个非常著名的数据集叫做COCO&#xff0c;基本上现在在目标检测领域发论文&#xff0c;COCO是不可能绕过的Benchmark。因此许多的开源目标检测算法框架都会支持解析COCO数据集格式。通过将其他数据集格式转换成COCO格式可以无痛的使用这些开源框架来训…

【计算机网络】静态路由实验

一&#xff1a;实验目的 1&#xff1a;掌握通过静态路由方法实现网络的连通性。 二&#xff1a;实验仪器设备及软件 硬件&#xff1a;RCMS-C服务器、网线、Windows 2019/2003操作系统的计算机等。 软件&#xff1a;记事本、WireShark、Chrome浏览器等。 三&#xff1a;实验方…

Spring集成ES

RestAPI ES官方提供的java语言客户端用以组装DSL语句,再通过http请求发送给ES RestClient初始化 引入依赖 <dependency><groupId>org.elasticsearch.client</groupId><artifactId>elasticsearch-rest-high-level-client</artifactId> </d…

《分析模式:可重用对象模型》学习笔记之四:企业财务分析中的观察和测量02

这个模型基本解决问题&#xff0c;可以方便定义层次&#xff0c;以及反映了三个不同的维数元素&#xff0c;也反映了企业部门单元和维数元素的关系&#xff0c;但是很快可以看到&#xff0c;在这里&#xff0c;维数被局限在三个&#xff1a;也就是说&#xff0c;如果维数需要改…

ROS2教程(10) - 编写接收程序、添加frame - Linux

注意 : 本篇文章接上节 (点击此处跳转到上节) 编写接收程序 cpp <the_work_ws>/src/learning_tf2_cpp/src/turtle_tf2_listener.cpp #include <chrono> #include <functional> #include <memory> #include <string>#include "geometry_…

【c++】多线程

多线程可以解决什么问题&#xff0c;最重要的用途是什么&#xff1f; 多线程技术在现代软件开发中扮演着至关重要的角色&#xff0c;它可以解决多种问题并带来显著的好处。以下是多线程最重要的几个用途&#xff1a; 资源利用最大化: 多线程可以充分利用多核处理器的能力&…

#如何在PDF文件中添加图片和文本框?

在PDF文件中添加图片 可以通过多种方法实现&#xff0c;以下是一些常用的方法&#xff1a; 一、使用PDF编辑器 下载并安装PDF编辑器&#xff1a;首先&#xff0c;需要在官网或可靠来源下载并安装一个PDF编辑器&#xff0c;如福昕PDF编辑器、Adobe Acrobat等。打开PDF文件&am…

静止轨道卫星大气校正(Atmospheric Correction)和BRDF校正

文章内容仅用于自己知识学习和分享&#xff0c;如有侵权&#xff0c;还请联系并删除 &#xff1a;&#xff09; 目的&#xff1a; TOA reflectance 转为 surface refletance。 主要包含两步&#xff1a; 1&#xff09;大气校正&#xff1b; 2&#xff09;BRDF校正 进度&#x…

抖音矩阵管理系统开发:全面解析与推荐

在数字时代&#xff0c;短视频平台如抖音已经成为人们生活中不可或缺的一部分。随着内容创作者数量的激增&#xff0c;如何高效地管理多个抖音账号&#xff0c;实现内容矩阵化运营&#xff0c;成为了众多创作者关注的焦点。今天&#xff0c;我们就来全面解析抖音矩阵管理系统的…

Java_如何在IDEA中使用Git

注意&#xff1a;进行操作前首先要确保已经下载git&#xff0c;在IDEA中可以下载git&#xff0c;但是速度很慢&#xff0c;可以挂梯子下载。 导入git仓库代码 第一次导入&#xff1a; 首先得到要加载的git仓库的url&#xff1a; 在git仓库中点击 “克隆/下载” 按钮&#xf…

SpringBoot教程(十七) | SpringBoot集成swagger

SpringBoot教程&#xff08;十七&#xff09; | SpringBoot集成swagger 一、Swagger的简述二、SpringBoot集成swagger21. 引入依赖2. 新建SwaggerConfig配置类当 SpringBoot为2.6.x及以上时 需要注意 3.配置Swagger开关4. 给Controller 添加注解&#xff08;正式使用&#xff0…

PCIe 以太网芯片 RTL8125B 的 spec 和 Linux driver 分析备忘

1,下载 RTL8125B driver 下载页&#xff1a; https://www.realtek.com/Download/List?cate_id584 2,RTL8125B datasheet下载 下载页&#xff1a; https://file.elecfans.com/web2/M00/44/D8/poYBAGKHVriAHnfWADAT6T6hjVk715.pdf3, 编译driver 解压&#xff1a; $ tar xj…

鸿蒙OpenHarmony Native API【drawing_color.h与drawing_font_collection.h】 头文件

drawing_color.h Overview Related Modules: [Drawing] Description: 文件中定义了与颜色相关的功能函数 Since: 8 Version: 1.0 Summary Functions FunctionDescription[OH_Drawing_ColorSetArgb] (uint32_t alpha, uint32_t red, uint32_t green, uint32_t blue)u…

webrtc Android源码分析一

nativeCreateVideoSource 初始化 PeerConnectionFactory(pc/peerconnectionfactory) 创建PeerConnection方法中: rtc::scoped_refptr<PeerConnectionInterface> PeerConnectionFactory::CreatePeerConnection(const PeerConnectionInterface::RTCConfiguration& c…

机器学习第四十九周周报 GT

文章目录 week49 GY摘要Abstract1. 题目2. Abstract3. 网络结构3.1 graphon3.2 框架概览 4. 文献解读4.1 Introduction4.2 创新点4.3 实验过程4.3.1 有效性4.3.2 可转移性4.3.3 消融研究4.3.4 运行时间 5. 结论6.代码复现小结参考文献 week49 GY 摘要 本周阅读了题为Fine-tun…

46、PHP实现矩阵中的路径

题目&#xff1a; PHP实现矩阵中的路径 描述&#xff1a; 请设计一个函数&#xff0c;用来判断在一个矩阵中是否存在一条包含某字符串所有字符的路径。 路径可以从矩阵中的任意一个格子开始&#xff0c;每一步可以在矩阵中向左&#xff0c;向右&#xff0c;向上&#xff0c;向…

几个小创新模型,Transformer与SVM、LSTM、BiLSTM、Adaboost的结合,MATLAB分类全家桶再更新!...

截止到本期MATLAB机器学习分类全家桶&#xff0c;一共发了5篇&#xff0c;参考文章如下&#xff1a; 1.机器学习分类全家桶&#xff0c;模式识别&#xff0c;故障诊断的看这一篇绝对够了&#xff01;MATLAB代码 2. 再更新&#xff0c;机器学习分类全家桶&#xff0c;模式识别&a…

【四】jdk8基于m2芯片arm架构Ubuntu24虚拟机下载与安装

文章目录 1. 安装版本2. 开始安装3. 集群安装 1. 安装版本 如无特别说明&#xff0c;本文均在root权限下安装。进入oracle官网&#xff1a;https://www.oracle.com/java/technologies/downloads/找到最下面Java SE 看到java 8&#xff0c;下载使用 ARM64 Compressed Archive版…

vue3+vite纯前端实现自动触发浏览器刷新更新版本内容,并在打包时生成版本号文件

前言 在前端项目中&#xff0c;有时候为了实现自动触发浏览器刷新并更新版本内容&#xff0c;可以采取一系列巧妙的措施。我的项目中是需要在打包时候生成一个version.js文件&#xff0c;用当前打包时间作为版本的唯一标识&#xff0c;然后打包发版 &#xff0c;从实现对版本更…