[PyTorch][chapter 52][迁移学习]

前言:

     迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。

      这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程


一  简介

    

迁移学习可以通过以下几种方式实现:

1.1 基于预训练模型的迁移:

      将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。

1.2  网络结构迁移:

将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。

1.3  特征迁移:

     将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。

     word2vec

1.4 参数迁移:

       将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。

本文主要例子用的是 参数迁移


二  Flatten

    作用:

     输入的向量x [batch, c, w, h]=>[batch, c*w*h]

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023@author: chengxf2
"""import torch
from torch import optim,nnclass Flatten(nn.Module):def __init__(self):super(Flatten,self).__init__()def forward(self, x):a = torch.tensor(x.shape[1:])#dim 中 input 张量的每一行的乘积。shape = torch.prod(a).item()#print("\n ---new shape--- ",shape)return x.view(-1,shape)

三 迁移学习

   torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。

   from torchvision.models import resnet152
  from torchvision.models import resnet18

   注意:

          现有分类器分类的类型 > = 新分类器类型,再做transfer.

才能取得好的效果.

         

分类器分类类型
已有分类器[猫,狗,鸡,鸭】
新分类器[猫,狗]

     

   

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023@author: chengxf2
"""# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023@author: chengxf2
"""import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18from util import FlattenbatchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)root ='pokemon'
resize =224csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()def evalute(model, loader):total =len(loader.dataset)correct =0for x,y in loader:x = x.to(device)y = y.to(device)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)correct += torch.eq(pred, y).sum().float().item()acc = correct/totalreturn acc   def main():trained_model = resnet152(pretrained=True)model = nn.Sequential(*list(trained_model.children())[:-1],Flatten(),nn.Linear(in_features=2048, out_features=5))optimizer = optim.Adam(model.parameters(),lr =lr) criteon = nn.CrossEntropyLoss()best_epoch=0,best_acc=0viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))global_step =0for epoch in range(epochs):print("\n --main---: ",epoch)for step, (x,y) in enumerate(train_loader):#x:[b,3,224,224] y:[b]x = x.to(device)y = y.to(device)#print("\n --x---: ",x.shape)logits =model(x)loss = criteon(logits, y)#print("\n --loss---: ",loss.shape)optimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1if epoch %2 ==0:val_acc = evalute(model, val_loader)if val_acc>best_acc:best_acc = val_accbest_epoch =epochtorch.save(model.state_dict(),'best.mdl')print("\n val_acc ",val_acc)viz.line([val_acc],[global_step],win='val_loss',update='append')print('\n best acc',best_acc, "best_epoch: ",best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt')test_acc = evalute(model, test_loader)print('\n test acc',test_acc)if __name__ == "__main__":main()


参考:
https://blog.csdn.net/qq_44089890/article/details/130460700

课时107 迁移学习实战_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/44399.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux:Temporary failure in name resolutionCouldn’t resolve host

所有域名无法正常解析。 ping www.baidu.com 等域名提示 Temporary failure in name resolution错误。 rootlocalhost:~# ping www.baidu.com ping: www.baidu.com: Temporary failure in name resolution rootlocalhost:~# 一、ubuntu/debian(emporary failure i…

云原生反模式

通过了解这些反模式并遵循云原生最佳实践,您可以设计、构建和运营更加强大、可扩展和成本效益高的云原生应用程序。 1.单体架构:在云上运行一个大而紧密耦合的应用程序,妨碍了可扩展性和敏捷性。2.忽略成本优化:云服务可能昂贵&am…

发布一个开源的新闻api(整理后就开源)

目录 说明: 基础说明 其他说明: 通用接口: 登录: 注册: 更改密码(需要token) 更换头像(需要token) 获取用户列表(需要token): 上传文件(5000端口): 获取文件(5000端口)源码文件,db文件均不能获取: 验证token(需要token): 获取系统时间: 文件…

AlexNet阅读笔记

ImageNet classification with deep convolutional neural networks 原文链接:https://dl.acm.org/doi/abs/10.1145/3065386 中文翻译:https://blog.csdn.net/qq_38473254/article/details/132307508 使用深度卷积神经网络进行 ImageNet 分类 摘要 大…

【Redis】Redisson分布式锁原理与使用

【Redis】Redisson分布式锁原理与使用 什么是Redisson? Redisson - 是一个高级的分布式协调Redis客服端,能帮助用户在分布式环境中轻松实现一些Java的对象,Redisson、Jedis、Lettuce 是三个不同的操作 Redis 的客户端,Jedis、Le…

【Golang系统开发】搜索引擎(3) 压缩倒排索引表

写在前面 假设我们的数据集中有 800000 篇文章,每篇文章有 200 词条,每个词条有6个字符,倒排记录数目是 1 亿。那么如果我们倒排索引表中单单记录文档id,不记录文档内的频率和偏移信息。 那么 文档id 的长度就必须是 l o g 2 8…

【不带权重的TOPSIS模型详解】——数学建模

目录索引 定义:问题引入:不合理之处:进行修改: 指标分类:指标正向化:极小型指标正向化公式:中间型指标正向化公式:区间型指标正向化公式: 标准化处理(消去单位)&#xff…

基于Java/springboot铁路物流数据平台的设计与实现

摘要 随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,铁路物流数据平台当然也不能排除在外,从文档信息、铁路设计的统计和分析,在过程中会产生大量的、各…

浙大数据结构第八周之08-图7 公路村村通

题目详情: 现有村落间道路的统计数据表中,列出了有可能建设成标准公路的若干条道路的成本,求使每个村落都有公路连通所需要的最低成本。 输入格式: 输入数据包括城镇数目正整数N(≤1000)和候选道路数目M&#xff08…

【C++】模板进阶

🌇个人主页:平凡的小苏 📚学习格言:命运给你一个低的起点,是想看你精彩的翻盘,而不是让你自甘堕落,脚下的路虽然难走,但我还能走,比起向阳而生,我更想尝试逆风…

华为PPPOE配置实验

华为PPPOE配置实验 网络拓扑图拓扑说明电信ISP设备配置用户拨号路由器配置查看是否拨上号是否看不懂? 看不懂就对了,只是记录一下命令。至于所有原理,等想写了再写 网络拓扑图 拓扑说明 用户路由器用于模拟家用拨号路由器,该设备…

最新AI系统ChatGPT程序源码/支持GPT4/自定义训练知识库/GPT联网/支持ai绘画(Midjourney)+Dall-E2绘画/支持MJ以图生图

一、前言 SparkAi系统是基于国外很火的ChatGPT进行开发的Ai智能问答系统。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。 那么如何搭建部署AI创作ChatGPT?小编这里写一个详细图文教程吧&#xff01…

自动执行探索性数据分析 (EDA),更快、更轻松地理解数据

一、说明 EDA是 exploratory data analysis (探索性数据分析 )的缩写。所谓EDA就是在数据分析之前需要对数据进行以此系统性研判,在这个研判后,得到基本的数据先验知识,在这个基础上进行数据分析。本文将在R语言和python语言的探索性处理。 摄…

K8S系列四:服务管理

写在前面 本文是K8S系列第四篇,主要面向对k8s新手同学。阅读本文需要读者对k8s的基本概念,比如Pod、Deployment、Service、Namespace等基础概念有所了解,尚且不了解的同学推荐先阅读本系列的第一篇文章《K8S系列一:概念入门》[1]…

JVM——分代收集理论和垃圾回收算法

一、分代收集理论 1、三个假说 弱分代假说:绝大多数对象都是朝生夕灭的。 强分代假说:熬过越多次垃圾收集过程的对象越难以消亡。 这两个分代假说共同奠定了多款常用的垃圾收集器的一致的设计原则:收集器应该将Java堆划分出不同的区域&…

kafka--kafka的基本概念-topic和partition

一、kafka的基本概念-topic和partition 1、topic (主题 ) topic是逻辑概念 以Topic机制来对消息进行分类的,同一类消息属于同一个Topic,你可以将每个topic看成是一个消息队列。 生产者(producer)将消息发…

利用Jackson封装常用的JsonUtil工具类

在实际开发中,我们对于 JSON 数据的处理,通常有这么几个第三方工具包可以使用: gson:谷歌的fastjson:阿里巴巴的jackson:美国FasterXML公司的,Spring框架默认用的 由于以前一直用习惯了阿里的…

Spring事务和事务传播机制(1)

前言🍭 ❤️❤️❤️SSM专栏更新中,各位大佬觉得写得不错,支持一下,感谢了!❤️❤️❤️ Spring Spring MVC MyBatis_冷兮雪的博客-CSDN博客 在Spring框架中,事务管理是一种用于维护数据库操作的一致性和…

Mac OS下应用Python+Selenium实现web自动化测试

在Mac环境下应用PythonSelenium实现web自动化测试 在这个过程中要注意两点: 1.在终端联网执行命令“sudo pip install –U selenium”如果失败了的话,可以尝试用命令“sudo easy_install selenium”来安装selenium; 2.安装好PyCharm后新建project&…

DTC 19服务学习1

在UDS(统一诊断服务)协议中,0x19是用于DTC(诊断故障代码)信息的服务。以下是你提到的子服务的功能和作用: 0x01 - 报告DTC按状态掩码。这个子服务用于获取当前存储在ECU中的DTC列表。状态掩码用于过滤DTC&a…