深度学习:迁移学习

迁移学习

标题1.什么是迁移学习

迁移学习(Transfer Learning)是一种机器学习方法,就是把为任务 A 开发 的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过 从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算 法都是为了解决单个任务而设计的,但是促进迁移学习的算法的开发是机器学 习社区持续关注的话题。 迁移学习对人类来说很常见,例如,我们可能会发现 学习识别苹果可能有助于识别梨,或者学习弹奏电子琴可能有助于学习钢琴。
找到目标问题的相似性,迁移学习任务就是从相似性出发,将旧领域 (domain)学习过的模型应用在新领域上

标题2.迁移学习的步骤

1、选择预训练的模型和适当的层
通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。
2、冻结预训练模型的参数
保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。
3、在新数据集上训练新增加的层
在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。
4、微调预训练模型的层
在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。
5、评估和测试
在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

标题3.迁移学习实例

该实例使用的模型是ResNet-18残差神经网络模型
###1. 导入必要的库

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

这里导入了后续代码会用到的库,具体如下:
torch:PyTorch 深度学习框架的核心库。
torchvision.models:包含了预训练的模型,这里会用到 ResNet-18。
torch.nn:用于构建神经网络的模块。
torch.utils.data.Dataset 和 torch.utils.data.DataLoader:用于自定义数据集和加载数据。
torchvision.transforms:用于图像的预处理。
PIL.Image:用于读取图像。
numpy:用于数值计算。
###2. 加载预训练模型并修改全连接层

resnet_model= models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():print(param)param.requires_grad=False
in_features=resnet_model.fc.in_features
resnet_model.fc=nn.Linear(in_features,20)
params_to_update=[]
for param in resnet_model.parameters():if param.requires_grad==True:params_to_update.append(param)

加载预训练的 ResNet-18 模型。
把模型中所有参数的 requires_grad 设置为 False,也就是冻结这些参数,使其在训练时不更新。
获取原模型全连接层的输入特征数,然后将全连接层替换为一个新的全连接层,输出维度为 20。
收集所有 requires_grad 为 True 的参数,这些参数会在训练时更新。
###3. 定义图像预处理变换

data_transforms = {'train':transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),transforms.CenterCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),'valid':transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

定义了两个图像预处理的组合变换,分别用于训练集和验证集。
训练集的变换包含了数据增强操作,像随机旋转、水平翻转、垂直翻转等。
验证集的变换只包含了调整大小、转换为张量和标准化操作。

4. 自定义数据集类

class food_dataset(Dataset):def __init__(self,file_path,transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return  len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label,dtype=np.int64))return image,label

自定义了一个 food_dataset 类,继承自 torch.utils.data.Dataset。 init 方法:解析包含图像路径和标签的文本文件,把图像路径和标签分别存到 self.imgs 和 self.labels 中。
len 方法:返回数据集的大小。
getitem 方法:根据索引读取图像,对图像进行预处理,将标签转换为张量,然后返回图像和标签。

5. 创建数据集和数据加载器

training_data = food_dataset(file_path='./trainbig.txt',transform=data_transforms['train'])
test_data = food_dataset(file_path='./testbig.txt',transform=data_transforms['valid'])
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)

创建训练集和测试集的数据集对象。
创建训练集和测试集的数据加载器,设置批量大小为 64,并且打乱数据
###6. 配置训练设备、损失函数、优化器和学习率调度器

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model=resnet_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update,lr=0.001)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5)

选择合适的训练设备(GPU 或 CPU)。
把模型移动到所选设备上。
定义交叉熵损失函数。
定义 Adam 优化器,只对之前收集的需要更新的参数进行优化。
定义学习率调度器,每 5 个 epoch 将学习率乘以 0.5。
###7. 定义训练和测试函数

def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()def test(dataloader, model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches =len(dataloader)model.eval()test_loss,correct =0,0with torch.no_grad():for X, y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)test_loss+=loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result:\n Accuracy:{(100 * correct)}%, Avg loss: {test_loss}")acc_s.append(correct)loss_s.append(test_loss)if correct>best_acc:best_acc=correct

train 函数:将模型设置为训练模式,遍历训练数据加载器,计算损失,反向传播并更新模型参数。
test 函数:将模型设置为评估模式,遍历测试数据加载器,计算测试集的准确率和平均损失,记录最佳准确率。
8. 训练模型并保存

epochs = 20
acc_s = []
loss_s =[]
for t in range(epochs):print(f"Epoch {t + 1}\n-----------")train(train_dataloader, model,loss_fn, optimizer)scheduler.step()test(test_dataloader,model,loss_fn)
print('最优训练结果为:',best_acc)
torch.save(model.state_dict(), 'food_classification_model.pt')

训练模型 20 个 epoch。
每个 epoch 结束后,更新学习率并进行测试。
打印最优训练结果。
保存模型的参数到 food_classification_model.pt 文件中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/903344.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Rabbitmq下载和安装(Windows系统,百度网盘)

一.下载安装Erlang 1.百度云下载 链接:https://pan.baidu.com/s/1k_U25KKngEf1iXWD1ANOeg 提取码:8ilc 2.安装 傻瓜式安装 直接下一步 选择自己要安装的路径 3.配置环境变量 增加变量名为:ERLANG_HOME 变量值填写自己的安装路径&#x…

(一)Linux的历史与环境搭建

【知识预告】 Linux背景介绍Linux操作系统特性Linux的应用场景Linux的发行版本搭建Linux环境 1 Linux背景介绍 1.1 什么是Linux? Linux是一种自由、开源的操作系统。严格来说,它是基于类Unix设计思想,旨在为用户提供稳定、安全、高效的计…

光流法:从传统方法到深度学习方法

1 光流法简介 光流(Optical Flow)是指图像中像素灰度值随时间的变化而产生的运动场。 简单来说,它描述了图像中每个像素点的运动速度和方向。 光流法是一种通过分析图像序列中像素灰度值来计算光流的方法。对于图像数据计算出来的光流是一个二…

解决ssh拉取服务器数据,要多次输入密码的问题

问题在于,每次循环调用 rsync 都是新开一个连接,所以每次都需要输入一次密码。为了只输入一次密码,有以下几种方式可以解决: ✅ 推荐方案:设置 SSH 免密登录 最稳最安全的方式是:配置 SSH 免密登录&#x…

web技术与Nginx网站服务

目录 一. web基础 1. 域名概念 2. Hosts 文件 3. DNS 4. 域名注册 5. 网页与 HTML 二. 网页概述 1. HTML 概述 2. HTML 基本标签 3. 网站和主页 三. 静态网页与动态网页 1. 静态网页 2. 动态网页 3. 动态网页语言 四. HTTP 协议 1. HTTP 协议概述 2. HTTP …

信创系统资产清单采集脚本:主机名+IP+MAC 一键生成 CSV

原文链接:信创系统资产清单采集脚本:主机名IPMAC 一键生成 CSV Hello,大家好啊!今天给大家带来一篇在信创终端操作系统上自动批量采集主机名、IP 和 MAC 并导出为 CSV 表格的实战文章!本方案使用 sshpass 和 Bash 脚本…

【dify+docker安装教程】

目录 一、dify安装包下载 二、运行环境配置 1、下载docker 2、安装 2.1 新建文件夹 2.2 安装 2.3 命令安装 3.下载完成后需要重启电脑,注意保存文档!!注意保存!!注意!!(血的教…

HTML 地理定位(Geolocation)教程

HTML 地理定位(Geolocation)教程 简介 HTML5 的 Geolocation API 允许网页应用获取用户的地理位置信息。这个功能可用于提供基于位置的服务,如导航、本地搜索、天气预报等。本教程将详细介绍如何在网页中实现地理定位功能。 工作原理 浏览器可以通过多种方式确定…

协作开发攻略:Git全面使用指南 — 引言

协作开发攻略:Git全面使用指南 — 引言 Git 是一种分布式版本控制系统,用于跟踪文件和目录的变更。它能帮助开发者有效管理代码版本,支持多人协作开发,方便代码合并与冲突解决,广泛应用于软件开发领域。 文中内容仅限技…

毕业设计-基于预训练语言模型与深度神经网络的Web入侵检测系统

项目技术说明 基于预训练语言模型与深度神经网络的Web入侵检测系统,通过预训练模型CodeBert分词,将分词输入给BiGRU的深度学习模型训练。通过sniff函数实时捕获http流量信息,将流量信息输入给模型进行检测,模型可以检测的类别有S…

[计算机科学#4]:二进制如何塑造数字世界(0和1的力量)

【核知坊】:释放青春想象,码动全新视野。 我们希望使用精简的信息传达知识的骨架,启发创造者开启创造之路!!! 内容摘要: 二进制是计算机世界的基石,数学是世界的…

JUC中各种锁机制的应用和原理及死锁问题定位

JUC中各种锁机制的应用和原理及死锁问题定位 在互联网大厂Java求职者的面试中,经常会被问到关于JUC(Java Util Concurrency)中的各种锁机制及其应用和原理的问题。本文通过一个故事场景来展示这些问题的实际解决方案。 第一轮提问 面试官&…

配置Ubuntu18.04中的Qt Creator为中文(图文详解)

配置Qt Creator为中文 1、前言2、先设置Ubuntu系统语言为中文3、配置Qt Creator中文环境2.1 IBus输入法(方法一)2.2、测试IBus输入法2.21IBus输入法终端中测试2.2.2IBus输入法Qt Creator中测试 2.3、Fcitx输入法(方法二)2.3.1安装…

高性能服务器配置经验指南3——安装服务器可能遇到的问题及解决方法

文章目录 1、重装系统后VScode远程连接失败问题2、XRDP连接黑屏问题1. 打开文件2. 添加配置3. 重启xrdp服务 3、VScode远程免密连接问题4、Vim编辑文件时出现不同用户冲突编辑的问题 在完成 服务器基本配置和 深度学习环境准备后,大家应该就可以正常使用服务器了&…

PyQt6基础_QThread

目录 前置 代码: 运行 正常运行 QThread运行报错 视频 前置 1 PySide6.QtCore.QThread - Qt for Python QThread官方文档 2 长时间任务可以放到QThread中执行,避免占用主线程导致界面卡顿无法操作 代码: import traceback,sys fro…

Spring Boot 应用运行指南

🚀 Spring Boot 应用运行指南 ⚙️ 使用 Maven 🔧 运行命令 $ mvn spring-boot:run✨ 启动效果 . ____ _ __ _ _/\\ / ____ __ _ _(_)_ __ __ _ \ \ \ \ ( ( )\___ | _ | _| | _ \/ _ | \ \ \ \\\/ ___)| |_)| | | | | || (_…

jeecgboot 3.8.0 集成knife4j问题一文解决

问题描述: ​ 在cloud环境下,若应用系统配置了context-path,则无法通过网关进入后台接口管理系统 原因分析: ​ 查看请求信息发现少拼接了系统的context-path,导致无法正确请求到数据。直接使用正确的地址可以正常通过网关访问。故此确定为集成knife4j的问题。 解决办法…

【Flutter】Flutter + Unity 插件结构与通信接口封装

关联文档:【方案分享】Flutter Unity 跨平台三维渲染架构设计全解:插件封装、通信机制与热更新机制—— 支持 Android/iOS/Web 的 3D 内容嵌入与远程资源管理,助力 XR 项目落地 —— 支持 Android/iOS/Web 的 3D 内容嵌入与远程资源管理&…

推荐 1 款 9.3k stars 的全景式开源数据分析与可视化工具

Orama 是一个开源的数据分析与可视化项目,由askorama团队开发和维护。该项目旨在为用户提供一套强大而易用的工具集,帮助用户轻松处理和理解大规模数据,通过创建交互式且引人入胜的数据可视化图表,揭示隐藏在数据背后的深层次洞察…

关于windows API 的键鼠可控可测

相关函数解释 GetAsyncKeyState 是 Windows API 中的一个函数,用于判断某个虚拟键是否被按下。GetAsyncKeyState(VK_ESCAPE) 专门用于检测 Esc 键的状态。下面为你详细介绍其用法: 函数原型 cpp SHORT GetAsyncKeyState( int vKey ); 参数 vKey&a…