PINN:用深度学习PyTorch求解微分方程

神经网络技术已在计算机视觉与自然语言处理等多个领域实现了突破性进展。然而在微分方程求解领域,传统神经网络因其依赖大规模标记数据集的特性而表现出明显局限性。物理信息神经网络(Physics-Informed Neural Networks, PINN)通过将物理定律直接整合到学习过程中,有效弥补了这一不足,使其成为求解常微分方程(ODE)和偏微分方程(PDE)的高效工具。

传统神经网络模型需要依赖规模庞大的标记数据集,而这类数据的采集往往成本高昂且耗时显著。PINN通过将物理定律(具体表现为微分方程)融入训练过程,显著提高了数据利用效率。这种方法使得在流体动力学、量子力学和气候系统建模等科学领域实现基于数据的科学发现成为可能,为跨学科研究提供了新的技术路径。

求解微分方程一般方法

有如下微分方程:

图片

边界条件

图片

由于

图片

对 x 积分一次可得

图片

再次积分,我们得到

图片

现在,应用边界条件:

  1. 对于 y(0)=1

图片

  1. 对于 y(2)=5:

图片

因此,解析解为:

图片

用神经网络解决微分方程

该方法称为 PINN(物理信息神经网络),在我们的示例中的工作方式如下:

神经网络近似:

  • 我们定义一个神经网络 y(θ,x),其中 θ 表示网络参数(权重和偏差)。该网络旨在近似微分方程的解 y(x)。

  • 在我们的例子中,神经网络是一个小型全连接网络(具有一个或多个隐藏层),它以空间坐标 x 作为输入并输出 y(x) 的近似值。

自动微分:

  • 在这种情况下使用神经网络的一个主要好处是大多数现代深度学习库(如 PyTorch)都支持自动区分。

  • 这意味着我们可以直接从网络输出计算关于输入 x 的导数 y′(x) 和 y′′(x)。

残差计算:

  • 对于 ODE

图片

我们将残差 r(x) 定义为:

图片

在网络近似精确的点处,残差应该为零。

损失函数:

  • PINN 方法中的损失函数由两部分组成:

  • 残差损失:在域内的一组内部搭配点处计算残差 r(x) 的均方误差 (MSE)。该项强制网络的预测满足微分方程。

  • 边界条件损失:网络预测与给定边界条件之间的差异的 MSE。

    图片

  • 总损失为:

图片

PINN的技术特性与创新点

PINN与传统神经网络的根本区别在于,它不依赖于标记数据集进行学习,而是将微分方程约束直接嵌入到损失函数中。这意味着模型学习得到的函数*yNN(x)*需同时满足:

  • 给定的微分方程约束条件

  • 特定的边界条件和初始条件

PINN框架中的偏微分方程(PDE)通常表示为:

图片

其中

图片

以二阶微分方程为例:

图片

这表明所求函数y(x)必须严格满足该方程。

基于PINN求解微分方程的实践案例

步骤1: 导入必要的库函数

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

步骤2: 定义 y(x) 的神经网络近似值

class ODE_Net(nn.Module):def __init__(self, hidden_units=20):super(ODE_Net, self).__init__()self.layer1 = nn.Linear(1, hidden_units)self.layer2 = nn.Linear(hidden_units, hidden_units)self.layer3 = nn.Linear(hidden_units, 1)self.activation = nn.Tanh()def forward(self, x):out = self.activation(self.layer1(x))out = self.activation(self.layer2(out))out = self.layer3(out)return out

步骤3:计算 ODE 残差

def residual(model, x):"""Compute the ODE residual:y''(x) - 3 = 0."""# Enable gradients for xx.requires_grad_(True)y = model(x)# Compute first derivative: y'(x)dydx = torch.autograd.grad(y, x,grad_outputs=torch.ones_like(y),create_graph=True)[0]# Compute second derivative: y''(x)d2ydx2 = torch.autograd.grad(dydx, x,grad_outputs=torch.ones_like(dydx),create_graph=True)[0]# Compute the residual of the ODE: y''(x) - 3res = d2ydx2 - 3.0return res

步骤4:损失函数

def boundary_loss(model):"""Compute the loss associated with the boundary conditions:y(0)=1 and y(2)=5."""# Boundary condition at x=0: y(0)=1x0 = torch.tensor([[0.0]], device=device, requires_grad=True)y0 = model(x0)loss_bc0 = (y0 - 1.0)**2# Boundary condition at x=2: y(2)=5x2 = torch.tensor([[2.0]], device=device, requires_grad=True)y2 = model(x2)loss_bc2 = (y2 - 5.0)**2return loss_bc0 + loss_bc2

步骤5:模型训练

  # Initialize the model and optimizermodel = ODE_Net(hidden_units=20).to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)num_epochs = 5000# Generate interior points in the domain [0,2]N_interior = 50x_interior = 2 * torch.rand(N_interior, 1, device=device)  # uniformly distributed in [0,2]# Training loop
for epoch in range(num_epochs):model.train()optimizer.zero_grad()# Compute the residual loss at interior pointsr_interior = residual(model, x_interior)loss_res = torch.mean(r_interior**2)# Compute the boundary condition lossloss_bc = boundary_loss(model)# Total loss is the sum of the residual and boundary lossesloss = loss_res + loss_bcloss.backward()optimizer.step()if epoch % 500 == 0:print(f"Epoch {epoch}, Loss: {loss.item():.6e}")# Evaluate and compare the solutionmodel.eval()x_test = torch.linspace(0, 2, 100, device=device).unsqueeze(1)y_pred = model(x_test).detach().cpu().numpy().flatten()x_test_np = x_test.cpu().numpy().flatten()
Epoch 0, Loss: 3.222174e+01
Epoch 500, Loss: 1.378794e-01
Epoch 1000, Loss: 5.264541e-03
Epoch 1500, Loss: 3.903809e-03
Epoch 2000, Loss: 3.040434e-03
Epoch 2500, Loss: 2.319159e-03
Epoch 3000, Loss: 1.656389e-03
Epoch 3500, Loss: 9.695904e-04
Epoch 4000, Loss: 4.545122e-04
Epoch 4500, Loss: 2.485181e-04

步骤6:对比精确度

  # Analytical solution: y(x) = (3/2)x^2 - x + 1y_true = 1.5 * x_test_np**2 - x_test_np + 1plt.figure(figsize=(8, 4))plt.plot(x_test_np, y_pred, label="PINN Solution")plt.plot(x_test_np, y_true, '--', label="Analytical Solution")plt.xlabel("x")plt.ylabel("y(x)")plt.legend()plt.title("ODE: y''(x) - 3 = 0 with y(0)=1, y(2)=5")plt.show()

图片

使用 PINN 求解更复杂的 ODE

图片

class ODE_Net(nn.Module):def __init__(self, hidden_units=20):super(ODE_Net, self).__init__()self.layer1 = nn.Linear(1, hidden_units)self.layer2 = nn.Linear(hidden_units, hidden_units)self.layer3 = nn.Linear(hidden_units, hidden_units)self.layer4 = nn.Linear(hidden_units, 1)self.activation = nn.Tanh()def forward(self, x):out = self.activation(self.layer1(x))out = self.activation(self.layer2(out))out = self.activation(self.layer3(out))out = self.layer4(out)return outdef residual(model, x):x.requires_grad_(True)y = model(x)y_x = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y),create_graph=True)[0]y_xx = torch.autograd.grad(y_x, x, grad_outputs=torch.ones_like(y_x),create_graph=True)[0]y_xxx = torch.autograd.grad(y_xx, x, grad_outputs=torch.ones_like(y_xx),create_graph=True)[0]y_xxxx = torch.autograd.grad(y_xxx, x, grad_outputs=torch.ones_like(y_xxx),create_graph=True)[0]    res = y_xxxx - 2*y_xxx + y_xxreturn resdef boundary_loss(model):x0 = torch.tensor([[0.0]], device=device, requires_grad=True)y0 = model(x0)y0_x = torch.autograd.grad(y0, x0, grad_outputs=torch.ones_like(y0),create_graph=True)[0]y0_xx = torch.autograd.grad(y0_x, x0, grad_outputs=torch.ones_like(y0_x),create_graph=True)[0]y0_xxx = torch.autograd.grad(y0_xx, x0, grad_outputs=torch.ones_like(y0_xx),create_graph=True)[0]bc1 = y0 - 1.0      # y(0) = 1bc2 = y0_x - 0.0    # y'(0) = 0bc3 = y0_xx - (-1.0)  # y''(0) = -1  -> y0_xx + 1 = 0bc4 = y0_xxx - 2.0# y'''(0) = 2loss_bc = bc1**2 + bc2**2 + bc3**2 + bc4**2return loss_bcdef main():model = ODE_Net(hidden_units=20).to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)num_epochs = 10000N_interior = 50x_interior = torch.rand(N_interior, 1, device=device)for epoch in range(num_epochs):model.train()optimizer.zero_grad()r_interior = residual(model, x_interior)loss_res = torch.mean(r_interior**2)loss_bc = boundary_loss(model)        loss = loss_res + loss_bcloss.backward()optimizer.step()if epoch % 500 == 0:print(f"Epoch {epoch}, Loss: {loss.item():.6e}")model.eval()x_test = torch.linspace(0, 1, 100, device=device).unsqueeze(1)y_pred = model(x_test).detach().cpu().numpy().flatten()x_test_np = x_test.cpu().numpy().flatten()# Analytical solution: y(x) = 8 + 4x - 7e^x + 3xe^xy_true = 8 + 4*x_test_np - 7*np.exp(x_test_np) + 3*x_test_np*np.exp(x_test_np)plt.figure(figsize=(8,4))plt.plot(x_test_np, y_pred, label="Solution using PINN")plt.plot(x_test_np, y_true, '--', label="Analytical solution")plt.xlabel("x")plt.ylabel("y(x)")plt.legend()plt.show()if __name__ == "__main__":main()
Epoch 0, Loss: 6.779857e+00
Epoch 500, Loss: 2.092192e-01
Epoch 1000, Loss: 4.828146e-02
Epoch 1500, Loss: 3.233620e-02
Epoch 2000, Loss: 3.518355e-04
Epoch 2500, Loss: 2.392017e-04
Epoch 3000, Loss: 1.745588e-04
Epoch 3500, Loss: 1.332138e-04
Epoch 4000, Loss: 1.039377e-04
Epoch 4500, Loss: 3.754102e-03
Epoch 5000, Loss: 7.414911e-05
Epoch 5500, Loss: 5.272599e-05
Epoch 6000, Loss: 4.189969e-05
Epoch 6500, Loss: 1.759992e-03
Epoch 7000, Loss: 1.593289e-04
Epoch 7500, Loss: 2.400937e-05
Epoch 8000, Loss: 8.885263e-03
Epoch 8500, Loss: 6.434955e-05
Epoch 9000, Loss: 1.761451e-05
Epoch 9500, Loss: 1.477061e-05

图片

通过结果可以看出,我们已经成功地使用PINN方法求解了上述微分方程,并获得了与解析解高度一致的数值解。

写在最后

物理信息神经网络(PINN)代表了一种在微分方程求解领域的重要技术突破,它将深度学习与物理定律有机结合,为传统数值求解方法提供了一种高效、数据驱动的替代方案。PINN方法不仅在理论上具有创新性,同时在实际应用中展现出广阔的应用前景,为复杂物理系统的建模与分析提供了新的研究路径。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/76876.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

程序化广告行业(89/89):广告创意审核的关键要点与实践应用

程序化广告行业(89/89):广告创意审核的关键要点与实践应用 在程序化广告这个充满机遇与挑战的领域,持续学习和知识共享是我们不断进步的动力。一直以来,我都希望能和大家一同深入探索这个行业,今天让我们聚…

【ES6新特性】Proxy进阶实战

🌟ES6 Proxy终极指南:从拦截器到响应式框架实现🔥 一、💡 为什么Proxy是革命性的?先看痛点场景 1.1 Object.defineProperty的局限 😫 // Vue2响应式实现 let data { count: 0 }; Object.defineProperty(…

c++解决动态规划

一、引言: 在我们学习了算法之后,我们一定遇到过贪心算法。而在贪心算法中就有着这样一个经典的例子——凑钱。 Eg: 你有面额为10、5、1的纸币,当你买菜时需要花费26元,请问需要最少的纸币张数是多少。 当我们用贪心算法去解决这个问题的时候,我们…

Qwen 2.5 VL 多种推理方案

Qwen 2.5 VL 多种推理方案 flyfish 单图推理 from modelscope import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info import torchmodel_path "/media/model/Qwen/Qwen25-VL-7B-Instruct/"m…

机器视觉检测Pin针歪斜应用

在现代电子制造业中,Pin针(插针)是连接器、芯片插座、PCB板等元器件的关键部件。如果Pin针歪斜,可能导致接触不良、短路,甚至整机失效。传统的人工检测不仅效率低,还容易疲劳漏检。 MasterAlign 机器视觉对…

经典算法问题解析:两数之和与三数之和的Java实现

文章目录 1. 问题背景2. 两数之和(Two Sum)2.1 问题描述2.2 哈希表解法代码实现关键点解析复杂度对比 3. 三数之和(3Sum)3.1 问题描述3.2 排序双指针解法代码实现关键点解析复杂度分析 4. 对比总结5. 常见问题解答6. 扩展练习 1. …

1022 Digital Library

1022 Digital Library 分数 30 全屏浏览 切换布局 作者 CHEN, Yue 单位 浙江大学 A Digital Library contains millions of books, stored according to their titles, authors, key words of their abstracts, publishers, and published years. Each book is assigned an u…

地理人工智能中位置编码的综述:方法与应用

以下是对论文 《A Review of Location Encoding for GeoAI: Methods and Applications》 的大纲和摘要整理: A Review of Location Encoding for GeoAI: Methods and Applications 摘要(Summary) 本文系统综述了地理人工智能(G…

(C语言)算法复习总结2——分治算法

1. 分治算法的定义 分治算法(Divide and Conquer)是一种重要的算法设计策略。 “分治” 从字面意义上理解,就是 “分而治之”。 它将一个复杂的问题分解成若干个规模较小、相互独立且与原问题形式相同的子问题,然后递归地解决这…

爱普生FC1610AN5G手机中替代传统晶振的理想之选

在 5G 技术引领的通信新时代,手机性能面临前所未有的挑战与机遇。从高速数据传输到多任务高效处理,从长时间续航到紧凑轻薄设计,每一项提升都离不开内部精密组件的协同优化。晶振,作为为手机各系统提供稳定时钟信号的关键元件&…

Android 接口定义语言 (AIDL)

目录 1. 本地进程调用(同一进程内)2. 远程进程调用(跨进程)3 `oneway` 关键字用于修改远程调用的行为Android 接口定义语言 (AIDL) 与其他 IDL 类似: 你可以利用它定义客户端与服务均认可的编程接口,以便二者使用进程间通信 (IPC) 进行相互通信。 在 Android 上,一个进…

关于QT5项目只生成一个CmakeLists.txt文件

编译器自动检测明明可以检测,Kit也没有报红 但是最后生成项目只有一个文件 一:检查cmake版本,我4.1版本cmake一直报错 cmake3.10可以用 解决之后还是有问题 把环境变量加上去:

uniapp小程序位置授权弹框与隐私协议耦合(合而为一)(只在真机上有用,模拟器会分开弹 )

注意: 只在真机上有用,模拟器会分开弹 效果图: 模拟器效果图(授权框跟隐私政策会分开弹,先弹隐私政策,同意再弹授权弹框): manifest-template.json配置( "__usePr…

[Godot] C#人物移动抖动解决方案

在写一个2D平台跳跃的游戏代码发现,移动的时候会抖动卡顿的厉害,后来研究了一下抖动问题,有了几种解决方案 1.垂直同步和物理插值问题 这是最常见的可能导致画面撕裂和抖动的原因,大家可以根据自己的需要调整项目设置&#xff0…

红帽Linux网页访问问题

配置网络,手动配置 搭建yum仓库红帽Linux网页访问问题 下载httpd 网页访问问题:首先看httpd的状态---selinux的工作模式(强制)---上下文类型(semanage-fcontext)---selinux端口有没有放行semanage port ---防火墙有没有active---…

Android12编译x86模拟器报找不到userdata-qemu.img

qemu-system-x86_64: Could not open out/target/product/generic_x86_64/userdata-qemu.img: No such file or directory 选择编译aosp_x86-eng时没有生成模拟器,报 qemu-system-x86_64: Could not open out/target/product/generic_x86_64/userdata-qemu.img: No…

【AI论文】PixelFlow:基于流的像素空间生成模型

摘要:我们提出PixelFlow,这是一系列直接在原始像素空间中运行的图像生成模型,与主流的潜在空间模型形成对比。这种方法通过消除对预训练变分自编码器(VAE)的需求,并使整个模型能够端到端训练,从…

AI大模型学习九:‌Sealos cloud+k8s云操作系统私有化一键安装脚本部署完美教程(单节点)

一、说明 ‌Sealos‌是一款基于Kubernetes(K8s)的云操作系统发行版,它将K8s以及常见的分布式应用如Docker、Dashboard、Ingress等进行了集成和封装,使得用户可以在不深入了解复杂的K8s底层原理的情况下,快速搭建起一个…

【HDFS入门】HDFS核心组件DataNode详解:角色职责、存储机制与健康管理

目录 1 DataNode的角色定位 2 DataNode的核心职责 2.1 数据块管理 2.2 与NameNode的协作 3 DataNode的存储机制 3.1 数据存储目录结构 3.2 数据块文件组织 4 DataNode的工作流程 4.1 数据写入流程 4.2 数据读取流程 5 DataNode的健康管理 5.1 心跳机制(…

BufferedOutputStream 终极解析与记忆指南

BufferedOutputStream 终极解析与记忆指南 一、核心本质 BufferedOutputStream 是 Java 提供的缓冲字节输出流,继承自 FilterOutputStream,通过内存缓冲区显著提升 I/O 性能。 核心特性速查表 特性说明继承链OutputStream → FilterOutputStream → …