使用pytorch进行迁移学习的两个步骤

1. 步骤及代码

迁移学习一般都会使用两个步骤进行训练:

  1. 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;
  2. 使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。
import os
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms as T
import numpy as np# 设置均值、方差
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]# 还原减均值除以方差之前的数据,用于可视化
def reduction_img_show(tensor, mean, std) -> None:to_img = T.ToPILImage()reduced_img = to_img(tensor * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1))reduced_img.show()def getResNet(*, class_names: str, loadfile: str = None):if loadfile is not None:model = torchvision.models.resnet18()model.load_state_dict(torch.load('resnet18-f37072fd.pth'))  # 加载权重else:model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)  # 模型自动下载到C:\Users\GaryLau\.cache\torch\hub\checkpoints# 将所有的参数层冻结,设置模型除最后一层以外都不可以进行训练,使模型只针对最后一层进行微调for param in model.parameters():param.requires_grad = False# 输出全连接层信息print(model.fc)x = model.fc.in_features  # 获取全连接层输入维度model.fc = torch.nn.Linear(in_features=x, out_features=len(class_names))  # 创建新的全连接层print(model.fc)  # 输出新的全连接层return model# 定义训练函数
def train(model, device, train_loader, criterion, optimizer, epoch):model.train()all_loss = []for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()y_pred = model(data)loss = criterion(y_pred, target)loss.backward()all_loss.append(loss.item())optimizer.step()if batch_idx % 10 == 0:print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),np.mean(all_loss)))def val(model, device, val_loader, criterion):model.eval()test_loss = []correct = []with torch.no_grad():for data, target in val_loader:data, target = data.to(device), target.to(device)y_pred = model(data)test_loss.append(criterion(y_pred, target).item())pred = y_pred.argmax(dim=1, keepdim=True)correct.append(pred.eq(target.view_as(pred)).sum().item()/pred.size(0))print('-->Test: Average loss:{:.4f}, Accuracy:({:.0f}%)\n'.format(np.mean(test_loss), 100 * sum(correct) / len(correct)))# 训练,验证时的预处理
transform = {'train': T.Compose([T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=mean, std=std)]),'val': T.Compose([T.Resize((224,224)),T.ToTensor(),T.Normalize(mean=mean, std=std)])}# 加载训练、验证数据
dataset_train = ImageFolder(r'./train', transform=transform['train'])
dataset_val = ImageFolder(r'./test', transform=transform['val'])# 类别标签
class_names = dataset_train.classes
print(dataset_train.class_to_idx)
print(dataset_val.class_to_idx)# 显示一张训练、验证图
# reduction_img_show(dataset_train[0][0], mean, std)
# reduction_img_show(dataset_val[0][0], mean, std)# 使用DataLoader遍历数据
dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, sampler=None, num_workers=0,pin_memory=False, drop_last=False)
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, sampler=None, num_workers=0,pin_memory=False, drop_last=False)# 使用方式一,使用next不断获取一个batch的数据
dataiter_train = iter(dataloader_train)
imgs, labels = next(dataiter_train)
print(imgs.size())
# reduction_img_show(imgs[0], mean, std)
# reduction_img_show(imgs[1], mean, std)
multi_imgs = torchvision.utils.make_grid(imgs, nrow=10)  # 拼接一个batch的图像用于展示
# reduction_img_show(multi_imgs, mean, std)# 获取ResNet模型,并加载预训练模型权重,将最后一层(输出层)去掉,换成一个新的全连接层,新全连接层输出的节点数是新数据的类别数
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)# 构建模型
model = getResNet(class_names=class_names, loadfile='resnet18-f37072fd.pth')
model.to(device)# 构建损失函数
criterion = torch.nn.CrossEntropyLoss()
# 指定新加的全连接层为要更新的参数
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)  # 只需要更新最后一层fc的参数if __name__ == '__main__':### 步骤一,微调最后一层first_model = 'resnet18-f37072fd_finetune_fcLayer.pth'for epoch in range(1, 6):train(model, device, dataloader_train, criterion, optimizer, epoch)val(model, device, dataloader_val, criterion)# 仅保存了最后新添加的全连接层的参数#torch.save(model.fc.state_dict(), first_model)torch.save(model.state_dict(), first_model)### 步骤二,小学习率微调所有层second_model = 'resnet18-f37072fd_finetune_allLayer.pth'optimizer2 = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=3, gamma=0.9)# 将所有的参数层设为可训练的for param in model.parameters():param.requires_grad = Trueif os.path.exists(second_model):model.load_state_dict(torch.load(second_model))   # 加载本地模型else:model.load_state_dict(torch.load(first_model))    # 加载步骤一训练得到的本地模型print('Finetune all layers with small learning rate......')for epoch in range(1, 101):train(model, device, dataloader_train, criterion, optimizer2, epoch)if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:exp_lr_scheduler.step()print(f"learning rate: {optimizer2.state_dict()['param_groups'][0]['lr']}")val(model, device, dataloader_val, criterion)# 保存整个模型torch.save(model.state_dict(), second_model)print('Done.')

2. 完整资源

https://download.csdn.net/download/liugan528/89833913

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/55464.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

利用 notepad++ 初步净化 HaE Linkfinder 规则所提取的内容(仅留下接口行)

去掉接口的带参部分 \?.*去掉文件行 .*\.(docx|doc|xlsx|xls|txt|xml|html|pdf|ppt|pptx|odt|ods|odp|rtf|md|epub|css|scss|less|sass|styl|png|jpg|jpeg|gif|svg|ico|bmp|tiff|webp|heic|dds|raw|vue|js|ts|mp4|avi|mov|wmv|mkv|flv|webm|mp3|wav|aac|flac|ogg|m4a).*(\r\…

Redission · 可重入锁(Reentrant Lock)

前言 Redisson是一个强大的分布式Java对象和服务库,专为简化在分布式环境中的Java开发而设计。通过Redisson,开发人员可以轻松地在分布式系统中共享数据、实现分布式锁、创建分布式对象,并处理各种分布式场景的挑战。 Redisson的设计灵感来…

【AI大模型】使用Embedding API

一、使用OpenAI API 目前GPT embedding mode有三种,性能如下所示: 模型每美元页数MTEB得分MIRACL得分text-embedding-3-large9,61554.964.6text-embedding-3-small62,50062.344.0text-embedding-ada-00212,50061.031.4 MTEB得分为embedding model分类…

如何解决深拷贝循环引用的问题

深拷贝循环引用的问题是JavaScript中一个常见且需要仔细处理的问题。循环引用指的是对象之间存在相互引用的关系,形成一个闭环,这样在深拷贝过程中可能会导致递归无限循环,占用大量内存,并最终导致堆栈溢出。以下是一些解决深拷贝…

快速上手C语言【上】(非常详细!!!)

目录 1. 基本数据类型 2. 变量 2.1 定义格式 和 命名规范 2.2 格式化输入和输出(scanf 和 printf) ​编辑 2.3 作用域和生命周期 3. 常量 4. 字符串转义字符注释 5. 操作符 5.1 双目操作符 5.1.1 算数操作符 5.1.2 移位操作符 5.1.3 位操作符…

【C/C++】错题记录(四)

题目一 一个函数可以有很多个返回值(有很多个return语句),但是最终只能有一个return语句执行。 题目二 题目三 题目四 题目五 程序数据结构算法 题目六 题目七 题目八 题目九 D选项是语句……

前端知识汇总(持续更新)

见:GitHub - eHackyd/Front-End: 前端知识汇总 包含:html,css,js,ts等等(语法使用实例)

Top4免费音频剪辑软件大比拼,2024年你选哪一款?

现在我们生活在一个数字化的时代,音频内容对我们来说很重要。不管是给自己拍的视频配背景音乐、整理开会时的录音,还是自己写歌,有个好用的音频剪辑软件都特别重要。今天,我要给大家介绍几款特别好用的音频剪辑软件免费的&#xf…

模型 SECI(知识的创造)

系列文章 分享 模型,了解更多👉 模型_思维模型目录。知识创造的螺旋转化模型。 1 SECI的应用 1.1 Tech Innovations移动应用创新 Tech Innovations是一家软件开发公司,致力于开发创新的移动应用程序。为了提升团队的知识共享和创新能力&…

Unity3D 单例模式

Unity3D 泛型单例 单例模式 单例模式是一种创建型设计模式,能够保证一个类只有一个实例,提供访问实例的全局节点。 通常会把一些管理类设置成单例,例如 GameManager、UIManager 等,可以很方便地使用这些管理类单例,…

【Qt】Qt学习笔记(一):Qt界面初识

Qt 是一个跨平台应用程序和 UI 开发框架。使用 Qt 您只需一次性开发应用程序,无须重新编写源代码,便可跨不同桌面和嵌入式操作系统部署这些应用程序。Qt Creator是跨平台的Qt集成开发环境。 创建项目 Qt的一些界面,初学时一般选择Qt Widgets …

在线教育系统开发:SpringBoot框架的实战应用

4系统概要设计 4.1概述 本系统采用B/S结构(Browser/Server,浏览器/服务器结构)和基于Web服务两种模式,是一个适用于Internet环境下的模型结构。只要用户能连上Internet,便可以在任何时间、任何地点使用。系统工作原理图如图4-1所示: 图4-1系统工作原理…

Linux下静态库与动态库制作及分文件编程

Linux下静态库与动态库制作及分文件编程 文章目录 Linux下静态库与动态库制作及分文件编程1.分文件编程1.1优点1.2操作逻辑1.3示例 2.Linux库的概念3.静态库的制作与使用3.1优缺点3.2命名规则3.3制作步骤3.4开始享用 4.动态库的制作与使用4.1优缺点4.2动态库命名规则4.3制作步骤…

说说你对es6中promise的理解?

ES6中的Promise是一个非常重要的特性,它为异步编程提供了一种更优雅、更简洁的解决方案。以下是我对ES6中Promise的理解: 一、Promise的基本概念 Promise是异步编程的一种解决方案,它代表了一个异步操作的最终完成(或失败&#…

GO实战课】第五讲:电子商务网站(5)——用户管理和注册

1. 简介 本课程将探讨电子商务网站的用户管理和注册功能,以及使用GO语言实现。在本课程中,我们将介绍如何设计一个可扩展、可靠和高性能的用户管理和注册系统,并演示如何使用GO语言编写相关代码。 本课程的目标是帮助学生理解电子商务网站的用户管理和注册功能,并提供一个…

【数据结构】双向链表(Doubly Linked List)

双向链表(Doubly Linked List)是一种链式数据结构,它的每个节点都包含三个部分:数据、指向前一个节点的指针(prev),以及指向下一个节点的指针(next)。与单向链表不同&…

基于Vue的汽车维修配件综合管理系统设计与实现SpringBoot后端源码

目录 1. 系统背景 2. 系统目标 3. 功能模块 4. 技术选型 5. 关键技术点 6. 实现步骤 7. 项目意义 8. 后期展望 1. 系统背景 市场需求分析:随着汽车保有量的不断增加,汽车维修和保养的需求日益增长。车主对维修质量和配件质量的要求也越来越高。汽…

class 004 选择 冒泡 插入排序

我感觉这个真是没有什么好讲的, 这个是比较简单的, 感觉没有什么必要写一篇博客, 而且这个这么简单的排序问题肯定有人已经有写好的帖子了, 肯定写的比我好, 所以我推荐大家直接去看“左程云”老师的讲解就很好了, 一定是能看懂的, 要是用文字形式再写一遍, 反而有点画蛇添足了…

java中有两个list列表,尽量少的去循环

java中有两个list列表,一个list列表是paymentRecord,另外一个list是listApplyBase,paymentRecord中的lendCode字段值跟listApplyBase中的repaymentCode字段值是对应的,用stream流去循环paymentRecord列表,然后判断当pa…

javascript中如何实现继承(附案例)

在 JavaScript 中,有多种实现继承的方法,最常用的有原型链继承、构造函数继承、组合继承和 class 继承(ES6)。下面以 ES6 的 class 继承为例,展示如何实现继承: 示例: // 父类 class Animal {…