图像识别模型与训练策略

图像预处理

1.需要将图像Resize到相同大小输入到卷积网络中
2.翻转、裁剪、色彩偏移等操作
3.转化为Tensor数据格式
4.对RGB三种颜色通道进行标准化

data_transforms = {'train': transforms.Compose([transforms.Resize([96, 96]),transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选transforms.CenterCrop(64),#从中心开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差]),'valid': transforms.Compose([transforms.Resize([64, 64]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

读取数据

将训练集中各个类别文件夹中的数据经过Transforms增强后进行统一读取封装

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
batch_size = 128image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes

迁移学习

使用官方发布的模型和参数,将参数冻住不更新

def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = Falsemodel_ft = models.resnet18()#18层的能快点,条件好点的也可以选152
model_ft

修改输出层

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):model_ft = models.resnet18(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 102)#类别数自己根据自己任务来input_size = 64#输入大小根据自己配置来return model_ft, input_size

更新输出层参数

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)#GPU还是CPU计算
model_ft = model_ft.to(device)# 模型保存,名字自己起
filename='checkpoint.pth'# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:params_to_update = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_update.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)

优化器设置

optimizer_ft = optim.Adam(params_to_update, lr=1e-2)#要训练啥参数,你来定
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
criterion = nn.CrossEntropyLoss()

训练策略

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25,filename='best.pt'):#咱们要算时间的since = time.time()#也要记录最好的那一次best_acc = 0#模型也得放到你的CPU或者GPUmodel.to(device)#训练过程中打印一堆损失和指标val_acc_history = []train_acc_history = []train_losses = []valid_losses = []#学习率LRs = [optimizer.param_groups[0]['lr']]#最好的那次模型,后续会变的,先初始化best_model_wts = copy.deepcopy(model.state_dict())#一个个epoch来遍历for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 训练和验证for phase in ['train', 'valid']:if phase == 'train':model.train()  # 训练else:model.eval()   # 验证running_loss = 0.0running_corrects = 0# 把数据都取个遍for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)#放到你的CPU或GPUlabels = labels.to(device)# 清零optimizer.zero_grad()# 只有训练的时候计算和更新梯度outputs = model(inputs)loss = criterion(outputs, labels)_, preds = torch.max(outputs, 1)# 训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()# 计算损失running_loss += loss.item() * inputs.size(0)#0表示batch那个维度running_corrects += torch.sum(preds == labels.data)#预测结果最大的和真实值是否一致epoch_loss = running_loss / len(dataloaders[phase].dataset)#算平均epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - since#一个epoch我浪费了多少时间print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 得到最好那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}torch.save(state, filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)#scheduler.step(epoch_loss)#学习率衰减if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()scheduler.step()#学习率衰减time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 训练完后用最好的一次当做模型最终的结果,等着一会测试model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/31421.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

unable to write symref for HEAD: Permission denied

今天从gitee上面克隆项目到本地时报错如下 warning: unable to unlink ‘D:/IDEAcode/ruiji1.0/.git/HEAD.lock’: Invalid argument error: unable to write symref for HEAD: Permission denied 解决方法:将要存放项目的文件夹权限修改为完全控制 原先权限&…

GO学习之 接口(Interface)

GO系列 1、GO学习之Hello World 2、GO学习之入门语法 3、GO学习之切片操作 4、GO学习之 Map 操作 5、GO学习之 结构体 操作 6、GO学习之 通道(Channel) 7、GO学习之 多线程(goroutine) 8、GO学习之 函数(Function) 9、GO学习之 接口(Interface) 文章目录 GO系列前言一、什么是…

什么是MVCC

问题描述 对于 MVCC 的理解,我觉得可以先从数据库的三种并发场景说起: 第一种:读读 线程 A 与线程 B 同时在进行读操作,这种情况下不会出现任何并发问题。 第二种:读写 线程 A 与线程 B 在同一时刻分别进行读和写…

W5100S-EVB-PICO 做TCP Server进行回环测试(六)

前言 上一章我们用W5100S-EVB-PICO开发板做TCP 客户端连接服务器进行数据回环测试,那么本章将用开发板做TCP服务器来进行数据回环测试。 TCP是什么?什么是TCP Server?能干什么? TCP (Transmission Control Protocol) 是一种面向连…

十一、结合数字孪生与时间技术进行多维分析设计与实施

大数据可视化中心以主题为分析对象,选择业务分类下的某个主题,可以在数据面板中展示其二维图表,在地图中标记其空间分布,并叠加其相应的二维或三维图层。 1、界面设计 其主界面设计详上图,各部分功能介绍如下: 1.1、主题与图层面板,从上到下,从左到右分别是: ①折…

【1++的数据结构】之二叉搜索树

👍作者主页:进击的1 🤩 专栏链接:【1的数据结构】 文章目录 一,什么是二叉搜索树二,二叉搜索树的操作及其实现2.1 插入操作及其实现2.2 查找操作及其实现2.3 删除操作及其实现 三,构造及其析构四…

分布式链路追踪概述

分布式链路追踪概述 文章目录 分布式链路追踪概述1.分布式链路追踪概述1.1.什么是 Tracing1.2.为什么需要Distributed Tracing 2.Google Dapper2.1.Dapper的分布式跟踪2.1.1.跟踪树和span2.1.2.Annotation2.1.3.采样率 3.OpenTracing3.1.发展历史3.2.数据模型 4.java探针技术-j…

TOMCAT部署及优化(Tomcat配置文件参数优化,Java虚拟机(JVM)调优)

TOMCAT tomcat :是一个开放源代码的web应用服务器,基于java代码开发的。也可以理解为tomacat就是处理动态请求和基于java代码的页面开发。可以在html当中写入java代码,tomcat可以解析html页面当中的java,执行动态请求,…

Java算法_ LRU 缓存(LeetCode_Hot100)

题目描述:请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 获得更多?算法思路:代码文档,算法解析的私得。 运行效果 完整代码 import java.util.HashMap; import java.util.Map;/*** 2 * Author: L…

makefile include 使用介绍

文章目录 前言一、include 关键字1. 语法介绍2. 处理方式示例: 二、- include 操作总结 前言 一、include 关键字 1. 语法介绍 在 Makefile 中,include 指令: 类似于 C 语言中的 include 。将其他文件的内容原封不动的搬入当前文件。 当 …

打破音频语言障碍,英语音频翻译成文字软件助你畅快对话

要理解外语歌曲对我来说难如登天。不过,这种痛苦没有持续太久,我发现了一种音频翻译技术,它像一个语言转换器,可以即时将外语歌曲翻译成我听得懂的语言!我惊喜地试用后,终于可以在听歌的同时看到翻译的歌词…

QT压缩解压文件

文章目录 前言一、下载Quazip二、编译Quazip1.使用vs2019打开quazip.sln2.使用Qt VS Tools打开外层的.pro工程3.编译 三、工程使用1.配置头文件路径2.配置静态库lib目录3.添加库4.动态库dll放到.exe同级目录下5.使用 前言 Qt工程中需要用到zip压缩解压功能,网上搜索…

C++类型查询模板之std::is_array

2023年8月10日&#xff0c;周四上午 概述 std::is_array是一个C类型查询(type trait)模板,它可以用来判断一个类型是否是数组类型。 std::is_array定义在头文件<type_traits>中。 使用方法 可以通过std::is_array::value成员常量来判断一个类型是否是数组类型。 std:…

【Tool】win to go 制作随身硬盘

前言 话说我一冲动买了512G固态硬盘&#xff0c;原本是装个ubuntu系统的&#xff0c;这个好装&#xff0c;但是用处太少&#xff0c;就像改成win10的 经历一堆坑之后&#xff0c;终于使用WTG安装好了 步骤 1.下载个WTG辅助工具 Windows To Go 辅助工具|WTG辅助工具 v5.6.1…

leetcode - 75. 颜色分类(java)

颜色分类 leetcode - 75. 颜色分类题目描述双指针代码演示 双指针算法专题 leetcode - 75. 颜色分类 难度 - 中等 原题链接 - 颜色分类 题目描述 给定一个包含红色、白色和蓝色、共 n 个元素的数组 nums &#xff0c;原地对它们进行排序&#xff0c;使得相同颜色的元素相邻&…

容器化相关面试题

Docker相关面试题 (1)Docker的组件包含哪些? 客户端:dockerclient服务端:dockerserver## 能看到相关的信息 docker info## docker client向docker daemon发送请求,docker daemon完成相应的任务,并把结果返还给容器 Docker镜像: docker镜像是一个只读的模板,是启动一…

【安装部署】Mysql下载及其安装的详细步骤

1.下载压缩包 官网地址&#xff1a;www.mysql.com 2.环境配置 1.先解压压缩包 2.配置环境变量 添加环境变量&#xff1a;我的电脑--->属性-->高级-->环境变量-->系统变量-->path 3.在mysql安装目录下新建my.ini文件并&#xff0c;编辑my.ini文件 编辑内容如…

Centos7.9安装lrzsz进行文件传输---Linux工作笔记059

这里咱们lrzsz命令,需要用来进行文件传输,因为如果不安装这个命令的话,那么 传输安装包什么的就不方便因为只有少数传输工具,才支持,直接拖拽的.没有的时候就可以用这个工具,用命令来传输 直接就是: sz 文件名 就可以把文件下载下来 rz 选择一个文件, 就可以把文件上传到当…

SpringMVC简介搭建环境快速入门

1.简介 SpringMVC是一个基于Spring开发的MVC轻量级框架&#xff0c;Spring3.0后发布的组件&#xff0c;SpringMVC和Spring可以无 缝整合&#xff0c;使用DispatcherServlet作为前端控制器&#xff0c;且内部提供了处理器映射器、处理器适配器、视图解析器等组 件&#xff0c;可…

F12诡异Bug分享

Bug本身情况 java运行的时候会产生class文件&#xff0c;其本身是跑class文件的&#xff0c;但某个实施反馈一个经典版本的长久bug。 当使用模糊查询时&#xff0c;一页一页查看&#xff0c;在倒数第二页时&#xff0c;点击下一页&#xff0c;页面静止不动。&#xff08;正常情…