自动微分技术在 AI for science 中的应用

本文简记我在学习自动微分相关技术时遇到的知识点。

反向传播和自动微分

以 NN 为代表的深度学习技术展现出了强大的参数拟合能力,人们通过堆叠固定的 layer 就能轻松设计出满足要求的参数拟合器。

例如,大部分图神经网络均基于消息传递的架构。在推理阶段,用户只需给出分子坐标及原子类型,就能得到整个分子的性质。因此其整体架构与下图类似:

img

在模型设计阶段,我们用 pytorch 即可满足大部分需求,以 schnetpack 为例:

  1. 我们 from torch import nn 导入了设计 nn 常用的模块。在初始化模型时,我们直接继承了 pytorch 内置的模块 class AtomisticModel(nn.Module)
  2. 有一些函数是重新编写的,例如激活函数 shiftedsoftplus

我们可以看到,模型的整体框架依然是基于 pytorch 的,但针对具体的应用场景,我们做了很多优化。

一方面,使用 pytorch 可以帮助我们快速建立类似上图的模型网络,pytorch 会自动执行梯度的反向传播。从 loss function 开始,逐层递进直至输入层。pytorch 还会帮助我们完成整个网络的参数迭代,学习率的迭代等等。。。

另一方面,针对一些特殊的需求,用户需要自行 DIY,完成需要的功能。

这其中隐含着,用户在程序设计时灵活性与便利性之间的折中。

注意到,刚才提到了梯度的反向传播,事实上,这种常用算法只是自动微分算法中的一种。引用 Gemini 的一个例子:

  • 反向传播好像是计算小山丘斜率(仅限于 NN)的一种算法;
  • 自动微分则可以计算除了小山丘以外的所有物品的斜率(涵盖所有链式求导法则);

写到这里,自动微分技术的应用场景就很好理解了:

  • 有一些应用场景不适合无脑堆叠 NN,但仍然需要优化参数,此时 from torch import nn 就不管用了,套用固定模版已经很难带来便利性;
  • 由于整个网络的框架已经不再是上图所示,规整的一层层的 NN 结构,反向传播算法就不再适用于参数优化了,需要更加灵活的自动微分方法;

pytorch 与 jax

我们可以将参数优化的相关框架归结为两个应用场景:

  1. 用户调用标准函数,搭建层级式标准 NN;
  2. 用户自行设计函数,搭建非标准拟合器(仍需优化参数)

针对第一个场景,我们可以使用 pytorch,因为 pytorch 对常用网络架构封装很好。

针对第二个场景,使用 pytorch 会更加繁琐,此时可以切换为 jax ,因为 jax 对用户自定义函数形式更加友好,其内置自动微分算法使用起来更加方便。

除了应用场景的区别外,二者还有以下几个区别:

  1. pytorch 支持静态/动态计算图,而 jax 仅支持静态图
  2. pytorch debug 起来更加方便
  3. jax 针对 GPU, TPU 等硬件优化更多,结合其 JIT(Just In Time) 特性,jax 模型一般比 pytorch 模型快得多
  4. 二者间的相互转换难度不大(参见:一文打通PyTorch与JAX)

AI for Science 领域内三个应用案例

DMFF

余旷老师在他的系列博文里系统阐释了为什么 DMFF 要基于 jax 开发(参见:漫谈分子力场、自动微分与DMFF项目:4. DMFF和JAX概述)

总结一下,使用 jax 的原因有以下几点:

  1. 传统分子力场的形式不适合用 NN 建模
    • 为方便大家理解,我举一个中学物理的例子。苹果从树上落下,遵从自由落体运动,位移随时间变化的规律:h=1/2 * g * t^2, 其中 g 作为引力常数就是需要通过多次落体实验测定的量。我们当然可以用多层 NN 拟合这一参数,但假如我们已经知道了这样一个表达式,此时直接使用该表达式即可。
    • 传统分子力场就是高度参数化的方程,发展至今已经有了一套函数形式,无需从头用 NN 的形式拟合
  2. 反向传播算法只适用与 NN,不适应上述高度参数化的方程,但优化力场参数仍需要自动微分技术
    • 计算原子受力,整个盒子的维里均需要微分技术,使用 jax 编程会更加方便
  3. jax 性能更高,速度快
  4. jax 可拓展性好
    • 余旷老师在 漫谈分子力场、自动微分与DMFF项目:5. DMFF中势函数的生成和拓展 举了一个例子,使用 DMFF 能有效复用前人开发势函数模块,无需从头造轮子

E3x

在 Oliver T. Unke 近期的一篇论文中,作者介绍了名为 E3x 的神经网络框架,对标 pytorch_geometric。

其目的在于,方便用户设计具有 E3 等变性的图神经网络。

使用 E3x 能将所有 AI for Science 领域的 GNN 从 pytorch 迁移至 jax 框架,再结合 jax-MD,获得大幅性能提升。

作者在另一篇论文中透露了这种改造的效果:

请添加图片描述

在稳定性和受力误差不变的情况下,NequIP 提速 28 倍,SchNet 提速 15 倍。那么,E3x 做了哪些关键改动呢?

  1. e3x 对不可约张量进行了压缩,降低了其稀疏性

    请添加图片描述

  2. e3x 设计了开箱即用的激活函数,全连接层、张量层等,这些网络结构都是 E3 等变的

DLDFPT

神经网络与密度泛函围绕理论的结合,论文地址

这是李贺大神今年上半年的一篇 PRL,说实话,我也没看懂。我只是理解到:

  • 传统的 DFPT 理论在计算某一个矩阵的时候遇到了计算瓶颈;
  • 使用自动微分技术能绕开这一瓶颈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/21012.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

带交互的卡尔曼滤滤波|一维滤波|源代码

背景 一维卡尔曼滤波的MATLAB例程,​背景为温度估计。 代码介绍 运行程序后,可以自己输入温度真实值: 以20℃为例,得到如下的估计值​: 滤波前的值和滤波后的值分别于期望值(真实值)作差…

基于Jenkins+Kubernetes+GitLab+Harbor构建CICD平台

1. 实验环境 1.1 k8s环境 1)Kubernetes 集群版本是 1.20.6 2)k8s控制节点: IP:192.168.140.130 主机名:k8s-master 配置:4C6G 3)k8s工作节点 节点1: IP:192.1…

【机器学习】基于OpenCV和TensorFlow的MobileNetV2模型的物种识别与个体相似度分析

在计算机视觉领域,物种识别和图像相似度比较是两个重要的研究方向。本文通过结合深度学习和图像处理技术,基于OpenCV和TensorFlow的MobileNetV2的预训练模型模,实现物种识别和个体相似度分析。本文详细介绍该实验过程并提供相关代码。 一、名…

JVM运行时数据区 - 程序计数器

运行时数据区 Java虚拟机在执行Java程序的过程中,会把它管理的内存划分成若干个不同的区域,这些区域有各自的用途、创建及销毁时间,有些区域随着虚拟机的启动一直存在,有些区域则随着用户线程的启动和结束而建立和销毁&#xff0…

前端组件业务数据选择功能优雅写法

1. 业务场景 后台管理在实际业务中,经常可见的功能为:在当前的页面中从其他列表中选择数据。 例如,在一个商品活动列表页面中 需要选择配置的商品。 2. 遇到问题 从代码划分的角度来说,每个业务列表代码首先分散开来&#xff0…

LeetCode刷题之HOT100之在排序数组中查找元素的第一个和最后一个位置

下午雨变小了,但我并未去实验室,难得的一天呆在宿舍。有些无聊,看看这个,弄弄那个,听听歌,消磨时间。不知觉中时间指针蹦到了九点,做题啦!朋友推荐了 Eason 的 2010-DUO 演唱会&…

2024年06月数据库流行度最新排名

点击查看最新数据库流行度最新排名(每月更新) 2024年06月数据库流行度最新排名 TOP DB顶级数据库索引是通过分析在谷歌上搜索数据库名称的频率来创建的 一个数据库被搜索的次数越多,这个数据库就被认为越受欢迎。这是一个领先指标。原始数…

低代码是什么?开发系统更有什么优势?

低代码(Low-Code)是一种应用开发方法,它采用图形化界面和预构建的模块,使得开发者能够通过少量的手动编程来快速创建应用程序。这种方法显著减少了传统软件开发中的手动编码量,提高了开发效率,降低了技术门…

thingsboard物联网平台快速入门教程

第一步,搭建服务器 使用我已经建好的服务器,thingsboard测试账号,租户管理员账号,物联网测试平台-CSDN博客 第二步,创建一个设备,获取设备Token 用租户管理员账户登录,左侧找到实体->设备&#xff0c…

Oracle导出clob字段到csv

使用UTL_FILE ref: How to Export The Table with a CLOB Column Into a CSV File using UTL_FILE ?(Doc ID 1967617.1) --preapre data CREATE TABLE TESTCLOB(ID NUMBER, MYCLOB1 CLOB, MYCLOB2 CLOB ); INSERT INTO TESTCLOB(ID,MYCLOB1,MYCLOB2) VALUES(1,Sample row 11…

Fiddler抓包工具的使用

目录 1、抓包原理:👇 2、抓包结果👇 1)如何查看一个http请求的原始摸样: 2)分析数据格式: 3、请求格式分析👇 4、响应格式分析👇 官网下载:安装过程比较…

【评价类模型】Topsis

综合赋权法:Topsis法: 主要适用情况:题目提供了足够的评价指标和数据,数据已知,评价指标的类型差异较大 基本思想:将所有方案与理想解和夫理想解进行比较,通过激素那方案与这两个解的举例去欸的…

深度学习复盘与论文复现B

文章目录 1、Knowledge Review1.1 NLLLoss vs CrossEntropyLoss1.2 MNIST dataset1.2.1 Repare Dataset1.2.2 Design Model1.2.3 Construct Loss and Optimizer1.2.4 Train and Test1.2.5 Training results Pytorch-Lightning MNIST :rocket::fire:1.3 Basic Convolutional Neu…

笔记:美团的测试

0.先启动appium 1.编写代码 如下: from appium import webdriver from appium.webdriver.extensions.android.nativekey import AndroidKeydesired_caps {platformName: Android,platformVersion: 10,deviceName: :VOG_AL10,appPackage: com.sankuai.meituan,ap…

Android关闭硬件加速对PorterDuffXfermode的影响

Android关闭硬件加速对PorterDuffXfermode的影响 跑的版本minSdk33 编译SDK34 import android.content.Context import android.graphics.Bitmap import android.graphics.Canvas import android.graphics.Color import android.graphics.Paint import android.graphics.Port…

OpenMV学习笔记3——画图函数汇总

画图,即在摄像头对应位置画出图形,对于需要反馈信息的程序来说很直观。就如上一篇文章颜色识别当中的例子一样,我们在识别出的色块上画出矩形方框,并在中间标出十字,可以直观的看到OpenMV现在识别出的色块。 目录 一…

执法装备管理系统DW-S304的概念与特点

执法装备管理系统(DW-S304)适用于多种警务和安保场景,如警察局、特警队、边防检查站、监狱管理系统、生态环境局、执法大队等。它可以帮助这些机构提高对装备的控制能力,确保装备在需要时能够迅速到位,同时也减少了因装…

API开发秘籍:揭秘Swagger与Spring REST Docs的文档自动化神技

在这个数字化时代,如何让你的业务像外卖一样快速送达顾客手中?本文将带你走进Spring Boot的世界,学习如何利用RESTful API构建一个高效、直观的“外卖帝国”。从基础的REST架构风格,到Spring MVC的魔力,再到Swagger和S…

追寻美的指引--纪念西蒙斯

周六早上醒来,James Simons(西蒙斯)辞世的消息刷屏了。多数人知道他,是因为他的财富和量化对冲基金公司-文艺复兴。但他更值得为人纪念的身份,则是数学家和慈善家。 西蒙斯1938年生于麻省,毕业于MIT&#…

探索 Python 的 vars() 函数

大家好,在软件开发的过程中,调试是一个不可或缺的环节。无论你是在解决 bug,优化代码,还是探索代码的执行流程,都需要一些有效的工具来帮助你更好地理解和调试代码。在 Python 编程中,vars() 函数是一个非常…