用神经网络求解微分方程

微分方程是物理科学的主角之一,在工程、生物、经济甚至社会科学中都有广泛的应用。粗略地说,它们告诉我们一个量如何随时间变化(或其他参数,但通常我们对时间变化感兴趣)。我们可以了解人口、股票价格,甚至某个社会对某些主题的看法如何随时间变化。

通常,用于解决微分方程的方法不是分析性的(即没有解决方案的“封闭公式”),我们必须利用数值方法。然而,从计算的角度来看,数值方法可能很昂贵,更糟糕的是:累积误差可能非常大。

本文将展示神经网络如何成为解决微分方程的宝贵盟友,以及我们如何借用物理信息神经网络的概念来解决这个问题:我们可以使用机器学习方法来解决微分方程吗?

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割

1、物理学信息神经网络

在本节中,我将简要介绍物理学信息神经网络(PINN)。我想你知道“神经网络”部分,但是是什么让它们受到物理学的影响?好吧,它们并不是完全由物理学决定的,而是由(微分)方程决定的。

通常,神经网络经过训练可以找到模式并弄清楚一组训练数据发生了什么。但是,当你训练神经网络遵循训练数据的行为并希望拟合看不见的数据时,你的模型高度依赖于数据本身,而不是系统的底层性质。这听起来几乎像一个哲学问题,但它比这更实际:如果你的数据来自对洋流的测量,这些洋流必须遵循描述洋流的物理方程。但是请注意,你的神经网络对这些方程完全不可知,并且只试图拟合数据点。

这就是物理学信息发挥作用的地方。如果你的模型除了学习如何拟合数据之外,还学习如何拟合控制该系统的方程,那么你的神经网络的预测将更加精确,并且泛化能力会更好,这只是物理信息模型的一些优点。

请注意,系统的控制方程根本不需要涉及物理,“物理信息”只是一种命名法(而且这种技术无论如何都是物理学家最常用的)。如果你的系统是城市中的交通,并且你恰好有一个很好的数学模型,你希望神经网络的预测遵循该模型,那么物理信息神经网络非常适合你。

3、如何告知模型物理信息?

希望我已经说服了你,让模型了解控制我们系统的基础方程是值得的。但是,我们该怎么做呢?有几种方法可以做到这一点,但主要方法是调整损失函数,使其除了通常的数据相关部分之外,还有一个考虑控制方程的项。也就是说,损失函数 L 将由总和组成

这里,数据损失是通常的损失:均方差,或其他适合的损失函数形式;但方程部分是迷人的。想象一下你的系统由以下微分方程控制:

我们如何将其拟合到损失函数中?好吧,由于我们在训练神经网络时的任务是最小化损失函数,我们想要的是最小化以下表达式:

所以我们的方程相关损失函数结果是

也就是说,它是我们的 DE 的均方差。如果我们设法最小化这个值(即使这个项尽可能接近零),我们就会自动满足系统的控制方程。很聪明,对吧?

现在,需要解决损失函数中的额外项 L_IC:它考虑了系统的初始条件。如果没有提供系统的初始条件,则微分方程有无数个解。

例如,从地面扔出的球的轨迹由与从 10 楼扔出的球相同的微分方程控制;但是,我们确信这些球的路径不会相同。这里发生的变化是系统的初始条件。我们的模型如何知道我们正在讨论哪些初始条件?此时,我们自然会使用损失函数项来强制执行它!

对于我们的 DE,让我们规定当 t = 0 时,y = 1。因此,我们希望最小化初始条件损失函数,该函数的内容为:

如果我们最小化这个项,那么我们就会自动满足系统的初始条件。现在,剩下需要理解的是如何使用它来解决微分方程。

4、求解微分方程

如果神经网络既可以用损失函数的数据相关项进行训练(这通常是在经典架构中完成的),也可以用数据和方程相关项进行训练(这就是​​我刚才提到的物理信息神经网络),那么它一定可以训练为仅最小化方程相关项。这正是我们要做的!这里使用的唯一损失函数将是 L_equation。希望下面的图表能够说明我刚才所说的内容:今天我们的目标是右下角的模型类型,即我们的 DE 求解器 NN。

图 1:显示了各种神经网络及其损失函数的图表。在本文中,我们针对右下方的神经网络。

5、代码实现

为了展示我们刚刚学到的理论知识,我将使用机器学习的 PyTorch 库,在 Python 代码中实现所提出的解决方案。

首先要做的是创建一个神经网络架构:

import torch
import torch.nn as nnclass NeuralNet(nn.Module):def __init__(self, hidden_size, output_size=1,input_size=1):super(NeuralNet, self).__init__()self.l1 = nn.Linear(input_size, hidden_size)self.relu1 = nn.LeakyReLU()self.l2 = nn.Linear(hidden_size, hidden_size)self.relu2 = nn.LeakyReLU()self.l3 = nn.Linear(hidden_size, hidden_size)self.relu3 = nn.LeakyReLU()self.l4 = nn.Linear(hidden_size, output_size)def forward(self, x):out = self.l1(x)out = self.relu1(out)out = self.l2(out)out = self.relu2(out)out = self.l3(out)out = self.relu3(out)out = self.l4(out)return out

这只是具有 LeakyReLU 激活函数的简单 MLP。然后,我将定义损失函数,以便在训练循环中稍后计算它们:

# Create the criterion that will be used for the DE part of the loss
criterion = nn.MSELoss()# Define the loss function for the initial condition
def initial_condition_loss(y, target_value):return nn.MSELoss()(y, target_value)

现在,我们将创建一个用作训练数据的时间数组,并实例化模型,并选择一种优化算法:

# Time vector that will be used as input of our NN
t_numpy = np.arange(0, 5+0.01, 0.01, dtype=np.float32)
t = torch.from_numpy(t_numpy).reshape(len(t_numpy), 1)
t.requires_grad_(True)# Constant for the model
k = 1# Instantiate one model with 50 neurons on the hidden layers
model = NeuralNet(hidden_size=50)# Loss and optimizer
learning_rate = 8e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)# Number of epochs
num_epochs = int(1e4)

最后,让我们开始训练循环:

for epoch in range(num_epochs):# Randomly perturbing the training points to have a wider range of timesepsilon = torch.normal(0,0.1, size=(len(t),1)).float()t_train = t + epsilon# Forward passy_pred = model(t_train)# Calculate the derivative of the forward pass w.r.t. the input (t)dy_dt = torch.autograd.grad(y_pred, t_train, grad_outputs=torch.ones_like(y_pred), create_graph=True)[0]# Define the differential equation and calculate the lossloss_DE = criterion(dy_dt + k*y_pred, torch.zeros_like(dy_dt))# Define the initial condition lossloss_IC = initial_condition_loss(model(torch.tensor([[0.0]])), torch.tensor([[1.0]]))loss = loss_DE + loss_IC# Backward pass and weight updateoptimizer.zero_grad()loss.backward()optimizer.step()

请注意使用 torch.autograd.grad 函数自动对输出 y_pred 相对于输入 t 进行微分,以计算损失函数。

6、结果

经过训练,我们可以看到损失函数迅速收敛。图 2 显示了损失函数与 epoch 数的关系图,其中的插图显示了损失函数下降最快的区域。

图 2:按时期划分的损失函数。在插图中,我们可以看到收敛速度最快的区域。

你可能已经注意到,这个神经网络并不常见。它没有训练数据(我们的训练数据是手工制作的时间戳向量,这只是我们想要研究的时间域),因此它从系统获得的所有信息都以损失函数的形式出现。它的唯一目的是在它被设计用于解决的时间域内求解微分方程。因此,为了测试它,我们使用它训练的时间域是公平的。图 3 显示了 NN 预测与理论答案(即解析解)之间的比较。

图 3:所示神经网络预测和微分方程的解析解预测。

我们可以看到两者之间有相当好的一致性,这对神经网络来说非常好。

这种方法的一个缺点是它不能很好地概括未来的时间。图 4 显示了如果我们将时间数据点向前移动五步会发生什么,结果简直是一片混乱。

图 4:神经网络和未见数据点的解析解。

因此,这里的教训是,这种方法是作为时间域内微分方程的数值求解器,不应将其用作常规神经网络,使用未见的训练域外数据进行预测并期望它能很好地推广。

7、结束语

毕竟,还有一个问题是:

为什么要费心训练一个不能很好地推广到未见数据的神经网络,而且它显然比解析解更差,因为它有内在的统计误差?

首先,这里提供的示例是一个微分方程的示例,其解析解是已知的。对于未知的解,仍然必须使用数值方法。话虽如此,用于微分方程求解的数值方法通常会累积误差。这意味着如果你试图在许多时间步骤中求解方程,解将在此过程中失去其准确性。另一方面,神经网络求解器学习如何在其每个训练时期为所有数据点求解 DE。

另一个原因是神经网络是良好的插值器,因此如果你想知道看不见的数据中的函数值(但这种“看不见的数据”必须位于你训练的时间间隔内),神经网络将迅速为你提供经典数值方法无法迅速给出的值。


原文链接:用神经网络求解微分方程 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/48630.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

艺术成分很高的完全自定义的UITabBar(很简单)

引言 在iOS应用开发中,UITabBar是一个非常场景且重要的UI组件。系统为我们提供的UITabBar虽然功能强大,但是在某些情况下,它的标准样式并不能满足我们特定的设计需求,它的灵活性也有一些局限。为了打造更具个性化好的用户友好的交…

显卡驱动程序下载失败的原因及对策

在数字时代,显卡作为电脑的心脏部件之一,其驱动程序的正常运行是保证图形处理性能的关键。然而,不少用户在尝试下载显卡驱动程序时遭遇失败,这不仅影响了日常使用体验,还可能埋下系统不稳定的风险。本文将深入探讨显卡…

Mamba中的Mamba:在标记化Mamba模型中的集中式Mamba跨扫描高光谱图像分类

摘要 https://arxiv.org/pdf/2405.12003 高光谱图像(HSI)分类在遥感(RS)领域至关重要,尤其是随着深度学习技术的不断进步。顺序模型,如循环神经网络(RNNs)和Transformer&#xff0…

java题目之数字加密以及如何解密

public class Main6 {public static void main(String[] args) {// 某系统的数字密码&#xff08;大于0&#xff09;&#xff0c;比如1983&#xff0c;采用加密方式进行传输//定义了一个静态数组int []arr{1,9,8,3};//1.加密//先给每位数加上5for (int i 0; i <arr.length …

随机变量的数学期望

目录 简介 基本概念 数学期望的定义 数学期望的性质 数学期望的应用 计算实例 数学期望在解决哪些具体问题时最为有效&#xff1f; 如何计算两个或多个随机变量的组合概率及其期望值&#xff1f; 1. 计算组合概率 2. 计算期望值 当涉及到两个或多个随机变量的组合时&…

git实操之线上分支合并

线上分支合并 【 1 】本地dev分支合并到本地master上 # 本地dev分支合并到本地master上# 远程(线上)分支合并# 本地dev分支合并到本地master上# 远程(线上)分支合并#####本地和线上分支同步################ #### 远程创建分支&#xff0c;拉取到本地####-远程创建分支&#…

自定义Bean转换工具类

BeanConvertor工具类&#xff1a;简化Java对象转换的利器 在Java开发中,我们经常需要在不同的对象之间转换数据。这可能是因为我们需要将数据从一个层(如数据访问层)转移到另一个层(如服务层或表示层),或者是因为我们需要将外部API的数据结构转换为我们的内部数据结构。这种转…

华为云.云日志服务LTS及其基本使用

云计算 云日志服务LTS及其基本使用 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/qq_28550…

2024最新版虚拟便携空调小程序源码 支持流量主切换空调型号

产品截图 部分源代码展示 urls.js Object.defineProperty(exports, "__esModule", {value: !0 }), exports.default ["9c5f1fa582bee88300ffb7e28dce8b68_3188_128_128.png", "E-116154b04e91de689fb1c4ae99266dff_960.svg", "573eee719…

mysql的索引、事务和存储引擎

目录 索引 索引的概念 索引的作用 作用 索引的副作用 创建索引 创建索引的原则和依据 索引的类型 创建索引 查看索引 删除索引 drop 主键索引 普通索引 添加普通索引 唯一索引 添加唯一索引 组合索引 添加组合索引 查询组合索引 全文索引 添加全文索引 …

构建高效Node.js中间层:探索请求合并转发的艺术

&#x1f389; 博客主页&#xff1a;【剑九 六千里-CSDN博客】 &#x1f3a8; 上一篇文章&#xff1a;【CSS盒模型&#xff1a;掌握网页布局的核心】 &#x1f3a0; 系列专栏&#xff1a;【面试题-八股系列】 &#x1f496; 感谢大家点赞&#x1f44d;收藏⭐评论✍ 引言&#x…

接口测试JMeter-1.接口测试初识

第一章 接口测试初识 1. 接口测试理论基础 “接口测试”一个让人觉得非常高大上的名词&#xff0c;特别是对于刚入门的测试同学而言。随着测试技术不断的深化&#xff0c;“接口测试”出现在我们视野中的频次越来越高。那么接口测试到底是如何做的&#xff1f;接口测试的优势又…

Flowable-SpringBoot项目集成

在前面的介绍中&#xff0c;虽然实现了绘制流程图&#xff0c;然后将流程图存储到数据库中&#xff0c;然后从数据库中获取流程信息&#xff0c;并部署和启动流程&#xff0c;但是部署的流程绘制器是在tomcat中部署的&#xff0c;可能在部分的项目中&#xff0c;需要我们将流程…

<数据集>pcb板缺陷检测数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;693张 标注数量(xml文件个数)&#xff1a;693 标注数量(txt文件个数)&#xff1a;693 标注类别数&#xff1a;6 标注类别名称&#xff1a;[missing_hole, mouse_bite, open_circuit, short, spurious_copper, spur…

git 提交的进阶操作

cherry-pick cherry-pick 是 Git 中的一种操作,允许你从一个分支中选择特定的 commit,并将其应用到另一个分支。它的主要用途是将特定的更改引入到其他分支,而无需合并整个分支历史。这在修复 bug 或者移植某些功能时特别有用。 cherry-pick 的使用场景 Bug 修复: 例如,你…

Python面试宝典第16题:跳跃游戏

题目 给你一个非负整数数组 nums &#xff0c;你最初位于数组的第一个下标 &#xff0c;数组中的每个元素代表你在该位置可以跳跃的最大长度。判断你是否能够到达最后一个下标&#xff0c;如果可以&#xff0c;返回 true。否则&#xff0c;返回 false。 示例 1&#xff1a; 输…

detection_segmentation

目标检测和实例分割(OBJECT_DETECTION AND INSTANCE SEGMENTATION) 文章目录 目标检测和实例分割(OBJECT_DETECTION AND INSTANCE SEGMENTATION)一. 计算机视觉(AI VISION)1. 图像分类2. 目标检测与定位3. 语义分割和实例分割目标检测算法可以分为两大类&#xff1a; R-CNN生成…

Linux系统:揭开它神秘面纱的科普之旅

在这个数字化时代&#xff0c;电脑和手机成了我们生活中不可或缺的一部分。而提到这些设备的操作系统&#xff0c;大家可能首先想到的是Windows、macOS或是Android。 但你知道吗&#xff0c;在技术的海洋里&#xff0c;还有一个强大而灵活的操作系统家族&#xff0c;它就是Lin…

python-多任务编程

2. 多任务编程 2.1 多任务概述 多任务 即操作系统中可以同时运行多个任务。比如我们可以同时挂着qq&#xff0c;听音乐&#xff0c;同时上网浏览网页。这是我们看得到的任务&#xff0c;在系统中还有很多系统任务在执行,现在的操作系统基本都是多任务操作系统&#xff0c;具备…

JVM--HostSpot算法细节实现

1.根节点枚举 定义&#xff1a; 我们以可达性分析算法中从GC Roots 集合找引用链这个操作作为介绍虚拟机高效实现的第一个例 子。固定可作为GC Roots 的节点主要在全局性的引用&#xff08;例如常量或类静态属性&#xff09;与执行上下文&#xff08;例如 栈帧中的本地变量表&a…