如何在深度学习中调用CAME

1、介绍

CAME:一种以置信度为导向的策略,以减少现有内存高效优化器的不稳定性。基于此策略,我们提出CAME同时实现两个目标:传统自适应方法的快速收敛和内存高效方法的低内存使用。大量的实验证明了CAME在各种NLP任务(如BERT和GPT-2训练)中的训练稳定性和优异的性能。

2、Pytorch中调用该优化算法

(1)定义CAME

import mathimport torch
import torch.optimclass CAME(torch.optim.Optimizer):"""Implements CAME algorithm.This implementation is based on:`CAME: Confidence-guided Adaptive Memory Efficient Optimization`Args:params (iterable): iterable of parameters to optimize or dicts definingparameter groupslr (float, optional): external learning rate (default: None)eps (tuple[float, float]): regularization constants for square gradientand instability respectively (default: (1e-30, 1e-16))clip_threshold (float): threshold of root-mean-square offinal gradient update (default: 1.0)betas (tuple[float, float, float]): coefficient used for computing running averages ofupdate, square gradient and instability (default: (0.9, 0.999, 0.9999)))weight_decay (float, optional): weight decay (L2 penalty) (default: 0)"""def __init__(self,params,lr=None,eps=(1e-30, 1e-16),clip_threshold=1.0,betas=(0.9, 0.999, 0.9999),weight_decay=0.0,):assert lr > 0.assert all([0. <= beta <= 1. for beta in betas])defaults = dict(lr=lr,eps=eps,clip_threshold=clip_threshold,betas=betas,weight_decay=weight_decay,)super(CAME, self).__init__(params, defaults)@propertydef supports_memory_efficient_fp16(self):return True@propertydef supports_flat_params(self):return Falsedef _get_options(self, param_shape):factored = len(param_shape) >= 2return factoreddef _rms(self, tensor):return tensor.norm(2) / (tensor.numel() ** 0.5)def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):r_factor = ((exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1))c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()return torch.mul(r_factor, c_factor)def step(self, closure=None):"""Performs a single optimization step.Args:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:loss = closure()for group in self.param_groups:for p in group["params"]:if p.grad is None:continuegrad = p.grad.dataif grad.dtype in {torch.float16, torch.bfloat16}:grad = grad.float()if grad.is_sparse:raise RuntimeError("CAME does not support sparse gradients.")state = self.state[p]grad_shape = grad.shapefactored = self._get_options(grad_shape)# State Initializationif len(state) == 0:state["step"] = 0state["exp_avg"] = torch.zeros_like(grad)if factored:state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).type_as(grad)state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).type_as(grad)state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1]).type_as(grad)state["exp_avg_res_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).type_as(grad)else:state["exp_avg_sq"] = torch.zeros_like(grad)state["RMS"] = 0state["step"] += 1state["RMS"] = self._rms(p.data)update = (grad**2) + group["eps"][0]if factored:exp_avg_sq_row = state["exp_avg_sq_row"]exp_avg_sq_col = state["exp_avg_sq_col"]exp_avg_sq_row.mul_(group["betas"][1]).add_(update.mean(dim=-1), alpha=1.0 - group["betas"][1])exp_avg_sq_col.mul_(group["betas"][1]).add_(update.mean(dim=-2), alpha=1.0 - group["betas"][1])# Approximation of exponential moving average of square of gradientupdate = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)update.mul_(grad)else:exp_avg_sq = state["exp_avg_sq"]exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1])update = exp_avg_sq.rsqrt().mul_(grad)update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))exp_avg = state["exp_avg"]exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])# Confidence-guided strategy# Calculation of instabilityres = (update - exp_avg)**2 + group["eps"][1]if factored:exp_avg_res_row = state["exp_avg_res_row"]exp_avg_res_col = state["exp_avg_res_col"]exp_avg_res_row.mul_(group["betas"][2]).add_(res.mean(dim=-1), alpha=1.0 - group["betas"][2])exp_avg_res_col.mul_(group["betas"][2]).add_(res.mean(dim=-2), alpha=1.0 - group["betas"][2])# Approximation of exponential moving average of instabilityres_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)update = res_approx.mul_(exp_avg)else:update = exp_avgif group["weight_decay"] != 0:p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])update.mul_(group["lr"])p.data.add_(-update)return loss

(2)在深度学习中调用CAME优化器

本文以使用LSTM算法对鸢尾花数据集进行分类为例,并且在代码中加入早停和十折交叉验证技术。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score# 定义 LSTM 模型
class LSTMClassifier(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(LSTMClassifier, self).__init__()self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):_, (hn, _) = self.lstm(x)out = self.fc(hn[-1])  # 选择最后一个 LSTM 隐层输出return out# 早停
class EarlyStopping:def __init__(self, patience=5, min_delta=0):self.patience = patienceself.min_delta = min_deltaself.best_loss = float('inf')self.counter = 0self.early_stop = Falsedef step(self, val_loss):if val_loss < self.best_loss - self.min_delta:self.best_loss = val_lossself.counter = 0else:self.counter += 1if self.counter >= self.patience:self.early_stop = True# 读取数据
iris = load_iris()
X = iris.data
y = iris.target# 标准化数据
scaler = StandardScaler()
X = scaler.fit_transform(X)# 将数据转换为 PyTorch 张量
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)# 配置模型参数
input_size = X.shape[1]  # 特征数量
hidden_size = 32
num_classes = 3
batch_size = 16
num_epochs = 100
learning_rate = 0.001
patience = 5# 进行十折交叉验证
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
fold_idx = 0for train_index, val_index in kf.split(X, y):fold_idx += 1print(f"Fold {fold_idx}")# 划分训练集和验证集X_train, X_val = X[train_index], X[val_index]y_train, y_val = y[train_index], y[val_index]# 定义模型和优化器model = LSTMClassifier(input_size, hidden_size, num_classes)optimizer = CAME(model.parameters(), lr=2e-4, weight_decay=1e-2, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))# optimizer = optim.Adam(model.parameters(), lr=learning_rate)criterion = nn.CrossEntropyLoss()# 早停设置early_stopping = EarlyStopping(patience=patience)# 训练模型for epoch in range(num_epochs):# 训练阶段model.train()optimizer.zero_grad()outputs = model(X_train.unsqueeze(1))loss = criterion(outputs, y_train)loss.backward()optimizer.step()# 验证阶段model.eval()with torch.no_grad():val_outputs = model(X_val.unsqueeze(1))val_loss = criterion(val_outputs, y_val)# 打印每轮迭代的损失值print(f"Epoch {epoch + 1}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss.item():.4f}")# 早停检查early_stopping.step(val_loss.item())if early_stopping.early_stop:print(f"Early stopping at epoch {epoch + 1}")break# 评估模型model.eval()with torch.no_grad():val_outputs = model(X_val.unsqueeze(1))_, predicted = torch.max(val_outputs, 1)accuracy = accuracy_score(y_val, predicted)print(f"Fold {fold_idx} Validation Accuracy: {accuracy:.4f}\n")

使用LSTM算法对鸢尾花数据集分类结果
由于CAME主要面向NLP数据集,因此对于鸢尾花效果不算好,本文仅展示CAME的使用方法,并非提升acc和epoch。

参考文献:Luo, Yang, et al. “CAME: Confidence-guided Adaptive Memory Efficient Optimization.” arXiv preprint arXiv:2307.02047 (2023).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/820480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

必应bing竞价广告推广开户联系方式?

随着互联网广告市场的日益繁荣与细分&#xff0c;必应Bing作为全球重要的搜索引擎之一&#xff0c;在国内市场也逐渐展现出强大的潜力与吸引力。越来越多的企业开始关注并探索必应Bing搜索广告所带来的巨大商机。其中&#xff0c;云衔科技以其卓越的专业素养和全面的服务体系&a…

stable diffusion--小白学习步骤

1.看一下Unet网络的讲解_哔哩哔哩_bilibili&#xff0c;了解Unet网络 2.看一下【生成式AI】Diffusion Model 原理剖析 (1/4)_哔哩哔哩_bilibili&#xff0c;起码要看前3/6个视频 3.看一下超详细的扩散模型&#xff08;Diffusion Models&#xff09;原理代码 - 知乎 (zhihu.co…

鑫鹿助贷CRM系统:助力助贷行业实现智能商业转型

数字化时代&#xff0c;商业竞争愈发激烈&#xff0c;助贷行业如何把握商机、实现高效管理、打造高回报率的商业模式&#xff0c;成为了助贷行业老板们比较关注的问题&#xff0c;而鑫鹿助贷CRM管理系统&#xff0c;正是这场商业变革中的得力助手&#xff0c;系统功能完善&…

每帧纵享丝滑——ToDesk云电脑、网易云游戏、无影云评测分析及ComfyUI部署

目录 一、前言二、云电脑性能测评分析2.1、基本配置分析2.1.1、处理器方面2.1.2、显卡方面2.1.3、内存与存储方面2.1.4、软件功能方面 2.2、综合跑分评测 三、软件应用实测分析3.1、云电竞测评3.2、AIGC科研测评——ComfyUI部署3.2.1、下载与激活工作台3.2.2、加载模型与体验3.…

AGI的智力有可能在两年内超过人类水平

特斯拉CEO埃隆马斯克近日与挪威银行投资管理基金CEO坦根的访谈中表示&#xff0c;AGI的智力将在两年内可能超过人类智力&#xff0c;在未来五年内&#xff0c;AI的能力很可能超过所有人类。 马斯克透漏&#xff0c;去年人工智能发展过程中的主要制约因素是缺少高性能芯片&#…

基于springboot实现人事管理系统项目【项目源码+论文说明】

基于springboot实现人事管理系统演示 摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;作为学校以及一些培训机构&#xff0c;都在用信息化战术来部署线上学习以及线上考试&#xff0c;可以与线下的考试有机的结合在一起&#xff0c;实现基于vue的人事系统在技术…

IP定位技术原理详细阐述

IP定位技术原理主要基于IP地址与地理位置之间的关联&#xff0c;通过一系列的技术手段&#xff0c;实现对网络设备的物理位置进行精确或大致的定位。以下是对IP定位技术原理的详细阐述。 首先&#xff0c;我们需要了解IP地址的基本概念。IP地址是互联网协议地址的简称&#xff…

大模型日报|今日必读的10篇大模型论文

大家好&#xff0c;今日必读的大模型论文来啦&#xff01; 1.谷歌推出新型 Transformer 架构&#xff1a;反馈注意力就是工作记忆 虽然 Transformer 给深度学习带来了革命性的变化&#xff0c;但二次注意复杂性阻碍了其处理无限长输入的能力。 谷歌研究团队提出了一种新型 T…

前端开发攻略---从源码角度分析Vue3的Propy比Vue2的defineproperty到底好在哪里。一篇文章让你彻底弄懂响应式原理。

1、思考 Vue的响应式到底要干什么&#xff1f; 无非就是要知道当你读取对象的时候&#xff0c;要知道它读了。要做一些别的事情无非就是要知道当你修改对象的时候&#xff0c;要知道它改了。要做一些别的事情所以要想一个办法&#xff0c;把读取和修改的动作变成一个函数&#…

xcode c++项目设置运行时参数

在 Xcode 项目中&#xff0c;你可以通过配置 scheme 来指定在运行时传递的参数。以下是在 Xcode 中设置运行时参数的步骤&#xff1a; 打开 Xcode&#xff0c;并打开你的项目。在 Xcode 菜单栏中&#xff0c;选择 "Product" -> "Scheme" -> "E…

每日一练

这题我主要用的思想是:动态规划 1.状态表示&#xff1a;以i位置为结尾的字符串是否可以用字典表示&#xff0c;然后就可以拆分成 j ~ i 为字典中的最后一个单词&#xff0c;此时 0 < j < i (1.有可能全部为字典的一个单词&#xff0c;2.有可能只有一个字母的单词)&#x…

光纤网络的星际旅行:SFP与QSFP光模块的技术演进

&#x1f389;光模块作为完成光电转换的光器件&#xff0c;在光通信网络中必不可少&#xff0c;常见的有千兆/万兆光模块、SFP/SFP/QSFP28光模块等&#xff0c;那你知道这些光模块都是如何分类的吗&#xff1f;另外还有哪些类型&#xff1f;接下来我会在本文详细介绍光模块是如…

平台工程在企业数字化转型中的战略价值

要建设成功、有弹性和面向未来的平台&#xff0c;需要做到这三点&#xff1a;了解需求、预测可能面临的挑战并制定经得起时间考验的解决方案。 了解需求是指理解利益相关者的要求和目标&#xff0c;无论他们是最终用户、开发人员还是平台生态系统中的其他相关方。这包括开展全面…

排序算法。

冒泡排序: 基本&#xff1a; private static void sort(int[] a){for (int i 0; i < a.length-1; i) {for (int j 0; j < a.length-i-1; j) {if (a[j]>a[j1]){swap(a,j,j1);}}}} private static void swap(int[] a,int i,int j){int tempa[i];a[i]a[j];a[j]temp;} …

MathType安装导致的Word粘贴操作出现运行时错误‘53’:文件未找到:MathPage.WLL

MathType安装导致的Word粘贴操作出现运行时错误‘53’&#xff1a;文件未找到&#xff1a;MathPage.WLL 解决方案 1、确定自己电脑的位数&#xff1b; 2、右击MathType桌面图标&#xff0c;点击“打开文件所在位置”&#xff0c;然后找到MathPage.WLL &#xff0c;复制一份进行…

二级综合医院云HIS系统源码,B/S架构,采用JAVA编程,集成相关医保接口

二级医院云HIS系统源码 云HIS系统是一款满足基层医院各类业务需要的健康云产品。该产品能帮助基层医院完成日常各类业务&#xff0c;提供病患预约挂号支持、病患问诊、电子病历、开药发药、会员管理、统计查询、医生工作站和护士工作站等一系列常规功能&#xff0c;还能与公卫…

【御控物联】Java JSON结构转换(3):对象To对象——多层属性重组

文章目录 一、JSON结构转换是什么&#xff1f;二、案例之《JSON对象 To JSON对象》三、代码实现四、在线转换工具五、技术资料 一、JSON结构转换是什么&#xff1f; JSON结构转换指的是将一个JSON对象或JSON数组按照一定规则进行重组、筛选、映射或转换&#xff0c;生成新的JS…

【电控笔记6.2】拉式转换与转移函数

概要 laplace&#xff1a;单输入单输出&#xff0c;线性系统 laplace 传递函数 总结

Spring GA、PRE、SNAPSHOT 版本含义及区别

GA:General Availability: 正式发布的版本&#xff0c;推荐使用&#xff08;主要是稳定&#xff09;&#xff0c;与maven的releases类似&#xff1b; PRE: 预览版,内部测试版。主要是给开发人员和测试人员测试和找BUG用的&#xff0c;不建议使用&#xff1b; SNAPSHOT: 快照…

算法学习系列(四十九):最长上升子序列模型(一)

目录 引言一、最长上升子序列二、怪盗基德的滑翔翼三、登山四、合唱队形五、友好城市六、最大上升子序列和 引言 今天学习的是最长上升子序列模型&#xff0c;这种模型我觉得就是我之前说过的就是相当于记忆的过程&#xff0c;记住遇到这种题是用这种模型&#xff0c;下次遇见…