基于安卓的虫害识别软件设计--(1)模型训练与可视化

引言

  • 简介:使用pytorch框架,从模型训练、模型部署完整地实现了一个基础的图像识别项目
  • 计算资源:使用的是Kaggle(每周免费30h的GPU)

1.创建名为“utils_1”的模块

模块中包含:训练和验证的加载器函数训练函数验证函数

import os
import sysimport torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdmdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def get_train_loader(image_path):train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform = train_transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32,shuffle=True, num_workers= 0)return train_loaderdef get_val_loader(image_path):val_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "validation"),transform = val_transform)val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=32,shuffle = False, num_workers = 0)return val_loaderdef train(train_loader,net):net.train()train_correct = 0.0train_loss = 0.0  # 初始化训练损失train_bar = tqdm(train_loader, file=sys.stdout)loss_function = nn.CrossEntropyLoss()loss_function = loss_function.to(device)optimizer = optim.Adam(net.parameters(), lr=0.001)for step, data in enumerate(train_bar):images, labels = dataimages, labels = images.to(device),labels.to(device)# 梯度清零optimizer.zero_grad()# 训练outputs = net(images)# 计算损失loss = loss_function(outputs, labels)# 反向传播loss.backward()# 更新权重optimizer.step()# 统计_, preds = outputs.max(1)correct = preds.eq(labels).sum()train_correct += correcttrain_loss += loss.item()  # 累加损失值train_bar.desc = 'Training Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(loss.item(),(100. * correct) / len(outputs),trained_samples=step * train_loader.batch_size + len(images),total_samples=len(train_loader.dataset))train_correct = (100. * train_correct) / len(train_loader.dataset)train_loss /= len(train_loader)  # 计算平均损失值return train_correct, train_loss  # 返回训练正确率和平均损失值def val(val_loader,net):net.eval()val_correct = 0.0val_loss = 0.0  # 初始化验证损失loss_function = nn.CrossEntropyLoss()loss_function = loss_function.to(device)val_bar = tqdm(val_loader, file=sys.stdout)for step, data in enumerate(val_bar):images, labels = dataimages, labels = images.to(device), labels.to(device)with torch.no_grad():# 验证outputs = net(images)# 计算损失loss = loss_function(outputs, labels)# 统计_, preds = outputs.max(1)correct = preds.eq(labels).sum()val_correct += correctval_loss += loss.item()  # 累加损失值val_bar.desc = 'Valing Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(loss.item(),(100. * correct) / len(outputs),trained_samples=step * val_loader.batch_size + len(images),total_samples=len(val_loader.dataset))val_correct = (100. * val_correct) / len(val_loader.dataset)val_loss /= len(val_loader)  # 计算平均损失值return val_correct , val_loss  # 返回验证正确率和平均损失值

注意:若使用Kaggle,想要导入该模块,需要添加以下代码

import sys
sys.path.append(r'/kaggle/input/mycode2')

其中,模块路径如下图


2.主函数 

主函数包含:使用模型函数训练主函数画图代码

2.1使用模型函数 

【若使用其他模型,可chatgpt创建其函数】

(1)resnet101 

def get_resnet101(class_num):net_name = "resnet101"net = torchvision.models.resnet101(pretrained=True)net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input featuresnet = net.to(device)return net_name, net

(2)resnet34 

def get_resnet34(class_num):net_name = "resnet34"net = torchvision.models.resnet34(pretrained=True)net.fc = Linear(in_features=512, out_features=class_num, bias=True)net = net.to(device)return net_name,net

(3)mobilenetv2

def get_mobilenet_v2(class_num):net_name = "mobilenet_v2"net = torchvision.models.mobilenet_v2(pretrained=True)net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)net = net.to(device)return net_name,net

 2.2画图代码 

    save_path="/kaggle/working/"  plt.figure(figsize=(12, 4))# lossplt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')plt.legend()plt.xlabel('Epoch')plt.ylabel('Loss')# accplt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')plt.legend()plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend()plt.savefig(os.path.join(save_path, 'result.png')) # 保存plt.show()

2.3完整代码 

import torch
import torchvision.models
from matplotlib import pyplot as plt
from torch.nn import Linear
import os# 导入自己创建的模块
from utils_1 import get_train_loader, train, val, get_val_loader# 模型选择
def get_resnet101(class_num):net_name = "resnet101"net = torchvision.models.resnet101(pretrained=True)net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input featuresnet = net.to(device)return net_name, net# def get_resnet34(class_num):
#     net_name = "resnet34"
#     net = torchvision.models.resnet34(pretrained=True)
#     net.fc = Linear(in_features=512, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net# def get_mobilenet_v2(class_num):
#     net_name = "mobilenet_v2"
#     net = torchvision.models.mobilenet_v2(pretrained=True)
#     net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net# 训练主函数
if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#1 加载数据image_path = r"/kaggle/input/fruits3"train_loader = get_train_loader(image_path)val_loader = get_val_loader(image_path)#2 加载模型net_name,net = get_resnet34(class_num=5)#3 训练epochs = 5best_acc = 0train_losses = []val_losses = []train_accs = []val_accs = []for epoch in range(epochs):train_acc,train_loss = train(train_loader, net)val_acc,val_loss = val(val_loader, net)train_losses.append(train_loss)val_losses.append(val_loss)train_accs.append(train_acc.item())val_accs.append(val_acc.item())if best_acc<val_acc:best_acc = val_acctorch.save(net, os.path.join("/kaggle/working/", net_name + ".pt"))# 画图save_path="/kaggle/working/" # 图片保存路径plt.figure(figsize=(12, 4))# lossplt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')plt.legend()plt.xlabel('Epoch')plt.ylabel('Loss')# accplt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')plt.legend()plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend()plt.savefig(os.path.join(save_path, 'result.png')) # 保存plt.show()

2.4训练效果与模型文件

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/19771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python爬虫--scrapy+selenium框架】超详细的Python爬虫scrapy+selenium框架学习笔记(保姆级别的,非常详细)

六&#xff0c;selenium 想要下载PDF或者md格式的笔记请点击以下链接获取 python爬虫学习笔记点击我获取 Scrapyselenium详细学习笔记点我获取 Python超详细的学习笔记共21万字点我获取 1&#xff0c;下载配置 ## 安装&#xff1a; pip install selenium## 它与其他库不同…

【C++】C++11新特性:列表初始化、声明、新容器、右值引用、万能引用和完美转发

目录 一、列表初始化 1.1 { } 初始化 1.2 std::initializer_list 二、声明 2.1 auto 2.2 decltype 2.3 nullptr 三、新容器 四、右值引用和移动语义 4.1 左值和左值引用 4.2 右值和右值引用 4.3 左值引用与右值引用比较 4.4 右值引用使用场景和意义&#xff1a;移…

万字长文深度解析Agent反思工作流框架Reflexion下篇:ReflectionAgent workflow

在前文[LLM-Agents]万字长文深度解析Agent反思工作流框架Reflexion中篇&#xff1a;React中&#xff0c;我们详细解析了ReactAgent的工作流程&#xff0c;而本文则将在此基础上探讨反思技巧的应用。之前的文章中[LLM-Agents]反思Reflection 工作流我们已经对反思技巧进行了探讨…

多维数组操作,不要再用遍历循环foreach了!来试试数组展平的小妙招!array.flat()用法与array.flatMap() 用法及二者差异详解

目录 一、array.flat&#xff08;&#xff09;方法 1.1、array.flat&#xff08;&#xff09;的语法及使用 ①语法 ②返回值 ③用途 二、array.flatMap() 方法 2.1、array.flatMap()的语法及作用 ①语法 ②返回值 ③用途 三、array.flat&#xff08;&#xff09;与a…

CLIP 源码分析:simple_tokenizer.py

tokenizer的含义 from .clip import *引入头文件时为什么有个. 正文 import gzip import html import os from functools import lru_cacheimport ftfy import regex as re# 上面的都是头文件# 这段代码定义了一个函数 default_bpe()&#xff0c;它使用了装饰器 lru_cache()。…

一维时间序列信号的改进小波降噪方法(MATLAB R2021B)

目前国内外对于小波分析在降噪方面的方法研究中&#xff0c;主要有小波分解与重构法降噪、小波阈值降噪、小波变换模极大值法降噪等三类方法。 (1)小波分解与重构法降噪 早在1988 年&#xff0c;Mallat提出了多分辨率分析的概念&#xff0c;利用小波分析的多分辨率特性进行分…

DAQmx Connect Terminals (VI) 信号路由作用及意义

DAQmx Connect Terminals是一个LabVIEW虚拟仪器&#xff08;VI&#xff09;&#xff0c;用于配置和连接数据采集系统中的物理终端或虚拟终端。这一功能在配置复杂的数据采集&#xff08;DAQ&#xff09;系统时非常重要&#xff0c;因为它允许用户在不改变硬件连接的情况下&…

Python——Selenium快速上手+方法(一站式解决问题)

目录 前言 一、Selenium是什么 二、Python安装Selenium 1、安装Selenium第三方库 2、下载浏览器驱动 3、使用Python来打开浏览器 三、Selenium的初始化 四、Selenium获取网页元素 4.1、获取元素的实用方法 1、模糊匹配获取元素 & 联合多个样式 2、使用拉姆达表达式 3、加上…

Python零基础-下【详细】

接上篇继续&#xff1a; Python零基础-中【详细】-CSDN博客 目录 十七、网络编程 1、初识socket &#xff08;1&#xff09;socket理解 &#xff08;2&#xff09;图解socket &#xff08;3&#xff09;戏说socket &#xff08;4&#xff09;网络服务 &#xff08;5&a…

实战经验分享之移动云快速部署Stable Diffusion SDXL 1.0

本文目录 前言产品优势部署环境准备模型安装测试运行 前言 移动云是中国移动面向政府、企业和公众的新型资源服务。 客户以购买服务的方式&#xff0c;通过网络快速获取虚 拟计算机、存储、网络等基础设施服务&#xff1b;软件开发工具、运行环境、数据库等平台服务&#xff1…

【评价类模型】熵权法

1.客观赋权法&#xff1a; 熵权法是一种客观求权重的方法&#xff0c;所有客观求权重的模型中都要有以下几步&#xff1a; 1.正向化处理&#xff1a; 极小型指标&#xff1a;取值越小越好的指标&#xff0c;例如错误率、缺陷率等。 中间项指标&#xff1a;取值在某个范围内较…

[ubuntu18.04]搭建mptcp测试环境说明

MPTCP介绍 Multipath TCP — Multipath TCP -- documentation 2022 documentation 安装ubuntu18.04&#xff0c;可以使用虚拟机安装 点击安装VMware Tool 桌面会出现如下图标 双击打开VMware Tools&#xff0c;复制如下图所示的文件到Home目录 打开终端&#xff0c;切换到管…

[Linux]重定向

一、struct file内核对象 struct file是在内核中创建&#xff0c;专门用来管理被打开文件的结构体。struct file中包含了打开文件的所有属性&#xff0c;文件的操作方法集以及文件缓冲区&#xff08;无论读写&#xff0c;我们都需要先将数据加载到文件缓冲区中。&#xff09;等…

基于JSP的高校二手交易平台

开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;JSP技术 工具&#xff1a;浏览器&#xff08;如360浏览器、谷歌浏览器、QQ浏览器等&#xff09;、MySQL数据库 系统展示 系统功能界面 用户注册与登录界面 个人中心界面 商品信息界面 摘要 本文研究了高…

为何懂行的人都在选海信Mini LED?

今年的618大促比往年来得要更早一些。纵览各电商平台的电视产品&#xff0c;能发现Mini LED电视的出镜率很高&#xff0c;成了各大品牌的主推产品。 对于什么样的Mini LED更值得买&#xff0c;各品牌都有自己的说辞。因为缺乏科学系统的选购标准&#xff0c;消费者容易在各方说…

【Qt】【模型-视图架构】代理模型示例

文章目录 1. 基本排序/过滤模型Basic Sort/Filter Model Example2. 自定义排序/过滤模型Custom Sort/Filter Model ExampleFilterLineEdit类定义及实现MySortFilterProxyModel类定义及实现 1. 基本排序/过滤模型Basic Sort/Filter Model Example 官方提供的基本排序/过滤模型示…

docker镜像体积优化攻略参考—— 筑梦之路

简单介绍 镜像的本质是镜像层和运行配置文件组成的压缩包&#xff0c;构建镜像是通过运行 Dockerfile 中的 RUN 、COPY 和 ADD 等指令生成镜像层和配置文件的过程。 和镜像体积大小有关的关键点&#xff1a; RUN、COPY 和 ADD 指令会在已有镜像层的基础上创建一个新的镜像层&…

【数据结构】详解二叉树

文章目录 1.树的结构及概念1.1树的概念1.2树的相关结构概念1.3树的表示1.4树在实际中的应用 2.二叉树的结构及概念2.1二叉树的概念2.2特殊的二叉树2.2.1满二叉树2.2.2完全二叉树 2.3 二叉树的性质2.4二叉树的存储结构2.4.1顺序结构2.4.2链表结构 1.树的结构及概念 1.1树的概念…

基于SSM的车辆租赁管理系统(含源码+sql+视频导入教程)

&#x1f449;文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 基于SSM的车辆租赁管理系统1拥有两种角色 管理员&#xff1a;用户管理、用户租车、用户换车和车辆入库、添加汽车、添加客户、生成出租单、客户选车、出租单管理、查询出租单、角色权限管…

登录校验及全局异常处理器

登录校验 会话技术 会话:用户打开浏览器,访问web服务器的资源,会话建立,直到有一方断开连接,会话结束.在一次会话中可以包含多次请求和响应会话跟踪:一种维护浏览器状态的方法,服务器需要识别多次请求是否来自于同一浏览器,以便在同一次会话请求间共享数据会话跟踪方案 客户端…