介绍 C++ 中的智能指针及其应用:以 AutogradMeta
为例
在 C++ 中,智能指针(Smart Pointer)是用于管理动态分配内存的一种工具。它们不仅自动管理内存的生命周期,还能帮助避免内存泄漏和野指针等问题。在深度学习框架如 PyTorch 的实现中,智能指针被广泛应用于复杂的数据结构和计算图的管理中。本文将结合 AutogradMeta
类,详细介绍 C++ 中的智能指针,解释 std::shared_ptr
、std::weak_ptr
、std::unique_ptr
等智能指针的使用场景及区别。
Source: https://github.com/pytorch/pytorch/blob/00df63f09f07546bacec734f37132edc58ccf574/torch/csrc/autograd/variable.h#L102
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// AutogradMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~/// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
/// metadata fields that are necessary for tracking the Variable's autograd
/// history. As an optimization, a Variable may store a nullptr, in lieu of a
/// default constructed AutogradMeta.struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {std::string name_;Variable grad_;std::shared_ptr<Node> grad_fn_;std::weak_ptr<Node> grad_accumulator_;// This field is used to store all the forward AD gradients// associated with this AutogradMeta (and the Tensor it corresponds to)// There is a semantic 1:1 correspondence between AutogradMeta and// ForwardGrad but:// - This field is lazily populated.// - This field is a shared_ptr but it must never be// shared by multiple Tensors. See Note [ Using ForwardGrad ]// Any transition from not_initialized to initialized// must be protected by mutex_mutable std::shared_ptr<ForwardGrad> fw_grad_;// The hooks_ field is actually reused by both python and cpp logic// For both cases, we have a data structure, cpp_hooks_list_ (cpp)// or dict (python) which is the canonical copy.// Then, for both cases, we always register a single hook to// hooks_ which wraps all the hooks in the list/dict.// And, again in both cases, if the grad_fn exists on that tensor// we will additionally register a single hook to the grad_fn.//// Note that the cpp and python use cases aren't actually aware of// each other, so using both is not defined behavior.std::vector<std::unique_ptr<FunctionPreHook>> hooks_;std::shared_ptr<hooks_list> cpp_hooks_list_;// The post_acc_grad_hooks_ field stores only Python hooks// (PyFunctionTensorPostAccGradHooks) that are called after the// .grad field has been accumulated into. This is less complicated// than the hooks_ field, which encapsulates a lot more.std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;// Only meaningful on leaf variables (must be false otherwise)bool requires_grad_{false};// Only meaningful on non-leaf variables (must be false otherwise)bool retains_grad_{false};bool is_view_{false};// The "output number" of this variable; e.g., if this variable// was the second output of a function, then output_nr == 1.// We use this to make sure we can setup the backwards trace// correctly when this variable is passed to another function.uint32_t output_nr_;// Mutex to ensure that concurrent read operations that modify internal// state are still thread-safe. Used by grad_fn(), grad_accumulator(),// fw_grad() and set_fw_grad()// This is mutable because we need to be able to acquire this from const// version of this class for the functions abovemutable std::mutex mutex_;/// Sets the `requires_grad` property of `Variable`. This should be true for/// leaf variables that want to accumulate gradients, and false for all other/// variables.void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) final {TORCH_CHECK(!requires_grad ||isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),"Only Tensors of floating point and complex dtype can require gradients");requires_grad_ = requires_grad;}bool requires_grad() const override {return requires_grad_ || grad_fn_;}/// Accesses the gradient `Variable` of this `Variable`.Variable& mutable_grad() override {return grad_;}const Variable& grad() const override {return grad_;}const Variable& fw_grad(uint64_t level, const at::TensorBase& self)const override;void set_fw_grad(const at::TensorBase& new_grad,const at::TensorBase& self,uint64_t level,bool is_inplace_op) override;AutogradMeta(at::TensorImpl* self_impl = nullptr,bool requires_grad = false,Edge gradient_edge = Edge()): grad_fn_(std::move(gradient_edge.function)),output_nr_(gradient_edge.input_nr) {// set_requires_grad also checks error conditions.if (requires_grad) {TORCH_INTERNAL_ASSERT(self_impl);set_requires_grad(requires_grad, self_impl);}TORCH_CHECK(!grad_fn_ || !requires_grad_,"requires_grad should be false if grad_fn is set");}~AutogradMeta() override {// If AutogradMeta is being destroyed, it means that there is no other// reference to its corresponding Tensor. It implies that no other thread// can be using this object and so there is no need to lock mutex_ here to// guard the check if fw_grad_ is populated.if (fw_grad_) {// See note [ Using ForwardGrad ]fw_grad_->clear();}}
};
1. AutogradMeta
类中的智能指针
AutogradMeta
是一个用于存储与自动求导(Autograd)相关元数据的数据结构。它包含了多种智能指针,例如:
std::shared_ptr<Node> grad_fn_;
std::weak_ptr<Node> grad_accumulator_;
mutable std::shared_ptr<ForwardGrad> fw_grad_;
std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_;
这些智能指针的应用各有不同,它们的主要作用是管理计算图中的节点、梯度、钩子函数等数据结构的生命周期。
2. std::shared_ptr
和 std::weak_ptr
的区别
首先,让我们从 std::shared_ptr<Node>
和 std::weak_ptr<Node>
开始讲解。
-
std::shared_ptr<Node> grad_fn_;
是一个共享指针,表示该对象(grad_fn_
)的所有者。一个std::shared_ptr
会通过引用计数来管理对象的生命周期。当一个shared_ptr
被复制时,引用计数会增加,而当指针超出作用域或被重置时,引用计数会减少,直到计数为 0 时对象会被销毁。示例代码:
std::shared_ptr<Node> grad_fn = std::make_shared<Node>(); // 在此,grad_fn 是 Node 类型对象的所有者
应用场景:
std::shared_ptr
适用于需要共享资源所有权的场景。比如,在AutogradMeta
中,grad_fn_
指向的是梯度计算的计算图节点,该节点可能会被多个Variable
共享,因此使用std::shared_ptr
可以确保计算图在不再使用时被自动销毁。
-
std::weak_ptr<Node> grad_accumulator_;
是一个弱指针,通常与std::shared_ptr
配合使用。它不会影响对象的引用计数,因此不会阻止对象的销毁。std::weak_ptr
适用于观察共享资源但不拥有其所有权的场景。示例代码:
std::shared_ptr<Node> shared_ptr_node = std::make_shared<Node>(); std::weak_ptr<Node> weak_ptr_node = shared_ptr_node;
应用场景:
std::weak_ptr
常用于防止循环引用。在AutogradMeta
中,grad_accumulator_
可能指向一个梯度累加器对象,但我们并不想让它拥有该对象的所有权,因此使用std::weak_ptr
。这样,当没有任何shared_ptr
指向该对象时,累加器会被销毁,避免内存泄漏。
3. mutable std::shared_ptr<ForwardGrad> fw_grad_;
mutable
关键字在这里的作用是允许即使在 const
对象上也能修改 fw_grad_
成员变量。在 AutogradMeta
中,fw_grad_
用于存储与正向自动求导相关的梯度。由于该对象的生命周期是动态管理的,所以它使用了 std::shared_ptr
。
示例代码:
mutable std::shared_ptr<ForwardGrad> fw_grad_;
应用场景:
- 在
AutogradMeta
类中,fw_grad_
可能在对象生命周期内多次更新,因此需要一个std::shared_ptr
来管理其内存。同时,mutable
允许即使AutogradMeta
对象是const
类型时,也可以修改fw_grad_
,这对线程安全和优化非常重要。
4. std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_;
std::unique_ptr
是独占指针,表示某个资源只能由一个指针管理。当 std::unique_ptr
被销毁时,它所管理的资源会被释放。
示例代码:
std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;
应用场景:
- 在
AutogradMeta
中,post_acc_grad_hooks_
用于存储 Python 特定的钩子函数。这些钩子函数会在梯度累加后执行,因此使用std::unique_ptr
确保钩子对象的独占管理,避免多个指针同时拥有该对象的所有权。
5. std::shared_ptr
、std::weak_ptr
、std::unique_ptr
对比
智能指针类型 | 主要特点 | 使用场景 |
---|---|---|
std::shared_ptr | 共享所有权,通过引用计数管理资源的生命周期;多个指针可以共享资源 | 用于共享资源所有权,确保资源在最后一个指针被销毁时被释放。 |
std::weak_ptr | 不增加资源的引用计数,不控制资源的生命周期;可以观察资源 | 用于避免循环引用和观察对象的生命周期。 |
std::unique_ptr | 独占所有权,确保资源只被一个指针管理 | 用于资源的独占管理,确保资源在超出作用域时被释放。 |
6. 总结与应用场景
在 C++ 中,智能指针是非常强大的工具,可以有效避免内存泄漏、野指针和循环引用等问题。std::shared_ptr
、std::weak_ptr
和 std::unique_ptr
各有各的特点,能够应对不同的资源管理需求。结合 AutogradMeta
这样的复杂数据结构,智能指针帮助我们确保计算图、梯度和钩子等资源的安全管理。
std::shared_ptr
适用于需要共享资源所有权的场景,如计算图的节点。std::weak_ptr
适用于观察资源但不控制其生命周期的场景,如梯度累加器。std::unique_ptr
适用于独占资源所有权的场景,如梯度累加后的钩子函数。
通过合理选择智能指针类型,能够显著提升代码的安全性和可维护性,减少内存管理上的错误。
以上就是对 C++ 中智能指针的详细介绍及其在 AutogradMeta
类中的应用。希望通过这个例子,读者能够更加清晰地理解智能指针的区别及其适用场景。
附录:override关键字
override
关键字详解
在 C++ 中,override
是一个用于显式声明虚函数重写的关键字。它告诉编译器,当前成员函数是用来重写基类中的虚函数的。如果基类中没有对应的虚函数,编译器将生成一个错误,从而帮助开发者捕获潜在的错误。
语法和作用
在类的成员函数后面加上 override
,表示该函数是重写了基类中的一个虚函数。如果基类中没有定义该函数,或者该函数的签名不匹配,编译器将报错。
语法示例:
class Base {
public:virtual void foo() {// 基类的实现}
};class Derived : public Base {
public:void foo() override { // 重写基类的 foo 函数// 派生类的实现}
};
在上面的代码中,Derived
类中的 foo
函数用 override
关键字显式地声明为重写基类 Base
中的 foo
函数。如果基类的 foo
函数没有定义为虚函数,或者派生类中的 foo
函数签名与基类的不一致,编译器会给出错误。
为什么要使用 override
?
-
避免拼写错误和签名错误: 使用
override
可以帮助程序员确保派生类中的函数签名完全匹配基类中的虚函数签名。如果有拼写错误或签名不匹配,编译器会在编译时提醒我们,避免在运行时遇到潜在的问题。 -
提高代码可读性:
override
显示了一个函数是基类虚函数的重写,有助于代码阅读者理解该函数是被派生类特意重写的,而不是无意间添加的。 -
增强可维护性: 如果将来基类的虚函数发生了修改,
override
可以帮助发现派生类中需要更新的地方,从而避免一些潜在的bug。
override
在 set_fw_grad
中的应用
在您提供的代码片段中:
void set_fw_grad(const at::TensorBase& new_grad,const at::TensorBase& self,uint64_t level,bool is_inplace_op) override;
override
表示 set_fw_grad
函数重写了基类中的一个虚函数。这个函数的作用可能是设置正向梯度(forward gradient)。如果基类中没有定义 set_fw_grad
或其签名不同,编译器会报错,提醒开发者检查是否正确实现了虚函数。
示例:虚函数重写的完整示例
#include <iostream>class Base {
public:// 声明一个虚函数virtual void set_fw_grad(const std::string& new_grad) {std::cout << "Base class set_fw_grad: " << new_grad << std::endl;}
};class Derived : public Base {
public:// 重写基类的虚函数,并加上 override 关键字void set_fw_grad(const std::string& new_grad) override {std::cout << "Derived class set_fw_grad: " << new_grad << std::endl;}
};int main() {Base* obj = new Derived();obj->set_fw_grad("Gradient Data"); // 调用的是 Derived 类的重写函数delete obj;return 0;
}
输出:
Derived class set_fw_grad: Gradient Data
总结
override
是 C++11 引入的一个关键字,用于显式标识派生类中的成员函数重写了基类中的虚函数。- 它增强了代码的安全性,帮助避免常见的编程错误,如函数签名不匹配等问题。
- 在
set_fw_grad
中,override
确保该函数是正确地重写了基类的虚函数。如果基类中没有定义相应的虚函数,编译器会发出错误提示。
后记
2025年1月3日15点33分于上海,在GPT4o大模型辅助下完成。