《迁移学习》—— 将 ResNet18 模型迁移到食物分类项目中

文章目录

  • 一、迁移学习的简单介绍
    • 1.迁移学习是什么?
    • 2.迁移学习的步骤
  • 二、数据集介绍
  • 三、代码实现
    • 1. 步骤
    • 2.所用到方法介绍的文章链接
    • 3. 完整代码

一、迁移学习的简单介绍

1.迁移学习是什么?

  • 迁移学习是指利用已经训练好的模型,在新的任务上进行微调。
  • 迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

2.迁移学习的步骤

  • (1) 选择预训练的模型和适当的层:通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。
  • (2) 冻结预训练模型的参数:保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。
  • (3) 在新数据集上训练新增加的层:在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。
  • (4) 微调预训练模型的层:在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。
  • (5) 评估和测试:在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

二、数据集介绍

  • 下图是数据集的结构
    • 在 food_dataset2 文件夹下含有训练数据和测试数据
    • 训练集和测试集数据中都含有 20 种食物图片,数量在200~400不等
    • trainda.txt 和 testda.txt 文本中存放了每张图片的路径及标签,用 0~19 这20个数字分别对20种食物进行标签
    • 在代码中通过trainda.txt 和 testda.txt 文本中的内容来获取每张图片及对应的标签
      在这里插入图片描述
    • 下面是trainda.txt文本中的部分内容(testda.txt 中的内容格式相同)
      在这里插入图片描述
  • 送福利!!! 私信送此数据集 !!!

三、代码实现

1. 步骤

  • 1.调用resnet18模型,并保存需要训练的模型参数
  • 2.定义一个图像预处理和数据增强字典
  • 3.定义获取每张食物图片和标签的类方法
  • 4.获取训练集和测试集数据
  • 5.对数据集进行打包
  • 6.调用交叉熵损失函数并创建优化器
  • 7.定义训练模型的函数
  • 8.定义测试模型的函数
  • 9.训练模型,并每训练一轮测试一次

2.所用到方法介绍的文章链接

  • ResNet 残差网络神经网络
    • https://blog.csdn.net/weixin_73504499/article/details/142575775?spm=1001.2014.3001.5501
  • 数据增强
    • https://blog.csdn.net/weixin_73504499/article/details/142499263?spm=1001.2014.3001.5501
  • 调整学习率
    • https://blog.csdn.net/weixin_73504499/article/details/142526863?spm=1001.2014.3001.5501

3. 完整代码

import torchimport torchvision.models as models  # 导入存有各种深度学习模型的模块from torch import nn  # 导入神经网络模块from torch.utils.data import Dataset, DataLoader  # Dataset: 抽象类,一种用于获取数据的方法  DataLoader:数据包管理工具,打包数据from torchvision import transforms  # transforms模块提供了一系列用于图像预处理和数据增强的函数和类from PIL import Image  # 用于处理图片import numpy as np""" 调用resnet18模型 """resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)for param in resnet_model.parameters():param.requires_grad = False# 模型所有参数(即权重和偏差)的 requires_grad 属性设置成 False,从而冻结所有模型参数# 使得在反向传播过程中不会计算他们的梯度,从此减少模型的计算量,提高推理速度in_features = resnet_model.fc.in_features  # 获取resnet18模型全连接层原输入的特征个数resnet_model.fc = nn.Linear(in_features, 20)  # 创建一个全连接层输入特征个数为: in_features  输出特征个数为:数据集中事务的种类数量params_to_update = []  # 保存需要训练的参数,仅仅包含全连接层的参数for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)""" 图像预处理和数据增强 """data_transforms = {'train':transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机transforms.CenterCrop(224),  # 中心裁剪transforms.RandomHorizontalFlip(p=0.5),  # 随机水平反转 选择一个概率transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 亮度、对比度transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R G Btransforms.ToTensor(),  # 转化为神经网络可以识别的 Tensor 类型transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 对图片数据进行归一化,[均值],[标准差]]),'valid':transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}""" 定义获取每张食物图片和标签的类方法 """class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label""" 获取训练集和测试集数据 """training_data = food_dataset(file_path='trainda.txt', transform=data_transforms['train'])test_data = food_dataset(file_path='testda.txt', transform=data_transforms['valid'])""" 对数据集进行打包 """train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 64张图片为一个包,shuffle --> 打乱顺序test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)""" 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU """device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"print(f"Using {device} device")# 把模型传入到 gpu 或 cpumodel = resnet_model.to(device)""" 调用交叉熵损失函数 """loss_fn = nn.CrossEntropyLoss()"""" 创建优化器并调整优化器中的学习率--> lr """optimizer = torch.optim.Adam(params_to_update, lr=0.001)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)""" 定义训练模型的函数 """def train(dataloader, model, loss_fn, optimizer):model.train()  # 告诉模型,开始训练# pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。# 一般用法是:在训练开始之前写上model.trian(),在测试时写上model.for X, y in dataloader:X, y = X.to(device), y.to(device)  # 把训练数据集和标签传入cpu或gpupred = model.forward(X)  # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值 lossoptimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数""" 定义测试模型的函数 """best_acc = 0  # 用于更新准确率def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()  # test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # correct是会自动累加每一个批次的正确率test_loss /= num_batches  # 平均的损失值correct /= size  # 平均的正确率print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")# 找到最好的准确率if correct > best_acc:best_acc = correct""" 定义模型训练的轮数,并每训练一轮测试一次 """epochs = 30for e in range(epochs):print(f"Epoch {e + 1}\n---------------------------")train(train_dataloader, model, loss_fn, optimizer)scheduler.step()  # 在每个epoch的训练中,使用scheduler.step()语句进行学习率更新test(test_dataloader, model, loss_fn)print('最优的训练结果为:', best_acc)
  • 结果如下
    • 此结果只是训练了30轮后的结果,可以训练更多轮,最后的准确率还会有所提高
      在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/54354.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙开发(NEXT/API 12)【状态查询与订阅】手机侧应用开发

注意 该接口的调用需要在开发者联盟申请设备基础信息权限与穿戴用户状态权限,穿戴用户状态权限还需获得用户授权。 实时查询穿戴设备可用空间、电量状态。订阅穿戴设备连接状态、低电量告警、用户心率告警。查询和订阅穿戴设备充电状态、佩戴状态、设备模式。 使…

初识Django

前言: 各位观众老爷们好,最近几个月都没怎么更新,主要是最近的事情太多了,我也在继续学习Django框架,之前还参加了一些比赛,现在我会开始持续更新Django的学习,这个过程会比较久,我会把我学习的…

MySQL--三大范式(超详解)

目录 一、前言二、三大范式2.1概念2.2第一范式(1NF)2.3第二范式(2NF)2.3第三范式(3NF) 一、前言 欢迎大家来到权权的博客~欢迎大家对我的博客进行指导,有什么不对的地方,我会及时改进…

嘴尚绝卤味:健康美味的双重奏

在当今快节奏的生活中,人们对美食的追求不再仅仅停留于味蕾的满足,更加注重食物的健康与营养。在这一背景下,"嘴尚绝卤味"以其独特的健康理念与精湛的制作工艺,成为了市场上备受瞩目的卤味品牌。本文将从"嘴尚绝卤…

Django学习笔记九:RESTAPI添加用户认证和授权

在Django REST Framework中添加用户认证和授权,通常涉及以下几个步骤: 1. 认证(Authentication) 认证是指确定用户身份的过程。Django REST Framework提供了多种认证方式: Token Authentication:通过一个…

Kotlin基本知识

Kotlin是一种现代的静态类型编程语言,由JetBrains公司在2010年推出,并被Google在2019年宣布为Android开发的首选语言。 超过 50% 的专业 Android 开发者使用 Kotlin 作为主要语言,而只有 30% 使用 Java 作为主要语言。 70% 以 Kotlin 为主要语…

APK安装包arm64-v8a、armeabi-v7a、x86、x86_64如何区别?(2024年10月1日)

其实就是安卓CPU的进步史 安卓CPU类型: arm64-v8a: 第8代、64位ARM处理器,目前手机大多数是此架构(新手机,可以无脑选择)armeabiv-v7a: 第七代及以上的 ARM 处理器。2011年5月以后生产的大部分安卓设备都使用它armeabi: 第5代、第6代的ARM处理器&#…

文章解读与仿真程序复现思路——电力自动化设备EI\CSCD\北大核心《考虑光伏不确定性的配电网谐波监测优化配置方法 》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

华为云技术深度解析:以系统性创新加速智能化升级

华为云技术深度解析:以系统性创新加速智能化升级 在当今数字化转型的浪潮中,云计算作为关键的基础设施,正以前所未有的速度推动着各行各业的智能化升级。作为全球领先的云服务提供商,华为云凭借其深厚的技术积累和创新实力&#…

Azure DevOps Server:不能指派新增的用户

Contents 1. 概述2. 解决方案 1. 概述 近期和微软Azure DevOps项目组解决了一个“无法指派开发人员”的问题,在此分享给大家。问题描述: 在一个数据量比较大的Azure DevOps Server的部署环境中,用户发现将新用户的AD域账户添加到Azure DevOps…

《15分钟轻松学 Python》教程目录

为什么要写这个教程呢,主要是因为即使是AI技术突起的时代,想要用好AI做开发,那肯定离不开Python,就算最轻量级的智能体都有代码块要写,所以不一定要掌握完完整整的Python,只要掌握基础就能应对大部分场景。…

数据看板如何提升决策效率?

数据看板作为一种直观、高效的数据可视化工具,在这一过程中发挥着至关重要的作用。以一家中型制造企业为例,每天面临着生产计划的安排、原材料的采购、产品质量的把控以及市场销售的策略制定等诸多业务场景。在生产线上,需要确保设备的高效运…

《计算机原理与系统结构》学习系列

系列文章目录 一、计算机概要与技术 二、指令:计算机的语言(上) 三、指令:计算机的语言(中) 四、指令:计算机的语言(下) 五、计算机的算数运算(上&#…

【隐私计算篇】多方安全计算之函数秘密共享(FSS)

1. 函数秘密共享(FSS)定义 秘密共享是一种将一个值拆分为多个份额的方法,形式有多种,可以参考《安全多方计算(MPC)矩阵乘法算子的原理分析》。这里主要提及加法秘密共享,使得:这些份额可以重新组合以还原出秘密值;任…

【HTML并不简单】笔记1-常用rel总结:nofollow、noopener、opener、noreferrer,relList

文章目录 rel"nofollow"rel"noopener"与rel"opener"rel"noreferrer"relList对象 《HTML并不简单:Web前端开发精进秘籍》张鑫旭,一些摘要: HTML,这门语言的知识体系非常庞杂,涉…

Python数据结构与算法问题详解

Python数据结构与算法问题详解 Python 作为一种高级编程语言,凭借其简洁的语法和强大的内置库,成为了数据结构与算法学习的绝佳工具。本文将深入解析几种常见的数据结构,并结合具体的算法,展示如何在实际问题中高效解决问题。通过…

《PMI-PBA认证与商业分析实战精析》第7章 解决方案评价

第7章 解决方案评价 本章主要内容: 评价的建议思维 解决方案的评价计划 确定评价什么 何时以及如何验证解决方案的结果 评价验收标准和解决缺陷 促进通过/不通过的决策 获得解决方案的签字确认 评价解决方案的长期绩效 解决方案替换/淘汰 本章涵盖的考试…

36 指针与 const 的多种搭配:指向常量的指针、常量指针、指向常量的常量指针、指针到指针的常量(涉及双重指针)

目录 1 指向常量的指针 1.1 概念 1.2 语法格式 1.3 声明与初始化 1.4 实际应用 1.4.1 保护数据 1.4.2 函数参数 1.4.3 字符串常量 2 常量指针 2.1 概念 2.2 语法格式 2.3 声明与初始化 2.4 实际应用 2.4.1 保护指针指向的地址 2.4.2 数组处理 2.4.3 函数参数 …

ASP.NET Core 创建使用异步队列

示例图 在 ASP.NET Core 应用程序中,执行耗时任务而不阻塞线程的一种有效方法是使用异步队列。在本文中,我们将探讨如何使用 .NET Core 和 C# 创建队列结构以及如何使用此队列异步执行操作。 步骤 1:创建 EmailMessage 类 首先&#xff0c…

1、Spring Boot 3.x 集成 Eureka Server/Client

一、前言 基于 Spring Boot 3.x 版本开发,因为 Spring Boot 3.x 暂时没有正式发布,所以很少有 Spring Boot 3.x 开发的项目,自己也很想了踩踩坑,看看 Spring Boot 3.x 与 2.x 有什么区别。自己与记录一下在 Spring Boot 3.x 过程…