微调及代码

一、微调:迁移学习(transfer learning)将从源数据集学到的知识迁移到目标数据集

二、步骤

1、在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型

2、创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。

3、向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。

4、在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调。

5、目标数据集比源数据集小得多时,微调有助于提高模型的泛化能力。就相当于在别人训练好的基础上训练

三、网络架构:神经网络一般分为两块:特征抽取(将原始像素变成容易线性分割的特征)和线性分类器

四、训练

1、微调是在目标数据集上的正常训练任务,但使用更强的正则化(有更强的lr和更少的epoch)

2、源数据集远复杂于目标数据,通常微调效果会更好

3、源数据集中可能也有目标数据中的部分标号

4、固定底层训练高层

五、总结

1、迁移学习将从源数据集中学到的知识迁移到目标数据集,微调是迁移学习的常见技巧。

2、除输出层外,目标模型从源模型中复制所有模型设计及其参数,并根据目标数据集对这些参数进行微调。但是,目标模型的输出层需要从头开始训练。

3、通常,微调参数使用较小的学习率,而从头开始训练输出层可以使用更大的学习率。

六、代码

1、导入数据

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])

2、定义和初始化模型

pretrained_net = torchvision.models.resnet18(pretrained=True)finetune_net = torchvision.models.resnet18(pretrained=True)
#最后一层类别为2
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

3、微调模型

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/870726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据基础:Hadoop之Yarn重点架构原理

文章目录 Hadoop之Yarn重点架构原理 一、Yarn介绍 二、Yarn架构 三、Yarn任务运行流程 四、Yarn三种资源调度器特点及使用场景 Hadoop之Yarn重点架构原理 一、Yarn介绍 Apache Hadoop Yarn(Yet Another Reasource Negotiator,另一种资源协调者)是Hadoop2.x版…

LLM-向量数据库中的索引算法总结

文章目录 前言向量数据库介绍索引方法倒排索引KNN 搜索近似 KNN 搜索Product Quantization(PQ)NSW 算法搜索HNSW 前言 向量数据库是当今大模型知识库检索落地实践的核心组件,下图是构建知识库检索的架构图: 首先会将相关文档数据向量化嵌入到向量化数据…

Python Linux下编译

注意 本教程针对较新Linux系统,没有升级依赖、处理旧版本Linux的openssl等步骤,如有需要可以查看往期文章,例如:在Centos7.6镜像中安装Python3.9 教程中没有使用默认位置、默认可执行文件名,请注意甄别 安装路径&#…

vue3中echarts的使用

1.下载 echartsnpm i -s echarts 2.在main.js中引入import { createApp } from vue import App from ./App.vue// 引入 echarts import * as echarts from echarts const app createApp(App) // 全局挂载 echarts app.config.globalProperties.$echarts echartsapp.mount(#ap…

I18N/L10N 历史 / I18N Guidelines I18N 指南 / libi18n 模块说明

注&#xff1a;机翻&#xff0c;未校对。 文章虽然从 Netscape 客户端展开 I18N/L10N 历史&#xff0c;但 I18N/L10N 的演化早已不仅限适用于 Netscape 客户端。 Netscape Client I18N/L10N History Netscape 客户端 I18N/L10N 历史 Contact: Bob Jung <bobjnetscape.com&…

阿里生态体系

阿里巴巴的“16N”战略框架是一种业务布局战略。具体来说&#xff0c;“1”代表核心电商平台&#xff0c;“6”代表阿里的六大板块&#xff0c;“N”代表众多的新业务和创新业务。以下是对“16N”具体内容的详细说明&#xff1a; 1. 核心电商平台 阿里巴巴电子商务业务&#…

Go语言入门之数组切片

Go语言入门之数组切片 1.数组的定义 数组是一组连续内存空间存储的具有相同类型的数据&#xff0c;是一种线性结构。 在Go语言中&#xff0c;数组的长度是固定的。 数组是值传递&#xff0c;有较大的内存开销&#xff0c;可以使用指针解决 数组声明 var name [size]typename&…

达梦数据库dm8安装步骤及迁移

目录 前言: 一、安装部署 1、下载 2、创建用户及安装目录 3、挂载下载的镜像 4、环境配置 5、安装 二、基本使用 1、DM工具使用 2、兼容性配置 2.1 兼容GBK字符集编码 2.2 兼容UTF-8字符集编码 3、创建用户和密码,表空间 4、整理数据库配置 5、启动脚本设置 …

华为OD机考题(HJ74 参数解析)

前言 经过前期的数据结构和算法学习&#xff0c;开始以OD机考题作为练习题&#xff0c;继续加强下熟练程度。 描述 在命令行输入如下命令&#xff1a; xcopy /s c:\\ d:\\e&#xff0c; 各个参数如下&#xff1a; 参数1&#xff1a;命令字xcopy 参数2&#xff1a;字符串…

JavaSE学习笔记之内部类、枚举类和基本类型包装类

今天我们继续复习Java相关的知识&#xff0c;和大家分享有关内部类等方面的知识&#xff0c;希望大家喜欢。 目录​​​​​​​ 内部类 成员内部类 ​编辑 静态内部类 局部内部类 匿名内部类 枚举类 定义方法 基本类型包装类 自动装箱和拆箱 内部类 成员内部类 成…

使用 Google 的 Generative AI 服务时,请求没有包含足够的认证范围(scopes)

题意&#xff1a; Google generativeai 403 Request had insufficient authentication scopes. [reason: "ACCESS_TOKEN_SCOPE_INSUFFICIENT" 问题背景&#xff1a; I have tried the simple POC for generativeai on its own to do generate_content and it works…

WPS点击Zotero插入没有任何反应

wps个人版没有内置vba&#xff0c;因此即便一下插件安装上了&#xff08;如Axmath&#xff0c;zotero&#xff09;&#xff0c;当点击插件的时候会出现“点不动”、“点击插件没反应的现象。至于islide一类的插件&#xff0c;干脆连装都装不上。 这就需要手动安装一下vba。 针…

Python酷库之旅-第三方库Pandas(017)

目录 一、用法精讲 41、pandas.melt函数 41-1、语法 41-2、参数 41-3、功能 41-4、返回值 41-5、说明 41-5-1、宽格式数据(Wide Format) 41-5-2、长格式数据(Long Format) 41-6、用法 41-6-1、数据准备 41-6-2、代码示例 41-6-3、结果输出 42、pandas.pivot函数 …

【单片机毕业设计选题24059】-太阳能嵌入式智能充电系统研究

系统功能: 系统由太阳能电池板提供电源&#xff0c; 系统上电后显示“欢迎使用智能充电系统请稍后”&#xff0c; 两秒钟后进入主页面显示。 第一行显示太阳能电池板输入的电压值 第二行显示系统输出的电压值 第三行显示采集到的太阳能电池板温度 第四行显示设置的太阳能…

回归损失和分类损失

回归损失和分类损失是机器学习模型训练过程中常用的两类损失函数&#xff0c;分别适用于回归任务和分类任务。 回归损失函数 回归任务的目标是预测一个连续值&#xff0c;因此回归损失函数衡量预测值与真实值之间的差异。常见的回归损失函数有&#xff1a; 均方误差&#xff…

【UNI-APP】阿里NLS一句话听写typescript模块

阿里提供的demo代码都是javascript&#xff0c;自己捏个轮子。参考着自己写了一个阿里巴巴一句话听写Nls的typescript模块。VUE3的组合式API形式 startClient&#xff1a;开始听写&#xff0c;注意下一步要尽快开启识别和传数据&#xff0c;否则6秒后会关闭 startRecognition…

004-基于Sklearn的机器学习入门:回归分析(下)

本节及后续章节将介绍机器学习中的几种经典回归算法&#xff0c;包括线性回归&#xff0c;多项式回归&#xff0c;以及正则项的岭回归等&#xff0c;所选方法都在Sklearn库中聚类模块有具体实现。本节为下篇&#xff0c;将介绍多项式回归和岭回归等。 目录 2.3 多项式回归 2…

Point Cloud Library (PCL) for Python - pclpy 安装指南 (1)

以下所有的版本号务必按照说明安装。 1.安装 Python 3.6 https://www.python.org/ftp/python/3.6.8/python-3.6.8-amd64.exe #或 百度网盘 2.确认 Python 版本为 3.6.x python #Python 3.6.8 (tags/v3.6.8:3c6b436a57, Dec 24 2018, 00:16:47) [MSC v.1916 64 bit (AMD64)] on…

给后台写了一个优雅的自定义风格的数据日志上报页面

highlight: atelier-cave-dark 查看后台数据日志是非常常见的场景,经常看到后台的小伙伴从服务器日志复制一段json数据字符串,然后找一个JSON工具网页打开,在线JSON格式化校验。有的时候,一些业务需要展示mqtt或者socket的实时信息展示,如果不做任何修改直接展示一串字符…

将有序数组转化成二叉搜索数

1 问题 将一个按照升序排列的有序数组&#xff0c;转换为一棵高度平衡二叉搜索树。本题中&#xff0c;一个高度平衡二叉树是指一个二叉树每个节点的左右两个子树的高度差的绝对值不超过1。 2 方法 采用递归的方法找到root结点&#xff0c;以及左子树和右子树。 代码清单 1 clas…