基于安卓的虫害识别软件设计--(1)模型训练与可视化

引言

  • 简介:使用pytorch框架,从模型训练、模型部署完整地实现了一个基础的图像识别项目
  • 计算资源:使用的是Kaggle(每周免费30h的GPU)

1.创建名为“utils_1”的模块

模块中包含:训练和验证的加载器函数训练函数验证函数

import os
import sysimport torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdmdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def get_train_loader(image_path):train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform = train_transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32,shuffle=True, num_workers= 0)return train_loaderdef get_val_loader(image_path):val_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "validation"),transform = val_transform)val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=32,shuffle = False, num_workers = 0)return val_loaderdef train(train_loader,net):net.train()train_correct = 0.0train_loss = 0.0  # 初始化训练损失train_bar = tqdm(train_loader, file=sys.stdout)loss_function = nn.CrossEntropyLoss()loss_function = loss_function.to(device)optimizer = optim.Adam(net.parameters(), lr=0.001)for step, data in enumerate(train_bar):images, labels = dataimages, labels = images.to(device),labels.to(device)# 梯度清零optimizer.zero_grad()# 训练outputs = net(images)# 计算损失loss = loss_function(outputs, labels)# 反向传播loss.backward()# 更新权重optimizer.step()# 统计_, preds = outputs.max(1)correct = preds.eq(labels).sum()train_correct += correcttrain_loss += loss.item()  # 累加损失值train_bar.desc = 'Training Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(loss.item(),(100. * correct) / len(outputs),trained_samples=step * train_loader.batch_size + len(images),total_samples=len(train_loader.dataset))train_correct = (100. * train_correct) / len(train_loader.dataset)train_loss /= len(train_loader)  # 计算平均损失值return train_correct, train_loss  # 返回训练正确率和平均损失值def val(val_loader,net):net.eval()val_correct = 0.0val_loss = 0.0  # 初始化验证损失loss_function = nn.CrossEntropyLoss()loss_function = loss_function.to(device)val_bar = tqdm(val_loader, file=sys.stdout)for step, data in enumerate(val_bar):images, labels = dataimages, labels = images.to(device), labels.to(device)with torch.no_grad():# 验证outputs = net(images)# 计算损失loss = loss_function(outputs, labels)# 统计_, preds = outputs.max(1)correct = preds.eq(labels).sum()val_correct += correctval_loss += loss.item()  # 累加损失值val_bar.desc = 'Valing Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(loss.item(),(100. * correct) / len(outputs),trained_samples=step * val_loader.batch_size + len(images),total_samples=len(val_loader.dataset))val_correct = (100. * val_correct) / len(val_loader.dataset)val_loss /= len(val_loader)  # 计算平均损失值return val_correct , val_loss  # 返回验证正确率和平均损失值

注意:若使用Kaggle,想要导入该模块,需要添加以下代码

import sys
sys.path.append(r'/kaggle/input/mycode2')

其中,模块路径如下图


2.主函数 

主函数包含:使用模型函数训练主函数画图代码

2.1使用模型函数 

【若使用其他模型,可chatgpt创建其函数】

(1)resnet101 

def get_resnet101(class_num):net_name = "resnet101"net = torchvision.models.resnet101(pretrained=True)net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input featuresnet = net.to(device)return net_name, net

(2)resnet34 

def get_resnet34(class_num):net_name = "resnet34"net = torchvision.models.resnet34(pretrained=True)net.fc = Linear(in_features=512, out_features=class_num, bias=True)net = net.to(device)return net_name,net

(3)mobilenetv2

def get_mobilenet_v2(class_num):net_name = "mobilenet_v2"net = torchvision.models.mobilenet_v2(pretrained=True)net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)net = net.to(device)return net_name,net

 2.2画图代码 

    save_path="/kaggle/working/"  plt.figure(figsize=(12, 4))# lossplt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')plt.legend()plt.xlabel('Epoch')plt.ylabel('Loss')# accplt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')plt.legend()plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend()plt.savefig(os.path.join(save_path, 'result.png')) # 保存plt.show()

2.3完整代码 

import torch
import torchvision.models
from matplotlib import pyplot as plt
from torch.nn import Linear
import os# 导入自己创建的模块
from utils_1 import get_train_loader, train, val, get_val_loader# 模型选择
def get_resnet101(class_num):net_name = "resnet101"net = torchvision.models.resnet101(pretrained=True)net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input featuresnet = net.to(device)return net_name, net# def get_resnet34(class_num):
#     net_name = "resnet34"
#     net = torchvision.models.resnet34(pretrained=True)
#     net.fc = Linear(in_features=512, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net# def get_mobilenet_v2(class_num):
#     net_name = "mobilenet_v2"
#     net = torchvision.models.mobilenet_v2(pretrained=True)
#     net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net# 训练主函数
if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#1 加载数据image_path = r"/kaggle/input/fruits3"train_loader = get_train_loader(image_path)val_loader = get_val_loader(image_path)#2 加载模型net_name,net = get_resnet34(class_num=5)#3 训练epochs = 5best_acc = 0train_losses = []val_losses = []train_accs = []val_accs = []for epoch in range(epochs):train_acc,train_loss = train(train_loader, net)val_acc,val_loss = val(val_loader, net)train_losses.append(train_loss)val_losses.append(val_loss)train_accs.append(train_acc.item())val_accs.append(val_acc.item())if best_acc<val_acc:best_acc = val_acctorch.save(net, os.path.join("/kaggle/working/", net_name + ".pt"))# 画图save_path="/kaggle/working/" # 图片保存路径plt.figure(figsize=(12, 4))# lossplt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')plt.legend()plt.xlabel('Epoch')plt.ylabel('Loss')# accplt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')plt.legend()plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend()plt.savefig(os.path.join(save_path, 'result.png')) # 保存plt.show()

2.4训练效果与模型文件

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/19771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python爬虫--scrapy+selenium框架】超详细的Python爬虫scrapy+selenium框架学习笔记(保姆级别的,非常详细)

六&#xff0c;selenium 想要下载PDF或者md格式的笔记请点击以下链接获取 python爬虫学习笔记点击我获取 Scrapyselenium详细学习笔记点我获取 Python超详细的学习笔记共21万字点我获取 1&#xff0c;下载配置 ## 安装&#xff1a; pip install selenium## 它与其他库不同…

【C++】C++11新特性:列表初始化、声明、新容器、右值引用、万能引用和完美转发

目录 一、列表初始化 1.1 { } 初始化 1.2 std::initializer_list 二、声明 2.1 auto 2.2 decltype 2.3 nullptr 三、新容器 四、右值引用和移动语义 4.1 左值和左值引用 4.2 右值和右值引用 4.3 左值引用与右值引用比较 4.4 右值引用使用场景和意义&#xff1a;移…

「前端+鸿蒙」核心技术HTML5+CSS3(二)

1、开发者文档 开发者文档通常由浏览器厂商或技术社区提供,包含有关Web技术(如HTML、CSS、JavaScript)的详细信息,API文档,以及最佳实践。例如,MDN Web Docs是一个广泛认可的开发者资源。 2、块级元素与行列元素 块级元素:在页面上占据整行的元素,如<div>、<…

万字长文深度解析Agent反思工作流框架Reflexion下篇:ReflectionAgent workflow

在前文[LLM-Agents]万字长文深度解析Agent反思工作流框架Reflexion中篇&#xff1a;React中&#xff0c;我们详细解析了ReactAgent的工作流程&#xff0c;而本文则将在此基础上探讨反思技巧的应用。之前的文章中[LLM-Agents]反思Reflection 工作流我们已经对反思技巧进行了探讨…

多维数组操作,不要再用遍历循环foreach了!来试试数组展平的小妙招!array.flat()用法与array.flatMap() 用法及二者差异详解

目录 一、array.flat&#xff08;&#xff09;方法 1.1、array.flat&#xff08;&#xff09;的语法及使用 ①语法 ②返回值 ③用途 二、array.flatMap() 方法 2.1、array.flatMap()的语法及作用 ①语法 ②返回值 ③用途 三、array.flat&#xff08;&#xff09;与a…

CLIP 源码分析:simple_tokenizer.py

tokenizer的含义 from .clip import *引入头文件时为什么有个. 正文 import gzip import html import os from functools import lru_cacheimport ftfy import regex as re# 上面的都是头文件# 这段代码定义了一个函数 default_bpe()&#xff0c;它使用了装饰器 lru_cache()。…

一维时间序列信号的改进小波降噪方法(MATLAB R2021B)

目前国内外对于小波分析在降噪方面的方法研究中&#xff0c;主要有小波分解与重构法降噪、小波阈值降噪、小波变换模极大值法降噪等三类方法。 (1)小波分解与重构法降噪 早在1988 年&#xff0c;Mallat提出了多分辨率分析的概念&#xff0c;利用小波分析的多分辨率特性进行分…

Vue基础(4)组件基础

1. 导入 要使用一个子组件&#xff0c;需要在父组件中进行导入 <script setup> import MyComponent from ./MyComponent.vue </script><template><MyComponent /> </template>2. 传递 props 在使用 <script setup> 的单文件组件中&…

服务器内存不足该怎么办?

当服务器的物理内存使用率达到或者是接近百分之百时&#xff0c;会导致系统没有办法为新的进程或者是请求分配足够的内存空间&#xff0c;在这种情况下&#xff0c;服务器的性能很有可能会受到一定的影响&#xff0c;严重的会导致系统崩溃或者服务出现中断。 那我们面临服务器内…

DAQmx Connect Terminals (VI) 信号路由作用及意义

DAQmx Connect Terminals是一个LabVIEW虚拟仪器&#xff08;VI&#xff09;&#xff0c;用于配置和连接数据采集系统中的物理终端或虚拟终端。这一功能在配置复杂的数据采集&#xff08;DAQ&#xff09;系统时非常重要&#xff0c;因为它允许用户在不改变硬件连接的情况下&…

TS-类型转换(显式)

1.将其他类型转换为布尔类型 要将其他类型转换为布尔类型&#xff0c;只需要将待转换的值传入Boolean()函数 var msg: string "ok"; var msgToBollean: boolean Boolean(msg); //得到trueBoolean()函数会判断传入的值是空值还是非空值。 若表示非空值&#xff0…

Python——Selenium快速上手+方法(一站式解决问题)

目录 前言 一、Selenium是什么 二、Python安装Selenium 1、安装Selenium第三方库 2、下载浏览器驱动 3、使用Python来打开浏览器 三、Selenium的初始化 四、Selenium获取网页元素 4.1、获取元素的实用方法 1、模糊匹配获取元素 & 联合多个样式 2、使用拉姆达表达式 3、加上…

Python零基础-下【详细】

接上篇继续&#xff1a; Python零基础-中【详细】-CSDN博客 目录 十七、网络编程 1、初识socket &#xff08;1&#xff09;socket理解 &#xff08;2&#xff09;图解socket &#xff08;3&#xff09;戏说socket &#xff08;4&#xff09;网络服务 &#xff08;5&a…

线程进阶-1 线程池

一.说一下线程池的执行原理 1.线程池的七大核心参数 &#xff08;1&#xff09;int corePoolSize&#xff1a;核心线程数。默认情况下核心线程会一直存活&#xff0c;当设置allowCoreThreadTimeout为true时&#xff0c;核心线程也会被超时回收。 &#xff08;2&#xff09;i…

Linux基础指令磁盘管理001

Linux磁盘管理涉及多个方面&#xff0c;包括硬盘分区、文件系统创建、挂载、检查磁盘空间、优化性能和维护等。今天我们讲一下磁盘的分区挂载&#xff0c;文件系统的创建。 操作系统 CentOS Stream 9 磁盘的分区 当我们新插入一块磁盘后&#xff0c;首先使用fdisk -l查看磁盘…

实战经验分享之移动云快速部署Stable Diffusion SDXL 1.0

本文目录 前言产品优势部署环境准备模型安装测试运行 前言 移动云是中国移动面向政府、企业和公众的新型资源服务。 客户以购买服务的方式&#xff0c;通过网络快速获取虚 拟计算机、存储、网络等基础设施服务&#xff1b;软件开发工具、运行环境、数据库等平台服务&#xff1…

【评价类模型】熵权法

1.客观赋权法&#xff1a; 熵权法是一种客观求权重的方法&#xff0c;所有客观求权重的模型中都要有以下几步&#xff1a; 1.正向化处理&#xff1a; 极小型指标&#xff1a;取值越小越好的指标&#xff0c;例如错误率、缺陷率等。 中间项指标&#xff1a;取值在某个范围内较…

[ubuntu18.04]搭建mptcp测试环境说明

MPTCP介绍 Multipath TCP — Multipath TCP -- documentation 2022 documentation 安装ubuntu18.04&#xff0c;可以使用虚拟机安装 点击安装VMware Tool 桌面会出现如下图标 双击打开VMware Tools&#xff0c;复制如下图所示的文件到Home目录 打开终端&#xff0c;切换到管…

[Linux]重定向

一、struct file内核对象 struct file是在内核中创建&#xff0c;专门用来管理被打开文件的结构体。struct file中包含了打开文件的所有属性&#xff0c;文件的操作方法集以及文件缓冲区&#xff08;无论读写&#xff0c;我们都需要先将数据加载到文件缓冲区中。&#xff09;等…

基于JSP的高校二手交易平台

开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;JSP技术 工具&#xff1a;浏览器&#xff08;如360浏览器、谷歌浏览器、QQ浏览器等&#xff09;、MySQL数据库 系统展示 系统功能界面 用户注册与登录界面 个人中心界面 商品信息界面 摘要 本文研究了高…