用神经网络求解微分方程

微分方程是物理科学的主角之一,在工程、生物、经济甚至社会科学中都有广泛的应用。粗略地说,它们告诉我们一个量如何随时间变化(或其他参数,但通常我们对时间变化感兴趣)。我们可以了解人口、股票价格,甚至某个社会对某些主题的看法如何随时间变化。

通常,用于解决微分方程的方法不是分析性的(即没有解决方案的“封闭公式”),我们必须利用数值方法。然而,从计算的角度来看,数值方法可能很昂贵,更糟糕的是:累积误差可能非常大。

本文将展示神经网络如何成为解决微分方程的宝贵盟友,以及我们如何借用物理信息神经网络的概念来解决这个问题:我们可以使用机器学习方法来解决微分方程吗?

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割

1、物理学信息神经网络

在本节中,我将简要介绍物理学信息神经网络(PINN)。我想你知道“神经网络”部分,但是是什么让它们受到物理学的影响?好吧,它们并不是完全由物理学决定的,而是由(微分)方程决定的。

通常,神经网络经过训练可以找到模式并弄清楚一组训练数据发生了什么。但是,当你训练神经网络遵循训练数据的行为并希望拟合看不见的数据时,你的模型高度依赖于数据本身,而不是系统的底层性质。这听起来几乎像一个哲学问题,但它比这更实际:如果你的数据来自对洋流的测量,这些洋流必须遵循描述洋流的物理方程。但是请注意,你的神经网络对这些方程完全不可知,并且只试图拟合数据点。

这就是物理学信息发挥作用的地方。如果你的模型除了学习如何拟合数据之外,还学习如何拟合控制该系统的方程,那么你的神经网络的预测将更加精确,并且泛化能力会更好,这只是物理信息模型的一些优点。

请注意,系统的控制方程根本不需要涉及物理,“物理信息”只是一种命名法(而且这种技术无论如何都是物理学家最常用的)。如果你的系统是城市中的交通,并且你恰好有一个很好的数学模型,你希望神经网络的预测遵循该模型,那么物理信息神经网络非常适合你。

3、如何告知模型物理信息?

希望我已经说服了你,让模型了解控制我们系统的基础方程是值得的。但是,我们该怎么做呢?有几种方法可以做到这一点,但主要方法是调整损失函数,使其除了通常的数据相关部分之外,还有一个考虑控制方程的项。也就是说,损失函数 L 将由总和组成

这里,数据损失是通常的损失:均方差,或其他适合的损失函数形式;但方程部分是迷人的。想象一下你的系统由以下微分方程控制:

我们如何将其拟合到损失函数中?好吧,由于我们在训练神经网络时的任务是最小化损失函数,我们想要的是最小化以下表达式:

所以我们的方程相关损失函数结果是

也就是说,它是我们的 DE 的均方差。如果我们设法最小化这个值(即使这个项尽可能接近零),我们就会自动满足系统的控制方程。很聪明,对吧?

现在,需要解决损失函数中的额外项 L_IC:它考虑了系统的初始条件。如果没有提供系统的初始条件,则微分方程有无数个解。

例如,从地面扔出的球的轨迹由与从 10 楼扔出的球相同的微分方程控制;但是,我们确信这些球的路径不会相同。这里发生的变化是系统的初始条件。我们的模型如何知道我们正在讨论哪些初始条件?此时,我们自然会使用损失函数项来强制执行它!

对于我们的 DE,让我们规定当 t = 0 时,y = 1。因此,我们希望最小化初始条件损失函数,该函数的内容为:

如果我们最小化这个项,那么我们就会自动满足系统的初始条件。现在,剩下需要理解的是如何使用它来解决微分方程。

4、求解微分方程

如果神经网络既可以用损失函数的数据相关项进行训练(这通常是在经典架构中完成的),也可以用数据和方程相关项进行训练(这就是​​我刚才提到的物理信息神经网络),那么它一定可以训练为仅最小化方程相关项。这正是我们要做的!这里使用的唯一损失函数将是 L_equation。希望下面的图表能够说明我刚才所说的内容:今天我们的目标是右下角的模型类型,即我们的 DE 求解器 NN。

图 1:显示了各种神经网络及其损失函数的图表。在本文中,我们针对右下方的神经网络。

5、代码实现

为了展示我们刚刚学到的理论知识,我将使用机器学习的 PyTorch 库,在 Python 代码中实现所提出的解决方案。

首先要做的是创建一个神经网络架构:

import torch
import torch.nn as nnclass NeuralNet(nn.Module):def __init__(self, hidden_size, output_size=1,input_size=1):super(NeuralNet, self).__init__()self.l1 = nn.Linear(input_size, hidden_size)self.relu1 = nn.LeakyReLU()self.l2 = nn.Linear(hidden_size, hidden_size)self.relu2 = nn.LeakyReLU()self.l3 = nn.Linear(hidden_size, hidden_size)self.relu3 = nn.LeakyReLU()self.l4 = nn.Linear(hidden_size, output_size)def forward(self, x):out = self.l1(x)out = self.relu1(out)out = self.l2(out)out = self.relu2(out)out = self.l3(out)out = self.relu3(out)out = self.l4(out)return out

这只是具有 LeakyReLU 激活函数的简单 MLP。然后,我将定义损失函数,以便在训练循环中稍后计算它们:

# Create the criterion that will be used for the DE part of the loss
criterion = nn.MSELoss()# Define the loss function for the initial condition
def initial_condition_loss(y, target_value):return nn.MSELoss()(y, target_value)

现在,我们将创建一个用作训练数据的时间数组,并实例化模型,并选择一种优化算法:

# Time vector that will be used as input of our NN
t_numpy = np.arange(0, 5+0.01, 0.01, dtype=np.float32)
t = torch.from_numpy(t_numpy).reshape(len(t_numpy), 1)
t.requires_grad_(True)# Constant for the model
k = 1# Instantiate one model with 50 neurons on the hidden layers
model = NeuralNet(hidden_size=50)# Loss and optimizer
learning_rate = 8e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)# Number of epochs
num_epochs = int(1e4)

最后,让我们开始训练循环:

for epoch in range(num_epochs):# Randomly perturbing the training points to have a wider range of timesepsilon = torch.normal(0,0.1, size=(len(t),1)).float()t_train = t + epsilon# Forward passy_pred = model(t_train)# Calculate the derivative of the forward pass w.r.t. the input (t)dy_dt = torch.autograd.grad(y_pred, t_train, grad_outputs=torch.ones_like(y_pred), create_graph=True)[0]# Define the differential equation and calculate the lossloss_DE = criterion(dy_dt + k*y_pred, torch.zeros_like(dy_dt))# Define the initial condition lossloss_IC = initial_condition_loss(model(torch.tensor([[0.0]])), torch.tensor([[1.0]]))loss = loss_DE + loss_IC# Backward pass and weight updateoptimizer.zero_grad()loss.backward()optimizer.step()

请注意使用 torch.autograd.grad 函数自动对输出 y_pred 相对于输入 t 进行微分,以计算损失函数。

6、结果

经过训练,我们可以看到损失函数迅速收敛。图 2 显示了损失函数与 epoch 数的关系图,其中的插图显示了损失函数下降最快的区域。

图 2:按时期划分的损失函数。在插图中,我们可以看到收敛速度最快的区域。

你可能已经注意到,这个神经网络并不常见。它没有训练数据(我们的训练数据是手工制作的时间戳向量,这只是我们想要研究的时间域),因此它从系统获得的所有信息都以损失函数的形式出现。它的唯一目的是在它被设计用于解决的时间域内求解微分方程。因此,为了测试它,我们使用它训练的时间域是公平的。图 3 显示了 NN 预测与理论答案(即解析解)之间的比较。

图 3:所示神经网络预测和微分方程的解析解预测。

我们可以看到两者之间有相当好的一致性,这对神经网络来说非常好。

这种方法的一个缺点是它不能很好地概括未来的时间。图 4 显示了如果我们将时间数据点向前移动五步会发生什么,结果简直是一片混乱。

图 4:神经网络和未见数据点的解析解。

因此,这里的教训是,这种方法是作为时间域内微分方程的数值求解器,不应将其用作常规神经网络,使用未见的训练域外数据进行预测并期望它能很好地推广。

7、结束语

毕竟,还有一个问题是:

为什么要费心训练一个不能很好地推广到未见数据的神经网络,而且它显然比解析解更差,因为它有内在的统计误差?

首先,这里提供的示例是一个微分方程的示例,其解析解是已知的。对于未知的解,仍然必须使用数值方法。话虽如此,用于微分方程求解的数值方法通常会累积误差。这意味着如果你试图在许多时间步骤中求解方程,解将在此过程中失去其准确性。另一方面,神经网络求解器学习如何在其每个训练时期为所有数据点求解 DE。

另一个原因是神经网络是良好的插值器,因此如果你想知道看不见的数据中的函数值(但这种“看不见的数据”必须位于你训练的时间间隔内),神经网络将迅速为你提供经典数值方法无法迅速给出的值。


原文链接:用神经网络求解微分方程 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/48630.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

黑龙江等级保护测评深度解析

一、黑龙江等级保护测评概述 黑龙江等级保护测评(以下简称“等保测评”)是一项针对信息系统安全等级保护的综合性评估活动,旨在确保信息系统符合国家网络安全等级保护制度的要求,保障信息系统的安全稳定运行。 二、等保测评的重…

艺术成分很高的完全自定义的UITabBar(很简单)

引言 在iOS应用开发中,UITabBar是一个非常场景且重要的UI组件。系统为我们提供的UITabBar虽然功能强大,但是在某些情况下,它的标准样式并不能满足我们特定的设计需求,它的灵活性也有一些局限。为了打造更具个性化好的用户友好的交…

显卡驱动程序下载失败的原因及对策

在数字时代,显卡作为电脑的心脏部件之一,其驱动程序的正常运行是保证图形处理性能的关键。然而,不少用户在尝试下载显卡驱动程序时遭遇失败,这不仅影响了日常使用体验,还可能埋下系统不稳定的风险。本文将深入探讨显卡…

黑龙江网络安全等级保护测评策略概述

一、简介 黑龙江省网络安全等级保护测评策略是为了保障信息系统安全稳定运行,根据《网络安全法》和相关国家标准制定的综合性安全评估和加固过程。该策略不仅要求企业和机构明确自身信息系统的安全等级,还指导其实施相应的技术防护与管理措施&#xff0…

算法学习4——动态规划

动态规划(Dynamic Programming,简称DP)是一种用于解决具有重叠子问题和最优子结构性质的问题的算法设计技术。它通过将复杂问题分解为更小的子问题,并保存子问题的解来避免重复计算,从而提高算法的效率。 基本思想 动…

Mamba中的Mamba:在标记化Mamba模型中的集中式Mamba跨扫描高光谱图像分类

摘要 https://arxiv.org/pdf/2405.12003 高光谱图像(HSI)分类在遥感(RS)领域至关重要,尤其是随着深度学习技术的不断进步。顺序模型,如循环神经网络(RNNs)和Transformer&#xff0…

接近50个实用编程相关学习资源网站

Date: 2024.07.17 09:45:10 author: lijianzhan 编程语言以及编程相关工具等实用性官方文档网站 C语言文档:https://learn.microsoft.com/zh-cn/cpp/c-languageMicrosoft C、C和汇编程序文档:https://learn.microsoft.com/zh-cn/cppJAVA官方文档&#…

java题目之数字加密以及如何解密

public class Main6 {public static void main(String[] args) {// 某系统的数字密码&#xff08;大于0&#xff09;&#xff0c;比如1983&#xff0c;采用加密方式进行传输//定义了一个静态数组int []arr{1,9,8,3};//1.加密//先给每位数加上5for (int i 0; i <arr.length …

随机变量的数学期望

目录 简介 基本概念 数学期望的定义 数学期望的性质 数学期望的应用 计算实例 数学期望在解决哪些具体问题时最为有效&#xff1f; 如何计算两个或多个随机变量的组合概率及其期望值&#xff1f; 1. 计算组合概率 2. 计算期望值 当涉及到两个或多个随机变量的组合时&…

Hadoop基础组件介绍!

Hadoop是一个由Apache基金会所开发的分布式系统基础架构&#xff0c;Hadoop生态系统已经远远超出了这些基本组件&#xff0c;现在包括了多种组件和技术&#xff0c;详情介绍如下&#xff1a; HDFS&#xff08;Hadoop Distributed File System&#xff09; HDFS是Hadoop的核心组…

git实操之线上分支合并

线上分支合并 【 1 】本地dev分支合并到本地master上 # 本地dev分支合并到本地master上# 远程(线上)分支合并# 本地dev分支合并到本地master上# 远程(线上)分支合并#####本地和线上分支同步################ #### 远程创建分支&#xff0c;拉取到本地####-远程创建分支&#…

自定义Bean转换工具类

BeanConvertor工具类&#xff1a;简化Java对象转换的利器 在Java开发中,我们经常需要在不同的对象之间转换数据。这可能是因为我们需要将数据从一个层(如数据访问层)转移到另一个层(如服务层或表示层),或者是因为我们需要将外部API的数据结构转换为我们的内部数据结构。这种转…

企业级-PDF图片水印

作者&#xff1a;fyupeng 技术专栏&#xff1a;☞ https://github.com/fyupeng 项目地址&#xff1a;☞ https://github.com/fyupeng/distributed-blog-system-api 留给读者 遇到签名&#xff0c;往往很无奈签名的位置、大小。 一、介绍 直接提供PDF路径和图片路径&#xff0…

RK RGA _MMU unsupported memory larger then 4G!问题解决

使用RGA程序,长时间运行的过程中出现了rga_mm: RGA_MMU unsupported memory larger than 4G! rga_mm: RGA_MMU unsupported memory larger than 4G! rga_mm: scheduler core[4] unsupported mm_flag[0x8]! rga_mm: rga_mm_map_buffer map virtual address error! rga_mm: job…

华为云.云日志服务LTS及其基本使用

云计算 云日志服务LTS及其基本使用 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/qq_28550…

2024最新版虚拟便携空调小程序源码 支持流量主切换空调型号

产品截图 部分源代码展示 urls.js Object.defineProperty(exports, "__esModule", {value: !0 }), exports.default ["9c5f1fa582bee88300ffb7e28dce8b68_3188_128_128.png", "E-116154b04e91de689fb1c4ae99266dff_960.svg", "573eee719…

酱酒七个轮次口感与特点,哪个轮次最好喝?

都知道酱香型白酒是按照“12987”工艺酿造而成,这12987便是以一整年为一个生产周期,中间经历润沙下沙,并在多次的蒸煮发酵后,完成七次取酒。 所以酱香型白酒是由7个轮次的基酒勾调而成的,这七轮次酒口感特点各不相同,品质也是有着极大的差异。而这各个轮次基酒的勾调配比又直接…

【踩坑日记26】Connection timed out fatal: expected flush after ref listing ```

问题描述 (base) XXXomega:/home/XXX/code$ git clone https://github.com/comeeasy/DALS.git Cloning into DALS... error: RPC failed; curl 28 Failed to connect to github.com port 443: Connection timed out fatal: expected flush after ref listing解决方法 直接换一…

mysql的索引、事务和存储引擎

目录 索引 索引的概念 索引的作用 作用 索引的副作用 创建索引 创建索引的原则和依据 索引的类型 创建索引 查看索引 删除索引 drop 主键索引 普通索引 添加普通索引 唯一索引 添加唯一索引 组合索引 添加组合索引 查询组合索引 全文索引 添加全文索引 …

构建高效Node.js中间层:探索请求合并转发的艺术

&#x1f389; 博客主页&#xff1a;【剑九 六千里-CSDN博客】 &#x1f3a8; 上一篇文章&#xff1a;【CSS盒模型&#xff1a;掌握网页布局的核心】 &#x1f3a0; 系列专栏&#xff1a;【面试题-八股系列】 &#x1f496; 感谢大家点赞&#x1f44d;收藏⭐评论✍ 引言&#x…