如何在深度学习中调用CAME

1、介绍

CAME:一种以置信度为导向的策略,以减少现有内存高效优化器的不稳定性。基于此策略,我们提出CAME同时实现两个目标:传统自适应方法的快速收敛和内存高效方法的低内存使用。大量的实验证明了CAME在各种NLP任务(如BERT和GPT-2训练)中的训练稳定性和优异的性能。

2、Pytorch中调用该优化算法

(1)定义CAME

import mathimport torch
import torch.optimclass CAME(torch.optim.Optimizer):"""Implements CAME algorithm.This implementation is based on:`CAME: Confidence-guided Adaptive Memory Efficient Optimization`Args:params (iterable): iterable of parameters to optimize or dicts definingparameter groupslr (float, optional): external learning rate (default: None)eps (tuple[float, float]): regularization constants for square gradientand instability respectively (default: (1e-30, 1e-16))clip_threshold (float): threshold of root-mean-square offinal gradient update (default: 1.0)betas (tuple[float, float, float]): coefficient used for computing running averages ofupdate, square gradient and instability (default: (0.9, 0.999, 0.9999)))weight_decay (float, optional): weight decay (L2 penalty) (default: 0)"""def __init__(self,params,lr=None,eps=(1e-30, 1e-16),clip_threshold=1.0,betas=(0.9, 0.999, 0.9999),weight_decay=0.0,):assert lr > 0.assert all([0. <= beta <= 1. for beta in betas])defaults = dict(lr=lr,eps=eps,clip_threshold=clip_threshold,betas=betas,weight_decay=weight_decay,)super(CAME, self).__init__(params, defaults)@propertydef supports_memory_efficient_fp16(self):return True@propertydef supports_flat_params(self):return Falsedef _get_options(self, param_shape):factored = len(param_shape) >= 2return factoreddef _rms(self, tensor):return tensor.norm(2) / (tensor.numel() ** 0.5)def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):r_factor = ((exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1))c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()return torch.mul(r_factor, c_factor)def step(self, closure=None):"""Performs a single optimization step.Args:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:loss = closure()for group in self.param_groups:for p in group["params"]:if p.grad is None:continuegrad = p.grad.dataif grad.dtype in {torch.float16, torch.bfloat16}:grad = grad.float()if grad.is_sparse:raise RuntimeError("CAME does not support sparse gradients.")state = self.state[p]grad_shape = grad.shapefactored = self._get_options(grad_shape)# State Initializationif len(state) == 0:state["step"] = 0state["exp_avg"] = torch.zeros_like(grad)if factored:state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).type_as(grad)state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).type_as(grad)state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1]).type_as(grad)state["exp_avg_res_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).type_as(grad)else:state["exp_avg_sq"] = torch.zeros_like(grad)state["RMS"] = 0state["step"] += 1state["RMS"] = self._rms(p.data)update = (grad**2) + group["eps"][0]if factored:exp_avg_sq_row = state["exp_avg_sq_row"]exp_avg_sq_col = state["exp_avg_sq_col"]exp_avg_sq_row.mul_(group["betas"][1]).add_(update.mean(dim=-1), alpha=1.0 - group["betas"][1])exp_avg_sq_col.mul_(group["betas"][1]).add_(update.mean(dim=-2), alpha=1.0 - group["betas"][1])# Approximation of exponential moving average of square of gradientupdate = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)update.mul_(grad)else:exp_avg_sq = state["exp_avg_sq"]exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1])update = exp_avg_sq.rsqrt().mul_(grad)update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))exp_avg = state["exp_avg"]exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])# Confidence-guided strategy# Calculation of instabilityres = (update - exp_avg)**2 + group["eps"][1]if factored:exp_avg_res_row = state["exp_avg_res_row"]exp_avg_res_col = state["exp_avg_res_col"]exp_avg_res_row.mul_(group["betas"][2]).add_(res.mean(dim=-1), alpha=1.0 - group["betas"][2])exp_avg_res_col.mul_(group["betas"][2]).add_(res.mean(dim=-2), alpha=1.0 - group["betas"][2])# Approximation of exponential moving average of instabilityres_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)update = res_approx.mul_(exp_avg)else:update = exp_avgif group["weight_decay"] != 0:p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])update.mul_(group["lr"])p.data.add_(-update)return loss

(2)在深度学习中调用CAME优化器

本文以使用LSTM算法对鸢尾花数据集进行分类为例,并且在代码中加入早停和十折交叉验证技术。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score# 定义 LSTM 模型
class LSTMClassifier(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(LSTMClassifier, self).__init__()self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):_, (hn, _) = self.lstm(x)out = self.fc(hn[-1])  # 选择最后一个 LSTM 隐层输出return out# 早停
class EarlyStopping:def __init__(self, patience=5, min_delta=0):self.patience = patienceself.min_delta = min_deltaself.best_loss = float('inf')self.counter = 0self.early_stop = Falsedef step(self, val_loss):if val_loss < self.best_loss - self.min_delta:self.best_loss = val_lossself.counter = 0else:self.counter += 1if self.counter >= self.patience:self.early_stop = True# 读取数据
iris = load_iris()
X = iris.data
y = iris.target# 标准化数据
scaler = StandardScaler()
X = scaler.fit_transform(X)# 将数据转换为 PyTorch 张量
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)# 配置模型参数
input_size = X.shape[1]  # 特征数量
hidden_size = 32
num_classes = 3
batch_size = 16
num_epochs = 100
learning_rate = 0.001
patience = 5# 进行十折交叉验证
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
fold_idx = 0for train_index, val_index in kf.split(X, y):fold_idx += 1print(f"Fold {fold_idx}")# 划分训练集和验证集X_train, X_val = X[train_index], X[val_index]y_train, y_val = y[train_index], y[val_index]# 定义模型和优化器model = LSTMClassifier(input_size, hidden_size, num_classes)optimizer = CAME(model.parameters(), lr=2e-4, weight_decay=1e-2, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))# optimizer = optim.Adam(model.parameters(), lr=learning_rate)criterion = nn.CrossEntropyLoss()# 早停设置early_stopping = EarlyStopping(patience=patience)# 训练模型for epoch in range(num_epochs):# 训练阶段model.train()optimizer.zero_grad()outputs = model(X_train.unsqueeze(1))loss = criterion(outputs, y_train)loss.backward()optimizer.step()# 验证阶段model.eval()with torch.no_grad():val_outputs = model(X_val.unsqueeze(1))val_loss = criterion(val_outputs, y_val)# 打印每轮迭代的损失值print(f"Epoch {epoch + 1}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss.item():.4f}")# 早停检查early_stopping.step(val_loss.item())if early_stopping.early_stop:print(f"Early stopping at epoch {epoch + 1}")break# 评估模型model.eval()with torch.no_grad():val_outputs = model(X_val.unsqueeze(1))_, predicted = torch.max(val_outputs, 1)accuracy = accuracy_score(y_val, predicted)print(f"Fold {fold_idx} Validation Accuracy: {accuracy:.4f}\n")

使用LSTM算法对鸢尾花数据集分类结果
由于CAME主要面向NLP数据集,因此对于鸢尾花效果不算好,本文仅展示CAME的使用方法,并非提升acc和epoch。

参考文献:Luo, Yang, et al. “CAME: Confidence-guided Adaptive Memory Efficient Optimization.” arXiv preprint arXiv:2307.02047 (2023).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/820480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android Jetpack 组件

1、ViewModel 用于将数据与Activity分离&#xff0c;这样在Activity声明周期中&#xff0c;数据不会丢失。 &#xff08;1&#xff09;简单使用 implementation ("androidx.lifecycle:lifecycle-extensions:2.2.0") // 使用ViewModel组件需要额外添加<LinearLay…

必应bing竞价广告推广开户联系方式?

随着互联网广告市场的日益繁荣与细分&#xff0c;必应Bing作为全球重要的搜索引擎之一&#xff0c;在国内市场也逐渐展现出强大的潜力与吸引力。越来越多的企业开始关注并探索必应Bing搜索广告所带来的巨大商机。其中&#xff0c;云衔科技以其卓越的专业素养和全面的服务体系&a…

stable diffusion--小白学习步骤

1.看一下Unet网络的讲解_哔哩哔哩_bilibili&#xff0c;了解Unet网络 2.看一下【生成式AI】Diffusion Model 原理剖析 (1/4)_哔哩哔哩_bilibili&#xff0c;起码要看前3/6个视频 3.看一下超详细的扩散模型&#xff08;Diffusion Models&#xff09;原理代码 - 知乎 (zhihu.co…

鑫鹿助贷CRM系统:助力助贷行业实现智能商业转型

数字化时代&#xff0c;商业竞争愈发激烈&#xff0c;助贷行业如何把握商机、实现高效管理、打造高回报率的商业模式&#xff0c;成为了助贷行业老板们比较关注的问题&#xff0c;而鑫鹿助贷CRM管理系统&#xff0c;正是这场商业变革中的得力助手&#xff0c;系统功能完善&…

途游游戏,科锐国际(计算机类),快手,得物,蓝禾,奇安信,顺丰,康冠科技,金证科技24春招内推

途游游戏&#xff0c;科锐国际&#xff08;计算机类&#xff09;&#xff0c;快手&#xff0c;得物&#xff0c;蓝禾&#xff0c;奇安信&#xff0c;顺丰&#xff0c;康冠科技&#xff0c;金证科技24春招内推 ①得物 【岗位】技术&#xff0c;设计&#xff0c;供应链&#xff0…

每帧纵享丝滑——ToDesk云电脑、网易云游戏、无影云评测分析及ComfyUI部署

目录 一、前言二、云电脑性能测评分析2.1、基本配置分析2.1.1、处理器方面2.1.2、显卡方面2.1.3、内存与存储方面2.1.4、软件功能方面 2.2、综合跑分评测 三、软件应用实测分析3.1、云电竞测评3.2、AIGC科研测评——ComfyUI部署3.2.1、下载与激活工作台3.2.2、加载模型与体验3.…

vba学习系列(4)-- index()提取指定单元格并保留字体格式

系列文章目录 文章目录 系列文章目录一、目标需求二、使用步骤1.VBA程序2.VBA简要程序 总结 一、目标需求 工作表2 B列中姓名&#xff0c;在工作表1 C列中存在相同姓名时&#xff0c;提取工作表2 AK列的对应单元格内容&#xff1b; 工作表2名称&#xff1a;OQC 工作表1名称&…

AGI的智力有可能在两年内超过人类水平

特斯拉CEO埃隆马斯克近日与挪威银行投资管理基金CEO坦根的访谈中表示&#xff0c;AGI的智力将在两年内可能超过人类智力&#xff0c;在未来五年内&#xff0c;AI的能力很可能超过所有人类。 马斯克透漏&#xff0c;去年人工智能发展过程中的主要制约因素是缺少高性能芯片&#…

基于springboot实现人事管理系统项目【项目源码+论文说明】

基于springboot实现人事管理系统演示 摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;作为学校以及一些培训机构&#xff0c;都在用信息化战术来部署线上学习以及线上考试&#xff0c;可以与线下的考试有机的结合在一起&#xff0c;实现基于vue的人事系统在技术…

c++ 栈溢出问题

示例代码: #include <iostream> #include <chrono> #include <fstream>int main() {// 测量内存操作的执行时间int num = 1024 * 1024;int arry[num] = {2};int arry_tmp[num] = {0};std::ofstream outfile("data.bin", std::ios::binary | std:…

【LeetCode热题100】【二分查找】搜索二维矩阵

题目链接&#xff1a;74. 搜索二维矩阵 - 力扣&#xff08;LeetCode&#xff09; 在一个有序二维数组里面查找元素&#xff0c;同【LeetCode热题100】【矩阵】搜索二维矩阵 II-CSDN博客 如果用二分查找&#xff0c;时间复杂度是log(mn)&#xff0c;但是可以实现时间复杂度为O…

交通大模型与时序大模型整理【共15篇工作】【附开源代码】

随着城市化进程的加速和交通系统的不断发展&#xff0c;对交通数据和时序数据的整理与分析变得尤为重要。本文旨在探讨交通大模型与时序大模型的整理及其在城市规划、交通管理等领域的应用。交通大模型涉及交通流量、道路网络、交通规划等方面的数据&#xff0c;而时序大模型则…

代码随想录-算法训练营day10【栈与队列01:理论基础、用栈实现队列、用队列实现栈】

代码随想录-035期-算法训练营【博客笔记汇总表】-CSDN博客 第五章 栈与队列part01 ● day 1 任务以及具体安排&#xff1a;https://docs.qq.com/doc/DUG9UR2ZUc3BjRUdY ● day 2 任务以及具体安排&#xff1a;https://docs.qq.com/doc/DUGRwWXNOVEpyaVpG ● day 3 任务以及…

ZooKeeper临时有序节点生成过程以及序号超过最大值的处理思路

目录 ZooKeeper临时有序节点生成过程 ZooKeeper序号超过最大值的处理 ZooKeeper临时有序节点生成过程 创建节点时指定类型 当客户端向ZooKeeper请求创建节点时&#xff0c;需要指定节点类型。对于临时有序节点&#xff0c;应使用CreateMode.EPHEMERAL_SEQUENTIAL标志。这告诉…

IP定位技术原理详细阐述

IP定位技术原理主要基于IP地址与地理位置之间的关联&#xff0c;通过一系列的技术手段&#xff0c;实现对网络设备的物理位置进行精确或大致的定位。以下是对IP定位技术原理的详细阐述。 首先&#xff0c;我们需要了解IP地址的基本概念。IP地址是互联网协议地址的简称&#xff…

大模型日报|今日必读的10篇大模型论文

大家好&#xff0c;今日必读的大模型论文来啦&#xff01; 1.谷歌推出新型 Transformer 架构&#xff1a;反馈注意力就是工作记忆 虽然 Transformer 给深度学习带来了革命性的变化&#xff0c;但二次注意复杂性阻碍了其处理无限长输入的能力。 谷歌研究团队提出了一种新型 T…

前端开发攻略---从源码角度分析Vue3的Propy比Vue2的defineproperty到底好在哪里。一篇文章让你彻底弄懂响应式原理。

1、思考 Vue的响应式到底要干什么&#xff1f; 无非就是要知道当你读取对象的时候&#xff0c;要知道它读了。要做一些别的事情无非就是要知道当你修改对象的时候&#xff0c;要知道它改了。要做一些别的事情所以要想一个办法&#xff0c;把读取和修改的动作变成一个函数&#…

xcode c++项目设置运行时参数

在 Xcode 项目中&#xff0c;你可以通过配置 scheme 来指定在运行时传递的参数。以下是在 Xcode 中设置运行时参数的步骤&#xff1a; 打开 Xcode&#xff0c;并打开你的项目。在 Xcode 菜单栏中&#xff0c;选择 "Product" -> "Scheme" -> "E…

前端实现下载的2种方法(个人总结)

一.后端在接口指明了下载的类型是blob类型 要实现下载项目数据并成为excel格式的 以这个接口为例: export const conversationDown () > {return http({url: /conversation/down,method: GET,responseType: blob}) } const handleDownload async () > {const res …

每日一练

这题我主要用的思想是:动态规划 1.状态表示&#xff1a;以i位置为结尾的字符串是否可以用字典表示&#xff0c;然后就可以拆分成 j ~ i 为字典中的最后一个单词&#xff0c;此时 0 < j < i (1.有可能全部为字典的一个单词&#xff0c;2.有可能只有一个字母的单词)&#x…