数据增强,迁移学习,Resnet分类实战

目录

1. 数据增强(Data Augmentation)

2. 迁移学习

3. 模型保存    

4. 102种类花分类实战

1. 数据集

2.导入包

3. 数据读取与预处理操作 

4. Datasets制作输入数据

5.将标签的名字读出 

6.展示原始数据 

7.加载models中提供的模型 

8.初始化 

9.优化器设置 

10.训练模块


1. 数据增强(Data Augmentation)

        数据不够怎么办?采用翻转,镜像,增加数据

        如何更加高效利用数据?多利用几次

        在pytorch中有数据预处理部分:

            数据增强:torchvision中transforms模块自带功能,比较实用

            数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可

            DataLoader模块直接读取batch数据

        pyorch官网:https://pytorch.org/vision/stable

2. 迁移学习

        在训练自己的模型时出现一些问题:

        1. 自己的数据不够好

        2. 训练参数花费时间多

        3. 训练模型太难

        解决方法:

        有前人已经训练好了模型,其实就是将训练的参数保留下来,而且目标都差不多。那么把别人的模型参数当成初始化参数,所有的结构和前人模型一样。

        网络模块设置:

    加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且也可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习

    需要注意的是别人训练好的任务根咱们的可不是完全一样的,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务

    训练时可以完全重头训练,也可以只训练最后咱们任务层,因为前几层都是做特征提取的,本质任务目标一致的。

        总结:迁移学习策略

                1. 将卷积层当成初始化权重参数

                2.将卷积层权重参数冻住不变,全连接层重新训练(一般是,数据量少,冻住的层数多)

3. 模型保存    

        网络模型保存与测试

            模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存

            读取模型进行实际测试

4. 102种类花分类实战

      1. 数据集

        有训练集,测试集。一共102种花,每种花有25~100个图像

2.导入包

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

3. 数据读取与预处理操作 

data_dir = './flower_data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

制作好数据源:

    data_transforms中指定了所有图像预处理操作

    ImageFolder假设所有文件按文件夹保存好,每个文件夹下面存储同一类别的图片,文件夹的名字为分类的名字

data_transforms = {'train' : transforms.Compose([transforms.RandomRotation(45),#随即旋转,-45度到45度之间随机选transforms.CenterCrop(224), #从中心点开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随即水平翻转,选择一个概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1), #参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),#转换成tensor格式transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#均值,标准差]), 'valid':transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])                   ])
}

 4. Datasets制作输入数据

        采用batch,将数据分组输入。

batch_size  = 8image_datasets = {x : datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) for x in ['train','valid']}
dataloaders = {x : torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True) for x in ['train','valid']}
dataset_szies = {x :len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes
print(image_datasets)
print(dataloaders)

5.将标签的名字读出 

        用123....打标签好像不好,用花的名字作为标签

#读取标签对应的实际名字
with open('./flower_data/cat_to_name.json','r') as f:cat_to_name = json.load(f)
print(cat_to_name)

6.展示原始数据 

        展示下数据

            注意tensor的数据需要转换成numpy格式,而且还需要还原成标准化的结果

def im_convert(tensor):'''展示数据'''image = tensor.to('cpu').clone().detach()image = image.numpy().squeeze()image = image.transpose(1,2,0)image = image * np.array((0.229,0.224,0.225)) + np.array((0.485,0.456,0.406))image = image.clip(0,1)return imagefig = plt.figure(figsize=(20,12))
colunms = 4
rows = 2dataiter = iter(dataloaders['valid'])
inputs,classes =next(dataiter)for idx in range(colunms * rows):ax = fig.add_subplot(rows,colunms,idx+1,xticks = [],yticks = [])ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])plt.imshow(im_convert(inputs[idx]))
plt.show()

7.加载models中提供的模型 

        加载models中提供的模型,并且直接用训练好的权重当作初始化参数

            第一次执行需要下载,可能会比较慢

model_name = 'resnet' #可选的会比较多['resnet','alexnet','vgg','squeezenet','densenet','inception']
# 是否用人家训练好的特征来做
feature_extract = True#是否用GPU训练
train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available .  Training on CPU...')
else:print('CUDA is  available .  Training on GPU...')device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')def set_parameter_requires_grad(model,feature_extracting):if feature_extract:for param in model.parameters():param.requires_grad = False #冻不冻住model_ft = models.resnet152()
print(model_ft)

8.初始化 

        迁移学习,用前人的参数,改变全连接层。

def initalize_model(model_name,num_classes,feature_extract,use_pretrained=True):#选择合适的模型,不同模型的初始化方法稍微有点区别model_ft = Noneinput_size = 0if model_name == 'resnet':model_ft = models.resnet152(pretrained=use_pretrained)set_parameter_requires_grad(model_ft,feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Sequential(nn.Linear(num_ftrs,num_classes),nn.LogSoftmax(dim=1))input_size = 224return model_ft,input_sizefeature_extract = True
model_ft,input_size = initalize_model(model_name,102,feature_extract,use_pretrained=True)#GPU计算
model_ft = model_ft.to(device)#模型保存
filename = 'checkpoint.pth'#是否训练所有层params_to_updata = model_ft.parameters()
print('Params to learn')
if feature_extract:params_to_updata = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_updata.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)print(model_ft)

 

9.优化器设置 

#优化器设置
optimizer_ft = optim.Adam(params_to_updata,lr=1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) #学习率每7个epoch衰减成原来的1/10
#最后一层已经LogSoftmax()了,所以不能nn.CrossEntropyLoss()来计算了,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NLLoss()整合
criterion = nn.NLLLoss()

10.训练模块

def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename=filename):since = time.time()best_acc = 0'''checkpoint = torch.laod(filename)best_acc = checkpoint['best_acc]model.load_state_dict(checkpoint['optimizer'])model.class_to_idx = checkpoint['mapping']'''model.to(device)val_acc_history = []train_acc_history = []train_losses = []vaild_losses = []LRs = [optimizer.param_groups[0]['lr']]best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch,num_epochs-1))print('-'*10)#训练和验证for phase in ['train','valid']:if phase == 'train':model.train() #训练else:model.eval()  #验证running_loss = 0.0running_corrects = 0#把数据都取个遍for inputs,labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)#清零optimizer.zero_grad()#只有训练的时候计算和更新梯度with torch.set_grad_enabled(phase=='train'):if is_inception and phase == 'train':outputs,aux_outputs = model(inputs)loss1 = criterion(outputs,labels)loss2 = criterion(aux_outputs,labels)loss = loss1 + loss2else: #resnet执行的是这里outputs= model(inputs)loss = criterion(outputs,labels)_,preds = torch.max(outputs,1)#训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()#计算损失running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - sinceprint('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase,epoch_loss,epoch_acc))#得到最好的那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state ={'state_dict' : model.state_dict(),'best_acc': best_acc,'optimizer': optimizer.state_dict(),}torch.save(state,filename)if phase == 'valid':val_acc_history.append(epoch_acc)vaild_losses.append(epoch_loss)scheduler.step(epoch_loss)if phase=='train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60,time_elapsed % 60))print('Best val Acc: {:.4f}'.format(best_acc))#训练完后用最好的一次当作模型的最终结果model.load_state_dict(best_model_wts)return model, val_acc_history,train_acc_history,vaild_losses,train_losses,LRs
#开始训练!!!
model_ft,val_acc_history,train_acc_history,vaild_losses,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/836325.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android Studio在android Emulator中运行的项目黑屏

前言: 最近在做一个Android相关的小项目,因为之前这方面的项目做的比较的少。今天在使用虚拟机调试的时候经常出现一些莫名其妙的问题,经过自己多次的尝试和搜索终于解决了这些问题。 问题: 每次run(运行&#xff09…

【机器学习300问】88、什么是Batch Norm算法?

一、什么是Batch Norm? (1)Batch Norm的本质 神经网络中的Batch Normalization(批量归一化,简称BatchNorm或BN)是一种改进神经网络训练过程的规范化方法,BatchNorm的主要目的是加速神经网络的训…

构建教育新未来:智慧校园平台的深度解读与全景呈现

引言 在全球数字化转型的大潮中,智慧校园平台作为教育信息化的重要载体,正以前所未有的姿态颠覆传统的教育模式,引领教育行业步入一个崭新的时代。这个融合了大数据、人工智能、云计算、物联网等一系列前沿科技的平台,以其强大的功…

mybatis-plus使用指南(1)

快速开始 首先 我们 在创建了一个基本的springboot的基础框架以后&#xff0c;在 pom文件中 引入 mybatisplus的相关依赖 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5…

PyTorch的卷积和池化

卷积计算 input 表示输入的图像filter 表示卷积核, 也叫做滤波器input 经过 filter 的得到输出为最右侧的图像&#xff0c;该图叫做特征图 卷积的计算是将卷积核放入左上角&#xff0c;在局部区域间做点积&#xff0c;然后将卷积核在Input上面依次从左向右&#xff0c;从上到下…

免费证件照一键换底色

最近星期天在家搞了一个小工具&#xff0c;在这里分享下! 废话不多说看看效果&#xff1a; 效果还不错&#xff0c;需要的可以联系我!!!!!!!!! 别的网上可都是一次五块钱这种。太贵了。。&#xff01;&#xff01;

【Dash】开始学习dash

安装Dash 网上很多安装dash的教程&#xff0c;不再赘述 开始Dash 一个dash页面的基本写法 # dash 的基本写法 import dash from dash import html,dcc,callback,Input,Output# 创建一个 dash 应用 app dash.Dash()# 定义布局&#xff0c;定义一个输入框和一个输出框 app.l…

VS项目Debug下生成的EXE在生产机器上运行

使用Visual Studio开发应用程序时&#xff0c;为了临时在非开发机上看一下效果&#xff0c;就直接把Debug下的文件全部拷贝到该机器上&#xff0c;直接双击exe运行。双击之后&#xff0c;没有直接打开应用程序&#xff0c;而是弹出了一个Error弹框。  赶快在网上搜了一遍&…

MFC窗口更新与重绘

窗口更新与重绘 窗口或控件更新其外观的情况通常包括以下几种&#xff1a; 窗口大小变化&#xff1a; 当用户调整窗口大小时&#xff0c;窗口的客户区大小会改变&#xff0c;需要重新绘制窗口内容以适应新的大小。 窗口重叠或暴露&#xff1a; 当窗口被其他窗口遮挡部分或完…

「 安全设计 」68家国内外科技巨头和安全巨头参与了CISA发起的安全设计承诺,包含MFA、默认密码、CVE、VDP等七大承诺目标

美国网络安全和基础设施安全局&#xff08;CISA&#xff0c;CyberSecurity & Infrastructure Security Agency&#xff09;于2024年5月开始呼吁企业是时候将网络安全融入到技术产品的设计和制造中了&#xff0c;并发起了安全设计承诺行动&#xff0c;该承诺旨在补充和建立现…

一个物业管理服务项目的思考——智慧停车场无人值守呼叫系统到电梯五方对讲再到呼叫中心

目录 起源智慧停车场无人值守呼叫系统然后电梯五方对讲系统又然后物业呼叫中心集控E控中心怎么做 之前介绍过一个关于 点这个链接&#xff1a;门卫、岗亭、值班室、门房、传达室如果距离办公室和机房比较远的情况下怎么实现电话通话&#xff0c;基本上属于物业管理服务的范围。…

强化学习在一致性模型中的应用与实验验证

在人工智能领域&#xff0c;文本到图像的生成任务一直是研究的热点。近年来&#xff0c;扩散模型和一致性模型因其在图像生成中的卓越性能而受到广泛关注。然而&#xff0c;这些模型在生成速度和微调灵活性上存在局限。为了解决这些问题&#xff0c;康奈尔大学的研究团队提出了…

【STM32+HAL+Proteus】系列学习教程---中断(NVIC、EXTI、按键)

实现目标 1、掌握STM32的中断知识 2、学会STM32CubeMX软件关于中断的配置 3、具体目标&#xff1a;1、外部中断检测按键&#xff0c;每按一次计一次数&#xff0c;满5次LED1状态取反。 一、中断概述 1.1、中断定义 CPU执行程序时&#xff0c;由于发生了某种随机的事件(包括…

实验室纳新宣讲会(java后端)

前言 这是陈旧已久的草稿2021-09-16 15:41:38 当时我进入实验室&#xff0c;也是大二了&#xff0c;实验室纳新需要宣讲&#xff0c; 但是当时有疫情&#xff0c;又没宣讲成。 现在2024-5-12 22:00:39&#xff0c;发布到[个人]专栏中。 实验室纳新宣讲会&#xff08;java后…

基于GD32的简易数字示波器(4)- 软件

这期记录的是项目实战&#xff0c;做一个简易的数字示波器。 教程来源于嘉立创&#xff0c;帖子主要做学习记录&#xff0c;方便以后查看。 本期主要介绍GD32的keil5环境和串口下载。详细教程可观看下方链接。 软件-第1讲-工程模板新建_哔哩哔哩_bilibili 2.1 开发环境搭建 …

logback日志持久化

1、问题描述 使用logback持久化记录日志。 2、我的代码 logback是Springboot框架里自带的&#xff0c;所以只要引入“spring-boot-starter”就行了。无需额外引入logback依赖。 pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns&…

2005-2022年各省居民人均可支配收入数据(含城镇居民人均可支配收入、农村居民人均可支配收入)(无缺失)

2005-2022年各省居民人均可支配收入数据&#xff08;含城镇居民人均可支配收入、农村居民人均可支配收入&#xff09;&#xff08;无缺失&#xff09; 1、时间&#xff1a;2005-2022年 2、来源&#xff1a;国家统计局、统计年鉴 3、指标&#xff1a;全体居民人均可支配收入、…

探索大型语言模型(LLM)的世界

​ 引言 大型语言模型&#xff08;LLM&#xff09;作为人工智能领域的前沿技术&#xff0c;正在重塑我们与机器的交流方式&#xff0c;在医疗、金融、技术等多个行业领域中发挥着重要作用。本文将从技术角度深入分析LLM的工作原理&#xff0c;探讨其在不同领域的应用&#xff0…

开源软件托管平台gogs操作注意事项

文章目录 一、基本说明二、gogs私有化部署三、设置仓库git链接自动生成参数四、关闭新用户注册入口 私有化部署gogs托管平台&#xff0c;即把gogs安装在我们自己的电脑或者云服务器上。 一、基本说明 系统环境&#xff1a;ubuntu 20.4docker安装 二、gogs私有化部署 前期准…

Ansible常用变量【上】

转载说明&#xff1a;如果您喜欢这篇文章并打算转载它&#xff0c;请私信作者取得授权。感谢您喜爱本文&#xff0c;请文明转载&#xff0c;谢谢。 在Ansible中会用到很多的变量&#xff0c;Ansible常用变量包括以下几种&#xff1a; 1. 自定义变量——在playbook中用户自定义…