微调及代码

一、微调:迁移学习(transfer learning)将从源数据集学到的知识迁移到目标数据集

二、步骤

1、在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型

2、创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。

3、向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。

4、在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调。

5、目标数据集比源数据集小得多时,微调有助于提高模型的泛化能力。就相当于在别人训练好的基础上训练

三、网络架构:神经网络一般分为两块:特征抽取(将原始像素变成容易线性分割的特征)和线性分类器

四、训练

1、微调是在目标数据集上的正常训练任务,但使用更强的正则化(有更强的lr和更少的epoch)

2、源数据集远复杂于目标数据,通常微调效果会更好

3、源数据集中可能也有目标数据中的部分标号

4、固定底层训练高层

五、总结

1、迁移学习将从源数据集中学到的知识迁移到目标数据集,微调是迁移学习的常见技巧。

2、除输出层外,目标模型从源模型中复制所有模型设计及其参数,并根据目标数据集对这些参数进行微调。但是,目标模型的输出层需要从头开始训练。

3、通常,微调参数使用较小的学习率,而从头开始训练输出层可以使用更大的学习率。

六、代码

1、导入数据

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])

2、定义和初始化模型

pretrained_net = torchvision.models.resnet18(pretrained=True)finetune_net = torchvision.models.resnet18(pretrained=True)
#最后一层类别为2
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

3、微调模型

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/870726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据基础:Hadoop之Yarn重点架构原理

文章目录 Hadoop之Yarn重点架构原理 一、Yarn介绍 二、Yarn架构 三、Yarn任务运行流程 四、Yarn三种资源调度器特点及使用场景 Hadoop之Yarn重点架构原理 一、Yarn介绍 Apache Hadoop Yarn(Yet Another Reasource Negotiator,另一种资源协调者)是Hadoop2.x版…

LLM-向量数据库中的索引算法总结

文章目录 前言向量数据库介绍索引方法倒排索引KNN 搜索近似 KNN 搜索Product Quantization(PQ)NSW 算法搜索HNSW 前言 向量数据库是当今大模型知识库检索落地实践的核心组件,下图是构建知识库检索的架构图: 首先会将相关文档数据向量化嵌入到向量化数据…

达梦数据库dm8安装步骤及迁移

目录 前言: 一、安装部署 1、下载 2、创建用户及安装目录 3、挂载下载的镜像 4、环境配置 5、安装 二、基本使用 1、DM工具使用 2、兼容性配置 2.1 兼容GBK字符集编码 2.2 兼容UTF-8字符集编码 3、创建用户和密码,表空间 4、整理数据库配置 5、启动脚本设置 …

JavaSE学习笔记之内部类、枚举类和基本类型包装类

今天我们继续复习Java相关的知识,和大家分享有关内部类等方面的知识,希望大家喜欢。 目录​​​​​​​ 内部类 成员内部类 ​编辑 静态内部类 局部内部类 匿名内部类 枚举类 定义方法 基本类型包装类 自动装箱和拆箱 内部类 成员内部类 成…

使用 Google 的 Generative AI 服务时,请求没有包含足够的认证范围(scopes)

题意: Google generativeai 403 Request had insufficient authentication scopes. [reason: "ACCESS_TOKEN_SCOPE_INSUFFICIENT" 问题背景: I have tried the simple POC for generativeai on its own to do generate_content and it works…

Python酷库之旅-第三方库Pandas(017)

目录 一、用法精讲 41、pandas.melt函数 41-1、语法 41-2、参数 41-3、功能 41-4、返回值 41-5、说明 41-5-1、宽格式数据(Wide Format) 41-5-2、长格式数据(Long Format) 41-6、用法 41-6-1、数据准备 41-6-2、代码示例 41-6-3、结果输出 42、pandas.pivot函数 …

【单片机毕业设计选题24059】-太阳能嵌入式智能充电系统研究

系统功能: 系统由太阳能电池板提供电源, 系统上电后显示“欢迎使用智能充电系统请稍后”, 两秒钟后进入主页面显示。 第一行显示太阳能电池板输入的电压值 第二行显示系统输出的电压值 第三行显示采集到的太阳能电池板温度 第四行显示设置的太阳能…

回归损失和分类损失

回归损失和分类损失是机器学习模型训练过程中常用的两类损失函数,分别适用于回归任务和分类任务。 回归损失函数 回归任务的目标是预测一个连续值,因此回归损失函数衡量预测值与真实值之间的差异。常见的回归损失函数有: 均方误差&#xff…

【UNI-APP】阿里NLS一句话听写typescript模块

阿里提供的demo代码都是javascript,自己捏个轮子。参考着自己写了一个阿里巴巴一句话听写Nls的typescript模块。VUE3的组合式API形式 startClient:开始听写,注意下一步要尽快开启识别和传数据,否则6秒后会关闭 startRecognition…

004-基于Sklearn的机器学习入门:回归分析(下)

本节及后续章节将介绍机器学习中的几种经典回归算法,包括线性回归,多项式回归,以及正则项的岭回归等,所选方法都在Sklearn库中聚类模块有具体实现。本节为下篇,将介绍多项式回归和岭回归等。 目录 2.3 多项式回归 2…

Point Cloud Library (PCL) for Python - pclpy 安装指南 (1)

以下所有的版本号务必按照说明安装。 1.安装 Python 3.6 https://www.python.org/ftp/python/3.6.8/python-3.6.8-amd64.exe #或 百度网盘 2.确认 Python 版本为 3.6.x python #Python 3.6.8 (tags/v3.6.8:3c6b436a57, Dec 24 2018, 00:16:47) [MSC v.1916 64 bit (AMD64)] on…

给后台写了一个优雅的自定义风格的数据日志上报页面

highlight: atelier-cave-dark 查看后台数据日志是非常常见的场景,经常看到后台的小伙伴从服务器日志复制一段json数据字符串,然后找一个JSON工具网页打开,在线JSON格式化校验。有的时候,一些业务需要展示mqtt或者socket的实时信息展示,如果不做任何修改直接展示一串字符…

学习笔记——动态路由——IS-IS中间系统到中间系统(特性之路由撤销)

6、路由撤销 ISIS路由协议的路由信息是封装在LSP报文中的TLV中的,但是它对撤销路由的处理和OSPF的处理方式类似。 在ISIS中撤销一条路由实则是将接口下的ISIS关闭: 撤销内部路由: 在ISIS中路由信息是由IP接口TLV和IP内部可达性TLV共同来描…

合宙 Air780E模块 AT 指令 MQTT连接

固件说明 重启模块 //tx ATRESET//rx ATRESETOK ^boot.romv!\n RDY^MODE: 17,17E_UTRAN ServiceCGEV: ME PDN ACT 1NITZ: 2024/07/10,08:33:440,0查询模块版本信息 //tx ATCGMR//rx ATCGMRCGMR: "AirM2M_780E_V1161_LTE_AT"OK基本流程 4G模块支持MQTT和MQTT SSl协…

顶顶通呼叫中心中间件-私有化asrproxy配置热词模型

顶顶通呼叫中心中间件-私有化asrproxy配置热词模型 1、配置热词文件 将热词存在一个txt文件中,比如:hotword.txttxt文本里面写热词,一个热词一行,用utf8编码把热词文件上传到asrproxy程序目录中,路径:/dd…

读人工智能全传10深度思维

1. 深度思维 1.1. DeepMind 1.1.1. 深度思维 1.1.2. 2014年的员工不足25人 1.1.3. 深度思维公司公开宣称其任务是解决智能问题 1.1.4. 2014年谷歌收购DeepMind,人工智能突然成了新闻热点,以及商业热点 1.1.4.1. 收购报价高达4亿英镑 1.1.4.2. 深度…

头歌资源库(26)方格填数

一、 问题描述 二、算法思想 这是一个排列组合问题。我们可以使用动态规划的思想来求解。 假设dp[i]表示填入前i个位置的数字的方案数。考虑第i个位置,它有9种填法(0~9减去前一个位置上的数字),则有dp[i] 9 * dp[i-1]。由于第…

240711_昇思学习打卡-Day23-LSTM+CRF序列标注(2)

240711_昇思学习打卡-Day23-LSTMCRF序列标注(2) 今天记录LSTMCRF序列标注的第二部分。仅作简单记录 Score计算 首先计算正确标签序列所对应的得分,这里需要注意,除了转移概率矩阵𝐏外,还需要维护两个大小…

html5——CSS基础选择器

目录 标签选择器 类选择器 id选择器 三种选择器优先级 标签指定式选择器 包含选择器 群组选择器 通配符选择器 Emmet语法&#xff08;扩展补充&#xff09; 标签选择器 HTML标签作为标签选择器的名称&#xff1a; <h1>…<h6>、<p>、<img/> 语…

如何做好漏洞扫描工作提高网络安全

在数字化浪潮席卷全球的今天&#xff0c;企业数字化转型已成为提升竞争力、实现可持续发展的关键路径。然而&#xff0c;这一转型过程并非坦途&#xff0c;其中网络安全问题如同暗礁般潜伏&#xff0c;稍有不慎便可能引发数据泄露、服务中断乃至品牌信誉受损等严重后果。因此&a…