介绍 C++ 中的智能指针及其应用:以 PyTorch框架自动梯度AutogradMeta为例

介绍 C++ 中的智能指针及其应用:以 AutogradMeta 为例

在 C++ 中,智能指针(Smart Pointer)是用于管理动态分配内存的一种工具。它们不仅自动管理内存的生命周期,还能帮助避免内存泄漏和野指针等问题。在深度学习框架如 PyTorch 的实现中,智能指针被广泛应用于复杂的数据结构和计算图的管理中。本文将结合 AutogradMeta 类,详细介绍 C++ 中的智能指针,解释 std::shared_ptrstd::weak_ptrstd::unique_ptr 等智能指针的使用场景及区别。

Source: https://github.com/pytorch/pytorch/blob/00df63f09f07546bacec734f37132edc58ccf574/torch/csrc/autograd/variable.h#L102

//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
//                            AutogradMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~/// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
/// metadata fields that are necessary for tracking the Variable's autograd
/// history. As an optimization, a Variable may store a nullptr, in lieu of a
/// default constructed AutogradMeta.struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {std::string name_;Variable grad_;std::shared_ptr<Node> grad_fn_;std::weak_ptr<Node> grad_accumulator_;// This field is used to store all the forward AD gradients// associated with this AutogradMeta (and the Tensor it corresponds to)// There is a semantic 1:1 correspondence between AutogradMeta and// ForwardGrad but://   - This field is lazily populated.//   - This field is a shared_ptr but it must never be//     shared by multiple Tensors. See Note [ Using ForwardGrad ]// Any transition from not_initialized to initialized// must be protected by mutex_mutable std::shared_ptr<ForwardGrad> fw_grad_;// The hooks_ field is actually reused by both python and cpp logic// For both cases, we have a data structure, cpp_hooks_list_ (cpp)// or dict (python) which is the canonical copy.// Then, for both cases, we always register a single hook to// hooks_ which wraps all the hooks in the list/dict.// And, again in both cases, if the grad_fn exists on that tensor// we will additionally register a single hook to the grad_fn.//// Note that the cpp and python use cases aren't actually aware of// each other, so using both is not defined behavior.std::vector<std::unique_ptr<FunctionPreHook>> hooks_;std::shared_ptr<hooks_list> cpp_hooks_list_;// The post_acc_grad_hooks_ field stores only Python hooks// (PyFunctionTensorPostAccGradHooks) that are called after the// .grad field has been accumulated into. This is less complicated// than the hooks_ field, which encapsulates a lot more.std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;// Only meaningful on leaf variables (must be false otherwise)bool requires_grad_{false};// Only meaningful on non-leaf variables (must be false otherwise)bool retains_grad_{false};bool is_view_{false};// The "output number" of this variable; e.g., if this variable// was the second output of a function, then output_nr == 1.// We use this to make sure we can setup the backwards trace// correctly when this variable is passed to another function.uint32_t output_nr_;// Mutex to ensure that concurrent read operations that modify internal// state are still thread-safe. Used by grad_fn(), grad_accumulator(),// fw_grad() and set_fw_grad()// This is mutable because we need to be able to acquire this from const// version of this class for the functions abovemutable std::mutex mutex_;/// Sets the `requires_grad` property of `Variable`. This should be true for/// leaf variables that want to accumulate gradients, and false for all other/// variables.void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) final {TORCH_CHECK(!requires_grad ||isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),"Only Tensors of floating point and complex dtype can require gradients");requires_grad_ = requires_grad;}bool requires_grad() const override {return requires_grad_ || grad_fn_;}/// Accesses the gradient `Variable` of this `Variable`.Variable& mutable_grad() override {return grad_;}const Variable& grad() const override {return grad_;}const Variable& fw_grad(uint64_t level, const at::TensorBase& self)const override;void set_fw_grad(const at::TensorBase& new_grad,const at::TensorBase& self,uint64_t level,bool is_inplace_op) override;AutogradMeta(at::TensorImpl* self_impl = nullptr,bool requires_grad = false,Edge gradient_edge = Edge()): grad_fn_(std::move(gradient_edge.function)),output_nr_(gradient_edge.input_nr) {// set_requires_grad also checks error conditions.if (requires_grad) {TORCH_INTERNAL_ASSERT(self_impl);set_requires_grad(requires_grad, self_impl);}TORCH_CHECK(!grad_fn_ || !requires_grad_,"requires_grad should be false if grad_fn is set");}~AutogradMeta() override {// If AutogradMeta is being destroyed, it means that there is no other// reference to its corresponding Tensor. It implies that no other thread// can be using this object and so there is no need to lock mutex_ here to// guard the check if fw_grad_ is populated.if (fw_grad_) {// See note [ Using ForwardGrad ]fw_grad_->clear();}}
};
1. AutogradMeta 类中的智能指针

AutogradMeta 是一个用于存储与自动求导(Autograd)相关元数据的数据结构。它包含了多种智能指针,例如:

  • std::shared_ptr<Node> grad_fn_;
  • std::weak_ptr<Node> grad_accumulator_;
  • mutable std::shared_ptr<ForwardGrad> fw_grad_;
  • std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_;

这些智能指针的应用各有不同,它们的主要作用是管理计算图中的节点、梯度、钩子函数等数据结构的生命周期。

2. std::shared_ptrstd::weak_ptr 的区别

首先,让我们从 std::shared_ptr<Node>std::weak_ptr<Node> 开始讲解。

  • std::shared_ptr<Node> grad_fn_; 是一个共享指针,表示该对象(grad_fn_)的所有者。一个 std::shared_ptr 会通过引用计数来管理对象的生命周期。当一个 shared_ptr 被复制时,引用计数会增加,而当指针超出作用域或被重置时,引用计数会减少,直到计数为 0 时对象会被销毁。

    示例代码:

    std::shared_ptr<Node> grad_fn = std::make_shared<Node>();
    // 在此,grad_fn 是 Node 类型对象的所有者
    

    应用场景:

    • std::shared_ptr 适用于需要共享资源所有权的场景。比如,在 AutogradMeta 中,grad_fn_ 指向的是梯度计算的计算图节点,该节点可能会被多个 Variable 共享,因此使用 std::shared_ptr 可以确保计算图在不再使用时被自动销毁。
  • std::weak_ptr<Node> grad_accumulator_; 是一个弱指针,通常与 std::shared_ptr 配合使用。它不会影响对象的引用计数,因此不会阻止对象的销毁。std::weak_ptr 适用于观察共享资源但不拥有其所有权的场景。

    示例代码:

    std::shared_ptr<Node> shared_ptr_node = std::make_shared<Node>();
    std::weak_ptr<Node> weak_ptr_node = shared_ptr_node;
    

    应用场景:

    • std::weak_ptr 常用于防止循环引用。在 AutogradMeta 中,grad_accumulator_ 可能指向一个梯度累加器对象,但我们并不想让它拥有该对象的所有权,因此使用 std::weak_ptr。这样,当没有任何 shared_ptr 指向该对象时,累加器会被销毁,避免内存泄漏。
3. mutable std::shared_ptr<ForwardGrad> fw_grad_;

mutable 关键字在这里的作用是允许即使在 const 对象上也能修改 fw_grad_ 成员变量。在 AutogradMeta 中,fw_grad_ 用于存储与正向自动求导相关的梯度。由于该对象的生命周期是动态管理的,所以它使用了 std::shared_ptr

示例代码:

mutable std::shared_ptr<ForwardGrad> fw_grad_;

应用场景:

  • AutogradMeta 类中,fw_grad_ 可能在对象生命周期内多次更新,因此需要一个 std::shared_ptr 来管理其内存。同时,mutable 允许即使 AutogradMeta 对象是 const 类型时,也可以修改 fw_grad_,这对线程安全和优化非常重要。
4. std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_;

std::unique_ptr 是独占指针,表示某个资源只能由一个指针管理。当 std::unique_ptr 被销毁时,它所管理的资源会被释放。

示例代码:

std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;

应用场景:

  • AutogradMeta 中,post_acc_grad_hooks_ 用于存储 Python 特定的钩子函数。这些钩子函数会在梯度累加后执行,因此使用 std::unique_ptr 确保钩子对象的独占管理,避免多个指针同时拥有该对象的所有权。
5. std::shared_ptrstd::weak_ptrstd::unique_ptr 对比
智能指针类型主要特点使用场景
std::shared_ptr共享所有权,通过引用计数管理资源的生命周期;多个指针可以共享资源用于共享资源所有权,确保资源在最后一个指针被销毁时被释放。
std::weak_ptr不增加资源的引用计数,不控制资源的生命周期;可以观察资源用于避免循环引用和观察对象的生命周期。
std::unique_ptr独占所有权,确保资源只被一个指针管理用于资源的独占管理,确保资源在超出作用域时被释放。
6. 总结与应用场景

在 C++ 中,智能指针是非常强大的工具,可以有效避免内存泄漏、野指针和循环引用等问题。std::shared_ptrstd::weak_ptrstd::unique_ptr 各有各的特点,能够应对不同的资源管理需求。结合 AutogradMeta 这样的复杂数据结构,智能指针帮助我们确保计算图、梯度和钩子等资源的安全管理。

  • std::shared_ptr 适用于需要共享资源所有权的场景,如计算图的节点。
  • std::weak_ptr 适用于观察资源但不控制其生命周期的场景,如梯度累加器。
  • std::unique_ptr 适用于独占资源所有权的场景,如梯度累加后的钩子函数。

通过合理选择智能指针类型,能够显著提升代码的安全性和可维护性,减少内存管理上的错误。


以上就是对 C++ 中智能指针的详细介绍及其在 AutogradMeta 类中的应用。希望通过这个例子,读者能够更加清晰地理解智能指针的区别及其适用场景。

附录:override关键字

override 关键字详解

在 C++ 中,override 是一个用于显式声明虚函数重写的关键字。它告诉编译器,当前成员函数是用来重写基类中的虚函数的。如果基类中没有对应的虚函数,编译器将生成一个错误,从而帮助开发者捕获潜在的错误。

语法和作用

在类的成员函数后面加上 override,表示该函数是重写了基类中的一个虚函数。如果基类中没有定义该函数,或者该函数的签名不匹配,编译器将报错。

语法示例:

class Base {
public:virtual void foo() {// 基类的实现}
};class Derived : public Base {
public:void foo() override { // 重写基类的 foo 函数// 派生类的实现}
};

在上面的代码中,Derived 类中的 foo 函数用 override 关键字显式地声明为重写基类 Base 中的 foo 函数。如果基类的 foo 函数没有定义为虚函数,或者派生类中的 foo 函数签名与基类的不一致,编译器会给出错误。

为什么要使用 override
  1. 避免拼写错误和签名错误: 使用 override 可以帮助程序员确保派生类中的函数签名完全匹配基类中的虚函数签名。如果有拼写错误或签名不匹配,编译器会在编译时提醒我们,避免在运行时遇到潜在的问题。

  2. 提高代码可读性: override 显示了一个函数是基类虚函数的重写,有助于代码阅读者理解该函数是被派生类特意重写的,而不是无意间添加的。

  3. 增强可维护性: 如果将来基类的虚函数发生了修改,override 可以帮助发现派生类中需要更新的地方,从而避免一些潜在的bug。

overrideset_fw_grad 中的应用

在您提供的代码片段中:

void set_fw_grad(const at::TensorBase& new_grad,const at::TensorBase& self,uint64_t level,bool is_inplace_op) override;

override 表示 set_fw_grad 函数重写了基类中的一个虚函数。这个函数的作用可能是设置正向梯度(forward gradient)。如果基类中没有定义 set_fw_grad 或其签名不同,编译器会报错,提醒开发者检查是否正确实现了虚函数。

示例:虚函数重写的完整示例
#include <iostream>class Base {
public:// 声明一个虚函数virtual void set_fw_grad(const std::string& new_grad) {std::cout << "Base class set_fw_grad: " << new_grad << std::endl;}
};class Derived : public Base {
public:// 重写基类的虚函数,并加上 override 关键字void set_fw_grad(const std::string& new_grad) override {std::cout << "Derived class set_fw_grad: " << new_grad << std::endl;}
};int main() {Base* obj = new Derived();obj->set_fw_grad("Gradient Data"); // 调用的是 Derived 类的重写函数delete obj;return 0;
}

输出:

Derived class set_fw_grad: Gradient Data
总结
  • override 是 C++11 引入的一个关键字,用于显式标识派生类中的成员函数重写了基类中的虚函数。
  • 它增强了代码的安全性,帮助避免常见的编程错误,如函数签名不匹配等问题。
  • set_fw_grad 中,override 确保该函数是正确地重写了基类的虚函数。如果基类中没有定义相应的虚函数,编译器会发出错误提示。

后记

2025年1月3日15点33分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/64915.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python +t kinter绘制彩虹和云朵

python t kinter绘制彩虹和云朵 彩虹&#xff0c;简称虹&#xff0c;是气象中的一种光学现象&#xff0c;当太阳光照射到半空中的水滴&#xff0c;光线被折射及反射&#xff0c;在天空上形成拱形的七彩光谱&#xff0c;由外圈至内圈呈红、橙、黄、绿、蓝、靛、紫七种颜色。事实…

Zabbix5.0版本(监控Nginx+PHP服务状态信息)

目录 1.监控Nginx服务状态信息 &#xff08;1&#xff09;通过Nginx监控模块&#xff0c;监控Nginx的7种状态 &#xff08;2&#xff09;开启Nginx状态模块 &#xff08;3&#xff09;配置监控项 &#xff08;4&#xff09;创建模板 &#xff08;5&#xff09;用默认键值…

Python入门教程 —— 字符串

字符串介绍 字符串可以理解为一段普通的文本内容,在python里,使用引号来表示一个字符串,不同的引号表示的效果会有区别。 字符串表示方式 a = "Im Tom" # 一对双引号 b = Tom said:"I am Tom" # 一对单引号c = Tom said:"I\m Tom" # 转义…

AcWing练习题:差

读取四个整数 A,B,C,D&#xff0c;并计算 (AB−CD)的值。 输入格式 输入共四行&#xff0c;第一行包含整数 A&#xff0c;第二行包含整数 B&#xff0c;第三行包含整数 C&#xff0c;第四行包含整数 D。 输出格式 输出格式为 DIFERENCA X&#xff0c;其中 X 为 (AB−CD) 的…

小程序添加购物车业务逻辑

数据库设计 DTO设计 实现步骤 1 判断当前加入购物车中的的商品是否已经存在了 2 如果已经存在 只需要将数量加一 3 如果不存在 插入一条购物车数据 4 判断加到本次购物车的是菜品还是套餐 Impl代码实现 Service public class ShoppingCartServiceImpl implements Shoppin…

如何在谷歌浏览器中使用自定义搜索快捷方式

在数字时代&#xff0c;浏览器已经成为我们日常生活中不可或缺的一部分。作为最常用的浏览器之一&#xff0c;谷歌浏览器凭借其简洁的界面和强大的功能深受用户喜爱。本文将详细介绍如何自定义谷歌浏览器的快捷工具栏&#xff0c;帮助你更高效地使用这一工具。 一、如何找到谷歌…

Python 3 与 Python 2 的主要区别

文章目录 1. 语法与关键字print 函数整数除法 2. 字符串处理默认字符串类型字符串格式化 3. 输入函数4. 迭代器和生成器range 函数map, filter, zip 5. 标准库变化urllib 模块configparser 模块 6. 异常处理7. 移除的功能8. 其他重要改进数据库操作多线程与并发类型注解 9. 总结…

关于IDE的相关知识之二【插件推荐】

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///计算机爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于ide插件推荐的相关内容&#xff01…

基于微信小程序的校园点餐平台的设计与实现(源码+SQL+LW+部署讲解)

文章目录 摘 要1. 第1章 选题背景及研究意义1.1 选题背景1.2 研究意义1.3 论文结构安排 2. 第2章 相关开发技术2.1 前端技术2.2 后端技术2.3 数据库技术 3. 第3章 可行性及需求分析3.1 可行性分析3.2 系统需求分析 4. 第4章 系统概要设计4.1 系统功能模块设计4.2 数据库设计 5.…

spring中使用@Validated,什么是JSR 303数据校验,spring boot中怎么使用数据校验

文章目录 一、JSR 303后台数据校验1.1 什么是 JSR303&#xff1f;1.2 为什么使用 JSR 303&#xff1f; 二、Spring Boot 中使用数据校验2.1 基本注解校验2.1.1 使用步骤2.1.2 举例Valid注解全局统一异常处理 2.2 分组校验2.2.1 使用步骤2.2.2 举例Validated注解Validated和Vali…

应用架构模式-总体思路

采用引导式设计方法&#xff1a;以企业级架构为指导&#xff0c;形成较为齐全的规范指引。在实践中总结重要设计形成决策要点&#xff0c;一个决策要点对应一个设计模式。自底向上总结采用该设计模式的必备条件&#xff0c;将之转化通过简单需求分析就能得到的业务特点&#xf…

【数据结构】双向循环链表的使用

双向循环链表的使用 1.双向循环链表节点设计2.初始化双向循环链表-->定义结构体变量 创建头节点&#xff08;1&#xff09;示例代码&#xff1a;&#xff08;2&#xff09;图示 3.双向循环链表节点头插&#xff08;1&#xff09;示例代码&#xff1a;&#xff08;2&#xff…

【Java设计模式-3】门面模式——简化复杂系统的魔法

在软件开发的世界里&#xff0c;我们常常会遇到复杂的系统&#xff0c;这些系统由多个子系统或模块组成&#xff0c;各个部分之间的交互错综复杂。如果直接让外部系统与这些复杂的子系统进行交互&#xff0c;不仅会让外部系统的代码变得复杂难懂&#xff0c;还会增加系统之间的…

Linux一些问题

修改YUM源 Centos7将yum源更换为国内源保姆级教程_centos使用中科大源-CSDN博客 直接安装包&#xff0c;走链接也行 Index of /7.9.2009/os/x86_64/Packages 直接复制里面的安装包链接&#xff0c;在命令行直接 yum install https://vault.centos.org/7.9.2009/os/x86_64/Pa…

HTML——57. type和name属性

<!DOCTYPE html> <html><head><meta charset"UTF-8"><title>type和name属性</title></head><body><!--1.input元素是最常用的表单控件--><!--2.input元素不仅可以在form标签内使用也可以在form标签外使用-…

uniapp本地加载腾讯X5浏览器内核插件

概述 TbsX5webviewUTS插件封装腾讯x5webview离线内核加载模块&#xff0c;可以把uniapp的浏览器内核直接替换成Android X5 Webview(腾讯TBS)最新内核&#xff0c;提高交互体验和流畅度。 功能说明 下载SDK插件 1.集成x5内核后哪些页面会由x5内核渲染&#xff1f; 所有plus…

设计模式 创建型 单例模式(Singleton Pattern)与 常见技术框架应用 解析

单例模式&#xff08;Singleton Pattern&#xff09;是一种创建型设计模式&#xff0c;旨在确保某个类在应用程序的生命周期内只有一个实例&#xff0c;并提供一个全局访问点来获取该实例。这种设计模式在需要控制资源访问、避免频繁创建和销毁对象的场景中尤为有用。 一、核心…

您的公司需要小型语言模型

当专用模型超越通用模型时 “越大越好”——这个原则在人工智能领域根深蒂固。每个月都有更大的模型诞生&#xff0c;参数越来越多。各家公司甚至为此建设价值100亿美元的AI数据中心。但这是唯一的方向吗&#xff1f; 在NeurIPS 2024大会上&#xff0c;OpenAI联合创始人伊利亚…

uniapp-vue3(下)

关联链接&#xff1a;uniapp-vue3&#xff08;上&#xff09; 文章目录 七、咸虾米壁纸项目实战7.1.咸虾米壁纸项目概述7.2.项目初始化公共目录和设计稿尺寸测量工具7.3.banner海报swiper轮播器7.4.使用swiper的纵向轮播做公告区域7.5.每日推荐滑动scroll-view布局7.6.组件具名…

使用 Python 实现随机中点位移法生成逼真的裂隙面

使用 Python 实现随机中点位移法生成逼真的裂隙面 一、随机中点位移法简介 1. 什么是随机中点位移法&#xff1f;2. 应用领域 二、 Python 代码实现 1. 导入必要的库2. 函数定义&#xff1a;随机中点位移法核心逻辑3. 设置随机数种子4. 初始化二维裂隙面5. 初始化网格的四个顶点…