【NLP】一个使用PyTorch实现图像分类的迁移学习实例

一个使用PyTorch实现图像分类的迁移学习实例

  • 1. 导入模块
  • 2. 加载数据
  • 3. 模型处理
  • 4. 训练及验证模型
  • 5. 微调
  • 6. 其他代码

在特征提取中,可以在预先训练好的网络结构后修改或添加一个简单的分类器,然后将源任务上预先训练好的网络作为另一个目标任务的特征提取器,只对最后增加的分类器参数重新学习,而预先训练好的网络参数不被修改或冻结。

在完成新任务的特征提取时使用的是源任务中学习到的参数,而不用重新学习所有参数。下面的示例用一个实例具体说明如何通过特征提取的方法进行图像分类。

1. 导入模块

from datetime import datetimeimport matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torchvision import models

2. 加载数据

这里需要事先将CIFAR10数据下载到本地,因为比较耗时,因此,将download=False。除此之外,还增加了一些预处理功能,比如数据标准化、对图片进行裁剪等。

def load_data(data, batch_size=64, num_workers=2, mean=None, std=None):if std is None:std = [0.229, 0.224, 0.225]if mean is None:mean = [0.485, 0.456, 0.406]trans_train = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])trans_valid = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])train_set = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=trans_train)trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_set = torchvision.datasets.CIFAR10(root=data, train=False, download=True, transform=trans_valid)testloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)return trainloader, testloader

3. 模型处理

这个部分包含三个操作:

  • 下载预训练模型:使用的预训练模型为resnet18,且已经在ImageNet大数据集上训练好了
  • 冻结模型参数:使其在反向传播时,不会更新
  • 修改最后一层的输出类别数:该数据集中有1000个类别,即原始输出为512×1000,现将其修改为512×10,因为这里使用的新数据集有10个类别
def freeze_net(num_class=10):# 下载预训练模型net = models.resnet18(pretrained=True)# 冻结模型参数for params in net.parameters():params.requires_grad = False# 修改最后一层的输出类别数net.fc = nn.Linear(512, num_class)# 查看冻结前后的参数情况total_params = sum(p.numel() for p in net.parameters())print(f'原总参数个数:{total_params}')total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)print(f'需训练参数个数:{total_trainable_params}')return net

原总参数个数:11181642
需训练参数个数:5130

从输出上可知,如果不冻结,需要更新的参数太多了,冻结之后只需要更新全连接层的参数即可。

4. 训练及验证模型

这里选用交叉熵作为损失函数,使用SGD作为优化器,学习率为1e-3,权重衰减设为1e-3,代码如下:

# 训练及验证模型
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):prev_time = datetime.now()for epoch in range(num_epochs):train_loss = 0train_acc = 0net = net.train()for im, label in train_data:im = im.to(device)  # (bs, 3, h, w)label = label.to(device)  # (bs, h, w)# forwardoutput = net(im)loss = criterion(output, label)# backwardoptimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += get_acc(output, label)cur_time = datetime.now()h, remainder = divmod((cur_time - prev_time).seconds, 3600)m, s = divmod(remainder, 60)time_str = "Time %02d:%02d:%02d" % (h, m, s)if valid_data is not None:valid_loss = 0valid_acc = 0net = net.eval()for im, label in valid_data:im = im.to(device)  # (bs, 3, h, w)label = label.to(device)  # (bs, h, w)output = net(im)loss = criterion(output, label)valid_loss += loss.item()valid_acc += get_acc(output, label)epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "% (epoch, train_loss / len(train_data),train_acc / len(train_data), valid_loss / len(valid_data),valid_acc / len(valid_data)))else:epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %(epoch, train_loss / len(train_data),train_acc / len(train_data)))prev_time = cur_timeprint(epoch_str + time_str)

运行结果:
Epoch 0. Train Loss: 1.474121, Train Acc: 0.498322, Valid Loss: 0.901339, Valid Acc: 0.713177, Time 00:03:26
Epoch 1. Train Loss: 1.222752, Train Acc: 0.576946, Valid Loss: 0.818926, Valid Acc: 0.730494, Time 00:04:35
Epoch 2. Train Loss: 1.172832, Train Acc: 0.592651, Valid Loss: 0.777265, Valid Acc: 0.737759, Time 00:04:23
Epoch 3. Train Loss: 1.158157, Train Acc: 0.596228, Valid Loss: 0.761969, Valid Acc: 0.746517, Time 00:04:28
Epoch 4. Train Loss: 1.143113, Train Acc: 0.600643, Valid Loss: 0.757134, Valid Acc: 0.742138, Time 00:04:24
Epoch 5. Train Loss: 1.128991, Train Acc: 0.607797, Valid Loss: 0.745840, Valid Acc: 0.747014, Time 00:04:24
Epoch 6. Train Loss: 1.131602, Train Acc: 0.603561, Valid Loss: 0.740176, Valid Acc: 0.748109, Time 00:04:21
Epoch 7. Train Loss: 1.127840, Train Acc: 0.608336, Valid Loss: 0.738235, Valid Acc: 0.751990, Time 00:04:19
Epoch 8. Train Loss: 1.122831, Train Acc: 0.609275, Valid Loss: 0.730571, Valid Acc: 0.751692, Time 00:04:18
Epoch 9. Train Loss: 1.118955, Train Acc: 0.609715, Valid Loss: 0.731084, Valid Acc: 0.751692, Time 00:04:13
Epoch 10. Train Loss: 1.111291, Train Acc: 0.612052, Valid Loss: 0.728281, Valid Acc: 0.749602, Time 00:04:09
Epoch 11. Train Loss: 1.108454, Train Acc: 0.612712, Valid Loss: 0.719465, Valid Acc: 0.752787, Time 00:04:15
Epoch 12. Train Loss: 1.111189, Train Acc: 0.612012, Valid Loss: 0.726525, Valid Acc: 0.751294, Time 00:04:09
Epoch 13. Train Loss: 1.114475, Train Acc: 0.610594, Valid Loss: 0.717852, Valid Acc: 0.754080, Time 00:04:06
Epoch 14. Train Loss: 1.112658, Train Acc: 0.608596, Valid Loss: 0.723336, Valid Acc: 0.751393, Time 00:04:14
Epoch 15. Train Loss: 1.109367, Train Acc: 0.614950, Valid Loss: 0.721230, Valid Acc: 0.752588, Time 00:04:06
Epoch 16. Train Loss: 1.107644, Train Acc: 0.614230, Valid Loss: 0.711586, Valid Acc: 0.755275, Time 00:04:08
Epoch 17. Train Loss: 1.100239, Train Acc: 0.613411, Valid Loss: 0.722191, Valid Acc: 0.749303, Time 00:04:11
Epoch 18. Train Loss: 1.108576, Train Acc: 0.611013, Valid Loss: 0.721263, Valid Acc: 0.753483, Time 00:04:08
Epoch 19. Train Loss: 1.098069, Train Acc: 0.618027, Valid Loss: 0.705413, Valid Acc: 0.757962, Time 00:04:06

从结果上看,验证集的准确率达到75%左右。下面采用微调+数据增强的方法继续提升准确率。

5. 微调

微调允许修改预训练好的网络参数来学习目标任务,所以训练时间要比特征抽取方法长,但精度更高。微调的大致过程是再预训练的网络上添加新的随机初始化层,此外预训练的网络参数也会被更新,但会使用较小的学习率以防止预训练好的参数发生较大改变

常用的方法是固定底层的参数,调整一些顶层或具体层的参数。这样可以减少训练参数的数量,也可以避免过拟合的发生。尤其是针对目标任务的数据量不够大的时候,该方法会很有效。

实际上,微调优于特征提取,因为它能对迁移过来的预训练网络参数进行优化,使其更加适合新的任务。
(1)数据预处理
对训练数据添加了几种数据增强方法,比如图片裁剪、旋转、颜色改变等方法。测试数据与特征提取的方法一样。

    if fine_tuning is False:trans_train = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])else:trans_train = transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),transforms.RandomRotation(degrees=15),transforms.ColorJitter(),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])

(2)修改模型的分类器层
修改最后全连接层,把类别数由原来的1000改为10。

def freeze_net(num_class=10, fine_tuning=False):# 下载预训练模型net = models.resnet18(pretrained=True)print(net)if fine_tuning is False:# 冻结模型参数for params in net.parameters():params.requires_grad = False# 修改最后一层的输出类别数net.fc = nn.Linear(512, num_class)# 查看冻结前后的参数情况total_params = sum(p.numel() for p in net.parameters())print(f'原总参数个数:{total_params}')total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)print(f'需训练参数个数:{total_trainable_params}')# 打印出第一层的权重print(f'第一层的权重:{net.conv1.weight.type()}')return net

训练结果:
Epoch 0. Train Loss: 1.455535, Train Acc: 0.488460, Valid Loss: 0.832547, Valid Acc: 0.721400, Time 00:14:48
Epoch 1. Train Loss: 1.342625, Train Acc: 0.530280, Valid Loss: 0.815430, Valid Acc: 0.723500, Time 10:31:48
Epoch 2. Train Loss: 1.319122, Train Acc: 0.535680, Valid Loss: 0.866512, Valid Acc: 0.699000, Time 00:12:02
Epoch 3. Train Loss: 1.310949, Train Acc: 0.541700, Valid Loss: 0.789511, Valid Acc: 0.728000, Time 00:12:03
Epoch 4. Train Loss: 1.313486, Train Acc: 0.538500, Valid Loss: 0.762553, Valid Acc: 0.741300, Time 00:12:19
Epoch 5. Train Loss: 1.309776, Train Acc: 0.540680, Valid Loss: 0.777906, Valid Acc: 0.736100, Time 00:11:43
Epoch 6. Train Loss: 1.302117, Train Acc: 0.541780, Valid Loss: 0.779318, Valid Acc: 0.737200, Time 00:12:00
Epoch 7. Train Loss: 1.304539, Train Acc: 0.544320, Valid Loss: 0.795917, Valid Acc: 0.726500, Time 00:13:16
Epoch 8. Train Loss: 1.311748, Train Acc: 0.542400, Valid Loss: 0.785983, Valid Acc: 0.728000, Time 00:14:48
Epoch 9. Train Loss: 1.302069, Train Acc: 0.544820, Valid Loss: 0.781665, Valid Acc: 0.734700, Time 00:14:15
Epoch 10. Train Loss: 1.298019, Train Acc: 0.547040, Valid Loss: 0.771555, Valid Acc: 0.742200, Time 00:16:11
Epoch 11. Train Loss: 1.310127, Train Acc: 0.538700, Valid Loss: 0.764313, Valid Acc: 0.739300, Time 00:17:33
Epoch 12. Train Loss: 1.300172, Train Acc: 0.544720, Valid Loss: 0.765881, Valid Acc: 0.734200, Time 00:12:04
Epoch 13. Train Loss: 1.289607, Train Acc: 0.546980, Valid Loss: 0.753371, Valid Acc: 0.742500, Time 00:11:49
Epoch 14. Train Loss: 1.295938, Train Acc: 0.546280, Valid Loss: 0.821099, Valid Acc: 0.721900, Time 00:11:43

使用微调训练方式的时间明显大于使用特征提取方式的时间,但是验证集上的准确率并没有提高,这是因为由于GPU内存限制,这里将batch_size设为了16。

6. 其他代码

if __name__ == '__main__':data_path = './data'classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'forg', 'horse', 'ship', 'truck')if torch.cuda.is_available():device = torch.device('cuda:0')torch.cuda.empty_cache()else:device = torch.device('cpu')# 加载数据train_loader, test_loader = load_data(data=data_path, fine_tuning=True)# 随机获取部分训练数据data_iter = iter(train_loader)images, labels = data_iter.next()# 显示图像imshow(torchvision.utils.make_grid(images[:4]))# 打印标签print(' '.join('%5s' % classes[labels[j]] for j in range(4)))# 加载模型net = freeze_net(num_class=len(classes), fine_tuning=True)net = net.to(device)# 定义损失函数及优化器criterion = nn.CrossEntropyLoss()# 只需要优化最后一层参数optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3, weight_decay=1e-3, momentum=0.9)# 训练及验证模型train(net, train_loader, test_loader, 20, optimizer, criterion)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/10637.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据决定AIGC的高度,什么又决定着数据的深度?

有人曾言,数据决定人工智能发展的天花板。深以为然。 随着ChatGPT等AIGC应用所展现出的强大能力,人们意识到通用人工智能的奇点正在来临,越来越多的企业开始涌入这条赛道。在AIGC浪潮席卷全球之际,数据的重要性也愈发被业界所认同…

HTML5 的离线储存怎么使用,工作原理

TML5提供了一种称为离线储存(Offline Storage)的功能,它允许网页在离线时缓存和存储数据,以便用户可以在没有网络连接的情况下访问这些数据。离线储存是通过使用Web Storage API或者应用程序缓存(Application Cache&am…

[SQL挖掘机] - 字符串函数 - lower

介绍: lower函数是mysql中的一个字符串函数,其作用是将给定的字符串转换为小写形式。它接受一个字符串作为参数,并返回一个新的字符串,其中所有的字母字符均被转换为小写形式。 使用lower函数可以帮助我们在字符串处理中实现标准化和规范化…

MySQL基础(四)数据库备份

目录 前言 一、概述 1.数据备份的重要性 2.造成数据丢失的原因 二、备份类型 (一)、物理与逻辑角度 1.物理备份 2.逻辑备份 (二)、数据库备份策略角度 1.完整备份 2.增量备份 三、常见的备份方法 四、备份&#xff08…

通讯录系统

目录 通讯录系统头文件&#xff1a; 通讯录系统Test&#xff1a; 通讯录系统函数源代码&#xff1a; 通讯录系统头文件&#xff1a; #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <assert…

python 操作sqlite3数据库

sqlite3 import sqlite3 db sqlite3.connect("c:/tmp/test2.db") #连接数据库&#xff0c;若不存在则自动创建 #文件夹 c:/tmp 必须事先存在,connect不会创建文件夹 cur db.cursor() #获取光标&#xff0c;要操作数据库一般要通过光标进行 sql CREATE TABLE if n…

vue实现flv格式视频播放

公司项目需要实现摄像头实时视频播放&#xff0c;flv格式的视频。先百度使用flv.js插件实现&#xff0c;但是两个摄像头一个能放一个不能放&#xff0c;没有找到原因。&#xff08;开始两个都能放&#xff0c;后端更改地址后不有一个不能放&#xff09;但是在另一个系统上是可以…

Blocking Analyzer 1.5 For MySQL 8.0

快速获取MySQL 8.0的blocking信息 1&#xff09;super_read_only 2&#xff09;read_only 3&#xff09;innodb lock waits 4&#xff09;schema table lock waits 5&#xff09;data lock waits 6&#xff09;metadata locks 7&#xff09;data locks 通过以上信息快速…

wangEditor初探

1、前言 现有的Quill比较简单&#xff0c;无法满足业务需求&#xff08;例如SEO的图片属性编辑需求&#xff09; Quill已经有比较长的时间没有更新了&#xff0c;虽然很灵活&#xff0c;但是官方demo都没有一个。 业务前期也没有这块的需求&#xff0c;也没有考虑到这块的扩展…

Xilinx P4使用方法--架构篇

Xilinx P4使用方法--架构篇 1 P4 IP架构2 P4接口说明3 P4使用方法3.1 P4程序3.2 命令文件3.3 数据流文件本文主要介绍Xilinx P4的基本架构、接口和仿真测试文件。 1 P4 IP架构 P4 IP的架构如下图所示,主要由解析器(Parser)、匹配-动作引擎(Match-Action Engine)、逆解析器(De…

接口自动化测试-Python+Requests+Pytest+YAML+Allure配套撸码(详细)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 接口自动化框架&a…

[Java] 观察者模式简述

模式定义&#xff1a;定义了对象之间的一对多依赖&#xff0c;让多个观察者对象同时监听某一个主题对象&#xff0c;当主题对象发生变化时&#xff0c;他的所有依赖者都会收到通知并且更新 依照这个图&#xff0c;简单的写一个代码 package Section1.listener;import java.ut…

前端程序员入门:先学Vue3还是Vue2?

一、前言 对于新手来说&#xff0c;学习Vue.js框架时往往会有这样一个疑问&#xff1a;应该先学习Vue2还是直接学习Vue3&#xff1f;在回答这个问题之前&#xff0c;我们先简单介绍一下Vue.js框架。 Vue.js是一个轻量级的MVVM(Model-View-ViewModel)框架&#xff0c;它以数据驱…

el-table表格自动滚动

实现效果如下&#xff1a; 功能点&#xff1a; 1. 当表格内容超出时&#xff0c;自动滚动&#xff0c;滚动到最后一条之后在从头滚动。 2. 表格中的数据会定时刷新&#xff0c;刷新后数据更新。 3. 鼠标移入表格中&#xff0c;停止滚动&#xff1b;移出后&#xff0c;继续滚…

nginx几种常见的负载均衡策略

在服务器集群中&#xff0c;Nginx起到一个代理服务器的角色&#xff08;即反向代理&#xff09;&#xff0c;为了避免单独一个服务器压力过大&#xff0c;将来自用户的请求转发给不同的服务器。 根据权重负载均衡 指定轮询几率&#xff0c;weight和访问比率成正比&#xff0c…

算法训练营第四十九天||● 121. 买卖股票的最佳时机 ● 122.买卖股票的最佳时机II

● 121. 买卖股票的最佳时机 暴力和贪心都可以解决 主要讲解动态规划 dp数组&#xff1a;我们把dp数组定义为一个二维的vector容器 dp[i][0]表示第i天持有股票手中的金额&#xff1a;它可以是当天买入的也可以是之前买入的 dp[i][1]表示第i天不持有股票手中的金额&#xf…

VXLAN集中式网关部署(静态方式)

目录 1. 网络拓扑1.1 配置思路1.2 数据准备2. 配置Underlay网络2.1 配置CE12.2 配置CE22.3 配置CE32.4 查看OSPF结果2.5 配置LSW12.6 配置LSW23. 配置Overlay网络二层互通(同网段)3.1 配置CE13.2 配置CE23.3 配置CE33.4 Server13.5 Server23.6 Server33.7 Server43.8 抓包分析…

Kafka入门到起飞系列 - 副本机制,什么是副本因子呢?

我们一直在讲一个主题会有多个分区&#xff0c;这多个分区可以分布在一台服务器上&#xff0c;也可以分布在多台服务器上&#xff0c;还可以增加分区&#xff08;Kafka目前只支持分区&#xff09;&#xff0c;这是Kafka提供的一种横向扩展的手段 比如我们创建了一个主题&#x…

YAML+PyYAML笔记 2 | YAML缩进、分离、注释简单使用

2 | YAML缩进、分离、注释简单使用 1 简介2 缩进3 分离4 多行文本4.1 折叠块4.2 字面块4.3 引用块 5 注释5.1 行内注释5.2 块注释5.3 完美注释示例 1 简介 YAML 不是一种标记语言&#xff0c;而是一种数据格式&#xff1b;使用缩进和分离来表示数据结构&#xff0c;不需要使用…

与 ChatGPT 进行有效交互的几种策略

在这篇文章中&#xff0c;您将了解即时工程。尤其&#xff0c; 如何在提示中提供对响应影响最大的信息什么是角色、正面和负面提示、零样本提示等如何迭代使用提示来利用 ChatGPT 的对话性质 废话不多说直接开始吧&#xff01;&#xff01;&#xff01; 提示原则 快速工程是有…