提示
从 PyTorch 2.4 开始,本教程已被废弃。请参考 PyTorch 自定义操作符,了解关于通过自定义 C++/CUDA 扩展扩展 PyTorch 的最新指南。
PyTorch 提供了大量与神经网络、任意张量代数、数据处理等相关的操作。然而,您可能仍然会发现自己需要一个更自定义的操作。例如,您可能想要使用论文中找到的一个新颖激活函数,或者实现您在研究中开发的一个操作。
将此类自定义操作集成到 PyTorch 中的最简单方法是通过扩展 Function
和 Module
来在 Python 中编写它们,正如这里所述。这样,您可以充分利用自动微分的功能(无需编写导数函数),并且保持 Python 的常规表达力。然而,也有可能您的操作更适合用 C++ 实现。例如,您的代码可能需要非常高效,因为它在模型中被频繁调用,或者即使是少数调用也非常昂贵。另一个可能的原因是,它依赖或与其他 C 或 C++ 库交互。为了解决这些情况,PyTorch 提供了一种非常简单的方式来编写自定义 C++ 扩展。
C++ 扩展是一种我们为用户(您)开发的机制,它允许您创建在源代码之外定义的 PyTorch 操作,即与 PyTorch