基于torch的图像识别训练策略与常用模块

数据预处理部分:

  • 数据增强:torchvision中transforms模块自带功能,比较实用
  • 数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可
  • DataLoader模块直接读取batch数据

网络模块设置:

  • 加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习
  • 需要注意的是别人训练好的任务跟咱们的可不是完全一样,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务
  • 训练时可以全部重头训练,也可以只训练最后咱们任务的层,因为前几层都是做特征提取的,本质任务目标是一致的

网络模型保存与测试

  • 模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存
  • 读取模型进行实际测试
data_transforms = {'train': transforms.Compose([transforms.Resize([96, 96]),transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选transforms.CenterCrop(64),#从中心开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差]),'valid': transforms.Compose([transforms.Resize([64, 64]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

选择性的权重更新

def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False

自定义修改模型输出层,以resnet18为例

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):model_ft = models.resnet18(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 102)#类别数自己根据自己任务来input_size = 64#输入大小根据自己配置来return model_ft, input_size

训练权重 选择

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)#GPU还是CPU计算
model_ft = model_ft.to(device)# 模型保存,名字自己起
filename='checkpoint.pth'# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:params_to_update = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_update.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)

基本训练代码

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25,filename='best.pt'):#咱们要算时间的since = time.time()#也要记录最好的那一次best_acc = 0#模型也得放到你的CPU或者GPUmodel.to(device)#训练过程中打印一堆损失和指标val_acc_history = []train_acc_history = []train_losses = []valid_losses = []#学习率LRs = [optimizer.param_groups[0]['lr']]#最好的那次模型,后续会变的,先初始化best_model_wts = copy.deepcopy(model.state_dict())#一个个epoch来遍历for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 训练和验证for phase in ['train', 'valid']:if phase == 'train':model.train()  # 训练else:model.eval()   # 验证running_loss = 0.0running_corrects = 0# 把数据都取个遍for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)#放到你的CPU或GPUlabels = labels.to(device)# 清零optimizer.zero_grad()# 只有训练的时候计算和更新梯度outputs = model(inputs)loss = criterion(outputs, labels)_, preds = torch.max(outputs, 1)# 训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()# 计算损失running_loss += loss.item() * inputs.size(0)#0表示batch那个维度running_corrects += torch.sum(preds == labels.data)#预测结果最大的和真实值是否一致epoch_loss = running_loss / len(dataloaders[phase].dataset)#算平均epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - since#一个epoch我浪费了多少时间print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 得到最好那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}torch.save(state, filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)#scheduler.step(epoch_loss)#学习率衰减if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()scheduler.step()#学习率衰减time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 训练完后用最好的一次当做模型最终的结果,等着一会测试model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

调用训练

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20)
def im_convert(tensor):""" 展示数据"""image = tensor.to("cpu").clone().detach()image = image.numpy().squeeze()image = image.transpose(1,2,0)image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))image = image.clip(0, 1)return image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/814762.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【SQL】数据库SQL语句

1、主键 主键值唯一,不可修改,不能为空,删除不能重用 2、数据类型(常用) char int float date timestamp 3、select select * from data; select xx,xxx from data;//取部分行 select * from data limit 100; //限…

Bezier曲线的绘制 matlab

式中: 称为基函数。 。 因为n表示次数,点数为n1,显然i表示第i个控制点。 显然在Matlab中可以同矩阵的形式来计算C(u)。 关键代码为: clc clear % 假设控制点P取值为: P [4,7;13,12;19,4;25,12;30,3]; % 因此&a…

vscode debug 配置:launch.json

打开新项目左边的“运行和调试” 点击蓝色字体“创建 launch.json 文件” 选择上方“python” 选择“Python 文件 调试当前正在运行的Python文件” 配置launch.json文件内容: {// 使用 IntelliSense 了解相关属性// 悬停以查看现有属性的描述。// 欲了解更多信息&a…

设计模式-单一职责原则

基本介绍 对类来说的,即一个类应该只负责一项职责。如类A负责两个不同的职责,职责1,职责2.当职责1需求变更而改变A时,可能造成职责2执行错误,所以需要将类A的粒度分解为A1,A2 应用实例 方案1 public cl…

大厂MVP技术JAVA架构师培养

课程介绍 这是一个很强悍的架构师涨薪计划课程,课程由专家级MVP讲师进行教学,分为是一个章节进行分解式面试及讲解,不仅仅是面试,更像是一个专业的架构师研讨会课程。课程内容从数据结构与算法、Spring Framwork、JVM原理、 JUC并…

JS时间戳转换 时间戳转时间 js转换时间戳为时间类型显示

JS时间戳转换 时间戳转时间 js转换时间戳为时间类型显示 本方法已经抽离出年月日时分秒 更多的时间格式搭配大家可以随意添加!!! convertToEnglishMonthAbbreviation方法把月份转化成英文显示 大家可以看下 并非使用通过自定义枚举的方式实现…

opencv基础图行展示

"""试用opencv创建画布并显示矩形框(适用于目标检测图像可视化) """ # 创建一个黑色的画布,图像格式(BGR) img np.zeros((512, 512, 3), np.uint8)# 画一个矩形:给定左上角和右下角坐标&#xff0…

速成英语语法(2)

be动词 表示事物的性质和状态 ..是.. ..有.. I am Tom 我是汤姆 There are seven days in a week 一个星期有七天be动词的种类 am/is/are 我am 它/她/他is 你arebe的疑问句 含be的陈述句变为疑问句 主语和be对换位置 Are you ready? 你准备好了吗?be的否定否定句 b…

数据结构:去发现顺序表的魅力所在

✨✨小新课堂开课了,欢迎欢迎~✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:http://t.csdnimg.cn/oHJAK(数据结构与算法) 小新的主页:编程版小新-CSDN博客 …

Docker入门实战教程

文章目录 Docker引擎的安装Docker比vm虚拟机快 Docker常用命令帮助启动类命令镜像命令docker imagesdocker searchdocker pulldocker system dfdocker rmi 容器命令redis前台交互式启动redis后台守护式启动Nginx容器运行ubuntu交互式运行tomcat交互式运行对外暴露访问端口 Dock…

Linux的内存管理子系统

大家好,今天给大家介绍Linux的内存管理子系统,文章末尾附有分享大家一个资料包,差不多150多G。里面学习内容、面经、项目都比较新也比较全!可进群免费领取。 Linux的内存管理子系统是Linux内核中一个非常重要且复杂的子系统&#…

防火墙操作!

当小编在Linux服务器上部署好程序以后,但是输入URL出现下述情况,原来是防火墙的原因!! 下面是一些防火墙操作! 为保证系统安全,服务器的防火墙不建议关闭!! 但是,我们可…

【网络安全】WebPack源码(前端源码)泄露 + jsmap文件还原

前言 webpack是一个JavaScript应用程序的静态资源打包器。它构建一个依赖关系图,其中包含应用程序需要的每个模块,然后将所有这些模块打包成一个或多个bundle。大部分Vue等项目应用会使用webpack进行打包,使用webpack打包应用程序会在网站js…

20240327-1-评测指标面试题

评测指标面试题 metric主要用来评测机器学习模型的好坏程度,不同的任务应该选择不同的评价指标,分类,回归和排序问题应该选择不同的评价函数. 不同的问题应该不同对待,即使都是分类问题也不应该唯评价函数论,不同问题不同分析. 回归(Regression) 平均绝对误差(MAE) 平均绝对…

CMake 学习笔记2

其他很好的总结 CMake教程系列-01-最小配置示例 - 知乎 CMake 保姆级教程(上) | 爱编程的大丙 10-补充(完结)_哔哩哔哩_bilibili 1、基本关键字 SET命令的补充 (1)SET命令设置执行标准 #增加-stdc11 set(CMAKE_CXX_STANDARD…

并查集的延伸--克鲁斯卡尔法求最小生成树MST

并查集的延伸--克鲁斯卡尔法求最小生成树MST 力扣 1135 力扣 1584并查集 UnionFind.java 力扣 1135 力扣 1584 package com.caoii;/**program:labu-pratice-study*package:com.caoii*author: Alan*Time: 2024/4/14 9:09*description: 最小生成树相关题目测试*/import org.jun…

Terminal 美化

摘自:Mac 系统终端美化与 ZSH 多设备配置同步共享 个人对 iTerm2 等第三方终端工具不太感冒,一直在用系统内置终端。 相比之下,系统自带的 Terminal 可谓是简陋啊。 安装了 Oh My Zsh,加上一些插件,感觉还行。 再调…

如何防止软件过度封装和抽象?

一、合适的软件架构 构建可读性强、高内聚、低耦合的软件架构是软件工程中的重要原则,这有助于提高代码的维护性、扩展性和复用性。以下是一些实践方法: 1. **模块化设计**:将系统划分为一系列职责单一、功能明确的模块或组件,每…

ELK企业级日志分析系统以及多种部署

目录 ELK简介 ELK简介 ELK平台是一套完整的日志集中处理解决方案,将 ElasticSearch、Logstash 和 Kiabana 三个开源工具配合使用, 完成更强大的用户对日志的查询、排序、统计需求。 ●ElasticSearch:是基于Lucene(一个全文检索引…

windows应急中的快捷键

windows应急中的快捷键 应急的时候,快捷键很重要,记录一下windows主机排查需要用到的快捷键 windows快捷键 appwiz.cpl 是打开安装面板 程序和功能 控制面板程序和功能 搜索程序和功能 控制而板主页 卸载或更改程序 若要卸酸程序,请从列表中将其…