实验13 使用预训练resnet18实现CIFAR-10分类

1.数据预处理

首先利用函数transforms.Compose定义了一个预处理函数transform,里面定义了两种操作,一个是将图像转换为Tensor,一个是对图像进行标准化。然后利用函数torchvision.datasets.CIFAR10下载数据集,这个函数有四个常见的初始化参数:root为数据存储的路径,如果数据已经下载,会直接从这个路径加载数据。train如果为True,表示加载训练集,train如果为False,加载测试集。download如果设置为True,表示如果本地不存在数据集,会自动从互联网上下载。transform指定一个转换函数,对数据进行预处理和数据增强等操作。所以下载训练集train_full时,train赋值为True,下载测试集时,train赋值为False。之后对下载的训练集train_full进行划分,先规定指定的大小,然后利用random_split进行划分,最后就是创建Dataloader,batch_size设为64,得到train_loader,val_loader,test_loader。

代码:

# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 数据预处理和增强
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 图像标准化
])# 下载 CIFAR-10 数据集
train_full = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)# 划分训练集(40,000)和验证集(10,000)
train_size = int(0.8 * len(train_full))  # 80% 用于训练
val_size = len(train_full) - train_size  # 剩余 20% 用于验证
train_data, val_data = random_split(train_full, [train_size, val_size])# 创建 DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = DataLoader(test, batch_size=64, shuffle=False)

2.模型构建

模型构建就比较简单,直接使用使用pytorch定义的库函数,只有一行代码:

model = models.resnet18(pretrained=False),pretrained=False表示不使用在Imagenet上预训练的权重,pretrained=True表示使用在Imagenet上预训练的权重。因为这个模型是训练Imagenet构建的模型,要想让这个模型适应新任务,需要获取最后一层的输入特征数,然后利用一个全连接层将输出改为10。

代码:

# 初始化 ResNet-18 模型
model = models.resnet18(pretrained=True)
# 修改最后一层(全连接层),适应新的任务
num_ftrs = model.fc.in_features  # 获取最后一层的输入特征数
model.fc = torch.nn.Linear(num_ftrs, 10)  # 将输出改为 10 个类别(例如 CIFAR-10)

3.模型训练

创建Runner类,管理训练、评估、测试和预测过程。还是之前的一套东西,首先是一个init函数,用于初始化数据集、损失函数、优化器等。train函数用于计算在训练集上的loss,并反向传播更新参数。evaluate函数用于计算在验证集上的损失,不用反向传播更新模型的参数,同时根据evaluate函数得到的损失判断是否保存最优模型,利用state_dict函数保存最优模型。test函数首先加载最优模型,然后在测试集计算最优模型的准确率。predict函数预测某个图像属于某个类别的概率,虽然resnet最后一层没有softmax,但是也可以根据最后一层得到的10个logits(未经过归一化的原始输出)取最大来判断图像属于某一类(因为这10个值也是有大小关系的,softmax函数不会修改这10个值的大小关系)。

定义学习率=0.01、批次大小=30、损失函数为交叉熵损失nn.CrossEntropyLoss()、优化器为Adam。

实例化Runner,调用train函数,开始训练。

代码:

class Runner:def __init__(self, model, train_loader, val_loader, test_loader, criterion, optimizer, device):self.model = model.to(device)  # 将模型移到GPUself.train_loader = train_loaderself.val_loader = val_loaderself.test_loader = test_loaderself.criterion = criterionself.optimizer = optimizerself.device = deviceself.best_model = Noneself.best_val_loss = float('inf')self.train_losses = []  # 存储训练损失self.val_losses = []  # 存储验证损失def train(self, epochs=10):for epoch in range(epochs):self.model.train()running_loss = 0.0for inputs, labels in self.train_loader:# 将数据移到GPUinputs, labels = inputs.to(self.device), labels.to(self.device)self.optimizer.zero_grad()outputs = self.model(inputs)loss = self.criterion(outputs, labels)loss.backward()self.optimizer.step()running_loss += loss.item()# 计算平均训练损失train_loss = running_loss / len(self.train_loader)self.train_losses.append(train_loss)# 计算验证集上的损失val_loss = self.evaluate()self.val_losses.append(val_loss)print(f'Epoch [{epoch + 1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')# 如果验证集上的损失最小,保存模型if val_loss < self.best_val_loss:self.best_val_loss = val_lossself.best_model = self.model.state_dict()def evaluate(self):self.model.eval()val_loss = 0.0with torch.no_grad():for inputs, labels in self.val_loader:# 将数据移到GPUinputs, labels = inputs.to(self.device), labels.to(self.device)outputs = self.model(inputs)loss = self.criterion(outputs, labels)val_loss += loss.item()return val_loss / len(self.val_loader)def test(self):self.model.load_state_dict(self.best_model)self.model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in self.test_loader:# 将数据移到GPUinputs, labels = inputs.to(self.device), labels.to(self.device)outputs = self.model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_accuracy = correct / totalprint(f'Test Accuracy: {test_accuracy:.4f}')def predict(self, image):self.model.eval()image = image.to(self.device)  # 将图像移到GPUwith torch.no_grad():output = self.model(image)_, predicted = torch.max(output, 1)return predicted.item()def visualize_and_predict(self, index):"""针对训练集中的某一张图片进行预测,并可视化图片。:param index: 训练集中的图片索引"""# 获取训练集中的第 index 张图片image, label = self.train_loader.dataset[index]# 将图像移到GPU(如果需要)image = image.unsqueeze(0).to(self.device)  # 增加一个维度作为batch size# 可视化图像plt.imshow(image.cpu().squeeze().numpy(), cmap='gray')  # 假设是灰度图,若是彩色图像要调整plt.title(f"True Label: {label}")plt.show()# 预测该图片的类别predicted_label = self.predict(image)print(f"Predicted Label: {predicted_label}")
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 实例化Runner类
runner = Runner(model, train_loader, val_loader, test_loader, criterion, optimizer, device)# 训练模型
runner.train(epochs=30)
# 绘制损失曲线
plt.figure(figsize=(10, 6))
plt.plot(runner.train_losses, label='Train Loss')
plt.plot(runner.val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.grid()
plt.show()

4.模型评价

调用test函数,计算在测试集上的准确率。

代码:

# 在最优模型上评估测试集准确率
runner.test()

5.模型预测

在训练集任意选取一个图像,获取图像的image和标签label,因为图像已经经过了transform的变换,所以这个图像不需要transform,只需要添加一个维度1作为batch_size,可视化图像和真实标签,然后调用predict函数进行预测,输出真实类别。

代码:

# CIFAR-10 是 RGB 图像,确保正确显示
# 将 Tensor 转换为 numpy 数组并调整维度顺序为 HWC (Height, Width, Channels)
image_np = image.numpy().transpose((1, 2, 0))  # 从 CHW 转为 HWC# 可视化图像
plt.imshow(image_np)
plt.title(f"True Label: {label}")
plt.show()# 直接将图像传递给预测函数,不再需要 transform
# 但是要确保图像传入时是正确的 batch size 形状,即增加一个 batch 维度
image_transformed = image.unsqueeze(0).to(device)  # 增加一个维度作为 batch size# 预测该图片的类别
predicted_label = runner.predict(image_transformed)
print(f"Predicted Label: {predicted_label}")

6.实验结果与分析

不使用预训练权重的损失变化、准确率和预测结果

使用预训练权重的损失变化、准确率和预测结果

通过观察损失变化,我们发现两个模型在训练集上的loss一直在减小,说明模型的参数一直在更新。但是在验证集上的损失一开始是下降的,但是后来不断增大,我觉得是因为模型过拟合了。但是可以发现在没有预训练权重上的最优验证损失是比有预训练权重的模型上的最优验证损失大的。通过保存最优模型,在最优模型上计算准确率,发现在没有预训练权重的模型得到的准确率是0.7332,在使用预训练权重的模型得到的准确率是0.7431。

结论:通过对比在验证集上的最优验证损失和在测试集上的准确率,得到结论使用了预训练的模型效果要更好。

7.总结与心得体会

总结:

1.预训练模型:

预训练模型是指在一个大规模数据集上(如 ImageNet、COCO 等)经过训练的模型。这个模型已经学习到了一些通用的特征,比如图像中的边缘、纹理、颜色、形状等,或者文本中的语法、词汇关系等。这些特征是从数据中自动学习的,并且在很多不同的任务中都有用。

例子:

在图像分类任务中,ResNet、VGG、Inception 等深度神经网络在 ImageNet 上经过训练后,它们可以识别成千上万种不同的物体。由于这些物体特征具有广泛的普适性,我们可以将这些模型用于其他图像分类任务(例如 Cifar-10、Cifar-100),而无需从头开始训练。

在自然语言处理(NLP)中,像 BERT、GPT 等预训练语言模型已经在大量的文本数据上训练过,学习了丰富的语言知识。因此,我们可以将这些模型应用于文本分类、情感分析、问答等任务。

预训练模型的优势:

节省计算资源:训练深度神经网络需要大量的计算资源和时间,尤其是在大规模数据集上。通过使用预训练模型,用户可以避免从零开始训练,直接利用现成的知识。

提高效果:预训练模型已经学习到了一些通用的特征,可以加速学习过程,并且通常能够取得比从头开始训练更好的效果。
2. 迁移学习(Transfer Learning)

迁移学习是一种利用在一个任务上学到的知识,来帮助在另一个相关任务上进行学习的技术。换句话说,它将一个任务中的学习成果迁移到另一个任务中,特别是在目标任务的数据较少时。

迁移学习的核心思想是:如果一个模型在某个任务上已经学到了一些有用的特征,那么这些特征可以迁移到另一个任务上,帮助模型更好地学习。

迁移学习的典型流程:

模型加载:加载一个在大数据集上预训练的模型(如 ResNet、VGG、BERT 等)。

模型微调:对模型的部分层进行微调,或者只训练新添加的层(如分类层)。

应用于新任务:将经过微调的模型应用于新的、可能较小的数据集。

迁移学习的类型

迁移学习有多种不同的方式,常见的有以下几种:

微调(Fine-Tuning):使用预训练模型的权重,并对某些层或整个模型进行微调,以适应新的任务和数据。

通常会冻结前几层(因为它们学习的是通用特征),只训练后几层(专门针对当前任务)。

特征提取(Feature Extraction):使用预训练模型的特征提取能力,将前几层的权重固定,不更新,仅训练新加的全连接层或输出层。

零-shot 学习:在一些任务中,预训练模型被直接应用到目标任务,而不进行微调,特别是当目标任务的标注数据非常少时。

迁移学习的应用:

计算机视觉:在一个大规模的数据集(如 ImageNet)上训练的模型可以用于许多不同的图像分类任务,例如识别猫、狗、车、飞机等物体,或者在医疗影像、无人驾驶等领域中应用。

自然语言处理(NLP):例如,BERT 和 GPT 等模型可以在情感分析、命名实体识别、机器翻译等任务上进行迁移学习。

3. 预训练模型和迁移学习的关系

预训练模型和迁移学习是紧密相关的。迁移学习通常依赖于预训练模型,使用在一个任务中学到的知识来帮助另一个任务。在迁移学习中,预训练模型提供了一个良好的起点,减少了从头开始训练的难度和所需的数据量。

预训练模型与迁移学习的关系:

预训练模型是迁移学习的基础,因为迁移学习的一个关键步骤是使用已经在其他任务上训练好的模型。

迁移学习则是使用这些预训练模型的技术,它通过微调或特征提取等方式,将预训练模型的知识应用到新任务中。

使用torchvision.datasets的常见参数:

root:数据存储的路径。如果数据已经下载,它会直接从该路径加载数据。

train:如果设置为 True,加载训练集;如果设置为 False,加载测试集。

download:如果设置为 True,如果本地不存在数据集,它会自动从互联网上下载。

transform:指定一个转换函数,对数据进行预处理和数据增强等操作。

transforms.Compose 是 torchvision.transforms 模块中的一个函数,用于将多个图像预处理操作组合成一个复合操作。在神经网络训练中,常常需要对输入图像进行多种预处理,例如将图像转换为张量(Tensor)、标准化、数据增强等。transforms.Compose 允许你将这些操作按顺序组合在一起,并一次性应用于输入图像。

心得体会:

这个实验直接调用预训练的resnet18进行CIFAR-10数据集的分类,因为这个模型是在Imagenet数据集上训练得到的,所以适用于新的任务需要微调模型。通过对比没有预训练权重的模型和有预训练权重的模型的训练效果,发现还是有预训练权重得到的结果比较好,因为预训练模型已经学习到了一些通用的特征,可以加速学习过程,通常能够取得比从头开始训练更好的效果。在实际应用中在理解模型内部实现的基础上,直接调用高层API是一个不错的选择,可以减少代码量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/63261.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

区块链概念 Web 3.0 实操

1. Web 3.0 概述 1.1 定义与背景 Web 3.0&#xff0c;也称为第三代互联网&#xff0c;是一个新兴的概念&#xff0c;它代表着互联网的未来发展和演进方向。Web 3.0的核心理念是去中心化、用户主权和智能化。这一概念的提出&#xff0c;旨在解决Web 2.0时代中用户数据隐私泄露…

linux下使用gdb运行程序,查看程序崩溃原因

1.什么是gdb? gdb 是 GNU Debugger 的缩写&#xff0c;是一个功能强大的用于调试程序的开源调试器工具。它可以帮助开发人员诊断和解决程序中的错误、跟踪程序执行过程、查看变量的值等。gdb 支持多种编程语言&#xff0c;包括 C、C、Objective-C、Fortran 等&#xff0c;并可…

鸿蒙arkts怎么打印一个方法的调用堆栈

做鸿蒙开发的时候&#xff0c;也想看一下一个方法到底是哪里调用的&#xff0c;工程太大&#xff0c;断点太麻烦&#xff0c;可以加堆栈日志。 在你的方法中加上这两句&#xff0c;就可以跟到堆栈日志 let err new Error() console.log(>>>>>>err.stack) …

Elasticsearch scroll 之滚动查询

Elasticsearch scroll 之滚动查询 Elasticsearch 的 Scroll API 是一种用于处理大规模数据集的机制&#xff0c;特别是在需要从索引中检索大量数据时。通常情况下&#xff0c;Elasticsearch 的搜索请求会有一个结果集大小的限制 (fromsize 的检索数量默认是 10,000 条记录)&am…

【漏洞复现】网动统一通信平台(ActiveUC)接口iactiveEnterMeeting存在信息泄露漏洞

🏘️个人主页: 点燃银河尽头的篝火(●’◡’●) 如果文章有帮到你的话记得点赞👍+收藏💗支持一下哦 @TOC 一、漏洞概述 1.1漏洞简介 漏洞名称:网动统一通信平台(ActiveUC)接口iactiveEnterMeeting存在信息泄露漏洞漏洞编号:无漏洞类型:信息泄露漏洞威胁等级:高危影…

掌握小程序地理位置服务插件,让用户体验再升级

在小程序开发中&#xff0c;地理位置服务插件扮演着至关重要的角色&#xff0c;它们不仅能够帮助开发者轻松获取用户的地理位置信息&#xff0c;还能够基于位置数据提供丰富的功能&#xff0c;如地图展示、周边搜索、路径规划等。 一、插件的基本概念与引入 插件定义&#xf…

IDE如何安装插件实现Go to Definition

项目背景 框架&#xff1a;Cucumber Cypress 语言&#xff1a;Javascript IDE&#xff1a;vscode 需求 项目根目录cypress-automation的cypress/integration是测试用例的存放路径&#xff0c;按照不同模块不同功能创建了很多子目录&#xff0c;cucumber测试用例.feature文…

如何通过 Windows 自带的启动管理功能优化电脑启动程序

在日常使用电脑的过程中&#xff0c;您可能注意到开机后某些程序会自动运行。这些程序被称为“自启动”或“启动项”&#xff0c;它们可以在系统启动时自动加载并开始运行&#xff0c;有时甚至在后台默默工作。虽然一些启动项可能是必要的&#xff08;如杀毒软件&#xff09;&a…

探索自然语言处理奥秘(NLP)

摘要 自然语言处理&#xff08;NLP&#xff09;是人工智能领域的一个重要分支&#xff0c;它致力于使计算机能够理解、解释和生成人类语言。这项技术让机器能够阅读文本、听懂语音&#xff0c;并与人类进行基本的对话交流。 通俗理解 自然语言处理&#xff08;NLP&#xff09…

html ul li 首页渲染多条数据 但只展示八条,其余的数据全部隐藏,通过icon图标 进行展示

<div style"float: left;" id"showMore"> 展开 </div> <div style"float: left;“id"hideLess"> 收起 </div> var data document.querySelectorAll(.allbox .item h3 a); const list document.querySelectorAl…

# issue 8 TCP内部原理和UDP编程

TCP 通信三大步骤&#xff1a; 1 三次握手建立连接; 2 开始通信&#xff0c;进行数据交换; 3 四次挥手断开连接&#xff1b; 一、TCP内部原理--三次握手 【第一次握手】套接字A∶"你好&#xff0c;套接字B。我这儿有数据要传给你&#xff0c;建立连接吧。" 【第二次…

力扣--543.二叉树的直径

题目 给你一棵二叉树的根节点&#xff0c;返回该树的 直径 。 二叉树的 直径 是指树中任意两个节点之间最长路径的 长度 。这条路径可能经过也可能不经过根节点 root 。 两节点之间路径的 长度 由它们之间边数表示。 代码 /** Definition for a binary tree node.public…

报错 JSON.parse: expected property name or ‘}‘,JSON数据中对象的key值不为字符串

报错 JSON.parse: expected property name or ‘}’ 原因 多是因为数据转换时出错&#xff0c;可能是存在单引号或者对象key值不为string导致 这里记录下我遇见的问题&#xff08;后端给的JSON数据里&#xff0c;对象key值不为string&#xff09; 现在后端转换JSON数据大多…

在ensp进行IS-IS网络架构配置

一、实验目的 1. 理解IS-IS协议的工作原理 2. 熟练ensp路由连接配置 二、实验要求 需求&#xff1a; 路由器可以互相ping通 实验设备&#xff1a; 路由器router6台 使用ensp搭建实验坏境&#xff0c;结构如图所示 三、实验内容 R1 u t m sys undo info en sys R1 #设…

挑战用React封装100个组件【010】

Hello&#xff0c;大家好&#xff0c;今天我挑战的组件是这样的&#xff01; 今天这个组件是一个打卡成功&#xff0c;或者获得徽章后的组件。点击按钮后&#xff0c;会弹出礼花。项目中的勋章是我通过AI生成的&#xff0c;还是很厉害的哈&#xff01;稍微抠图直接使用。最后面…

Mybatis-Plus的主要API

一、实体类操作相关API BaseMapper<T>接口 功能&#xff1a;这是 MyBatis - Plus 为每个实体类对应的 Mapper 接口提供的基础接口。它提供了一系列基本的 CRUD&#xff08;增删改查&#xff09;操作方法。例如insert(T entity)方法用于插入一条记录&#xff0c;d…

C++类与对象(二)

一、默认成员函数 class A{}; 像上面一样&#xff0c;一个什么都没有的类叫做空类&#xff0c;但是这个什么都没有并不是真正的什么都没有&#xff0c;只是我们看不见&#xff0c;空类里面其实是有6个默认成员函数的&#xff0c;当我们在类里面什么都不写的时候&#xff0c;编译…

数据结构与算法-03链表-03

递归与迭代 由一个问题引出 假设我们要计算 一个正整数的阶乘, N! 。 从数学上看 1&#xff01; 1 2&#xff01; 2 x 1 3! 3 x 2 x 1 4! 4 x 3 x 2 x 1 5! 5 x 4 x 3 x 2 x 1 : n! n x (n-1) x (n-2) x (n-3) x ... 1我们推出一般公式 f(1) 1 f(n) n * f(n-1…

spring6:2入门

spring6&#xff1a;2入门 目录 spring6&#xff1a;2入门2.1、环境要求2.2、构建模块2.3、程序开发2.3.1、引入依赖2.3.2、创建java类2.3.3、创建配置文件2.3.4、创建测试类测试2.3.5、运行测试程序 2.4、程序分析2.5、启用Log4j2日志框架2.5.1、Log4j2日志概述2.5.2、引入Log…

汽车IVI中控开发入门及进阶(三十五):架构QML App Architecture Best Practices

在Qt/QML工程的架构中,架构很重要,虽然本身它有分层,比如QML调用资源文件(图片等)显示GUI界面,后面的CPP文件实现界面逻辑,但是这个分类还有点粗。在实际开发中,界面逻辑也就是基于类cpp的实现,也开始使用各种面向对象的设计模式,实现更加优秀的开发架构,这点尤其在…