[PyTorch][chapter 52][迁移学习]

前言:

     迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。

      这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程


一  简介

    

迁移学习可以通过以下几种方式实现:

1.1 基于预训练模型的迁移:

      将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。

1.2  网络结构迁移:

将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。

1.3  特征迁移:

     将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。

     word2vec

1.4 参数迁移:

       将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。

本文主要例子用的是 参数迁移


二  Flatten

    作用:

     输入的向量x [batch, c, w, h]=>[batch, c*w*h]

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023@author: chengxf2
"""import torch
from torch import optim,nnclass Flatten(nn.Module):def __init__(self):super(Flatten,self).__init__()def forward(self, x):a = torch.tensor(x.shape[1:])#dim 中 input 张量的每一行的乘积。shape = torch.prod(a).item()#print("\n ---new shape--- ",shape)return x.view(-1,shape)

三 迁移学习

   torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。

   from torchvision.models import resnet152
  from torchvision.models import resnet18

   注意:

          现有分类器分类的类型 > = 新分类器类型,再做transfer.

才能取得好的效果.

         

分类器分类类型
已有分类器[猫,狗,鸡,鸭】
新分类器[猫,狗]

     

   

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023@author: chengxf2
"""# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023@author: chengxf2
"""import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18from util import FlattenbatchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)root ='pokemon'
resize =224csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()def evalute(model, loader):total =len(loader.dataset)correct =0for x,y in loader:x = x.to(device)y = y.to(device)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)correct += torch.eq(pred, y).sum().float().item()acc = correct/totalreturn acc   def main():trained_model = resnet152(pretrained=True)model = nn.Sequential(*list(trained_model.children())[:-1],Flatten(),nn.Linear(in_features=2048, out_features=5))optimizer = optim.Adam(model.parameters(),lr =lr) criteon = nn.CrossEntropyLoss()best_epoch=0,best_acc=0viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))global_step =0for epoch in range(epochs):print("\n --main---: ",epoch)for step, (x,y) in enumerate(train_loader):#x:[b,3,224,224] y:[b]x = x.to(device)y = y.to(device)#print("\n --x---: ",x.shape)logits =model(x)loss = criteon(logits, y)#print("\n --loss---: ",loss.shape)optimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1if epoch %2 ==0:val_acc = evalute(model, val_loader)if val_acc>best_acc:best_acc = val_accbest_epoch =epochtorch.save(model.state_dict(),'best.mdl')print("\n val_acc ",val_acc)viz.line([val_acc],[global_step],win='val_loss',update='append')print('\n best acc',best_acc, "best_epoch: ",best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt')test_acc = evalute(model, test_loader)print('\n test acc',test_acc)if __name__ == "__main__":main()


参考:
https://blog.csdn.net/qq_44089890/article/details/130460700

课时107 迁移学习实战_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/44399.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux:Temporary failure in name resolutionCouldn’t resolve host

所有域名无法正常解析。 ping www.baidu.com 等域名提示 Temporary failure in name resolution错误。 rootlocalhost:~# ping www.baidu.com ping: www.baidu.com: Temporary failure in name resolution rootlocalhost:~# 一、ubuntu/debian(emporary failure i…

云原生反模式

通过了解这些反模式并遵循云原生最佳实践,您可以设计、构建和运营更加强大、可扩展和成本效益高的云原生应用程序。 1.单体架构:在云上运行一个大而紧密耦合的应用程序,妨碍了可扩展性和敏捷性。2.忽略成本优化:云服务可能昂贵&am…

发布一个开源的新闻api(整理后就开源)

目录 说明: 基础说明 其他说明: 通用接口: 登录: 注册: 更改密码(需要token) 更换头像(需要token) 获取用户列表(需要token): 上传文件(5000端口): 获取文件(5000端口)源码文件,db文件均不能获取: 验证token(需要token): 获取系统时间: 文件…

AlexNet阅读笔记

ImageNet classification with deep convolutional neural networks 原文链接:https://dl.acm.org/doi/abs/10.1145/3065386 中文翻译:https://blog.csdn.net/qq_38473254/article/details/132307508 使用深度卷积神经网络进行 ImageNet 分类 摘要 大…

Django进阶-文件上传

普通文件上传 定义 用户可以通过浏览器将图片等文件上传到网站 场景 用户上传头像 上传流动性的文档【pdf,txt】等 上传规范-后端 1.视图函数中,用request。FILES取文件框的内容 file request.FILES[xxx] 说明: 1.FILES的key对应页面中…

【Redis】Redisson分布式锁原理与使用

【Redis】Redisson分布式锁原理与使用 什么是Redisson? Redisson - 是一个高级的分布式协调Redis客服端,能帮助用户在分布式环境中轻松实现一些Java的对象,Redisson、Jedis、Lettuce 是三个不同的操作 Redis 的客户端,Jedis、Le…

Redis基于内存的key-value结构化NOSQL(非关系型)数据库

Redis Redis介绍Redis的优点Redis的缺点Redis的安装Redis的连接Redis的使用Redis中的数据类型String的使用get setsetex(expire)ttlsetnx(not exit)HashList列表(队列)Set集合ZSet集合Redis 通用命令Redis图形客户端Redis在Java中的使用RedisTemplate

【Golang系统开发】搜索引擎(3) 压缩倒排索引表

写在前面 假设我们的数据集中有 800000 篇文章,每篇文章有 200 词条,每个词条有6个字符,倒排记录数目是 1 亿。那么如果我们倒排索引表中单单记录文档id,不记录文档内的频率和偏移信息。 那么 文档id 的长度就必须是 l o g 2 8…

git如何检查和修改忽略文件和忽略规则

查询忽略规则 使用命令行&#xff1a;git status --ignored&#xff0c;进行查询&#xff0c; 例&#xff1a; $ git status --ignored On branch develop Your branch is up to date with origin/develop.Ignored files:(use "git add -f <file>..." to inc…

Effective C++条款07——为多态基类声明virtual析构函数(构造/析构/赋值运算)

有许多种做法可以记录时间&#xff0c;因此&#xff0c;设计一个TimeKeeper base class和一些derived classes 作为不同的计时方法&#xff0c;相当合情合理&#xff1a; class TimeKeeper { public:TimeKeeper();~TimeKeeper();// ... };class AtomicClock: public TimeKeepe…

Go 1.21新增的 slices 包详解(二)

Go 1.21新增的 slices 包提供了很多和切片相关的函数&#xff0c;可以用于任何类型的切片。 slices.Delete 定义如下&#xff1a; func Delete[S ~[]E, E any](s S, i, j int) S 从 s 中删除元素 s[i:j]&#xff0c;返回修改后的切片。如果 s[i:j] 不是 s 的有效切片&#…

域名子目录发布问题(nginx、vue-element-admin、uni-app)

域名子目录发布问题&#xff08;nginx、vue-element-admin、uni-app&#xff09; 说明Vue-Element-Admin 代码打包nginx配置&#xff1a;uni-app打包 说明 使用一个域名下子目录进行打包&#xff1a; 比如&#xff1a; http://www.xxx.com/merchant 商户端代码 http://www.xx…

【不带权重的TOPSIS模型详解】——数学建模

目录索引 定义&#xff1a;问题引入&#xff1a;不合理之处&#xff1a;进行修改&#xff1a; 指标分类&#xff1a;指标正向化&#xff1a;极小型指标正向化公式&#xff1a;中间型指标正向化公式&#xff1a;区间型指标正向化公式&#xff1a; 标准化处理(消去单位)&#xff…

基于Java/springboot铁路物流数据平台的设计与实现

摘要 随着科学技术的飞速发展&#xff0c;社会的方方面面、各行各业都在努力与现代的先进技术接轨&#xff0c;通过科技手段来提高自身的优势&#xff0c;铁路物流数据平台当然也不能排除在外&#xff0c;从文档信息、铁路设计的统计和分析&#xff0c;在过程中会产生大量的、各…

浙大数据结构第八周之08-图7 公路村村通

题目详情&#xff1a; 现有村落间道路的统计数据表中&#xff0c;列出了有可能建设成标准公路的若干条道路的成本&#xff0c;求使每个村落都有公路连通所需要的最低成本。 输入格式: 输入数据包括城镇数目正整数N&#xff08;≤1000&#xff09;和候选道路数目M&#xff08…

【C++】模板进阶

&#x1f307;个人主页&#xff1a;平凡的小苏 &#x1f4da;学习格言&#xff1a;命运给你一个低的起点&#xff0c;是想看你精彩的翻盘&#xff0c;而不是让你自甘堕落&#xff0c;脚下的路虽然难走&#xff0c;但我还能走&#xff0c;比起向阳而生&#xff0c;我更想尝试逆风…

华为PPPOE配置实验

华为PPPOE配置实验 网络拓扑图拓扑说明电信ISP设备配置用户拨号路由器配置查看是否拨上号是否看不懂&#xff1f; 看不懂就对了&#xff0c;只是记录一下命令。至于所有原理&#xff0c;等想写了再写 网络拓扑图 拓扑说明 用户路由器用于模拟家用拨号路由器&#xff0c;该设备…

深入浅出Pytorch函数——torch.nn.init.normal_

分类目录&#xff1a;《深入浅出Pytorch函数》总目录 相关文章&#xff1a; 深入浅出Pytorch函数——torch.nn.init.calculate_gain 深入浅出Pytorch函数——torch.nn.init.uniform_ 深入浅出Pytorch函数——torch.nn.init.normal_ 深入浅出Pytorch函数——torch.nn.init.c…

最新AI系统ChatGPT程序源码/支持GPT4/自定义训练知识库/GPT联网/支持ai绘画(Midjourney)+Dall-E2绘画/支持MJ以图生图

一、前言 SparkAi系统是基于国外很火的ChatGPT进行开发的Ai智能问答系统。本期针对源码系统整体测试下来非常完美&#xff0c;可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。 那么如何搭建部署AI创作ChatGPT&#xff1f;小编这里写一个详细图文教程吧&#xff01…

自动执行探索性数据分析 (EDA),更快、更轻松地理解数据

一、说明 EDA是 exploratory data analysis (探索性数据分析 )的缩写。所谓EDA就是在数据分析之前需要对数据进行以此系统性研判&#xff0c;在这个研判后&#xff0c;得到基本的数据先验知识&#xff0c;在这个基础上进行数据分析。本文将在R语言和python语言的探索性处理。 摄…