【AI】基于已有模型训练自己的模型(迁移)

实际工作中,我们可能缺乏算力去从头到尾训练一个模型,使用别人训练好的模型(通常是经典模型)就成了一个很好的选择,这样我们就不需要设置每一层的初始参数,极大的提高了训练的效率;但是在使用别人的模型时,有时候会有一些不适应的地方,以分类项目而言,可能不同的数据集的分类类别就不一样,需要修改模型最后的输出类别。我们以最简单的resnet18模型为例,来进行我们分类任务的迁移学习。

1.准备工作

导包准备

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
import imageio
import time
import warnings
warnings.filterwarnings("ignore")
import random
import sys
import copy
import json
from PIL import Image

我们还是使用之前用过的花朵分类任务,先制作Dataset和Dataloader

data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'data_transforms = {'train': transforms.Compose([transforms.Resize([96, 96]),transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选transforms.CenterCrop(64),#从中心开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差]),'valid': transforms.Compose([transforms.Resize([64, 64]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 这里的batchsize根据自己的显卡水平来进行设置,如果是cpu跑,尽量使用小一点的batchsize
batch_size = 512image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes# 读入标签对应文件
with open('cat_to_name.json', 'r') as f:cat_to_name = json.load(f)

2.准备模型

#是否用人家训练好的特征来做
feature_extract = True #都用人家特征,咱先不更新def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available.  Training on CPU ...')
else:print('CUDA is available!  Training on GPU ...')device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model_ft = models.resnet18()#选用18层的网络能快点

可以打印一下模型,看一下模型的情况

model_ft#以下是输出
ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): BasicBlock((conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer3): Sequential((0): BasicBlock((conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer4): Sequential((0): BasicBlock((conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=512, out_features=1000, bias=True)
)

从最后一行可以看到输出类别(out_features)为1000,我们花朵分类的任务是102,所以我们需要修改模型最后的全连接层输出为102

开始修改模型

def initialize_model( num_classes, feature_extract, use_pretrained=True):model_ft = models.resnet18(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, num_classes)#类别数自己根据自己任务来input_size = 512 #输入大小根据自己配置来return model_ft, input_size

配置模型修改的层,根据需求,我们先只训练最后一层

model_ft, input_size = initialize_model(102, feature_extract, use_pretrained=True)#GPU还是CPU计算
model_ft = model_ft.to(device)
# 模型保存,名字自己起
filename='checkpoint.pth'
# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:params_to_update = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_update.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)

3.训练

定义优化器、损失函数等

# 优化器设置
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)#要训练啥参数,你来定
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.1#学习率每5个epoch衰减成原来的1/10
criterion = nn.CrossEntropyLoss()

定义训练函数

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25,filename='best.pt'):#咱们要算时间的since = time.time()#也要记录最好的那一次best_acc = 0#模型也得放到你的CPU或者GPUmodel.to(device)#训练过程中打印一堆损失和指标val_acc_history = []train_acc_history = []train_losses = []valid_losses = []#学习率LRs = [optimizer.param_groups[0]['lr']]#最好的那次模型,后续会变的,先初始化best_model_wts = copy.deepcopy(model.state_dict())#一个个epoch来遍历for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 训练和验证for phase in ['train', 'valid']:if phase == 'train':model.train()  # 训练else:model.eval()   # 验证running_loss = 0.0running_corrects = 0# 把数据都取个遍for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)#放到你的CPU或GPUlabels = labels.to(device)# 清零optimizer.zero_grad()# 只有训练的时候计算和更新梯度outputs = model(inputs)loss = criterion(outputs, labels)_, preds = torch.max(outputs, 1)# 训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()# 计算损失running_loss += loss.item() * inputs.size(0)#0表示batch那个维度running_corrects += torch.sum(preds == labels.data)#预测结果最大的和真实值是否一致epoch_loss = running_loss / len(dataloaders[phase].dataset)#算平均epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - since#一个epoch我浪费了多少时间print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 得到最好那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}torch.save(state, filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)#scheduler.step(epoch_loss)#学习率衰减if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()scheduler.step()#学习率衰减time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 训练完后用最好的一次当做模型最终的结果,等着一会测试model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

开始训练

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20)

训练了20个epoch,模型的精度最高只有40%

然后放开所有层参数,让他们全部参与进来

for param in model_ft.parameters():param.requires_grad = True# 再继续训练所有的参数,学习率调小一点
optimizer = optim.Adam(model_ft.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.1)# 损失函数
criterion = nn.CrossEntropyLoss()model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=10,)

这次训练的精度有了显著的提高

Epoch 6/9
----------
Time elapsed 4m 48s
train Loss: 0.4505 Acc: 0.8646
Time elapsed 4m 51s
valid Loss: 1.5857 Acc: 0.6259
Optimizer learning rate : 0.0010000

精度达到了62%。

这次的工作只是对迁移学习的简单尝试,代码借用了一些教程的。学习的过程就是小步快跑,当我们能力不足时,适当的借鉴可以提高我们成长的速度;这也正是迁移学习的思想吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/191841.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python中的加法测试题实现

随机生成5道10以内的加法测试题,用户在10秒内使用键盘输入答案。完成全部5道答题之后,计算机生成答题记录报告,并对答题情况进行分析,显示“答对了”,或“答错了”、并显示正确答案。如果未能按时完成,则显…

opencv知识库:基于cv2.flip()函数对图像进行随机翻转(水平/垂直)

需求场景 欲对RGB格式的lena图像进行随机翻转,要求这些图像不翻转、水平翻转、垂直翻转的概率都为1/3。 功能代码 import cv2 import random# 读取并展示图像 img cv2.imread("lena.jpg") cv2.imshow(lena, img) cv2.waitKey(0)for i in range(6): #…

Python concurrent.futures实现多进程多线程编程

Python的concurrent.futures模块可以很方便的实现多进程、多线程运行,减少了多进程带来的的同步和共享数据问题。 Executor是一个抽象类,表示一个可执行的上下文。Future则代表一个将要执行的任务,并提供了一些方法来获取任务的状态和结果。T…

Hdoop学习笔记(HDP)-Part.18 安装Flink

目录 Part.01 关于HDP Part.02 核心组件原理 Part.03 资源规划 Part.04 基础环境配置 Part.05 Yum源配置 Part.06 安装OracleJDK Part.07 安装MySQL Part.08 部署Ambari集群 Part.09 安装OpenLDAP Part.10 创建集群 Part.11 安装Kerberos Part.12 安装HDFS Part.13 安装Ranger …

技术面选股的方法

技术面选股是根据股票的价格走势、成交量等技术指标,来预测股票未来走势的一种方法。以下是一些常用的技术面选股方法的详细解析: 均线分析:这是一种常见且基础的技术面选股方法,投资者会计算股票的移动平均线(如5日、…

头歌JUnit单元测试相关实验入门

一、入门实验 1.1第一个Junit测试程序 任务描述 请学员写一个名为testSub()的测试函数,来测试给定的减法函数是否正确。 相关知识 Junit编写原则 1、简化测试的编写,这种简化包括测试框架的学习和实际测试单元的编写。 2、测试单元保持持久性。 3、利用…

【已解决】AttributeError: module ‘gradio‘ has no attribute ‘Image‘

问题描述 AttributeError: module gradio has no attribute Image 不知道作者用的是哪个gradio版本,最新的版本报错AttributeError: module gradio has no attribute outputs , 换一个老一点的版本会报错AttributeError: module gradio has no attribute…

短线买入卖出有哪些交易技巧?

前面两节课,我们认识了短线交易,知道了短线交易常见的买入卖出时机,这节课,我们来讲解一下短线买入卖出的一些交易技巧。话不多时,直接进入重点! 一、短线交易要果断 短线波动快,在出现买卖信号…

排序算法总结(Python、Java)

Title of Content 1 冒泡排序 Bubble sort:两两交换,大的冒到最后概念排序可视化代码实现Python - 基础实现Python - 优化实现Java - 优化实现C - 优化实现C - 优化实现 2 选择排序 Selection sort:第i轮遍历时,将未排序序列中最小…

华为OD机试 - CPU算力分配(Java JS Python C)

题目描述 现有两组服务器A和B,每组有多个算力不同的CPU,其中 A[i] 是 A 组第 i 个CPU的运算能力,B[i] 是 B组 第 i 个CPU的运算能力。 一组服务器的总算力是各CPU的算力之和。 为了让两组服务器的算力相等,允许从每组各选出一个CPU进行一次交换, 求两组服务器中,用于…

反序列化漏洞详解(一)

目录 一、php面向对象 二、类 2.1 类的定义 2.2 类的修饰符介绍 三、序列化 3.1 序列化的作用 3.2 序列化之后的表达方式/格式 ① 简单序列化 ② 数组序列化 ③ 对象序列化 ④ 私有修饰符序列化 ⑤ 保护修饰符序列化 ⑥ 成员属性调用对象 序列化 四、反序列化 …

【笔记】常用的Linux命令之解压缩:tar、zip、rar 命令

1、tar 常用压缩和解压缩 # 压缩文件 file1 和目录 dir2 到 test.tar.gz tar -zcvf test.tar.gz file1 dir2 # 解压 test.tar.gz(将 c 换成 x 即可) tar -zxvf test.tar.gz 额外知识:查看压缩文件内容 # 列出压缩文件的内容 tar -ztvf test…

unity学习笔记

一、线段渲染器 在Unity中,线段渲染器(Line Renderer)是一种用于在场景中绘制线段的组件。线段渲染器非常适合用于创建轨迹、路径、光束等效果。 1. 创建Line Renderer:在Unity编辑器中,你可以通过创建空对象 -> …

Linux - 动态库的加载 和 重谈进程地址空间 - vscode 当中的 Remote - SSH 插件

推书:《现代操作系统》《操作系统--精髓于设计原理》《UNIX环境高级编程》 目录 前言 程序的加载 程序没有加载之前的地址(此时还是程序) 程序被加载到内存之后(此时是进程) 动态库的地址 静态库的不加载&#xff…

力扣labuladong一刷day24天

力扣labuladong一刷day24天 文章目录 力扣labuladong一刷day24天一、875. 爱吃香蕉的珂珂二、1011. 在 D 天内送达包裹的能力三、410. 分割数组的最大值 一、875. 爱吃香蕉的珂珂 题目链接:https://leetcode.cn/problems/koko-eating-bananas/?utm_sourceLCUS&…

数据结构——堆排序的topk问题

呀哈喽,我是结衣 前言 今天给大家带来的堆排序的topk问题。topk就是在许多数中,找出前k个大的数,可能是几十个数,也可能是几千万个数中找。今天我们将要在1000000(一百万)个数中找出前10大的数。 知识点 C…

【c】角谷猜想

#include<stdio.h> int coll(int x)//定义函数 {int count0;while(x>1){if(x%20){xx/2;count;}else{x3*x1;count;}}return count; } int main() {int n,num;scanf("%d",&n);int arr[n1];for(int i1;i<n;i)//输入n组数据保存到数组中{scanf("%d&…

数据结构之哈希表

数据结构之哈希表 文章目录 数据结构之哈希表一、哈希概念二、哈希冲突三、哈希函数常见哈希函数 四、哈希冲突解决闭散列闭散列的思考线性探测线性探测的实现 二次探测 开散列开散列概念开散列的思考开散列实现 五、开散列与闭散列比较 一、哈希概念 顺序结构以及平衡树中&am…

为获取导入百分比,使用easyexcel获取导入excel表总行数

背景 分批读取大量数据的excel文件&#xff0c;每次读取1000行数据&#xff0c;然后插入数据库&#xff0c;并且去执行一个方法&#xff0c;执行完毕后更新此行数据的状态。需要获取已更新数据的占比&#xff0c;即计算百分比。 因为是分批读取的&#xff0c;我们不可以直接用已…

CPP-SCNUOJ-Problem P26. [算法课动态规划] 打家劫舍

Problem P26. [算法课动态规划] 打家劫舍 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的房屋在同一晚上被小偷闯入&#xff0c;系统会自动…